[
  {
    "path": "README.md",
    "content": "# Progressive Growing of GANs inference in PyTorch with CelebA training snapshot\n\n\n## Description\nThis is an inference sample written in [PyTorch](http://pytorch.org/) of the original Theano/Lasagne code.\n\nI recreated the network as described in the paper of [Karras et al.](http://research.nvidia.com/publication/2017-10_Progressive-Growing-of) \nSince some layers seemed to be missing in PyTorch, these were implemented as well. \nThe network and the layers can be found in `model.py`.\n\nFor the demo, a [100-celeb-hq-1024x1024-ours snapshot](https://drive.google.com/drive/folders/0B4qLcYyJmiz0bWJ5bHdKT0d6UXc) was used, which was made publicly available by the authors.\nSince I couldn't find any model converter between Theano/Lasagne and PyTorch, I used a quick and dirty script to transfer the weights between the models (`transfer_weights.py`).\n\nThis repo does not provide the code for training the networks.\n\n### Simple inference\nTo run the demo, simply execute `predict.py`.\nYou can specify other weights with the `--weights` flag.\n\nExample image:\n\n![Example image](https://raw.githubusercontent.com/ptrblck/prog_gans_pytorch_inference/master/example_small.png)\n\n\n### Latent space interpolation\nTo try the latent space interpolation, use `latent_interp.py`.\nAll output images will be saved in `./interp`.\n\nYou can chose between the \"gaussian interpolation\" introduced in the original paper\nand the \"slerp interpolation\" introduced by Tom White in his paper [Sampling Generative Networks](https://arxiv.org/abs/1609.04468v3)\nusing the `--type` argument.\n\nUse `--filter` to change the gaussian filter size for the gaussian interpolation and `--interp` for the interpolation steps\nfor the slerp interpolation.\n\nThe following arguments are defined:\n\n  * `--weights` - path to pretrained PyTorch state dict\n  * `--output` - Directory for storing interpolated images\n  * `--batch_size` - batch size for `DataLoader`\n  * `--num_workers` - number of workers for `DataLoader`\n  * `--type` {gauss, slerp} - interpolation type\n  * `--nb_latents` - number of latent vectors to generate\n  * `--filter` - gaussian filter length for interpolating latent space (gauss interpolation)\n  * `--interp` - interpolation length between each latent vector (slerp interpolation)\n  * `--seed` - random seed for numpy and PyTorch\n  * `--cuda` - use GPU \n\nThe total number of generated frames depends on the used interpolation technique.\n\nFor gaussian interpolation the number of generated frames equals `nb_latents`, while the slerp interpolation generates `nb_latents * interp` frames.\n\nExample interpolation:\n\n![Example interpolation](https://raw.githubusercontent.com/ptrblck/prog_gans_pytorch_inference/master/example_interp.gif)\n\n### Live latent space interpolation\nA live demo of the latent space interpolation using PyGame can be seen in `pygame_interp_demo.py`.\n\nUse the `--size` argument to change the output window size.\n\nThe following arguments are defined:\n\n  * `--weights` - path to pretrained PyTorch state dict\n  * `--num_workers` - number of workers for `DataLoader`\n  * `--type` {gauss, slerp} - interpolation type\n  * `--nb_latents` - number of latent vectors to generate\n  * `--filter` - gaussian filter length for interpolating latent space (gauss interpolation)\n  * `--interp` - interpolation length between each latent vector (slerp interpolation)\n  * `--size` - PyGame window size\n  * `--seed` - random seed for numpy and PyTorch\n  * `--cuda` - use GPU \n\n### Transferring weights\nThe pretrained lasagne weights can be transferred to a PyTorch state dict using `transfer_weights.py`.\n\nTo transfer other snapshots from the paper (other than CelebA), you have to modify the model architecture accordingly and use the corresponding weights.\n\n### Environment\nThe code was tested on Ubuntu 16.04 with an NVIDIA GTX 1080 using PyTorch v.0.2.0_4.\n\n  * `transfer_weights.py` needs Theano and Lasagne to load the pretrained weights.\n  * `pygame_interp_demo.py` needs PyGame to visualize the output\n\nA single forward pass took approx. 0.031 seconds.\n\n\n## Links\n\n* [Original code (Theano/Lasagne implementation)](https://github.com/tkarras/progressive_growing_of_gans)\n\n* [Paper (research.nvidia.com)](http://research.nvidia.com/publication/2017-10_Progressive-Growing-of)\n\n\n## License\n\nThis code is a modified form of the original code under the [CC BY-NC](https://creativecommons.org/licenses/by-nc/4.0/legalcode) license with the following copyright notice:\n\n```\n# Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n```\n\nAccording the Section 3, I hereby identify [Tero Karras et al. and NVIDIA](https://github.com/tkarras) as the original authors of the material.\n\n\n"
  },
  {
    "path": "latent_interp.py",
    "content": "#!/usr/bin/env python2\n# -*- coding: utf-8 -*-\n\"\"\"\nSample code for inference of Progressive Growing of GANs paper\n(https://github.com/tkarras/progressive_growing_of_gans)\nusing a CelebA snapshot\n\"\"\"\n\nfrom __future__ import print_function\nimport argparse\nimport os\n\nimport numpy as np\n\nimport torch\nfrom torch.autograd import Variable\nfrom torch.utils.data.dataloader import DataLoader\n\nfrom model import Generator\nfrom utils import LatentDataset, save_images\n\n\ninterp_types = ['gauss', 'slerp']\nuse_cuda = False\n\nparser = argparse.ArgumentParser(description='Interpolation demo')\nparser.add_argument(\n    '--weights',\n    default='100_celeb_hq_network-snapshot-010403.pth',\n    type=str,\n    metavar='PATH',\n    help='path to PyTorch state dict')\nparser.add_argument(\n    '--output',\n    type=str,\n    default='./interp',\n    help='Directory for storing interpolated imaged')\nparser.add_argument(\n    '--batch_size',\n    default=1,\n    type=int,\n    help='batch size')\nparser.add_argument(\n    '--num_workers',\n    default=1,\n    type=int,\n    help='number of workers for DataLoader')\nparser.add_argument(\n    '--type',\n    default='gauss',\n    choices=interp_types,\n    help='interpolation types: ' +\n         ' | '.join(interp_types) +\n         ' (default: gauss)')\nparser.add_argument(\n    '--nb_latents',\n    default=10,\n    type=int,\n    help='number of latent vectors to generate')\nparser.add_argument(\n    '--filter',\n    default=2,\n    type=int,\n    help='gauss filter length for latent vector smoothing (\\'gaus\\' interp)')\nparser.add_argument(\n    '--interp',\n    default=50,\n    type=int,\n    help='interpolation length between latents (\\'slerp\\' inter)')\nparser.add_argument(\n    '--seed',\n    default=187,\n    type=int,\n    help='Random seed')\nparser.add_argument(\n    '--cuda',\n    dest='cuda',\n    action='store_true',\n    help='Use GPU for processing')\n\n\ndef run(args):\n    global use_cuda\n    \n    print('Loading Generator')\n    model = Generator()\n    model.load_state_dict(torch.load(args.weights))\n    \n    if use_cuda:\n        model = model.cuda()\n        pin_memory = True\n    else:\n        pin_memory = False\n    \n    # Generate latent data\n    latent_dataset = LatentDataset(interp_type=args.type,\n                                   nb_latents=args.nb_latents,\n                                   filter_latents=args.filter,\n                                   nb_interp=args.interp)\n    latent_loader = DataLoader(latent_dataset,\n                               batch_size=args.batch_size,\n                               num_workers=args.num_workers,\n                               shuffle=False,\n                               pin_memory=pin_memory)\n    \n    print('Processing')\n    for i, data in enumerate(latent_loader):\n        if use_cuda:\n            data = data.cuda()\n        data = Variable(data, volatile=True)\n\n        output = model(data)\n\n        if use_cuda:\n            output = output.cpu()\n    \n        images_np = output.data.numpy()\n    \n        save_images(images_np, args.output, i*args.batch_size)\n\n\ndef main():\n    global use_cuda\n    args = parser.parse_args()\n\n    if not args.weights:\n        print('No PyTorch state dict path provided. Exiting...')\n        return\n\n    if args.cuda:\n        use_cuda = True\n\n    np.random.seed(args.seed)\n    torch.manual_seed(args.seed)\n    if use_cuda:\n        torch.cuda.manual_seed(args.seed)\n    \n    if not os.path.exists(args.output):\n        os.mkdir(args.output)\n\n    run(args)\n\nif __name__ == '__main__':\n    main()"
  },
  {
    "path": "model.py",
    "content": "#!/usr/bin/env python2\n# -*- coding: utf-8 -*-\n\"\"\"\nThis work is based on the Theano/Lasagne implementation of\nProgressive Growing of GANs paper from tkarras:\nhttps://github.com/tkarras/progressive_growing_of_gans\n\nPyTorch Model definition\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom collections import OrderedDict\n\n\nclass PixelNormLayer(nn.Module):\n    def __init__(self):\n        super(PixelNormLayer, self).__init__()\n\n    def forward(self, x):\n        return x / torch.sqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)\n\n\nclass WScaleLayer(nn.Module):\n    def __init__(self, size):\n        super(WScaleLayer, self).__init__()\n        self.scale = nn.Parameter(torch.randn([1]))\n        self.b = nn.Parameter(torch.randn(size))\n        self.size = size\n\n    def forward(self, x):\n        x_size = x.size()\n        x = x * self.scale + self.b.view(1, -1, 1, 1).expand(\n            x_size[0], self.size, x_size[2], x_size[3])\n\n        return x\n\n\nclass NormConvBlock(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size, padding):\n        super(NormConvBlock, self).__init__()\n        self.norm = PixelNormLayer()\n        self.conv = nn.Conv2d(\n            in_channels, out_channels, kernel_size, 1, padding, bias=False)\n        self.wscale = WScaleLayer(out_channels)\n\n    def forward(self, x):\n        x = self.norm(x)\n        x = self.conv(x)\n        x = F.leaky_relu(self.wscale(x), negative_slope=0.2)\n        return x\n\n\nclass NormUpscaleConvBlock(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size, padding):\n        super(NormUpscaleConvBlock, self).__init__()\n        self.norm = PixelNormLayer()\n        self.up = nn.Upsample(scale_factor=2, mode='nearest')\n        self.conv = nn.Conv2d(\n            in_channels, out_channels, kernel_size, 1, padding, bias=False)\n        self.wscale = WScaleLayer(out_channels)\n\n    def forward(self, x):\n        x = self.norm(x)\n        x = self.up(x)\n        x = self.conv(x)\n        x = F.leaky_relu(self.wscale(x), negative_slope=0.2)\n        return x\n\n\nclass Generator(nn.Module):\n    def __init__(self):\n        super(Generator, self).__init__()\n\n        self.features = nn.Sequential(\n            NormConvBlock(512, 512, kernel_size=4, padding=3),\n            NormConvBlock(512, 512, kernel_size=3, padding=1),\n            NormUpscaleConvBlock(512, 512, kernel_size=3, padding=1),\n            NormConvBlock(512, 512, kernel_size=3, padding=1),\n            NormUpscaleConvBlock(512, 512, kernel_size=3, padding=1),\n            NormConvBlock(512, 512, kernel_size=3, padding=1),\n            NormUpscaleConvBlock(512, 512, kernel_size=3, padding=1),\n            NormConvBlock(512, 512, kernel_size=3, padding=1),\n            NormUpscaleConvBlock(512, 256, kernel_size=3, padding=1),\n            NormConvBlock(256, 256, kernel_size=3, padding=1),\n            NormUpscaleConvBlock(256, 128, kernel_size=3, padding=1),\n            NormConvBlock(128, 128, kernel_size=3, padding=1),\n            NormUpscaleConvBlock(128, 64, kernel_size=3, padding=1),\n            NormConvBlock(64, 64, kernel_size=3, padding=1),\n            NormUpscaleConvBlock(64, 32, kernel_size=3, padding=1),\n            NormConvBlock(32, 32, kernel_size=3, padding=1),\n            NormUpscaleConvBlock(32, 16, kernel_size=3, padding=1),\n            NormConvBlock(16, 16, kernel_size=3, padding=1))\n\n        self.output = nn.Sequential(OrderedDict([\n                        ('norm', PixelNormLayer()),\n                        ('conv', nn.Conv2d(16,\n                                           3,\n                                           kernel_size=1,\n                                           padding=0,\n                                           bias=False)),\n                        ('wscale', WScaleLayer(3))\n                    ]))\n\n    def forward(self, x):\n        x = self.features(x)\n        x = self.output(x)\n        return x\n"
  },
  {
    "path": "network.py",
    "content": "# Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\nimport sys\nimport imp\nimport inspect\nimport copy\nimport collections\nimport numpy as np\nimport theano\nfrom theano import tensor as T\nimport lasagne\nimport cPickle\n\n# NOTE: Do not reference config.py here!\n# Instead, specify all network parameters as build function arguments.\n\n#----------------------------------------------------------------------------\n# Convenience.\n\nfrom lasagne.layers import InputLayer, Conv2DLayer, DenseLayer, NINLayer\nfrom lasagne.layers import Upscale2DLayer, Pool2DLayer, GlobalPoolLayer, MaxPool2DLayer\nfrom lasagne.layers import ReshapeLayer, ElemwiseSumLayer, ConcatLayer, FlattenLayer\nfrom lasagne.layers import NonlinearityLayer, ScaleLayer\n\nlinear,  ilinear  = lasagne.nonlinearities.linear,            lasagne.init.HeNormal(1.0)\nrelu,    irelu    = lasagne.nonlinearities.rectify,           lasagne.init.HeNormal('relu')\nlrelu,   ilrelu   = lasagne.nonlinearities.LeakyRectify(0.2), lasagne.init.HeNormal('relu')\nvlrelu            = lasagne.nonlinearities.LeakyRectify(0.3)\nelu,     ielu     = lasagne.nonlinearities.elu,               lasagne.init.HeNormal('relu')\ntanh,    itanh    = lasagne.nonlinearities.tanh,              lasagne.init.HeNormal(1.0)\nsigmoid, isigmoid = lasagne.nonlinearities.sigmoid,           lasagne.init.HeNormal(1.0)\nclip,    iclip    = lambda x: T.clip(x, 0, 1),                lasagne.init.HeNormal('relu')\n\ndef Tsum    (*args, **kwargs): return T.sum (*args, dtype=theano.config.floatX, acc_dtype=theano.config.floatX, **kwargs)\ndef Tmean   (*args, **kwargs): return T.mean(*args, dtype=theano.config.floatX, acc_dtype=theano.config.floatX, **kwargs)\ndef Tstd    (*args, **kwargs): return T.std (*args, **kwargs)\ndef Tstdeps (val, **kwargs):   return T.sqrt(Tmean(T.square(val - Tmean(val, **kwargs)), **kwargs) + 1.0e-8)\ndef Downscale2DLayer(incoming, scale_factor, **kwargs): return Pool2DLayer(incoming, pool_size=scale_factor, mode='average_exc_pad', **kwargs)\n\n#----------------------------------------------------------------------------\n# Wrapper class for Lasagne networks for robust pickling.\n\nclass Network(object):\n    def __init__(self, **build_func_spec):\n        self.build_func_spec    = build_func_spec       # dict(func='func_name', **kwargs)\n        self.build_module_src   = inspect.getsource(sys.modules[__name__]) # For pickle import.\n        self.input_layers       = [] # One or more.\n        self.output_layers      = [] # One or more.\n        self.input_shapes       = [] # Including minibatch dimension.\n        self.output_shapes      = [] # Including minibatch dimension.\n        self.input_shape        = () # For first input layer.\n        self.output_shape       = () # For first output layer.\n        #self.arbitrary_field   = ...# Arbitrary fields returned by the build func.\n        self.__dict__.update(self._call_build_func(globals()))\n        self._call_build_func_from_src() # Make sure that pickle import will work.\n\n    def eval(self, *inputs, **kwargs): # eval(input) => output --OR-- eval(primary_input, secondary_input, ...) => primary_output, secondary_output, ...\n        ignore_unused_inputs = kwargs.pop('ignore_unused_inputs', False)\n        expect_num_outputs = kwargs.pop('expect_num_outputs', None)\n        assert len(inputs) >= len(self.input_layers)\n        assert len(inputs) == len(self.input_layers) or ignore_unused_inputs\n        input_dict = dict(zip(self.input_layers, inputs[:len(self.input_layers)]))\n        outputs = lasagne.layers.get_output(self.output_layers, input_dict, **kwargs)\n        if expect_num_outputs is not None:\n            outputs += [None] * max(expect_num_outputs - len(outputs), 0)\n        return outputs[0] if len(outputs) == 1 else tuple(outputs)\n\n    def eval_d(self, *inputs, **kwargs):\n        return self.eval(*inputs, deterministic=True, **kwargs)\n\n    def eval_nd(self, *inputs, **kwargs):\n        return self.eval(*inputs, deterministic=False, **kwargs)\n\n    def eval_multi(self, *inputs, **kwargs): # eval(input_batch1, input_batch2, ...) => output_batch1, output_batch2, ... --OR-- eval([list], [list]) => [list], [list]\n        input_lists = [input if isinstance(input, list) or isinstance(input, tuple) else [input] for input in inputs]\n        combo_inputs = [T.concatenate(spliced_input, axis=0) for spliced_input in zip(*input_lists)]\n        combo_outputs = self.eval(*combo_inputs, **kwargs)\n        combo_outputs = combo_outputs if isinstance(combo_outputs, tuple) else [combo_outputs]\n        output_ranges = [sum(input_lists[j][0].shape[0] for j in xrange(i)) for i in xrange(len(input_lists))]\n        output_ranges = [(begin, begin + input_list[0].shape[0]) for input_list, begin in zip(input_lists, output_ranges)]\n        spliced_outputs = [[combo_output[begin : end] for begin, end in output_ranges] for combo_output in combo_outputs]\n        output_lists = [outputs[0] if len(outputs) == 1 else outputs for outputs in zip(*spliced_outputs)]\n        return output_lists[0] if len(output_lists) == 1 else tuple(output_lists)\n\n    def find_layer(self, name):\n        for layer in lasagne.layers.get_all_layers(self.output_layers):\n            if layer.name == name:\n                return layer\n        return None\n\n    def trainable_params(self):\n        return lasagne.layers.get_all_params(self.output_layers, trainable=True)\n\n    def toplevel_params(self): # returns dict(name=shared)\n        return {name: value for name, value in self.__dict__.iteritems() if isinstance(value, theano.compile.SharedVariable)}\n\n    def get_toplevel_param_values(self): # returns dict(name=value)\n        return {name: shared.get_value() for name, shared in self.toplevel_params().iteritems()}\n\n    def set_toplevel_param_values(self, value_dict): # accepts dict(name=value)\n        for name, shared in self.toplevel_params().iteritems():\n            if name in value_dict:\n                shared.set_value(value_dict[name])\n\n    def create_temporally_smoothed_version(self, beta=0.99, explicit_updates=True):\n        # Create shallow copy of the network.\n        net = Network.__new__(Network)\n        net.__dict__.update(self.__dict__)\n        layer_map = {layer: copy.copy(layer) for layer in lasagne.layers.get_all_layers(net.output_layers)}\n        net.input_layers = [layer_map[layer] for layer in net.input_layers]\n        net.output_layers = [layer_map[layer] for layer in net.output_layers]\n        for layer in layer_map.itervalues():\n            if hasattr(layer, 'input_layer'): layer.input_layer = layer_map[layer.input_layer]\n            if hasattr(layer, 'input_layers'): layer.input_layers = [layer_map[input] for input in layer.input_layers]\n\n        # Override trainable parameters with their smoothed versions.\n        if explicit_updates: net.updates = []\n        for layer in layer_map.itervalues():\n            orig_params = layer.params\n            param_map = dict()\n            for name, orig in layer.__dict__.items():\n                try:\n                    if orig in orig_params and 'trainable' in orig_params[orig] and beta > 0.0:\n                        smoothed = theano.shared(orig.get_value())\n                        param_map[orig] = smoothed\n                        updated = beta * smoothed + (1.0 - beta) * orig\n                        if explicit_updates: # explicit_updates=True: You need to explicitly include net.updates in a Theano function to update the weights.\n                            layer.__dict__[name] = smoothed\n                            net.updates.append((smoothed, updated))\n                        else: # explicit_updates=False: Weights are updated automatically every time the net is evaluated.\n                            layer.__dict__[name + '_param'] = orig # for print_network_topology_info()\n                            layer.__dict__[name] = updated\n                            smoothed.default_update = updated\n                except TypeError: # if orig is not hashable\n                    pass\n            layer.params = collections.OrderedDict()\n            for param, tags in orig_params.iteritems():\n                layer.params[param_map.get(param, param)] = copy.copy(tags)\n        return net\n\n    def _call_build_func(self, module_globals):\n        func_params = dict(self.build_func_spec)\n        func_name = func_params['func']\n        del func_params['func']\n        if 'subfunc' in func_params:\n            func_params['subfunc'] = module_globals[func_params['subfunc']]     # str --> function\n        func_result = module_globals[func_name](**func_params)\n\n        # func_result can be one of the following:\n        #   output_layer\n        #   [first_output_layer, second_output_layer, ...]\n        #   dict(output_layers=<one-or-more>)\n        #   dict(input_layers=<one-or-more>, output_layers=<one-or-more>)\n        #   dict(input_layers=<one-or-more>, output_layers=<one-or-more>, arbitray_field=arbitrary_value, ...)\n\n        # Convert output layer list to canonical form.\n        r = dict(func_result) if isinstance(func_result, dict) else dict(output_layers=func_result)\n        assert 'output_layers' in r\n        if isinstance(r['output_layers'], lasagne.layers.Layer):\n            r['output_layers'] = [r['output_layers']]\n\n        # Convert input layer list to canonical form.\n        if 'input_layers' not in r:\n            r['input_layers'] = [l for l in lasagne.layers.get_all_layers(r['output_layers']) if isinstance(l, InputLayer)]\n        elif isinstance(r['input_layers'], lasagne.layers.Layer):\n            r['input_layers'] = [r['input_layers']]\n\n        # Check that input/output layers are specified correctly.\n        assert isinstance(r['input_layers'], list) and len(r['input_layers']) >= 1\n        assert isinstance(r['output_layers'], list) and len(r['output_layers']) >= 1\n        assert all(isinstance(layer, InputLayer) for layer in r['input_layers'])\n\n        # Fill in input/output shapes.\n        r['input_shapes'] = lasagne.layers.get_output_shape(r['input_layers'])\n        r['output_shapes'] = lasagne.layers.get_output_shape(r['output_layers'])\n        r['input_shape'] = r['input_shapes'][0]\n        r['output_shape'] = r['output_shapes'][0]\n        return r\n\n    def _call_build_func_from_src(self):\n        tmp_module = imp.new_module('network_tmp_module')\n        exec self.build_module_src in tmp_module.__dict__\n        globals()['tmp_modules'] = globals().get('tmp_modules', []) + [tmp_module] # Work around issues with GC.\n        return self._call_build_func(tmp_module.__dict__)\n\n    def __getstate__(self): # Pickle export.\n        return {\n            'build_func_spec':  self.build_func_spec,\n            'build_module_src': self.build_module_src,\n            'param_values':     lasagne.layers.get_all_param_values(self.output_layers),\n            'toplevel_params':  self.get_toplevel_param_values()}\n\n    def __setstate__(self, state): # Pickle import.\n        self.build_func_spec    = state['build_func_spec']\n        self.build_module_src   = state['build_module_src']\n        self.__dict__.update(self._call_build_func_from_src())\n        lasagne.layers.set_all_param_values(self.output_layers, state['param_values'])\n        self.set_toplevel_param_values(state.get('toplevel_params', dict()))\n\n#----------------------------------------------------------------------------\n# Mark all parameters in the last layer as non-trainable.\n\ndef non_trainable(net):\n    for tags in net.params.itervalues():\n        tags -= {'trainable', 'regularizable'}\n    return net\n\n#----------------------------------------------------------------------------\n# Resize activation tensor 'v' of shape 'si' to match shape 'so'.\n\ndef resize_activations(v, si, so):\n    assert len(si) == len(so) and si[0] == so[0]\n\n    # Decrease feature maps.\n    if si[1] > so[1]:\n        v = v[:, :so[1]]\n\n    # Shrink spatial axes.\n    if len(si) == 4 and (si[2] > so[2] or si[3] > so[3]):\n        assert si[2] % so[2] == 0 and si[3] % so[3] == 0\n        ws = (si[2] / so[2], si[3] / so[3])\n        v = T.signal.pool.pool_2d(v, ws=ws, stride=ws, ignore_border=True, pad=(0,0), mode='average_exc_pad')\n\n    # Extend spatial axes.\n    for i in xrange(2, len(si)):\n        if si[i] < so[i]:\n            assert so[i] % si[i] == 0\n            v = T.extra_ops.repeat(v, so[i] / si[i], i)\n\n    # Increase feature maps.\n    if si[1] < so[1]:\n        z = T.zeros((v.shape[0], so[1] - si[1]) + so[2:], dtype=v.dtype)\n        v = T.concatenate([v, z], axis=1)\n    return v\n\n#----------------------------------------------------------------------------\n# Resolution selector for fading in new layers during progressive growing.\n\nclass LODSelectLayer(lasagne.layers.MergeLayer):\n    def __init__(self, incomings, cur_lod, first_incoming_lod=0, ref_idx=0, **kwargs):\n        super(LODSelectLayer, self).__init__(incomings, **kwargs)\n        self.cur_lod = cur_lod\n        self.first_incoming_lod = first_incoming_lod\n        self.ref_idx = ref_idx\n\n    def get_output_shape_for(self, input_shapes):\n        return input_shapes[self.ref_idx]\n\n    def get_output_for(self, inputs, min_lod=None, max_lod=None, **kwargs):\n        v = [resize_activations(input, shape, self.input_shapes[self.ref_idx]) for input, shape in zip(inputs, self.input_shapes)]\n        lo = np.clip(int(np.floor(min_lod - self.first_incoming_lod)), 0, len(v)-1) if min_lod is not None else 0\n        hi = np.clip(int(np.ceil(max_lod - self.first_incoming_lod)), lo, len(v)-1) if max_lod is not None else len(v)-1\n        t = self.cur_lod - self.first_incoming_lod\n        r = v[hi]\n        for i in xrange(hi-1, lo-1, -1): # i = hi-1, hi-2, ..., lo\n            r = theano.ifelse.ifelse(T.lt(t, i+1), v[i] * ((i+1)-t) + v[i+1] * (t-i), r)\n        if lo < hi:\n            r = theano.ifelse.ifelse(T.le(t, lo), v[lo], r)\n        return r\n\n#----------------------------------------------------------------------------\n# Pixelwise feature vector normalization.\n\nclass PixelNormLayer(lasagne.layers.Layer):\n    def __init__(self, incoming, **kwargs):\n        super(PixelNormLayer, self).__init__(incoming, **kwargs)\n    def get_output_for(self, v, **kwargs):\n        return v / T.sqrt(Tmean(v**2, axis=1, keepdims=True) + 1.0e-8)\n\n#----------------------------------------------------------------------------\n# Applies equalized learning rate to the preceding layer.\n\nclass WScaleLayer(lasagne.layers.Layer):\n    def __init__(self, incoming, **kwargs):\n        super(WScaleLayer, self).__init__(incoming, **kwargs)\n        W = incoming.W.get_value()\n        scale = np.sqrt(np.mean(W ** 2))\n        incoming.W.set_value(W / scale)\n        self.scale = self.add_param(scale, (), name='scale', trainable=False)\n        self.b = None\n        if hasattr(incoming, 'b') and incoming.b is not None:\n            b = incoming.b.get_value()\n            self.b = self.add_param(b, b.shape, name='b', regularizable=False)\n            del incoming.params[incoming.b]\n            incoming.b = None\n        self.nonlinearity = lasagne.nonlinearities.linear\n        if hasattr(incoming, 'nonlinearity') and incoming.nonlinearity is not None:\n            self.nonlinearity = incoming.nonlinearity\n            incoming.nonlinearity = lasagne.nonlinearities.linear\n\n    def get_output_for(self, v, **kwargs):\n        v = v * self.scale\n        if self.b is not None:\n            pattern = ['x', 0] + ['x'] * (v.ndim - 2)\n            v = v + self.b.dimshuffle(*pattern)\n        return self.nonlinearity(v)\n\n#----------------------------------------------------------------------------\n# Minibatch stat concatenation layer. \n# - func is the function to use for the activations across minibatch\n# - averaging tells how much averaging to use ('all', 'spatial', 'none')\n\nclass MinibatchStatConcatLayer(lasagne.layers.Layer):\n    def __init__(self, incoming, func, averaging, **kwargs):\n        super(MinibatchStatConcatLayer, self).__init__(incoming, **kwargs)\n        self.func = func\n        self.averaging = averaging\n\n    def get_output_shape_for(self, input_shape):\n        s = list(input_shape)\n        if self.averaging == 'all': s[1] += 1\n        elif self.averaging == 'flat': s[1] += 1\n        elif self.averaging.startswith('group'): s[1] += int(self.averaging[len('group'):])\n        else: s[1] *= 2\n        return tuple(s)\n\n    def get_output_for(self, input, **kwargs):\n        s = list(input.shape)\n        vals = self.func(input,axis=0,keepdims=True)                # per activation, over minibatch dim\n        if self.averaging == 'all':                                 # average everything --> 1 value per minibatch\n            vals = Tmean(vals,keepdims=True)\n            reps = s; reps[1]=1\n            vals = T.tile(vals,reps)\n        elif self.averaging == 'spatial':                           # average spatial locations\n            if len(s) == 4:\n                vals = Tmean(vals,axis=(2,3),keepdims=True)\n            reps = s; reps[1]=1\n            vals = T.tile(vals,reps)\n        elif self.averaging == 'none':                              # no averaging, pass on all information\n            vals = T.repeat(vals,repeats=s[0],axis=0)\n        elif self.averaging == 'gpool':                             # EXPERIMENTAL: compute variance (func) over minibatch AND spatial locations.\n            if len(s) == 4:\n                vals = self.func(input,axis=(0,2,3),keepdims=True)\n            reps = s; reps[1]=1\n            vals = T.tile(vals,reps)\n        elif self.averaging == 'flat':\n            vals = self.func(input,keepdims=True)                   # variance of ALL activations --> 1 value per minibatch\n            reps = s; reps[1]=1\n            vals = T.tile(vals,reps)\n        elif self.averaging.startswith('group'):                    # average everything over n groups of feature maps --> n values per minibatch\n            n = int(self.averaging[len('group'):])\n            vals = vals.reshape((1, n, s[1]/n, s[2], s[3]))\n            vals = Tmean(vals, axis=(2,3,4), keepdims=True)\n            vals = vals.reshape((1, n, 1, 1))\n            reps = s; reps[1] = 1\n            vals = T.tile(vals, reps)\n        else:\n            raise ValueError('Invalid averaging mode', self.averaging)\n        return T.concatenate([input, vals], axis=1)\n\n#----------------------------------------------------------------------------\n# Generalized dropout layer. Supports arbitrary subsets of axes and different\n# modes. Mainly used to inject multiplicative Gaussian noise in the network.\n\nclass GDropLayer(lasagne.layers.Layer):\n    def __init__(self, incoming, mode='mul', strength=0.4, axes=(0,1), normalize=False, **kwargs):\n        super(GDropLayer, self).__init__(incoming, **kwargs)\n        assert mode in ('drop', 'mul', 'prop')\n        self.random     = theano.sandbox.rng_mrg.MRG_RandomStreams(lasagne.random.get_rng().randint(1, 2147462579))\n        self.mode       = mode\n        self.strength   = strength\n        self.axes       = [axes] if isinstance(axes, int) else list(axes)\n        self.normalize  = normalize # If true, retain overall signal variance.\n        self.gain       = None      # For experimentation.\n\n    def get_output_for(self, input, deterministic=False, **kwargs):\n        if self.gain is not None:\n            input = input * self.gain\n        if deterministic or not self.strength:\n            return input\n\n        in_shape  = self.input_shape\n        in_axes   = range(len(in_shape))\n        in_shape  = [in_shape[axis] if in_shape[axis] is not None else input.shape[axis] for axis in in_axes] # None => Theano expr\n        rnd_shape = [in_shape[axis] for axis in self.axes]\n        broadcast = [self.axes.index(axis) if axis in self.axes else 'x' for axis in in_axes]\n        one       = T.constant(1)\n\n        if self.mode == 'drop':\n            p = one - self.strength\n            rnd = self.random.binomial(tuple(rnd_shape), p=p, dtype=input.dtype) / p\n\n        elif self.mode == 'mul':\n            rnd = (one + self.strength) ** self.random.normal(tuple(rnd_shape), dtype=input.dtype)\n\n        elif self.mode == 'prop':\n            coef = self.strength * T.constant(np.sqrt(np.float32(self.input_shape[1])))\n            rnd = self.random.normal(tuple(rnd_shape), dtype=input.dtype) * coef + one\n\n        else:\n            raise ValueError('Invalid GDropLayer mode', self.mode)\n\n        if self.normalize:\n            rnd = rnd / T.sqrt(Tmean(rnd ** 2, axis=1, keepdims=True))\n        return input * rnd.dimshuffle(broadcast)\n\n#----------------------------------------------------------------------------\n# Layer normalization. Custom reimplementation based on the paper:\n# https://arxiv.org/abs/1607.06450\n\nclass LayerNormLayer(lasagne.layers.Layer):\n    def __init__(self, incoming, epsilon=1.0e-4, **kwargs):\n        super(LayerNormLayer, self).__init__(incoming, **kwargs)\n        self.epsilon = epsilon\n        self.gain = self.add_param(np.float32(1.0), (), name='gain', trainable=True)\n        self.b = None\n        if hasattr(incoming, 'b') and incoming.b is not None: # steal bias\n            b = incoming.b.get_value()\n            self.b = self.add_param(b, b.shape, name='b', regularizable=False)\n            del incoming.params[incoming.b]\n            incoming.b = None\n        self.nonlinearity = lasagne.nonlinearities.linear\n        if hasattr(incoming, 'nonlinearity') and incoming.nonlinearity is not None: # steal nonlinearity\n            self.nonlinearity = incoming.nonlinearity\n            incoming.nonlinearity = lasagne.nonlinearities.linear\n\n    def get_output_for(self, v, **kwargs):\n        avg_axes = range(1, len(self.input_shape))\n        v = v - Tmean(v, axis=avg_axes, keepdims=True) # subtract mean\n        v = v * T.inv(T.sqrt(Tmean(T.square(v), axis=avg_axes, keepdims=True) + self.epsilon)) # divide by stdev\n        v = v * self.gain # multiply by gain\n        if self.b is not None:\n            pattern = ['x', 0] + ['x'] * (v.ndim - 2)\n            v = v + self.b.dimshuffle(*pattern) # apply bias\n        return self.nonlinearity(v) # apply nonlinearity\n\n#----------------------------------------------------------------------------\n# Generator network template used in the paper.\n\ndef G_paper(\n    num_channels        = 1,        # Overridden based on dataset.\n    resolution          = 32,       # Overridden based on dataset.\n    label_size          = 0,        # Overridden based on dataset.\n    fmap_base           = 4096,\n    fmap_decay          = 1.0,\n    fmap_max            = 256,\n    latent_size         = None,\n    normalize_latents   = True,\n    use_wscale          = True,\n    use_pixelnorm       = True,\n    use_leakyrelu       = True,\n    use_batchnorm       = False,\n    tanh_at_end         = None,\n    **kwargs):\n\n    R = int(np.log2(resolution))\n    assert resolution == 2**R and resolution >= 4\n    cur_lod = theano.shared(np.float32(0.0))\n    def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)\n    def PN(layer): return PixelNormLayer(layer, name=layer.name+'pn') if use_pixelnorm else layer\n    def BN(layer): return lasagne.layers.batch_norm(layer) if use_batchnorm else layer\n    def WS(layer): return WScaleLayer(layer, name=layer.name+'S') if use_wscale else layer\n    if latent_size is None: latent_size = nf(0)\n    (act, iact) = (lrelu, ilrelu) if use_leakyrelu else (relu, irelu)\n\n    input_layers = [InputLayer(name='Glatents', shape=[None, latent_size])]\n    net = input_layers[-1]\n    if normalize_latents:\n        net = PixelNormLayer(net, name='Glnorm')\n    if label_size:\n        input_layers += [InputLayer(name='Glabels', shape=[None, label_size])]\n        net = ConcatLayer(name='Gina', incomings=[net, input_layers[-1]])\n\n    net = ReshapeLayer(name='Ginb', incoming=net, shape=[[0], [1], 1, 1])\n    net = PN(BN(WS(Conv2DLayer(net, name='G1a', num_filters=nf(1), filter_size=4, pad='full', nonlinearity=act, W=iact))))\n    net = PN(BN(WS(Conv2DLayer(net, name='G1b', num_filters=nf(1), filter_size=3, pad=1,      nonlinearity=act, W=iact))))\n    lods  = [net]\n\n    for I in xrange(2, R): # I = 2, 3, ..., R-1\n        net = Upscale2DLayer(net, name='G%dup' % I, scale_factor=2)\n        net = PN(BN(WS(Conv2DLayer(net, name='G%da'  % I, num_filters=nf(I), filter_size=3, pad=1, nonlinearity=act, W=iact))))\n        net = PN(BN(WS(Conv2DLayer(net, name='G%db'  % I, num_filters=nf(I), filter_size=3, pad=1, nonlinearity=act, W=iact))))\n        lods += [net]\n\n    lods = [WS(NINLayer(l, name='Glod%d' % i, num_units=num_channels, nonlinearity=linear, W=ilinear)) for i, l in enumerate(reversed(lods))]\n    output_layer = LODSelectLayer(name='Glod', incomings=lods, cur_lod=cur_lod, first_incoming_lod=0)\n    if tanh_at_end is not None:\n        output_layer = NonlinearityLayer(output_layer, name='Gtanh', nonlinearity=tanh)\n        if tanh_at_end != 1.0:\n            output_layer = non_trainable(ScaleLayer(output_layer, name='Gtanhs', scales=lasagne.init.Constant(tanh_at_end)))\n    return dict(input_layers=input_layers, output_layers=[output_layer], cur_lod=cur_lod)\n\n#----------------------------------------------------------------------------\n# Discriminator network template used in the paper.\n\ndef D_paper(\n    num_channels    = 1,        # Overridden based on dataset.\n    resolution      = 32,       # Overridden based on dataset.\n    label_size      = 0,        # Overridden based on dataset.\n    fmap_base       = 4096,\n    fmap_decay      = 1.0,\n    fmap_max        = 256,\n    mbstat_func     = 'Tstdeps',\n    mbstat_avg      = 'all',\n    mbdisc_kernels  = None,\n    use_wscale      = True,\n    use_gdrop       = True,\n    use_layernorm   = False,\n    **kwargs):\n\n    R = int(np.log2(resolution))\n    assert resolution == 2**R and resolution >= 4\n    cur_lod = theano.shared(np.float32(0.0))\n    gdrop_strength = theano.shared(np.float32(0.0))\n    def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)\n    def GD(layer): return GDropLayer(layer, name=layer.name+'gd', mode='prop', strength=gdrop_strength) if use_gdrop else layer\n    def LN(layer): return LayerNormLayer(layer, name=layer.name+'ln') if use_layernorm else layer\n    def WS(layer): return WScaleLayer(layer, name=layer.name+'ws') if use_wscale else layer\n\n    input_layer = InputLayer(name='Dimages', shape=[None, num_channels, 2**R, 2**R])\n    net = WS(NINLayer(input_layer, name='D%dx' % (R-1), num_units=nf(R-1), nonlinearity=lrelu, W=ilrelu))\n\n    for I in xrange(R-1, 1, -1): # I = R-1, R-2, ..., 2\n        net = LN(WS(Conv2DLayer     (GD(net),     name='D%db'   % I, num_filters=nf(I),   filter_size=3, pad=1, nonlinearity=lrelu, W=ilrelu)))\n        net = LN(WS(Conv2DLayer     (GD(net),     name='D%da'   % I, num_filters=nf(I-1), filter_size=3, pad=1, nonlinearity=lrelu, W=ilrelu)))\n        net =       Downscale2DLayer(net,         name='D%ddn'  % I, scale_factor=2)\n        lod =       Downscale2DLayer(input_layer, name='D%dxs'  % (I-1), scale_factor=2**(R-I))\n        lod =    WS(NINLayer        (lod,         name='D%dx'   % (I-1), num_units=nf(I-1), nonlinearity=lrelu, W=ilrelu))\n        net =       LODSelectLayer  (             name='D%dlod' % (I-1), incomings=[net, lod], cur_lod=cur_lod, first_incoming_lod=R-I-1)\n\n    if mbstat_avg is not None:\n        net = MinibatchStatConcatLayer(net, name='Dstat', func=globals()[mbstat_func], averaging=mbstat_avg)\n\n    net = LN(WS(Conv2DLayer(GD(net), name='D1b', num_filters=nf(1), filter_size=3, pad=1, nonlinearity=lrelu, W=ilrelu)))\n    net = LN(WS(Conv2DLayer(GD(net), name='D1a', num_filters=nf(0), filter_size=4, pad=0, nonlinearity=lrelu, W=ilrelu)))\n\n    if mbdisc_kernels:\n        import minibatch_discrimination\n        net = minibatch_discrimination.MinibatchLayer(net, name='Dmd', num_kernels=mbdisc_kernels)\n\n    output_layers = [WS(DenseLayer(net, name='Dscores', num_units=1, nonlinearity=linear, W=ilinear))]\n    if label_size:\n        output_layers += [WS(DenseLayer(net, name='Dlabels', num_units=label_size, nonlinearity=linear, W=ilinear))]\n    return dict(input_layers=[input_layer], output_layers=output_layers, cur_lod=cur_lod, gdrop_strength=gdrop_strength)\n\n#----------------------------------------------------------------------------\n# Cripped generator for MNIST mode recovery experiment.\n\ndef G_mnist_mode_recovery(\n    num_channels        = 1,\n    resolution          = 32,\n    fmap_base           = 64,\n    fmap_decay          = 1.0,\n    fmap_max            = 256,\n    latent_size         = None,\n    label_size          = 10,\n    normalize_latents   = True,\n    use_wscale          = False,\n    use_pixelnorm       = False,\n    use_batchnorm       = True,\n    tanh_at_end         = True,\n    progressive         = False,\n    **kwargs):\n\n    R = int(np.log2(resolution))\n    assert resolution == 2**R and resolution >= 4\n    cur_lod = theano.shared(np.float32(0.0))\n    def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)\n    def PN(layer): return PixelNormLayer(layer, name=layer.name+'pn') if use_pixelnorm else layer\n    def BN(layer): return lasagne.layers.batch_norm(layer) if use_batchnorm else layer\n    def WS(layer): return WScaleLayer(layer, name=layer.name+'S') if use_wscale else layer\n    if latent_size is None: latent_size = nf(0)\n\n    input_layers = [InputLayer(name='Glatents', shape=[None, latent_size])]\n    net = input_layers[-1]\n    if normalize_latents:\n        net = PixelNormLayer(net, name='Glnorm')\n    if label_size:\n        input_layers += [InputLayer(name='Glabels', shape=[None, label_size])]\n        net = ConcatLayer (name='Gina', incomings=[net, input_layers[-1]])\n\n    net = ReshapeLayer(name='Ginb', incoming=net, shape=[[0], [1], 1, 1])\n    net = PN(BN(WS(Conv2DLayer(net, name='G1a', num_filters=64, filter_size=4, pad='full', nonlinearity=vlrelu, W=irelu))))\n\n    lods  = [net]\n    for I in xrange(2, R): # I = 2, 3, ..., R-1\n        net = Upscale2DLayer(net, name='G%dup' % I, scale_factor=2)\n        net = PN(BN(WS(Conv2DLayer(net, name='G%da'  % I, num_filters=nf(I-1), filter_size=3, pad=1, nonlinearity=vlrelu, W=irelu))))\n        lods += [net]\n\n    if progressive:\n        lods = [WS(Conv2DLayer(l, name='Glod%d' % i, num_filters=num_channels, filter_size=3, pad=1, nonlinearity=linear, W=ilinear)) for i, l in enumerate(reversed(lods))]        # Should be this\n        #lods = [WS(NINLayer(l, name='Glod%d' % i, num_units=num_channels, nonlinearity=linear, W=ilinear)) for i, l in enumerate(reversed(lods))]                                  # .. but this is better\n        output_layer = LODSelectLayer(name='Glod', incomings=lods, cur_lod=cur_lod, first_incoming_lod=0)\n    else:\n        net = WS(Conv2DLayer(net, name='toRGB', num_filters=num_channels, filter_size=3, pad=1, nonlinearity=linear, W=ilinear))                                                    # Should be this\n        #net = WS(NINLayer(net, name='toRGB', num_units=num_channels, nonlinearity=linear, W=ilinear))                                                                              # .. but this is better\n        output_layer = net\n\n    if tanh_at_end:\n        output_layer = NonlinearityLayer(output_layer, name='Gtanh', nonlinearity=tanh)\n\n    return dict(input_layers=input_layers, output_layers=[output_layer], cur_lod=cur_lod)\n\n#----------------------------------------------------------------------------\n# Cripped discriminator for MNIST mode recovery experiment.\n\ndef D_mnist_mode_recovery(\n    num_channels    = 1,\n    resolution      = 32,\n    fmap_base       = 64,\n    fmap_decay      = 1.0,\n    fmap_max        = 256,\n    mbstat_func     = 'Tstdeps',\n    mbstat_avg      = None,         #'all',\n    label_size      = 0,\n    use_wscale      = False,\n    use_gdrop       = False,\n    use_layernorm   = False,\n    use_batchnorm   = True,\n    X               = 2,\n    progressive     = False,\n    **kwargs):\n\n    R = int(np.log2(resolution))\n    assert resolution == 2**R and resolution >= 4\n    cur_lod = theano.shared(np.float32(0.0))\n    gdrop_strength = theano.shared(np.float32(0.0))\n    def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))) // X, fmap_max)\n    def GD(layer): return GDropLayer(layer, name=layer.name+'gd', mode='prop', strength=gdrop_strength) if use_gdrop else layer\n    def LN(layer): return LayerNormLayer(layer, name=layer.name+'ln') if use_layernorm else layer\n    def WS(layer): return WScaleLayer(layer, name=layer.name+'ws') if use_wscale else layer\n    def BN(layer): return lasagne.layers.batch_norm(layer) if use_batchnorm else layer\n\n    net = input_layer = InputLayer(name='Dimages', shape=[None, num_channels, 2**R, 2**R])\n    for I in xrange(R-1, 1, -1): # I = R-1, R-2, ..., 2     (i.e. 4,3,2)\n        net = BN(LN(WS(Conv2DLayer     (GD(net),     name='D%da'   % I, num_filters=nf(I-1), filter_size=3, pad=1, nonlinearity=lrelu, W=ilrelu))))\n        net =       Downscale2DLayer(net,         name='D%ddn'  % I, scale_factor=2)\n        if progressive:\n            lod =       Downscale2DLayer(input_layer, name='D%dxs'  % (I-1), scale_factor=2**(R-I))\n            lod =    WS(NINLayer        (lod,         name='D%dx'   % (I-1), num_units=nf(I-1), nonlinearity=lrelu, W=ilrelu))\n            net =       LODSelectLayer  (             name='D%dlod' % (I-1), incomings=[net, lod], cur_lod=cur_lod, first_incoming_lod=R-I-1)\n\n    if mbstat_avg is not None:\n        net = MinibatchStatConcatLayer(net, name='Dstat', func=globals()[mbstat_func], averaging=mbstat_avg)\n\n    net = FlattenLayer(GD(net), name='Dflatten')\n    output_layers = [WS(DenseLayer(net, name='Dscores', num_units=1, nonlinearity=linear, W=ilinear))]\n\n    if label_size:\n        output_layers += [WS(DenseLayer(net, name='Dlabels', num_units=label_size, nonlinearity=linear, W=ilinear))]\n    return dict(input_layers=[input_layer], output_layers=output_layers, cur_lod=cur_lod, gdrop_strength=gdrop_strength)\n\n#----------------------------------------------------------------------------\n# Load a simple MNIST classifier.\n\ndef load_mnist_classifier(pkl_path):\n    nl = lasagne.nonlinearities.LeakyRectify(0.1)\n\n    net = InputLayer((None, 1, 32, 32))\n    net = Conv2DLayer(net, 32, (3, 3), pad='same', nonlinearity=nl)\n    net = Conv2DLayer(net, 32, (3, 3), pad='same', nonlinearity=nl)\n    net = MaxPool2DLayer(net, (2, 2))\n    net = Conv2DLayer(net, 55, (3, 3), pad='same', nonlinearity=nl)\n    net = Conv2DLayer(net, 55, (3, 3), pad='same', nonlinearity=nl)\n    net = MaxPool2DLayer(net, (2, 2))\n    net = Conv2DLayer(net, 96, (3, 3), pad=0, nonlinearity=nl)\n    net = Conv2DLayer(net, 96, (3, 3), pad=0, nonlinearity=nl)\n    net = MaxPool2DLayer(net, (2, 2))\n    net = DenseLayer(net, num_units=10, nonlinearity=lasagne.nonlinearities.softmax)\n\n    with open(pkl_path, 'rb') as file:\n        lasagne.layers.set_all_param_values(net, cPickle.load(file))\n    return net\n\n#----------------------------------------------------------------------------\n"
  },
  {
    "path": "predict.py",
    "content": "#!/usr/bin/env python2\n# -*- coding: utf-8 -*-\n\"\"\"\nSample code for inference of Progressive Growing of GANs paper\n(https://github.com/tkarras/progressive_growing_of_gans)\nusing a CelebA snapshot\n\"\"\"\n\nfrom __future__ import print_function\nimport argparse\n\nimport torch\nfrom torch.autograd import Variable\n\nfrom model import Generator\n\nfrom utils import scale_image\n\nimport matplotlib.pyplot as plt\n\n\nparser = argparse.ArgumentParser(description='Inference demo')\nparser.add_argument(\n    '--weights',\n    default='100_celeb_hq_network-snapshot-010403.pth',\n    type=str,\n    metavar='PATH',\n    help='path to PyTorch state dict')\nparser.add_argument('--cuda', dest='cuda', action='store_true')\n\nseed = 2809\nuse_cuda = False\n\ntorch.manual_seed(seed)\nif use_cuda:\n    torch.cuda.manual_seed(seed)\n\ndef run(args):\n    global use_cuda\n    \n    print('Loading Generator')\n    model = Generator()\n    model.load_state_dict(torch.load(args.weights))\n    \n    # Generate latent vector\n    x = torch.randn(1, 512, 1, 1)\n    \n    if use_cuda:\n        model = model.cuda()\n        x = x.cuda()\n    \n    x = Variable(x, volatile=True)\n    \n    print('Executing forward pass')\n    images = model(x)\n    \n    if use_cuda:\n        images = images.cpu()\n    \n    images_np = images.data.numpy().transpose(0, 2, 3, 1)\n    image_np = scale_image(images_np[0, ...])\n    \n    print('Output')\n    plt.figure()\n    plt.imshow(image_np)\n\n\ndef main():\n    global use_cuda\n    args = parser.parse_args()\n\n    if not args.weights:\n        print('No PyTorch state dict path privided. Exiting...')\n        return\n    \n    if args.cuda:\n        use_cuda = True\n\n    run(args)\n\nif __name__ == '__main__':\n    main()"
  },
  {
    "path": "pygame_interp_demo.py",
    "content": "#!/usr/bin/env python2\n# -*- coding: utf-8 -*-\n\"\"\"\nSample code for inference of Progressive Growing of GANs paper\n(https://github.com/tkarras/progressive_growing_of_gans)\nusing a CelebA snapshot\n\"\"\"\n\nfrom __future__ import print_function\nimport argparse\n\nimport numpy as np\n\nimport torch\nfrom torch.autograd import Variable\nfrom torch.utils.data.dataloader import DataLoader\n\nfrom model import Generator\nfrom utils import LatentDataset, scale_image_paper\n\nimport pygame\n\ninterp_types = ['gauss', 'slerp']\nuse_cuda = False\n\nparser = argparse.ArgumentParser(description='Interpolation demo')\nparser.add_argument(\n    '--weights',\n    default='100_celeb_hq_network-snapshot-010403.pth',\n    type=str,\n    metavar='PATH',\n    help='path to PyTorch state dict')\nparser.add_argument(\n    '--num_workers',\n    default=1,\n    type=int,\n    help='number of workers for DataLoader')\nparser.add_argument(\n    '--type',\n    default='gauss',\n    choices=interp_types,\n    help='interpolation types: ' + ' | '.join(interp_types) +\n    ' (default: gauss)')\nparser.add_argument(\n    '--nb_latents',\n    default=10,\n    type=int,\n    help='number of latent vectors to generate')\nparser.add_argument(\n    '--filter',\n    default=2,\n    type=int,\n    help='gauss filter length for latent vector smoothing (\\'gaus\\' interp)')\nparser.add_argument(\n    '--interp',\n    default=50,\n    type=int,\n    help='interpolation length between latents (\\'slerp\\' inter)')\nparser.add_argument('--size', default=256, type=int, help='pygame window size')\nparser.add_argument('--seed', default=187, type=int, help='Random seed')\nparser.add_argument(\n    '--cuda', dest='cuda', action='store_true', help='Use GPU for processing')\n\n\ndef run(args):\n    global use_cuda\n\n    # Init PYGame\n    pygame.init()\n    display = pygame.display.set_mode((args.size, args.size), 0)\n\n    print('Loading Generator')\n    model = Generator()\n    model.load_state_dict(torch.load(args.weights))\n\n    if use_cuda:\n        model = model.cuda()\n        pin_memory = True\n    else:\n        pin_memory = False\n\n        # Generate latent data\n    latent_dataset = LatentDataset(\n        interp_type=args.type,\n        nb_latents=args.nb_latents,\n        filter_latents=args.filter,\n        nb_interp=args.interp)\n    latent_loader = DataLoader(\n        latent_dataset,\n        batch_size=1,  # Since we want see it 'live'\n        num_workers=args.num_workers,\n        shuffle=False,\n        pin_memory=pin_memory)\n\n    print('Processing')\n    for i, data in enumerate(latent_loader):\n        if use_cuda:\n            data = data.cuda()\n        data = Variable(data, volatile=True)\n\n        output = model(data)\n\n        if use_cuda:\n            output = output.cpu()\n\n        image = output.data.numpy()[0, ...].transpose(1, 2, 0)\n        image = np.rot90(scale_image_paper(image, [-1, 1], [0, 255]))\n        snapshot = pygame.surfarray.make_surface(image)\n        snapshot = pygame.transform.scale(snapshot, (args.size, args.size))\n        display.blit(snapshot, (0, 0))\n        pygame.display.flip()\n\n\ndef main():\n    global use_cuda\n    args = parser.parse_args()\n\n    if not args.weights:\n        print('No PyTorch state dict path privided. Exiting...')\n        return\n\n    if args.cuda:\n        use_cuda = True\n\n    np.random.seed(args.seed)\n    torch.manual_seed(args.seed)\n    if use_cuda:\n        torch.cuda.manual_seed(args.seed)\n\n    run(args)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "transfer_weights.py",
    "content": "#!/usr/bin/env python2\n# -*- coding: utf-8 -*-\n\"\"\"\nThis work is based on the Theano/Lasagne implementation of\nProgressive Growing of GANs paper from tkarras:\nhttps://github.com/tkarras/progressive_growing_of_gans\n\nScript for weight transfer (lasagne - PyTorch)\n\"\"\"\n\nfrom __future__ import print_function\nimport argparse\n\nimport numpy as np\n\nimport os\nimport cPickle\n\nimport torch\n\nimport theano\nimport theano.tensor as T\nimport lasagne\n\nfrom model import Generator\n\nparser = argparse.ArgumentParser(description='Weight transfer script')\nparser.add_argument(\n    '--weights',\n    default='',\n    type=str,\n    metavar='PATH',\n    help='path to lasagne checkpoint (default: none)')\nparser.add_argument(\n    '--output',\n    type=str,\n    default='./output',\n    help='Directory for storing PyTorch weight output')\n\n\ndef init_model(model, conv_weights, wscale_weights, nin_weights,\n               nin_wscale_weights):\n    for feat_layer, conv_w, wscale_w in zip(model.features, conv_weights,\n                                            wscale_weights):\n        # Get Conv weights and flip them (lasagne default)\n        curr_conv_w = np.copy(conv_w.W.get_value()[:, :, ::-1, ::-1])\n        feat_layer.conv.weight.data = torch.FloatTensor(curr_conv_w)\n        \n        # Get WScale weights\n        feat_layer.wscale.scale.data = torch.FloatTensor(\n            wscale_w.scale.get_value().reshape(1, ))\n        feat_layer.wscale.b.data = torch.FloatTensor(wscale_w.b.get_value())\n\n    # Last layer has to be handeled differently, since a NIN layer was used in\n    # lasagne (basically 1x1 conv in PyTorch)\n    model.output.conv.weight.data = torch.FloatTensor(\n        nin_weights.W.get_value().T).unsqueeze_(2).unsqueeze_(3)\n    model.output.wscale.scale.data = torch.FloatTensor(\n        nin_wscale_weights.scale.get_value().reshape(1, ))\n    model.output.wscale.b.data = torch.FloatTensor(\n        nin_wscale_weights.b.get_value())\n\n\ndef compare_results(model, G, use_cuda=False):\n    from torch.autograd import Variable\n\n    # Create random latent vector\n    example_latents = np.random.randn(1, 512).astype(np.float32)\n\n    # Create theano expressions\n    latents_var = T.TensorType(\n        'float32', [False] * len(example_latents.shape))('latents_var')\n    lod = 0.0\n    images_expr = G.eval(\n        latents_var, min_lod=lod, max_lod=lod, ignore_unsued_inputs=True)\n    gen_fn = theano.function(\n        [latents_var], images_expr, on_unused_input='ignore')\n\n    # Generate reference image\n    images_ref = gen_fn(example_latents[:1])\n\n    # Use same latent vector for our model (we need [1, 512, 1, 1])\n    x = torch.from_numpy(example_latents[:, :, np.newaxis, np.newaxis])\n\n    if use_cuda:\n        x = x.cuda()\n        model = model.cuda()\n\n    x = Variable(x, volatile=True)\n    images = model(x)\n\n    if use_cuda:\n        images = images.cpu()\n\n    images = images.data.numpy()\n    print('Sum of abs error: {}'.format(np.sum(np.abs(images_ref - images))))\n\n\ndef run(args):\n    # Get lasagne weights\n    lasagne_weights_path = args.weights\n\n    print('Loading lasagne weights')\n    with open(lasagne_weights_path, \"rb\") as f:\n        _, _, G = cPickle.load(f)\n\n    # Set output layer\n    lasagne_output_layer = G.find_layer('Glod0S')\n\n    # Get all layers up to output layer\n    lasagne_layers = lasagne.layers.get_all_layers(lasagne_output_layer)\n\n    # Get weigths for each layer type\n    conv_weights = [l for l in lasagne_layers if 'Conv' in str(l)]\n\n    # Skip last wscale layer weights, since these belong to the NIN layer\n    wscale_weights = [l for l in lasagne_layers if 'WScale' in str(l)][:-1]\n\n    # Get NIN weights (these should be the two last layers)\n    nin_weights = lasagne_layers[-2]\n    nin_wscale_weights = lasagne_layers[-1]  # get last wscale layer weight\n\n    print('Initializing PyTorch model')\n    model = Generator()\n    init_model(model, conv_weights, wscale_weights, nin_weights,\n               nin_wscale_weights)\n\n    if args.output:\n        _, model_name = os.path.split(args.weights)\n        model_name = model_name.replace('.pkl', '.pth')\n        output_path = os.path.join(args.output, model_name)\n        print('Saving model to {}'.format(output_path))\n        torch.save(model.state_dict(), output_path)\n\n\ndef main():\n    args = parser.parse_args()\n\n    if not args.weights:\n        print('No lasagne checkpoint defined. Exiting...')\n        return\n\n    if not os.path.exists(args.output):\n        os.mkdir(args.output)\n\n    run(args)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "utils.py",
    "content": "# -*- coding: utf-8 -*-\n\"\"\"\nThis work is based on the Theano/Lasagne implementation of\nProgressive Growing of GANs paper from tkarras:\nhttps://github.com/tkarras/progressive_growing_of_gans\n\nUtils\n\"\"\"\n\nimport numpy as np\n\nfrom scipy import ndimage\nfrom scipy.misc import imsave\n\nimport os\n\nimport torch\nfrom torch.utils.data import Dataset\n\n\ndef scale_image(image):\n    image -= image.min()\n    image /= image.max()\n    image *= 255\n    return image.astype(np.uint8)\n\n\ndef scale_image_paper(image, drange_in, drange_out):\n    '''\n    Re-implemented according to\n    https://github.com/tkarras/progressive_growing_of_gans/blob/master/misc.py\n    '''\n    scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (np.float32(drange_in[1]) - np.float32(drange_in[0]))\n    bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale)\n    image = np.clip(image * scale + bias, 0, 255).astype(np.uint8)\n    return image\n\n\ndef save_images(images, output_dir, start_idx=0):\n    for i, image in enumerate(images):\n        image = scale_image_paper(image, [-1, 1], [0, 255])\n        image = image.transpose(1, 2, 0) # CWH -> WHC\n        image_path = os.path.join(output_dir,\n                                  'image{:04d}.png'.format(i+start_idx))\n        imsave(image_path, image)\n\n\ndef get_gaussian_latents(nb_latents, filter_latents):\n    latents = np.random.randn(nb_latents, 512, 1, 1).astype(np.float32)\n    latents = ndimage.gaussian_filter(latents,\n                                      [filter_latents, 0, 0, 0],\n                                      mode='wrap')\n    latents /= np.sqrt(np.mean(latents**2))\n    return latents\n\n\ndef slerp(val, low, high):\n    '''\n    original: Animating Rotation with Quaternion Curves, Ken Shoemake\n    \n    https://arxiv.org/abs/1609.04468\n    Code: https://github.com/soumith/dcgan.torch/issues/14, Tom White\n    '''\n    omega = np.arccos(np.dot(low/np.linalg.norm(low), high/np.linalg.norm(high)))\n    so = np.sin(omega)\n    return np.sin((1.0-val)*omega) / so * low + np.sin(val*omega)/so * high\n\n\ndef get_slerp_interp(nb_latents, nb_interp):\n    low = np.random.randn(512)\n    \n    latent_interps = np.empty(shape=(0, 512), dtype=np.float32)\n    for _ in range(nb_latents):\n        high = np.random.randn(512)#low + np.random.randn(512) * 0.7\n        \n        interp_vals = np.linspace(1./nb_interp, 1, num=nb_interp)\n        latent_interp = np.array([slerp(v, low, high) for v in interp_vals],\n                                  dtype=np.float32)\n        \n        latent_interps = np.vstack((latent_interps, latent_interp))\n        low = high\n\n    return latent_interps[:, :, np.newaxis, np.newaxis]\n\n\nclass LatentDataset(Dataset):\n    def __init__(self, interp_type='gauss', nb_latents=1, filter_latents=3,\n                 nb_interp=50):\n        if interp_type=='gauss':\n            latents = get_gaussian_latents(nb_latents, filter_latents)\n        elif interp_type=='slerp':\n            latents = get_slerp_interp(nb_latents, nb_interp)\n        self.data = torch.from_numpy(latents)\n\n    def __getitem__(self, index):\n        return self.data[index]\n\n    def __len__(self):\n        return len(self.data)\n"
  }
]