[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# Data\n#data/\n#checkpoints/\n#tensorboard/\n#samples/\n*.log\n*.slurm\n*.ipynb\n*.out\n\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nenv/\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# dotenv\n.env\n\n# virtualenv\n.venv\nvenv/\nENV/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n\n# other\n.DS_Store\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2018 JTan\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# generative-compression\n\nTensorFlow Implementation for learned compression of images using Generative Adversarial Networks. The method was developed by Agustsson et. al. in [Generative Adversarial Networks for Extreme Learned Image Compression](https://arxiv.org/abs/1804.02958). The proposed idea is very interesting and their approach is well-described.\n\n![Results from authors using C=4 bottleneck channels, global compression without semantic maps on the Kodak dataset](images/authors/kodak_GC_C4.png)\n\n-----------------------------\n## Usage\nThe code depends on [Tensorflow 1.8](https://github.com/tensorflow/tensorflow)\n```bash\n# Clone\n$ git clone https://github.com/Justin-Tan/generative-compression.git\n$ cd generative-compression\n\n# To train, check command line arguments\n$ python3 train.py -h\n# Run\n$ python3 train.py -opt momentum --name my_network\n```\nTraining is conducted with batch size 1 and reconstructed samples / tensorboard summaries will be periodically written every certain number of steps (default is 128). Checkpoints are saved every 10 epochs. \n\nTo compress a single image:\n```bash\n# Compress\n$ python3 compress.py -r /path/to/model/checkpoint -i /path/to/image -o path/to/output/image\n```\nThe compressed image will be saved as a side-by-side comparison with the original image under the path specified in `directories.samples` in `config.py`. If you are using the provided pretrained model with noise sampling, retain the hyperparameters under `config_test` in `config.py`, otherwise the parameters during test time should match the parameters set during training.\n\n*Note:* If you're willing to pay higher bitrates in exchange for much higher perceptual quality, you may want to check out this implementation of [\"High-Fidelity Generative Image Compression\"](https://github.com/Justin-Tan/high-fidelity-generative-compression), which is in the same vein but operates in higher bitrate regimes. Furthermore, it is capable of working with images of arbitrary size and resolution.\n\n## Results\nThese globally compressed images are from the test split of the Cityscapes `leftImg8bit` dataset. The decoder seems to hallunicate greenery in buildings, and vice-versa. \n\n#### Global conditional compression: Multiscale discriminator + feature-matching losses, C=8 channels - (compression to 0.072 bbp)\n**Epoch 38**\n![cityscapes_e38](images/results/cGAN_epoch38.png)\n**Epoch 44**\n![cityscapes_e44](images/results/cGAN_epoch44.png)\n**Epoch 47**\n![cityscapes_e44](images/results/cGAN_epoch47.png)\n**Epoch 48**\n![cityscapes_e44](images/results/cGAN_epoch48.png)\n```\nShow quantized C=4,8,16 channels image comparison\n```\n| Generator Loss | Discriminator Loss |\n|-------|-------|\n|![gen_loss](images/results/generator_loss.png) | ![discriminator_loss](images/results/discriminator_loss.png) |\n\n## Pretrained Model\nYou can find the pretrained model for global compression with a channel bottleneck of `C = 8` (corresponding to a 0.072 bpp representation) below. The model was subject to the multiscale discriminator and feature matching losses. Noise is sampled from a 128-dim normal distribution, passed through a DCGAN-like generator and concatenated to the quantized image representation. The model was trained for 55 epochs on the train split of the [Cityscapes](https://www.cityscapes-dataset.com/) `leftImg8bit` dataset for the images and used the `gtFine` dataset for the corresponding semantic maps. This should work with the default settings under `config_test` in `config.py`.\n\nA pretrained model for global conditional compression with a `C=8` bottleneck is also included. This model was, trained for 50 epochs with the same losses as above. Reconstruction is conditioned on semantic label maps (see the `cGAN/` folder and 'Conditional GAN usage').\n\n* [Noise sampling model](https://drive.google.com/open?id=1gy6NJqlxflLDI1g9Rsileva-8G1ifsEC)\n* [Conditional GAN model](https://drive.google.com/open?id=1L3G4l8IQukNrsf3hjHv5xRhpNE77TD2k)\n\n** Warning: Tensorflow 1.3 was used to train the models, but it appears to load without problems on Tensorflow 1.8. Please raise an issue if you have any problems.\n\n## Details / extensions\nThe network architectures are based on the description provided in the appendix of the original paper, which is in turn based on the paper [Perceptual Losses for Real-Time Style Transfer\nand Super-Resolution](https://cs.stanford.edu/people/jcjohns/eccv16/) The multiscale discriminator loss used was originally proposed in the project [High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs](https://tcwang0509.github.io/pix2pixHD/), consult `network.py` for the implementation. If you would like to add an extension you can create a new method under the `Network` class, e.g.\n\n```python\n@staticmethod\ndef my_generator(z, **kwargs):\n    \"\"\"\n    Inputs:\n    z: sampled noise\n\n    Returns:\n    upsampled image\n    \"\"\"\n\n    return tf.random_normal([z.get_shape()[0], height, width, channels], seed=42)\n```\nTo change hyperparameters/toggle features use the knobs in `config.py`. (Bad form maybe. but I find it easier than a 20-line `argparse` specification).\n\n### Data / Setup\nTraining was done using the [ADE 20k dataset](http://groups.csail.mit.edu/vision/datasets/ADE20K/) and the [Cityscapes leftImg8bit dataset](https://www.cityscapes-dataset.com/). In the former case images are rescaled to width `512` px, and in the latter images are [resampled to `[512 x 1024]` prior to training](https://www.imagemagick.org/script/command-line-options.php#resample). An example script for resampling using `Imagemagick` is provided under `data/`. In each case, you will need to create a Pandas dataframe containing a single column: `path`, which holds the absolute/relative path to the images. This should be saved as a `HDF5` file, and you should provide the path to this under the `directories` class in `config.py`. Examples for the Cityscapes dataset are provided in the `data` directory. \n\n### Conditional GAN usage\nThe conditional GAN implementation for global compression is in the `cGAN` directory. The cGAN implementation appears to yield images with the highest image quality, but this implementation remains experimental. In this implementation generation is conditioned on the information in the semantic label map of the selected image. You will need to download the `gtFine` dataset of annotation maps and append a separate column `semantic_map_paths` to the Pandas dataframe pointing to the corresponding images from the `gtFine` dataset.\n\n### Dependencies\n* Python 3.6\n* [Pandas](https://pandas.pydata.org/)\n* [TensorFlow 1.8](https://github.com/tensorflow/tensorflow)\n\n### Todo:\n* Incorporate GAN noise sampling into the reconstructed image. The authors state that this step is optional and that the sampled noise is combined with the quantized representation but don't provide further details. Currently the model samples from a normal distribution and upsamples this using a DCGAN-like generator (see `network.py`) to be concatenated with the quantized image representation `w_hat`, but this appears to substantially increase the 'hallunication factor' in the reconstructed images.\n* Integrate VGG loss.\n* Experiment with WGAN-GP. \n* Experiment with spectral normalization/\n* Experiment with different generator architectures with noise sampling. \n* Extend to selective compression using semantic maps (contributions welcome).\n\n### Resources\n* [Generative Adversarial Networks for Extreme Learned Image Compression](https://data.vision.ee.ethz.ch/aeirikur/extremecompression/#publication)\n* [CycleGAN](https://arxiv.org/pdf/1703.10593.pdf)\n* [High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs](https://tcwang0509.github.io/pix2pixHD/)\n\n## More Results\n#### Global compression: Noise sampling, multiscale discriminator + feature-matching losses, C=8 channels - Compression to 0.072 bbp\n![cityscapes_e45](images/results/noiseE45.png)\n![cityscapes_e47](images/results/cGANe47.png)\n![cityscapes_e51](images/results/noiseE51.png)\n![cityscapes_e53](images/results/noiseE53.png)\n![cityscapes_e54](images/results/noiseE54.png)\n![cityscapes_e55](images/results/noiseE55.png)\n![cityscapes_e56](images/results/noiseE56.png)\n"
  },
  {
    "path": "cGAN/config.py",
    "content": "#!/usr/bin/env python3\n\nclass config_train(object):\n    mode = 'gan-train'\n    num_epochs = 512\n    batch_size = 1\n    ema_decay = 0.999\n    G_learning_rate = 2e-4\n    D_learning_rate = 2e-4\n    lr_decay_rate = 2e-5\n    momentum = 0.9\n    weight_decay = 5e-4\n    noise_dim = 128\n    optimizer = 'adam'\n    kernel_size = 3\n    diagnostic_steps = 256\n\n    # WGAN\n    gradient_penalty = True\n    lambda_gp = 10\n    weight_clipping = False\n    max_c = 1e-2\n    n_critic_iterations = 20\n\n    # Compression\n    lambda_X = 12\n    channel_bottleneck = 16\n    sample_noise = False\n    use_vanilla_GAN = False\n    use_feature_matching_loss = True\n    upsample_dim = 256\n    multiscale = True\n    feature_matching_weight = 10\n    use_conditional_GAN = False\n\nclass config_test(object):\n    mode = 'gan-test'\n    num_epochs = 512\n    batch_size = 1\n    ema_decay = 0.999\n    G_learning_rate = 2e-4\n    D_learning_rate = 2e-4\n    lr_decay_rate = 2e-5\n    momentum = 0.9\n    weight_decay = 5e-4\n    noise_dim = 128\n    optimizer = 'adam'\n    kernel_size = 3\n    diagnostic_steps = 256\n\n    # WGAN\n    gradient_penalty = True\n    lambda_gp = 10\n    weight_clipping = False\n    max_c = 1e-2\n    n_critic_iterations = 5\n\n    # Compression\n    lambda_X = 12\n    channel_bottleneck = 8\n    sample_noise = True\n    use_vanilla_GAN = False\n    use_feature_matching_loss = True\n    upsample_dim = 256\n    multiscale = True\n    feature_matching_weight = 10\n    use_conditional_GAN = False\n\nclass directories(object):\n    # train = 'data/ADE20K_paths_train.h5'\n    # test = 'data/ADE20K_paths_test.h5'\n    train = 'data/sm_cityscapes_paths_train.h5'\n    test = 'data/cityscapes_paths_test.h5'\n    val = 'data/cityscapes_paths_val.h5'\n    tensorboard = 'tensorboard'\n    checkpoints = 'checkpoints'\n    checkpoints_best = 'checkpoints/best'\n    samples = 'samples/cityscapes'\n\n"
  },
  {
    "path": "cGAN/data.py",
    "content": "#!/usr/bin/python3\nimport tensorflow as tf\nimport numpy as np\nimport pandas as pd\nfrom config import directories\n\nclass Data(object):\n\n    @staticmethod\n    def load_dataframe(filename, load_semantic_maps=False):\n        df = pd.read_hdf(filename, key='df').sample(frac=1).reset_index(drop=True)\n\n        if load_semantic_maps:\n            return df['path'].values, df['semantic_map_path'].values\n        else:\n            return df['path'].values\n\n    @staticmethod\n    def load_dataset(image_paths, batch_size, test=False, augment=False, downsample=False,\n            training_dataset='cityscapes', use_conditional_GAN=False, **kwargs):\n\n        def _augment(image):\n            # On-the-fly data augmentation\n            image = tf.image.random_brightness(image, max_delta=0.1)\n            image = tf.image.random_contrast(image, 0.5, 1.5)\n            image = tf.image.random_flip_left_right(image)\n\n            return image\n\n        def _parser(image_path, semantic_map_path=None):\n\n            def _aspect_preserving_width_resize(image, width=512):\n                height_i = tf.shape(image)[0]\n                # width_i = tf.shape(image)[1]\n                # ratio = tf.to_float(width_i) / tf.to_float(height_i)\n                # new_height = tf.to_int32(tf.to_float(height_i) / ratio)\n                new_height = height_i - tf.floormod(height_i, 16)\n                return tf.image.resize_image_with_crop_or_pad(image, new_height, width)\n\n            def _image_decoder(path):\n                im = tf.image.decode_png(tf.read_file(path), channels=3)\n                im = tf.image.convert_image_dtype(im, dtype=tf.float32)\n                return 2 * im - 1 # [0,1] -> [-1,1] (tanh range)\n                    \n            image = _image_decoder(image_path)\n\n            # Explicitly set the shape if you want a sanity check\n            # or if you are using your own custom dataset, otherwise\n            # the model is shape-agnostic as it is fully convolutional\n\n            # im.set_shape([512,1024,3])  # downscaled cityscapes\n\n            if use_conditional_GAN:\n                # Semantic map only enabled for cityscapes\n                semantic_map = _image_decoder(semantic_map_path)           \n\n            if training_dataset == 'ADE20k':\n                image = _aspect_preserving_width_resize(image)\n                # im.set_shape([None,512,3])\n\n            if use_conditional_GAN:\n                if training_dataset == 'ADE20k':\n                    raise NotImplementedError('Conditional generation not implemented for ADE20k dataset.')\n                return image, semantic_map\n            else:\n                return image\n            \n\n        print('Training on', training_dataset)\n\n        if use_conditional_GAN:\n            dataset = tf.data.Dataset.from_tensor_slices((image_paths, kwargs['semantic_map_paths']))\n        else:\n            dataset = tf.data.Dataset.from_tensor_slices(image_paths)\n\n        dataset = dataset.map(_parser)\n        dataset = dataset.shuffle(buffer_size=8)\n        dataset = dataset.batch(batch_size)\n\n        if test:\n            dataset = dataset.repeat()\n\n        return dataset\n\n    @staticmethod\n    def load_cGAN_dataset(image_paths, semantic_map_paths, batch_size, test=False, augment=False, downsample=False,\n            training_dataset='cityscapes'):\n        \"\"\"\n        Load image dataset with semantic label maps for conditional GAN\n        \"\"\" \n\n        def _parser(image_path, semantic_map_path):\n            def _aspect_preserving_width_resize(image, width=512):\n                # If training on ADE20k\n                height_i = tf.shape(image)[0]\n                new_height = height_i - tf.floormod(height_i, 16)\n                    \n                return tf.image.resize_image_with_crop_or_pad(image, new_height, width)\n\n            def _image_decoder(path):\n                im = tf.image.decode_png(tf.read_file(image_path), channels=3)\n                im = tf.image.convert_image_dtype(im, dtype=tf.float32)\n                return 2 * im - 1 # [0,1] -> [-1,1] (tanh range)\n\n\n            image, semantic_map = _image_decoder(image_path), _image_decoder(semantic_map_path)\n            \n            print('Training on', training_dataset)\n            if training_dataset is 'ADE20k':\n                image = _aspect_preserving_width_resize(image)\n                semantic_map = _aspect_preserving_width_resize(semantic_map)\n\n            # im.set_shape([512,1024,3])  # downscaled cityscapes\n\n            return image, semantic_map\n\n        dataset = tf.data.Dataset.from_tensor_slices(image_paths, semantic_map_paths)\n        dataset = dataset.map(_parser)\n        dataset = dataset.shuffle(buffer_size=8)\n        dataset = dataset.batch(batch_size)\n\n        if test:\n            dataset = dataset.repeat()\n\n        return dataset\n\n    @staticmethod\n    def load_inference(filenames, labels, batch_size, resize=(32,32)):\n\n        # Single image estimation over multiple stochastic forward passes\n\n        def _preprocess_inference(image_path, label, resize=(32,32)):\n            # Preprocess individual images during inference\n            image_path = tf.squeeze(image_path)\n            image = tf.image.decode_png(tf.read_file(image_path))\n            image = tf.image.convert_image_dtype(image, dtype=tf.float32)\n            image = tf.image.per_image_standardization(image)\n            image = tf.image.resize_images(image, size=resize)\n\n            return image, label\n\n        dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))\n        dataset = dataset.map(_preprocess_inference)\n        dataset = dataset.batch(batch_size)\n        \n        return dataset\n\n"
  },
  {
    "path": "cGAN/model.py",
    "content": "#!/usr/bin/python3\n    \nimport tensorflow as tf\nimport numpy as np\nimport glob, time, os\n\nfrom network import Network\nfrom data import Data\nfrom config import directories\nfrom utils import Utils\n\nclass Model():\n    def __init__(self, config, paths, dataset, name='gan_compression', evaluate=False):\n        # Build the computational graph\n        \n        print('Building computational graph ...')\n        self.G_global_step = tf.Variable(0, trainable=False)\n        self.D_global_step = tf.Variable(0, trainable=False)\n        self.handle = tf.placeholder(tf.string, shape=[])\n        self.training_phase = tf.placeholder(tf.bool)\n\n        # >>> Data handling\n        self.path_placeholder = tf.placeholder(paths.dtype, paths.shape)\n        self.test_path_placeholder = tf.placeholder(paths.dtype)            \n\n        self.semantic_map_path_placeholder = tf.placeholder(paths.dtype, paths.shape)\n        self.test_semantic_map_path_placeholder = tf.placeholder(paths.dtype)  \n\n        train_dataset = Data.load_dataset(self.path_placeholder,\n                                          config.batch_size,\n                                          augment=False,\n                                          training_dataset=dataset,\n                                          use_conditional_GAN=config.use_conditional_GAN,\n                                          semantic_map_paths=self.semantic_map_path_placeholder)\n\n        test_dataset = Data.load_dataset(self.test_path_placeholder,\n                                         config.batch_size,\n                                         augment=False,\n                                         training_dataset=dataset,\n                                         use_conditional_GAN=config.use_conditional_GAN,\n                                         semantic_map_paths=self.test_semantic_map_path_placeholder,\n                                         test=True)\n\n        self.iterator = tf.data.Iterator.from_string_handle(self.handle,\n                                                                    train_dataset.output_types,\n                                                                    train_dataset.output_shapes)\n\n        self.train_iterator = train_dataset.make_initializable_iterator()\n        self.test_iterator = test_dataset.make_initializable_iterator()\n\n        if config.use_conditional_GAN:\n            self.example, self.semantic_map = self.iterator.get_next()\n        else:\n            self.example = self.iterator.get_next()\n\n        # Global generator: Encode -> quantize -> reconstruct\n        # =======================================================================================================>>>\n        with tf.variable_scope('generator'):\n            self.feature_map = Network.encoder(self.example, config, self.training_phase, config.channel_bottleneck)\n            self.w_hat = Network.quantizer(self.feature_map, config)\n\n            if config.use_conditional_GAN:\n                self.semantic_feature_map = Network.encoder(self.semantic_map, config, self.training_phase, \n                    config.channel_bottleneck, scope='semantic_map')\n                self.w_hat_semantic = Network.quantizer(self.semantic_feature_map, config, scope='semantic_map')\n\n                self.w_hat = tf.concat([self.w_hat, self.w_hat_semantic], axis=-1)\n\n            if config.sample_noise is True:\n                print('Sampling noise...')\n                # noise_prior = tf.contrib.distributions.Uniform(-1., 1.)\n                # self.noise_sample = noise_prior.sample([tf.shape(self.example)[0], config.noise_dim])\n                noise_prior = tf.contrib.distributions.MultivariateNormalDiag(loc=tf.zeros([config.noise_dim]), scale_diag=tf.ones([config.noise_dim]))\n                v = noise_prior.sample(tf.shape(self.example)[0])\n                Gv = Network.dcgan_generator(v, config, self.training_phase, C=config.channel_bottleneck, upsample_dim=config.upsample_dim)\n                self.z = tf.concat([self.w_hat, Gv], axis=-1)\n            else:\n                self.z = self.w_hat\n\n            self.reconstruction = Network.decoder(self.z, config, self.training_phase, C=config.channel_bottleneck)\n\n        print('Real image shape:', self.example.get_shape().as_list())\n        print('Reconstruction shape:', self.reconstruction.get_shape().as_list())\n\n        # Pass generated, real images to discriminator\n        # =======================================================================================================>>>\n\n        if config.use_conditional_GAN:\n            # Model conditional distribution\n            self.example = tf.concat([self.example, self.semantic_map], axis=-1)\n            self.reconstruction = tf.concat([self.reconstruction, self.semantic_map], axis=-1)\n\n        if config.multiscale:\n            D_x, D_x2, D_x4, *Dk_x = Network.multiscale_discriminator(self.example, config, self.training_phase, \n                use_sigmoid=config.use_vanilla_GAN, mode='real')\n            D_Gz, D_Gz2, D_Gz4, *Dk_Gz = Network.multiscale_discriminator(self.reconstruction, config, self.training_phase, \n                use_sigmoid=config.use_vanilla_GAN, mode='reconstructed', reuse=True)\n        else:\n            D_x = Network.discriminator(self.example, config, self.training_phase, use_sigmoid=config.use_vanilla_GAN)\n            D_Gz = Network.discriminator(self.reconstruction, config, self.training_phase, use_sigmoid=config.use_vanilla_GAN, reuse=True)\n         \n        # Loss terms \n        # =======================================================================================================>>>\n        if config.use_vanilla_GAN is True:\n            # Minimize JS divergence\n            D_loss_real = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=D_x,\n                labels=tf.ones_like(D_x)))\n            D_loss_gen = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=D_Gz,\n                labels=tf.zeros_like(D_Gz)))\n            self.D_loss = D_loss_real + D_loss_gen\n            # G_loss = max log D(G(z))\n            self.G_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=D_Gz,\n                labels=tf.ones_like(D_Gz)))\n        else:\n            # Minimize $\\chi^2$ divergence\n            self.D_loss = tf.reduce_mean(tf.square(D_x - 1.)) + tf.reduce_mean(tf.square(D_Gz))\n            self.G_loss = tf.reduce_mean(tf.square(D_Gz - 1.))\n\n            if config.multiscale:\n                self.D_loss += tf.reduce_mean(tf.square(D_x2 - 1.)) + tf.reduce_mean(tf.square(D_x4 - 1.))\n                self.D_loss += tf.reduce_mean(tf.square(D_Gz2)) + tf.reduce_mean(tf.square(D_Gz4))\n\n        distortion_penalty = config.lambda_X * tf.losses.mean_squared_error(self.example, self.reconstruction)\n        self.G_loss += distortion_penalty\n\n        if config.use_feature_matching_loss:  # feature extractor for generator\n            D_x_layers, D_Gz_layers = [j for i in Dk_x for j in i], [j for i in Dk_Gz for j in i]\n            feature_matching_loss = tf.reduce_sum([tf.reduce_mean(tf.abs(Dkx-Dkz)) for Dkx, Dkz in zip(D_x_layers, D_Gz_layers)])\n            self.G_loss += config.feature_matching_weight * feature_matching_loss\n\n        \n        # Optimization\n        # =======================================================================================================>>>\n        G_opt = tf.train.AdamOptimizer(learning_rate=config.G_learning_rate, beta1=0.5)\n        D_opt = tf.train.AdamOptimizer(learning_rate=config.D_learning_rate, beta1=0.5)\n\n        theta_G = Utils.scope_variables('generator')\n        theta_D = Utils.scope_variables('discriminator')\n        print('Generator parameters:', theta_G)\n        print('Discriminator parameters:', theta_D)\n        G_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='generator')\n        D_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='discriminator')\n\n        # Execute the update_ops before performing the train_step\n        with tf.control_dependencies(G_update_ops):\n            self.G_opt_op = G_opt.minimize(self.G_loss, name='G_opt', global_step=self.G_global_step, var_list=theta_G)\n        with tf.control_dependencies(D_update_ops):\n            self.D_opt_op = D_opt.minimize(self.D_loss, name='D_opt', global_step=self.D_global_step, var_list=theta_D)\n\n        G_ema = tf.train.ExponentialMovingAverage(decay=config.ema_decay, num_updates=self.G_global_step)\n        G_maintain_averages_op = G_ema.apply(theta_G)\n        D_ema = tf.train.ExponentialMovingAverage(decay=config.ema_decay, num_updates=self.D_global_step)\n        D_maintain_averages_op = D_ema.apply(theta_D)\n\n        with tf.control_dependencies(G_update_ops+[self.G_opt_op]):\n            self.G_train_op = tf.group(G_maintain_averages_op)\n        with tf.control_dependencies(D_update_ops+[self.D_opt_op]):\n            self.D_train_op = tf.group(D_maintain_averages_op)\n\n        # >>> Monitoring\n        # tf.summary.scalar('learning_rate', learning_rate)\n        tf.summary.scalar('generator_loss', self.G_loss)\n        tf.summary.scalar('discriminator_loss', self.D_loss)\n        tf.summary.scalar('distortion_penalty', distortion_penalty)\n        if config.use_feature_matching_loss:\n            tf.summary.scalar('feature_matching_loss', feature_matching_loss)\n        tf.summary.scalar('G_global_step', self.G_global_step)\n        tf.summary.scalar('D_global_step', self.D_global_step)\n        tf.summary.image('real_images', self.example[:,:,:,:3], max_outputs=4)\n        tf.summary.image('compressed_images', self.reconstruction[:,:,:,:3], max_outputs=4)\n        if config.use_conditional_GAN:\n            tf.summary.image('semantic_map', self.semantic_map, max_outputs=4)\n        self.merge_op = tf.summary.merge_all()\n\n        self.train_writer = tf.summary.FileWriter(\n            os.path.join(directories.tensorboard, '{}_train_{}'.format(name, time.strftime('%d-%m_%I:%M'))), graph=tf.get_default_graph())\n        self.test_writer = tf.summary.FileWriter(\n            os.path.join(directories.tensorboard, '{}_test_{}'.format(name, time.strftime('%d-%m_%I:%M'))))\n"
  },
  {
    "path": "cGAN/network.py",
    "content": "\"\"\" Modular components of computational graph\n    JTan 2018\n\"\"\"\nimport tensorflow as tf\nfrom utils import Utils\n\nclass Network(object):\n\n    @staticmethod\n    def encoder(x, config, training, C, reuse=False, actv=tf.nn.relu, scope='image'):\n        \"\"\"\n        Process image x ([512,1024]) into a feature map of size W/16 x H/16 x C\n         + C:       Bottleneck depth, controls bpp\n         + Output:  Projection onto C channels, C = {2,4,8,16}\n        \"\"\"\n        init = tf.contrib.layers.xavier_initializer()\n        print('<------------ Building global {} generator architecture ------------>'.format(scope))\n\n        def conv_block(x, filters, kernel_size=[3,3], strides=2, padding='same', actv=actv, init=init):\n            bn_kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False}\n            in_kwargs = {'center':True, 'scale': True}\n            x = tf.layers.conv2d(x, filters, kernel_size, strides=strides, padding=padding, activation=None)\n            # x = tf.layers.batch_normalization(x, **bn_kwargs)\n            x = tf.contrib.layers.instance_norm(x, **in_kwargs)\n            x = actv(x)\n            return x\n\n        with tf.variable_scope('encoder_{}'.format(scope), reuse=reuse):\n\n            # Run convolutions\n            f = [60, 120, 240, 480, 960]\n            x = tf.pad(x, [[0, 0], [3, 3], [3, 3], [0, 0]], 'REFLECT')\n            out = conv_block(x, filters=f[0], kernel_size=7, strides=1, padding='VALID', actv=actv)\n\n            out = conv_block(out, filters=f[1], kernel_size=3, strides=2, actv=actv)\n            out = conv_block(out, filters=f[2], kernel_size=3, strides=2, actv=actv)\n            out = conv_block(out, filters=f[3], kernel_size=3, strides=2, actv=actv)\n            out = conv_block(out, filters=f[4], kernel_size=3, strides=2, actv=actv)\n\n            # Project channels onto space w/ dimension C\n            # Feature maps have dimension W/16 x H/16 x C\n            out = tf.pad(out, [[0, 0], [1, 1], [1, 1], [0, 0]], 'REFLECT')\n            feature_map = conv_block(out, filters=C, kernel_size=3, strides=1, padding='VALID', actv=actv)\n            \n            return feature_map\n\n\n    @staticmethod\n    def quantizer(w, config, reuse=False, temperature=1, L=5, scope='image'):\n        \"\"\"\n        Quantize feature map over L centers to obtain discrete $\\hat{w}$\n         + Centers: {-2,-1,0,1,2}\n         + TODO:    Toggle learnable centers?\n        \"\"\"\n        with tf.variable_scope('quantizer_{}'.format(scope, reuse=reuse)):\n\n            centers = tf.cast(tf.range(-2,3), tf.float32)\n            # Partition W into the Voronoi tesellation over the centers\n            w_stack = tf.stack([w for _ in range(L)], axis=-1)\n            w_hard = tf.cast(tf.argmin(tf.abs(w_stack - centers), axis=-1), tf.float32) + tf.reduce_min(centers)\n\n            smx = tf.nn.softmax(-1.0/temperature * tf.abs(w_stack - centers), dim=-1)\n            # Contract last dimension\n            w_soft = tf.einsum('ijklm,m->ijkl', smx, centers)  # w_soft = tf.tensordot(smx, centers, axes=((-1),(0)))\n\n            # Treat quantization as differentiable for optimization\n            w_bar = tf.round(tf.stop_gradient(w_hard - w_soft) + w_soft)\n\n            return w_bar\n\n\n    @staticmethod\n    def decoder(w_bar, config, training, C, reuse=False, actv=tf.nn.relu, channel_upsample=960):\n        \"\"\"\n        Attempt to reconstruct the image from the quantized representation w_bar.\n        Generated image should be consistent with the true image distribution while\n        recovering the specific encoded image\n        + C:        Bottleneck depth, controls bpp - last dimension of encoder output\n        + TODO:     Concatenate quantized w_bar with noise sampled from prior\n        \"\"\"\n        init = tf.contrib.layers.xavier_initializer()\n\n        def residual_block(x, n_filters, kernel_size=3, strides=1, actv=actv):\n            init = tf.contrib.layers.xavier_initializer()\n            # kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False}\n            strides = [1,1]\n            identity_map = x\n\n            p = int((kernel_size-1)/2)\n            res = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]], 'REFLECT')\n            res = tf.layers.conv2d(res, filters=n_filters, kernel_size=kernel_size, strides=strides,\n                    activation=None, padding='VALID')\n            res = actv(tf.contrib.layers.instance_norm(res))\n\n            res = tf.pad(res, [[0, 0], [p, p], [p, p], [0, 0]], 'REFLECT')\n            res = tf.layers.conv2d(res, filters=n_filters, kernel_size=kernel_size, strides=strides,\n                    activation=None, padding='VALID')\n            res = tf.contrib.layers.instance_norm(res)\n\n            assert res.get_shape().as_list() == identity_map.get_shape().as_list(), 'Mismatched shapes between input/output!'\n            out = tf.add(res, identity_map)\n\n            return out\n\n        def upsample_block(x, filters, kernel_size=[3,3], strides=2, padding='same', actv=actv, batch_norm=False):\n            bn_kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False}\n            in_kwargs = {'center':True, 'scale': True}\n            x = tf.layers.conv2d_transpose(x, filters, kernel_size, strides=strides, padding=padding, activation=None)\n            if batch_norm is True:\n                x = tf.layers.batch_normalization(x, **bn_kwargs)\n            else:\n                x = tf.contrib.layers.instance_norm(x, **in_kwargs)\n            x = actv(x)\n\n            return x\n\n        # Project channel dimension of w_bar to higher dimension\n        # W_pc = tf.get_variable('W_pc_{}'.format(C), shape=[C, channel_upsample], initializer=init)\n        # upsampled = tf.einsum('ijkl,lm->ijkm', w_bar, W_pc)\n        with tf.variable_scope('decoder', reuse=reuse):\n            w_bar = tf.pad(w_bar, [[0, 0], [1, 1], [1, 1], [0, 0]], 'REFLECT')\n            upsampled = Utils.conv_block(w_bar, filters=960, kernel_size=3, strides=1, padding='VALID', actv=actv)\n            \n            # Process upsampled feature map with residual blocks\n            res = residual_block(upsampled, 960, actv=actv)\n            res = residual_block(res, 960, actv=actv)\n            res = residual_block(res, 960, actv=actv)\n            res = residual_block(res, 960, actv=actv)\n            res = residual_block(res, 960, actv=actv)\n            res = residual_block(res, 960, actv=actv)\n            res = residual_block(res, 960, actv=actv)\n            res = residual_block(res, 960, actv=actv)\n            res = residual_block(res, 960, actv=actv)\n\n            # Upsample to original dimensions - mirror decoder\n            f = [480, 240, 120, 60]\n\n            ups = upsample_block(res, f[0], 3, strides=[2,2], padding='same')\n            ups = upsample_block(ups, f[1], 3, strides=[2,2], padding='same')\n            ups = upsample_block(ups, f[2], 3, strides=[2,2], padding='same')\n            ups = upsample_block(ups, f[3], 3, strides=[2,2], padding='same')\n            \n            ups = tf.pad(ups, [[0, 0], [3, 3], [3, 3], [0, 0]], 'REFLECT')\n            ups = tf.layers.conv2d(ups, 3, kernel_size=7, strides=1, padding='VALID')\n\n            out = tf.nn.tanh(ups)\n\n            return out\n\n\n    @staticmethod\n    def discriminator(x, config, training, reuse=False, actv=tf.nn.leaky_relu, use_sigmoid=False, ksize=4):\n        # x is either generator output G(z) or drawn from the real data distribution\n        # Patch-GAN discriminator based on arXiv 1711.11585\n        # bn_kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False}\n        in_kwargs = {'center':True, 'scale':True, 'activation_fn':actv}\n\n        print('Shape of x:', x.get_shape().as_list())\n\n        with tf.variable_scope('discriminator', reuse=reuse):\n            c1 = tf.layers.conv2d(x, 64, kernel_size=ksize, strides=2, padding='same', activation=actv)\n            c2 = tf.layers.conv2d(c1, 128, kernel_size=ksize, strides=2, padding='same')\n            c2 = actv(tf.contrib.layers.instance_norm(c2, **in_kwargs))\n            c3 = tf.layers.conv2d(c2, 256, kernel_size=ksize, strides=2, padding='same')\n            c3 = actv(tf.contrib.layers.instance_norm(c3, **in_kwargs))\n            c4 = tf.layers.conv2d(c3, 512, kernel_size=ksize, strides=2, padding='same')\n            c4 = actv(tf.contrib.layers.instance_norm(c4, **in_kwargs))\n\n            out = tf.layers.conv2d(c4, 1, kernel_size=ksize, strides=1, padding='same')\n\n            if use_sigmoid is True:  # Otherwise use LS-GAN\n                out = tf.nn.sigmoid(out)\n\n        return out\n\n\n    @staticmethod\n    def multiscale_discriminator(x, config, training, actv=tf.nn.leaky_relu, use_sigmoid=False, \n        ksize=4, mode='real', reuse=False):\n        # x is either generator output G(z) or drawn from the real data distribution\n        # Multiscale + Patch-GAN discriminator architecture based on arXiv 1711.11585\n        print('<------------ Building multiscale discriminator architecture ------------>')\n\n        if mode == 'real':\n            print('Building discriminator D(x)')\n        elif mode == 'reconstructed':\n            print('Building discriminator D(G(z))')\n        else:\n            raise NotImplementedError('Invalid discriminator mode specified.')\n\n        # Downsample input\n        x2 = tf.layers.average_pooling2d(x, pool_size=3, strides=2, padding='same')\n        x4 = tf.layers.average_pooling2d(x2, pool_size=3, strides=2, padding='same')\n\n        print('Shape of x:', x.get_shape().as_list())\n        print('Shape of x downsampled by factor 2:', x2.get_shape().as_list())\n        print('Shape of x downsampled by factor 4:', x4.get_shape().as_list())\n\n        def discriminator(x, scope, actv=actv, use_sigmoid=use_sigmoid, ksize=ksize, reuse=reuse):\n\n            # Returns patch-GAN output + intermediate layers\n\n            with tf.variable_scope('discriminator_{}'.format(scope), reuse=reuse):\n                c1 = tf.layers.conv2d(x, 64, kernel_size=ksize, strides=2, padding='same', activation=actv)\n                c2 = Utils.conv_block(c1, filters=128, kernel_size=ksize, strides=2, padding='same', actv=actv)\n                c3 = Utils.conv_block(c2, filters=256, kernel_size=ksize, strides=2, padding='same', actv=actv)\n                c4 = Utils.conv_block(c3, filters=512, kernel_size=ksize, strides=2, padding='same', actv=actv)\n                out = tf.layers.conv2d(c4, 1, kernel_size=ksize, strides=1, padding='same')\n\n                if use_sigmoid is True:  # Otherwise use LS-GAN\n                    out = tf.nn.sigmoid(out)\n\n            return out, c1, c2, c3, c4\n\n        with tf.variable_scope('discriminator', reuse=reuse):\n            disc, *Dk = discriminator(x, 'original')\n            disc_downsampled_2, *Dk_2 = discriminator(x2, 'downsampled_2')\n            disc_downsampled_4, *Dk_4 = discriminator(x4, 'downsampled_4')\n\n        return disc, disc_downsampled_2, disc_downsampled_4, Dk, Dk_2, Dk_4\n\n    @staticmethod\n    def dcgan_generator(z, config, training, C, reuse=False, actv=tf.nn.relu, kernel_size=5, upsample_dim=256):\n        \"\"\"\n        Upsample noise to concatenate with quantized representation w_bar.\n        + z:    Drawn from latent distribution - [batch_size, noise_dim]\n        + C:    Bottleneck depth, controls bpp - last dimension of encoder output\n        \"\"\"\n        init =  tf.contrib.layers.xavier_initializer()\n        kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False}\n        with tf.variable_scope('noise_generator', reuse=reuse):\n\n            # [batch_size, 4, 8, dim]\n            with tf.variable_scope('fc1', reuse=reuse):\n                h2 = tf.layers.dense(z, units=4 * 8 * upsample_dim, activation=actv, kernel_initializer=init)  # cifar-10\n                h2 = tf.layers.batch_normalization(h2, **kwargs)\n                h2 = tf.reshape(h2, shape=[-1, 4, 8, upsample_dim])\n\n            # [batch_size, 8, 16, dim/2]\n            with tf.variable_scope('upsample1', reuse=reuse):\n                up1 = tf.layers.conv2d_transpose(h2, upsample_dim//2, kernel_size=kernel_size, strides=2, padding='same', activation=actv)\n                up1 = tf.layers.batch_normalization(up1, **kwargs)\n\n            # [batch_size, 16, 32, dim/4]\n            with tf.variable_scope('upsample2', reuse=reuse):\n                up2 = tf.layers.conv2d_transpose(up1, upsample_dim//4, kernel_size=kernel_size, strides=2, padding='same', activation=actv)\n                up2 = tf.layers.batch_normalization(up2, **kwargs)\n            \n            # [batch_size, 32, 64, dim/8]\n            with tf.variable_scope('upsample3', reuse=reuse):\n                up3 = tf.layers.conv2d_transpose(up2, upsample_dim//8, kernel_size=kernel_size, strides=2, padding='same', activation=actv)  # cifar-10\n                up3 = tf.layers.batch_normalization(up3, **kwargs)\n\n            with tf.variable_scope('conv_out', reuse=reuse):\n                out = tf.pad(up3, [[0, 0], [3, 3], [3, 3], [0, 0]], 'REFLECT')\n                out = tf.layers.conv2d(out, C, kernel_size=7, strides=1, padding='VALID')\n\n        return out\n\n    @staticmethod\n    def dcgan_discriminator(x, config, training, reuse=False, actv=tf.nn.relu):\n        # x is either generator output G(z) or drawn from the real data distribution\n        init =  tf.contrib.layers.xavier_initializer()\n        kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False}\n        print('Shape of x:', x.get_shape().as_list())\n        x = tf.reshape(x, shape=[-1, 32, 32, 3]) \n        # x = tf.reshape(x, shape=[-1, 28, 28, 1]) \n\n        with tf.variable_scope('discriminator', reuse=reuse):\n            with tf.variable_scope('conv1', reuse=reuse):\n                c1 = tf.layers.conv2d(x, 64, kernel_size=5, strides=2, padding='same', activation=actv)\n                c1 = tf.layers.batch_normalization(c1, **kwargs)\n\n            with tf.variable_scope('conv2', reuse=reuse):\n                c2 = tf.layers.conv2d(c1, 128, kernel_size=5, strides=2, padding='same', activation=actv)\n                c2 = tf.layers.batch_normalization(c2, **kwargs)\n\n            with tf.variable_scope('fc1', reuse=reuse):\n                fc1 = tf.contrib.layers.flatten(c2)\n                # fc1 = tf.reshape(c2, shape=[-1, 8 * 8 * 128])\n                fc1 = tf.layers.dense(fc1, units=1024, activation=actv, kernel_initializer=init)\n                fc1 = tf.layers.batch_normalization(fc1, **kwargs)\n            \n            with tf.variable_scope('out', reuse=reuse):\n                out = tf.layers.dense(fc1, units=2, activation=None, kernel_initializer=init)\n\n        return out\n        \n\n    @staticmethod\n    def critic_grande(x, config, training, reuse=False, actv=tf.nn.relu, kernel_size=5, gradient_penalty=True):\n        # x is either generator output G(z) or drawn from the real data distribution\n        init =  tf.contrib.layers.xavier_initializer()\n        kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False}\n        print('Shape of x:', x.get_shape().as_list())\n        x = tf.reshape(x, shape=[-1, 32, 32, 3]) \n        # x = tf.reshape(x, shape=[-1, 28, 28, 1]) \n\n        with tf.variable_scope('critic', reuse=reuse):\n            with tf.variable_scope('conv1', reuse=reuse):\n                c1 = tf.layers.conv2d(x, 64, kernel_size=kernel_size, strides=2, padding='same', activation=actv)\n                if gradient_penalty is False:\n                    c1 = tf.layers.batch_normalization(c1, **kwargs)\n\n            with tf.variable_scope('conv2', reuse=reuse):\n                c2 = tf.layers.conv2d(c1, 128, kernel_size=kernel_size, strides=2, padding='same', activation=actv)\n                if gradient_penalty is False:\n                    c2 = tf.layers.batch_normalization(c2, **kwargs)\n\n            with tf.variable_scope('conv3', reuse=reuse):\n                c3 = tf.layers.conv2d(c2, 256, kernel_size=kernel_size, strides=2, padding='same', activation=actv)\n                if gradient_penalty is False:\n                    c3 = tf.layers.batch_normalization(c3, **kwargs)\n\n            with tf.variable_scope('fc1', reuse=reuse):\n                fc1 = tf.contrib.layers.flatten(c3)\n                # fc1 = tf.reshape(c2, shape=[-1, 8 * 8 * 128])\n                fc1 = tf.layers.dense(fc1, units=1024, activation=actv, kernel_initializer=init)\n                #fc1 = tf.layers.batch_normalization(fc1, **kwargs)\n            \n            with tf.variable_scope('out', reuse=reuse):\n                out = tf.layers.dense(fc1, units=1, activation=None, kernel_initializer=init)\n\n        return out\n\n    @staticmethod\n    def wrn(x, config, training, reuse=False, actv=tf.nn.relu):\n        # Implements W-28-10 wide residual network\n        # See Arxiv 1605.07146\n        network_width = 10 # k\n        block_multiplicity = 2 # n\n\n        filters = [16, 16, 32, 64]\n        init = tf.contrib.layers.xavier_initializer()\n        kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':True}\n\n        def residual_block(x, n_filters, actv, keep_prob, training, project_shortcut=False, first_block=False):\n            init = tf.contrib.layers.xavier_initializer()\n            kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':True}\n\n            if project_shortcut:\n                strides = [2,2] if not first_block else [1,1]\n                identity_map = tf.layers.conv2d(x, filters=n_filters, kernel_size=[1,1],\n                                   strides=strides, kernel_initializer=init, padding='same')\n                # identity_map = tf.layers.batch_normalization(identity_map, **kwargs)\n            else:\n                strides = [1,1]\n                identity_map = x\n\n            bn = tf.layers.batch_normalization(x, **kwargs)\n            conv = tf.layers.conv2d(bn, filters=n_filters, kernel_size=[3,3], activation=actv,\n                       strides=strides, kernel_initializer=init, padding='same')\n\n            bn = tf.layers.batch_normalization(conv, **kwargs)\n            do = tf.layers.dropout(bn, rate=1-keep_prob, training=training)\n\n            conv = tf.layers.conv2d(do, filters=n_filters, kernel_size=[3,3], activation=actv,\n                       kernel_initializer=init, padding='same')\n            out = tf.add(conv, identity_map)\n\n            return out\n\n        def residual_block_2(x, n_filters, actv, keep_prob, training, project_shortcut=False, first_block=False):\n            init = tf.contrib.layers.xavier_initializer()\n            kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':True}\n            prev_filters = x.get_shape().as_list()[-1]\n            if project_shortcut:\n                strides = [2,2] if not first_block else [1,1]\n                # identity_map = tf.layers.conv2d(x, filters=n_filters, kernel_size=[1,1],\n                #                   strides=strides, kernel_initializer=init, padding='same')\n                identity_map = tf.layers.average_pooling2d(x, strides, strides, 'valid')\n                identity_map = tf.pad(identity_map, \n                    tf.constant([[0,0],[0,0],[0,0],[(n_filters-prev_filters)//2, (n_filters-prev_filters)//2]]))\n                # identity_map = tf.layers.batch_normalization(identity_map, **kwargs)\n            else:\n                strides = [1,1]\n                identity_map = x\n\n            x = tf.layers.batch_normalization(x, **kwargs)\n            x = tf.nn.relu(x)\n            x = tf.layers.conv2d(x, filters=n_filters, kernel_size=[3,3], strides=strides,\n                    kernel_initializer=init, padding='same')\n\n            x = tf.layers.batch_normalization(x, **kwargs)\n            x = tf.nn.relu(x)\n            x = tf.layers.dropout(x, rate=1-keep_prob, training=training)\n\n            x = tf.layers.conv2d(x, filters=n_filters, kernel_size=[3,3],\n                       kernel_initializer=init, padding='same')\n            out = tf.add(x, identity_map)\n\n            return out\n\n        with tf.variable_scope('wrn_conv', reuse=reuse):\n            # Initial convolution --------------------------------------------->\n            with tf.variable_scope('conv0', reuse=reuse):\n                conv = tf.layers.conv2d(x, filters[0], kernel_size=[3,3], activation=actv,\n                                        kernel_initializer=init, padding='same')\n            # Residual group 1 ------------------------------------------------>\n            rb = conv\n            f1 = filters[1]*network_width\n            for n in range(block_multiplicity):\n                with tf.variable_scope('group1/{}'.format(n), reuse=reuse):\n                    project_shortcut = True if n==0 else False\n                    rb = residual_block(rb, f1, actv, project_shortcut=project_shortcut,\n                            keep_prob=config.conv_keep_prob, training=training, first_block=True)\n            # Residual group 2 ------------------------------------------------>\n            f2 = filters[2]*network_width\n            for n in range(block_multiplicity):\n                with tf.variable_scope('group2/{}'.format(n), reuse=reuse):\n                    project_shortcut = True if n==0 else False\n                    rb = residual_block(rb, f2, actv, project_shortcut=project_shortcut,\n                            keep_prob=config.conv_keep_prob, training=training)\n            # Residual group 3 ------------------------------------------------>\n            f3 = filters[3]*network_width\n            for n in range(block_multiplicity):\n                with tf.variable_scope('group3/{}'.format(n), reuse=reuse):\n                    project_shortcut = True if n==0 else False\n                    rb = residual_block(rb, f3, actv, project_shortcut=project_shortcut,\n                            keep_prob=config.conv_keep_prob, training=training)\n            # Avg pooling + output -------------------------------------------->\n            with tf.variable_scope('output', reuse=reuse):\n                bn = tf.nn.relu(tf.layers.batch_normalization(rb, **kwargs))\n                avp = tf.layers.average_pooling2d(bn, pool_size=[8,8], strides=[1,1], padding='valid')\n                flatten = tf.contrib.layers.flatten(avp)\n                out = tf.layers.dense(flatten, units=config.n_classes, kernel_initializer=init)\n\n            return out\n\n\n    @staticmethod\n    def old_encoder(x, config, training, C, reuse=False, actv=tf.nn.relu):\n        \"\"\"\n        Process image x ([512,1024]) into a feature map of size W/16 x H/16 x C\n         + C:       Bottleneck depth, controls bpp\n         + Output:  Projection onto C channels, C = {2,4,8,16}\n        \"\"\"\n        # proj_channels = [2,4,8,16]\n        init = tf.contrib.layers.xavier_initializer()\n\n        def conv_block(x, filters, kernel_size=[3,3], strides=2, padding='same', actv=actv, init=init):\n            in_kwargs = {'center':True, 'scale': True}\n            x = tf.layers.conv2d(x, filters, kernel_size, strides=strides, padding=padding, activation=None)\n            x = tf.contrib.layers.instance_norm(x, **in_kwargs)\n            x = actv(x)\n            return x\n                \n        with tf.variable_scope('encoder', reuse=reuse):\n\n            # Run convolutions\n            out = conv_block(x, kernel_size=3, strides=1, filters=160, actv=actv)\n            out = conv_block(out, kernel_size=[3,3], strides=2, filters=320, actv=actv)\n            out = conv_block(out, kernel_size=[3,3], strides=2, filters=480, actv=actv)\n            out = conv_block(out, kernel_size=[3,3], strides=2, filters=640, actv=actv)\n            out = conv_block(out, kernel_size=[3,3], strides=2, filters=800, actv=actv)\n\n            out = conv_block(out, kernel_size=3, strides=1, filters=960, actv=actv)\n            # Project channels onto lower-dimensional embedding space\n            W = tf.get_variable('W_channel_{}'.format(C), shape=[960,C], initializer=init)\n            feature_map = tf.einsum('ijkl,lm->ijkm', out, W)  # feature_map = tf.tensordot(out, W, axes=((3),(0)))\n            \n            # Feature maps have dimension W/16 x H/16 x C\n            return feature_map\n\n\n"
  },
  {
    "path": "cGAN/train.py",
    "content": "#!/usr/bin/python3\nimport tensorflow as tf\nimport numpy as np\nimport pandas as pd\nimport time, os, sys\nimport argparse\n\n# User-defined\nfrom network import Network\nfrom utils import Utils\nfrom data import Data\nfrom model import Model\nfrom config import config_train, directories\n\ntf.logging.set_verbosity(tf.logging.ERROR)\n\ndef train(config, args):\n\n    start_time = time.time()\n    G_loss_best, D_loss_best = float('inf'), float('inf')\n    ckpt = tf.train.get_checkpoint_state(directories.checkpoints)\n\n    # Load data\n    print('Training on dataset', args.dataset)\n    if config.use_conditional_GAN:\n        print('Using conditional GAN')\n        paths, semantic_map_paths = Data.load_dataframe(directories.train, load_semantic_maps=True)\n        test_paths, test_semantic_map_paths = Data.load_dataframe(directories.test, load_semantic_maps=True)\n    else:\n        paths = Data.load_dataframe(directories.train)\n        test_paths = Data.load_dataframe(directories.test)\n\n    # Build graph\n    gan = Model(config, paths, name=args.name, dataset=args.dataset)\n    saver = tf.train.Saver()\n\n    if config.use_conditional_GAN:\n        feed_dict_test_init = {gan.test_path_placeholder: test_paths, \n                               gan.test_semantic_map_path_placeholder: test_semantic_map_paths}\n        feed_dict_train_init = {gan.path_placeholder: paths,\n                                gan.semantic_map_path_placeholder: semantic_map_paths}\n    else:\n        feed_dict_test_init = {gan.test_path_placeholder: test_paths}\n        feed_dict_train_init = {gan.path_placeholder: paths}\n\n    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:\n        sess.run(tf.global_variables_initializer())\n        sess.run(tf.local_variables_initializer())\n        train_handle = sess.run(gan.train_iterator.string_handle())\n        test_handle = sess.run(gan.test_iterator.string_handle())\n\n        if args.restore_last and ckpt.model_checkpoint_path:\n            # Continue training saved model\n            saver.restore(sess, ckpt.model_checkpoint_path)\n            print('{} restored.'.format(ckpt.model_checkpoint_path))\n        else:\n            if args.restore_path:\n                new_saver = tf.train.import_meta_graph('{}.meta'.format(args.restore_path))\n                new_saver.restore(sess, args.restore_path)\n                print('{} restored.'.format(args.restore_path))\n\n        sess.run(gan.test_iterator.initializer, feed_dict=feed_dict_test_init)\n\n        for epoch in range(config.num_epochs):\n\n            sess.run(gan.train_iterator.initializer, feed_dict=feed_dict_train_init)\n\n            # Run diagnostics\n            G_loss_best, D_loss_best = Utils.run_diagnostics(gan, config, directories, sess, saver, train_handle,\n                start_time, epoch, args.name, G_loss_best, D_loss_best)\n\n            while True:\n                try:\n                    # Update generator\n                    # for _ in range(8):\n                    feed_dict = {gan.training_phase: True, gan.handle: train_handle}\n                    sess.run(gan.G_train_op, feed_dict=feed_dict)\n\n                    # Update discriminator \n                    step, _ = sess.run([gan.D_global_step, gan.D_train_op], feed_dict=feed_dict)\n\n                    if step % config.diagnostic_steps == 0:\n                        G_loss_best, D_loss_best = Utils.run_diagnostics(gan, config, directories, sess, saver, train_handle,\n                            start_time, epoch, args.name, G_loss_best, D_loss_best)\n                        Utils.single_plot(epoch, step, sess, gan, train_handle, args.name, config)\n                        # for _ in range(4):\n                        #    sess.run(gan.G_opt_op, feed_dict=feed_dict)\n\n\n                except tf.errors.OutOfRangeError:\n                    print('End of epoch!')\n                    break\n\n                except KeyboardInterrupt:\n                    save_path = saver.save(sess, os.path.join(directories.checkpoints,\n                        '{}_last.ckpt'.format(args.name)), global_step=epoch)\n                    print('Interrupted, model saved to: ', save_path)\n                    sys.exit()\n\n        save_path = saver.save(sess, os.path.join(directories.checkpoints,\n                               '{}_end.ckpt'.format(args.name)),\n                               global_step=epoch)\n\n    print(\"Training Complete. Model saved to file: {} Time elapsed: {:.3f} s\".format(save_path, time.time()-start_time))\n\ndef main(**kwargs):\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"-rl\", \"--restore_last\", help=\"restore last saved model\", action=\"store_true\")\n    parser.add_argument(\"-r\", \"--restore_path\", help=\"path to model to be restored\", type=str)\n    parser.add_argument(\"-opt\", \"--optimizer\", default=\"adam\", help=\"Selected optimizer\", type=str)\n    parser.add_argument(\"-name\", \"--name\", default=\"gan-train\", help=\"Checkpoint/Tensorboard label\")\n    parser.add_argument(\"-ds\", \"--dataset\", default=\"cityscapes\", help=\"choice of training dataset. Currently only supports cityscapes/ADE20k\", choices=set((\"cityscapes\", \"ADE20k\")), type=str)\n    args = parser.parse_args()\n\n    # Launch training\n    train(config_train, args)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "cGAN/utils.py",
    "content": "# -*- coding: utf-8 -*-\n# Diagnostic helper functions for Tensorflow session\n\nimport tensorflow as tf\nimport numpy as np\nimport os, time\nimport matplotlib as mpl\nmpl.use('Agg')\nimport matplotlib.pyplot as plt\nimport seaborn as sns\n\nfrom config import directories\n\nclass Utils(object):\n    \n    @staticmethod\n    def conv_block(x, filters, kernel_size=[3,3], strides=2, padding='same', actv=tf.nn.relu):\n        in_kwargs = {'center':True, 'scale': True}\n        x = tf.layers.conv2d(x, filters, kernel_size, strides=strides, padding=padding, activation=None)\n        x = tf.contrib.layers.instance_norm(x, **in_kwargs)\n        x = actv(x)\n\n        return x\n\n    @staticmethod\n    def upsample_block(x, filters, kernel_size=[3,3], strides=2, padding='same', actv=tf.nn.relu):\n        in_kwargs = {'center':True, 'scale': True}\n        x = tf.layers.conv2d_transpose(x, filters, kernel_size, strides=strides, padding=padding, activation=None)\n        x = tf.contrib.layers.instance_norm(x, **in_kwargs)\n        x = actv(x)\n\n        return x\n\n    @staticmethod\n    def residual_block(x, n_filters, kernel_size=3, strides=1, actv=tf.nn.relu):\n        init = tf.contrib.layers.xavier_initializer()\n        # kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False}\n        strides = [1,1]\n        identity_map = x\n\n        p = int((kernel_size-1)/2)\n        res = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]], 'REFLECT')\n        res = tf.layers.conv2d(res, filters=n_filters, kernel_size=kernel_size, strides=strides,\n                activation=None, padding='VALID')\n        res = actv(tf.contrib.layers.instance_norm(res))\n\n        res = tf.pad(res, [[0, 0], [p, p], [p, p], [0, 0]], 'REFLECT')\n        res = tf.layers.conv2d(res, filters=n_filters, kernel_size=kernel_size, strides=strides,\n                activation=None, padding='VALID')\n        res = tf.contrib.layers.instance_norm(res)\n\n        assert res.get_shape().as_list() == identity_map.get_shape().as_list(), 'Mismatched shapes between input/output!'\n        out = tf.add(res, identity_map)\n\n        return out\n\n    @staticmethod\n    def get_available_gpus():\n        from tensorflow.python.client import device_lib\n        local_device_protos = device_lib.list_local_devices()\n        #return local_device_protos\n        print('Available GPUs:')\n        print([x.name for x in local_device_protos if x.device_type == 'GPU'])\n\n    @staticmethod\n    def scope_variables(name):\n        with tf.variable_scope(name):\n            return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=tf.get_variable_scope().name)\n\n    @staticmethod\n    def run_diagnostics(model, config, directories, sess, saver, train_handle, start_time, epoch, name, G_loss_best, D_loss_best):\n        t0 = time.time()\n        improved = ''\n        sess.run(tf.local_variables_initializer())\n        feed_dict_test = {model.training_phase: False, model.handle: train_handle}\n\n        try:\n            G_loss, D_loss, summary = sess.run([model.G_loss, model.D_loss, model.merge_op], feed_dict=feed_dict_test)\n            model.train_writer.add_summary(summary)\n        except tf.errors.OutOfRangeError:\n            G_loss, D_loss = float('nan'), float('nan')\n\n        if G_loss < G_loss_best and D_loss < D_loss_best:\n            G_loss_best, D_loss_best = G_loss, D_loss\n            improved = '[*]'\n            if epoch>5:\n                save_path = saver.save(sess,\n                            os.path.join(directories.checkpoints_best, '{}_epoch{}.ckpt'.format(name, epoch)),\n                            global_step=epoch)\n                print('Graph saved to file: {}'.format(save_path))\n\n        if epoch % 5 == 0 and epoch > 5:\n            save_path = saver.save(sess, os.path.join(directories.checkpoints, '{}_epoch{}.ckpt'.format(name, epoch)), global_step=epoch)\n            print('Graph saved to file: {}'.format(save_path))\n\n        print('Epoch {} | Generator Loss: {:.3f} | Discriminator Loss: {:.3f} | Rate: {} examples/s ({:.2f} s) {}'.format(epoch, G_loss, D_loss, int(config.batch_size/(time.time()-t0)), time.time() - start_time, improved))\n\n        return G_loss_best, D_loss_best\n\n    @staticmethod\n    def single_plot(epoch, global_step, sess, model, handle, name, config):\n\n        real = model.example\n        gen = model.reconstruction\n\n        # Generate images from noise, using the generator network.\n        r, g = sess.run([real, gen], feed_dict={model.training_phase:True, model.handle: handle})\n\n        images = list()\n\n        for im, imtype in zip([r,g], ['real', 'gen']):\n            im = ((im+1.0))/2  # [-1,1] -> [0,1]\n            im = np.squeeze(im)\n            im = im[:,:,:3]\n            images.append(im)\n\n            # Uncomment to plot real and generated samples separately\n            # f = plt.figure()\n            # plt.imshow(im)\n            # plt.axis('off')\n            # f.savefig(\"{}/gan_compression_{}_epoch{}_step{}_{}.pdf\".format(directories.samples, name, epoch,\n            #                     global_step, imtype), format='pdf', dpi=720, bbox_inches='tight', pad_inches=0)\n            # plt.gcf().clear()\n            # plt.close(f)\n\n        comparison = np.hstack(images)\n        f = plt.figure()\n        plt.imshow(comparison)\n        plt.axis('off')\n        f.savefig(\"{}/gan_compression_{}_epoch{}_step{}_{}_comparison.pdf\".format(directories.samples, name, epoch,\n            global_step, imtype), format='pdf', dpi=720, bbox_inches='tight', pad_inches=0)\n        plt.gcf().clear()\n        plt.close(f)\n\n\n    @staticmethod\n    def weight_decay(weight_decay, var_label='DW'):\n        \"\"\"L2 weight decay loss.\"\"\"\n        costs = []\n        for var in tf.trainable_variables():\n            if var.op.name.find(r'{}'.format(var_label)) > 0:\n                costs.append(tf.nn.l2_loss(var))\n\n        return tf.multiply(weight_decay, tf.add_n(costs))\n\n"
  },
  {
    "path": "checkpoints/.gitignore",
    "content": "*\n!.gitignore\n!best/\n\n"
  },
  {
    "path": "compress.py",
    "content": "#!/usr/bin/python3\nimport tensorflow as tf\nimport numpy as np\nimport pandas as pd\nimport time, os, sys\nimport argparse\n\n# User-defined\nfrom network import Network\nfrom utils import Utils\nfrom data import Data\nfrom model import Model\nfrom config import config_test, directories\n\ntf.logging.set_verbosity(tf.logging.ERROR)\n\ndef single_compress(config, args):\n    start = time.time()\n    ckpt = tf.train.get_checkpoint_state(directories.checkpoints)\n    assert (ckpt.model_checkpoint_path), 'Missing checkpoint file!'\n\n    if config.use_conditional_GAN:\n        print('Using conditional GAN')\n        paths, semantic_map_paths = np.array([args.image_path]), np.array([args.semantic_map_path])\n    else:\n        paths = np.array([args.image_path])\n\n    gan = Model(config, paths, name='single_compress', dataset=args.dataset, evaluate=True)\n    saver = tf.train.Saver()\n\n    if config.use_conditional_GAN:\n        feed_dict_init = {gan.path_placeholder: paths,\n                          gan.semantic_map_path_placeholder: semantic_map_paths}\n    else:\n        feed_dict_init = {gan.path_placeholder: paths}\n\n    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:\n        # Initialize variables\n        sess.run(tf.global_variables_initializer())\n        sess.run(tf.local_variables_initializer())\n        handle = sess.run(gan.train_iterator.string_handle())\n\n        if args.restore_last and ckpt.model_checkpoint_path:\n            saver.restore(sess, ckpt.model_checkpoint_path)\n            print('Most recent {} restored.'.format(ckpt.model_checkpoint_path))\n        else:\n            if args.restore_path:\n                new_saver = tf.train.import_meta_graph('{}.meta'.format(args.restore_path))\n                new_saver.restore(sess, args.restore_path)\n                print('Previous checkpoint {} restored.'.format(args.restore_path))\n\n        sess.run(gan.train_iterator.initializer, feed_dict=feed_dict_init)\n        eval_dict = {gan.training_phase: False, gan.handle: handle}\n\n        if args.output_path is None:\n            output = os.path.splitext(os.path.basename(args.image_path))\n            save_path = os.path.join(directories.samples, '{}_compressed.pdf'.format(output[0]))\n        else:\n            save_path = args.output_path\n        Utils.single_plot(0, 0, sess, gan, handle, save_path, config, single_compress=True)\n        print('Reconstruction saved to', save_path)\n\n    return\n\n\ndef main(**kwargs):\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"-rl\", \"--restore_last\", help=\"restore last saved model\", action=\"store_true\")\n    parser.add_argument(\"-r\", \"--restore_path\", help=\"path to model to be restored\", type=str)\n    parser.add_argument(\"-i\", \"--image_path\", help=\"path to image to compress\", type=str)\n    parser.add_argument(\"-sm\", \"--semantic_map_path\", help=\"path to corresponding semantic map\", type=str)\n    parser.add_argument(\"-o\", \"--output_path\", help=\"path to output image\", type=str)\n    parser.add_argument(\"-ds\", \"--dataset\", default=\"cityscapes\", help=\"choice of training dataset. Currently only supports cityscapes/ADE20k\", choices=set((\"cityscapes\", \"ADE20k\")), type=str)\n    args = parser.parse_args()\n\n    # Launch training\n    single_compress(config_test, args)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "config.py",
    "content": "#!/usr/bin/env python3\n\nclass config_train(object):\n    mode = 'gan-train'\n    num_epochs = 512\n    batch_size = 1\n    ema_decay = 0.999\n    G_learning_rate = 2e-4\n    D_learning_rate = 2e-4\n    lr_decay_rate = 2e-5\n    momentum = 0.9\n    weight_decay = 5e-4\n    noise_dim = 128\n    optimizer = 'adam'\n    kernel_size = 3\n    diagnostic_steps = 256\n\n    # WGAN\n    gradient_penalty = True\n    lambda_gp = 10\n    weight_clipping = False\n    max_c = 1e-2\n    n_critic_iterations = 20\n\n    # Compression\n    lambda_X = 12\n    channel_bottleneck = 8\n    sample_noise = True\n    use_vanilla_GAN = False\n    use_feature_matching_loss = True\n    upsample_dim = 256\n    multiscale = True\n    feature_matching_weight = 10\n    use_conditional_GAN = False\n\nclass config_test(object):\n    mode = 'gan-test'\n    num_epochs = 512\n    batch_size = 1\n    ema_decay = 0.999\n    G_learning_rate = 2e-4\n    D_learning_rate = 2e-4\n    lr_decay_rate = 2e-5\n    momentum = 0.9\n    weight_decay = 5e-4\n    noise_dim = 128\n    optimizer = 'adam'\n    kernel_size = 3\n    diagnostic_steps = 256\n\n    # WGAN\n    gradient_penalty = True\n    lambda_gp = 10\n    weight_clipping = False\n    max_c = 1e-2\n    n_critic_iterations = 5\n\n    # Compression\n    lambda_X = 12\n    channel_bottleneck = 8\n    sample_noise = True\n    use_vanilla_GAN = False\n    use_feature_matching_loss = True\n    upsample_dim = 256\n    multiscale = True\n    feature_matching_weight = 10\n    use_conditional_GAN = False\n\nclass directories(object):\n    train = 'data/cityscapes_paths_train.h5'\n    test = 'data/cityscapes_paths_test.h5'\n    val = 'data/cityscapes_paths_val.h5'\n    tensorboard = 'tensorboard'\n    checkpoints = 'checkpoints'\n    checkpoints_best = 'checkpoints/best'\n    samples = 'samples/cityscapes'\n\n"
  },
  {
    "path": "data/.gitignore",
    "content": "*\n!.gitignore\n!resize_cityscapes.sh\n!cityscapes_paths_train.h5\n!cityscapes_paths_test.h5\n!cityscapes_paths_val.h5\n\n"
  },
  {
    "path": "data/resize_cityscapes.sh",
    "content": "#!/bin/bash\n# Author: Grace Han\n# In place resampling to 512 x 1024 px\n# Requires imagemagick on a *nix system\n# Modify according to your directory structure\n\nfor f in ./**/*.png; do\n    convert $f -resize 1024x512 $f\ndone"
  },
  {
    "path": "data.py",
    "content": "#!/usr/bin/python3\nimport tensorflow as tf\nimport numpy as np\nimport pandas as pd\nfrom config import directories\n\nclass Data(object):\n\n    @staticmethod\n    def load_dataframe(filename, load_semantic_maps=False):\n        df = pd.read_hdf(filename, key='df').sample(frac=1).reset_index(drop=True)\n\n        if load_semantic_maps:\n            return df['path'].values, df['semantic_map_path'].values\n        else:\n            return df['path'].values\n\n    @staticmethod\n    def load_dataset(image_paths, batch_size, test=False, augment=False, downsample=False,\n            training_dataset='cityscapes', use_conditional_GAN=False, **kwargs):\n\n        def _augment(image):\n            # On-the-fly data augmentation\n            image = tf.image.random_brightness(image, max_delta=0.1)\n            image = tf.image.random_contrast(image, 0.5, 1.5)\n            image = tf.image.random_flip_left_right(image)\n\n            return image\n\n        def _parser(image_path, semantic_map_path=None):\n\n            def _aspect_preserving_width_resize(image, width=512):\n                height_i = tf.shape(image)[0]\n                # width_i = tf.shape(image)[1]\n                # ratio = tf.to_float(width_i) / tf.to_float(height_i)\n                # new_height = tf.to_int32(tf.to_float(height_i) / ratio)\n                new_height = height_i - tf.floormod(height_i, 16)\n                return tf.image.resize_image_with_crop_or_pad(image, new_height, width)\n\n            def _image_decoder(path):\n                im = tf.image.decode_png(tf.read_file(path), channels=3)\n                im = tf.image.convert_image_dtype(im, dtype=tf.float32)\n                return 2 * im - 1 # [0,1] -> [-1,1] (tanh range)\n                    \n            image = _image_decoder(image_path)\n\n            # Explicitly set the shape if you want a sanity check\n            # or if you are using your own custom dataset, otherwise\n            # the model is shape-agnostic as it is fully convolutional\n\n            # im.set_shape([512,1024,3])  # downscaled cityscapes\n\n            if use_conditional_GAN:\n                # Semantic map only enabled for cityscapes\n                semantic_map = _image_decoder(semantic_map_path)           \n\n            if training_dataset == 'ADE20k':\n                image = _aspect_preserving_width_resize(image)\n                if use_conditional_GAN:\n                    semantic_map = _aspect_preserving_width_resize(semantic_map)\n                # im.set_shape([None,512,3])\n\n            if use_conditional_GAN:\n                return image, semantic_map\n            else:\n                return image\n            \n\n        print('Training on', training_dataset)\n\n        if use_conditional_GAN:\n            dataset = tf.data.Dataset.from_tensor_slices((image_paths, kwargs['semantic_map_paths']))\n        else:\n            dataset = tf.data.Dataset.from_tensor_slices(image_paths)\n\n        dataset = dataset.shuffle(buffer_size=8)\n        dataset = dataset.map(_parser)\n        dataset = dataset.cache()\n        dataset = dataset.batch(batch_size)\n\n        if test:\n            dataset = dataset.repeat()\n\n        return dataset\n\n    @staticmethod\n    def load_cGAN_dataset(image_paths, semantic_map_paths, batch_size, test=False, augment=False, downsample=False,\n            training_dataset='cityscapes'):\n        \"\"\"\n        Load image dataset with semantic label maps for conditional GAN\n        \"\"\" \n\n        def _parser(image_path, semantic_map_path):\n            def _aspect_preserving_width_resize(image, width=512):\n                # If training on ADE20k\n                height_i = tf.shape(image)[0]\n                new_height = height_i - tf.floormod(height_i, 16)\n                    \n                return tf.image.resize_image_with_crop_or_pad(image, new_height, width)\n\n            def _image_decoder(path):\n                im = tf.image.decode_png(tf.read_file(image_path), channels=3)\n                im = tf.image.convert_image_dtype(im, dtype=tf.float32)\n                return 2 * im - 1 # [0,1] -> [-1,1] (tanh range)\n\n\n            image, semantic_map = _image_decoder(image_path), _image_decoder(semantic_map_path)\n            \n            print('Training on', training_dataset)\n            if training_dataset is 'ADE20k':\n                image = _aspect_preserving_width_resize(image)\n                semantic_map = _aspect_preserving_width_resize(semantic_map)\n\n            # im.set_shape([512,1024,3])  # downscaled cityscapes\n\n            return image, semantic_map\n\n        dataset = tf.data.Dataset.from_tensor_slices(image_paths, semantic_map_paths)\n        dataset = dataset.map(_parser)\n        dataset = dataset.shuffle(buffer_size=8)\n        dataset = dataset.batch(batch_size)\n\n        if test:\n            dataset = dataset.repeat()\n\n        return dataset\n\n    @staticmethod\n    def load_inference(filenames, labels, batch_size, resize=(32,32)):\n\n        # Single image estimation over multiple stochastic forward passes\n\n        def _preprocess_inference(image_path, label, resize=(32,32)):\n            # Preprocess individual images during inference\n            image_path = tf.squeeze(image_path)\n            image = tf.image.decode_png(tf.read_file(image_path))\n            image = tf.image.convert_image_dtype(image, dtype=tf.float32)\n            image = tf.image.per_image_standardization(image)\n            image = tf.image.resize_images(image, size=resize)\n\n            return image, label\n\n        dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))\n        dataset = dataset.map(_preprocess_inference)\n        dataset = dataset.batch(batch_size)\n        \n        return dataset\n\n"
  },
  {
    "path": "model.py",
    "content": "#!/usr/bin/python3\n    \nimport tensorflow as tf\nimport numpy as np\nimport glob, time, os\n\nfrom network import Network\nfrom data import Data\nfrom config import directories\nfrom utils import Utils\n\nclass Model():\n    def __init__(self, config, paths, dataset, name='gan_compression', evaluate=False):\n\n        # Build the computational graph\n\n        print('Building computational graph ...')\n        self.G_global_step = tf.Variable(0, trainable=False)\n        self.D_global_step = tf.Variable(0, trainable=False)\n        self.handle = tf.placeholder(tf.string, shape=[])\n        self.training_phase = tf.placeholder(tf.bool)\n\n        # >>> Data handling\n        self.path_placeholder = tf.placeholder(paths.dtype, paths.shape)\n        self.test_path_placeholder = tf.placeholder(paths.dtype)            \n\n        self.semantic_map_path_placeholder = tf.placeholder(paths.dtype, paths.shape)\n        self.test_semantic_map_path_placeholder = tf.placeholder(paths.dtype)  \n\n        train_dataset = Data.load_dataset(self.path_placeholder,\n                                          config.batch_size,\n                                          augment=False,\n                                          training_dataset=dataset,\n                                          use_conditional_GAN=config.use_conditional_GAN,\n                                          semantic_map_paths=self.semantic_map_path_placeholder)\n\n        test_dataset = Data.load_dataset(self.test_path_placeholder,\n                                         config.batch_size,\n                                         augment=False,\n                                         training_dataset=dataset,\n                                         use_conditional_GAN=config.use_conditional_GAN,\n                                         semantic_map_paths=self.test_semantic_map_path_placeholder,\n                                         test=True)\n\n        self.iterator = tf.data.Iterator.from_string_handle(self.handle,\n                                                                    train_dataset.output_types,\n                                                                    train_dataset.output_shapes)\n\n        self.train_iterator = train_dataset.make_initializable_iterator()\n        self.test_iterator = test_dataset.make_initializable_iterator()\n\n        if config.use_conditional_GAN:\n            self.example, self.semantic_map = self.iterator.get_next()\n        else:\n            self.example = self.iterator.get_next()\n\n        # Global generator: Encode -> quantize -> reconstruct\n        # =======================================================================================================>>>\n        with tf.variable_scope('generator'):\n            self.feature_map = Network.encoder(self.example, config, self.training_phase, config.channel_bottleneck)\n            self.w_hat = Network.quantizer(self.feature_map, config)\n\n            if config.use_conditional_GAN:\n                self.semantic_feature_map = Network.encoder(self.semantic_map, config, self.training_phase, \n                    config.channel_bottleneck, scope='semantic_map')\n                self.w_hat_semantic = Network.quantizer(self.semantic_feature_map, config, scope='semantic_map')\n\n                self.w_hat = tf.concat([self.w_hat, self.w_hat_semantic], axis=-1)\n\n            if config.sample_noise is True:\n                print('Sampling noise...')\n                # noise_prior = tf.contrib.distributions.Uniform(-1., 1.)\n                # self.noise_sample = noise_prior.sample([tf.shape(self.example)[0], config.noise_dim])\n                noise_prior = tf.contrib.distributions.MultivariateNormalDiag(loc=tf.zeros([config.noise_dim]), scale_diag=tf.ones([config.noise_dim]))\n                v = noise_prior.sample(tf.shape(self.example)[0])\n                Gv = Network.dcgan_generator(v, config, self.training_phase, C=config.channel_bottleneck, upsample_dim=config.upsample_dim)\n                self.z = tf.concat([self.w_hat, Gv], axis=-1)\n            else:\n                self.z = self.w_hat\n\n            self.reconstruction = Network.decoder(self.z, config, self.training_phase, C=config.channel_bottleneck)\n\n        print('Real image shape:', self.example.get_shape().as_list())\n        print('Reconstruction shape:', self.reconstruction.get_shape().as_list())\n\n        if evaluate:\n            return\n\n        # Pass generated, real images to discriminator\n        # =======================================================================================================>>>\n\n        if config.use_conditional_GAN:\n            # Model conditional distribution\n            self.example = tf.concat([self.example, self.semantic_map], axis=-1)\n            self.reconstruction = tf.concat([self.reconstruction, self.semantic_map], axis=-1)\n\n        if config.multiscale:\n            D_x, D_x2, D_x4, *Dk_x = Network.multiscale_discriminator(self.example, config, self.training_phase, \n                use_sigmoid=config.use_vanilla_GAN, mode='real')\n            D_Gz, D_Gz2, D_Gz4, *Dk_Gz = Network.multiscale_discriminator(self.reconstruction, config, self.training_phase, \n                use_sigmoid=config.use_vanilla_GAN, mode='reconstructed', reuse=True)\n        else:\n            D_x = Network.discriminator(self.example, config, self.training_phase, use_sigmoid=config.use_vanilla_GAN)\n            D_Gz = Network.discriminator(self.reconstruction, config, self.training_phase, use_sigmoid=config.use_vanilla_GAN, reuse=True)\n         \n        # Loss terms \n        # =======================================================================================================>>>\n        if config.use_vanilla_GAN is True:\n            # Minimize JS divergence\n            D_loss_real = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=D_x,\n                labels=tf.ones_like(D_x)))\n            D_loss_gen = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=D_Gz,\n                labels=tf.zeros_like(D_Gz)))\n            self.D_loss = D_loss_real + D_loss_gen\n            # G_loss = max log D(G(z))\n            self.G_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=D_Gz,\n                labels=tf.ones_like(D_Gz)))\n        else:\n            # Minimize $\\chi^2$ divergence\n            self.D_loss = tf.reduce_mean(tf.square(D_x - 1.)) + tf.reduce_mean(tf.square(D_Gz))\n            self.G_loss = tf.reduce_mean(tf.square(D_Gz - 1.))\n\n            if config.multiscale:\n                self.D_loss += tf.reduce_mean(tf.square(D_x2 - 1.)) + tf.reduce_mean(tf.square(D_x4 - 1.))\n                self.D_loss += tf.reduce_mean(tf.square(D_Gz2)) + tf.reduce_mean(tf.square(D_Gz4))\n\n        distortion_penalty = config.lambda_X * tf.losses.mean_squared_error(self.example, self.reconstruction)\n        self.G_loss += distortion_penalty\n\n        if config.use_feature_matching_loss:  # feature extractor for generator\n            D_x_layers, D_Gz_layers = [j for i in Dk_x for j in i], [j for i in Dk_Gz for j in i]\n            feature_matching_loss = tf.reduce_sum([tf.reduce_mean(tf.abs(Dkx-Dkz)) for Dkx, Dkz in zip(D_x_layers, D_Gz_layers)])\n            self.G_loss += config.feature_matching_weight * feature_matching_loss\n\n        \n        # Optimization\n        # =======================================================================================================>>>\n        G_opt = tf.train.AdamOptimizer(learning_rate=config.G_learning_rate, beta1=0.5)\n        D_opt = tf.train.AdamOptimizer(learning_rate=config.D_learning_rate, beta1=0.5)\n\n        theta_G = Utils.scope_variables('generator')\n        theta_D = Utils.scope_variables('discriminator')\n        # print('Generator parameters:', theta_G)\n        # print('Discriminator parameters:', theta_D)\n        G_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='generator')\n        D_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='discriminator')\n\n        # Execute the update_ops before performing the train_step\n        with tf.control_dependencies(G_update_ops):\n            self.G_opt_op = G_opt.minimize(self.G_loss, name='G_opt', global_step=self.G_global_step, var_list=theta_G)\n        with tf.control_dependencies(D_update_ops):\n            self.D_opt_op = D_opt.minimize(self.D_loss, name='D_opt', global_step=self.D_global_step, var_list=theta_D)\n\n        G_ema = tf.train.ExponentialMovingAverage(decay=config.ema_decay, num_updates=self.G_global_step)\n        G_maintain_averages_op = G_ema.apply(theta_G)\n        D_ema = tf.train.ExponentialMovingAverage(decay=config.ema_decay, num_updates=self.D_global_step)\n        D_maintain_averages_op = D_ema.apply(theta_D)\n\n        with tf.control_dependencies(G_update_ops+[self.G_opt_op]):\n            self.G_train_op = tf.group(G_maintain_averages_op)\n        with tf.control_dependencies(D_update_ops+[self.D_opt_op]):\n            self.D_train_op = tf.group(D_maintain_averages_op)\n\n        # >>> Monitoring\n        # tf.summary.scalar('learning_rate', learning_rate)\n        tf.summary.scalar('generator_loss', self.G_loss)\n        tf.summary.scalar('discriminator_loss', self.D_loss)\n        tf.summary.scalar('distortion_penalty', distortion_penalty)\n        if config.use_feature_matching_loss:\n            tf.summary.scalar('feature_matching_loss', feature_matching_loss)\n        tf.summary.scalar('G_global_step', self.G_global_step)\n        tf.summary.scalar('D_global_step', self.D_global_step)\n        tf.summary.image('real_images', self.example[:,:,:,:3], max_outputs=4)\n        tf.summary.image('compressed_images', self.reconstruction[:,:,:,:3], max_outputs=4)\n        if config.use_conditional_GAN:\n            tf.summary.image('semantic_map', self.semantic_map, max_outputs=4)\n        self.merge_op = tf.summary.merge_all()\n\n        self.train_writer = tf.summary.FileWriter(\n            os.path.join(directories.tensorboard, '{}_train_{}'.format(name, time.strftime('%d-%m_%I:%M'))), graph=tf.get_default_graph())\n        self.test_writer = tf.summary.FileWriter(\n            os.path.join(directories.tensorboard, '{}_test_{}'.format(name, time.strftime('%d-%m_%I:%M'))))\n"
  },
  {
    "path": "network.py",
    "content": "\"\"\" Modular components of computational graph\n    JTan 2018\n\"\"\"\nimport tensorflow as tf\nfrom utils import Utils\n\nclass Network(object):\n\n    @staticmethod\n    def encoder(x, config, training, C, reuse=False, actv=tf.nn.relu, scope='image'):\n        \"\"\"\n        Process image x ([512,1024]) into a feature map of size W/16 x H/16 x C\n         + C:       Bottleneck depth, controls bpp\n         + Output:  Projection onto C channels, C = {2,4,8,16}\n        \"\"\"\n        init = tf.contrib.layers.xavier_initializer()\n        print('<------------ Building global {} generator architecture ------------>'.format(scope))\n\n        def conv_block(x, filters, kernel_size=[3,3], strides=2, padding='same', actv=actv, init=init):\n            bn_kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False}\n            in_kwargs = {'center':True, 'scale': True}\n            x = tf.layers.conv2d(x, filters, kernel_size, strides=strides, padding=padding, activation=None)\n            # x = tf.layers.batch_normalization(x, **bn_kwargs)\n            x = tf.contrib.layers.instance_norm(x, **in_kwargs)\n            x = actv(x)\n            return x\n\n        with tf.variable_scope('encoder_{}'.format(scope), reuse=reuse):\n\n            # Run convolutions\n            f = [60, 120, 240, 480, 960]\n            x = tf.pad(x, [[0, 0], [3, 3], [3, 3], [0, 0]], 'REFLECT')\n            out = conv_block(x, filters=f[0], kernel_size=7, strides=1, padding='VALID', actv=actv)\n\n            out = conv_block(out, filters=f[1], kernel_size=3, strides=2, actv=actv)\n            out = conv_block(out, filters=f[2], kernel_size=3, strides=2, actv=actv)\n            out = conv_block(out, filters=f[3], kernel_size=3, strides=2, actv=actv)\n            out = conv_block(out, filters=f[4], kernel_size=3, strides=2, actv=actv)\n\n            # Project channels onto space w/ dimension C\n            # Feature maps have dimension W/16 x H/16 x C\n            out = tf.pad(out, [[0, 0], [1, 1], [1, 1], [0, 0]], 'REFLECT')\n            feature_map = conv_block(out, filters=C, kernel_size=3, strides=1, padding='VALID', actv=actv)\n            \n            return feature_map\n\n\n    @staticmethod\n    def quantizer(w, config, reuse=False, temperature=1, L=5, scope='image'):\n        \"\"\"\n        Quantize feature map over L centers to obtain discrete $\\hat{w}$\n         + Centers: {-2,-1,0,1,2}\n         + TODO:    Toggle learnable centers?\n        \"\"\"\n        with tf.variable_scope('quantizer_{}'.format(scope, reuse=reuse)):\n\n            centers = tf.cast(tf.range(-2,3), tf.float32)\n            # Partition W into the Voronoi tesellation over the centers\n            w_stack = tf.stack([w for _ in range(L)], axis=-1)\n            w_hard = tf.cast(tf.argmin(tf.abs(w_stack - centers), axis=-1), tf.float32) + tf.reduce_min(centers)\n\n            smx = tf.nn.softmax(-1.0/temperature * tf.abs(w_stack - centers), dim=-1)\n            # Contract last dimension\n            w_soft = tf.einsum('ijklm,m->ijkl', smx, centers)  # w_soft = tf.tensordot(smx, centers, axes=((-1),(0)))\n\n            # Treat quantization as differentiable for optimization\n            w_bar = tf.round(tf.stop_gradient(w_hard - w_soft) + w_soft)\n\n            return w_bar\n\n\n    @staticmethod\n    def decoder(w_bar, config, training, C, reuse=False, actv=tf.nn.relu, channel_upsample=960):\n        \"\"\"\n        Attempt to reconstruct the image from the quantized representation w_bar.\n        Generated image should be consistent with the true image distribution while\n        recovering the specific encoded image\n        + C:        Bottleneck depth, controls bpp - last dimension of encoder output\n        + TODO:     Concatenate quantized w_bar with noise sampled from prior\n        \"\"\"\n        init = tf.contrib.layers.xavier_initializer()\n\n        def residual_block(x, n_filters, kernel_size=3, strides=1, actv=actv):\n            init = tf.contrib.layers.xavier_initializer()\n            # kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False}\n            strides = [1,1]\n            identity_map = x\n\n            p = int((kernel_size-1)/2)\n            res = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]], 'REFLECT')\n            res = tf.layers.conv2d(res, filters=n_filters, kernel_size=kernel_size, strides=strides,\n                    activation=None, padding='VALID')\n            res = actv(tf.contrib.layers.instance_norm(res))\n\n            res = tf.pad(res, [[0, 0], [p, p], [p, p], [0, 0]], 'REFLECT')\n            res = tf.layers.conv2d(res, filters=n_filters, kernel_size=kernel_size, strides=strides,\n                    activation=None, padding='VALID')\n            res = tf.contrib.layers.instance_norm(res)\n\n            assert res.get_shape().as_list() == identity_map.get_shape().as_list(), 'Mismatched shapes between input/output!'\n            out = tf.add(res, identity_map)\n\n            return out\n\n        def upsample_block(x, filters, kernel_size=[3,3], strides=2, padding='same', actv=actv, batch_norm=False):\n            bn_kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False}\n            in_kwargs = {'center':True, 'scale': True}\n            x = tf.layers.conv2d_transpose(x, filters, kernel_size, strides=strides, padding=padding, activation=None)\n            if batch_norm is True:\n                x = tf.layers.batch_normalization(x, **bn_kwargs)\n            else:\n                x = tf.contrib.layers.instance_norm(x, **in_kwargs)\n            x = actv(x)\n\n            return x\n\n        # Project channel dimension of w_bar to higher dimension\n        # W_pc = tf.get_variable('W_pc_{}'.format(C), shape=[C, channel_upsample], initializer=init)\n        # upsampled = tf.einsum('ijkl,lm->ijkm', w_bar, W_pc)\n        with tf.variable_scope('decoder', reuse=reuse):\n            w_bar = tf.pad(w_bar, [[0, 0], [1, 1], [1, 1], [0, 0]], 'REFLECT')\n            upsampled = Utils.conv_block(w_bar, filters=960, kernel_size=3, strides=1, padding='VALID', actv=actv)\n            \n            # Process upsampled feature map with residual blocks\n            res = residual_block(upsampled, 960, actv=actv)\n            res = residual_block(res, 960, actv=actv)\n            res = residual_block(res, 960, actv=actv)\n            res = residual_block(res, 960, actv=actv)\n            res = residual_block(res, 960, actv=actv)\n            res = residual_block(res, 960, actv=actv)\n            res = residual_block(res, 960, actv=actv)\n            res = residual_block(res, 960, actv=actv)\n            res = residual_block(res, 960, actv=actv)\n\n            # Upsample to original dimensions - mirror decoder\n            f = [480, 240, 120, 60]\n\n            ups = upsample_block(res, f[0], 3, strides=[2,2], padding='same')\n            ups = upsample_block(ups, f[1], 3, strides=[2,2], padding='same')\n            ups = upsample_block(ups, f[2], 3, strides=[2,2], padding='same')\n            ups = upsample_block(ups, f[3], 3, strides=[2,2], padding='same')\n            \n            ups = tf.pad(ups, [[0, 0], [3, 3], [3, 3], [0, 0]], 'REFLECT')\n            ups = tf.layers.conv2d(ups, 3, kernel_size=7, strides=1, padding='VALID')\n\n            out = tf.nn.tanh(ups)\n\n            return out\n\n\n    @staticmethod\n    def discriminator(x, config, training, reuse=False, actv=tf.nn.leaky_relu, use_sigmoid=False, ksize=4):\n        # x is either generator output G(z) or drawn from the real data distribution\n        # Patch-GAN discriminator based on arXiv 1711.11585\n        # bn_kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False}\n        in_kwargs = {'center':True, 'scale':True, 'activation_fn':actv}\n\n        print('Shape of x:', x.get_shape().as_list())\n\n        with tf.variable_scope('discriminator', reuse=reuse):\n            c1 = tf.layers.conv2d(x, 64, kernel_size=ksize, strides=2, padding='same', activation=actv)\n            c2 = tf.layers.conv2d(c1, 128, kernel_size=ksize, strides=2, padding='same')\n            c2 = actv(tf.contrib.layers.instance_norm(c2, **in_kwargs))\n            c3 = tf.layers.conv2d(c2, 256, kernel_size=ksize, strides=2, padding='same')\n            c3 = actv(tf.contrib.layers.instance_norm(c3, **in_kwargs))\n            c4 = tf.layers.conv2d(c3, 512, kernel_size=ksize, strides=2, padding='same')\n            c4 = actv(tf.contrib.layers.instance_norm(c4, **in_kwargs))\n\n            out = tf.layers.conv2d(c4, 1, kernel_size=ksize, strides=1, padding='same')\n\n            if use_sigmoid is True:  # Otherwise use LS-GAN\n                out = tf.nn.sigmoid(out)\n\n        return out\n\n\n    @staticmethod\n    def multiscale_discriminator(x, config, training, actv=tf.nn.leaky_relu, use_sigmoid=False, \n        ksize=4, mode='real', reuse=False):\n        # x is either generator output G(z) or drawn from the real data distribution\n        # Multiscale + Patch-GAN discriminator architecture based on arXiv 1711.11585\n        print('<------------ Building multiscale discriminator architecture ------------>')\n\n        if mode == 'real':\n            print('Building discriminator D(x)')\n        elif mode == 'reconstructed':\n            print('Building discriminator D(G(z))')\n        else:\n            raise NotImplementedError('Invalid discriminator mode specified.')\n\n        # Downsample input\n        x2 = tf.layers.average_pooling2d(x, pool_size=3, strides=2, padding='same')\n        x4 = tf.layers.average_pooling2d(x2, pool_size=3, strides=2, padding='same')\n\n        print('Shape of x:', x.get_shape().as_list())\n        print('Shape of x downsampled by factor 2:', x2.get_shape().as_list())\n        print('Shape of x downsampled by factor 4:', x4.get_shape().as_list())\n\n        def discriminator(x, scope, actv=actv, use_sigmoid=use_sigmoid, ksize=ksize, reuse=reuse):\n\n            # Returns patch-GAN output + intermediate layers\n\n            with tf.variable_scope('discriminator_{}'.format(scope), reuse=reuse):\n                c1 = tf.layers.conv2d(x, 64, kernel_size=ksize, strides=2, padding='same', activation=actv)\n                c2 = Utils.conv_block(c1, filters=128, kernel_size=ksize, strides=2, padding='same', actv=actv)\n                c3 = Utils.conv_block(c2, filters=256, kernel_size=ksize, strides=2, padding='same', actv=actv)\n                c4 = Utils.conv_block(c3, filters=512, kernel_size=ksize, strides=2, padding='same', actv=actv)\n                out = tf.layers.conv2d(c4, 1, kernel_size=ksize, strides=1, padding='same')\n\n                if use_sigmoid is True:  # Otherwise use LS-GAN\n                    out = tf.nn.sigmoid(out)\n\n            return out, c1, c2, c3, c4\n\n        with tf.variable_scope('discriminator', reuse=reuse):\n            disc, *Dk = discriminator(x, 'original')\n            disc_downsampled_2, *Dk_2 = discriminator(x2, 'downsampled_2')\n            disc_downsampled_4, *Dk_4 = discriminator(x4, 'downsampled_4')\n\n        return disc, disc_downsampled_2, disc_downsampled_4, Dk, Dk_2, Dk_4\n\n    @staticmethod\n    def dcgan_generator(z, config, training, C, reuse=False, actv=tf.nn.relu, kernel_size=5, upsample_dim=256):\n        \"\"\"\n        Upsample noise to concatenate with quantized representation w_bar.\n        + z:    Drawn from latent distribution - [batch_size, noise_dim]\n        + C:    Bottleneck depth, controls bpp - last dimension of encoder output\n        \"\"\"\n        init =  tf.contrib.layers.xavier_initializer()\n        kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False}\n        with tf.variable_scope('noise_generator', reuse=reuse):\n\n            # [batch_size, 4, 8, dim]\n            with tf.variable_scope('fc1', reuse=reuse):\n                h2 = tf.layers.dense(z, units=4 * 8 * upsample_dim, activation=actv, kernel_initializer=init)  # cifar-10\n                h2 = tf.layers.batch_normalization(h2, **kwargs)\n                h2 = tf.reshape(h2, shape=[-1, 4, 8, upsample_dim])\n\n            # [batch_size, 8, 16, dim/2]\n            with tf.variable_scope('upsample1', reuse=reuse):\n                up1 = tf.layers.conv2d_transpose(h2, upsample_dim//2, kernel_size=kernel_size, strides=2, padding='same', activation=actv)\n                up1 = tf.layers.batch_normalization(up1, **kwargs)\n\n            # [batch_size, 16, 32, dim/4]\n            with tf.variable_scope('upsample2', reuse=reuse):\n                up2 = tf.layers.conv2d_transpose(up1, upsample_dim//4, kernel_size=kernel_size, strides=2, padding='same', activation=actv)\n                up2 = tf.layers.batch_normalization(up2, **kwargs)\n            \n            # [batch_size, 32, 64, dim/8]\n            with tf.variable_scope('upsample3', reuse=reuse):\n                up3 = tf.layers.conv2d_transpose(up2, upsample_dim//8, kernel_size=kernel_size, strides=2, padding='same', activation=actv)  # cifar-10\n                up3 = tf.layers.batch_normalization(up3, **kwargs)\n\n            with tf.variable_scope('conv_out', reuse=reuse):\n                out = tf.pad(up3, [[0, 0], [3, 3], [3, 3], [0, 0]], 'REFLECT')\n                out = tf.layers.conv2d(out, C, kernel_size=7, strides=1, padding='VALID')\n\n        return out\n\n    @staticmethod\n    def dcgan_discriminator(x, config, training, reuse=False, actv=tf.nn.relu):\n        # x is either generator output G(z) or drawn from the real data distribution\n        init =  tf.contrib.layers.xavier_initializer()\n        kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False}\n        print('Shape of x:', x.get_shape().as_list())\n        x = tf.reshape(x, shape=[-1, 32, 32, 3]) \n        # x = tf.reshape(x, shape=[-1, 28, 28, 1]) \n\n        with tf.variable_scope('discriminator', reuse=reuse):\n            with tf.variable_scope('conv1', reuse=reuse):\n                c1 = tf.layers.conv2d(x, 64, kernel_size=5, strides=2, padding='same', activation=actv)\n                c1 = tf.layers.batch_normalization(c1, **kwargs)\n\n            with tf.variable_scope('conv2', reuse=reuse):\n                c2 = tf.layers.conv2d(c1, 128, kernel_size=5, strides=2, padding='same', activation=actv)\n                c2 = tf.layers.batch_normalization(c2, **kwargs)\n\n            with tf.variable_scope('fc1', reuse=reuse):\n                fc1 = tf.contrib.layers.flatten(c2)\n                # fc1 = tf.reshape(c2, shape=[-1, 8 * 8 * 128])\n                fc1 = tf.layers.dense(fc1, units=1024, activation=actv, kernel_initializer=init)\n                fc1 = tf.layers.batch_normalization(fc1, **kwargs)\n            \n            with tf.variable_scope('out', reuse=reuse):\n                out = tf.layers.dense(fc1, units=2, activation=None, kernel_initializer=init)\n\n        return out\n        \n\n    @staticmethod\n    def critic_grande(x, config, training, reuse=False, actv=tf.nn.relu, kernel_size=5, gradient_penalty=True):\n        # x is either generator output G(z) or drawn from the real data distribution\n        init =  tf.contrib.layers.xavier_initializer()\n        kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False}\n        print('Shape of x:', x.get_shape().as_list())\n        x = tf.reshape(x, shape=[-1, 32, 32, 3]) \n        # x = tf.reshape(x, shape=[-1, 28, 28, 1]) \n\n        with tf.variable_scope('critic', reuse=reuse):\n            with tf.variable_scope('conv1', reuse=reuse):\n                c1 = tf.layers.conv2d(x, 64, kernel_size=kernel_size, strides=2, padding='same', activation=actv)\n                if gradient_penalty is False:\n                    c1 = tf.layers.batch_normalization(c1, **kwargs)\n\n            with tf.variable_scope('conv2', reuse=reuse):\n                c2 = tf.layers.conv2d(c1, 128, kernel_size=kernel_size, strides=2, padding='same', activation=actv)\n                if gradient_penalty is False:\n                    c2 = tf.layers.batch_normalization(c2, **kwargs)\n\n            with tf.variable_scope('conv3', reuse=reuse):\n                c3 = tf.layers.conv2d(c2, 256, kernel_size=kernel_size, strides=2, padding='same', activation=actv)\n                if gradient_penalty is False:\n                    c3 = tf.layers.batch_normalization(c3, **kwargs)\n\n            with tf.variable_scope('fc1', reuse=reuse):\n                fc1 = tf.contrib.layers.flatten(c3)\n                # fc1 = tf.reshape(c2, shape=[-1, 8 * 8 * 128])\n                fc1 = tf.layers.dense(fc1, units=1024, activation=actv, kernel_initializer=init)\n                #fc1 = tf.layers.batch_normalization(fc1, **kwargs)\n            \n            with tf.variable_scope('out', reuse=reuse):\n                out = tf.layers.dense(fc1, units=1, activation=None, kernel_initializer=init)\n\n        return out\n\n    @staticmethod\n    def wrn(x, config, training, reuse=False, actv=tf.nn.relu):\n        # Implements W-28-10 wide residual network\n        # See Arxiv 1605.07146\n        network_width = 10 # k\n        block_multiplicity = 2 # n\n\n        filters = [16, 16, 32, 64]\n        init = tf.contrib.layers.xavier_initializer()\n        kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':True}\n\n        def residual_block(x, n_filters, actv, keep_prob, training, project_shortcut=False, first_block=False):\n            init = tf.contrib.layers.xavier_initializer()\n            kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':True}\n\n            if project_shortcut:\n                strides = [2,2] if not first_block else [1,1]\n                identity_map = tf.layers.conv2d(x, filters=n_filters, kernel_size=[1,1],\n                                   strides=strides, kernel_initializer=init, padding='same')\n                # identity_map = tf.layers.batch_normalization(identity_map, **kwargs)\n            else:\n                strides = [1,1]\n                identity_map = x\n\n            bn = tf.layers.batch_normalization(x, **kwargs)\n            conv = tf.layers.conv2d(bn, filters=n_filters, kernel_size=[3,3], activation=actv,\n                       strides=strides, kernel_initializer=init, padding='same')\n\n            bn = tf.layers.batch_normalization(conv, **kwargs)\n            do = tf.layers.dropout(bn, rate=1-keep_prob, training=training)\n\n            conv = tf.layers.conv2d(do, filters=n_filters, kernel_size=[3,3], activation=actv,\n                       kernel_initializer=init, padding='same')\n            out = tf.add(conv, identity_map)\n\n            return out\n\n        def residual_block_2(x, n_filters, actv, keep_prob, training, project_shortcut=False, first_block=False):\n            init = tf.contrib.layers.xavier_initializer()\n            kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':True}\n            prev_filters = x.get_shape().as_list()[-1]\n            if project_shortcut:\n                strides = [2,2] if not first_block else [1,1]\n                # identity_map = tf.layers.conv2d(x, filters=n_filters, kernel_size=[1,1],\n                #                   strides=strides, kernel_initializer=init, padding='same')\n                identity_map = tf.layers.average_pooling2d(x, strides, strides, 'valid')\n                identity_map = tf.pad(identity_map, \n                    tf.constant([[0,0],[0,0],[0,0],[(n_filters-prev_filters)//2, (n_filters-prev_filters)//2]]))\n                # identity_map = tf.layers.batch_normalization(identity_map, **kwargs)\n            else:\n                strides = [1,1]\n                identity_map = x\n\n            x = tf.layers.batch_normalization(x, **kwargs)\n            x = tf.nn.relu(x)\n            x = tf.layers.conv2d(x, filters=n_filters, kernel_size=[3,3], strides=strides,\n                    kernel_initializer=init, padding='same')\n\n            x = tf.layers.batch_normalization(x, **kwargs)\n            x = tf.nn.relu(x)\n            x = tf.layers.dropout(x, rate=1-keep_prob, training=training)\n\n            x = tf.layers.conv2d(x, filters=n_filters, kernel_size=[3,3],\n                       kernel_initializer=init, padding='same')\n            out = tf.add(x, identity_map)\n\n            return out\n\n        with tf.variable_scope('wrn_conv', reuse=reuse):\n            # Initial convolution --------------------------------------------->\n            with tf.variable_scope('conv0', reuse=reuse):\n                conv = tf.layers.conv2d(x, filters[0], kernel_size=[3,3], activation=actv,\n                                        kernel_initializer=init, padding='same')\n            # Residual group 1 ------------------------------------------------>\n            rb = conv\n            f1 = filters[1]*network_width\n            for n in range(block_multiplicity):\n                with tf.variable_scope('group1/{}'.format(n), reuse=reuse):\n                    project_shortcut = True if n==0 else False\n                    rb = residual_block(rb, f1, actv, project_shortcut=project_shortcut,\n                            keep_prob=config.conv_keep_prob, training=training, first_block=True)\n            # Residual group 2 ------------------------------------------------>\n            f2 = filters[2]*network_width\n            for n in range(block_multiplicity):\n                with tf.variable_scope('group2/{}'.format(n), reuse=reuse):\n                    project_shortcut = True if n==0 else False\n                    rb = residual_block(rb, f2, actv, project_shortcut=project_shortcut,\n                            keep_prob=config.conv_keep_prob, training=training)\n            # Residual group 3 ------------------------------------------------>\n            f3 = filters[3]*network_width\n            for n in range(block_multiplicity):\n                with tf.variable_scope('group3/{}'.format(n), reuse=reuse):\n                    project_shortcut = True if n==0 else False\n                    rb = residual_block(rb, f3, actv, project_shortcut=project_shortcut,\n                            keep_prob=config.conv_keep_prob, training=training)\n            # Avg pooling + output -------------------------------------------->\n            with tf.variable_scope('output', reuse=reuse):\n                bn = tf.nn.relu(tf.layers.batch_normalization(rb, **kwargs))\n                avp = tf.layers.average_pooling2d(bn, pool_size=[8,8], strides=[1,1], padding='valid')\n                flatten = tf.contrib.layers.flatten(avp)\n                out = tf.layers.dense(flatten, units=config.n_classes, kernel_initializer=init)\n\n            return out\n\n\n    @staticmethod\n    def old_encoder(x, config, training, C, reuse=False, actv=tf.nn.relu):\n        \"\"\"\n        Process image x ([512,1024]) into a feature map of size W/16 x H/16 x C\n         + C:       Bottleneck depth, controls bpp\n         + Output:  Projection onto C channels, C = {2,4,8,16}\n        \"\"\"\n        # proj_channels = [2,4,8,16]\n        init = tf.contrib.layers.xavier_initializer()\n\n        def conv_block(x, filters, kernel_size=[3,3], strides=2, padding='same', actv=actv, init=init):\n            in_kwargs = {'center':True, 'scale': True}\n            x = tf.layers.conv2d(x, filters, kernel_size, strides=strides, padding=padding, activation=None)\n            x = tf.contrib.layers.instance_norm(x, **in_kwargs)\n            x = actv(x)\n            return x\n                \n        with tf.variable_scope('encoder', reuse=reuse):\n\n            # Run convolutions\n            out = conv_block(x, kernel_size=3, strides=1, filters=160, actv=actv)\n            out = conv_block(out, kernel_size=[3,3], strides=2, filters=320, actv=actv)\n            out = conv_block(out, kernel_size=[3,3], strides=2, filters=480, actv=actv)\n            out = conv_block(out, kernel_size=[3,3], strides=2, filters=640, actv=actv)\n            out = conv_block(out, kernel_size=[3,3], strides=2, filters=800, actv=actv)\n\n            out = conv_block(out, kernel_size=3, strides=1, filters=960, actv=actv)\n            # Project channels onto lower-dimensional embedding space\n            W = tf.get_variable('W_channel_{}'.format(C), shape=[960,C], initializer=init)\n            feature_map = tf.einsum('ijkl,lm->ijkm', out, W)  # feature_map = tf.tensordot(out, W, axes=((3),(0)))\n            \n            # Feature maps have dimension W/16 x H/16 x C\n            return feature_map\n\n\n"
  },
  {
    "path": "samples/.gitignore",
    "content": "*\n!.gitignore\n\n"
  },
  {
    "path": "tensorboard/.gitignore",
    "content": "*\n!.gitignore\n"
  },
  {
    "path": "train.py",
    "content": "#!/usr/bin/python3\nimport tensorflow as tf\nimport numpy as np\nimport pandas as pd\nimport time, os, sys\nimport argparse\n\n# User-defined\nfrom network import Network\nfrom utils import Utils\nfrom data import Data\nfrom model import Model\nfrom config import config_train, directories\n\ntf.logging.set_verbosity(tf.logging.ERROR)\n\ndef train(config, args):\n\n    start_time = time.time()\n    G_loss_best, D_loss_best = float('inf'), float('inf')\n    ckpt = tf.train.get_checkpoint_state(directories.checkpoints)\n\n    # Load data\n    print('Training on dataset', args.dataset)\n    if config.use_conditional_GAN:\n        print('Using conditional GAN')\n        paths, semantic_map_paths = Data.load_dataframe(directories.train, load_semantic_maps=True)\n        test_paths, test_semantic_map_paths = Data.load_dataframe(directories.test, load_semantic_maps=True)\n    else:\n        paths = Data.load_dataframe(directories.train)\n        test_paths = Data.load_dataframe(directories.test)\n\n    # Build graph\n    gan = Model(config, paths, name=args.name, dataset=args.dataset)\n    saver = tf.train.Saver()\n\n    if config.use_conditional_GAN:\n        feed_dict_test_init = {gan.test_path_placeholder: test_paths, \n                               gan.test_semantic_map_path_placeholder: test_semantic_map_paths}\n        feed_dict_train_init = {gan.path_placeholder: paths,\n                                gan.semantic_map_path_placeholder: semantic_map_paths}\n    else:\n        feed_dict_test_init = {gan.test_path_placeholder: test_paths}\n        feed_dict_train_init = {gan.path_placeholder: paths}\n\n    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:\n        sess.run(tf.global_variables_initializer())\n        sess.run(tf.local_variables_initializer())\n        train_handle = sess.run(gan.train_iterator.string_handle())\n        test_handle = sess.run(gan.test_iterator.string_handle())\n\n        if args.restore_last and ckpt.model_checkpoint_path:\n            # Continue training saved model\n            saver.restore(sess, ckpt.model_checkpoint_path)\n            print('{} restored.'.format(ckpt.model_checkpoint_path))\n        else:\n            if args.restore_path:\n                new_saver = tf.train.import_meta_graph('{}.meta'.format(args.restore_path))\n                new_saver.restore(sess, args.restore_path)\n                print('{} restored.'.format(args.restore_path))\n\n        sess.run(gan.test_iterator.initializer, feed_dict=feed_dict_test_init)\n\n        for epoch in range(config.num_epochs):\n\n            sess.run(gan.train_iterator.initializer, feed_dict=feed_dict_train_init)\n\n            # Run diagnostics\n            G_loss_best, D_loss_best = Utils.run_diagnostics(gan, config, directories, sess, saver, train_handle,\n                start_time, epoch, args.name, G_loss_best, D_loss_best)\n\n            while True:\n                try:\n                    # Update generator\n                    # for _ in range(8):\n                    feed_dict = {gan.training_phase: True, gan.handle: train_handle}\n                    sess.run(gan.G_train_op, feed_dict=feed_dict)\n\n                    # Update discriminator \n                    step, _ = sess.run([gan.D_global_step, gan.D_train_op], feed_dict=feed_dict)\n\n                    if step % config.diagnostic_steps == 0:\n                        G_loss_best, D_loss_best = Utils.run_diagnostics(gan, config, directories, sess, saver, train_handle,\n                            start_time, epoch, args.name, G_loss_best, D_loss_best)\n                        Utils.single_plot(epoch, step, sess, gan, train_handle, args.name, config)\n                        # for _ in range(4):\n                        #    sess.run(gan.G_train_op, feed_dict=feed_dict)\n\n\n                except tf.errors.OutOfRangeError:\n                    print('End of epoch!')\n                    break\n\n                except KeyboardInterrupt:\n                    save_path = saver.save(sess, os.path.join(directories.checkpoints,\n                        '{}_last.ckpt'.format(args.name)), global_step=epoch)\n                    print('Interrupted, model saved to: ', save_path)\n                    sys.exit()\n\n        save_path = saver.save(sess, os.path.join(directories.checkpoints,\n                               '{}_end.ckpt'.format(args.name)),\n                               global_step=epoch)\n\n    print(\"Training Complete. Model saved to file: {} Time elapsed: {:.3f} s\".format(save_path, time.time()-start_time))\n\ndef main(**kwargs):\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"-rl\", \"--restore_last\", help=\"restore last saved model\", action=\"store_true\")\n    parser.add_argument(\"-r\", \"--restore_path\", help=\"path to model to be restored\", type=str)\n    parser.add_argument(\"-opt\", \"--optimizer\", default=\"adam\", help=\"Selected optimizer\", type=str)\n    parser.add_argument(\"-name\", \"--name\", default=\"gan-train\", help=\"Checkpoint/Tensorboard label\")\n    parser.add_argument(\"-ds\", \"--dataset\", default=\"cityscapes\", help=\"choice of training dataset. Currently only supports cityscapes/ADE20k\", choices=set((\"cityscapes\", \"ADE20k\")), type=str)\n    args = parser.parse_args()\n\n    # Launch training\n    train(config_train, args)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "utils.py",
    "content": "# -*- coding: utf-8 -*-\n# Diagnostic helper functions for Tensorflow session\n\nimport tensorflow as tf\nimport numpy as np\nimport os, time\nimport matplotlib as mpl\nmpl.use('Agg')\nimport matplotlib.pyplot as plt\nimport seaborn as sns\n\nfrom config import directories\n\nclass Utils(object):\n    \n    @staticmethod\n    def conv_block(x, filters, kernel_size=[3,3], strides=2, padding='same', actv=tf.nn.relu):\n        in_kwargs = {'center':True, 'scale': True}\n        x = tf.layers.conv2d(x, filters, kernel_size, strides=strides, padding=padding, activation=None)\n        x = tf.contrib.layers.instance_norm(x, **in_kwargs)\n        x = actv(x)\n\n        return x\n\n    @staticmethod\n    def upsample_block(x, filters, kernel_size=[3,3], strides=2, padding='same', actv=tf.nn.relu):\n        in_kwargs = {'center':True, 'scale': True}\n        x = tf.layers.conv2d_transpose(x, filters, kernel_size, strides=strides, padding=padding, activation=None)\n        x = tf.contrib.layers.instance_norm(x, **in_kwargs)\n        x = actv(x)\n\n        return x\n\n    @staticmethod\n    def residual_block(x, n_filters, kernel_size=3, strides=1, actv=tf.nn.relu):\n        init = tf.contrib.layers.xavier_initializer()\n        # kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False}\n        strides = [1,1]\n        identity_map = x\n\n        p = int((kernel_size-1)/2)\n        res = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]], 'REFLECT')\n        res = tf.layers.conv2d(res, filters=n_filters, kernel_size=kernel_size, strides=strides,\n                activation=None, padding='VALID')\n        res = actv(tf.contrib.layers.instance_norm(res))\n\n        res = tf.pad(res, [[0, 0], [p, p], [p, p], [0, 0]], 'REFLECT')\n        res = tf.layers.conv2d(res, filters=n_filters, kernel_size=kernel_size, strides=strides,\n                activation=None, padding='VALID')\n        res = tf.contrib.layers.instance_norm(res)\n\n        assert res.get_shape().as_list() == identity_map.get_shape().as_list(), 'Mismatched shapes between input/output!'\n        out = tf.add(res, identity_map)\n\n        return out\n\n    @staticmethod\n    def get_available_gpus():\n        from tensorflow.python.client import device_lib\n        local_device_protos = device_lib.list_local_devices()\n        #return local_device_protos\n        print('Available GPUs:')\n        print([x.name for x in local_device_protos if x.device_type == 'GPU'])\n\n    @staticmethod\n    def scope_variables(name):\n        with tf.variable_scope(name):\n            return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=tf.get_variable_scope().name)\n\n    @staticmethod\n    def run_diagnostics(model, config, directories, sess, saver, train_handle, start_time, epoch, name, G_loss_best, D_loss_best):\n        t0 = time.time()\n        improved = ''\n        sess.run(tf.local_variables_initializer())\n        feed_dict_test = {model.training_phase: False, model.handle: train_handle}\n\n        try:\n            G_loss, D_loss, summary = sess.run([model.G_loss, model.D_loss, model.merge_op], feed_dict=feed_dict_test)\n            model.train_writer.add_summary(summary)\n        except tf.errors.OutOfRangeError:\n            G_loss, D_loss = float('nan'), float('nan')\n\n        if G_loss < G_loss_best and D_loss < D_loss_best:\n            G_loss_best, D_loss_best = G_loss, D_loss\n            improved = '[*]'\n            if epoch>5:\n                save_path = saver.save(sess,\n                            os.path.join(directories.checkpoints_best, '{}_epoch{}.ckpt'.format(name, epoch)),\n                            global_step=epoch)\n                print('Graph saved to file: {}'.format(save_path))\n\n        if epoch % 5 == 0 and epoch > 5:\n            save_path = saver.save(sess, os.path.join(directories.checkpoints, '{}_epoch{}.ckpt'.format(name, epoch)), global_step=epoch)\n            print('Graph saved to file: {}'.format(save_path))\n\n        print('Epoch {} | Generator Loss: {:.3f} | Discriminator Loss: {:.3f} | Rate: {} examples/s ({:.2f} s) {}'.format(epoch, G_loss, D_loss, int(config.batch_size/(time.time()-t0)), time.time() - start_time, improved))\n\n        return G_loss_best, D_loss_best\n\n    @staticmethod\n    def single_plot(epoch, global_step, sess, model, handle, name, config, single_compress=False):\n\n        real = model.example\n        gen = model.reconstruction\n\n        # Generate images from noise, using the generator network.\n        r, g = sess.run([real, gen], feed_dict={model.training_phase:True, model.handle: handle})\n\n        images = list()\n\n        for im, imtype in zip([r,g], ['real', 'gen']):\n            im = ((im+1.0))/2  # [-1,1] -> [0,1]\n            im = np.squeeze(im)\n            im = im[:,:,:3]\n            images.append(im)\n\n            # Uncomment to plot real and generated samples separately\n            # f = plt.figure()\n            # plt.imshow(im)\n            # plt.axis('off')\n            # f.savefig(\"{}/gan_compression_{}_epoch{}_step{}_{}.pdf\".format(directories.samples, name, epoch,\n            #                     global_step, imtype), format='pdf', dpi=720, bbox_inches='tight', pad_inches=0)\n            # plt.gcf().clear()\n            # plt.close(f)\n\n        comparison = np.hstack(images)\n        f = plt.figure()\n        plt.imshow(comparison)\n        plt.axis('off')\n        if single_compress:\n            f.savefig(name, format='pdf', dpi=720, bbox_inches='tight', pad_inches=0)\n        else:\n            f.savefig(\"{}/gan_compression_{}_epoch{}_step{}_{}_comparison.pdf\".format(directories.samples, name, epoch,\n                global_step, imtype), format='pdf', dpi=720, bbox_inches='tight', pad_inches=0)\n        plt.gcf().clear()\n        plt.close(f)\n\n\n    @staticmethod\n    def weight_decay(weight_decay, var_label='DW'):\n        \"\"\"L2 weight decay loss.\"\"\"\n        costs = []\n        for var in tf.trainable_variables():\n            if var.op.name.find(r'{}'.format(var_label)) > 0:\n                costs.append(tf.nn.l2_loss(var))\n\n        return tf.multiply(weight_decay, tf.add_n(costs))\n\n"
  }
]