Repository: Justin-Tan/generative-compression Branch: master Commit: 38d34c853836 Files: 24 Total size: 117.5 KB Directory structure: gitextract_gpypn45i/ ├── .gitignore ├── LICENSE ├── README.md ├── cGAN/ │ ├── config.py │ ├── data.py │ ├── model.py │ ├── network.py │ ├── train.py │ └── utils.py ├── checkpoints/ │ └── .gitignore ├── compress.py ├── config.py ├── data/ │ ├── .gitignore │ ├── cityscapes_paths_test.h5 │ ├── cityscapes_paths_train.h5 │ ├── cityscapes_paths_val.h5 │ └── resize_cityscapes.sh ├── data.py ├── model.py ├── network.py ├── samples/ │ └── .gitignore ├── tensorboard/ │ └── .gitignore ├── train.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # Data #data/ #checkpoints/ #tensorboard/ #samples/ *.log *.slurm *.ipynb *.out # C extensions *.so # Distribution / packaging .Python env/ build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover .hypothesis/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # pyenv .python-version # celery beat schedule file celerybeat-schedule # SageMath parsed files *.sage.py # dotenv .env # virtualenv .venv venv/ ENV/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ # other .DS_Store ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2018 JTan Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # generative-compression TensorFlow Implementation for learned compression of images using Generative Adversarial Networks. The method was developed by Agustsson et. al. in [Generative Adversarial Networks for Extreme Learned Image Compression](https://arxiv.org/abs/1804.02958). The proposed idea is very interesting and their approach is well-described. ![Results from authors using C=4 bottleneck channels, global compression without semantic maps on the Kodak dataset](images/authors/kodak_GC_C4.png) ----------------------------- ## Usage The code depends on [Tensorflow 1.8](https://github.com/tensorflow/tensorflow) ```bash # Clone $ git clone https://github.com/Justin-Tan/generative-compression.git $ cd generative-compression # To train, check command line arguments $ python3 train.py -h # Run $ python3 train.py -opt momentum --name my_network ``` Training is conducted with batch size 1 and reconstructed samples / tensorboard summaries will be periodically written every certain number of steps (default is 128). Checkpoints are saved every 10 epochs. To compress a single image: ```bash # Compress $ python3 compress.py -r /path/to/model/checkpoint -i /path/to/image -o path/to/output/image ``` The compressed image will be saved as a side-by-side comparison with the original image under the path specified in `directories.samples` in `config.py`. If you are using the provided pretrained model with noise sampling, retain the hyperparameters under `config_test` in `config.py`, otherwise the parameters during test time should match the parameters set during training. *Note:* If you're willing to pay higher bitrates in exchange for much higher perceptual quality, you may want to check out this implementation of ["High-Fidelity Generative Image Compression"](https://github.com/Justin-Tan/high-fidelity-generative-compression), which is in the same vein but operates in higher bitrate regimes. Furthermore, it is capable of working with images of arbitrary size and resolution. ## Results These globally compressed images are from the test split of the Cityscapes `leftImg8bit` dataset. The decoder seems to hallunicate greenery in buildings, and vice-versa. #### Global conditional compression: Multiscale discriminator + feature-matching losses, C=8 channels - (compression to 0.072 bbp) **Epoch 38** ![cityscapes_e38](images/results/cGAN_epoch38.png) **Epoch 44** ![cityscapes_e44](images/results/cGAN_epoch44.png) **Epoch 47** ![cityscapes_e44](images/results/cGAN_epoch47.png) **Epoch 48** ![cityscapes_e44](images/results/cGAN_epoch48.png) ``` Show quantized C=4,8,16 channels image comparison ``` | Generator Loss | Discriminator Loss | |-------|-------| |![gen_loss](images/results/generator_loss.png) | ![discriminator_loss](images/results/discriminator_loss.png) | ## Pretrained Model You can find the pretrained model for global compression with a channel bottleneck of `C = 8` (corresponding to a 0.072 bpp representation) below. The model was subject to the multiscale discriminator and feature matching losses. Noise is sampled from a 128-dim normal distribution, passed through a DCGAN-like generator and concatenated to the quantized image representation. The model was trained for 55 epochs on the train split of the [Cityscapes](https://www.cityscapes-dataset.com/) `leftImg8bit` dataset for the images and used the `gtFine` dataset for the corresponding semantic maps. This should work with the default settings under `config_test` in `config.py`. A pretrained model for global conditional compression with a `C=8` bottleneck is also included. This model was, trained for 50 epochs with the same losses as above. Reconstruction is conditioned on semantic label maps (see the `cGAN/` folder and 'Conditional GAN usage'). * [Noise sampling model](https://drive.google.com/open?id=1gy6NJqlxflLDI1g9Rsileva-8G1ifsEC) * [Conditional GAN model](https://drive.google.com/open?id=1L3G4l8IQukNrsf3hjHv5xRhpNE77TD2k) ** Warning: Tensorflow 1.3 was used to train the models, but it appears to load without problems on Tensorflow 1.8. Please raise an issue if you have any problems. ## Details / extensions The network architectures are based on the description provided in the appendix of the original paper, which is in turn based on the paper [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://cs.stanford.edu/people/jcjohns/eccv16/) The multiscale discriminator loss used was originally proposed in the project [High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs](https://tcwang0509.github.io/pix2pixHD/), consult `network.py` for the implementation. If you would like to add an extension you can create a new method under the `Network` class, e.g. ```python @staticmethod def my_generator(z, **kwargs): """ Inputs: z: sampled noise Returns: upsampled image """ return tf.random_normal([z.get_shape()[0], height, width, channels], seed=42) ``` To change hyperparameters/toggle features use the knobs in `config.py`. (Bad form maybe. but I find it easier than a 20-line `argparse` specification). ### Data / Setup Training was done using the [ADE 20k dataset](http://groups.csail.mit.edu/vision/datasets/ADE20K/) and the [Cityscapes leftImg8bit dataset](https://www.cityscapes-dataset.com/). In the former case images are rescaled to width `512` px, and in the latter images are [resampled to `[512 x 1024]` prior to training](https://www.imagemagick.org/script/command-line-options.php#resample). An example script for resampling using `Imagemagick` is provided under `data/`. In each case, you will need to create a Pandas dataframe containing a single column: `path`, which holds the absolute/relative path to the images. This should be saved as a `HDF5` file, and you should provide the path to this under the `directories` class in `config.py`. Examples for the Cityscapes dataset are provided in the `data` directory. ### Conditional GAN usage The conditional GAN implementation for global compression is in the `cGAN` directory. The cGAN implementation appears to yield images with the highest image quality, but this implementation remains experimental. In this implementation generation is conditioned on the information in the semantic label map of the selected image. You will need to download the `gtFine` dataset of annotation maps and append a separate column `semantic_map_paths` to the Pandas dataframe pointing to the corresponding images from the `gtFine` dataset. ### Dependencies * Python 3.6 * [Pandas](https://pandas.pydata.org/) * [TensorFlow 1.8](https://github.com/tensorflow/tensorflow) ### Todo: * Incorporate GAN noise sampling into the reconstructed image. The authors state that this step is optional and that the sampled noise is combined with the quantized representation but don't provide further details. Currently the model samples from a normal distribution and upsamples this using a DCGAN-like generator (see `network.py`) to be concatenated with the quantized image representation `w_hat`, but this appears to substantially increase the 'hallunication factor' in the reconstructed images. * Integrate VGG loss. * Experiment with WGAN-GP. * Experiment with spectral normalization/ * Experiment with different generator architectures with noise sampling. * Extend to selective compression using semantic maps (contributions welcome). ### Resources * [Generative Adversarial Networks for Extreme Learned Image Compression](https://data.vision.ee.ethz.ch/aeirikur/extremecompression/#publication) * [CycleGAN](https://arxiv.org/pdf/1703.10593.pdf) * [High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs](https://tcwang0509.github.io/pix2pixHD/) ## More Results #### Global compression: Noise sampling, multiscale discriminator + feature-matching losses, C=8 channels - Compression to 0.072 bbp ![cityscapes_e45](images/results/noiseE45.png) ![cityscapes_e47](images/results/cGANe47.png) ![cityscapes_e51](images/results/noiseE51.png) ![cityscapes_e53](images/results/noiseE53.png) ![cityscapes_e54](images/results/noiseE54.png) ![cityscapes_e55](images/results/noiseE55.png) ![cityscapes_e56](images/results/noiseE56.png) ================================================ FILE: cGAN/config.py ================================================ #!/usr/bin/env python3 class config_train(object): mode = 'gan-train' num_epochs = 512 batch_size = 1 ema_decay = 0.999 G_learning_rate = 2e-4 D_learning_rate = 2e-4 lr_decay_rate = 2e-5 momentum = 0.9 weight_decay = 5e-4 noise_dim = 128 optimizer = 'adam' kernel_size = 3 diagnostic_steps = 256 # WGAN gradient_penalty = True lambda_gp = 10 weight_clipping = False max_c = 1e-2 n_critic_iterations = 20 # Compression lambda_X = 12 channel_bottleneck = 16 sample_noise = False use_vanilla_GAN = False use_feature_matching_loss = True upsample_dim = 256 multiscale = True feature_matching_weight = 10 use_conditional_GAN = False class config_test(object): mode = 'gan-test' num_epochs = 512 batch_size = 1 ema_decay = 0.999 G_learning_rate = 2e-4 D_learning_rate = 2e-4 lr_decay_rate = 2e-5 momentum = 0.9 weight_decay = 5e-4 noise_dim = 128 optimizer = 'adam' kernel_size = 3 diagnostic_steps = 256 # WGAN gradient_penalty = True lambda_gp = 10 weight_clipping = False max_c = 1e-2 n_critic_iterations = 5 # Compression lambda_X = 12 channel_bottleneck = 8 sample_noise = True use_vanilla_GAN = False use_feature_matching_loss = True upsample_dim = 256 multiscale = True feature_matching_weight = 10 use_conditional_GAN = False class directories(object): # train = 'data/ADE20K_paths_train.h5' # test = 'data/ADE20K_paths_test.h5' train = 'data/sm_cityscapes_paths_train.h5' test = 'data/cityscapes_paths_test.h5' val = 'data/cityscapes_paths_val.h5' tensorboard = 'tensorboard' checkpoints = 'checkpoints' checkpoints_best = 'checkpoints/best' samples = 'samples/cityscapes' ================================================ FILE: cGAN/data.py ================================================ #!/usr/bin/python3 import tensorflow as tf import numpy as np import pandas as pd from config import directories class Data(object): @staticmethod def load_dataframe(filename, load_semantic_maps=False): df = pd.read_hdf(filename, key='df').sample(frac=1).reset_index(drop=True) if load_semantic_maps: return df['path'].values, df['semantic_map_path'].values else: return df['path'].values @staticmethod def load_dataset(image_paths, batch_size, test=False, augment=False, downsample=False, training_dataset='cityscapes', use_conditional_GAN=False, **kwargs): def _augment(image): # On-the-fly data augmentation image = tf.image.random_brightness(image, max_delta=0.1) image = tf.image.random_contrast(image, 0.5, 1.5) image = tf.image.random_flip_left_right(image) return image def _parser(image_path, semantic_map_path=None): def _aspect_preserving_width_resize(image, width=512): height_i = tf.shape(image)[0] # width_i = tf.shape(image)[1] # ratio = tf.to_float(width_i) / tf.to_float(height_i) # new_height = tf.to_int32(tf.to_float(height_i) / ratio) new_height = height_i - tf.floormod(height_i, 16) return tf.image.resize_image_with_crop_or_pad(image, new_height, width) def _image_decoder(path): im = tf.image.decode_png(tf.read_file(path), channels=3) im = tf.image.convert_image_dtype(im, dtype=tf.float32) return 2 * im - 1 # [0,1] -> [-1,1] (tanh range) image = _image_decoder(image_path) # Explicitly set the shape if you want a sanity check # or if you are using your own custom dataset, otherwise # the model is shape-agnostic as it is fully convolutional # im.set_shape([512,1024,3]) # downscaled cityscapes if use_conditional_GAN: # Semantic map only enabled for cityscapes semantic_map = _image_decoder(semantic_map_path) if training_dataset == 'ADE20k': image = _aspect_preserving_width_resize(image) # im.set_shape([None,512,3]) if use_conditional_GAN: if training_dataset == 'ADE20k': raise NotImplementedError('Conditional generation not implemented for ADE20k dataset.') return image, semantic_map else: return image print('Training on', training_dataset) if use_conditional_GAN: dataset = tf.data.Dataset.from_tensor_slices((image_paths, kwargs['semantic_map_paths'])) else: dataset = tf.data.Dataset.from_tensor_slices(image_paths) dataset = dataset.map(_parser) dataset = dataset.shuffle(buffer_size=8) dataset = dataset.batch(batch_size) if test: dataset = dataset.repeat() return dataset @staticmethod def load_cGAN_dataset(image_paths, semantic_map_paths, batch_size, test=False, augment=False, downsample=False, training_dataset='cityscapes'): """ Load image dataset with semantic label maps for conditional GAN """ def _parser(image_path, semantic_map_path): def _aspect_preserving_width_resize(image, width=512): # If training on ADE20k height_i = tf.shape(image)[0] new_height = height_i - tf.floormod(height_i, 16) return tf.image.resize_image_with_crop_or_pad(image, new_height, width) def _image_decoder(path): im = tf.image.decode_png(tf.read_file(image_path), channels=3) im = tf.image.convert_image_dtype(im, dtype=tf.float32) return 2 * im - 1 # [0,1] -> [-1,1] (tanh range) image, semantic_map = _image_decoder(image_path), _image_decoder(semantic_map_path) print('Training on', training_dataset) if training_dataset is 'ADE20k': image = _aspect_preserving_width_resize(image) semantic_map = _aspect_preserving_width_resize(semantic_map) # im.set_shape([512,1024,3]) # downscaled cityscapes return image, semantic_map dataset = tf.data.Dataset.from_tensor_slices(image_paths, semantic_map_paths) dataset = dataset.map(_parser) dataset = dataset.shuffle(buffer_size=8) dataset = dataset.batch(batch_size) if test: dataset = dataset.repeat() return dataset @staticmethod def load_inference(filenames, labels, batch_size, resize=(32,32)): # Single image estimation over multiple stochastic forward passes def _preprocess_inference(image_path, label, resize=(32,32)): # Preprocess individual images during inference image_path = tf.squeeze(image_path) image = tf.image.decode_png(tf.read_file(image_path)) image = tf.image.convert_image_dtype(image, dtype=tf.float32) image = tf.image.per_image_standardization(image) image = tf.image.resize_images(image, size=resize) return image, label dataset = tf.data.Dataset.from_tensor_slices((filenames, labels)) dataset = dataset.map(_preprocess_inference) dataset = dataset.batch(batch_size) return dataset ================================================ FILE: cGAN/model.py ================================================ #!/usr/bin/python3 import tensorflow as tf import numpy as np import glob, time, os from network import Network from data import Data from config import directories from utils import Utils class Model(): def __init__(self, config, paths, dataset, name='gan_compression', evaluate=False): # Build the computational graph print('Building computational graph ...') self.G_global_step = tf.Variable(0, trainable=False) self.D_global_step = tf.Variable(0, trainable=False) self.handle = tf.placeholder(tf.string, shape=[]) self.training_phase = tf.placeholder(tf.bool) # >>> Data handling self.path_placeholder = tf.placeholder(paths.dtype, paths.shape) self.test_path_placeholder = tf.placeholder(paths.dtype) self.semantic_map_path_placeholder = tf.placeholder(paths.dtype, paths.shape) self.test_semantic_map_path_placeholder = tf.placeholder(paths.dtype) train_dataset = Data.load_dataset(self.path_placeholder, config.batch_size, augment=False, training_dataset=dataset, use_conditional_GAN=config.use_conditional_GAN, semantic_map_paths=self.semantic_map_path_placeholder) test_dataset = Data.load_dataset(self.test_path_placeholder, config.batch_size, augment=False, training_dataset=dataset, use_conditional_GAN=config.use_conditional_GAN, semantic_map_paths=self.test_semantic_map_path_placeholder, test=True) self.iterator = tf.data.Iterator.from_string_handle(self.handle, train_dataset.output_types, train_dataset.output_shapes) self.train_iterator = train_dataset.make_initializable_iterator() self.test_iterator = test_dataset.make_initializable_iterator() if config.use_conditional_GAN: self.example, self.semantic_map = self.iterator.get_next() else: self.example = self.iterator.get_next() # Global generator: Encode -> quantize -> reconstruct # =======================================================================================================>>> with tf.variable_scope('generator'): self.feature_map = Network.encoder(self.example, config, self.training_phase, config.channel_bottleneck) self.w_hat = Network.quantizer(self.feature_map, config) if config.use_conditional_GAN: self.semantic_feature_map = Network.encoder(self.semantic_map, config, self.training_phase, config.channel_bottleneck, scope='semantic_map') self.w_hat_semantic = Network.quantizer(self.semantic_feature_map, config, scope='semantic_map') self.w_hat = tf.concat([self.w_hat, self.w_hat_semantic], axis=-1) if config.sample_noise is True: print('Sampling noise...') # noise_prior = tf.contrib.distributions.Uniform(-1., 1.) # self.noise_sample = noise_prior.sample([tf.shape(self.example)[0], config.noise_dim]) noise_prior = tf.contrib.distributions.MultivariateNormalDiag(loc=tf.zeros([config.noise_dim]), scale_diag=tf.ones([config.noise_dim])) v = noise_prior.sample(tf.shape(self.example)[0]) Gv = Network.dcgan_generator(v, config, self.training_phase, C=config.channel_bottleneck, upsample_dim=config.upsample_dim) self.z = tf.concat([self.w_hat, Gv], axis=-1) else: self.z = self.w_hat self.reconstruction = Network.decoder(self.z, config, self.training_phase, C=config.channel_bottleneck) print('Real image shape:', self.example.get_shape().as_list()) print('Reconstruction shape:', self.reconstruction.get_shape().as_list()) # Pass generated, real images to discriminator # =======================================================================================================>>> if config.use_conditional_GAN: # Model conditional distribution self.example = tf.concat([self.example, self.semantic_map], axis=-1) self.reconstruction = tf.concat([self.reconstruction, self.semantic_map], axis=-1) if config.multiscale: D_x, D_x2, D_x4, *Dk_x = Network.multiscale_discriminator(self.example, config, self.training_phase, use_sigmoid=config.use_vanilla_GAN, mode='real') D_Gz, D_Gz2, D_Gz4, *Dk_Gz = Network.multiscale_discriminator(self.reconstruction, config, self.training_phase, use_sigmoid=config.use_vanilla_GAN, mode='reconstructed', reuse=True) else: D_x = Network.discriminator(self.example, config, self.training_phase, use_sigmoid=config.use_vanilla_GAN) D_Gz = Network.discriminator(self.reconstruction, config, self.training_phase, use_sigmoid=config.use_vanilla_GAN, reuse=True) # Loss terms # =======================================================================================================>>> if config.use_vanilla_GAN is True: # Minimize JS divergence D_loss_real = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=D_x, labels=tf.ones_like(D_x))) D_loss_gen = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=D_Gz, labels=tf.zeros_like(D_Gz))) self.D_loss = D_loss_real + D_loss_gen # G_loss = max log D(G(z)) self.G_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=D_Gz, labels=tf.ones_like(D_Gz))) else: # Minimize $\chi^2$ divergence self.D_loss = tf.reduce_mean(tf.square(D_x - 1.)) + tf.reduce_mean(tf.square(D_Gz)) self.G_loss = tf.reduce_mean(tf.square(D_Gz - 1.)) if config.multiscale: self.D_loss += tf.reduce_mean(tf.square(D_x2 - 1.)) + tf.reduce_mean(tf.square(D_x4 - 1.)) self.D_loss += tf.reduce_mean(tf.square(D_Gz2)) + tf.reduce_mean(tf.square(D_Gz4)) distortion_penalty = config.lambda_X * tf.losses.mean_squared_error(self.example, self.reconstruction) self.G_loss += distortion_penalty if config.use_feature_matching_loss: # feature extractor for generator D_x_layers, D_Gz_layers = [j for i in Dk_x for j in i], [j for i in Dk_Gz for j in i] feature_matching_loss = tf.reduce_sum([tf.reduce_mean(tf.abs(Dkx-Dkz)) for Dkx, Dkz in zip(D_x_layers, D_Gz_layers)]) self.G_loss += config.feature_matching_weight * feature_matching_loss # Optimization # =======================================================================================================>>> G_opt = tf.train.AdamOptimizer(learning_rate=config.G_learning_rate, beta1=0.5) D_opt = tf.train.AdamOptimizer(learning_rate=config.D_learning_rate, beta1=0.5) theta_G = Utils.scope_variables('generator') theta_D = Utils.scope_variables('discriminator') print('Generator parameters:', theta_G) print('Discriminator parameters:', theta_D) G_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='generator') D_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='discriminator') # Execute the update_ops before performing the train_step with tf.control_dependencies(G_update_ops): self.G_opt_op = G_opt.minimize(self.G_loss, name='G_opt', global_step=self.G_global_step, var_list=theta_G) with tf.control_dependencies(D_update_ops): self.D_opt_op = D_opt.minimize(self.D_loss, name='D_opt', global_step=self.D_global_step, var_list=theta_D) G_ema = tf.train.ExponentialMovingAverage(decay=config.ema_decay, num_updates=self.G_global_step) G_maintain_averages_op = G_ema.apply(theta_G) D_ema = tf.train.ExponentialMovingAverage(decay=config.ema_decay, num_updates=self.D_global_step) D_maintain_averages_op = D_ema.apply(theta_D) with tf.control_dependencies(G_update_ops+[self.G_opt_op]): self.G_train_op = tf.group(G_maintain_averages_op) with tf.control_dependencies(D_update_ops+[self.D_opt_op]): self.D_train_op = tf.group(D_maintain_averages_op) # >>> Monitoring # tf.summary.scalar('learning_rate', learning_rate) tf.summary.scalar('generator_loss', self.G_loss) tf.summary.scalar('discriminator_loss', self.D_loss) tf.summary.scalar('distortion_penalty', distortion_penalty) if config.use_feature_matching_loss: tf.summary.scalar('feature_matching_loss', feature_matching_loss) tf.summary.scalar('G_global_step', self.G_global_step) tf.summary.scalar('D_global_step', self.D_global_step) tf.summary.image('real_images', self.example[:,:,:,:3], max_outputs=4) tf.summary.image('compressed_images', self.reconstruction[:,:,:,:3], max_outputs=4) if config.use_conditional_GAN: tf.summary.image('semantic_map', self.semantic_map, max_outputs=4) self.merge_op = tf.summary.merge_all() self.train_writer = tf.summary.FileWriter( os.path.join(directories.tensorboard, '{}_train_{}'.format(name, time.strftime('%d-%m_%I:%M'))), graph=tf.get_default_graph()) self.test_writer = tf.summary.FileWriter( os.path.join(directories.tensorboard, '{}_test_{}'.format(name, time.strftime('%d-%m_%I:%M')))) ================================================ FILE: cGAN/network.py ================================================ """ Modular components of computational graph JTan 2018 """ import tensorflow as tf from utils import Utils class Network(object): @staticmethod def encoder(x, config, training, C, reuse=False, actv=tf.nn.relu, scope='image'): """ Process image x ([512,1024]) into a feature map of size W/16 x H/16 x C + C: Bottleneck depth, controls bpp + Output: Projection onto C channels, C = {2,4,8,16} """ init = tf.contrib.layers.xavier_initializer() print('<------------ Building global {} generator architecture ------------>'.format(scope)) def conv_block(x, filters, kernel_size=[3,3], strides=2, padding='same', actv=actv, init=init): bn_kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False} in_kwargs = {'center':True, 'scale': True} x = tf.layers.conv2d(x, filters, kernel_size, strides=strides, padding=padding, activation=None) # x = tf.layers.batch_normalization(x, **bn_kwargs) x = tf.contrib.layers.instance_norm(x, **in_kwargs) x = actv(x) return x with tf.variable_scope('encoder_{}'.format(scope), reuse=reuse): # Run convolutions f = [60, 120, 240, 480, 960] x = tf.pad(x, [[0, 0], [3, 3], [3, 3], [0, 0]], 'REFLECT') out = conv_block(x, filters=f[0], kernel_size=7, strides=1, padding='VALID', actv=actv) out = conv_block(out, filters=f[1], kernel_size=3, strides=2, actv=actv) out = conv_block(out, filters=f[2], kernel_size=3, strides=2, actv=actv) out = conv_block(out, filters=f[3], kernel_size=3, strides=2, actv=actv) out = conv_block(out, filters=f[4], kernel_size=3, strides=2, actv=actv) # Project channels onto space w/ dimension C # Feature maps have dimension W/16 x H/16 x C out = tf.pad(out, [[0, 0], [1, 1], [1, 1], [0, 0]], 'REFLECT') feature_map = conv_block(out, filters=C, kernel_size=3, strides=1, padding='VALID', actv=actv) return feature_map @staticmethod def quantizer(w, config, reuse=False, temperature=1, L=5, scope='image'): """ Quantize feature map over L centers to obtain discrete $\hat{w}$ + Centers: {-2,-1,0,1,2} + TODO: Toggle learnable centers? """ with tf.variable_scope('quantizer_{}'.format(scope, reuse=reuse)): centers = tf.cast(tf.range(-2,3), tf.float32) # Partition W into the Voronoi tesellation over the centers w_stack = tf.stack([w for _ in range(L)], axis=-1) w_hard = tf.cast(tf.argmin(tf.abs(w_stack - centers), axis=-1), tf.float32) + tf.reduce_min(centers) smx = tf.nn.softmax(-1.0/temperature * tf.abs(w_stack - centers), dim=-1) # Contract last dimension w_soft = tf.einsum('ijklm,m->ijkl', smx, centers) # w_soft = tf.tensordot(smx, centers, axes=((-1),(0))) # Treat quantization as differentiable for optimization w_bar = tf.round(tf.stop_gradient(w_hard - w_soft) + w_soft) return w_bar @staticmethod def decoder(w_bar, config, training, C, reuse=False, actv=tf.nn.relu, channel_upsample=960): """ Attempt to reconstruct the image from the quantized representation w_bar. Generated image should be consistent with the true image distribution while recovering the specific encoded image + C: Bottleneck depth, controls bpp - last dimension of encoder output + TODO: Concatenate quantized w_bar with noise sampled from prior """ init = tf.contrib.layers.xavier_initializer() def residual_block(x, n_filters, kernel_size=3, strides=1, actv=actv): init = tf.contrib.layers.xavier_initializer() # kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False} strides = [1,1] identity_map = x p = int((kernel_size-1)/2) res = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]], 'REFLECT') res = tf.layers.conv2d(res, filters=n_filters, kernel_size=kernel_size, strides=strides, activation=None, padding='VALID') res = actv(tf.contrib.layers.instance_norm(res)) res = tf.pad(res, [[0, 0], [p, p], [p, p], [0, 0]], 'REFLECT') res = tf.layers.conv2d(res, filters=n_filters, kernel_size=kernel_size, strides=strides, activation=None, padding='VALID') res = tf.contrib.layers.instance_norm(res) assert res.get_shape().as_list() == identity_map.get_shape().as_list(), 'Mismatched shapes between input/output!' out = tf.add(res, identity_map) return out def upsample_block(x, filters, kernel_size=[3,3], strides=2, padding='same', actv=actv, batch_norm=False): bn_kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False} in_kwargs = {'center':True, 'scale': True} x = tf.layers.conv2d_transpose(x, filters, kernel_size, strides=strides, padding=padding, activation=None) if batch_norm is True: x = tf.layers.batch_normalization(x, **bn_kwargs) else: x = tf.contrib.layers.instance_norm(x, **in_kwargs) x = actv(x) return x # Project channel dimension of w_bar to higher dimension # W_pc = tf.get_variable('W_pc_{}'.format(C), shape=[C, channel_upsample], initializer=init) # upsampled = tf.einsum('ijkl,lm->ijkm', w_bar, W_pc) with tf.variable_scope('decoder', reuse=reuse): w_bar = tf.pad(w_bar, [[0, 0], [1, 1], [1, 1], [0, 0]], 'REFLECT') upsampled = Utils.conv_block(w_bar, filters=960, kernel_size=3, strides=1, padding='VALID', actv=actv) # Process upsampled feature map with residual blocks res = residual_block(upsampled, 960, actv=actv) res = residual_block(res, 960, actv=actv) res = residual_block(res, 960, actv=actv) res = residual_block(res, 960, actv=actv) res = residual_block(res, 960, actv=actv) res = residual_block(res, 960, actv=actv) res = residual_block(res, 960, actv=actv) res = residual_block(res, 960, actv=actv) res = residual_block(res, 960, actv=actv) # Upsample to original dimensions - mirror decoder f = [480, 240, 120, 60] ups = upsample_block(res, f[0], 3, strides=[2,2], padding='same') ups = upsample_block(ups, f[1], 3, strides=[2,2], padding='same') ups = upsample_block(ups, f[2], 3, strides=[2,2], padding='same') ups = upsample_block(ups, f[3], 3, strides=[2,2], padding='same') ups = tf.pad(ups, [[0, 0], [3, 3], [3, 3], [0, 0]], 'REFLECT') ups = tf.layers.conv2d(ups, 3, kernel_size=7, strides=1, padding='VALID') out = tf.nn.tanh(ups) return out @staticmethod def discriminator(x, config, training, reuse=False, actv=tf.nn.leaky_relu, use_sigmoid=False, ksize=4): # x is either generator output G(z) or drawn from the real data distribution # Patch-GAN discriminator based on arXiv 1711.11585 # bn_kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False} in_kwargs = {'center':True, 'scale':True, 'activation_fn':actv} print('Shape of x:', x.get_shape().as_list()) with tf.variable_scope('discriminator', reuse=reuse): c1 = tf.layers.conv2d(x, 64, kernel_size=ksize, strides=2, padding='same', activation=actv) c2 = tf.layers.conv2d(c1, 128, kernel_size=ksize, strides=2, padding='same') c2 = actv(tf.contrib.layers.instance_norm(c2, **in_kwargs)) c3 = tf.layers.conv2d(c2, 256, kernel_size=ksize, strides=2, padding='same') c3 = actv(tf.contrib.layers.instance_norm(c3, **in_kwargs)) c4 = tf.layers.conv2d(c3, 512, kernel_size=ksize, strides=2, padding='same') c4 = actv(tf.contrib.layers.instance_norm(c4, **in_kwargs)) out = tf.layers.conv2d(c4, 1, kernel_size=ksize, strides=1, padding='same') if use_sigmoid is True: # Otherwise use LS-GAN out = tf.nn.sigmoid(out) return out @staticmethod def multiscale_discriminator(x, config, training, actv=tf.nn.leaky_relu, use_sigmoid=False, ksize=4, mode='real', reuse=False): # x is either generator output G(z) or drawn from the real data distribution # Multiscale + Patch-GAN discriminator architecture based on arXiv 1711.11585 print('<------------ Building multiscale discriminator architecture ------------>') if mode == 'real': print('Building discriminator D(x)') elif mode == 'reconstructed': print('Building discriminator D(G(z))') else: raise NotImplementedError('Invalid discriminator mode specified.') # Downsample input x2 = tf.layers.average_pooling2d(x, pool_size=3, strides=2, padding='same') x4 = tf.layers.average_pooling2d(x2, pool_size=3, strides=2, padding='same') print('Shape of x:', x.get_shape().as_list()) print('Shape of x downsampled by factor 2:', x2.get_shape().as_list()) print('Shape of x downsampled by factor 4:', x4.get_shape().as_list()) def discriminator(x, scope, actv=actv, use_sigmoid=use_sigmoid, ksize=ksize, reuse=reuse): # Returns patch-GAN output + intermediate layers with tf.variable_scope('discriminator_{}'.format(scope), reuse=reuse): c1 = tf.layers.conv2d(x, 64, kernel_size=ksize, strides=2, padding='same', activation=actv) c2 = Utils.conv_block(c1, filters=128, kernel_size=ksize, strides=2, padding='same', actv=actv) c3 = Utils.conv_block(c2, filters=256, kernel_size=ksize, strides=2, padding='same', actv=actv) c4 = Utils.conv_block(c3, filters=512, kernel_size=ksize, strides=2, padding='same', actv=actv) out = tf.layers.conv2d(c4, 1, kernel_size=ksize, strides=1, padding='same') if use_sigmoid is True: # Otherwise use LS-GAN out = tf.nn.sigmoid(out) return out, c1, c2, c3, c4 with tf.variable_scope('discriminator', reuse=reuse): disc, *Dk = discriminator(x, 'original') disc_downsampled_2, *Dk_2 = discriminator(x2, 'downsampled_2') disc_downsampled_4, *Dk_4 = discriminator(x4, 'downsampled_4') return disc, disc_downsampled_2, disc_downsampled_4, Dk, Dk_2, Dk_4 @staticmethod def dcgan_generator(z, config, training, C, reuse=False, actv=tf.nn.relu, kernel_size=5, upsample_dim=256): """ Upsample noise to concatenate with quantized representation w_bar. + z: Drawn from latent distribution - [batch_size, noise_dim] + C: Bottleneck depth, controls bpp - last dimension of encoder output """ init = tf.contrib.layers.xavier_initializer() kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False} with tf.variable_scope('noise_generator', reuse=reuse): # [batch_size, 4, 8, dim] with tf.variable_scope('fc1', reuse=reuse): h2 = tf.layers.dense(z, units=4 * 8 * upsample_dim, activation=actv, kernel_initializer=init) # cifar-10 h2 = tf.layers.batch_normalization(h2, **kwargs) h2 = tf.reshape(h2, shape=[-1, 4, 8, upsample_dim]) # [batch_size, 8, 16, dim/2] with tf.variable_scope('upsample1', reuse=reuse): up1 = tf.layers.conv2d_transpose(h2, upsample_dim//2, kernel_size=kernel_size, strides=2, padding='same', activation=actv) up1 = tf.layers.batch_normalization(up1, **kwargs) # [batch_size, 16, 32, dim/4] with tf.variable_scope('upsample2', reuse=reuse): up2 = tf.layers.conv2d_transpose(up1, upsample_dim//4, kernel_size=kernel_size, strides=2, padding='same', activation=actv) up2 = tf.layers.batch_normalization(up2, **kwargs) # [batch_size, 32, 64, dim/8] with tf.variable_scope('upsample3', reuse=reuse): up3 = tf.layers.conv2d_transpose(up2, upsample_dim//8, kernel_size=kernel_size, strides=2, padding='same', activation=actv) # cifar-10 up3 = tf.layers.batch_normalization(up3, **kwargs) with tf.variable_scope('conv_out', reuse=reuse): out = tf.pad(up3, [[0, 0], [3, 3], [3, 3], [0, 0]], 'REFLECT') out = tf.layers.conv2d(out, C, kernel_size=7, strides=1, padding='VALID') return out @staticmethod def dcgan_discriminator(x, config, training, reuse=False, actv=tf.nn.relu): # x is either generator output G(z) or drawn from the real data distribution init = tf.contrib.layers.xavier_initializer() kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False} print('Shape of x:', x.get_shape().as_list()) x = tf.reshape(x, shape=[-1, 32, 32, 3]) # x = tf.reshape(x, shape=[-1, 28, 28, 1]) with tf.variable_scope('discriminator', reuse=reuse): with tf.variable_scope('conv1', reuse=reuse): c1 = tf.layers.conv2d(x, 64, kernel_size=5, strides=2, padding='same', activation=actv) c1 = tf.layers.batch_normalization(c1, **kwargs) with tf.variable_scope('conv2', reuse=reuse): c2 = tf.layers.conv2d(c1, 128, kernel_size=5, strides=2, padding='same', activation=actv) c2 = tf.layers.batch_normalization(c2, **kwargs) with tf.variable_scope('fc1', reuse=reuse): fc1 = tf.contrib.layers.flatten(c2) # fc1 = tf.reshape(c2, shape=[-1, 8 * 8 * 128]) fc1 = tf.layers.dense(fc1, units=1024, activation=actv, kernel_initializer=init) fc1 = tf.layers.batch_normalization(fc1, **kwargs) with tf.variable_scope('out', reuse=reuse): out = tf.layers.dense(fc1, units=2, activation=None, kernel_initializer=init) return out @staticmethod def critic_grande(x, config, training, reuse=False, actv=tf.nn.relu, kernel_size=5, gradient_penalty=True): # x is either generator output G(z) or drawn from the real data distribution init = tf.contrib.layers.xavier_initializer() kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False} print('Shape of x:', x.get_shape().as_list()) x = tf.reshape(x, shape=[-1, 32, 32, 3]) # x = tf.reshape(x, shape=[-1, 28, 28, 1]) with tf.variable_scope('critic', reuse=reuse): with tf.variable_scope('conv1', reuse=reuse): c1 = tf.layers.conv2d(x, 64, kernel_size=kernel_size, strides=2, padding='same', activation=actv) if gradient_penalty is False: c1 = tf.layers.batch_normalization(c1, **kwargs) with tf.variable_scope('conv2', reuse=reuse): c2 = tf.layers.conv2d(c1, 128, kernel_size=kernel_size, strides=2, padding='same', activation=actv) if gradient_penalty is False: c2 = tf.layers.batch_normalization(c2, **kwargs) with tf.variable_scope('conv3', reuse=reuse): c3 = tf.layers.conv2d(c2, 256, kernel_size=kernel_size, strides=2, padding='same', activation=actv) if gradient_penalty is False: c3 = tf.layers.batch_normalization(c3, **kwargs) with tf.variable_scope('fc1', reuse=reuse): fc1 = tf.contrib.layers.flatten(c3) # fc1 = tf.reshape(c2, shape=[-1, 8 * 8 * 128]) fc1 = tf.layers.dense(fc1, units=1024, activation=actv, kernel_initializer=init) #fc1 = tf.layers.batch_normalization(fc1, **kwargs) with tf.variable_scope('out', reuse=reuse): out = tf.layers.dense(fc1, units=1, activation=None, kernel_initializer=init) return out @staticmethod def wrn(x, config, training, reuse=False, actv=tf.nn.relu): # Implements W-28-10 wide residual network # See Arxiv 1605.07146 network_width = 10 # k block_multiplicity = 2 # n filters = [16, 16, 32, 64] init = tf.contrib.layers.xavier_initializer() kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':True} def residual_block(x, n_filters, actv, keep_prob, training, project_shortcut=False, first_block=False): init = tf.contrib.layers.xavier_initializer() kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':True} if project_shortcut: strides = [2,2] if not first_block else [1,1] identity_map = tf.layers.conv2d(x, filters=n_filters, kernel_size=[1,1], strides=strides, kernel_initializer=init, padding='same') # identity_map = tf.layers.batch_normalization(identity_map, **kwargs) else: strides = [1,1] identity_map = x bn = tf.layers.batch_normalization(x, **kwargs) conv = tf.layers.conv2d(bn, filters=n_filters, kernel_size=[3,3], activation=actv, strides=strides, kernel_initializer=init, padding='same') bn = tf.layers.batch_normalization(conv, **kwargs) do = tf.layers.dropout(bn, rate=1-keep_prob, training=training) conv = tf.layers.conv2d(do, filters=n_filters, kernel_size=[3,3], activation=actv, kernel_initializer=init, padding='same') out = tf.add(conv, identity_map) return out def residual_block_2(x, n_filters, actv, keep_prob, training, project_shortcut=False, first_block=False): init = tf.contrib.layers.xavier_initializer() kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':True} prev_filters = x.get_shape().as_list()[-1] if project_shortcut: strides = [2,2] if not first_block else [1,1] # identity_map = tf.layers.conv2d(x, filters=n_filters, kernel_size=[1,1], # strides=strides, kernel_initializer=init, padding='same') identity_map = tf.layers.average_pooling2d(x, strides, strides, 'valid') identity_map = tf.pad(identity_map, tf.constant([[0,0],[0,0],[0,0],[(n_filters-prev_filters)//2, (n_filters-prev_filters)//2]])) # identity_map = tf.layers.batch_normalization(identity_map, **kwargs) else: strides = [1,1] identity_map = x x = tf.layers.batch_normalization(x, **kwargs) x = tf.nn.relu(x) x = tf.layers.conv2d(x, filters=n_filters, kernel_size=[3,3], strides=strides, kernel_initializer=init, padding='same') x = tf.layers.batch_normalization(x, **kwargs) x = tf.nn.relu(x) x = tf.layers.dropout(x, rate=1-keep_prob, training=training) x = tf.layers.conv2d(x, filters=n_filters, kernel_size=[3,3], kernel_initializer=init, padding='same') out = tf.add(x, identity_map) return out with tf.variable_scope('wrn_conv', reuse=reuse): # Initial convolution ---------------------------------------------> with tf.variable_scope('conv0', reuse=reuse): conv = tf.layers.conv2d(x, filters[0], kernel_size=[3,3], activation=actv, kernel_initializer=init, padding='same') # Residual group 1 ------------------------------------------------> rb = conv f1 = filters[1]*network_width for n in range(block_multiplicity): with tf.variable_scope('group1/{}'.format(n), reuse=reuse): project_shortcut = True if n==0 else False rb = residual_block(rb, f1, actv, project_shortcut=project_shortcut, keep_prob=config.conv_keep_prob, training=training, first_block=True) # Residual group 2 ------------------------------------------------> f2 = filters[2]*network_width for n in range(block_multiplicity): with tf.variable_scope('group2/{}'.format(n), reuse=reuse): project_shortcut = True if n==0 else False rb = residual_block(rb, f2, actv, project_shortcut=project_shortcut, keep_prob=config.conv_keep_prob, training=training) # Residual group 3 ------------------------------------------------> f3 = filters[3]*network_width for n in range(block_multiplicity): with tf.variable_scope('group3/{}'.format(n), reuse=reuse): project_shortcut = True if n==0 else False rb = residual_block(rb, f3, actv, project_shortcut=project_shortcut, keep_prob=config.conv_keep_prob, training=training) # Avg pooling + output --------------------------------------------> with tf.variable_scope('output', reuse=reuse): bn = tf.nn.relu(tf.layers.batch_normalization(rb, **kwargs)) avp = tf.layers.average_pooling2d(bn, pool_size=[8,8], strides=[1,1], padding='valid') flatten = tf.contrib.layers.flatten(avp) out = tf.layers.dense(flatten, units=config.n_classes, kernel_initializer=init) return out @staticmethod def old_encoder(x, config, training, C, reuse=False, actv=tf.nn.relu): """ Process image x ([512,1024]) into a feature map of size W/16 x H/16 x C + C: Bottleneck depth, controls bpp + Output: Projection onto C channels, C = {2,4,8,16} """ # proj_channels = [2,4,8,16] init = tf.contrib.layers.xavier_initializer() def conv_block(x, filters, kernel_size=[3,3], strides=2, padding='same', actv=actv, init=init): in_kwargs = {'center':True, 'scale': True} x = tf.layers.conv2d(x, filters, kernel_size, strides=strides, padding=padding, activation=None) x = tf.contrib.layers.instance_norm(x, **in_kwargs) x = actv(x) return x with tf.variable_scope('encoder', reuse=reuse): # Run convolutions out = conv_block(x, kernel_size=3, strides=1, filters=160, actv=actv) out = conv_block(out, kernel_size=[3,3], strides=2, filters=320, actv=actv) out = conv_block(out, kernel_size=[3,3], strides=2, filters=480, actv=actv) out = conv_block(out, kernel_size=[3,3], strides=2, filters=640, actv=actv) out = conv_block(out, kernel_size=[3,3], strides=2, filters=800, actv=actv) out = conv_block(out, kernel_size=3, strides=1, filters=960, actv=actv) # Project channels onto lower-dimensional embedding space W = tf.get_variable('W_channel_{}'.format(C), shape=[960,C], initializer=init) feature_map = tf.einsum('ijkl,lm->ijkm', out, W) # feature_map = tf.tensordot(out, W, axes=((3),(0))) # Feature maps have dimension W/16 x H/16 x C return feature_map ================================================ FILE: cGAN/train.py ================================================ #!/usr/bin/python3 import tensorflow as tf import numpy as np import pandas as pd import time, os, sys import argparse # User-defined from network import Network from utils import Utils from data import Data from model import Model from config import config_train, directories tf.logging.set_verbosity(tf.logging.ERROR) def train(config, args): start_time = time.time() G_loss_best, D_loss_best = float('inf'), float('inf') ckpt = tf.train.get_checkpoint_state(directories.checkpoints) # Load data print('Training on dataset', args.dataset) if config.use_conditional_GAN: print('Using conditional GAN') paths, semantic_map_paths = Data.load_dataframe(directories.train, load_semantic_maps=True) test_paths, test_semantic_map_paths = Data.load_dataframe(directories.test, load_semantic_maps=True) else: paths = Data.load_dataframe(directories.train) test_paths = Data.load_dataframe(directories.test) # Build graph gan = Model(config, paths, name=args.name, dataset=args.dataset) saver = tf.train.Saver() if config.use_conditional_GAN: feed_dict_test_init = {gan.test_path_placeholder: test_paths, gan.test_semantic_map_path_placeholder: test_semantic_map_paths} feed_dict_train_init = {gan.path_placeholder: paths, gan.semantic_map_path_placeholder: semantic_map_paths} else: feed_dict_test_init = {gan.test_path_placeholder: test_paths} feed_dict_train_init = {gan.path_placeholder: paths} with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) train_handle = sess.run(gan.train_iterator.string_handle()) test_handle = sess.run(gan.test_iterator.string_handle()) if args.restore_last and ckpt.model_checkpoint_path: # Continue training saved model saver.restore(sess, ckpt.model_checkpoint_path) print('{} restored.'.format(ckpt.model_checkpoint_path)) else: if args.restore_path: new_saver = tf.train.import_meta_graph('{}.meta'.format(args.restore_path)) new_saver.restore(sess, args.restore_path) print('{} restored.'.format(args.restore_path)) sess.run(gan.test_iterator.initializer, feed_dict=feed_dict_test_init) for epoch in range(config.num_epochs): sess.run(gan.train_iterator.initializer, feed_dict=feed_dict_train_init) # Run diagnostics G_loss_best, D_loss_best = Utils.run_diagnostics(gan, config, directories, sess, saver, train_handle, start_time, epoch, args.name, G_loss_best, D_loss_best) while True: try: # Update generator # for _ in range(8): feed_dict = {gan.training_phase: True, gan.handle: train_handle} sess.run(gan.G_train_op, feed_dict=feed_dict) # Update discriminator step, _ = sess.run([gan.D_global_step, gan.D_train_op], feed_dict=feed_dict) if step % config.diagnostic_steps == 0: G_loss_best, D_loss_best = Utils.run_diagnostics(gan, config, directories, sess, saver, train_handle, start_time, epoch, args.name, G_loss_best, D_loss_best) Utils.single_plot(epoch, step, sess, gan, train_handle, args.name, config) # for _ in range(4): # sess.run(gan.G_opt_op, feed_dict=feed_dict) except tf.errors.OutOfRangeError: print('End of epoch!') break except KeyboardInterrupt: save_path = saver.save(sess, os.path.join(directories.checkpoints, '{}_last.ckpt'.format(args.name)), global_step=epoch) print('Interrupted, model saved to: ', save_path) sys.exit() save_path = saver.save(sess, os.path.join(directories.checkpoints, '{}_end.ckpt'.format(args.name)), global_step=epoch) print("Training Complete. Model saved to file: {} Time elapsed: {:.3f} s".format(save_path, time.time()-start_time)) def main(**kwargs): parser = argparse.ArgumentParser() parser.add_argument("-rl", "--restore_last", help="restore last saved model", action="store_true") parser.add_argument("-r", "--restore_path", help="path to model to be restored", type=str) parser.add_argument("-opt", "--optimizer", default="adam", help="Selected optimizer", type=str) parser.add_argument("-name", "--name", default="gan-train", help="Checkpoint/Tensorboard label") parser.add_argument("-ds", "--dataset", default="cityscapes", help="choice of training dataset. Currently only supports cityscapes/ADE20k", choices=set(("cityscapes", "ADE20k")), type=str) args = parser.parse_args() # Launch training train(config_train, args) if __name__ == '__main__': main() ================================================ FILE: cGAN/utils.py ================================================ # -*- coding: utf-8 -*- # Diagnostic helper functions for Tensorflow session import tensorflow as tf import numpy as np import os, time import matplotlib as mpl mpl.use('Agg') import matplotlib.pyplot as plt import seaborn as sns from config import directories class Utils(object): @staticmethod def conv_block(x, filters, kernel_size=[3,3], strides=2, padding='same', actv=tf.nn.relu): in_kwargs = {'center':True, 'scale': True} x = tf.layers.conv2d(x, filters, kernel_size, strides=strides, padding=padding, activation=None) x = tf.contrib.layers.instance_norm(x, **in_kwargs) x = actv(x) return x @staticmethod def upsample_block(x, filters, kernel_size=[3,3], strides=2, padding='same', actv=tf.nn.relu): in_kwargs = {'center':True, 'scale': True} x = tf.layers.conv2d_transpose(x, filters, kernel_size, strides=strides, padding=padding, activation=None) x = tf.contrib.layers.instance_norm(x, **in_kwargs) x = actv(x) return x @staticmethod def residual_block(x, n_filters, kernel_size=3, strides=1, actv=tf.nn.relu): init = tf.contrib.layers.xavier_initializer() # kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False} strides = [1,1] identity_map = x p = int((kernel_size-1)/2) res = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]], 'REFLECT') res = tf.layers.conv2d(res, filters=n_filters, kernel_size=kernel_size, strides=strides, activation=None, padding='VALID') res = actv(tf.contrib.layers.instance_norm(res)) res = tf.pad(res, [[0, 0], [p, p], [p, p], [0, 0]], 'REFLECT') res = tf.layers.conv2d(res, filters=n_filters, kernel_size=kernel_size, strides=strides, activation=None, padding='VALID') res = tf.contrib.layers.instance_norm(res) assert res.get_shape().as_list() == identity_map.get_shape().as_list(), 'Mismatched shapes between input/output!' out = tf.add(res, identity_map) return out @staticmethod def get_available_gpus(): from tensorflow.python.client import device_lib local_device_protos = device_lib.list_local_devices() #return local_device_protos print('Available GPUs:') print([x.name for x in local_device_protos if x.device_type == 'GPU']) @staticmethod def scope_variables(name): with tf.variable_scope(name): return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=tf.get_variable_scope().name) @staticmethod def run_diagnostics(model, config, directories, sess, saver, train_handle, start_time, epoch, name, G_loss_best, D_loss_best): t0 = time.time() improved = '' sess.run(tf.local_variables_initializer()) feed_dict_test = {model.training_phase: False, model.handle: train_handle} try: G_loss, D_loss, summary = sess.run([model.G_loss, model.D_loss, model.merge_op], feed_dict=feed_dict_test) model.train_writer.add_summary(summary) except tf.errors.OutOfRangeError: G_loss, D_loss = float('nan'), float('nan') if G_loss < G_loss_best and D_loss < D_loss_best: G_loss_best, D_loss_best = G_loss, D_loss improved = '[*]' if epoch>5: save_path = saver.save(sess, os.path.join(directories.checkpoints_best, '{}_epoch{}.ckpt'.format(name, epoch)), global_step=epoch) print('Graph saved to file: {}'.format(save_path)) if epoch % 5 == 0 and epoch > 5: save_path = saver.save(sess, os.path.join(directories.checkpoints, '{}_epoch{}.ckpt'.format(name, epoch)), global_step=epoch) print('Graph saved to file: {}'.format(save_path)) print('Epoch {} | Generator Loss: {:.3f} | Discriminator Loss: {:.3f} | Rate: {} examples/s ({:.2f} s) {}'.format(epoch, G_loss, D_loss, int(config.batch_size/(time.time()-t0)), time.time() - start_time, improved)) return G_loss_best, D_loss_best @staticmethod def single_plot(epoch, global_step, sess, model, handle, name, config): real = model.example gen = model.reconstruction # Generate images from noise, using the generator network. r, g = sess.run([real, gen], feed_dict={model.training_phase:True, model.handle: handle}) images = list() for im, imtype in zip([r,g], ['real', 'gen']): im = ((im+1.0))/2 # [-1,1] -> [0,1] im = np.squeeze(im) im = im[:,:,:3] images.append(im) # Uncomment to plot real and generated samples separately # f = plt.figure() # plt.imshow(im) # plt.axis('off') # f.savefig("{}/gan_compression_{}_epoch{}_step{}_{}.pdf".format(directories.samples, name, epoch, # global_step, imtype), format='pdf', dpi=720, bbox_inches='tight', pad_inches=0) # plt.gcf().clear() # plt.close(f) comparison = np.hstack(images) f = plt.figure() plt.imshow(comparison) plt.axis('off') f.savefig("{}/gan_compression_{}_epoch{}_step{}_{}_comparison.pdf".format(directories.samples, name, epoch, global_step, imtype), format='pdf', dpi=720, bbox_inches='tight', pad_inches=0) plt.gcf().clear() plt.close(f) @staticmethod def weight_decay(weight_decay, var_label='DW'): """L2 weight decay loss.""" costs = [] for var in tf.trainable_variables(): if var.op.name.find(r'{}'.format(var_label)) > 0: costs.append(tf.nn.l2_loss(var)) return tf.multiply(weight_decay, tf.add_n(costs)) ================================================ FILE: checkpoints/.gitignore ================================================ * !.gitignore !best/ ================================================ FILE: compress.py ================================================ #!/usr/bin/python3 import tensorflow as tf import numpy as np import pandas as pd import time, os, sys import argparse # User-defined from network import Network from utils import Utils from data import Data from model import Model from config import config_test, directories tf.logging.set_verbosity(tf.logging.ERROR) def single_compress(config, args): start = time.time() ckpt = tf.train.get_checkpoint_state(directories.checkpoints) assert (ckpt.model_checkpoint_path), 'Missing checkpoint file!' if config.use_conditional_GAN: print('Using conditional GAN') paths, semantic_map_paths = np.array([args.image_path]), np.array([args.semantic_map_path]) else: paths = np.array([args.image_path]) gan = Model(config, paths, name='single_compress', dataset=args.dataset, evaluate=True) saver = tf.train.Saver() if config.use_conditional_GAN: feed_dict_init = {gan.path_placeholder: paths, gan.semantic_map_path_placeholder: semantic_map_paths} else: feed_dict_init = {gan.path_placeholder: paths} with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess: # Initialize variables sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) handle = sess.run(gan.train_iterator.string_handle()) if args.restore_last and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) print('Most recent {} restored.'.format(ckpt.model_checkpoint_path)) else: if args.restore_path: new_saver = tf.train.import_meta_graph('{}.meta'.format(args.restore_path)) new_saver.restore(sess, args.restore_path) print('Previous checkpoint {} restored.'.format(args.restore_path)) sess.run(gan.train_iterator.initializer, feed_dict=feed_dict_init) eval_dict = {gan.training_phase: False, gan.handle: handle} if args.output_path is None: output = os.path.splitext(os.path.basename(args.image_path)) save_path = os.path.join(directories.samples, '{}_compressed.pdf'.format(output[0])) else: save_path = args.output_path Utils.single_plot(0, 0, sess, gan, handle, save_path, config, single_compress=True) print('Reconstruction saved to', save_path) return def main(**kwargs): parser = argparse.ArgumentParser() parser.add_argument("-rl", "--restore_last", help="restore last saved model", action="store_true") parser.add_argument("-r", "--restore_path", help="path to model to be restored", type=str) parser.add_argument("-i", "--image_path", help="path to image to compress", type=str) parser.add_argument("-sm", "--semantic_map_path", help="path to corresponding semantic map", type=str) parser.add_argument("-o", "--output_path", help="path to output image", type=str) parser.add_argument("-ds", "--dataset", default="cityscapes", help="choice of training dataset. Currently only supports cityscapes/ADE20k", choices=set(("cityscapes", "ADE20k")), type=str) args = parser.parse_args() # Launch training single_compress(config_test, args) if __name__ == '__main__': main() ================================================ FILE: config.py ================================================ #!/usr/bin/env python3 class config_train(object): mode = 'gan-train' num_epochs = 512 batch_size = 1 ema_decay = 0.999 G_learning_rate = 2e-4 D_learning_rate = 2e-4 lr_decay_rate = 2e-5 momentum = 0.9 weight_decay = 5e-4 noise_dim = 128 optimizer = 'adam' kernel_size = 3 diagnostic_steps = 256 # WGAN gradient_penalty = True lambda_gp = 10 weight_clipping = False max_c = 1e-2 n_critic_iterations = 20 # Compression lambda_X = 12 channel_bottleneck = 8 sample_noise = True use_vanilla_GAN = False use_feature_matching_loss = True upsample_dim = 256 multiscale = True feature_matching_weight = 10 use_conditional_GAN = False class config_test(object): mode = 'gan-test' num_epochs = 512 batch_size = 1 ema_decay = 0.999 G_learning_rate = 2e-4 D_learning_rate = 2e-4 lr_decay_rate = 2e-5 momentum = 0.9 weight_decay = 5e-4 noise_dim = 128 optimizer = 'adam' kernel_size = 3 diagnostic_steps = 256 # WGAN gradient_penalty = True lambda_gp = 10 weight_clipping = False max_c = 1e-2 n_critic_iterations = 5 # Compression lambda_X = 12 channel_bottleneck = 8 sample_noise = True use_vanilla_GAN = False use_feature_matching_loss = True upsample_dim = 256 multiscale = True feature_matching_weight = 10 use_conditional_GAN = False class directories(object): train = 'data/cityscapes_paths_train.h5' test = 'data/cityscapes_paths_test.h5' val = 'data/cityscapes_paths_val.h5' tensorboard = 'tensorboard' checkpoints = 'checkpoints' checkpoints_best = 'checkpoints/best' samples = 'samples/cityscapes' ================================================ FILE: data/.gitignore ================================================ * !.gitignore !resize_cityscapes.sh !cityscapes_paths_train.h5 !cityscapes_paths_test.h5 !cityscapes_paths_val.h5 ================================================ FILE: data/resize_cityscapes.sh ================================================ #!/bin/bash # Author: Grace Han # In place resampling to 512 x 1024 px # Requires imagemagick on a *nix system # Modify according to your directory structure for f in ./**/*.png; do convert $f -resize 1024x512 $f done ================================================ FILE: data.py ================================================ #!/usr/bin/python3 import tensorflow as tf import numpy as np import pandas as pd from config import directories class Data(object): @staticmethod def load_dataframe(filename, load_semantic_maps=False): df = pd.read_hdf(filename, key='df').sample(frac=1).reset_index(drop=True) if load_semantic_maps: return df['path'].values, df['semantic_map_path'].values else: return df['path'].values @staticmethod def load_dataset(image_paths, batch_size, test=False, augment=False, downsample=False, training_dataset='cityscapes', use_conditional_GAN=False, **kwargs): def _augment(image): # On-the-fly data augmentation image = tf.image.random_brightness(image, max_delta=0.1) image = tf.image.random_contrast(image, 0.5, 1.5) image = tf.image.random_flip_left_right(image) return image def _parser(image_path, semantic_map_path=None): def _aspect_preserving_width_resize(image, width=512): height_i = tf.shape(image)[0] # width_i = tf.shape(image)[1] # ratio = tf.to_float(width_i) / tf.to_float(height_i) # new_height = tf.to_int32(tf.to_float(height_i) / ratio) new_height = height_i - tf.floormod(height_i, 16) return tf.image.resize_image_with_crop_or_pad(image, new_height, width) def _image_decoder(path): im = tf.image.decode_png(tf.read_file(path), channels=3) im = tf.image.convert_image_dtype(im, dtype=tf.float32) return 2 * im - 1 # [0,1] -> [-1,1] (tanh range) image = _image_decoder(image_path) # Explicitly set the shape if you want a sanity check # or if you are using your own custom dataset, otherwise # the model is shape-agnostic as it is fully convolutional # im.set_shape([512,1024,3]) # downscaled cityscapes if use_conditional_GAN: # Semantic map only enabled for cityscapes semantic_map = _image_decoder(semantic_map_path) if training_dataset == 'ADE20k': image = _aspect_preserving_width_resize(image) if use_conditional_GAN: semantic_map = _aspect_preserving_width_resize(semantic_map) # im.set_shape([None,512,3]) if use_conditional_GAN: return image, semantic_map else: return image print('Training on', training_dataset) if use_conditional_GAN: dataset = tf.data.Dataset.from_tensor_slices((image_paths, kwargs['semantic_map_paths'])) else: dataset = tf.data.Dataset.from_tensor_slices(image_paths) dataset = dataset.shuffle(buffer_size=8) dataset = dataset.map(_parser) dataset = dataset.cache() dataset = dataset.batch(batch_size) if test: dataset = dataset.repeat() return dataset @staticmethod def load_cGAN_dataset(image_paths, semantic_map_paths, batch_size, test=False, augment=False, downsample=False, training_dataset='cityscapes'): """ Load image dataset with semantic label maps for conditional GAN """ def _parser(image_path, semantic_map_path): def _aspect_preserving_width_resize(image, width=512): # If training on ADE20k height_i = tf.shape(image)[0] new_height = height_i - tf.floormod(height_i, 16) return tf.image.resize_image_with_crop_or_pad(image, new_height, width) def _image_decoder(path): im = tf.image.decode_png(tf.read_file(image_path), channels=3) im = tf.image.convert_image_dtype(im, dtype=tf.float32) return 2 * im - 1 # [0,1] -> [-1,1] (tanh range) image, semantic_map = _image_decoder(image_path), _image_decoder(semantic_map_path) print('Training on', training_dataset) if training_dataset is 'ADE20k': image = _aspect_preserving_width_resize(image) semantic_map = _aspect_preserving_width_resize(semantic_map) # im.set_shape([512,1024,3]) # downscaled cityscapes return image, semantic_map dataset = tf.data.Dataset.from_tensor_slices(image_paths, semantic_map_paths) dataset = dataset.map(_parser) dataset = dataset.shuffle(buffer_size=8) dataset = dataset.batch(batch_size) if test: dataset = dataset.repeat() return dataset @staticmethod def load_inference(filenames, labels, batch_size, resize=(32,32)): # Single image estimation over multiple stochastic forward passes def _preprocess_inference(image_path, label, resize=(32,32)): # Preprocess individual images during inference image_path = tf.squeeze(image_path) image = tf.image.decode_png(tf.read_file(image_path)) image = tf.image.convert_image_dtype(image, dtype=tf.float32) image = tf.image.per_image_standardization(image) image = tf.image.resize_images(image, size=resize) return image, label dataset = tf.data.Dataset.from_tensor_slices((filenames, labels)) dataset = dataset.map(_preprocess_inference) dataset = dataset.batch(batch_size) return dataset ================================================ FILE: model.py ================================================ #!/usr/bin/python3 import tensorflow as tf import numpy as np import glob, time, os from network import Network from data import Data from config import directories from utils import Utils class Model(): def __init__(self, config, paths, dataset, name='gan_compression', evaluate=False): # Build the computational graph print('Building computational graph ...') self.G_global_step = tf.Variable(0, trainable=False) self.D_global_step = tf.Variable(0, trainable=False) self.handle = tf.placeholder(tf.string, shape=[]) self.training_phase = tf.placeholder(tf.bool) # >>> Data handling self.path_placeholder = tf.placeholder(paths.dtype, paths.shape) self.test_path_placeholder = tf.placeholder(paths.dtype) self.semantic_map_path_placeholder = tf.placeholder(paths.dtype, paths.shape) self.test_semantic_map_path_placeholder = tf.placeholder(paths.dtype) train_dataset = Data.load_dataset(self.path_placeholder, config.batch_size, augment=False, training_dataset=dataset, use_conditional_GAN=config.use_conditional_GAN, semantic_map_paths=self.semantic_map_path_placeholder) test_dataset = Data.load_dataset(self.test_path_placeholder, config.batch_size, augment=False, training_dataset=dataset, use_conditional_GAN=config.use_conditional_GAN, semantic_map_paths=self.test_semantic_map_path_placeholder, test=True) self.iterator = tf.data.Iterator.from_string_handle(self.handle, train_dataset.output_types, train_dataset.output_shapes) self.train_iterator = train_dataset.make_initializable_iterator() self.test_iterator = test_dataset.make_initializable_iterator() if config.use_conditional_GAN: self.example, self.semantic_map = self.iterator.get_next() else: self.example = self.iterator.get_next() # Global generator: Encode -> quantize -> reconstruct # =======================================================================================================>>> with tf.variable_scope('generator'): self.feature_map = Network.encoder(self.example, config, self.training_phase, config.channel_bottleneck) self.w_hat = Network.quantizer(self.feature_map, config) if config.use_conditional_GAN: self.semantic_feature_map = Network.encoder(self.semantic_map, config, self.training_phase, config.channel_bottleneck, scope='semantic_map') self.w_hat_semantic = Network.quantizer(self.semantic_feature_map, config, scope='semantic_map') self.w_hat = tf.concat([self.w_hat, self.w_hat_semantic], axis=-1) if config.sample_noise is True: print('Sampling noise...') # noise_prior = tf.contrib.distributions.Uniform(-1., 1.) # self.noise_sample = noise_prior.sample([tf.shape(self.example)[0], config.noise_dim]) noise_prior = tf.contrib.distributions.MultivariateNormalDiag(loc=tf.zeros([config.noise_dim]), scale_diag=tf.ones([config.noise_dim])) v = noise_prior.sample(tf.shape(self.example)[0]) Gv = Network.dcgan_generator(v, config, self.training_phase, C=config.channel_bottleneck, upsample_dim=config.upsample_dim) self.z = tf.concat([self.w_hat, Gv], axis=-1) else: self.z = self.w_hat self.reconstruction = Network.decoder(self.z, config, self.training_phase, C=config.channel_bottleneck) print('Real image shape:', self.example.get_shape().as_list()) print('Reconstruction shape:', self.reconstruction.get_shape().as_list()) if evaluate: return # Pass generated, real images to discriminator # =======================================================================================================>>> if config.use_conditional_GAN: # Model conditional distribution self.example = tf.concat([self.example, self.semantic_map], axis=-1) self.reconstruction = tf.concat([self.reconstruction, self.semantic_map], axis=-1) if config.multiscale: D_x, D_x2, D_x4, *Dk_x = Network.multiscale_discriminator(self.example, config, self.training_phase, use_sigmoid=config.use_vanilla_GAN, mode='real') D_Gz, D_Gz2, D_Gz4, *Dk_Gz = Network.multiscale_discriminator(self.reconstruction, config, self.training_phase, use_sigmoid=config.use_vanilla_GAN, mode='reconstructed', reuse=True) else: D_x = Network.discriminator(self.example, config, self.training_phase, use_sigmoid=config.use_vanilla_GAN) D_Gz = Network.discriminator(self.reconstruction, config, self.training_phase, use_sigmoid=config.use_vanilla_GAN, reuse=True) # Loss terms # =======================================================================================================>>> if config.use_vanilla_GAN is True: # Minimize JS divergence D_loss_real = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=D_x, labels=tf.ones_like(D_x))) D_loss_gen = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=D_Gz, labels=tf.zeros_like(D_Gz))) self.D_loss = D_loss_real + D_loss_gen # G_loss = max log D(G(z)) self.G_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=D_Gz, labels=tf.ones_like(D_Gz))) else: # Minimize $\chi^2$ divergence self.D_loss = tf.reduce_mean(tf.square(D_x - 1.)) + tf.reduce_mean(tf.square(D_Gz)) self.G_loss = tf.reduce_mean(tf.square(D_Gz - 1.)) if config.multiscale: self.D_loss += tf.reduce_mean(tf.square(D_x2 - 1.)) + tf.reduce_mean(tf.square(D_x4 - 1.)) self.D_loss += tf.reduce_mean(tf.square(D_Gz2)) + tf.reduce_mean(tf.square(D_Gz4)) distortion_penalty = config.lambda_X * tf.losses.mean_squared_error(self.example, self.reconstruction) self.G_loss += distortion_penalty if config.use_feature_matching_loss: # feature extractor for generator D_x_layers, D_Gz_layers = [j for i in Dk_x for j in i], [j for i in Dk_Gz for j in i] feature_matching_loss = tf.reduce_sum([tf.reduce_mean(tf.abs(Dkx-Dkz)) for Dkx, Dkz in zip(D_x_layers, D_Gz_layers)]) self.G_loss += config.feature_matching_weight * feature_matching_loss # Optimization # =======================================================================================================>>> G_opt = tf.train.AdamOptimizer(learning_rate=config.G_learning_rate, beta1=0.5) D_opt = tf.train.AdamOptimizer(learning_rate=config.D_learning_rate, beta1=0.5) theta_G = Utils.scope_variables('generator') theta_D = Utils.scope_variables('discriminator') # print('Generator parameters:', theta_G) # print('Discriminator parameters:', theta_D) G_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='generator') D_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='discriminator') # Execute the update_ops before performing the train_step with tf.control_dependencies(G_update_ops): self.G_opt_op = G_opt.minimize(self.G_loss, name='G_opt', global_step=self.G_global_step, var_list=theta_G) with tf.control_dependencies(D_update_ops): self.D_opt_op = D_opt.minimize(self.D_loss, name='D_opt', global_step=self.D_global_step, var_list=theta_D) G_ema = tf.train.ExponentialMovingAverage(decay=config.ema_decay, num_updates=self.G_global_step) G_maintain_averages_op = G_ema.apply(theta_G) D_ema = tf.train.ExponentialMovingAverage(decay=config.ema_decay, num_updates=self.D_global_step) D_maintain_averages_op = D_ema.apply(theta_D) with tf.control_dependencies(G_update_ops+[self.G_opt_op]): self.G_train_op = tf.group(G_maintain_averages_op) with tf.control_dependencies(D_update_ops+[self.D_opt_op]): self.D_train_op = tf.group(D_maintain_averages_op) # >>> Monitoring # tf.summary.scalar('learning_rate', learning_rate) tf.summary.scalar('generator_loss', self.G_loss) tf.summary.scalar('discriminator_loss', self.D_loss) tf.summary.scalar('distortion_penalty', distortion_penalty) if config.use_feature_matching_loss: tf.summary.scalar('feature_matching_loss', feature_matching_loss) tf.summary.scalar('G_global_step', self.G_global_step) tf.summary.scalar('D_global_step', self.D_global_step) tf.summary.image('real_images', self.example[:,:,:,:3], max_outputs=4) tf.summary.image('compressed_images', self.reconstruction[:,:,:,:3], max_outputs=4) if config.use_conditional_GAN: tf.summary.image('semantic_map', self.semantic_map, max_outputs=4) self.merge_op = tf.summary.merge_all() self.train_writer = tf.summary.FileWriter( os.path.join(directories.tensorboard, '{}_train_{}'.format(name, time.strftime('%d-%m_%I:%M'))), graph=tf.get_default_graph()) self.test_writer = tf.summary.FileWriter( os.path.join(directories.tensorboard, '{}_test_{}'.format(name, time.strftime('%d-%m_%I:%M')))) ================================================ FILE: network.py ================================================ """ Modular components of computational graph JTan 2018 """ import tensorflow as tf from utils import Utils class Network(object): @staticmethod def encoder(x, config, training, C, reuse=False, actv=tf.nn.relu, scope='image'): """ Process image x ([512,1024]) into a feature map of size W/16 x H/16 x C + C: Bottleneck depth, controls bpp + Output: Projection onto C channels, C = {2,4,8,16} """ init = tf.contrib.layers.xavier_initializer() print('<------------ Building global {} generator architecture ------------>'.format(scope)) def conv_block(x, filters, kernel_size=[3,3], strides=2, padding='same', actv=actv, init=init): bn_kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False} in_kwargs = {'center':True, 'scale': True} x = tf.layers.conv2d(x, filters, kernel_size, strides=strides, padding=padding, activation=None) # x = tf.layers.batch_normalization(x, **bn_kwargs) x = tf.contrib.layers.instance_norm(x, **in_kwargs) x = actv(x) return x with tf.variable_scope('encoder_{}'.format(scope), reuse=reuse): # Run convolutions f = [60, 120, 240, 480, 960] x = tf.pad(x, [[0, 0], [3, 3], [3, 3], [0, 0]], 'REFLECT') out = conv_block(x, filters=f[0], kernel_size=7, strides=1, padding='VALID', actv=actv) out = conv_block(out, filters=f[1], kernel_size=3, strides=2, actv=actv) out = conv_block(out, filters=f[2], kernel_size=3, strides=2, actv=actv) out = conv_block(out, filters=f[3], kernel_size=3, strides=2, actv=actv) out = conv_block(out, filters=f[4], kernel_size=3, strides=2, actv=actv) # Project channels onto space w/ dimension C # Feature maps have dimension W/16 x H/16 x C out = tf.pad(out, [[0, 0], [1, 1], [1, 1], [0, 0]], 'REFLECT') feature_map = conv_block(out, filters=C, kernel_size=3, strides=1, padding='VALID', actv=actv) return feature_map @staticmethod def quantizer(w, config, reuse=False, temperature=1, L=5, scope='image'): """ Quantize feature map over L centers to obtain discrete $\hat{w}$ + Centers: {-2,-1,0,1,2} + TODO: Toggle learnable centers? """ with tf.variable_scope('quantizer_{}'.format(scope, reuse=reuse)): centers = tf.cast(tf.range(-2,3), tf.float32) # Partition W into the Voronoi tesellation over the centers w_stack = tf.stack([w for _ in range(L)], axis=-1) w_hard = tf.cast(tf.argmin(tf.abs(w_stack - centers), axis=-1), tf.float32) + tf.reduce_min(centers) smx = tf.nn.softmax(-1.0/temperature * tf.abs(w_stack - centers), dim=-1) # Contract last dimension w_soft = tf.einsum('ijklm,m->ijkl', smx, centers) # w_soft = tf.tensordot(smx, centers, axes=((-1),(0))) # Treat quantization as differentiable for optimization w_bar = tf.round(tf.stop_gradient(w_hard - w_soft) + w_soft) return w_bar @staticmethod def decoder(w_bar, config, training, C, reuse=False, actv=tf.nn.relu, channel_upsample=960): """ Attempt to reconstruct the image from the quantized representation w_bar. Generated image should be consistent with the true image distribution while recovering the specific encoded image + C: Bottleneck depth, controls bpp - last dimension of encoder output + TODO: Concatenate quantized w_bar with noise sampled from prior """ init = tf.contrib.layers.xavier_initializer() def residual_block(x, n_filters, kernel_size=3, strides=1, actv=actv): init = tf.contrib.layers.xavier_initializer() # kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False} strides = [1,1] identity_map = x p = int((kernel_size-1)/2) res = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]], 'REFLECT') res = tf.layers.conv2d(res, filters=n_filters, kernel_size=kernel_size, strides=strides, activation=None, padding='VALID') res = actv(tf.contrib.layers.instance_norm(res)) res = tf.pad(res, [[0, 0], [p, p], [p, p], [0, 0]], 'REFLECT') res = tf.layers.conv2d(res, filters=n_filters, kernel_size=kernel_size, strides=strides, activation=None, padding='VALID') res = tf.contrib.layers.instance_norm(res) assert res.get_shape().as_list() == identity_map.get_shape().as_list(), 'Mismatched shapes between input/output!' out = tf.add(res, identity_map) return out def upsample_block(x, filters, kernel_size=[3,3], strides=2, padding='same', actv=actv, batch_norm=False): bn_kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False} in_kwargs = {'center':True, 'scale': True} x = tf.layers.conv2d_transpose(x, filters, kernel_size, strides=strides, padding=padding, activation=None) if batch_norm is True: x = tf.layers.batch_normalization(x, **bn_kwargs) else: x = tf.contrib.layers.instance_norm(x, **in_kwargs) x = actv(x) return x # Project channel dimension of w_bar to higher dimension # W_pc = tf.get_variable('W_pc_{}'.format(C), shape=[C, channel_upsample], initializer=init) # upsampled = tf.einsum('ijkl,lm->ijkm', w_bar, W_pc) with tf.variable_scope('decoder', reuse=reuse): w_bar = tf.pad(w_bar, [[0, 0], [1, 1], [1, 1], [0, 0]], 'REFLECT') upsampled = Utils.conv_block(w_bar, filters=960, kernel_size=3, strides=1, padding='VALID', actv=actv) # Process upsampled feature map with residual blocks res = residual_block(upsampled, 960, actv=actv) res = residual_block(res, 960, actv=actv) res = residual_block(res, 960, actv=actv) res = residual_block(res, 960, actv=actv) res = residual_block(res, 960, actv=actv) res = residual_block(res, 960, actv=actv) res = residual_block(res, 960, actv=actv) res = residual_block(res, 960, actv=actv) res = residual_block(res, 960, actv=actv) # Upsample to original dimensions - mirror decoder f = [480, 240, 120, 60] ups = upsample_block(res, f[0], 3, strides=[2,2], padding='same') ups = upsample_block(ups, f[1], 3, strides=[2,2], padding='same') ups = upsample_block(ups, f[2], 3, strides=[2,2], padding='same') ups = upsample_block(ups, f[3], 3, strides=[2,2], padding='same') ups = tf.pad(ups, [[0, 0], [3, 3], [3, 3], [0, 0]], 'REFLECT') ups = tf.layers.conv2d(ups, 3, kernel_size=7, strides=1, padding='VALID') out = tf.nn.tanh(ups) return out @staticmethod def discriminator(x, config, training, reuse=False, actv=tf.nn.leaky_relu, use_sigmoid=False, ksize=4): # x is either generator output G(z) or drawn from the real data distribution # Patch-GAN discriminator based on arXiv 1711.11585 # bn_kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False} in_kwargs = {'center':True, 'scale':True, 'activation_fn':actv} print('Shape of x:', x.get_shape().as_list()) with tf.variable_scope('discriminator', reuse=reuse): c1 = tf.layers.conv2d(x, 64, kernel_size=ksize, strides=2, padding='same', activation=actv) c2 = tf.layers.conv2d(c1, 128, kernel_size=ksize, strides=2, padding='same') c2 = actv(tf.contrib.layers.instance_norm(c2, **in_kwargs)) c3 = tf.layers.conv2d(c2, 256, kernel_size=ksize, strides=2, padding='same') c3 = actv(tf.contrib.layers.instance_norm(c3, **in_kwargs)) c4 = tf.layers.conv2d(c3, 512, kernel_size=ksize, strides=2, padding='same') c4 = actv(tf.contrib.layers.instance_norm(c4, **in_kwargs)) out = tf.layers.conv2d(c4, 1, kernel_size=ksize, strides=1, padding='same') if use_sigmoid is True: # Otherwise use LS-GAN out = tf.nn.sigmoid(out) return out @staticmethod def multiscale_discriminator(x, config, training, actv=tf.nn.leaky_relu, use_sigmoid=False, ksize=4, mode='real', reuse=False): # x is either generator output G(z) or drawn from the real data distribution # Multiscale + Patch-GAN discriminator architecture based on arXiv 1711.11585 print('<------------ Building multiscale discriminator architecture ------------>') if mode == 'real': print('Building discriminator D(x)') elif mode == 'reconstructed': print('Building discriminator D(G(z))') else: raise NotImplementedError('Invalid discriminator mode specified.') # Downsample input x2 = tf.layers.average_pooling2d(x, pool_size=3, strides=2, padding='same') x4 = tf.layers.average_pooling2d(x2, pool_size=3, strides=2, padding='same') print('Shape of x:', x.get_shape().as_list()) print('Shape of x downsampled by factor 2:', x2.get_shape().as_list()) print('Shape of x downsampled by factor 4:', x4.get_shape().as_list()) def discriminator(x, scope, actv=actv, use_sigmoid=use_sigmoid, ksize=ksize, reuse=reuse): # Returns patch-GAN output + intermediate layers with tf.variable_scope('discriminator_{}'.format(scope), reuse=reuse): c1 = tf.layers.conv2d(x, 64, kernel_size=ksize, strides=2, padding='same', activation=actv) c2 = Utils.conv_block(c1, filters=128, kernel_size=ksize, strides=2, padding='same', actv=actv) c3 = Utils.conv_block(c2, filters=256, kernel_size=ksize, strides=2, padding='same', actv=actv) c4 = Utils.conv_block(c3, filters=512, kernel_size=ksize, strides=2, padding='same', actv=actv) out = tf.layers.conv2d(c4, 1, kernel_size=ksize, strides=1, padding='same') if use_sigmoid is True: # Otherwise use LS-GAN out = tf.nn.sigmoid(out) return out, c1, c2, c3, c4 with tf.variable_scope('discriminator', reuse=reuse): disc, *Dk = discriminator(x, 'original') disc_downsampled_2, *Dk_2 = discriminator(x2, 'downsampled_2') disc_downsampled_4, *Dk_4 = discriminator(x4, 'downsampled_4') return disc, disc_downsampled_2, disc_downsampled_4, Dk, Dk_2, Dk_4 @staticmethod def dcgan_generator(z, config, training, C, reuse=False, actv=tf.nn.relu, kernel_size=5, upsample_dim=256): """ Upsample noise to concatenate with quantized representation w_bar. + z: Drawn from latent distribution - [batch_size, noise_dim] + C: Bottleneck depth, controls bpp - last dimension of encoder output """ init = tf.contrib.layers.xavier_initializer() kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False} with tf.variable_scope('noise_generator', reuse=reuse): # [batch_size, 4, 8, dim] with tf.variable_scope('fc1', reuse=reuse): h2 = tf.layers.dense(z, units=4 * 8 * upsample_dim, activation=actv, kernel_initializer=init) # cifar-10 h2 = tf.layers.batch_normalization(h2, **kwargs) h2 = tf.reshape(h2, shape=[-1, 4, 8, upsample_dim]) # [batch_size, 8, 16, dim/2] with tf.variable_scope('upsample1', reuse=reuse): up1 = tf.layers.conv2d_transpose(h2, upsample_dim//2, kernel_size=kernel_size, strides=2, padding='same', activation=actv) up1 = tf.layers.batch_normalization(up1, **kwargs) # [batch_size, 16, 32, dim/4] with tf.variable_scope('upsample2', reuse=reuse): up2 = tf.layers.conv2d_transpose(up1, upsample_dim//4, kernel_size=kernel_size, strides=2, padding='same', activation=actv) up2 = tf.layers.batch_normalization(up2, **kwargs) # [batch_size, 32, 64, dim/8] with tf.variable_scope('upsample3', reuse=reuse): up3 = tf.layers.conv2d_transpose(up2, upsample_dim//8, kernel_size=kernel_size, strides=2, padding='same', activation=actv) # cifar-10 up3 = tf.layers.batch_normalization(up3, **kwargs) with tf.variable_scope('conv_out', reuse=reuse): out = tf.pad(up3, [[0, 0], [3, 3], [3, 3], [0, 0]], 'REFLECT') out = tf.layers.conv2d(out, C, kernel_size=7, strides=1, padding='VALID') return out @staticmethod def dcgan_discriminator(x, config, training, reuse=False, actv=tf.nn.relu): # x is either generator output G(z) or drawn from the real data distribution init = tf.contrib.layers.xavier_initializer() kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False} print('Shape of x:', x.get_shape().as_list()) x = tf.reshape(x, shape=[-1, 32, 32, 3]) # x = tf.reshape(x, shape=[-1, 28, 28, 1]) with tf.variable_scope('discriminator', reuse=reuse): with tf.variable_scope('conv1', reuse=reuse): c1 = tf.layers.conv2d(x, 64, kernel_size=5, strides=2, padding='same', activation=actv) c1 = tf.layers.batch_normalization(c1, **kwargs) with tf.variable_scope('conv2', reuse=reuse): c2 = tf.layers.conv2d(c1, 128, kernel_size=5, strides=2, padding='same', activation=actv) c2 = tf.layers.batch_normalization(c2, **kwargs) with tf.variable_scope('fc1', reuse=reuse): fc1 = tf.contrib.layers.flatten(c2) # fc1 = tf.reshape(c2, shape=[-1, 8 * 8 * 128]) fc1 = tf.layers.dense(fc1, units=1024, activation=actv, kernel_initializer=init) fc1 = tf.layers.batch_normalization(fc1, **kwargs) with tf.variable_scope('out', reuse=reuse): out = tf.layers.dense(fc1, units=2, activation=None, kernel_initializer=init) return out @staticmethod def critic_grande(x, config, training, reuse=False, actv=tf.nn.relu, kernel_size=5, gradient_penalty=True): # x is either generator output G(z) or drawn from the real data distribution init = tf.contrib.layers.xavier_initializer() kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False} print('Shape of x:', x.get_shape().as_list()) x = tf.reshape(x, shape=[-1, 32, 32, 3]) # x = tf.reshape(x, shape=[-1, 28, 28, 1]) with tf.variable_scope('critic', reuse=reuse): with tf.variable_scope('conv1', reuse=reuse): c1 = tf.layers.conv2d(x, 64, kernel_size=kernel_size, strides=2, padding='same', activation=actv) if gradient_penalty is False: c1 = tf.layers.batch_normalization(c1, **kwargs) with tf.variable_scope('conv2', reuse=reuse): c2 = tf.layers.conv2d(c1, 128, kernel_size=kernel_size, strides=2, padding='same', activation=actv) if gradient_penalty is False: c2 = tf.layers.batch_normalization(c2, **kwargs) with tf.variable_scope('conv3', reuse=reuse): c3 = tf.layers.conv2d(c2, 256, kernel_size=kernel_size, strides=2, padding='same', activation=actv) if gradient_penalty is False: c3 = tf.layers.batch_normalization(c3, **kwargs) with tf.variable_scope('fc1', reuse=reuse): fc1 = tf.contrib.layers.flatten(c3) # fc1 = tf.reshape(c2, shape=[-1, 8 * 8 * 128]) fc1 = tf.layers.dense(fc1, units=1024, activation=actv, kernel_initializer=init) #fc1 = tf.layers.batch_normalization(fc1, **kwargs) with tf.variable_scope('out', reuse=reuse): out = tf.layers.dense(fc1, units=1, activation=None, kernel_initializer=init) return out @staticmethod def wrn(x, config, training, reuse=False, actv=tf.nn.relu): # Implements W-28-10 wide residual network # See Arxiv 1605.07146 network_width = 10 # k block_multiplicity = 2 # n filters = [16, 16, 32, 64] init = tf.contrib.layers.xavier_initializer() kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':True} def residual_block(x, n_filters, actv, keep_prob, training, project_shortcut=False, first_block=False): init = tf.contrib.layers.xavier_initializer() kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':True} if project_shortcut: strides = [2,2] if not first_block else [1,1] identity_map = tf.layers.conv2d(x, filters=n_filters, kernel_size=[1,1], strides=strides, kernel_initializer=init, padding='same') # identity_map = tf.layers.batch_normalization(identity_map, **kwargs) else: strides = [1,1] identity_map = x bn = tf.layers.batch_normalization(x, **kwargs) conv = tf.layers.conv2d(bn, filters=n_filters, kernel_size=[3,3], activation=actv, strides=strides, kernel_initializer=init, padding='same') bn = tf.layers.batch_normalization(conv, **kwargs) do = tf.layers.dropout(bn, rate=1-keep_prob, training=training) conv = tf.layers.conv2d(do, filters=n_filters, kernel_size=[3,3], activation=actv, kernel_initializer=init, padding='same') out = tf.add(conv, identity_map) return out def residual_block_2(x, n_filters, actv, keep_prob, training, project_shortcut=False, first_block=False): init = tf.contrib.layers.xavier_initializer() kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':True} prev_filters = x.get_shape().as_list()[-1] if project_shortcut: strides = [2,2] if not first_block else [1,1] # identity_map = tf.layers.conv2d(x, filters=n_filters, kernel_size=[1,1], # strides=strides, kernel_initializer=init, padding='same') identity_map = tf.layers.average_pooling2d(x, strides, strides, 'valid') identity_map = tf.pad(identity_map, tf.constant([[0,0],[0,0],[0,0],[(n_filters-prev_filters)//2, (n_filters-prev_filters)//2]])) # identity_map = tf.layers.batch_normalization(identity_map, **kwargs) else: strides = [1,1] identity_map = x x = tf.layers.batch_normalization(x, **kwargs) x = tf.nn.relu(x) x = tf.layers.conv2d(x, filters=n_filters, kernel_size=[3,3], strides=strides, kernel_initializer=init, padding='same') x = tf.layers.batch_normalization(x, **kwargs) x = tf.nn.relu(x) x = tf.layers.dropout(x, rate=1-keep_prob, training=training) x = tf.layers.conv2d(x, filters=n_filters, kernel_size=[3,3], kernel_initializer=init, padding='same') out = tf.add(x, identity_map) return out with tf.variable_scope('wrn_conv', reuse=reuse): # Initial convolution ---------------------------------------------> with tf.variable_scope('conv0', reuse=reuse): conv = tf.layers.conv2d(x, filters[0], kernel_size=[3,3], activation=actv, kernel_initializer=init, padding='same') # Residual group 1 ------------------------------------------------> rb = conv f1 = filters[1]*network_width for n in range(block_multiplicity): with tf.variable_scope('group1/{}'.format(n), reuse=reuse): project_shortcut = True if n==0 else False rb = residual_block(rb, f1, actv, project_shortcut=project_shortcut, keep_prob=config.conv_keep_prob, training=training, first_block=True) # Residual group 2 ------------------------------------------------> f2 = filters[2]*network_width for n in range(block_multiplicity): with tf.variable_scope('group2/{}'.format(n), reuse=reuse): project_shortcut = True if n==0 else False rb = residual_block(rb, f2, actv, project_shortcut=project_shortcut, keep_prob=config.conv_keep_prob, training=training) # Residual group 3 ------------------------------------------------> f3 = filters[3]*network_width for n in range(block_multiplicity): with tf.variable_scope('group3/{}'.format(n), reuse=reuse): project_shortcut = True if n==0 else False rb = residual_block(rb, f3, actv, project_shortcut=project_shortcut, keep_prob=config.conv_keep_prob, training=training) # Avg pooling + output --------------------------------------------> with tf.variable_scope('output', reuse=reuse): bn = tf.nn.relu(tf.layers.batch_normalization(rb, **kwargs)) avp = tf.layers.average_pooling2d(bn, pool_size=[8,8], strides=[1,1], padding='valid') flatten = tf.contrib.layers.flatten(avp) out = tf.layers.dense(flatten, units=config.n_classes, kernel_initializer=init) return out @staticmethod def old_encoder(x, config, training, C, reuse=False, actv=tf.nn.relu): """ Process image x ([512,1024]) into a feature map of size W/16 x H/16 x C + C: Bottleneck depth, controls bpp + Output: Projection onto C channels, C = {2,4,8,16} """ # proj_channels = [2,4,8,16] init = tf.contrib.layers.xavier_initializer() def conv_block(x, filters, kernel_size=[3,3], strides=2, padding='same', actv=actv, init=init): in_kwargs = {'center':True, 'scale': True} x = tf.layers.conv2d(x, filters, kernel_size, strides=strides, padding=padding, activation=None) x = tf.contrib.layers.instance_norm(x, **in_kwargs) x = actv(x) return x with tf.variable_scope('encoder', reuse=reuse): # Run convolutions out = conv_block(x, kernel_size=3, strides=1, filters=160, actv=actv) out = conv_block(out, kernel_size=[3,3], strides=2, filters=320, actv=actv) out = conv_block(out, kernel_size=[3,3], strides=2, filters=480, actv=actv) out = conv_block(out, kernel_size=[3,3], strides=2, filters=640, actv=actv) out = conv_block(out, kernel_size=[3,3], strides=2, filters=800, actv=actv) out = conv_block(out, kernel_size=3, strides=1, filters=960, actv=actv) # Project channels onto lower-dimensional embedding space W = tf.get_variable('W_channel_{}'.format(C), shape=[960,C], initializer=init) feature_map = tf.einsum('ijkl,lm->ijkm', out, W) # feature_map = tf.tensordot(out, W, axes=((3),(0))) # Feature maps have dimension W/16 x H/16 x C return feature_map ================================================ FILE: samples/.gitignore ================================================ * !.gitignore ================================================ FILE: tensorboard/.gitignore ================================================ * !.gitignore ================================================ FILE: train.py ================================================ #!/usr/bin/python3 import tensorflow as tf import numpy as np import pandas as pd import time, os, sys import argparse # User-defined from network import Network from utils import Utils from data import Data from model import Model from config import config_train, directories tf.logging.set_verbosity(tf.logging.ERROR) def train(config, args): start_time = time.time() G_loss_best, D_loss_best = float('inf'), float('inf') ckpt = tf.train.get_checkpoint_state(directories.checkpoints) # Load data print('Training on dataset', args.dataset) if config.use_conditional_GAN: print('Using conditional GAN') paths, semantic_map_paths = Data.load_dataframe(directories.train, load_semantic_maps=True) test_paths, test_semantic_map_paths = Data.load_dataframe(directories.test, load_semantic_maps=True) else: paths = Data.load_dataframe(directories.train) test_paths = Data.load_dataframe(directories.test) # Build graph gan = Model(config, paths, name=args.name, dataset=args.dataset) saver = tf.train.Saver() if config.use_conditional_GAN: feed_dict_test_init = {gan.test_path_placeholder: test_paths, gan.test_semantic_map_path_placeholder: test_semantic_map_paths} feed_dict_train_init = {gan.path_placeholder: paths, gan.semantic_map_path_placeholder: semantic_map_paths} else: feed_dict_test_init = {gan.test_path_placeholder: test_paths} feed_dict_train_init = {gan.path_placeholder: paths} with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) train_handle = sess.run(gan.train_iterator.string_handle()) test_handle = sess.run(gan.test_iterator.string_handle()) if args.restore_last and ckpt.model_checkpoint_path: # Continue training saved model saver.restore(sess, ckpt.model_checkpoint_path) print('{} restored.'.format(ckpt.model_checkpoint_path)) else: if args.restore_path: new_saver = tf.train.import_meta_graph('{}.meta'.format(args.restore_path)) new_saver.restore(sess, args.restore_path) print('{} restored.'.format(args.restore_path)) sess.run(gan.test_iterator.initializer, feed_dict=feed_dict_test_init) for epoch in range(config.num_epochs): sess.run(gan.train_iterator.initializer, feed_dict=feed_dict_train_init) # Run diagnostics G_loss_best, D_loss_best = Utils.run_diagnostics(gan, config, directories, sess, saver, train_handle, start_time, epoch, args.name, G_loss_best, D_loss_best) while True: try: # Update generator # for _ in range(8): feed_dict = {gan.training_phase: True, gan.handle: train_handle} sess.run(gan.G_train_op, feed_dict=feed_dict) # Update discriminator step, _ = sess.run([gan.D_global_step, gan.D_train_op], feed_dict=feed_dict) if step % config.diagnostic_steps == 0: G_loss_best, D_loss_best = Utils.run_diagnostics(gan, config, directories, sess, saver, train_handle, start_time, epoch, args.name, G_loss_best, D_loss_best) Utils.single_plot(epoch, step, sess, gan, train_handle, args.name, config) # for _ in range(4): # sess.run(gan.G_train_op, feed_dict=feed_dict) except tf.errors.OutOfRangeError: print('End of epoch!') break except KeyboardInterrupt: save_path = saver.save(sess, os.path.join(directories.checkpoints, '{}_last.ckpt'.format(args.name)), global_step=epoch) print('Interrupted, model saved to: ', save_path) sys.exit() save_path = saver.save(sess, os.path.join(directories.checkpoints, '{}_end.ckpt'.format(args.name)), global_step=epoch) print("Training Complete. Model saved to file: {} Time elapsed: {:.3f} s".format(save_path, time.time()-start_time)) def main(**kwargs): parser = argparse.ArgumentParser() parser.add_argument("-rl", "--restore_last", help="restore last saved model", action="store_true") parser.add_argument("-r", "--restore_path", help="path to model to be restored", type=str) parser.add_argument("-opt", "--optimizer", default="adam", help="Selected optimizer", type=str) parser.add_argument("-name", "--name", default="gan-train", help="Checkpoint/Tensorboard label") parser.add_argument("-ds", "--dataset", default="cityscapes", help="choice of training dataset. Currently only supports cityscapes/ADE20k", choices=set(("cityscapes", "ADE20k")), type=str) args = parser.parse_args() # Launch training train(config_train, args) if __name__ == '__main__': main() ================================================ FILE: utils.py ================================================ # -*- coding: utf-8 -*- # Diagnostic helper functions for Tensorflow session import tensorflow as tf import numpy as np import os, time import matplotlib as mpl mpl.use('Agg') import matplotlib.pyplot as plt import seaborn as sns from config import directories class Utils(object): @staticmethod def conv_block(x, filters, kernel_size=[3,3], strides=2, padding='same', actv=tf.nn.relu): in_kwargs = {'center':True, 'scale': True} x = tf.layers.conv2d(x, filters, kernel_size, strides=strides, padding=padding, activation=None) x = tf.contrib.layers.instance_norm(x, **in_kwargs) x = actv(x) return x @staticmethod def upsample_block(x, filters, kernel_size=[3,3], strides=2, padding='same', actv=tf.nn.relu): in_kwargs = {'center':True, 'scale': True} x = tf.layers.conv2d_transpose(x, filters, kernel_size, strides=strides, padding=padding, activation=None) x = tf.contrib.layers.instance_norm(x, **in_kwargs) x = actv(x) return x @staticmethod def residual_block(x, n_filters, kernel_size=3, strides=1, actv=tf.nn.relu): init = tf.contrib.layers.xavier_initializer() # kwargs = {'center':True, 'scale':True, 'training':training, 'fused':True, 'renorm':False} strides = [1,1] identity_map = x p = int((kernel_size-1)/2) res = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]], 'REFLECT') res = tf.layers.conv2d(res, filters=n_filters, kernel_size=kernel_size, strides=strides, activation=None, padding='VALID') res = actv(tf.contrib.layers.instance_norm(res)) res = tf.pad(res, [[0, 0], [p, p], [p, p], [0, 0]], 'REFLECT') res = tf.layers.conv2d(res, filters=n_filters, kernel_size=kernel_size, strides=strides, activation=None, padding='VALID') res = tf.contrib.layers.instance_norm(res) assert res.get_shape().as_list() == identity_map.get_shape().as_list(), 'Mismatched shapes between input/output!' out = tf.add(res, identity_map) return out @staticmethod def get_available_gpus(): from tensorflow.python.client import device_lib local_device_protos = device_lib.list_local_devices() #return local_device_protos print('Available GPUs:') print([x.name for x in local_device_protos if x.device_type == 'GPU']) @staticmethod def scope_variables(name): with tf.variable_scope(name): return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=tf.get_variable_scope().name) @staticmethod def run_diagnostics(model, config, directories, sess, saver, train_handle, start_time, epoch, name, G_loss_best, D_loss_best): t0 = time.time() improved = '' sess.run(tf.local_variables_initializer()) feed_dict_test = {model.training_phase: False, model.handle: train_handle} try: G_loss, D_loss, summary = sess.run([model.G_loss, model.D_loss, model.merge_op], feed_dict=feed_dict_test) model.train_writer.add_summary(summary) except tf.errors.OutOfRangeError: G_loss, D_loss = float('nan'), float('nan') if G_loss < G_loss_best and D_loss < D_loss_best: G_loss_best, D_loss_best = G_loss, D_loss improved = '[*]' if epoch>5: save_path = saver.save(sess, os.path.join(directories.checkpoints_best, '{}_epoch{}.ckpt'.format(name, epoch)), global_step=epoch) print('Graph saved to file: {}'.format(save_path)) if epoch % 5 == 0 and epoch > 5: save_path = saver.save(sess, os.path.join(directories.checkpoints, '{}_epoch{}.ckpt'.format(name, epoch)), global_step=epoch) print('Graph saved to file: {}'.format(save_path)) print('Epoch {} | Generator Loss: {:.3f} | Discriminator Loss: {:.3f} | Rate: {} examples/s ({:.2f} s) {}'.format(epoch, G_loss, D_loss, int(config.batch_size/(time.time()-t0)), time.time() - start_time, improved)) return G_loss_best, D_loss_best @staticmethod def single_plot(epoch, global_step, sess, model, handle, name, config, single_compress=False): real = model.example gen = model.reconstruction # Generate images from noise, using the generator network. r, g = sess.run([real, gen], feed_dict={model.training_phase:True, model.handle: handle}) images = list() for im, imtype in zip([r,g], ['real', 'gen']): im = ((im+1.0))/2 # [-1,1] -> [0,1] im = np.squeeze(im) im = im[:,:,:3] images.append(im) # Uncomment to plot real and generated samples separately # f = plt.figure() # plt.imshow(im) # plt.axis('off') # f.savefig("{}/gan_compression_{}_epoch{}_step{}_{}.pdf".format(directories.samples, name, epoch, # global_step, imtype), format='pdf', dpi=720, bbox_inches='tight', pad_inches=0) # plt.gcf().clear() # plt.close(f) comparison = np.hstack(images) f = plt.figure() plt.imshow(comparison) plt.axis('off') if single_compress: f.savefig(name, format='pdf', dpi=720, bbox_inches='tight', pad_inches=0) else: f.savefig("{}/gan_compression_{}_epoch{}_step{}_{}_comparison.pdf".format(directories.samples, name, epoch, global_step, imtype), format='pdf', dpi=720, bbox_inches='tight', pad_inches=0) plt.gcf().clear() plt.close(f) @staticmethod def weight_decay(weight_decay, var_label='DW'): """L2 weight decay loss.""" costs = [] for var in tf.trainable_variables(): if var.op.name.find(r'{}'.format(var_label)) > 0: costs.append(tf.nn.l2_loss(var)) return tf.multiply(weight_decay, tf.add_n(costs))