Repository: PrajitR/fast-pixel-cnn Branch: master Commit: 6a88afb95207 Files: 11 Total size: 105.2 KB Directory structure: gitextract_i13tz309/ ├── LICENSE.md ├── README.md ├── fast_pixel_cnn_pp/ │ ├── __init__.py │ ├── fast_nn.py │ ├── ground_truth_output.npy │ ├── model.py │ ├── nn.py │ ├── plotting.py │ ├── test_components.py │ └── test_end_to_end.py └── generate.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE.md ================================================ MIT License (MIT) Copyright (c) 2017 Prajit Ramachandran, Tom Le Paine, Pooya Khorrami, Mohammad Babaeizadeh Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # Fast PixelCNN++: speedy image generation *Real time generation of 16 32-by-32 images. Naive generation (left) vs. fast generation (right).*

We speed up the image generation algorithm of [PixelCNN++](https://github.com/openai/pixel-cnn) by avoiding redundant computation through caching. Naive generation discards computation that can be re-used and performs additional computation that will not be used to generate a particular pixel. Naive generation can take up to 11 minutes to generate 16 32-by-32 images on a Tesla K40 GPU. By re-using previous computation and only performing the minimum amount of computation required, we achieve up to a 183 times speedup over the naive generation algorithm.

## How to run We have tested our code with Python 3 and TensorFlow 1.0. You may need to make small changes for other versions of Python or TensorFlow. Instructions to run: * Install [TensorFlow 1.0](https://www.tensorflow.org/install/), Numpy, and Matplotlib * Download and unzip [OpenAI's pretrained PixelCNN++ model](http://alpha.openai.com/pxpp.zip). After unzipping, there should be a file called `params_cifar.ckpt` * Run the script with `CUDA_VISIBLE_DEVICES=0 python generate.py --checkpoint=/path/to/params_cifar.ckpt --save_dir=/path/to/save/generated/images` The script will continually generate images in a loop and write out the images to `--save_dir`. You can exit the script at any time by interrupting with Control-C. ## How it works ### What is PixelCNN++, and why should I use it? PixelCNN++ is a generative model that uses all previously generated pixels as information to generate the next pixel. That is, to generate the 10th pixel in the image, PixelCNN++ will look at pixels 1-9 to model the output distribution of pixel 10: `P(pixel 10 | pixel 1, ..., pixel 9)`. Similarly, pixel 11 will look at pixels 1-10, and this process continues for all pixels. This property makes PixelCNN an *autoregressive* model, where each pixel is modeled by the history of previous pixels. What makes PixelCNN unique is that it uses clever, fast methods to coalesce information from previous pixels, which is crucial for training speed. PixelCNN was [original developed by DeepMind](https://arxiv.org/abs/1606.05328) and [improved upon by OpenAI](https://openreview.net/pdf?id=BJrFC6ceg) in PixelCNN++. These models have achieved state-of-the-art results on a variety of image generation benchmarks. They are straightforward to train, have a large capacity to model complex inputs, and are able to generate crisp, attractive images. For example, PixelCNN has [recently been used for superresolution](https://arxiv.org/abs/1702.00783). One of the main downsides of autoregressive models compared to other generative models like [Generative Adversarial Networks](https://arxiv.org/abs/1701.00160) and [Variational Autoencoders](https://arxiv.org/abs/1606.05908) is that autoregressive models must generate pixels one at a time, whereas other methods can generate the entire image at once. Our method speeds up the generation process for PixelCNN. ### Speeding up a simple 1D example with dilation Before jumping into the details of speeding up PixelCNN, let's focus on a simpler 1D autoregressive model: [Wavenet](https://arxiv.org/abs/1609.03499). The details presented here are the similar to that of our other repository, [Fast Wavenet](https://github.com/tomlepaine/fast-wavenet), which you can refer to for more details. ![](assets/wavenet.png) The Wavenet graph (on the left) looks like a binary tree. A node is convolved with it's `nth` previous neighbor, where `n` is a power of 2. Since `n`, the *dilation*, is increasing by a factor of two every layer, the range of nodes that are combined together, the *receptive field*, increases exponentially. On every generation step, information from all nodes in the receptive field (8 in the picture) must be combined. A naive generation algorithm simply repeats the entire tree of computation for every generation step. This is easy to implement, but is slow. You may have noticed that when generating consecutive outputs, a large portion of the tree is reused. For example, call the current step in the picture `t` and imagine generating the output for `t + 2`. In this case, three of the four orange nodes in the first hidden layer can be reused! It is a waste of time to recompute them. This brings us to the core of our method: caching previously computed hidden states. As illustrated in the image on the right, we maintain a cache for each layer which holds previously computed hidden states. The size of the cache is equal to the dilation factor of the hidden layer since the model must look back `n` steps at the hidden layer. The cache acts like a queue: the oldest hidden state is popped from the the front of the queue, which is exactly equivalent to the normal dilated convolution. After a hidden state is computed, it must then be pushed into the back of the queue, to be used exactly `n` steps in future from now. This process repeats itself, giving a fast generation algorithm that avoids the exponential computation of the naive approach. ### Speeding up strided convolutions The previous section used dilated convolutions. In this case, node `t` is convolved with node `t - n`, and node `t + 1` is convolved with node `t + 1 - n`. This implies that the number of hidden states in a layer is equal to the number of inputs, making caching straightforward. However, using strided convolutions makes the problem more difficult because the number of states in a hidden layer is different from the number of inputs. Strided convolutions are downsampling layers. This means that there are fewer hidden states than inputs. A typical convolution will convolve over a local neighborhood and then slide 1 position over and repeat the procedure. For example, nodes `t - 1` and `t` will be convolved, and then nodes `t` and `t + 1` will be convolved. Striding affects the number of positions that the convolution will slide over. In the previous example, the stride is 1. However, when the stride is greater than 1, the input is downsampled. For example, take the stride to be 2. Nodes `t - 1` and `t` will be convolved, and then nodes `t + 1` and `t + 2` will be convolved, since the convolution has slided 2 positions over. This means every pair of inputs to the layer only produces one output, so the number of hidden states is smaller than the number of inputs. Similarly, there are upsampling layers, which are strided transposed convolutions. With a stride of `s`, upsampling layers will produce `s` outputs for every input to the layer. This increases the number of hidden states compared to the number of inputs to the layer. PixelCNN++ uses 2 downsampling layers followed by 2 upsampling layers, each of stride 2, meaning that the number of generated pixels is the same as the number of input pixels (i.e. `D / 2 / 2 * 2 * 2 = D`). A detailed explanation of strides and transposed convolutions can be found [here](https://github.com/vdumoulin/conv_arithmetic) and [here](https://arxiv.org/abs/1603.07285). Because of the differing number of hidden states, caches cannot be updated in every timestep. Thus, each cache has an additional property `cache every`, where the cache is only updated every `cache every` steps. Every downsampling layer increases the `cache every` property of the layer by the stride. Conversely, every upsampling layer decreases the `cache every` property of the layer by the stride. ![](assets/strided.png) The figure above shows an example model with 2 upsampling and 2 downsampling layers each with a stride of 2. Orange nodes are computed in the current timestep, blue nodes are previously cached states, and gray nodes are not involved in the current timestep. * At the first timestep `t = 0`, the first input is used to compute and cache all nodes for which there is sufficient information to generate, including the first four outputs. * At `t = 1`, there are no nodes that have sufficient information to be computed, but the output for `t = 1` has already been computed at `t = 0`. * At `t = 2`, there is one new node that now has sufficient information to be computed, although the output for `t = 2` has also been computed at `t = 0`. * The `t = 3` scenario is similar to `t = 1`. * At `t = 4`, there is enough information to compute multiple hidden states and generate the next four outputs. This is analogous to the `t = 0` scenario. * `t = 5` is analogous to `t = 1`, and this cycle is followed for all future time steps. In our code, we also use a property `run every` which is equal to the `cache every` property of the next layer. This allows us to avoid computation if the next layer is simply going to ignore its input. ### Speeding up PixelCNN++ After understanding the previous sections, it should seem relatively straightforward to generalize the 1D example to the 2D case. Indeed, our method generalizes with very few changes. The caches for each layer are now 2D, with a height equal to the filter height and a width equal to the image width. After an entire row is generated, the oldest row of the cache is popped and the new row is pushed. Because strided convolutions are used, we use the `cache every` idea detailed in the previous section. PixelCNN maintains two streams of computation: a vertical stream and a horizontal stream. Oversimplifying the details a bit, the vertical stream looks at all pixels above the current pixel while the horizontal stream looks at all pixels immediately left of the current pixel, satisfying the autoregressive property in the 2D case (see the PixelCNN papers for a more precise explanation). The horizontal stream also takes in the vertical stream as another input. In our code, we compute the vertical stream one row at a time, cache it, and use it to compute the horizontal stream (and the generated output) one pixel at a time. And with this, we are able to achieve orders of magnitude speedups for PixelCNN++ generation! Increasing the batch size demonstrates the scalability of our method. While the naive implementation scales linearly with the batch size (because of 100% GPU utilization), our method enjoys superior scaling because of its minimal computational requirements.

### Beyond PixelCNN++ The core concepts detailed here will easily generalize for different modalities. For example, speeding up the [Video Pixel Network](https://arxiv.org/abs/1610.00527) for generating videos should look straightforward, and will probably produce even more impressive speedups because of the higher computational demand. We look forward to hearing about practical uses of fast generation for convolutional autoregressive models! ## Authors * [Prajit Ramachandran](https://github.com/PrajitR) * [Tom Le Paine](https://github.com/tomlepaine) * [Pooya Khorrami](https://github.com/pkhorrami4) * [Mohammad Babaeizadeh](https://github.com/mbz) If you found this work useful, please cite our [paper](https://arxiv.org/abs/1704.06001). ``` @article{ramachandran2017fast, title={Fast Generation for Convolutional Autoregressive Models}, author={Ramachandran, Prajit and Paine, Tom Le and Khorrami, Pooya and Babaeizadeh, Mohammad and Chang, Shiyu and Zhang, Yang and Hasegawa-Johnson, Mark A and Campbell, Roy H and Huang, Thomas S}, journal={arXiv preprint arXiv:1704.06001}, year={2017} } ``` ================================================ FILE: fast_pixel_cnn_pp/__init__.py ================================================ ================================================ FILE: fast_pixel_cnn_pp/fast_nn.py ================================================ from . import nn import tensorflow as tf from tensorflow.contrib.framework.python.ops import add_arg_scope import numpy as np from collections import namedtuple import math LayerInfo = namedtuple('LayerInfo', [ 'image_size', 'batch', 'image_height', 'image_width', 'image_channels', 'filter_size', 'filter_height', 'filter_width', 'filter_channels', 'input_channels', 'nonlinearity' ]) RESET_CACHE_COLLECTION = 'reset_cache' def down_shift(image): '''Shift all rows down by one, using zeros as the first row and throwing away the last row.''' all_image_except_last_row = image[:, :-1, :, :] zero_row = np.zeros_like(image[:, :1, :, :]) return np.concatenate([zero_row, all_image_except_last_row], axis=1) def right_shift(image): '''Shift all columns right by one, using zeros as the first column and throwing away the last column.''' all_image_except_last_column = image[:, :, :-1, :] zero_column = np.zeros_like(image[:, :, :1, :]) return np.concatenate([zero_column, all_image_except_last_column], axis=2) def _extract_layer_info(network_info, input_, nonlinearity): '''Utility function to extract information about the current layer.''' image_size, filter_size = network_info batch, image_height, image_width, image_channels = image_size filter_height, filter_width, filter_channels = filter_size input_channels = int(input_.get_shape()[-1]) if nonlinearity is None: nonlinearity = tf.identity return LayerInfo(image_size, batch, image_height, image_width, image_channels, filter_size, filter_height, filter_width, filter_channels, input_channels, nonlinearity) def _create_cache(batch, cache_height, cache_width, channels): '''Creates a cache, which is used to avoid redundant computation.''' cache = tf.Variable( initial_value=np.zeros((batch, cache_height, cache_width, channels)), dtype=tf.float32, name='cache', trainable=False) # Reset the cache between generations. reset_cache = cache.assign(tf.zeros_like(cache)) tf.add_to_collection(RESET_CACHE_COLLECTION, reset_cache) return cache def reset_cache_op(): '''Returns an op to reset all created caches. Used between different generation calls.''' return tf.group(*tf.get_collection(RESET_CACHE_COLLECTION)) def _get_conv_variables(filter_size, input_channels, scope_name, counters): '''Creates and returns variables used for convolution.''' filter_height, filter_width, filter_channels = filter_size with tf.variable_scope(nn.get_name(scope_name, counters)): V = tf.get_variable( 'V', [filter_height, filter_width, input_channels, filter_channels], dtype=tf.float32) g = tf.get_variable('g', [filter_channels], dtype=tf.float32) b = tf.get_variable('b', [filter_channels], dtype=tf.float32) return V, g, b def _get_conv2d_variables(filter_size, input_channels, counters): '''Creates and returns the variables used for a normal 2D convolution.''' V, g, b = _get_conv_variables(filter_size, input_channels, 'conv2d', counters) filter_channels = filter_size[-1] W = tf.reshape(g, [1, 1, 1, filter_channels]) * tf.nn.l2_normalize( V, [0, 1, 2]) # Weight normalization. return W, b def _get_deconv2d_variables(filter_size, input_channels, counters): '''Creates and returns the variables used for a 2D transposed convolution (deconvolution).''' V, g, b = _get_conv_variables(filter_size, input_channels, 'deconv2d', counters) filter_channels = filter_size[-1] W = tf.reshape(g, [1, 1, filter_channels, 1]) * tf.nn.l2_normalize( V, [0, 1, 3]) # Weight normalization. return W, b def _mod_equal_0(row_or_col, every): '''Returns a boolean tensor representing (row_or_col % every == 0)''' return tf.equal(tf.mod(row_or_col, every), 0) def _roll_cache(cache): '''Pop off the oldest row of the cache to make space for the newest row of input.''' batch, _, cache_width, channels = cache.get_shape() without_dropped_row = cache[:, 1:, :, :] zero_row = tf.zeros([batch, 1, cache_width, channels]) rolled_cache = tf.concat([without_dropped_row, zero_row], 1) return cache.assign(rolled_cache) @add_arg_scope def down_shifted_conv2d(row_input, network_info, stride, row, cache_every, run_every, nonlinearity=None, counters={}): '''Performs a convolution for the vertical stack.''' li = _extract_layer_info(network_info, row_input, nonlinearity) ## Create cache. cache_height = li.filter_height # Just large enough to fit the filter. padding = li.filter_width // 2 # Horizontal padding to make VALID convolution maintain the width of input. cache_width = li.image_width + 2 * padding # Cache width is the image width plus padding to the left and right. cache = _create_cache(li.batch, cache_height, cache_width, li.input_channels) ## Update cache. should_cache = _mod_equal_0(row, cache_every) cache_func = lambda: cache[:, -1:, padding:(padding + li.image_width), :].assign(row_input) do_nothing_cache_func = lambda: row_input assign_to_cache = tf.cond(should_cache, cache_func, do_nothing_cache_func) ## Compute output. W, b = _get_conv2d_variables(li.filter_size, li.input_channels, counters) with tf.control_dependencies([assign_to_cache]): should_run = _mod_equal_0(row, run_every) # Compute output for the entire row. run_func = lambda: li.nonlinearity(tf.nn.conv2d(cache, W, [1, 1, stride, 1], 'VALID') + b) output_width = int(math.ceil(li.image_width / float(stride))) do_nothing_run_func = lambda: tf.zeros([li.batch, 1, output_width, li.filter_channels]) outputs = tf.cond(should_run, run_func, do_nothing_run_func) outputs.set_shape([li.batch, 1, output_width, li.filter_channels]) # Ensure that roll_cache() is run, and only after computing the outputs. with tf.control_dependencies([outputs]): roll_cache_op = tf.cond(should_cache, lambda: _roll_cache(cache), lambda: cache) with tf.control_dependencies([roll_cache_op]): outputs = tf.identity(outputs) return outputs @add_arg_scope def down_right_shifted_conv2d(pixel_input, network_info, row, col, cache_every, run_every, nonlinearity=None, counters={}): '''Performs a convolution for the horizontal stack.''' li = _extract_layer_info(network_info, pixel_input, nonlinearity) ## Create cache. cache_height = li.filter_height # Just large enough to fit the filter. left_pad = li.filter_width - 1 # Only need left padding because always convolving to the left. cache_width = li.image_width + left_pad cache = _create_cache(li.batch, cache_height, cache_width, li.input_channels) cache_col = col // cache_every # Accounts for downsampling due to stride in previous layers. ## Update cache. should_cache = tf.logical_and( _mod_equal_0(row, cache_every), _mod_equal_0(col, cache_every)) pixel_col = cache_col + left_pad # Accounts for padding in the cache. cache_func = lambda: cache[:, -1:, pixel_col:(pixel_col + 1), :].assign(pixel_input) do_nothing_cache_func = lambda: pixel_input assign_to_cache = tf.cond(should_cache, cache_func, do_nothing_cache_func) ## Compute output. W, b = _get_conv2d_variables(li.filter_size, li.input_channels, counters) with tf.control_dependencies([assign_to_cache]): should_run = tf.logical_and( _mod_equal_0(row, run_every), _mod_equal_0(col, run_every)) # Extract the local neighborhood of the current column in the cache to be convolved with the filter. # This is simply a matrix multiply, since the neighborhood is the size of the filter. width_start = cache_col width_end = width_start + li.filter_width cache_neighborhood = cache[:, :, width_start:width_end, :] run_func = lambda: li.nonlinearity(tf.nn.conv2d(cache_neighborhood, W, [1, 1, 1, 1], 'VALID') + b) do_nothing_run_func = lambda: tf.zeros([li.batch, 1, 1, li.filter_channels]) outputs = tf.cond(should_run, run_func, do_nothing_run_func) outputs.set_shape([li.batch, 1, 1, li.filter_channels]) # Ensure that roll_cache() is run, and only after computing the outputs. with tf.control_dependencies([outputs]): # Roll out an entire row of the cache only after generating output for the last column. is_end_of_row = tf.equal(cache_col, li.image_width - 1) should_roll = tf.logical_and(should_cache, is_end_of_row) maybe_roll = tf.cond(should_roll, lambda: _roll_cache(cache), lambda: cache) with tf.control_dependencies([maybe_roll]): outputs = tf.identity(outputs) return outputs def _create_deconv_cache(li, stride): '''Creates the cache for the two deconv layers.''' cache_height = li.filter_height # Just large enough to fit the filter. # The deconv will increases the number of outputs `stride` times. # The extra width comes from the tf.nn.conv2d_transpose() operation. cache_width = li.image_width * stride + li.filter_width - 1 cache = _create_cache(li.batch, cache_height, cache_width, li.filter_channels) return cache, cache_height, cache_width @add_arg_scope def down_shifted_deconv2d(row_input, network_info, row, cache_every, run_every, stride=2, nonlinearity=None, counters={}): '''Performs a transposed convolution for the vertical stack.''' li = _extract_layer_info(network_info, row_input, nonlinearity) ## Create cache. cache, cache_height, cache_width = _create_deconv_cache(li, stride) ## Update cache. should_cache = _mod_equal_0(row, cache_every) W, b = _get_deconv2d_variables(li.filter_size, li.input_channels, counters) def cache_func(): # Compute deconv output for the entire row. outputs = tf.nn.conv2d_transpose( row_input, W, output_shape=[ li.batch, cache_height, cache_width, li.filter_channels ], strides=[1, stride, stride, 1], padding='VALID') outputs = li.nonlinearity(outputs + b) # Store the output in the cache. with tf.control_dependencies([outputs]): # With stride=2, this is simply cache.assign(outputs) since the old rows in the cache # will all have been rolled out. update_cache = cache.assign(cache + outputs) return update_cache do_nothing_cache_func = lambda: tf.zeros_like(cache) assign_to_cache = tf.cond(should_cache, cache_func, do_nothing_cache_func) ## Compute output. with tf.control_dependencies([assign_to_cache]): should_run = _mod_equal_0(row, run_every) def run_func(): # The cache stores the deconv output, so just return the next (first) row and roll. output = cache[:, 0:1, 1:-1, :] with tf.control_dependencies([output]): with tf.control_dependencies([_roll_cache(cache)]): output = tf.identity(output) return output do_nothing_run_func = lambda: tf.zeros([li.batch, 1, cache_width - 2, li.filter_channels]) outputs = tf.cond(should_run, run_func, do_nothing_run_func) outputs.set_shape([li.batch, 1, cache_width - 2, li.filter_channels]) return outputs @add_arg_scope def down_right_shifted_deconv2d(pixel_input, network_info, row, col, cache_every, run_every, stride=2, nonlinearity=None, counters={}): '''Performs a transposed convolution for the horizontal stack.''' li = _extract_layer_info(network_info, pixel_input, nonlinearity) ## Create cache. cache, cache_height, cache_width = _create_deconv_cache(li, stride) ## Update cache. should_cache = tf.logical_and( _mod_equal_0(row, cache_every), _mod_equal_0(col, cache_every)) W, b = _get_deconv2d_variables(li.filter_size, li.input_channels, counters) def cache_func(): outputs = tf.nn.conv2d_transpose( pixel_input, W, output_shape=[ li.batch, li.filter_height, li.filter_width, li.filter_channels ], strides=[1, stride, stride, 1], padding='VALID') outputs = li.nonlinearity(outputs + b) # Store the output in the cache. with tf.control_dependencies([outputs]): cache_col = col // cache_every update_cache = cache[:, :, (stride * cache_col):(stride * ( cache_col + 1)), :].assign(outputs) return update_cache do_nothing_cache_func = lambda: tf.zeros([li.batch, li.filter_height, li.filter_width, li.filter_channels]) assign_to_cache = tf.cond(should_cache, cache_func, do_nothing_cache_func) ## Compute output. with tf.control_dependencies([assign_to_cache]): should_run = tf.logical_and( _mod_equal_0(row, run_every), _mod_equal_0(col, run_every)) def run_func(): output_col = col // run_every output = cache[:, 0:1, output_col:(output_col + 1), :] # Only roll after the end of the row has been reached. with tf.control_dependencies([output]): is_end_of_row = tf.equal(output_col, cache_width - li.filter_width) maybe_roll = tf.cond(is_end_of_row, lambda: _roll_cache(cache), lambda: cache) with tf.control_dependencies([maybe_roll]): output = tf.identity(output) return output do_nothing_run_func = lambda: tf.zeros([li.batch, 1, 1, li.filter_channels]) outputs = tf.cond(should_run, run_func, do_nothing_run_func) outputs.set_shape([li.batch, 1, 1, li.filter_channels]) return outputs def sum_rightshift_downshift(rightshifted_pixel, downshifted_row, col): '''Sums the vertical and horizontal stack.''' downshifted_pixel = downshifted_row[:, :, col:(col + 1), :] return rightshifted_pixel + downshifted_pixel def _conditional_info(h, batch, filter_channels, counters): '''Computes the conditional information for the resnet layer.''' with tf.variable_scope(nn.get_name('conditional_weights', counters)): hw = tf.get_variable( 'hw', shape=[h.get_shape()[-1], 2 * filter_channels], dtype=tf.float32, initializer=tf.random_normal_initializer(0, 0.05), trainable=True) conditional_info = tf.reshape( tf.matmul(h, hw), [batch, 1, 1, 2 * filter_channels]) return conditional_info def _gated_nonlinearity(out): a, b = tf.split(out, 2, 3) return a * tf.nn.sigmoid(b) @add_arg_scope def gated_resnet_vstack_only(row_input, network_info, row, cache_every, run_every, extra_row_input=None, h=None, nonlinearity=None, counters={}): '''Performs gated resnet computations for the vertical stack.''' li = _extract_layer_info(network_info, row_input, nonlinearity) out = li.nonlinearity(row_input) out = down_shifted_conv2d( out, network_info, stride=1, row=row, cache_every=cache_every, run_every=run_every, nonlinearity=None, counters=counters) if extra_row_input is not None: # For skip connections between downsampling and upsampling layers. out += nn.nin( li.nonlinearity(extra_row_input), li.filter_channels, counters=counters) out = li.nonlinearity(out) network_info = (li.image_size, (li.filter_height, li.filter_width, 2 * li.filter_channels)) out = down_shifted_conv2d( out, network_info, stride=1, row=row, cache_every=cache_every, run_every=run_every, nonlinearity=None, counters=counters) if h is not None: out += _conditional_info(h, li.batch, li.filter_channels, counters) out = row_input + _gated_nonlinearity(out) return out @add_arg_scope def gated_resnet_hstack(pixel_input, v_stack_row_input, network_info, row, col, cache_every, run_every, extra_pixel_input=None, h=None, nonlinearity=None, counters={}): '''Performs gated resnet computations for the horizontal stack.''' li = _extract_layer_info(network_info, pixel_input, nonlinearity) out = li.nonlinearity(pixel_input) out = down_right_shifted_conv2d( out, network_info, row=row, col=col, cache_every=cache_every, run_every=run_every, nonlinearity=None, counters=counters) # Horizontal stack also takes in as input the vertical stack. cache_col = col // cache_every # Compensates for striding in previous layers. v_stack_pixel = v_stack_row_input[:, :, cache_col:(cache_col + 1), :] v_shape = v_stack_pixel.get_shape() v_stack_pixel.set_shape([li.batch, 1, 1, li.input_channels]) if extra_pixel_input is not None: # For skip connections between downsampling and upsampling layers. v_stack_pixel = tf.concat([v_stack_pixel, extra_pixel_input], 3) out += nn.nin( li.nonlinearity(v_stack_pixel), li.filter_channels, counters=counters) out = li.nonlinearity(out) network_info = (li.image_size, (li.filter_height, li.filter_width, 2 * li.filter_channels)) out = down_right_shifted_conv2d( out, network_info, row=row, col=col, cache_every=cache_every, run_every=run_every, nonlinearity=None, counters=counters) if h is not None: out += _conditional_info(h, li.batch, li.filter_channels, counters) out = pixel_input + _gated_nonlinearity(out) return out ================================================ FILE: fast_pixel_cnn_pp/model.py ================================================ from . import fast_nn from . import nn import tensorflow as tf from tensorflow.contrib.framework.python.ops import arg_scope import numpy as np UPDATE_V_STACK = 'update_v_stack' def undo_zeroth_row_bias_when_downshifting(row_output, row): '''The down_shifted_conv2d adds a bias to the row of all zeros. This removes that bias.''' return tf.cond( tf.equal(row, 0), lambda: tf.zeros_like(row_output), lambda: row_output) def undo_zeroth_column_bias_when_rightshifting(pixel_output, col): '''The down_shifted_conv2d adds a bias to the column of all zeros. This removes that bias.''' return tf.cond( tf.equal(col, 0), lambda: tf.zeros_like(pixel_output), lambda: pixel_output) def cache_v_stack_variable(v_stack_variable): '''Caches vertical stack hidden states. This avoids the need to pass the computed vertical stack in the feed_dict, which would involve CPU to GPU transfers.''' cache = tf.Variable( initial_value=np.zeros(v_stack_variable.get_shape().as_list()), name='v_stack_cache', dtype=tf.float32) update_v_stack_cache = cache.assign(v_stack_variable) tf.add_to_collection(UPDATE_V_STACK, update_v_stack_cache) reset_cache = cache.assign(tf.zeros_like(cache)) tf.add_to_collection(fast_nn.RESET_CACHE_COLLECTION, reset_cache) return cache def model_spec(row_input, pixel_input, row, col, image_size, h=None, nr_resnet=5, nr_filters=160, nr_logistic_mix=10, resnet_nonlinearity='concat_elu', seed=None): '''Creates the model. Follows the same model_spec structure as the original PixelCNN++.''' counters = {} with arg_scope( [ fast_nn.down_shifted_conv2d, fast_nn.down_right_shifted_conv2d, fast_nn.down_shifted_deconv2d, fast_nn.down_right_shifted_deconv2d, fast_nn.gated_resnet_vstack_only, fast_nn.gated_resnet_hstack, nn.dense ], counters=counters): # Parse resnet nonlinearity argument. if resnet_nonlinearity == 'concat_elu': resnet_nonlinearity = nn.concat_elu elif resnet_nonlinearity == 'elu': resnet_nonlinearity = tf.nn.elu elif resnet_nonlinearity == 'relu': resnet_nonlinearity = tf.nn.relu else: raise ('resnet nonlinearity ' + resnet_nonlinearity + ' is not supported') with arg_scope( [fast_nn.gated_resnet_vstack_only, fast_nn.gated_resnet_hstack], nonlinearity=resnet_nonlinearity, h=h): u_filter = [2, 3, nr_filters] ul_filter = [2, 2, nr_filters] cache_every, run_every = 1, 1 ## Downsampling pass. # The initial computation to the network. Importantly, it is assumed that the # vertical stack inputs are already downshifted, and the horizontal stack inputs # are already rightshifted. v_stack = [] u_list_input = fast_nn.down_shifted_conv2d( row_input, (image_size, u_filter), stride=1, row=row, cache_every=cache_every, run_every=run_every) u_list = [ undo_zeroth_row_bias_when_downshifting(u_list_input, row) ] v_stack.append(u_list[-1]) downshift_hstack_input = fast_nn.down_shifted_conv2d( row_input, (image_size, [1, 3, nr_filters]), stride=1, row=row, cache_every=cache_every, run_every=run_every) downshift_hstack_input = undo_zeroth_row_bias_when_downshifting( downshift_hstack_input, row) downshift_hstack_input = cache_v_stack_variable( downshift_hstack_input) v_stack.append(downshift_hstack_input) rightshift_hstack_input = fast_nn.down_right_shifted_conv2d( pixel_input, (image_size, [2, 1, nr_filters]), row=row, col=col, cache_every=cache_every, run_every=run_every) rightshift_hstack_input = undo_zeroth_column_bias_when_rightshifting( rightshift_hstack_input, col) ul_list = [ fast_nn.sum_rightshift_downshift(rightshift_hstack_input, downshift_hstack_input, col) ] # Gated resnet layers. image_size = (image_size[0], image_size[1], image_size[2], nr_filters) for rep in range(nr_resnet): u_list.append( fast_nn.gated_resnet_vstack_only( u_list[-1], (image_size, u_filter), row=row, cache_every=cache_every, run_every=run_every, nonlinearity=resnet_nonlinearity)) v_stack.append(u_list[-1]) ul_list.append( fast_nn.gated_resnet_hstack( ul_list[-1], cache_v_stack_variable(u_list[-1]), (image_size, ul_filter), row=row, col=col, cache_every=cache_every, run_every=run_every, nonlinearity=resnet_nonlinearity)) # Downsample. cache_every, run_every = 1, 2 u_list.append( fast_nn.down_shifted_conv2d( u_list[-1], (image_size, u_filter), stride=2, row=row, cache_every=cache_every, run_every=run_every)) v_stack.append(u_list[-1]) ul_list.append( fast_nn.down_right_shifted_conv2d( ul_list[-1], (image_size, ul_filter), row=row, col=col, cache_every=cache_every, run_every=run_every)) cache_every, run_every = 2, 2 image_size = (image_size[0], image_size[1] // 2, image_size[2] // 2, nr_filters) # Gated resnet layers. for rep in range(nr_resnet): u_list.append( fast_nn.gated_resnet_vstack_only( u_list[-1], (image_size, u_filter), row=row, cache_every=cache_every, run_every=run_every, nonlinearity=resnet_nonlinearity)) v_stack.append(u_list[-1]) ul_list.append( fast_nn.gated_resnet_hstack( ul_list[-1], cache_v_stack_variable(u_list[-1]), (image_size, ul_filter), row=row, col=col, cache_every=cache_every, run_every=run_every, nonlinearity=resnet_nonlinearity)) # Downsample. cache_every, run_every = 2, 4 u_list.append( fast_nn.down_shifted_conv2d( u_list[-1], (image_size, u_filter), stride=2, row=row, cache_every=cache_every, run_every=run_every)) v_stack.append(u_list[-1]) ul_list.append( fast_nn.down_right_shifted_conv2d( ul_list[-1], (image_size, ul_filter), row=row, col=col, cache_every=cache_every, run_every=run_every)) cache_every, run_every = 4, 4 image_size = (image_size[0], image_size[1] // 2, image_size[2] // 2, nr_filters) # Gated resnet layers. for rep in range(nr_resnet): u_list.append( fast_nn.gated_resnet_vstack_only( u_list[-1], (image_size, u_filter), row=row, cache_every=cache_every, run_every=run_every, nonlinearity=resnet_nonlinearity)) v_stack.append(u_list[-1]) ul_list.append( fast_nn.gated_resnet_hstack( ul_list[-1], cache_v_stack_variable(u_list[-1]), (image_size, ul_filter), row=row, col=col, cache_every=cache_every, run_every=run_every, nonlinearity=resnet_nonlinearity)) # Upsampling pass. u = u_list.pop() ul = ul_list.pop() for rep in range(nr_resnet): u = fast_nn.gated_resnet_vstack_only( u, (image_size, u_filter), extra_row_input=u_list.pop(), row=row, cache_every=cache_every, run_every=run_every, nonlinearity=resnet_nonlinearity) v_stack.append(u) ul = fast_nn.gated_resnet_hstack( ul, cache_v_stack_variable(u), (image_size, ul_filter), extra_pixel_input=ul_list.pop(), row=row, col=col, cache_every=cache_every, run_every=run_every, nonlinearity=resnet_nonlinearity) # Upsample. cache_every, run_every = 4, 2 u = fast_nn.down_shifted_deconv2d( u, (image_size, u_filter), stride=2, row=row, cache_every=cache_every, run_every=run_every) v_stack.append(u) ul = fast_nn.down_right_shifted_deconv2d( ul, (image_size, ul_filter), row=row, col=col, cache_every=cache_every, run_every=run_every) cache_every, run_every = 2, 2 image_size = (image_size[0], image_size[1] * 2, image_size[2] * 2, nr_filters) # Gated resnet layers. for rep in range(nr_resnet + 1): u = fast_nn.gated_resnet_vstack_only( u, (image_size, u_filter), extra_row_input=u_list.pop(), row=row, cache_every=cache_every, run_every=run_every, nonlinearity=resnet_nonlinearity) v_stack.append(u) ul = fast_nn.gated_resnet_hstack( ul, cache_v_stack_variable(u), (image_size, ul_filter), extra_pixel_input=ul_list.pop(), row=row, col=col, cache_every=cache_every, run_every=run_every, nonlinearity=resnet_nonlinearity) # Upsample. cache_every, run_every = 2, 1 u = fast_nn.down_shifted_deconv2d( u, (image_size, u_filter), stride=2, row=row, cache_every=cache_every, run_every=run_every) v_stack.append(u) ul = fast_nn.down_right_shifted_deconv2d( ul, (image_size, ul_filter), row=row, col=col, cache_every=cache_every, run_every=run_every) cache_every, run_every = 1, 1 image_size = (image_size[0], image_size[1] * 2, image_size[2] * 2, nr_filters) # Gated resnet layers. for rep in range(nr_resnet + 1): u = fast_nn.gated_resnet_vstack_only( u, (image_size, u_filter), extra_row_input=u_list.pop(), row=row, cache_every=cache_every, run_every=run_every, nonlinearity=resnet_nonlinearity) v_stack.append(u) ul = fast_nn.gated_resnet_hstack( ul, cache_v_stack_variable(u), (image_size, ul_filter), extra_pixel_input=ul_list.pop(), row=row, col=col, cache_every=cache_every, run_every=run_every, nonlinearity=resnet_nonlinearity) assert len(u_list) == 0 assert len(ul_list) == 0 x_out = nn.nin(tf.nn.elu(ul), 10 * nr_logistic_mix) sample = nn.sample_from_discretized_mix_logistic( x_out, nr_logistic_mix, seed=seed) cache_v_stack = tf.group(*tf.get_collection(UPDATE_V_STACK)) return sample, x_out, cache_v_stack ================================================ FILE: fast_pixel_cnn_pp/nn.py ================================================ """ A mostly copied but slightly modified version of OpenAI's pixel_cnn_pp/nn.py """ import numpy as np import tensorflow as tf from tensorflow.contrib.framework.python.ops import add_arg_scope def int_shape(x): return list(map(int, x.get_shape())) def concat_elu(x): """ like concatenated ReLU (http://arxiv.org/abs/1603.05201), but then with ELU """ x_shape = x.get_shape().as_list() axis = len(x_shape) - 1 out = tf.nn.elu(tf.concat([x, -x], axis)) out.set_shape(x_shape[:-1] + [x_shape[-1] * 2]) return out def log_sum_exp(x): """ numerically stable log_sum_exp implementation that prevents overflow """ axis = len(x.get_shape())-1 m = tf.reduce_max(x, axis) m2 = tf.reduce_max(x, axis, keep_dims=True) return m + tf.log(tf.reduce_sum(tf.exp(x-m2), axis)) def log_prob_from_logits(x): """ numerically stable log_softmax implementation that prevents overflow """ axis = len(x.get_shape())-1 m = tf.reduce_max(x, axis, keep_dims=True) return x - m - tf.log(tf.reduce_sum(tf.exp(x-m), axis, keep_dims=True)) def discretized_mix_logistic_loss(x,l,sum_all=True): """ log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """ xs = int_shape(x) # true image (i.e. labels) to regress to, e.g. (B,32,32,3) ls = int_shape(l) # predicted distribution, e.g. (B,32,32,100) nr_mix = int(ls[-1] / 10) # here and below: unpacking the params of the mixture of logistics logit_probs = l[:,:,:,:nr_mix] l = tf.reshape(l[:,:,:,nr_mix:], xs + [nr_mix*3]) means = l[:,:,:,:,:nr_mix] log_scales = tf.maximum(l[:,:,:,:,nr_mix:2*nr_mix], -7.) coeffs = tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix]) x = tf.reshape(x, xs + [1]) + tf.zeros(xs + [nr_mix]) # here and below: getting the means and adjusting them based on preceding sub-pixels m2 = tf.reshape(means[:,:,:,1,:] + coeffs[:, :, :, 0, :] * x[:, :, :, 0, :], [xs[0],xs[1],xs[2],1,nr_mix]) m3 = tf.reshape(means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + coeffs[:, :, :, 2, :] * x[:, :, :, 1, :], [xs[0],xs[1],xs[2],1,nr_mix]) means = tf.concat(3,[tf.reshape(means[:,:,:,0,:], [xs[0],xs[1],xs[2],1,nr_mix]), m2, m3]) centered_x = x - means inv_stdv = tf.exp(-log_scales) plus_in = inv_stdv * (centered_x + 1./255.) cdf_plus = tf.nn.sigmoid(plus_in) min_in = inv_stdv * (centered_x - 1./255.) cdf_min = tf.nn.sigmoid(min_in) log_cdf_plus = plus_in - tf.nn.softplus(plus_in) # log probability for edge case of 0 (before scaling) log_one_minus_cdf_min = -tf.nn.softplus(min_in) # log probability for edge case of 255 (before scaling) cdf_delta = cdf_plus - cdf_min # probability for all other cases mid_in = inv_stdv * centered_x log_pdf_mid = mid_in - log_scales - 2.*tf.nn.softplus(mid_in) # log probability in the center of the bin, to be used in extreme cases (not actually used in our code) # now select the right output: left edge case, right edge case, normal case, extremely low prob case (doesn't actually happen for us) # this is what we are really doing, but using the robust version below for extreme cases in other applications and to avoid NaN issue with tf.select() # log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999, log_one_minus_cdf_min, tf.log(cdf_delta))) # robust version, that still works if probabilities are below 1e-5 (which never happens in our code) # tensorflow backpropagates through tf.select() by multiplying with zero instead of selecting: this requires use to use some ugly tricks to avoid potential NaNs # the 1e-12 in tf.maximum(cdf_delta, 1e-12) is never actually used as output, it's purely there to get around the tf.select() gradient issue # if the probability on a sub-pixel is below 1e-5, we use an approximation based on the assumption that the log-density is constant in the bin of the observed sub-pixel value log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999, log_one_minus_cdf_min, tf.select(cdf_delta > 1e-5, tf.log(tf.maximum(cdf_delta, 1e-12)), log_pdf_mid - np.log(127.5)))) log_probs = tf.reduce_sum(log_probs,3) + log_prob_from_logits(logit_probs) if sum_all: return -tf.reduce_sum(log_sum_exp(log_probs)) else: return -tf.reduce_sum(log_sum_exp(log_probs),[1,2]) def sample_from_discretized_mix_logistic(l,nr_mix,seed=None): ls = int_shape(l) xs = ls[:-1] + [3] # unpack parameters logit_probs = l[:, :, :, :nr_mix] l = tf.reshape(l[:, :, :, nr_mix:], xs + [nr_mix*3]) # sample mixture indicator from softmax sel = tf.one_hot(tf.argmax(logit_probs - tf.log(-tf.log(tf.random_uniform(logit_probs.get_shape(), minval=1e-5, maxval=1. - 1e-5, seed=seed))), 3), depth=nr_mix, dtype=tf.float32) sel = tf.reshape(sel, xs[:-1] + [1,nr_mix]) # select logistic parameters means = tf.reduce_sum(l[:,:,:,:,:nr_mix]*sel,4) log_scales = tf.maximum(tf.reduce_sum(l[:,:,:,:,nr_mix:2*nr_mix]*sel,4), -7.) coeffs = tf.reduce_sum(tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix])*sel,4) # sample from logistic & clip to interval # we don't actually round to the nearest 8bit value when sampling u = tf.random_uniform(means.get_shape(), minval=1e-5, maxval=1. - 1e-5, seed=(seed + 1 if seed is not None else None)) x = means + tf.exp(log_scales)*(tf.log(u) - tf.log(1. - u)) x0 = tf.minimum(tf.maximum(x[:,:,:,0], -1.), 1.) x1 = tf.minimum(tf.maximum(x[:,:,:,1] + coeffs[:,:,:,0]*x0, -1.), 1.) x2 = tf.minimum(tf.maximum(x[:,:,:,2] + coeffs[:,:,:,1]*x0 + coeffs[:,:,:,2]*x1, -1.), 1.) return tf.concat([tf.reshape(x0,xs[:-1]+[1]), tf.reshape(x1,xs[:-1]+[1]), tf.reshape(x2,xs[:-1]+[1])], 3) def get_var_maybe_avg(var_name, ema, **kwargs): ''' utility for retrieving polyak averaged params ''' v = tf.get_variable(var_name, **kwargs) if ema is not None: v = ema.average(v) return v def get_vars_maybe_avg(var_names, ema, **kwargs): ''' utility for retrieving polyak averaged params ''' vars = [] for vn in var_names: vars.append(get_var_maybe_avg(vn, ema, **kwargs)) return vars def adam_updates(params, cost_or_grads, lr=0.001, mom1=0.9, mom2=0.999): ''' Adam optimizer ''' updates = [] if type(cost_or_grads) is not list: grads = tf.gradients(cost_or_grads, params) else: grads = cost_or_grads t = tf.Variable(1., 'adam_t') for p, g in zip(params, grads): mg = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_mg') if mom1>0: v = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_v') v_t = mom1*v + (1. - mom1)*g v_hat = v_t / (1. - tf.pow(mom1,t)) updates.append(v.assign(v_t)) else: v_hat = g mg_t = mom2*mg + (1. - mom2)*tf.square(g) mg_hat = mg_t / (1. - tf.pow(mom2,t)) g_t = v_hat / tf.sqrt(mg_hat + 1e-8) p_t = p - lr * g_t updates.append(mg.assign(mg_t)) updates.append(p.assign(p_t)) updates.append(t.assign_add(1)) return tf.group(*updates) def get_name(layer_name, counters): ''' utlity for keeping track of layer names ''' if not layer_name in counters: counters[layer_name] = 0 name = layer_name + '_' + str(counters[layer_name]) counters[layer_name] += 1 return name @add_arg_scope def dense(x, num_units, nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs): ''' fully connected layer ''' name = get_name('dense', counters) with tf.variable_scope(name): if init: # data based initialization of parameters V = tf.get_variable('V', [int(x.get_shape()[1]),num_units], tf.float32, tf.random_normal_initializer(0, 0.05), trainable=True) V_norm = tf.nn.l2_normalize(V.initialized_value(), [0]) x_init = tf.matmul(x, V_norm) m_init, v_init = tf.nn.moments(x_init, [0]) scale_init = init_scale/tf.sqrt(v_init + 1e-10) g = tf.get_variable('g', dtype=tf.float32, initializer=scale_init, trainable=True) b = tf.get_variable('b', dtype=tf.float32, initializer=-m_init*scale_init, trainable=True) x_init = tf.reshape(scale_init,[1,num_units])*(x_init-tf.reshape(m_init,[1,num_units])) if nonlinearity is not None: x_init = nonlinearity(x_init) return x_init else: #V,g,b = get_vars_maybe_avg(['V','g','b'], ema) V = tf.get_variable('V', [int(x.get_shape()[1]),num_units], tf.float32) g = tf.get_variable('g', [num_units], tf.float32) b = tf.get_variable('b', [num_units], tf.float32) if ema is not None: V, g, b = ema.average(V), ema.average(g), ema.average(b) #tf.assert_variables_initialized([V,g,b]) # use weight normalization (Salimans & Kingma, 2016) x = tf.matmul(x, V) scaler = g/tf.sqrt(tf.reduce_sum(tf.square(V),[0])) x = tf.reshape(scaler,[1,num_units])*x + tf.reshape(b,[1,num_units]) # apply nonlinearity if nonlinearity is not None: x = nonlinearity(x) return x @add_arg_scope def conv2d(x, num_filters, filter_size=[3,3], stride=[1,1], pad='SAME', nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs): ''' convolutional layer ''' name = get_name('conv2d', counters) with tf.variable_scope(name): if init: # data based initialization of parameters V = tf.get_variable('V', filter_size+[int(x.get_shape()[-1]),num_filters], tf.float32, tf.random_normal_initializer(0, 0.05), trainable=True) V_norm = tf.nn.l2_normalize(V.initialized_value(), [0,1,2]) x_init = tf.nn.conv2d(x, V_norm, [1]+stride+[1], pad) m_init, v_init = tf.nn.moments(x_init, [0,1,2]) scale_init = init_scale/tf.sqrt(v_init + 1e-8) g = tf.get_variable('g', dtype=tf.float32, initializer=scale_init, trainable=True) b = tf.get_variable('b', dtype=tf.float32, initializer=-m_init*scale_init, trainable=True) x_init = tf.reshape(scale_init,[1,1,1,num_filters])*(x_init-tf.reshape(m_init,[1,1,1,num_filters])) if nonlinearity is not None: x_init = nonlinearity(x_init) return x_init else: V, g, b = get_vars_maybe_avg(['V', 'g', 'b'], ema) tf.assert_variables_initialized([V,g,b]) # use weight normalization (Salimans & Kingma, 2016) W = tf.reshape(g,[1,1,1,num_filters])*tf.nn.l2_normalize(V,[0,1,2]) # calculate convolutional layer output x = tf.nn.bias_add(tf.nn.conv2d(x, W, [1]+stride+[1], pad), b) # apply nonlinearity if nonlinearity is not None: x = nonlinearity(x) return x @add_arg_scope def deconv2d(x, num_filters, filter_size=[3,3], stride=[1,1], pad='SAME', nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs): ''' transposed convolutional layer ''' name = get_name('deconv2d', counters) xs = int_shape(x) if pad=='SAME': target_shape = [xs[0], xs[1]*stride[0], xs[2]*stride[1], num_filters] else: target_shape = [xs[0], xs[1]*stride[0] + filter_size[0]-1, xs[2]*stride[1] + filter_size[1]-1, num_filters] with tf.variable_scope(name): if init: # data based initialization of parameters V = tf.get_variable('V', filter_size+[num_filters,int(x.get_shape()[-1])], tf.float32, tf.random_normal_initializer(0, 0.05), trainable=True) V_norm = tf.nn.l2_normalize(V.initialized_value(), [0,1,3]) x_init = tf.nn.conv2d_transpose(x, V_norm, target_shape, [1]+stride+[1], padding=pad) m_init, v_init = tf.nn.moments(x_init, [0,1,2]) scale_init = init_scale/tf.sqrt(v_init + 1e-8) g = tf.get_variable('g', dtype=tf.float32, initializer=scale_init, trainable=True) b = tf.get_variable('b', dtype=tf.float32, initializer=-m_init*scale_init, trainable=True) x_init = tf.reshape(scale_init,[1,1,1,num_filters])*(x_init-tf.reshape(m_init,[1,1,1,num_filters])) if nonlinearity is not None: x_init = nonlinearity(x_init) return x_init else: V, g, b = get_vars_maybe_avg(['V', 'g', 'b'], ema) tf.assert_variables_initialized([V,g,b]) # use weight normalization (Salimans & Kingma, 2016) W = tf.reshape(g,[1,1,num_filters,1])*tf.nn.l2_normalize(V,[0,1,3]) # calculate convolutional layer output x = tf.nn.conv2d_transpose(x, W, target_shape, [1]+stride+[1], padding=pad) x = tf.nn.bias_add(x, b) # apply nonlinearity if nonlinearity is not None: x = nonlinearity(x) return x @add_arg_scope def nin(x, num_units, **kwargs): """ a network in network layer (1x1 CONV) """ s = int_shape(x) x = tf.reshape(x, [np.prod(s[:-1]),s[-1]]) x = dense(x, num_units, **kwargs) return tf.reshape(x, s[:-1]+[num_units]) ''' meta-layer consisting of multiple base layers ''' @add_arg_scope def gated_resnet(x, a=None, h=None, nonlinearity=concat_elu, conv=conv2d, init=False, counters={}, ema=None, dropout_p=0., **kwargs): xs = int_shape(x) num_filters = xs[-1] c1 = conv(nonlinearity(x), num_filters, counters=counters, init=init) if a is not None: # add short-cut connection if auxiliary input 'a' is given c1 += nin(nonlinearity(a), num_filters, counters=counters, init=init) c1 = nonlinearity(c1) if dropout_p > 0: c1 = tf.nn.dropout(c1, keep_prob=1. - dropout_p) c2 = conv(c1, num_filters * 2, init_scale=0.1, counters=counters, init=init) # add projection of h vector if included: conditional generation if h is not None: with tf.variable_scope(get_name('conditional_weights', counters)): hw = get_var_maybe_avg('hw', ema, shape=[int_shape(h)[-1], 2 * num_filters], dtype=tf.float32, initializer=tf.random_normal_initializer(0, 0.05), trainable=True) if init: hw = hw.initialized_value() c2 += tf.reshape(tf.matmul(h, hw), [xs[0], 1, 1, 2 * num_filters]) a, b = tf.split(c2, 2, 3) c3 = a * tf.nn.sigmoid(b) return x + c3 ''' utilities for shifting the image around, efficient alternative to masking convolutions ''' def down_shift(x): xs = int_shape(x) return tf.concat(1,[tf.zeros([xs[0],1,xs[2],xs[3]]), x[:,:xs[1]-1,:,:]]) def right_shift(x): xs = int_shape(x) return tf.concat(2,[tf.zeros([xs[0],xs[1],1,xs[3]]), x[:,:,:xs[2]-1,:]]) @add_arg_scope def down_shifted_conv2d(x, num_filters, filter_size=[2,3], stride=[1,1], **kwargs): x = tf.pad(x, [[0,0],[filter_size[0]-1,0], [int((filter_size[1]-1)/2),int((filter_size[1]-1)/2)],[0,0]]) return conv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs) @add_arg_scope def down_shifted_deconv2d(x, num_filters, filter_size=[2,3], stride=[1,1], **kwargs): x = deconv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs) xs = int_shape(x) return x[:,:(xs[1]-filter_size[0]+1),int((filter_size[1]-1)/2):(xs[2]-int((filter_size[1]-1)/2)),:] @add_arg_scope def down_right_shifted_conv2d(x, num_filters, filter_size=[2,2], stride=[1,1], **kwargs): x = tf.pad(x, [[0,0],[filter_size[0]-1, 0], [filter_size[1]-1, 0],[0,0]]) return conv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs) @add_arg_scope def down_right_shifted_deconv2d(x, num_filters, filter_size=[2,2], stride=[1,1], **kwargs): x = deconv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs) xs = int_shape(x) return x[:,:(xs[1]-filter_size[0]+1):,:(xs[2]-filter_size[1]+1),:] ================================================ FILE: fast_pixel_cnn_pp/plotting.py ================================================ ''' Copied from OpenAI's pixel_cnn_pp/plotting.py ''' import numpy as np import matplotlib matplotlib.use('Agg') from matplotlib import pyplot as plt # Plot image examples. def plot_img(img, title=None): plt.figure() plt.imshow(img, interpolation='nearest') if title is not None: plt.title(title) plt.axis('off') plt.tight_layout() plt.show(block=False) def img_stretch(img): img = img.astype(float) img -= np.min(img) img /= np.max(img)+1e-12 return img def img_tile(imgs, aspect_ratio=1.0, tile_shape=None, border=1, border_color=0, stretch=False): ''' Tile images in a grid. If tile_shape is provided only as many images as specified in tile_shape will be included in the output. ''' # Prepare images if stretch: imgs = img_stretch(imgs) imgs = np.array(imgs) if imgs.ndim != 3 and imgs.ndim != 4: raise ValueError('imgs has wrong number of dimensions.') n_imgs = imgs.shape[0] # Grid shape img_shape = np.array(imgs.shape[1:3]) if tile_shape is None: img_aspect_ratio = img_shape[1] / float(img_shape[0]) aspect_ratio *= img_aspect_ratio tile_height = int(np.ceil(np.sqrt(n_imgs * aspect_ratio))) tile_width = int(np.ceil(np.sqrt(n_imgs / aspect_ratio))) grid_shape = np.array((tile_height, tile_width)) else: assert len(tile_shape) == 2 grid_shape = np.array(tile_shape) # Tile image shape tile_img_shape = np.array(imgs.shape[1:]) tile_img_shape[:2] = (img_shape[:2] + border) * grid_shape[:2] - border # Assemble tile image tile_img = np.empty(tile_img_shape) tile_img[:] = border_color for i in range(grid_shape[0]): for j in range(grid_shape[1]): img_idx = j + i*grid_shape[1] if img_idx >= n_imgs: # No more images - stop filling out the grid. break img = imgs[img_idx] yoff = (img_shape[0] + border) * i xoff = (img_shape[1] + border) * j tile_img[yoff:yoff+img_shape[0], xoff:xoff+img_shape[1], ...] = img return tile_img def conv_filter_tile(filters): n_filters, n_channels, height, width = filters.shape tile_shape = None if n_channels == 3: # Interpret 3 color channels as RGB filters = np.transpose(filters, (0, 2, 3, 1)) else: # Organize tile such that each row corresponds to a filter and the # columns are the filter channels tile_shape = (n_channels, n_filters) filters = np.transpose(filters, (1, 0, 2, 3)) filters = np.resize(filters, (n_filters*n_channels, height, width)) filters = img_stretch(filters) return img_tile(filters, tile_shape=tile_shape) def scale_to_unit_interval(ndar, eps=1e-8): """ Scales all values in the ndarray ndar to be between 0 and 1 """ ndar = ndar.copy() ndar -= ndar.min() ndar *= 1.0 / (ndar.max() + eps) return ndar def tile_raster_images(X, img_shape, tile_shape, tile_spacing=(0, 0), scale_rows_to_unit_interval=True, output_pixel_vals=True): """ Transform an array with one flattened image per row, into an array in which images are reshaped and layed out like tiles on a floor. This function is useful for visualizing datasets whose rows are images, and also columns of matrices for transforming those rows (such as the first layer of a neural net). :type X: a 2-D ndarray or a tuple of 4 channels, elements of which can be 2-D ndarrays or None; :param X: a 2-D array in which every row is a flattened image. :type img_shape: tuple; (height, width) :param img_shape: the original shape of each image :type tile_shape: tuple; (rows, cols) :param tile_shape: the number of images to tile (rows, cols) :param output_pixel_vals: if output should be pixel values (i.e. int8 values) or floats :param scale_rows_to_unit_interval: if the values need to be scaled before being plotted to [0,1] or not :returns: array suitable for viewing as an image. (See:`PIL.Image.fromarray`.) :rtype: a 2-d array with same dtype as X. """ assert len(img_shape) == 2 assert len(tile_shape) == 2 assert len(tile_spacing) == 2 # The expression below can be re-written in a more C style as # follows : # # out_shape = [0,0] # out_shape[0] = (img_shape[0] + tile_spacing[0]) * tile_shape[0] - # tile_spacing[0] # out_shape[1] = (img_shape[1] + tile_spacing[1]) * tile_shape[1] - # tile_spacing[1] out_shape = [(ishp + tsp) * tshp - tsp for ishp, tshp, tsp in zip(img_shape, tile_shape, tile_spacing)] if isinstance(X, tuple): assert len(X) == 4 # Create an output numpy ndarray to store the image if output_pixel_vals: out_array = np.zeros((out_shape[0], out_shape[1], 4), dtype='uint8') else: out_array = np.zeros((out_shape[0], out_shape[1], 4), dtype=X.dtype) #colors default to 0, alpha defaults to 1 (opaque) if output_pixel_vals: channel_defaults = [0, 0, 0, 255] else: channel_defaults = [0., 0., 0., 1.] for i in range(4): if X[i] is None: # if channel is None, fill it with zeros of the correct # dtype out_array[:, :, i] = np.zeros(out_shape, dtype='uint8' if output_pixel_vals else out_array.dtype ) + channel_defaults[i] else: # use a recurrent call to compute the channel and store it # in the output out_array[:, :, i] = tile_raster_images(X[i], img_shape, tile_shape, tile_spacing, scale_rows_to_unit_interval, output_pixel_vals) return out_array else: # if we are dealing with only one channel H, W = img_shape Hs, Ws = tile_spacing # generate a matrix to store the output out_array = np.zeros(out_shape, dtype='uint8' if output_pixel_vals else X.dtype) for tile_row in range(tile_shape[0]): for tile_col in range(tile_shape[1]): if tile_row * tile_shape[1] + tile_col < X.shape[0]: if scale_rows_to_unit_interval: # if we should scale values to be between 0 and 1 # do this by calling the `scale_to_unit_interval` # function this_img = scale_to_unit_interval(X[tile_row * tile_shape[1] + tile_col].reshape(img_shape)) else: this_img = X[tile_row * tile_shape[1] + tile_col].reshape(img_shape) # add the slice to the corresponding position in the # output array out_array[ tile_row * (H+Hs): tile_row * (H + Hs) + H, tile_col * (W+Ws): tile_col * (W + Ws) + W ] \ = this_img * (255 if output_pixel_vals else 1) return out_array ================================================ FILE: fast_pixel_cnn_pp/test_components.py ================================================ from . import nn from . import fast_nn import tensorflow as tf import numpy as np import math import unittest from collections import namedtuple Placeholders = namedtuple( 'Placeholders', ['full_input', 'pixel_input', 'row_input', 'row_id', 'col_id']) class FastPixelCNNPPTest(tf.test.TestCase): def test_down_shifted_conv2d_layer1_stride1(self): self._test_down_shifted(num_layers=1) def test_down_shifted_conv2d_layer1_stride1_1by3(self): self._test_down_shifted(num_layers=1, filter_size=[1, 3]) def test_down_shifted_conv2d_layer2_stride1(self): self._test_down_shifted(num_layers=2) def test_down_shifted_conv2d_layer3_stride1(self): self._test_down_shifted(num_layers=3) def test_down_shifted_conv2d_layer2_stride1_2(self): self._test_down_shifted(num_layers=2, strides=[1, 2]) def test_down_shifted_conv2d_layer1_stride2(self): self._test_down_shifted(num_layers=1, strides=[2]) def test_down_shifted_conv2d_layer3_stride1_2_2(self): self._test_down_shifted(num_layers=3, strides=[1, 2, 2]) def test_down_shifted_deconv2d_layer1_stride2(self): self._test_down_shifted(num_layers=1, strides=[2], layers=['deconv']) def test_down_shifted_deconv2d_layer2_stride2(self): self._test_down_shifted( num_layers=2, strides=[2, 2], layers=['deconv', 'deconv']) def test_down_shifted_conv2d_deconv2d_layer2_stride2(self): self._test_down_shifted( num_layers=2, strides=[2, 2], layers=['conv', 'deconv']) def test_down_shifted_conv2d_conv2d_2_deconv_2_layer2(self): self._test_down_shifted( num_layers=3, strides=[1, 2, 2], layers=['conv', 'conv', 'deconv']) def test_down_shifted_conv2d_conv2d_deconv2d_deconv2d_layer4_stride2(self): self._test_down_shifted( num_layers=4, strides=[2, 2, 2, 2], layers=['conv', 'conv', 'deconv', 'deconv']) def test_down_shifted_conv2d_deconv2d_conv2d_deconv2d_layer4_stride2(self): self._test_down_shifted( num_layers=4, strides=[2, 2, 2, 2], layers=['conv', 'deconv', 'conv', 'deconv']) def test_down_right_shifted_conv2d_layer1_stride1_1by3(self): self._test_down_right_shifted( batch_size=1, size=8, num_layers=1, filter_size=[1, 3]) def test_down_right_shifted_conv2d_layer1_stride1(self): self._test_down_right_shifted(batch_size=1, size=8, num_layers=1) def test_down_right_shifted_conv2d_layer2_stride1(self): self._test_down_right_shifted(batch_size=1, size=8, num_layers=2) def test_down_right_shifted_conv2d_layer3_stride1(self): self._test_down_right_shifted(batch_size=1, size=8, num_layers=3) def test_down_right_shifted_conv2d_layer1_stride2(self): self._test_down_right_shifted( batch_size=1, size=8, num_layers=1, strides=[2]) def test_down_right_shifted_conv2d_layer2_stride2(self): self._test_down_right_shifted( batch_size=1, size=8, num_layers=2, strides=[2, 2]) def test_down_right_shifted_conv2d_layer3_stride2(self): self._test_down_right_shifted( batch_size=1, size=16, num_layers=3, strides=[2, 1, 2]) def test_down_right_shifted_deconv2d_layer1_stride2(self): self._test_down_right_shifted( batch_size=1, size=4, num_layers=1, layers=['deconv'], strides=[2]) def test_down_right_shifted_deconv2d_layer2_stride2(self): self._test_down_right_shifted( batch_size=1, size=4, num_layers=2, layers=['deconv', 'deconv'], strides=[2, 2]) def test_down_right_shifted_deconv2d_layer3_stride2(self): self._test_down_right_shifted( batch_size=1, size=4, num_layers=3, layers=['deconv', 'deconv', 'deconv'], strides=[2, 2, 2]) def test_down_right_shifted_conv2d_deconv2d_layer2_stride2(self): self._test_down_right_shifted( num_layers=2, strides=[2, 2], layers=['conv', 'deconv']) def test_down_right_shifted_conv2d_conv2d_2_deconv_2_layer2(self): self._test_down_right_shifted( num_layers=3, strides=[1, 2, 2], layers=['conv', 'conv', 'deconv']) def test_down_right_shifted_conv2d_conv2d_deconv2d_deconv2d_layer4_stride2( self): self._test_down_right_shifted( num_layers=4, strides=[2, 2, 2, 2], layers=['conv', 'conv', 'deconv', 'deconv']) def test_down_right_shifted_conv2d_deconv2d_conv2d_deconv2d_layer4_stride2( self): self._test_down_right_shifted( num_layers=4, strides=[2, 2, 2, 2], layers=['conv', 'deconv', 'conv', 'deconv']) def test_sum_rightshift_downshift(self): self._test_sum_rightshift_downshift(size=32) def test_gated_resnet_vstack_only_basic(self): self._gated_resnet_vstack_only() def test_gated_resnet_vstack_only_use_h(self): self._gated_resnet_vstack_only(use_h=True) def test_gated_resnet_vstack_only_basic_3layers(self): self._gated_resnet_vstack_only(num_layers=3) def test_gated_resnet_vstack_only_use_h_3layers(self): self._gated_resnet_vstack_only(use_h=True, num_layers=3) def test_gated_resnet_vstack_only_use_extra_row_input(self): self._gated_resnet_vstack_only(use_extra_row_input=True) def test_gated_resnet_vstack_only_use_extra_row_input_3layers(self): self._gated_resnet_vstack_only(use_extra_row_input=True, num_layers=3) def test_gated_resnet_vstack_only_use_extra_row_input_and_use_h(self): self._gated_resnet_vstack_only(use_extra_row_input=True, use_h=True) def test_gated_resnet_vstack_only_use_extra_row_input__and_use_h_3layers( self): self._gated_resnet_vstack_only( use_extra_row_input=True, use_h=True, num_layers=3) def test_gated_resnet_hstack_basic(self): self._gated_resnet_hstack() def test_gated_resnet_hstack_use_h(self): self._gated_resnet_hstack(use_h=True) def test_gated_resnet_hstack_basic_3layers(self): self._gated_resnet_hstack(num_layers=3) def test_gated_resnet_hstack_use_h_3layers(self): self._gated_resnet_hstack(use_h=True, num_layers=3) def test_gated_resnet_hstack_use_extra_pixel_input(self): self._gated_resnet_hstack(use_extra_pixel_input=True) def test_gated_resnet_hstack_use_extra_pixel_input_3layers(self): self._gated_resnet_hstack(use_extra_pixel_input=True, num_layers=3) def test_gated_resnet_hstack_use_extra_pixel_input_and_use_h(self): self._gated_resnet_hstack(use_extra_pixel_input=True, use_h=True) def test_gated_resnet_hstack_use_extra_pixel_input_and_use_h_3layers(self): self._gated_resnet_hstack( use_extra_pixel_input=True, use_h=True, num_layers=3) def _get_placeholders(self, image_size): '''Creates all placeholders.''' batch_size, size, _, input_channels = image_size full_input = tf.placeholder( tf.float32, [batch_size, size, size, input_channels], name='full_input') pixel_input = tf.placeholder( tf.float32, [batch_size, 1, 1, input_channels], name='pixel_input') row_input = tf.placeholder( tf.float32, [batch_size, 1, size, input_channels], name='row_input') row_id = tf.placeholder(tf.int32, [], name='row_id') col_id = tf.placeholder(tf.int32, [], name='col_id') return Placeholders(full_input, pixel_input, row_input, row_id, col_id) def _setup_test_equal(self, sess, nn_out, full_input, image_size, output_image_size): '''Sets up both _test_*_equals() methods by initializing variables and outputs.''' np.random.seed(2702) x = np.random.randn(*image_size) # nn layers use data dependent initialization. # Data dependent initialization requires a batch of initial data, # which we pass through with a feed dict. sess.run(tf.global_variables_initializer(), {full_input: x}) # Calculate ground truth output. ground_truth_output = sess.run(nn_out, {full_input: x}) # Create variable that holds output. if output_image_size is None: output_image_size = image_size fast_output = np.zeros(output_image_size) # Calculate the increase in output size compared to the input size. # This is useful when only deconv (upsampling) layers are used. side_length = image_size[2] width_ratio = output_image_size[2] // image_size[2] image_increase_factor = max(1, width_ratio) # Reset the cache to be safe. sess.run(fast_nn.reset_cache_op()) return x, ground_truth_output, fast_output, side_length, image_increase_factor def _test_rows_equal(self, sess, fast_nn_out, nn_out, placeholders, image_size, output_image_size=None, run_every=1): '''Tests if vertical stack outputs (one row at a time) of our code and OpenAI code are equal.''' (x, ground_truth_output, fast_output, side_length, image_increase_factor) = self._setup_test_equal( sess, nn_out, placeholders.full_input, image_size, output_image_size) # Generate fast output. for row in range(side_length): x_row_input = x[:, row:(row + 1), :, :] # image_increase_factor is relevant when only deconvs are used. # It just runs each row of input multiple times to populate the upsampled output. for inner_iteration in range(image_increase_factor): row_compensated = image_increase_factor * row + inner_iteration feed_dict = { placeholders.row_input: x_row_input, placeholders.row_id: row_compensated } row_output = sess.run(fast_nn_out, feed_dict) if row_compensated % run_every == 0: # The run_every division is for downsampling, # because the output is smaller than the input. output_row = row_compensated // run_every fast_output[:, output_row:(output_row + 1), :, :] = row_output # Within a tolerance. self.assertTrue(np.allclose(ground_truth_output, fast_output)) # Exact match. self.assertTrue( np.max(np.abs(ground_truth_output - fast_output)) == 0.0) def _test_pixels_equal(self, sess, fast_nn_out, nn_out, placeholders, image_size, output_image_size=None, run_every=1, atol=1e-6): '''Tests if horizontal stack outputs (one pixel at a time) of our code and OpenAI code are equal.''' (x, ground_truth_output, fast_output, side_length, image_increase_factor) = self._setup_test_equal( sess, nn_out, placeholders.full_input, image_size, output_image_size) # Generate fast output. for row in range(side_length): # image_increase_factor is relevant when only deconvs are used. # It just runs each row and column of input multiple times to populate the upsampled output. for inner_row_iteration in range(image_increase_factor): row_compensated = image_increase_factor * row + inner_row_iteration x_row_input = x[:, row:(row + 1), :, :] for col in range(side_length): x_pixel_input = x[:, row:(row + 1), col:(col + 1), :] for inner_col_iteration in range(image_increase_factor): col_compensated = image_increase_factor * col + inner_col_iteration feed_dict = { placeholders.pixel_input: x_pixel_input, placeholders.row_id: row_compensated, placeholders.col_id: col_compensated, placeholders.row_input: x_row_input } pixel_output = sess.run(fast_nn_out, feed_dict) # The run_every division is for downsampling, # because the output is smaller than the input. if row_compensated % run_every == 0 and col_compensated % run_every == 0: output_row = row_compensated // run_every output_col = col_compensated // run_every fast_output[:, output_row:(output_row + 1), output_col:(output_col + 1 ), :] = pixel_output self.assertTrue( np.allclose(ground_truth_output, fast_output, atol=atol)) def _setup_conv_tests(self, batch_size, size, channels, filter_size, strides, layers, num_layers): '''Sets up the conv tests by computing basic layer information.''' image_size = (batch_size, size, size, channels) full_filter_size = filter_size + [channels] if strides is None: strides = [1 for _ in range(num_layers)] assert len(strides) == num_layers if layers is None: layers = ['conv' for _ in range(num_layers)] assert len(layers) == num_layers return image_size, full_filter_size, strides, layers def _compute_conv_fast_nn_out(self, compute_output_func, network_input, image_size, strides, layers): '''Computes cached convolutions, handling downsampling and upsampling.''' batch_size, size, _, nr_filters = image_size num_layers = len(layers) # Computes the final output size taking into account downsampling and upsampling. output_size = size for stride, layer_type in zip(strides, layers): if layer_type == 'conv': output_size = output_size // stride else: output_size = output_size * stride output_image_size = (batch_size, output_size, output_size, nr_filters) # When running only deconvs, the output size gets bigger than the input size. # For generation, each input must be run multiple times to populate the output. image_increase_factor = max(output_size // size, 1) cumulative_stride = max(1, image_increase_factor) fast_nn_out = network_input counters = {} layer_input_size = size # Run the network. for layer_num in range(num_layers): stride = strides[layer_num] layer_type = layers[layer_num] # The run_every of one layer is the cache_every of the next layer. # These increase after downsampling since fewer inputs correspond to an output. # These decrease after downsampling since more inputs correspond to an output. cache_every = cumulative_stride if layer_type == 'conv': run_every = cumulative_stride * stride else: run_every = max(1, cumulative_stride // stride) input_image_size = (batch_size, layer_input_size, layer_input_size, nr_filters) cumulative_stride = run_every fast_nn_out = compute_output_func(fast_nn_out, layer_type, input_image_size, stride, cache_every, run_every, counters) # The size of the input to the next layer. if layer_type == 'conv': layer_input_size = layer_input_size // stride # Downsampling. else: layer_input_size = layer_input_size * stride # Upsampling. return fast_nn_out, output_image_size, run_every def _test_down_shifted(self, batch_size=10, size=16, channels=7, num_layers=1, filter_size=[2, 3], strides=None, layers=None, nonlinearity=tf.sigmoid): '''Tests the down_shifted convolution for the vertical stack.''' def get_conv_function(module, layer_type): '''Returns the matching conv or deconv function.''' if layer_type == 'conv': return module.down_shifted_conv2d elif layer_type == 'deconv': return module.down_shifted_deconv2d else: raise ValueError('Unknown layer_type %s' % layer_type) image_size, full_filter_size, strides, layers = self._setup_conv_tests( batch_size, size, channels, filter_size, strides, layers, num_layers) with self.test_session() as sess: placeholders = self._get_placeholders(image_size) # OpenAI output. def compute_ground_truth(init): nn_out = placeholders.full_input counters = {} for layer_num in range(num_layers): stride = strides[layer_num] layer_func = get_conv_function(nn, layers[layer_num]) nn_out = layer_func( nn_out, num_filters=channels, filter_size=filter_size, stride=[stride, stride], nonlinearity=nonlinearity, counters=counters, init=init) return nn_out compute_ground_truth(init=True) tf.get_variable_scope().reuse_variables() nn_out = compute_ground_truth(init=False) # Our output. def compute_output_func(fast_nn_out, layer_type, input_image_size, stride, cache_every, run_every, counters): layer_func = get_conv_function(fast_nn, layer_type) return layer_func( fast_nn_out, network_info=(input_image_size, full_filter_size), stride=stride, row=placeholders.row_id, cache_every=cache_every, run_every=run_every, counters=counters, nonlinearity=nonlinearity) fast_nn_out, output_image_size, run_every = self._compute_conv_fast_nn_out( compute_output_func, placeholders.row_input, image_size, strides, layers) self._test_rows_equal( sess, fast_nn_out, nn_out, placeholders, image_size, output_image_size=output_image_size, run_every=run_every) def _test_down_right_shifted(self, batch_size=10, size=16, channels=7, num_layers=1, filter_size=[2, 2], strides=None, layers=None, nonlinearity=tf.sigmoid): '''Tests the down_shifted convolution for the vertical stack.''' def get_conv_function(module, layer_type): '''Returns the matching conv or deconv function.''' if layer_type == 'conv': return module.down_right_shifted_conv2d elif layer_type == 'deconv': return module.down_right_shifted_deconv2d else: raise ValueError('Unknown layer_type %s' % layer_type) image_size, full_filter_size, strides, layers = self._setup_conv_tests( batch_size, size, channels, filter_size, strides, layers, num_layers) with self.test_session() as sess: placeholders = self._get_placeholders(image_size) # OpenAI output. def compute_ground_truth(init): nn_out = placeholders.full_input counters = {} for layer_num in range(num_layers): stride = strides[layer_num] layer_func = get_conv_function(nn, layers[layer_num]) nn_out = layer_func( nn_out, num_filters=channels, filter_size=filter_size, stride=[stride, stride], nonlinearity=nonlinearity, counters=counters, init=init) return nn_out compute_ground_truth(init=True) tf.get_variable_scope().reuse_variables() nn_out = compute_ground_truth(init=False) # Our output. def compute_output_func(fast_nn_out, layer_type, input_image_size, stride, cache_every, run_every, counters): layer_func = get_conv_function(fast_nn, layer_type) return layer_func( fast_nn_out, network_info=(input_image_size, full_filter_size), row=placeholders.row_id, col=placeholders.col_id, cache_every=cache_every, run_every=run_every, counters=counters, nonlinearity=nonlinearity) fast_nn_out, output_image_size, run_every = self._compute_conv_fast_nn_out( compute_output_func, placeholders.pixel_input, image_size, strides, layers) self._test_pixels_equal( sess, fast_nn_out, nn_out, placeholders, image_size, output_image_size=output_image_size, run_every=run_every) def _gated_resnet_vstack_only(self, batch_size=10, size=16, channels=7, num_layers=1, filter_size=[2, 3], use_h=False, use_extra_row_input=False, nonlinearity=tf.sigmoid): '''Tests the gated resnet layers for the vertical stack.''' image_size = (batch_size, size, size, channels) full_filter_size = filter_size + [channels] np.random.seed(2702) with self.test_session() as sess: placeholders = self._get_placeholders(image_size) # Conditional information and skip connections. h, a = None, None if use_h: h = tf.constant( np.random.randn(batch_size, 20), dtype=tf.float32) if use_extra_row_input: a = placeholders.full_input # OpenAI output. def compute_ground_truth(init): counters = {} nn_out = placeholders.full_input for _ in range(num_layers): nn_out = nn.gated_resnet( nn_out, a=a, h=h, conv=nn.down_shifted_conv2d, nonlinearity=nonlinearity, counters=counters, init=init) return nn_out compute_ground_truth(init=True) tf.get_variable_scope().reuse_variables() nn_out = compute_ground_truth(init=False) # Our output. counters = {} fast_nn_out = placeholders.row_input if use_extra_row_input: a = placeholders.row_input for _ in range(num_layers): fast_nn_out = fast_nn.gated_resnet_vstack_only( fast_nn_out, (image_size, full_filter_size), placeholders.row_id, extra_row_input=a, h=h, cache_every=1, run_every=1, nonlinearity=nonlinearity, counters=counters) self._test_rows_equal(sess, fast_nn_out, nn_out, placeholders, image_size) def _gated_resnet_hstack(self, batch_size=10, size=16, channels=7, filter_size=[2, 2], num_layers=1, use_h=False, use_extra_pixel_input=False, nonlinearity=tf.sigmoid): '''Tests the gated resnet layers for the horizontal stack.''' image_size = (batch_size, size, size, channels) full_filter_size = filter_size + [channels] with self.test_session() as sess: placeholders = self._get_placeholders(image_size) # Conditional information and skip connections. h, a = None, placeholders.full_input if use_h: h = tf.constant( np.random.randn(batch_size, 20), dtype=tf.float32) if use_extra_pixel_input: a = tf.concat([a, 2 * placeholders.full_input], 3) # OpenAI output. def compute_ground_truth(init): counters = {} nn_out = placeholders.full_input for _ in range(num_layers): nn_out = nn.gated_resnet( nn_out, a=a, h=h, conv=nn.down_right_shifted_conv2d, nonlinearity=nonlinearity, counters=counters, init=init) return nn_out compute_ground_truth(init=True) tf.get_variable_scope().reuse_variables() nn_out = compute_ground_truth(init=False) # Our output. extra_pixel_input = None if use_extra_pixel_input: extra_pixel_input = 2 * placeholders.pixel_input counters = {} fast_nn_out = placeholders.pixel_input for _ in range(num_layers): fast_nn_out = fast_nn.gated_resnet_hstack( fast_nn_out, placeholders.row_input, (image_size, full_filter_size), h=h, row=placeholders.row_id, col=placeholders.col_id, cache_every=1, run_every=1, extra_pixel_input=extra_pixel_input, nonlinearity=nonlinearity, counters=counters) self._test_pixels_equal(sess, fast_nn_out, nn_out, placeholders, image_size) def _test_sum_rightshift_downshift(self, batch_size=10, size=16, channels=7, nonlinearity=tf.sigmoid): '''Tests the sum of the vertical and horizontal stack.''' image_size = (batch_size, size, size, channels) with self.test_session() as sess: placeholders = self._get_placeholders(image_size) # OpenAI output. def compute_ground_truth(init): counters = {} nn_v_stack = nn.down_shifted_conv2d( placeholders.full_input, num_filters=channels, filter_size=[1, 3], stride=[1, 1], nonlinearity=nonlinearity, counters=counters, init=init) nn_h_stack = nn.down_right_shifted_conv2d( placeholders.full_input, num_filters=channels, filter_size=[2, 1], stride=[1, 1], nonlinearity=nonlinearity, counters=counters, init=init) return nn_v_stack + nn_h_stack compute_ground_truth(init=True) tf.get_variable_scope().reuse_variables() nn_out = compute_ground_truth(init=False) # Our output counters, stride, cache_every, run_every = {}, 1, 1, 1 fast_nn_v_stack = fast_nn.down_shifted_conv2d( placeholders.row_input, network_info=(image_size, [1, 3, channels]), stride=stride, row=placeholders.row_id, cache_every=cache_every, run_every=run_every, counters=counters, nonlinearity=nonlinearity) fast_nn_h_stack = fast_nn.down_right_shifted_conv2d( placeholders.pixel_input, network_info=(image_size, [2, 1, channels]), row=placeholders.row_id, col=placeholders.col_id, cache_every=cache_every, run_every=run_every, counters=counters, nonlinearity=nonlinearity) fast_nn_out = fast_nn.sum_rightshift_downshift( fast_nn_h_stack, fast_nn_v_stack, placeholders.col_id) self._test_pixels_equal(sess, fast_nn_out, nn_out, placeholders, image_size) ================================================ FILE: fast_pixel_cnn_pp/test_end_to_end.py ================================================ from . import model from . import fast_nn import tensorflow as tf import numpy as np import os import unittest class FastPixelCNNPPEndToEndTest(tf.test.TestCase): def test_end_to_end(self): with self.test_session() as sess: print('Creating model') image_size = (10, 32, 32, 4) batch_size, image_height, image_width, image_channels = image_size # Create placeholders. row_input = tf.placeholder( tf.float32, [batch_size, 1, image_width, image_channels], name='row_input') pixel_input = tf.placeholder( tf.float32, [batch_size, 1, 1, image_channels], name='pixel_input') row_id = tf.placeholder(tf.int32, [], name='row_id') col_id = tf.placeholder(tf.int32, [], name='col_id') ema = tf.train.ExponentialMovingAverage(0.9995) # Create the model. model_spec = tf.make_template('model', model.model_spec) sample, fast_nn_out, v_stack = model_spec( row_input, pixel_input, row_id, col_id, image_size) # Initialize the caches. cache_variables = [ v for v in tf.global_variables() if 'cache' in v.name ] sess.run(tf.variables_initializer(cache_variables)) # Load the pretrained model print('Restoring variables') vars_to_restore = { k: v for k, v in ema.variables_to_restore().items() if 'cache' not in k } saver = tf.train.Saver(vars_to_restore) ckpt_path = None assert ckpt_path, 'Provide a path to the checkpoint in this file' saver.restore(sess, ckpt_path) # Create the fixed random input. np.random.seed(2702) x = np.random.randint(0, 256, size=(10, 32, 32, 3)) x = np.cast[np.float32]((x - 127.5) / 127.5) x_pad = np.concatenate( (x, np.ones((batch_size, 32, 32, 1))), axis=3) x_downshift = fast_nn.down_shift(x_pad) x_rightshift = fast_nn.right_shift(x_pad) # Holds the output. num_output_features = 10 * 10 output_features = np.zeros( (batch_size, 32, 32, num_output_features)) # Compute all features. print('Computing features') sess.run(fast_nn.reset_cache_op()) for row in range(image_height): x_row_input = x_downshift[:, row:(row + 1), :, :] sess.run(v_stack, {row_input: x_row_input, row_id: row}) for col in range(image_width): x_pixel_input = x_rightshift[:, row:(row + 1), col:(col + 1), :] feed_dict = { row_id: row, col_id: col, pixel_input: x_pixel_input } pixel_features = sess.run(fast_nn_out, feed_dict) output_features[:, row:(row + 1), col:( col + 1), :] = pixel_features ground_truth_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'ground_truth_output.npy') ground_truth_features = np.load(ground_truth_file) total_features = np.prod(output_features[0].shape) for i in range(batch_size): self.assertTrue( np.allclose( output_features[i, :, :, :], ground_truth_features[i, :, :, :], atol=1e-4)) ================================================ FILE: generate.py ================================================ import fast_pixel_cnn_pp.model as model import fast_pixel_cnn_pp.fast_nn as fast_nn import fast_pixel_cnn_pp.plotting as plotting import tensorflow as tf import numpy as np import matplotlib.pyplot as plt import argparse import time import os parser = argparse.ArgumentParser() parser.add_argument( '-b', '--batch_size', type=int, default=16, help='Number of images to generate simultaneously') parser.add_argument( '-i', '--image_size', type=int, default=32, help='Height and width of the image') parser.add_argument( '-s', '--seed', type=int, default=2702, help='Seed for random generation') parser.add_argument( '-c', '--checkpoint', type=str, default='/home/mbz/pixel_cnn_pp/params_cifar.ckpt', help='Location of the pretrained checkpoint') parser.add_argument( '-v', '--save_dir', type=str, default='/tmp', help='Location to save generated images to') args = parser.parse_args() g = tf.Graph() with g.as_default(): print('Creating model') input_channels = 4 # 3 channels for RGB and 1 channel of all ones image_size = (args.batch_size, args.image_size, args.image_size, input_channels) row_input = tf.placeholder( tf.float32, [args.batch_size, 1, args.image_size, input_channels], name='row_input') pixel_input = tf.placeholder( tf.float32, [args.batch_size, 1, 1, input_channels], name='pixel_input') row_id = tf.placeholder(tf.int32, [], name='row_id') col_id = tf.placeholder(tf.int32, [], name='col_id') ema = tf.train.ExponentialMovingAverage(0.9995) model_spec = tf.make_template('model', model.model_spec) sample, fast_nn_out, v_stack = model_spec( row_input, pixel_input, row_id, col_id, image_size, seed=args.seed) all_cache_variables = [ v for v in tf.global_variables() if 'cache' in v.name ] initialize_cache = tf.variables_initializer(all_cache_variables) reset_cache = fast_nn.reset_cache_op() vars_to_restore = { k: v for k, v in ema.variables_to_restore().items() if 'cache' not in k } saver = tf.train.Saver(vars_to_restore) output_images = np.zeros( (args.batch_size, args.image_size, args.image_size, 3)) sess = tf.Session() sess.run(initialize_cache) print('Loading checkpoint %s' % args.checkpoint) saver.restore(sess, args.checkpoint) batch = 0 while True: print('Generating') sess.run(reset_cache) start_time = time.time() for row in range(args.image_size): # Implicit downshift. if row == 0: x_row_input = np.zeros( (args.batch_size, 1, args.image_size, input_channels)) else: x_row_input = output_images[:, (row - 1):row, :, :] x_row_input = np.concatenate( (x_row_input, np.ones( (args.batch_size, 1, args.image_size, 1))), axis=3) sess.run(v_stack, {row_input: x_row_input, row_id: row}) for col in range(args.image_size): # Implicit rightshift. if col == 0: x_pixel_input = np.zeros( (args.batch_size, 1, 1, input_channels)) else: x_pixel_input = output_images[:, row:(row + 1), (col - 1):col, :] x_pixel_input = np.concatenate( (x_pixel_input, np.ones((args.batch_size, 1, 1, 1))), axis=3) feed_dict = { row_id: row, col_id: col, pixel_input: x_pixel_input } pixel_output = sess.run(sample, feed_dict) output_images[:, row:(row + 1), col:(col + 1), :] = pixel_output end_time = time.time() print('Time taken to generate %d images: %.2f seconds' % (args.batch_size, end_time - start_time)) plt.close('all') image_tile = plotting.img_tile( output_images, border_color=1.0, stretch=True) plotting.plot_img(image_tile) plt.savefig(os.path.join(args.save_dir, 'images_%d.png' % batch)) batch += 1 plt.show()