[
  {
    "path": "LICENSE.md",
    "content": "MIT License (MIT)\n\nCopyright (c) 2017 Prajit Ramachandran, Tom Le Paine, Pooya Khorrami, Mohammad Babaeizadeh\n\nPermission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the \"Software\"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE."
  },
  {
    "path": "README.md",
    "content": "\n# Fast PixelCNN++: speedy image generation\n\n*Real time generation of 16 32-by-32 images. Naive generation (left) vs. fast generation (right).*\n\n\n<p align=\"center\">\n  <img src=\"assets/speedup.gif\" width=\"650\" />\n</p>\n\nWe speed up the image generation algorithm of [PixelCNN++](https://github.com/openai/pixel-cnn) by avoiding redundant computation through caching. Naive generation discards computation that can be re-used and performs additional computation that will not be used to generate a particular pixel.  Naive generation can take up to 11 minutes to generate 16 32-by-32 images on a Tesla K40 GPU. By re-using previous computation and only performing the minimum amount of computation required, we achieve up to a 183 times speedup over the naive generation algorithm.\n\n<p align=\"center\">\n  <img src=\"assets/pixelcnn_speedup.png\" width=\"480\" />\n</p>\n\n## How to run\nWe have tested our code with Python 3 and TensorFlow 1.0. You may need to make small changes for other versions of Python or TensorFlow.\n\nInstructions to run:\n* Install [TensorFlow 1.0](https://www.tensorflow.org/install/), Numpy, and Matplotlib\n* Download and unzip [OpenAI's pretrained PixelCNN++ model](http://alpha.openai.com/pxpp.zip). After unzipping, there should be a file called `params_cifar.ckpt`\n* Run the script with `CUDA_VISIBLE_DEVICES=0 python generate.py --checkpoint=/path/to/params_cifar.ckpt --save_dir=/path/to/save/generated/images`\n\nThe script will continually generate images in a loop and write out the images to `--save_dir`. You can exit the script at any time by interrupting with Control-C.\n\n## How it works\n\n### What is PixelCNN++, and why should I use it?\n\nPixelCNN++ is a generative model that uses all previously generated pixels as information to generate the next pixel. That is, to generate the 10th pixel in the image, PixelCNN++ will look at pixels 1-9 to model the output distribution of pixel 10: `P(pixel 10 | pixel 1, ..., pixel 9)`. Similarly, pixel 11 will look at pixels 1-10, and this process continues for all pixels. This property makes PixelCNN an *autoregressive* model, where each pixel is modeled by the history of previous pixels. What makes PixelCNN unique is that it uses clever, fast methods to coalesce information from previous pixels, which is crucial for training speed.\n\nPixelCNN was [original developed by DeepMind](https://arxiv.org/abs/1606.05328) and [improved upon by OpenAI](https://openreview.net/pdf?id=BJrFC6ceg) in PixelCNN++. These models have achieved state-of-the-art results on a variety of image generation benchmarks. They are straightforward to train, have a large capacity to model complex inputs, and are able to generate crisp, attractive images. For example, PixelCNN has [recently been used for superresolution](https://arxiv.org/abs/1702.00783).\n\nOne of the main downsides of autoregressive models compared to other generative models like [Generative Adversarial Networks](https://arxiv.org/abs/1701.00160) and [Variational Autoencoders](https://arxiv.org/abs/1606.05908) is that autoregressive models must generate pixels one at a time, whereas other methods can generate the entire image at once. Our method speeds up the generation process for PixelCNN.\n\n### Speeding up a simple 1D example with dilation\n\nBefore jumping into the details of speeding up PixelCNN, let's focus on a simpler 1D autoregressive model: [Wavenet](https://arxiv.org/abs/1609.03499). The details presented here are the similar to that of our other repository, [Fast Wavenet](https://github.com/tomlepaine/fast-wavenet), which you can refer to for more details.\n\n![](assets/wavenet.png)\n\nThe Wavenet graph (on the left) looks like a binary tree. A node is convolved with it's `nth` previous neighbor, where `n` is a power of 2. Since `n`, the *dilation*, is increasing by a factor of two every layer, the range of nodes that are combined together, the *receptive field*, increases exponentially. On every generation step, information from all nodes in the receptive field (8 in the picture) must be combined. A naive generation algorithm simply repeats the entire tree of computation for every generation step. This is easy to implement, but is slow.\n\nYou may have noticed that when generating consecutive outputs, a large portion of the tree is reused. For example, call the current step in the picture `t` and imagine generating the output for `t + 2`. In this case, three of the four orange nodes in the first hidden layer can be reused! It is a waste of time to recompute them.\n\nThis brings us to the core of our method: caching previously computed hidden states. As illustrated in the image on the right, we maintain a cache for each layer which holds previously computed hidden states. The size of the cache is equal to the dilation factor of the hidden layer since the model must look back `n` steps at the hidden layer. The cache acts like a queue: the oldest hidden state is popped from the the front of the queue, which is exactly equivalent to the normal dilated convolution. After a hidden state is computed, it must then be pushed into the back of the queue, to be used exactly `n` steps in future from now. This process repeats itself, giving a fast generation algorithm that avoids the exponential computation of the naive approach.\n\n### Speeding up strided convolutions\n\nThe previous section used dilated convolutions. In this case, node `t` is convolved with node `t - n`, and node `t + 1` is convolved with node `t + 1 - n`. This implies that the number of hidden states in a layer is equal to the number of inputs, making caching straightforward. However, using strided convolutions makes the problem more difficult because the number of states in a hidden layer is different from the number of inputs.\n\nStrided convolutions are downsampling layers. This means that there are fewer hidden states than inputs. A typical convolution will convolve over a local neighborhood and then slide 1 position over and repeat the procedure. For example, nodes `t - 1` and `t` will be convolved, and then nodes `t` and `t + 1` will be convolved. Striding affects the number of positions that the convolution will slide over. In the previous example, the stride is 1. However, when the stride is greater than 1, the input is downsampled. For example, take the stride to be 2. Nodes `t - 1` and `t` will be convolved, and then nodes `t + 1` and `t + 2` will be convolved, since the convolution has slided 2 positions over. This means every pair of inputs to the layer only produces one output, so the number of hidden states is smaller than the number of inputs.\n\nSimilarly, there are upsampling layers, which are strided transposed convolutions. With a stride of `s`, upsampling layers will produce `s` outputs for every input to the layer. This increases the number of hidden states compared to the number of inputs to the layer. PixelCNN++ uses 2 downsampling layers followed by 2 upsampling layers, each of stride 2, meaning that the number of generated pixels is the same as the number of input pixels (i.e. `D / 2 / 2 * 2 * 2 = D`). A detailed explanation of strides and transposed convolutions can be found [here](https://github.com/vdumoulin/conv_arithmetic) and [here](https://arxiv.org/abs/1603.07285).\n\nBecause of the differing number of hidden states, caches cannot be updated in every timestep. Thus, each cache has an additional property `cache every`, where the cache is only updated every `cache every` steps. Every downsampling layer increases the `cache every` property of the layer by the stride. Conversely, every upsampling layer decreases the `cache every` property of the layer by the stride.\n\n![](assets/strided.png)\n\nThe figure above shows an example model with 2 upsampling and 2 downsampling layers each with a stride of 2. Orange nodes are computed in the current timestep, blue nodes are previously cached states, and gray nodes are not involved in the current timestep.\n* At the first timestep `t = 0`, the first input is used to compute and cache all nodes for which there is sufficient information to generate, including the first four outputs.\n* At `t = 1`, there are no nodes that have sufficient information to be computed, but the output for `t = 1` has already been computed at `t = 0`.\n* At `t = 2`, there is one new node that now has sufficient information to be computed, although the output for `t = 2` has also been computed at `t = 0`. \n* The `t = 3` scenario is similar to `t = 1`.\n* At `t = 4`, there is enough information to compute multiple hidden states and generate the next four outputs. This is analogous to the `t = 0` scenario.\n* `t = 5` is analogous to `t = 1`, and this cycle is followed for all future time steps.\n\nIn our code, we also use a property `run every` which is equal to the `cache every` property of the next layer. This allows us to avoid computation if the next layer is simply going to ignore its input.\n\n### Speeding up PixelCNN++\n\nAfter understanding the previous sections, it should seem relatively straightforward to generalize the 1D example to the 2D case. Indeed, our method generalizes with very few changes. The caches for each layer are now 2D, with a height equal to the filter height and a width equal to the image width. After an entire row is generated, the oldest row of the cache is popped and the new row is pushed. Because strided convolutions are used, we use the `cache every` idea detailed in the previous section.\n\nPixelCNN maintains two streams of computation: a vertical stream and a horizontal stream. Oversimplifying the details a bit, the vertical stream looks at all pixels above the current pixel while the horizontal stream looks at all pixels immediately left of the current pixel, satisfying the autoregressive property in the 2D case (see the PixelCNN papers for a more precise explanation). The horizontal stream also takes in the vertical stream as another input. In our code, we compute the vertical stream one row at a time, cache it, and use it to compute the horizontal stream (and the generated output) one pixel at a time.\n\nAnd with this, we are able to achieve orders of magnitude speedups for PixelCNN++ generation! Increasing the batch size demonstrates the scalability of our method. While the naive implementation scales linearly with the batch size (because of 100% GPU utilization), our method enjoys superior scaling because of its minimal computational requirements.\n\n<p align=\"center\">\n  <img src=\"assets/pixelcnn_speedup.png\" width=\"480\" />\n</p>\n\n### Beyond PixelCNN++\n\nThe core concepts detailed here will easily generalize for different modalities. For example, speeding up the [Video Pixel Network](https://arxiv.org/abs/1610.00527) for generating videos should look straightforward, and will probably produce even more impressive speedups because of the higher computational demand. We look forward to hearing about practical uses of fast generation for convolutional autoregressive models!\n\n## Authors\n\n* [Prajit Ramachandran](https://github.com/PrajitR)\n* [Tom Le Paine](https://github.com/tomlepaine) \n* [Pooya Khorrami](https://github.com/pkhorrami4) \n* [Mohammad Babaeizadeh](https://github.com/mbz)\n\nIf you found this work useful, please cite our [paper](https://arxiv.org/abs/1704.06001).\n\n```\n@article{ramachandran2017fast,\n  title={Fast Generation for Convolutional Autoregressive Models},\n  author={Ramachandran, Prajit and Paine, Tom Le and Khorrami, Pooya and Babaeizadeh, Mohammad and Chang, Shiyu and Zhang, Yang and Hasegawa-Johnson, Mark A and Campbell, Roy H and Huang, Thomas S},\n  journal={arXiv preprint arXiv:1704.06001},\n  year={2017}\n}\n```\n"
  },
  {
    "path": "fast_pixel_cnn_pp/__init__.py",
    "content": ""
  },
  {
    "path": "fast_pixel_cnn_pp/fast_nn.py",
    "content": "from . import nn\n\nimport tensorflow as tf\nfrom tensorflow.contrib.framework.python.ops import add_arg_scope\nimport numpy as np\n\nfrom collections import namedtuple\nimport math\n\nLayerInfo = namedtuple('LayerInfo', [\n    'image_size', 'batch', 'image_height', 'image_width', 'image_channels',\n    'filter_size', 'filter_height', 'filter_width', 'filter_channels',\n    'input_channels', 'nonlinearity'\n])\n\nRESET_CACHE_COLLECTION = 'reset_cache'\n\n\ndef down_shift(image):\n    '''Shift all rows down by one, using zeros as the first row and throwing away the last row.'''\n    all_image_except_last_row = image[:, :-1, :, :]\n    zero_row = np.zeros_like(image[:, :1, :, :])\n    return np.concatenate([zero_row, all_image_except_last_row], axis=1)\n\n\ndef right_shift(image):\n    '''Shift all columns right by one, using zeros as the first column and throwing away the last column.'''\n    all_image_except_last_column = image[:, :, :-1, :]\n    zero_column = np.zeros_like(image[:, :, :1, :])\n    return np.concatenate([zero_column, all_image_except_last_column], axis=2)\n\n\ndef _extract_layer_info(network_info, input_, nonlinearity):\n    '''Utility function to extract information about the current layer.'''\n    image_size, filter_size = network_info\n    batch, image_height, image_width, image_channels = image_size\n    filter_height, filter_width, filter_channels = filter_size\n    input_channels = int(input_.get_shape()[-1])\n    if nonlinearity is None:\n        nonlinearity = tf.identity\n    return LayerInfo(image_size, batch, image_height, image_width,\n                     image_channels, filter_size, filter_height, filter_width,\n                     filter_channels, input_channels, nonlinearity)\n\n\ndef _create_cache(batch, cache_height, cache_width, channels):\n    '''Creates a cache, which is used to avoid redundant computation.'''\n    cache = tf.Variable(\n        initial_value=np.zeros((batch, cache_height, cache_width, channels)),\n        dtype=tf.float32,\n        name='cache',\n        trainable=False)\n    # Reset the cache between generations.\n    reset_cache = cache.assign(tf.zeros_like(cache))\n    tf.add_to_collection(RESET_CACHE_COLLECTION, reset_cache)\n    return cache\n\n\ndef reset_cache_op():\n    '''Returns an op to reset all created caches. Used between different generation calls.'''\n    return tf.group(*tf.get_collection(RESET_CACHE_COLLECTION))\n\n\ndef _get_conv_variables(filter_size, input_channels, scope_name, counters):\n    '''Creates and returns variables used for convolution.'''\n    filter_height, filter_width, filter_channels = filter_size\n    with tf.variable_scope(nn.get_name(scope_name, counters)):\n        V = tf.get_variable(\n            'V',\n            [filter_height, filter_width, input_channels, filter_channels],\n            dtype=tf.float32)\n        g = tf.get_variable('g', [filter_channels], dtype=tf.float32)\n        b = tf.get_variable('b', [filter_channels], dtype=tf.float32)\n    return V, g, b\n\n\ndef _get_conv2d_variables(filter_size, input_channels, counters):\n    '''Creates and returns the variables used for a normal 2D convolution.'''\n    V, g, b = _get_conv_variables(filter_size, input_channels, 'conv2d',\n                                  counters)\n    filter_channels = filter_size[-1]\n    W = tf.reshape(g, [1, 1, 1, filter_channels]) * tf.nn.l2_normalize(\n        V, [0, 1, 2])  # Weight normalization.\n    return W, b\n\n\ndef _get_deconv2d_variables(filter_size, input_channels, counters):\n    '''Creates and returns the variables used for a 2D transposed convolution (deconvolution).'''\n    V, g, b = _get_conv_variables(filter_size, input_channels, 'deconv2d',\n                                  counters)\n    filter_channels = filter_size[-1]\n    W = tf.reshape(g, [1, 1, filter_channels, 1]) * tf.nn.l2_normalize(\n        V, [0, 1, 3])  # Weight normalization.\n    return W, b\n\n\ndef _mod_equal_0(row_or_col, every):\n    '''Returns a boolean tensor representing (row_or_col % every == 0)'''\n    return tf.equal(tf.mod(row_or_col, every), 0)\n\n\ndef _roll_cache(cache):\n    '''Pop off the oldest row of the cache to make space for the newest row of input.'''\n    batch, _, cache_width, channels = cache.get_shape()\n    without_dropped_row = cache[:, 1:, :, :]\n    zero_row = tf.zeros([batch, 1, cache_width, channels])\n    rolled_cache = tf.concat([without_dropped_row, zero_row], 1)\n    return cache.assign(rolled_cache)\n\n\n@add_arg_scope\ndef down_shifted_conv2d(row_input,\n                        network_info,\n                        stride,\n                        row,\n                        cache_every,\n                        run_every,\n                        nonlinearity=None,\n                        counters={}):\n    '''Performs a convolution for the vertical stack.'''\n    li = _extract_layer_info(network_info, row_input, nonlinearity)\n\n    ## Create cache.\n    cache_height = li.filter_height  # Just large enough to fit the filter.\n    padding = li.filter_width // 2  # Horizontal padding to make VALID convolution maintain the width of input. \n    cache_width = li.image_width + 2 * padding  # Cache width is the image width plus padding to the left and right.\n    cache = _create_cache(li.batch, cache_height, cache_width,\n                          li.input_channels)\n\n    ## Update cache.\n    should_cache = _mod_equal_0(row, cache_every)\n    cache_func = lambda: cache[:, -1:, padding:(padding + li.image_width), :].assign(row_input)\n    do_nothing_cache_func = lambda: row_input\n    assign_to_cache = tf.cond(should_cache, cache_func, do_nothing_cache_func)\n\n    ## Compute output.\n    W, b = _get_conv2d_variables(li.filter_size, li.input_channels, counters)\n    with tf.control_dependencies([assign_to_cache]):\n        should_run = _mod_equal_0(row, run_every)\n\n        # Compute output for the entire row.\n        run_func = lambda: li.nonlinearity(tf.nn.conv2d(cache, W, [1, 1, stride, 1], 'VALID') + b)\n\n        output_width = int(math.ceil(li.image_width / float(stride)))\n        do_nothing_run_func = lambda: tf.zeros([li.batch, 1, output_width, li.filter_channels])\n\n        outputs = tf.cond(should_run, run_func, do_nothing_run_func)\n        outputs.set_shape([li.batch, 1, output_width, li.filter_channels])\n\n        # Ensure that roll_cache() is run, and only after computing the outputs.\n        with tf.control_dependencies([outputs]):\n            roll_cache_op = tf.cond(should_cache, lambda: _roll_cache(cache),\n                                    lambda: cache)\n            with tf.control_dependencies([roll_cache_op]):\n                outputs = tf.identity(outputs)\n\n    return outputs\n\n\n@add_arg_scope\ndef down_right_shifted_conv2d(pixel_input,\n                              network_info,\n                              row,\n                              col,\n                              cache_every,\n                              run_every,\n                              nonlinearity=None,\n                              counters={}):\n    '''Performs a convolution for the horizontal stack.'''\n    li = _extract_layer_info(network_info, pixel_input, nonlinearity)\n\n    ## Create cache.\n    cache_height = li.filter_height  # Just large enough to fit the filter.\n    left_pad = li.filter_width - 1  # Only need left padding because always convolving to the left.\n    cache_width = li.image_width + left_pad\n    cache = _create_cache(li.batch, cache_height, cache_width,\n                          li.input_channels)\n    cache_col = col // cache_every  # Accounts for downsampling due to stride in previous layers.\n\n    ## Update cache.\n    should_cache = tf.logical_and(\n        _mod_equal_0(row, cache_every), _mod_equal_0(col, cache_every))\n\n    pixel_col = cache_col + left_pad  # Accounts for padding in the cache.\n    cache_func = lambda: cache[:, -1:, pixel_col:(pixel_col + 1), :].assign(pixel_input)\n\n    do_nothing_cache_func = lambda: pixel_input\n\n    assign_to_cache = tf.cond(should_cache, cache_func, do_nothing_cache_func)\n\n    ## Compute output.\n    W, b = _get_conv2d_variables(li.filter_size, li.input_channels, counters)\n    with tf.control_dependencies([assign_to_cache]):\n        should_run = tf.logical_and(\n            _mod_equal_0(row, run_every), _mod_equal_0(col, run_every))\n\n        # Extract the local neighborhood of the current column in the cache to be convolved with the filter.\n        # This is simply a matrix multiply, since the neighborhood is the size of the filter.\n        width_start = cache_col\n        width_end = width_start + li.filter_width\n        cache_neighborhood = cache[:, :, width_start:width_end, :]\n        run_func = lambda: li.nonlinearity(tf.nn.conv2d(cache_neighborhood, W, [1, 1, 1, 1], 'VALID') + b)\n\n        do_nothing_run_func = lambda: tf.zeros([li.batch, 1, 1, li.filter_channels])\n\n        outputs = tf.cond(should_run, run_func, do_nothing_run_func)\n        outputs.set_shape([li.batch, 1, 1, li.filter_channels])\n\n        # Ensure that roll_cache() is run, and only after computing the outputs.\n        with tf.control_dependencies([outputs]):\n            # Roll out an entire row of the cache only after generating output for the last column.\n            is_end_of_row = tf.equal(cache_col, li.image_width - 1)\n            should_roll = tf.logical_and(should_cache, is_end_of_row)\n            maybe_roll = tf.cond(should_roll, lambda: _roll_cache(cache),\n                                 lambda: cache)\n            with tf.control_dependencies([maybe_roll]):\n                outputs = tf.identity(outputs)\n\n    return outputs\n\n\ndef _create_deconv_cache(li, stride):\n    '''Creates the cache for the two deconv layers.'''\n    cache_height = li.filter_height  # Just large enough to fit the filter.\n    # The deconv will increases the number of outputs `stride` times. \n    # The extra width comes from the tf.nn.conv2d_transpose() operation.\n    cache_width = li.image_width * stride + li.filter_width - 1\n    cache = _create_cache(li.batch, cache_height, cache_width,\n                          li.filter_channels)\n    return cache, cache_height, cache_width\n\n\n@add_arg_scope\ndef down_shifted_deconv2d(row_input,\n                          network_info,\n                          row,\n                          cache_every,\n                          run_every,\n                          stride=2,\n                          nonlinearity=None,\n                          counters={}):\n    '''Performs a transposed convolution for the vertical stack.'''\n    li = _extract_layer_info(network_info, row_input, nonlinearity)\n\n    ## Create cache.\n    cache, cache_height, cache_width = _create_deconv_cache(li, stride)\n\n    ## Update cache.\n    should_cache = _mod_equal_0(row, cache_every)\n\n    W, b = _get_deconv2d_variables(li.filter_size, li.input_channels, counters)\n\n    def cache_func():\n        # Compute deconv output for the entire row.\n        outputs = tf.nn.conv2d_transpose(\n            row_input,\n            W,\n            output_shape=[\n                li.batch, cache_height, cache_width, li.filter_channels\n            ],\n            strides=[1, stride, stride, 1],\n            padding='VALID')\n        outputs = li.nonlinearity(outputs + b)\n\n        # Store the output in the cache.\n        with tf.control_dependencies([outputs]):\n            # With stride=2, this is simply cache.assign(outputs) since the old rows in the cache\n            # will all have been rolled out. \n            update_cache = cache.assign(cache + outputs)\n        return update_cache\n\n    do_nothing_cache_func = lambda: tf.zeros_like(cache)\n\n    assign_to_cache = tf.cond(should_cache, cache_func, do_nothing_cache_func)\n\n    ## Compute output.\n    with tf.control_dependencies([assign_to_cache]):\n        should_run = _mod_equal_0(row, run_every)\n\n        def run_func():\n            # The cache stores the deconv output, so just return the next (first) row and roll.\n            output = cache[:, 0:1, 1:-1, :]\n            with tf.control_dependencies([output]):\n                with tf.control_dependencies([_roll_cache(cache)]):\n                    output = tf.identity(output)\n            return output\n\n        do_nothing_run_func = lambda: tf.zeros([li.batch, 1, cache_width - 2, li.filter_channels])\n\n        outputs = tf.cond(should_run, run_func, do_nothing_run_func)\n        outputs.set_shape([li.batch, 1, cache_width - 2, li.filter_channels])\n\n    return outputs\n\n\n@add_arg_scope\ndef down_right_shifted_deconv2d(pixel_input,\n                                network_info,\n                                row,\n                                col,\n                                cache_every,\n                                run_every,\n                                stride=2,\n                                nonlinearity=None,\n                                counters={}):\n    '''Performs a transposed convolution for the horizontal stack.'''\n    li = _extract_layer_info(network_info, pixel_input, nonlinearity)\n\n    ## Create cache.\n    cache, cache_height, cache_width = _create_deconv_cache(li, stride)\n\n    ## Update cache.\n    should_cache = tf.logical_and(\n        _mod_equal_0(row, cache_every), _mod_equal_0(col, cache_every))\n\n    W, b = _get_deconv2d_variables(li.filter_size, li.input_channels, counters)\n\n    def cache_func():\n        outputs = tf.nn.conv2d_transpose(\n            pixel_input,\n            W,\n            output_shape=[\n                li.batch, li.filter_height, li.filter_width, li.filter_channels\n            ],\n            strides=[1, stride, stride, 1],\n            padding='VALID')\n        outputs = li.nonlinearity(outputs + b)\n\n        # Store the output in the cache.\n        with tf.control_dependencies([outputs]):\n            cache_col = col // cache_every\n            update_cache = cache[:, :, (stride * cache_col):(stride * (\n                cache_col + 1)), :].assign(outputs)\n        return update_cache\n\n    do_nothing_cache_func = lambda: tf.zeros([li.batch, li.filter_height, li.filter_width, li.filter_channels])\n\n    assign_to_cache = tf.cond(should_cache, cache_func, do_nothing_cache_func)\n\n    ## Compute output.\n    with tf.control_dependencies([assign_to_cache]):\n        should_run = tf.logical_and(\n            _mod_equal_0(row, run_every), _mod_equal_0(col, run_every))\n\n        def run_func():\n            output_col = col // run_every\n            output = cache[:, 0:1, output_col:(output_col + 1), :]\n\n            # Only roll after the end of the row has been reached.\n            with tf.control_dependencies([output]):\n                is_end_of_row = tf.equal(output_col,\n                                         cache_width - li.filter_width)\n                maybe_roll = tf.cond(is_end_of_row, lambda: _roll_cache(cache),\n                                     lambda: cache)\n                with tf.control_dependencies([maybe_roll]):\n                    output = tf.identity(output)\n            return output\n\n        do_nothing_run_func = lambda: tf.zeros([li.batch, 1, 1, li.filter_channels])\n\n        outputs = tf.cond(should_run, run_func, do_nothing_run_func)\n        outputs.set_shape([li.batch, 1, 1, li.filter_channels])\n\n    return outputs\n\n\ndef sum_rightshift_downshift(rightshifted_pixel, downshifted_row, col):\n    '''Sums the vertical and horizontal stack.'''\n    downshifted_pixel = downshifted_row[:, :, col:(col + 1), :]\n    return rightshifted_pixel + downshifted_pixel\n\n\ndef _conditional_info(h, batch, filter_channels, counters):\n    '''Computes the conditional information for the resnet layer.'''\n    with tf.variable_scope(nn.get_name('conditional_weights', counters)):\n        hw = tf.get_variable(\n            'hw',\n            shape=[h.get_shape()[-1], 2 * filter_channels],\n            dtype=tf.float32,\n            initializer=tf.random_normal_initializer(0, 0.05),\n            trainable=True)\n        conditional_info = tf.reshape(\n            tf.matmul(h, hw), [batch, 1, 1, 2 * filter_channels])\n        return conditional_info\n\n\ndef _gated_nonlinearity(out):\n    a, b = tf.split(out, 2, 3)\n    return a * tf.nn.sigmoid(b)\n\n\n@add_arg_scope\ndef gated_resnet_vstack_only(row_input,\n                             network_info,\n                             row,\n                             cache_every,\n                             run_every,\n                             extra_row_input=None,\n                             h=None,\n                             nonlinearity=None,\n                             counters={}):\n    '''Performs gated resnet computations for the vertical stack.'''\n    li = _extract_layer_info(network_info, row_input, nonlinearity)\n\n    out = li.nonlinearity(row_input)\n    out = down_shifted_conv2d(\n        out,\n        network_info,\n        stride=1,\n        row=row,\n        cache_every=cache_every,\n        run_every=run_every,\n        nonlinearity=None,\n        counters=counters)\n    if extra_row_input is not None:\n        # For skip connections between downsampling and upsampling layers.\n        out += nn.nin(\n            li.nonlinearity(extra_row_input),\n            li.filter_channels,\n            counters=counters)\n\n    out = li.nonlinearity(out)\n    network_info = (li.image_size, (li.filter_height, li.filter_width, 2 *\n                                    li.filter_channels))\n    out = down_shifted_conv2d(\n        out,\n        network_info,\n        stride=1,\n        row=row,\n        cache_every=cache_every,\n        run_every=run_every,\n        nonlinearity=None,\n        counters=counters)\n\n    if h is not None:\n        out += _conditional_info(h, li.batch, li.filter_channels, counters)\n\n    out = row_input + _gated_nonlinearity(out)\n    return out\n\n\n@add_arg_scope\ndef gated_resnet_hstack(pixel_input,\n                        v_stack_row_input,\n                        network_info,\n                        row,\n                        col,\n                        cache_every,\n                        run_every,\n                        extra_pixel_input=None,\n                        h=None,\n                        nonlinearity=None,\n                        counters={}):\n    '''Performs gated resnet computations for the horizontal stack.'''\n    li = _extract_layer_info(network_info, pixel_input, nonlinearity)\n\n    out = li.nonlinearity(pixel_input)\n    out = down_right_shifted_conv2d(\n        out,\n        network_info,\n        row=row,\n        col=col,\n        cache_every=cache_every,\n        run_every=run_every,\n        nonlinearity=None,\n        counters=counters)\n\n    # Horizontal stack also takes in as input the vertical stack.\n    cache_col = col // cache_every  # Compensates for striding in previous layers.\n    v_stack_pixel = v_stack_row_input[:, :, cache_col:(cache_col + 1), :]\n    v_shape = v_stack_pixel.get_shape()\n    v_stack_pixel.set_shape([li.batch, 1, 1, li.input_channels])\n\n    if extra_pixel_input is not None:\n        # For skip connections between downsampling and upsampling layers.\n        v_stack_pixel = tf.concat([v_stack_pixel, extra_pixel_input], 3)\n\n    out += nn.nin(\n        li.nonlinearity(v_stack_pixel), li.filter_channels, counters=counters)\n    out = li.nonlinearity(out)\n    network_info = (li.image_size, (li.filter_height, li.filter_width, 2 *\n                                    li.filter_channels))\n    out = down_right_shifted_conv2d(\n        out,\n        network_info,\n        row=row,\n        col=col,\n        cache_every=cache_every,\n        run_every=run_every,\n        nonlinearity=None,\n        counters=counters)\n\n    if h is not None:\n        out += _conditional_info(h, li.batch, li.filter_channels, counters)\n\n    out = pixel_input + _gated_nonlinearity(out)\n    return out\n"
  },
  {
    "path": "fast_pixel_cnn_pp/model.py",
    "content": "from . import fast_nn\nfrom . import nn\n\nimport tensorflow as tf\nfrom tensorflow.contrib.framework.python.ops import arg_scope\nimport numpy as np\n\nUPDATE_V_STACK = 'update_v_stack'\n\n\ndef undo_zeroth_row_bias_when_downshifting(row_output, row):\n    '''The down_shifted_conv2d adds a bias to the row of all zeros. This removes that bias.'''\n    return tf.cond(\n        tf.equal(row, 0), lambda: tf.zeros_like(row_output),\n        lambda: row_output)\n\n\ndef undo_zeroth_column_bias_when_rightshifting(pixel_output, col):\n    '''The down_shifted_conv2d adds a bias to the column of all zeros. This removes that bias.'''\n    return tf.cond(\n        tf.equal(col, 0), lambda: tf.zeros_like(pixel_output),\n        lambda: pixel_output)\n\n\ndef cache_v_stack_variable(v_stack_variable):\n    '''Caches vertical stack hidden states. This avoids the need to pass the computed\n        vertical stack in the feed_dict, which would involve CPU to GPU transfers.'''\n    cache = tf.Variable(\n        initial_value=np.zeros(v_stack_variable.get_shape().as_list()),\n        name='v_stack_cache',\n        dtype=tf.float32)\n    update_v_stack_cache = cache.assign(v_stack_variable)\n    tf.add_to_collection(UPDATE_V_STACK, update_v_stack_cache)\n    reset_cache = cache.assign(tf.zeros_like(cache))\n    tf.add_to_collection(fast_nn.RESET_CACHE_COLLECTION, reset_cache)\n    return cache\n\n\ndef model_spec(row_input,\n               pixel_input,\n               row,\n               col,\n               image_size,\n               h=None,\n               nr_resnet=5,\n               nr_filters=160,\n               nr_logistic_mix=10,\n               resnet_nonlinearity='concat_elu',\n               seed=None):\n    '''Creates the model. Follows the same model_spec structure as the original PixelCNN++.'''\n    counters = {}\n    with arg_scope(\n        [\n            fast_nn.down_shifted_conv2d, fast_nn.down_right_shifted_conv2d,\n            fast_nn.down_shifted_deconv2d, fast_nn.down_right_shifted_deconv2d,\n            fast_nn.gated_resnet_vstack_only, fast_nn.gated_resnet_hstack,\n            nn.dense\n        ],\n            counters=counters):\n\n        # Parse resnet nonlinearity argument.\n        if resnet_nonlinearity == 'concat_elu':\n            resnet_nonlinearity = nn.concat_elu\n        elif resnet_nonlinearity == 'elu':\n            resnet_nonlinearity = tf.nn.elu\n        elif resnet_nonlinearity == 'relu':\n            resnet_nonlinearity = tf.nn.relu\n        else:\n            raise ('resnet nonlinearity ' + resnet_nonlinearity +\n                   ' is not supported')\n\n        with arg_scope(\n            [fast_nn.gated_resnet_vstack_only, fast_nn.gated_resnet_hstack],\n                nonlinearity=resnet_nonlinearity,\n                h=h):\n\n            u_filter = [2, 3, nr_filters]\n            ul_filter = [2, 2, nr_filters]\n            cache_every, run_every = 1, 1\n\n            ## Downsampling pass.\n\n            # The initial computation to the network. Importantly, it is assumed that the\n            # vertical stack inputs are already downshifted, and the horizontal stack inputs\n            # are already rightshifted. \n            v_stack = []\n            u_list_input = fast_nn.down_shifted_conv2d(\n                row_input, (image_size, u_filter),\n                stride=1,\n                row=row,\n                cache_every=cache_every,\n                run_every=run_every)\n            u_list = [\n                undo_zeroth_row_bias_when_downshifting(u_list_input, row)\n            ]\n            v_stack.append(u_list[-1])\n\n            downshift_hstack_input = fast_nn.down_shifted_conv2d(\n                row_input, (image_size, [1, 3, nr_filters]),\n                stride=1,\n                row=row,\n                cache_every=cache_every,\n                run_every=run_every)\n            downshift_hstack_input = undo_zeroth_row_bias_when_downshifting(\n                downshift_hstack_input, row)\n            downshift_hstack_input = cache_v_stack_variable(\n                downshift_hstack_input)\n            v_stack.append(downshift_hstack_input)\n            rightshift_hstack_input = fast_nn.down_right_shifted_conv2d(\n                pixel_input, (image_size, [2, 1, nr_filters]),\n                row=row,\n                col=col,\n                cache_every=cache_every,\n                run_every=run_every)\n            rightshift_hstack_input = undo_zeroth_column_bias_when_rightshifting(\n                rightshift_hstack_input, col)\n            ul_list = [\n                fast_nn.sum_rightshift_downshift(rightshift_hstack_input,\n                                                 downshift_hstack_input, col)\n            ]\n\n            # Gated resnet layers.\n            image_size = (image_size[0], image_size[1], image_size[2],\n                          nr_filters)\n            for rep in range(nr_resnet):\n                u_list.append(\n                    fast_nn.gated_resnet_vstack_only(\n                        u_list[-1], (image_size, u_filter),\n                        row=row,\n                        cache_every=cache_every,\n                        run_every=run_every,\n                        nonlinearity=resnet_nonlinearity))\n                v_stack.append(u_list[-1])\n                ul_list.append(\n                    fast_nn.gated_resnet_hstack(\n                        ul_list[-1],\n                        cache_v_stack_variable(u_list[-1]), (image_size,\n                                                             ul_filter),\n                        row=row,\n                        col=col,\n                        cache_every=cache_every,\n                        run_every=run_every,\n                        nonlinearity=resnet_nonlinearity))\n\n            # Downsample.\n            cache_every, run_every = 1, 2\n            u_list.append(\n                fast_nn.down_shifted_conv2d(\n                    u_list[-1], (image_size, u_filter),\n                    stride=2,\n                    row=row,\n                    cache_every=cache_every,\n                    run_every=run_every))\n            v_stack.append(u_list[-1])\n            ul_list.append(\n                fast_nn.down_right_shifted_conv2d(\n                    ul_list[-1], (image_size, ul_filter),\n                    row=row,\n                    col=col,\n                    cache_every=cache_every,\n                    run_every=run_every))\n\n            cache_every, run_every = 2, 2\n            image_size = (image_size[0], image_size[1] // 2,\n                          image_size[2] // 2, nr_filters)\n\n            # Gated resnet layers.\n            for rep in range(nr_resnet):\n                u_list.append(\n                    fast_nn.gated_resnet_vstack_only(\n                        u_list[-1], (image_size, u_filter),\n                        row=row,\n                        cache_every=cache_every,\n                        run_every=run_every,\n                        nonlinearity=resnet_nonlinearity))\n                v_stack.append(u_list[-1])\n                ul_list.append(\n                    fast_nn.gated_resnet_hstack(\n                        ul_list[-1],\n                        cache_v_stack_variable(u_list[-1]), (image_size,\n                                                             ul_filter),\n                        row=row,\n                        col=col,\n                        cache_every=cache_every,\n                        run_every=run_every,\n                        nonlinearity=resnet_nonlinearity))\n\n            # Downsample.\n            cache_every, run_every = 2, 4\n            u_list.append(\n                fast_nn.down_shifted_conv2d(\n                    u_list[-1], (image_size, u_filter),\n                    stride=2,\n                    row=row,\n                    cache_every=cache_every,\n                    run_every=run_every))\n            v_stack.append(u_list[-1])\n            ul_list.append(\n                fast_nn.down_right_shifted_conv2d(\n                    ul_list[-1], (image_size, ul_filter),\n                    row=row,\n                    col=col,\n                    cache_every=cache_every,\n                    run_every=run_every))\n\n            cache_every, run_every = 4, 4\n            image_size = (image_size[0], image_size[1] // 2,\n                          image_size[2] // 2, nr_filters)\n\n            # Gated resnet layers.\n            for rep in range(nr_resnet):\n                u_list.append(\n                    fast_nn.gated_resnet_vstack_only(\n                        u_list[-1], (image_size, u_filter),\n                        row=row,\n                        cache_every=cache_every,\n                        run_every=run_every,\n                        nonlinearity=resnet_nonlinearity))\n                v_stack.append(u_list[-1])\n                ul_list.append(\n                    fast_nn.gated_resnet_hstack(\n                        ul_list[-1],\n                        cache_v_stack_variable(u_list[-1]), (image_size,\n                                                             ul_filter),\n                        row=row,\n                        col=col,\n                        cache_every=cache_every,\n                        run_every=run_every,\n                        nonlinearity=resnet_nonlinearity))\n\n            # Upsampling pass.\n            u = u_list.pop()\n            ul = ul_list.pop()\n            for rep in range(nr_resnet):\n                u = fast_nn.gated_resnet_vstack_only(\n                    u, (image_size, u_filter),\n                    extra_row_input=u_list.pop(),\n                    row=row,\n                    cache_every=cache_every,\n                    run_every=run_every,\n                    nonlinearity=resnet_nonlinearity)\n                v_stack.append(u)\n                ul = fast_nn.gated_resnet_hstack(\n                    ul,\n                    cache_v_stack_variable(u), (image_size, ul_filter),\n                    extra_pixel_input=ul_list.pop(),\n                    row=row,\n                    col=col,\n                    cache_every=cache_every,\n                    run_every=run_every,\n                    nonlinearity=resnet_nonlinearity)\n\n            # Upsample.\n            cache_every, run_every = 4, 2\n            u = fast_nn.down_shifted_deconv2d(\n                u, (image_size, u_filter),\n                stride=2,\n                row=row,\n                cache_every=cache_every,\n                run_every=run_every)\n            v_stack.append(u)\n            ul = fast_nn.down_right_shifted_deconv2d(\n                ul, (image_size, ul_filter),\n                row=row,\n                col=col,\n                cache_every=cache_every,\n                run_every=run_every)\n\n            cache_every, run_every = 2, 2\n            image_size = (image_size[0], image_size[1] * 2, image_size[2] * 2,\n                          nr_filters)\n\n            # Gated resnet layers.\n            for rep in range(nr_resnet + 1):\n                u = fast_nn.gated_resnet_vstack_only(\n                    u, (image_size, u_filter),\n                    extra_row_input=u_list.pop(),\n                    row=row,\n                    cache_every=cache_every,\n                    run_every=run_every,\n                    nonlinearity=resnet_nonlinearity)\n                v_stack.append(u)\n                ul = fast_nn.gated_resnet_hstack(\n                    ul,\n                    cache_v_stack_variable(u), (image_size, ul_filter),\n                    extra_pixel_input=ul_list.pop(),\n                    row=row,\n                    col=col,\n                    cache_every=cache_every,\n                    run_every=run_every,\n                    nonlinearity=resnet_nonlinearity)\n\n            # Upsample.    \n            cache_every, run_every = 2, 1\n            u = fast_nn.down_shifted_deconv2d(\n                u, (image_size, u_filter),\n                stride=2,\n                row=row,\n                cache_every=cache_every,\n                run_every=run_every)\n            v_stack.append(u)\n            ul = fast_nn.down_right_shifted_deconv2d(\n                ul, (image_size, ul_filter),\n                row=row,\n                col=col,\n                cache_every=cache_every,\n                run_every=run_every)\n\n            cache_every, run_every = 1, 1\n            image_size = (image_size[0], image_size[1] * 2, image_size[2] * 2,\n                          nr_filters)\n\n            # Gated resnet layers.\n            for rep in range(nr_resnet + 1):\n                u = fast_nn.gated_resnet_vstack_only(\n                    u, (image_size, u_filter),\n                    extra_row_input=u_list.pop(),\n                    row=row,\n                    cache_every=cache_every,\n                    run_every=run_every,\n                    nonlinearity=resnet_nonlinearity)\n                v_stack.append(u)\n                ul = fast_nn.gated_resnet_hstack(\n                    ul,\n                    cache_v_stack_variable(u), (image_size, ul_filter),\n                    extra_pixel_input=ul_list.pop(),\n                    row=row,\n                    col=col,\n                    cache_every=cache_every,\n                    run_every=run_every,\n                    nonlinearity=resnet_nonlinearity)\n\n            assert len(u_list) == 0\n            assert len(ul_list) == 0\n\n            x_out = nn.nin(tf.nn.elu(ul), 10 * nr_logistic_mix)\n            sample = nn.sample_from_discretized_mix_logistic(\n                x_out, nr_logistic_mix, seed=seed)\n            cache_v_stack = tf.group(*tf.get_collection(UPDATE_V_STACK))\n\n            return sample, x_out, cache_v_stack\n"
  },
  {
    "path": "fast_pixel_cnn_pp/nn.py",
    "content": "\"\"\"\nA mostly copied but slightly modified version of OpenAI's pixel_cnn_pp/nn.py \n\"\"\"\n\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow.contrib.framework.python.ops import add_arg_scope\n\ndef int_shape(x):\n    return list(map(int, x.get_shape()))\n\ndef concat_elu(x):\n    \"\"\" like concatenated ReLU (http://arxiv.org/abs/1603.05201), but then with ELU \"\"\"\n    x_shape = x.get_shape().as_list()\n    axis = len(x_shape) - 1\n    out = tf.nn.elu(tf.concat([x, -x], axis))\n    out.set_shape(x_shape[:-1] + [x_shape[-1] * 2])\n    return out\n\ndef log_sum_exp(x):\n    \"\"\" numerically stable log_sum_exp implementation that prevents overflow \"\"\"\n    axis = len(x.get_shape())-1\n    m = tf.reduce_max(x, axis)\n    m2 = tf.reduce_max(x, axis, keep_dims=True)\n    return m + tf.log(tf.reduce_sum(tf.exp(x-m2), axis))\n\ndef log_prob_from_logits(x):\n    \"\"\" numerically stable log_softmax implementation that prevents overflow \"\"\"\n    axis = len(x.get_shape())-1\n    m = tf.reduce_max(x, axis, keep_dims=True)\n    return x - m - tf.log(tf.reduce_sum(tf.exp(x-m), axis, keep_dims=True))\n\ndef discretized_mix_logistic_loss(x,l,sum_all=True):\n    \"\"\" log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval \"\"\"\n    xs = int_shape(x) # true image (i.e. labels) to regress to, e.g. (B,32,32,3)\n    ls = int_shape(l) # predicted distribution, e.g. (B,32,32,100)\n    nr_mix = int(ls[-1] / 10) # here and below: unpacking the params of the mixture of logistics\n    logit_probs = l[:,:,:,:nr_mix]\n    l = tf.reshape(l[:,:,:,nr_mix:], xs + [nr_mix*3])\n    means = l[:,:,:,:,:nr_mix]\n    log_scales = tf.maximum(l[:,:,:,:,nr_mix:2*nr_mix], -7.)\n    coeffs = tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix])\n    x = tf.reshape(x, xs + [1]) + tf.zeros(xs + [nr_mix]) # here and below: getting the means and adjusting them based on preceding sub-pixels\n    m2 = tf.reshape(means[:,:,:,1,:] + coeffs[:, :, :, 0, :] * x[:, :, :, 0, :], [xs[0],xs[1],xs[2],1,nr_mix])\n    m3 = tf.reshape(means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + coeffs[:, :, :, 2, :] * x[:, :, :, 1, :], [xs[0],xs[1],xs[2],1,nr_mix])\n    means = tf.concat(3,[tf.reshape(means[:,:,:,0,:], [xs[0],xs[1],xs[2],1,nr_mix]), m2, m3])\n    centered_x = x - means\n    inv_stdv = tf.exp(-log_scales)\n    plus_in = inv_stdv * (centered_x + 1./255.)\n    cdf_plus = tf.nn.sigmoid(plus_in)\n    min_in = inv_stdv * (centered_x - 1./255.)\n    cdf_min = tf.nn.sigmoid(min_in)\n    log_cdf_plus = plus_in - tf.nn.softplus(plus_in) # log probability for edge case of 0 (before scaling)\n    log_one_minus_cdf_min = -tf.nn.softplus(min_in) # log probability for edge case of 255 (before scaling)\n    cdf_delta = cdf_plus - cdf_min # probability for all other cases\n    mid_in = inv_stdv * centered_x\n    log_pdf_mid = mid_in - log_scales - 2.*tf.nn.softplus(mid_in) # log probability in the center of the bin, to be used in extreme cases (not actually used in our code)\n\n    # now select the right output: left edge case, right edge case, normal case, extremely low prob case (doesn't actually happen for us)\n\n    # this is what we are really doing, but using the robust version below for extreme cases in other applications and to avoid NaN issue with tf.select()\n    # log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999, log_one_minus_cdf_min, tf.log(cdf_delta)))\n\n    # robust version, that still works if probabilities are below 1e-5 (which never happens in our code)\n    # tensorflow backpropagates through tf.select() by multiplying with zero instead of selecting: this requires use to use some ugly tricks to avoid potential NaNs\n    # the 1e-12 in tf.maximum(cdf_delta, 1e-12) is never actually used as output, it's purely there to get around the tf.select() gradient issue\n    # if the probability on a sub-pixel is below 1e-5, we use an approximation based on the assumption that the log-density is constant in the bin of the observed sub-pixel value\n    log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999, log_one_minus_cdf_min, tf.select(cdf_delta > 1e-5, tf.log(tf.maximum(cdf_delta, 1e-12)), log_pdf_mid - np.log(127.5))))\n\n    log_probs = tf.reduce_sum(log_probs,3) + log_prob_from_logits(logit_probs)\n    if sum_all:\n        return -tf.reduce_sum(log_sum_exp(log_probs))\n    else:\n        return -tf.reduce_sum(log_sum_exp(log_probs),[1,2])\n\ndef sample_from_discretized_mix_logistic(l,nr_mix,seed=None):\n    ls = int_shape(l)\n    xs = ls[:-1] + [3]\n    # unpack parameters\n    logit_probs = l[:, :, :, :nr_mix]\n    l = tf.reshape(l[:, :, :, nr_mix:], xs + [nr_mix*3])\n    # sample mixture indicator from softmax\n    sel = tf.one_hot(tf.argmax(logit_probs - tf.log(-tf.log(tf.random_uniform(logit_probs.get_shape(), minval=1e-5, maxval=1. - 1e-5, seed=seed))), 3), depth=nr_mix, dtype=tf.float32)    \n    sel = tf.reshape(sel, xs[:-1] + [1,nr_mix])\n    # select logistic parameters\n    means = tf.reduce_sum(l[:,:,:,:,:nr_mix]*sel,4)\n    log_scales = tf.maximum(tf.reduce_sum(l[:,:,:,:,nr_mix:2*nr_mix]*sel,4), -7.)\n    coeffs = tf.reduce_sum(tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix])*sel,4)\n    # sample from logistic & clip to interval\n    # we don't actually round to the nearest 8bit value when sampling\n    \n    u = tf.random_uniform(means.get_shape(), minval=1e-5, maxval=1. - 1e-5, seed=(seed + 1 if seed is not None else None))\n    \n    x = means + tf.exp(log_scales)*(tf.log(u) - tf.log(1. - u))\n    x0 = tf.minimum(tf.maximum(x[:,:,:,0], -1.), 1.)\n    x1 = tf.minimum(tf.maximum(x[:,:,:,1] + coeffs[:,:,:,0]*x0, -1.), 1.)\n    x2 = tf.minimum(tf.maximum(x[:,:,:,2] + coeffs[:,:,:,1]*x0 + coeffs[:,:,:,2]*x1, -1.), 1.)\n    return tf.concat([tf.reshape(x0,xs[:-1]+[1]), tf.reshape(x1,xs[:-1]+[1]), tf.reshape(x2,xs[:-1]+[1])], 3)\n\ndef get_var_maybe_avg(var_name, ema, **kwargs):\n    ''' utility for retrieving polyak averaged params '''\n    v = tf.get_variable(var_name, **kwargs)\n    if ema is not None:\n        v = ema.average(v)\n    return v\n\ndef get_vars_maybe_avg(var_names, ema, **kwargs):\n    ''' utility for retrieving polyak averaged params '''\n    vars = []\n    for vn in var_names:\n        vars.append(get_var_maybe_avg(vn, ema, **kwargs))\n    return vars\n\ndef adam_updates(params, cost_or_grads, lr=0.001, mom1=0.9, mom2=0.999):\n    ''' Adam optimizer '''\n    updates = []\n    if type(cost_or_grads) is not list:\n        grads = tf.gradients(cost_or_grads, params)\n    else:\n        grads = cost_or_grads\n    t = tf.Variable(1., 'adam_t')\n    for p, g in zip(params, grads):\n        mg = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_mg')\n        if mom1>0:\n            v = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_v')\n            v_t = mom1*v + (1. - mom1)*g\n            v_hat = v_t / (1. - tf.pow(mom1,t))\n            updates.append(v.assign(v_t))\n        else:\n            v_hat = g\n        mg_t = mom2*mg + (1. - mom2)*tf.square(g)\n        mg_hat = mg_t / (1. - tf.pow(mom2,t))\n        g_t = v_hat / tf.sqrt(mg_hat + 1e-8)\n        p_t = p - lr * g_t\n        updates.append(mg.assign(mg_t))\n        updates.append(p.assign(p_t))\n    updates.append(t.assign_add(1))\n    return tf.group(*updates)\n\ndef get_name(layer_name, counters):\n    ''' utlity for keeping track of layer names '''\n    if not layer_name in counters:\n        counters[layer_name] = 0\n    name = layer_name + '_' + str(counters[layer_name])\n    counters[layer_name] += 1\n    return name\n\n@add_arg_scope\ndef dense(x, num_units, nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs):\n    ''' fully connected layer '''\n    name = get_name('dense', counters)\n    with tf.variable_scope(name):\n        if init:\n            # data based initialization of parameters\n            V = tf.get_variable('V', [int(x.get_shape()[1]),num_units], tf.float32, tf.random_normal_initializer(0, 0.05), trainable=True)\n            V_norm = tf.nn.l2_normalize(V.initialized_value(), [0])\n            x_init = tf.matmul(x, V_norm)\n            m_init, v_init = tf.nn.moments(x_init, [0])\n            scale_init = init_scale/tf.sqrt(v_init + 1e-10)\n            g = tf.get_variable('g', dtype=tf.float32, initializer=scale_init, trainable=True)\n            b = tf.get_variable('b', dtype=tf.float32, initializer=-m_init*scale_init, trainable=True)\n            x_init = tf.reshape(scale_init,[1,num_units])*(x_init-tf.reshape(m_init,[1,num_units]))\n            if nonlinearity is not None:\n                x_init = nonlinearity(x_init)\n            return x_init\n\n        else:\n            #V,g,b = get_vars_maybe_avg(['V','g','b'], ema)\n            V = tf.get_variable('V', [int(x.get_shape()[1]),num_units], tf.float32)\n            g = tf.get_variable('g', [num_units], tf.float32)\n            b = tf.get_variable('b', [num_units], tf.float32)\n            if ema is not None:\n                V, g, b = ema.average(V), ema.average(g), ema.average(b)\n            #tf.assert_variables_initialized([V,g,b])\n\n            # use weight normalization (Salimans & Kingma, 2016)\n            x = tf.matmul(x, V)\n            scaler = g/tf.sqrt(tf.reduce_sum(tf.square(V),[0]))\n            x = tf.reshape(scaler,[1,num_units])*x + tf.reshape(b,[1,num_units])\n\n            # apply nonlinearity\n            if nonlinearity is not None:\n                x = nonlinearity(x)\n            return x\n\n@add_arg_scope\ndef conv2d(x, num_filters, filter_size=[3,3], stride=[1,1], pad='SAME', nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs):\n    ''' convolutional layer '''\n    name = get_name('conv2d', counters)\n    with tf.variable_scope(name):\n        if init:\n            # data based initialization of parameters\n            V = tf.get_variable('V', filter_size+[int(x.get_shape()[-1]),num_filters], tf.float32, tf.random_normal_initializer(0, 0.05), trainable=True)\n            V_norm = tf.nn.l2_normalize(V.initialized_value(), [0,1,2])\n            x_init = tf.nn.conv2d(x, V_norm, [1]+stride+[1], pad)\n            m_init, v_init = tf.nn.moments(x_init, [0,1,2])\n            scale_init = init_scale/tf.sqrt(v_init + 1e-8)\n            g = tf.get_variable('g', dtype=tf.float32, initializer=scale_init, trainable=True)\n            b = tf.get_variable('b', dtype=tf.float32, initializer=-m_init*scale_init, trainable=True)\n            x_init = tf.reshape(scale_init,[1,1,1,num_filters])*(x_init-tf.reshape(m_init,[1,1,1,num_filters]))\n            if nonlinearity is not None:\n                x_init = nonlinearity(x_init)\n            return x_init\n\n        else:\n            V, g, b = get_vars_maybe_avg(['V', 'g', 'b'], ema)\n            tf.assert_variables_initialized([V,g,b])\n\n            # use weight normalization (Salimans & Kingma, 2016)\n            W = tf.reshape(g,[1,1,1,num_filters])*tf.nn.l2_normalize(V,[0,1,2])\n\n            # calculate convolutional layer output\n            x = tf.nn.bias_add(tf.nn.conv2d(x, W, [1]+stride+[1], pad), b)\n\n            # apply nonlinearity\n            if nonlinearity is not None:\n                x = nonlinearity(x)\n            return x\n\n@add_arg_scope\ndef deconv2d(x, num_filters, filter_size=[3,3], stride=[1,1], pad='SAME', nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs):\n    ''' transposed convolutional layer '''\n    name = get_name('deconv2d', counters)\n    xs = int_shape(x)\n    if pad=='SAME':\n        target_shape = [xs[0], xs[1]*stride[0], xs[2]*stride[1], num_filters]\n    else:\n        target_shape = [xs[0], xs[1]*stride[0] + filter_size[0]-1, xs[2]*stride[1] + filter_size[1]-1, num_filters]\n    with tf.variable_scope(name):\n        if init:\n            # data based initialization of parameters\n            V = tf.get_variable('V', filter_size+[num_filters,int(x.get_shape()[-1])], tf.float32, tf.random_normal_initializer(0, 0.05), trainable=True)\n            V_norm = tf.nn.l2_normalize(V.initialized_value(), [0,1,3])\n            x_init = tf.nn.conv2d_transpose(x, V_norm, target_shape, [1]+stride+[1], padding=pad)\n            m_init, v_init = tf.nn.moments(x_init, [0,1,2])\n            scale_init = init_scale/tf.sqrt(v_init + 1e-8)\n            g = tf.get_variable('g', dtype=tf.float32, initializer=scale_init, trainable=True)\n            b = tf.get_variable('b', dtype=tf.float32, initializer=-m_init*scale_init, trainable=True)\n            x_init = tf.reshape(scale_init,[1,1,1,num_filters])*(x_init-tf.reshape(m_init,[1,1,1,num_filters]))\n            if nonlinearity is not None:\n                x_init = nonlinearity(x_init)\n            return x_init\n\n        else:\n            V, g, b = get_vars_maybe_avg(['V', 'g', 'b'], ema)\n            tf.assert_variables_initialized([V,g,b])\n\n            # use weight normalization (Salimans & Kingma, 2016)\n            W = tf.reshape(g,[1,1,num_filters,1])*tf.nn.l2_normalize(V,[0,1,3])\n\n            # calculate convolutional layer output\n            x = tf.nn.conv2d_transpose(x, W, target_shape, [1]+stride+[1], padding=pad)\n            x = tf.nn.bias_add(x, b)\n\n            # apply nonlinearity\n            if nonlinearity is not None:\n                x = nonlinearity(x)\n            return x\n\n@add_arg_scope\ndef nin(x, num_units, **kwargs):\n    \"\"\" a network in network layer (1x1 CONV) \"\"\"\n    s = int_shape(x)\n    x = tf.reshape(x, [np.prod(s[:-1]),s[-1]])\n    x = dense(x, num_units, **kwargs)\n    return tf.reshape(x, s[:-1]+[num_units])\n\n''' meta-layer consisting of multiple base layers '''\n\n@add_arg_scope\ndef gated_resnet(x, a=None, h=None, nonlinearity=concat_elu, conv=conv2d, init=False, counters={}, ema=None, dropout_p=0., **kwargs):\n    xs = int_shape(x)\n    num_filters = xs[-1]\n\n    c1 = conv(nonlinearity(x), num_filters, counters=counters, init=init)\n    if a is not None: # add short-cut connection if auxiliary input 'a' is given\n        c1 += nin(nonlinearity(a), num_filters, counters=counters, init=init)\n    c1 = nonlinearity(c1)\n    if dropout_p > 0:\n        c1 = tf.nn.dropout(c1, keep_prob=1. - dropout_p)\n    c2 = conv(c1, num_filters * 2, init_scale=0.1, counters=counters, init=init)\n\n    # add projection of h vector if included: conditional generation\n    if h is not None:\n        with tf.variable_scope(get_name('conditional_weights', counters)):\n            hw = get_var_maybe_avg('hw', ema, shape=[int_shape(h)[-1], 2 * num_filters], dtype=tf.float32,\n                                    initializer=tf.random_normal_initializer(0, 0.05), trainable=True)\n        if init:\n            hw = hw.initialized_value()\n        c2 += tf.reshape(tf.matmul(h, hw), [xs[0], 1, 1, 2 * num_filters])\n\n    a, b = tf.split(c2, 2, 3)\n    c3 = a * tf.nn.sigmoid(b)\n    return x + c3\n\n''' utilities for shifting the image around, efficient alternative to masking convolutions '''\n\ndef down_shift(x):\n    xs = int_shape(x)\n    return tf.concat(1,[tf.zeros([xs[0],1,xs[2],xs[3]]), x[:,:xs[1]-1,:,:]])\n\ndef right_shift(x):\n    xs = int_shape(x)\n    return tf.concat(2,[tf.zeros([xs[0],xs[1],1,xs[3]]), x[:,:,:xs[2]-1,:]])\n\n@add_arg_scope\ndef down_shifted_conv2d(x, num_filters, filter_size=[2,3], stride=[1,1], **kwargs):\n    x = tf.pad(x, [[0,0],[filter_size[0]-1,0], [int((filter_size[1]-1)/2),int((filter_size[1]-1)/2)],[0,0]])\n    return conv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs)\n\n@add_arg_scope\ndef down_shifted_deconv2d(x, num_filters, filter_size=[2,3], stride=[1,1], **kwargs):\n    x = deconv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs)\n    xs = int_shape(x)\n    return x[:,:(xs[1]-filter_size[0]+1),int((filter_size[1]-1)/2):(xs[2]-int((filter_size[1]-1)/2)),:]\n\n@add_arg_scope\ndef down_right_shifted_conv2d(x, num_filters, filter_size=[2,2], stride=[1,1], **kwargs):\n    x = tf.pad(x, [[0,0],[filter_size[0]-1, 0], [filter_size[1]-1, 0],[0,0]])\n    return conv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs)\n\n@add_arg_scope\ndef down_right_shifted_deconv2d(x, num_filters, filter_size=[2,2], stride=[1,1], **kwargs):\n    x = deconv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs)\n    xs = int_shape(x)\n    return x[:,:(xs[1]-filter_size[0]+1):,:(xs[2]-filter_size[1]+1),:]"
  },
  {
    "path": "fast_pixel_cnn_pp/plotting.py",
    "content": "'''\nCopied from OpenAI's pixel_cnn_pp/plotting.py\n'''\n\nimport numpy as np\nimport matplotlib\nmatplotlib.use('Agg')\nfrom matplotlib import pyplot as plt\n\n# Plot image examples.\ndef plot_img(img, title=None):\n    plt.figure()\n    plt.imshow(img, interpolation='nearest')\n    if title is not None:\n        plt.title(title)\n    plt.axis('off')\n    plt.tight_layout()\n    plt.show(block=False)\n\ndef img_stretch(img):\n    img = img.astype(float)\n    img -= np.min(img)\n    img /= np.max(img)+1e-12\n    return img\n\ndef img_tile(imgs, aspect_ratio=1.0, tile_shape=None, border=1,\n             border_color=0, stretch=False):\n    ''' Tile images in a grid.\n    If tile_shape is provided only as many images as specified in tile_shape\n    will be included in the output.\n    '''\n\n    # Prepare images\n    if stretch:\n        imgs = img_stretch(imgs)\n    imgs = np.array(imgs)\n    if imgs.ndim != 3 and imgs.ndim != 4:\n        raise ValueError('imgs has wrong number of dimensions.')\n    n_imgs = imgs.shape[0]\n\n    # Grid shape\n    img_shape = np.array(imgs.shape[1:3])\n    if tile_shape is None:\n        img_aspect_ratio = img_shape[1] / float(img_shape[0])\n        aspect_ratio *= img_aspect_ratio\n        tile_height = int(np.ceil(np.sqrt(n_imgs * aspect_ratio)))\n        tile_width = int(np.ceil(np.sqrt(n_imgs / aspect_ratio)))\n        grid_shape = np.array((tile_height, tile_width))\n    else:\n        assert len(tile_shape) == 2\n        grid_shape = np.array(tile_shape)\n\n    # Tile image shape\n    tile_img_shape = np.array(imgs.shape[1:])\n    tile_img_shape[:2] = (img_shape[:2] + border) * grid_shape[:2] - border\n\n    # Assemble tile image\n    tile_img = np.empty(tile_img_shape)\n    tile_img[:] = border_color\n    for i in range(grid_shape[0]):\n        for j in range(grid_shape[1]):\n            img_idx = j + i*grid_shape[1]\n            if img_idx >= n_imgs:\n                # No more images - stop filling out the grid.\n                break\n            img = imgs[img_idx]\n            yoff = (img_shape[0] + border) * i\n            xoff = (img_shape[1] + border) * j\n            tile_img[yoff:yoff+img_shape[0], xoff:xoff+img_shape[1], ...] = img\n\n    return tile_img\n\ndef conv_filter_tile(filters):\n    n_filters, n_channels, height, width = filters.shape\n    tile_shape = None\n    if n_channels == 3:\n        # Interpret 3 color channels as RGB\n        filters = np.transpose(filters, (0, 2, 3, 1))\n    else:\n        # Organize tile such that each row corresponds to a filter and the\n        # columns are the filter channels\n        tile_shape = (n_channels, n_filters)\n        filters = np.transpose(filters, (1, 0, 2, 3))\n        filters = np.resize(filters, (n_filters*n_channels, height, width))\n    filters = img_stretch(filters)\n    return img_tile(filters, tile_shape=tile_shape)\n    \ndef scale_to_unit_interval(ndar, eps=1e-8):\n  \"\"\" Scales all values in the ndarray ndar to be between 0 and 1 \"\"\"\n  ndar = ndar.copy()\n  ndar -= ndar.min()\n  ndar *= 1.0 / (ndar.max() + eps)\n  return ndar\n\n\ndef tile_raster_images(X, img_shape, tile_shape, tile_spacing=(0, 0),\n                       scale_rows_to_unit_interval=True,\n                       output_pixel_vals=True):\n  \"\"\"\n  Transform an array with one flattened image per row, into an array in\n  which images are reshaped and layed out like tiles on a floor.\n\n  This function is useful for visualizing datasets whose rows are images,\n  and also columns of matrices for transforming those rows\n  (such as the first layer of a neural net).\n\n  :type X: a 2-D ndarray or a tuple of 4 channels, elements of which can\n  be 2-D ndarrays or None;\n  :param X: a 2-D array in which every row is a flattened image.\n\n  :type img_shape: tuple; (height, width)\n  :param img_shape: the original shape of each image\n\n  :type tile_shape: tuple; (rows, cols)\n  :param tile_shape: the number of images to tile (rows, cols)\n\n  :param output_pixel_vals: if output should be pixel values (i.e. int8\n  values) or floats\n\n  :param scale_rows_to_unit_interval: if the values need to be scaled before\n  being plotted to [0,1] or not\n\n\n  :returns: array suitable for viewing as an image.\n  (See:`PIL.Image.fromarray`.)\n  :rtype: a 2-d array with same dtype as X.\n\n  \"\"\"\n\n  assert len(img_shape) == 2\n  assert len(tile_shape) == 2\n  assert len(tile_spacing) == 2\n\n  # The expression below can be re-written in a more C style as\n  # follows :\n  #\n  # out_shape = [0,0]\n  # out_shape[0] = (img_shape[0] + tile_spacing[0]) * tile_shape[0] -\n  #                tile_spacing[0]\n  # out_shape[1] = (img_shape[1] + tile_spacing[1]) * tile_shape[1] -\n  #                tile_spacing[1]\n  out_shape = [(ishp + tsp) * tshp - tsp for ishp, tshp, tsp\n                      in zip(img_shape, tile_shape, tile_spacing)]\n\n  if isinstance(X, tuple):\n      assert len(X) == 4\n      # Create an output numpy ndarray to store the image\n      if output_pixel_vals:\n          out_array = np.zeros((out_shape[0], out_shape[1], 4), dtype='uint8')\n      else:\n          out_array = np.zeros((out_shape[0], out_shape[1], 4), dtype=X.dtype)\n\n      #colors default to 0, alpha defaults to 1 (opaque)\n      if output_pixel_vals:\n          channel_defaults = [0, 0, 0, 255]\n      else:\n          channel_defaults = [0., 0., 0., 1.]\n\n      for i in range(4):\n          if X[i] is None:\n              # if channel is None, fill it with zeros of the correct\n              # dtype\n              out_array[:, :, i] = np.zeros(out_shape,\n                      dtype='uint8' if output_pixel_vals else out_array.dtype\n                      ) + channel_defaults[i]\n          else:\n              # use a recurrent call to compute the channel and store it\n              # in the output\n              out_array[:, :, i] = tile_raster_images(X[i], img_shape, tile_shape, tile_spacing, scale_rows_to_unit_interval, output_pixel_vals)\n      return out_array\n\n  else:\n      # if we are dealing with only one channel\n      H, W = img_shape\n      Hs, Ws = tile_spacing\n\n      # generate a matrix to store the output\n      out_array = np.zeros(out_shape, dtype='uint8' if output_pixel_vals else X.dtype)\n\n\n      for tile_row in range(tile_shape[0]):\n          for tile_col in range(tile_shape[1]):\n              if tile_row * tile_shape[1] + tile_col < X.shape[0]:\n                  if scale_rows_to_unit_interval:\n                      # if we should scale values to be between 0 and 1\n                      # do this by calling the `scale_to_unit_interval`\n                      # function\n                      this_img = scale_to_unit_interval(X[tile_row * tile_shape[1] + tile_col].reshape(img_shape))\n                  else:\n                      this_img = X[tile_row * tile_shape[1] + tile_col].reshape(img_shape)\n                  # add the slice to the corresponding position in the\n                  # output array\n                  out_array[\n                      tile_row * (H+Hs): tile_row * (H + Hs) + H,\n                      tile_col * (W+Ws): tile_col * (W + Ws) + W\n                      ] \\\n                      = this_img * (255 if output_pixel_vals else 1)\n      return out_array\n\n"
  },
  {
    "path": "fast_pixel_cnn_pp/test_components.py",
    "content": "from . import nn\nfrom . import fast_nn\n\nimport tensorflow as tf\nimport numpy as np\n\nimport math\nimport unittest\nfrom collections import namedtuple\n\nPlaceholders = namedtuple(\n    'Placeholders',\n    ['full_input', 'pixel_input', 'row_input', 'row_id', 'col_id'])\n\n\nclass FastPixelCNNPPTest(tf.test.TestCase):\n    def test_down_shifted_conv2d_layer1_stride1(self):\n        self._test_down_shifted(num_layers=1)\n\n    def test_down_shifted_conv2d_layer1_stride1_1by3(self):\n        self._test_down_shifted(num_layers=1, filter_size=[1, 3])\n\n    def test_down_shifted_conv2d_layer2_stride1(self):\n        self._test_down_shifted(num_layers=2)\n\n    def test_down_shifted_conv2d_layer3_stride1(self):\n        self._test_down_shifted(num_layers=3)\n\n    def test_down_shifted_conv2d_layer2_stride1_2(self):\n        self._test_down_shifted(num_layers=2, strides=[1, 2])\n\n    def test_down_shifted_conv2d_layer1_stride2(self):\n        self._test_down_shifted(num_layers=1, strides=[2])\n\n    def test_down_shifted_conv2d_layer3_stride1_2_2(self):\n        self._test_down_shifted(num_layers=3, strides=[1, 2, 2])\n\n    def test_down_shifted_deconv2d_layer1_stride2(self):\n        self._test_down_shifted(num_layers=1, strides=[2], layers=['deconv'])\n\n    def test_down_shifted_deconv2d_layer2_stride2(self):\n        self._test_down_shifted(\n            num_layers=2, strides=[2, 2], layers=['deconv', 'deconv'])\n\n    def test_down_shifted_conv2d_deconv2d_layer2_stride2(self):\n        self._test_down_shifted(\n            num_layers=2, strides=[2, 2], layers=['conv', 'deconv'])\n\n    def test_down_shifted_conv2d_conv2d_2_deconv_2_layer2(self):\n        self._test_down_shifted(\n            num_layers=3, strides=[1, 2, 2], layers=['conv', 'conv', 'deconv'])\n\n    def test_down_shifted_conv2d_conv2d_deconv2d_deconv2d_layer4_stride2(self):\n        self._test_down_shifted(\n            num_layers=4,\n            strides=[2, 2, 2, 2],\n            layers=['conv', 'conv', 'deconv', 'deconv'])\n\n    def test_down_shifted_conv2d_deconv2d_conv2d_deconv2d_layer4_stride2(self):\n        self._test_down_shifted(\n            num_layers=4,\n            strides=[2, 2, 2, 2],\n            layers=['conv', 'deconv', 'conv', 'deconv'])\n\n    def test_down_right_shifted_conv2d_layer1_stride1_1by3(self):\n        self._test_down_right_shifted(\n            batch_size=1, size=8, num_layers=1, filter_size=[1, 3])\n\n    def test_down_right_shifted_conv2d_layer1_stride1(self):\n        self._test_down_right_shifted(batch_size=1, size=8, num_layers=1)\n\n    def test_down_right_shifted_conv2d_layer2_stride1(self):\n        self._test_down_right_shifted(batch_size=1, size=8, num_layers=2)\n\n    def test_down_right_shifted_conv2d_layer3_stride1(self):\n        self._test_down_right_shifted(batch_size=1, size=8, num_layers=3)\n\n    def test_down_right_shifted_conv2d_layer1_stride2(self):\n        self._test_down_right_shifted(\n            batch_size=1, size=8, num_layers=1, strides=[2])\n\n    def test_down_right_shifted_conv2d_layer2_stride2(self):\n        self._test_down_right_shifted(\n            batch_size=1, size=8, num_layers=2, strides=[2, 2])\n\n    def test_down_right_shifted_conv2d_layer3_stride2(self):\n        self._test_down_right_shifted(\n            batch_size=1, size=16, num_layers=3, strides=[2, 1, 2])\n\n    def test_down_right_shifted_deconv2d_layer1_stride2(self):\n        self._test_down_right_shifted(\n            batch_size=1, size=4, num_layers=1, layers=['deconv'], strides=[2])\n\n    def test_down_right_shifted_deconv2d_layer2_stride2(self):\n        self._test_down_right_shifted(\n            batch_size=1,\n            size=4,\n            num_layers=2,\n            layers=['deconv', 'deconv'],\n            strides=[2, 2])\n\n    def test_down_right_shifted_deconv2d_layer3_stride2(self):\n        self._test_down_right_shifted(\n            batch_size=1,\n            size=4,\n            num_layers=3,\n            layers=['deconv', 'deconv', 'deconv'],\n            strides=[2, 2, 2])\n\n    def test_down_right_shifted_conv2d_deconv2d_layer2_stride2(self):\n        self._test_down_right_shifted(\n            num_layers=2, strides=[2, 2], layers=['conv', 'deconv'])\n\n    def test_down_right_shifted_conv2d_conv2d_2_deconv_2_layer2(self):\n        self._test_down_right_shifted(\n            num_layers=3, strides=[1, 2, 2], layers=['conv', 'conv', 'deconv'])\n\n    def test_down_right_shifted_conv2d_conv2d_deconv2d_deconv2d_layer4_stride2(\n            self):\n        self._test_down_right_shifted(\n            num_layers=4,\n            strides=[2, 2, 2, 2],\n            layers=['conv', 'conv', 'deconv', 'deconv'])\n\n    def test_down_right_shifted_conv2d_deconv2d_conv2d_deconv2d_layer4_stride2(\n            self):\n        self._test_down_right_shifted(\n            num_layers=4,\n            strides=[2, 2, 2, 2],\n            layers=['conv', 'deconv', 'conv', 'deconv'])\n\n    def test_sum_rightshift_downshift(self):\n        self._test_sum_rightshift_downshift(size=32)\n\n    def test_gated_resnet_vstack_only_basic(self):\n        self._gated_resnet_vstack_only()\n\n    def test_gated_resnet_vstack_only_use_h(self):\n        self._gated_resnet_vstack_only(use_h=True)\n\n    def test_gated_resnet_vstack_only_basic_3layers(self):\n        self._gated_resnet_vstack_only(num_layers=3)\n\n    def test_gated_resnet_vstack_only_use_h_3layers(self):\n        self._gated_resnet_vstack_only(use_h=True, num_layers=3)\n\n    def test_gated_resnet_vstack_only_use_extra_row_input(self):\n        self._gated_resnet_vstack_only(use_extra_row_input=True)\n\n    def test_gated_resnet_vstack_only_use_extra_row_input_3layers(self):\n        self._gated_resnet_vstack_only(use_extra_row_input=True, num_layers=3)\n\n    def test_gated_resnet_vstack_only_use_extra_row_input_and_use_h(self):\n        self._gated_resnet_vstack_only(use_extra_row_input=True, use_h=True)\n\n    def test_gated_resnet_vstack_only_use_extra_row_input__and_use_h_3layers(\n            self):\n        self._gated_resnet_vstack_only(\n            use_extra_row_input=True, use_h=True, num_layers=3)\n\n    def test_gated_resnet_hstack_basic(self):\n        self._gated_resnet_hstack()\n\n    def test_gated_resnet_hstack_use_h(self):\n        self._gated_resnet_hstack(use_h=True)\n\n    def test_gated_resnet_hstack_basic_3layers(self):\n        self._gated_resnet_hstack(num_layers=3)\n\n    def test_gated_resnet_hstack_use_h_3layers(self):\n        self._gated_resnet_hstack(use_h=True, num_layers=3)\n\n    def test_gated_resnet_hstack_use_extra_pixel_input(self):\n        self._gated_resnet_hstack(use_extra_pixel_input=True)\n\n    def test_gated_resnet_hstack_use_extra_pixel_input_3layers(self):\n        self._gated_resnet_hstack(use_extra_pixel_input=True, num_layers=3)\n\n    def test_gated_resnet_hstack_use_extra_pixel_input_and_use_h(self):\n        self._gated_resnet_hstack(use_extra_pixel_input=True, use_h=True)\n\n    def test_gated_resnet_hstack_use_extra_pixel_input_and_use_h_3layers(self):\n        self._gated_resnet_hstack(\n            use_extra_pixel_input=True, use_h=True, num_layers=3)\n\n    def _get_placeholders(self, image_size):\n        '''Creates all placeholders.'''\n        batch_size, size, _, input_channels = image_size\n        full_input = tf.placeholder(\n            tf.float32, [batch_size, size, size, input_channels],\n            name='full_input')\n        pixel_input = tf.placeholder(\n            tf.float32, [batch_size, 1, 1, input_channels], name='pixel_input')\n        row_input = tf.placeholder(\n            tf.float32, [batch_size, 1, size, input_channels],\n            name='row_input')\n        row_id = tf.placeholder(tf.int32, [], name='row_id')\n        col_id = tf.placeholder(tf.int32, [], name='col_id')\n        return Placeholders(full_input, pixel_input, row_input, row_id, col_id)\n\n    def _setup_test_equal(self, sess, nn_out, full_input, image_size,\n                          output_image_size):\n        '''Sets up both _test_*_equals() methods by initializing variables and outputs.'''\n        np.random.seed(2702)\n        x = np.random.randn(*image_size)\n        # nn layers use data dependent initialization.\n        # Data dependent initialization requires a batch of initial data,\n        # which we pass through with a feed dict.\n        sess.run(tf.global_variables_initializer(), {full_input: x})\n\n        # Calculate ground truth output.\n        ground_truth_output = sess.run(nn_out, {full_input: x})\n\n        # Create variable that holds output.\n        if output_image_size is None:\n            output_image_size = image_size\n        fast_output = np.zeros(output_image_size)\n\n        # Calculate the increase in output size compared to the input size.\n        # This is useful when only deconv (upsampling) layers are used.\n        side_length = image_size[2]\n        width_ratio = output_image_size[2] // image_size[2]\n        image_increase_factor = max(1, width_ratio)\n\n        # Reset the cache to be safe.\n        sess.run(fast_nn.reset_cache_op())\n\n        return x, ground_truth_output, fast_output, side_length, image_increase_factor\n\n    def _test_rows_equal(self,\n                         sess,\n                         fast_nn_out,\n                         nn_out,\n                         placeholders,\n                         image_size,\n                         output_image_size=None,\n                         run_every=1):\n        '''Tests if vertical stack outputs (one row at a time) of our code and OpenAI code are equal.'''\n        (x, ground_truth_output, fast_output, side_length,\n         image_increase_factor) = self._setup_test_equal(\n             sess, nn_out, placeholders.full_input, image_size,\n             output_image_size)\n\n        # Generate fast output.\n        for row in range(side_length):\n            x_row_input = x[:, row:(row + 1), :, :]\n            # image_increase_factor is relevant when only deconvs are used.\n            # It just runs each row of input multiple times to populate the upsampled output.\n            for inner_iteration in range(image_increase_factor):\n                row_compensated = image_increase_factor * row + inner_iteration\n                feed_dict = {\n                    placeholders.row_input: x_row_input,\n                    placeholders.row_id: row_compensated\n                }\n                row_output = sess.run(fast_nn_out, feed_dict)\n\n                if row_compensated % run_every == 0:\n                    # The run_every division is for downsampling,\n                    # because the output is smaller than the input.\n                    output_row = row_compensated // run_every\n                    fast_output[:, output_row:(output_row + 1),\n                                                :, :] = row_output\n\n        # Within a tolerance.\n        self.assertTrue(np.allclose(ground_truth_output, fast_output))\n        # Exact match.\n        self.assertTrue(\n            np.max(np.abs(ground_truth_output - fast_output)) == 0.0)\n\n    def _test_pixels_equal(self,\n                           sess,\n                           fast_nn_out,\n                           nn_out,\n                           placeholders,\n                           image_size,\n                           output_image_size=None,\n                           run_every=1,\n                           atol=1e-6):\n        '''Tests if horizontal stack outputs (one pixel at a time) of our code and OpenAI code are equal.'''\n        (x, ground_truth_output, fast_output, side_length,\n         image_increase_factor) = self._setup_test_equal(\n             sess, nn_out, placeholders.full_input, image_size,\n             output_image_size)\n\n        # Generate fast output.\n        for row in range(side_length):\n            # image_increase_factor is relevant when only deconvs are used.\n            # It just runs each row and column of input multiple times to populate the upsampled output.\n            for inner_row_iteration in range(image_increase_factor):\n                row_compensated = image_increase_factor * row + inner_row_iteration\n                x_row_input = x[:, row:(row + 1), :, :]\n                for col in range(side_length):\n                    x_pixel_input = x[:, row:(row + 1), col:(col + 1), :]\n                    for inner_col_iteration in range(image_increase_factor):\n                        col_compensated = image_increase_factor * col + inner_col_iteration\n                        feed_dict = {\n                            placeholders.pixel_input: x_pixel_input,\n                            placeholders.row_id: row_compensated,\n                            placeholders.col_id: col_compensated,\n                            placeholders.row_input: x_row_input\n                        }\n\n                        pixel_output = sess.run(fast_nn_out, feed_dict)\n\n                        # The run_every division is for downsampling,\n                        # because the output is smaller than the input.\n                        if row_compensated % run_every == 0 and col_compensated % run_every == 0:\n                            output_row = row_compensated // run_every\n                            output_col = col_compensated // run_every\n                            fast_output[:, output_row:(output_row + 1),\n                                        output_col:(output_col + 1\n                                                    ), :] = pixel_output\n\n        self.assertTrue(\n            np.allclose(ground_truth_output, fast_output, atol=atol))\n\n    def _setup_conv_tests(self, batch_size, size, channels, filter_size,\n                          strides, layers, num_layers):\n        '''Sets up the conv tests by computing basic layer information.'''\n        image_size = (batch_size, size, size, channels)\n        full_filter_size = filter_size + [channels]\n        if strides is None:\n            strides = [1 for _ in range(num_layers)]\n        assert len(strides) == num_layers\n        if layers is None:\n            layers = ['conv' for _ in range(num_layers)]\n        assert len(layers) == num_layers\n        return image_size, full_filter_size, strides, layers\n\n    def _compute_conv_fast_nn_out(self, compute_output_func, network_input,\n                                  image_size, strides, layers):\n        '''Computes cached convolutions, handling downsampling and upsampling.'''\n        batch_size, size, _, nr_filters = image_size\n        num_layers = len(layers)\n\n        # Computes the final output size taking into account downsampling and upsampling.\n        output_size = size\n        for stride, layer_type in zip(strides, layers):\n            if layer_type == 'conv':\n                output_size = output_size // stride\n            else:\n                output_size = output_size * stride\n        output_image_size = (batch_size, output_size, output_size, nr_filters)\n\n        # When running only deconvs, the output size gets bigger than the input size.\n        # For generation, each input must be run multiple times to populate the output.\n        image_increase_factor = max(output_size // size, 1)\n        cumulative_stride = max(1, image_increase_factor)\n\n        fast_nn_out = network_input\n        counters = {}\n        layer_input_size = size\n\n        # Run the network.\n        for layer_num in range(num_layers):\n            stride = strides[layer_num]\n            layer_type = layers[layer_num]\n\n            # The run_every of one layer is the cache_every of the next layer.\n            # These increase after downsampling since fewer inputs correspond to an output.\n            # These decrease after downsampling since more inputs correspond to an output.\n            cache_every = cumulative_stride\n            if layer_type == 'conv':\n                run_every = cumulative_stride * stride\n            else:\n                run_every = max(1, cumulative_stride // stride)\n\n            input_image_size = (batch_size, layer_input_size, layer_input_size,\n                                nr_filters)\n            cumulative_stride = run_every\n\n            fast_nn_out = compute_output_func(fast_nn_out, layer_type,\n                                              input_image_size, stride,\n                                              cache_every, run_every, counters)\n\n            # The size of the input to the next layer.\n            if layer_type == 'conv':\n                layer_input_size = layer_input_size // stride  # Downsampling.\n            else:\n                layer_input_size = layer_input_size * stride  # Upsampling.\n\n        return fast_nn_out, output_image_size, run_every\n\n    def _test_down_shifted(self,\n                           batch_size=10,\n                           size=16,\n                           channels=7,\n                           num_layers=1,\n                           filter_size=[2, 3],\n                           strides=None,\n                           layers=None,\n                           nonlinearity=tf.sigmoid):\n        '''Tests the down_shifted convolution for the vertical stack.'''\n\n        def get_conv_function(module, layer_type):\n            '''Returns the matching conv or deconv function.'''\n            if layer_type == 'conv':\n                return module.down_shifted_conv2d\n            elif layer_type == 'deconv':\n                return module.down_shifted_deconv2d\n            else:\n                raise ValueError('Unknown layer_type %s' % layer_type)\n\n        image_size, full_filter_size, strides, layers = self._setup_conv_tests(\n            batch_size, size, channels, filter_size, strides, layers,\n            num_layers)\n\n        with self.test_session() as sess:\n            placeholders = self._get_placeholders(image_size)\n\n            # OpenAI output.\n            def compute_ground_truth(init):\n                nn_out = placeholders.full_input\n                counters = {}\n                for layer_num in range(num_layers):\n                    stride = strides[layer_num]\n                    layer_func = get_conv_function(nn, layers[layer_num])\n                    nn_out = layer_func(\n                        nn_out,\n                        num_filters=channels,\n                        filter_size=filter_size,\n                        stride=[stride, stride],\n                        nonlinearity=nonlinearity,\n                        counters=counters,\n                        init=init)\n                return nn_out\n\n            compute_ground_truth(init=True)\n            tf.get_variable_scope().reuse_variables()\n            nn_out = compute_ground_truth(init=False)\n\n            # Our output.\n            def compute_output_func(fast_nn_out, layer_type, input_image_size,\n                                    stride, cache_every, run_every, counters):\n                layer_func = get_conv_function(fast_nn, layer_type)\n                return layer_func(\n                    fast_nn_out,\n                    network_info=(input_image_size, full_filter_size),\n                    stride=stride,\n                    row=placeholders.row_id,\n                    cache_every=cache_every,\n                    run_every=run_every,\n                    counters=counters,\n                    nonlinearity=nonlinearity)\n\n            fast_nn_out, output_image_size, run_every = self._compute_conv_fast_nn_out(\n                compute_output_func, placeholders.row_input, image_size,\n                strides, layers)\n\n            self._test_rows_equal(\n                sess,\n                fast_nn_out,\n                nn_out,\n                placeholders,\n                image_size,\n                output_image_size=output_image_size,\n                run_every=run_every)\n\n    def _test_down_right_shifted(self,\n                                 batch_size=10,\n                                 size=16,\n                                 channels=7,\n                                 num_layers=1,\n                                 filter_size=[2, 2],\n                                 strides=None,\n                                 layers=None,\n                                 nonlinearity=tf.sigmoid):\n        '''Tests the down_shifted convolution for the vertical stack.'''\n\n        def get_conv_function(module, layer_type):\n            '''Returns the matching conv or deconv function.'''\n            if layer_type == 'conv':\n                return module.down_right_shifted_conv2d\n            elif layer_type == 'deconv':\n                return module.down_right_shifted_deconv2d\n            else:\n                raise ValueError('Unknown layer_type %s' % layer_type)\n\n        image_size, full_filter_size, strides, layers = self._setup_conv_tests(\n            batch_size, size, channels, filter_size, strides, layers,\n            num_layers)\n\n        with self.test_session() as sess:\n\n            placeholders = self._get_placeholders(image_size)\n\n            # OpenAI output.\n            def compute_ground_truth(init):\n                nn_out = placeholders.full_input\n                counters = {}\n                for layer_num in range(num_layers):\n                    stride = strides[layer_num]\n                    layer_func = get_conv_function(nn, layers[layer_num])\n                    nn_out = layer_func(\n                        nn_out,\n                        num_filters=channels,\n                        filter_size=filter_size,\n                        stride=[stride, stride],\n                        nonlinearity=nonlinearity,\n                        counters=counters,\n                        init=init)\n                return nn_out\n\n            compute_ground_truth(init=True)\n            tf.get_variable_scope().reuse_variables()\n            nn_out = compute_ground_truth(init=False)\n\n            # Our output.\n            def compute_output_func(fast_nn_out, layer_type, input_image_size,\n                                    stride, cache_every, run_every, counters):\n                layer_func = get_conv_function(fast_nn, layer_type)\n                return layer_func(\n                    fast_nn_out,\n                    network_info=(input_image_size, full_filter_size),\n                    row=placeholders.row_id,\n                    col=placeholders.col_id,\n                    cache_every=cache_every,\n                    run_every=run_every,\n                    counters=counters,\n                    nonlinearity=nonlinearity)\n\n            fast_nn_out, output_image_size, run_every = self._compute_conv_fast_nn_out(\n                compute_output_func, placeholders.pixel_input, image_size,\n                strides, layers)\n\n            self._test_pixels_equal(\n                sess,\n                fast_nn_out,\n                nn_out,\n                placeholders,\n                image_size,\n                output_image_size=output_image_size,\n                run_every=run_every)\n\n    def _gated_resnet_vstack_only(self,\n                                  batch_size=10,\n                                  size=16,\n                                  channels=7,\n                                  num_layers=1,\n                                  filter_size=[2, 3],\n                                  use_h=False,\n                                  use_extra_row_input=False,\n                                  nonlinearity=tf.sigmoid):\n        '''Tests the gated resnet layers for the vertical stack.'''\n        image_size = (batch_size, size, size, channels)\n        full_filter_size = filter_size + [channels]\n\n        np.random.seed(2702)\n        with self.test_session() as sess:\n            placeholders = self._get_placeholders(image_size)\n\n            # Conditional information and skip connections.\n            h, a = None, None\n            if use_h:\n                h = tf.constant(\n                    np.random.randn(batch_size, 20), dtype=tf.float32)\n            if use_extra_row_input:\n                a = placeholders.full_input\n\n            # OpenAI output.\n            def compute_ground_truth(init):\n                counters = {}\n                nn_out = placeholders.full_input\n                for _ in range(num_layers):\n                    nn_out = nn.gated_resnet(\n                        nn_out,\n                        a=a,\n                        h=h,\n                        conv=nn.down_shifted_conv2d,\n                        nonlinearity=nonlinearity,\n                        counters=counters,\n                        init=init)\n                return nn_out\n\n            compute_ground_truth(init=True)\n            tf.get_variable_scope().reuse_variables()\n            nn_out = compute_ground_truth(init=False)\n\n            # Our output.\n            counters = {}\n            fast_nn_out = placeholders.row_input\n            if use_extra_row_input:\n                a = placeholders.row_input\n            for _ in range(num_layers):\n                fast_nn_out = fast_nn.gated_resnet_vstack_only(\n                    fast_nn_out, (image_size, full_filter_size),\n                    placeholders.row_id,\n                    extra_row_input=a,\n                    h=h,\n                    cache_every=1,\n                    run_every=1,\n                    nonlinearity=nonlinearity,\n                    counters=counters)\n\n        self._test_rows_equal(sess, fast_nn_out, nn_out, placeholders,\n                              image_size)\n\n    def _gated_resnet_hstack(self,\n                             batch_size=10,\n                             size=16,\n                             channels=7,\n                             filter_size=[2, 2],\n                             num_layers=1,\n                             use_h=False,\n                             use_extra_pixel_input=False,\n                             nonlinearity=tf.sigmoid):\n        '''Tests the gated resnet layers for the horizontal stack.'''\n        image_size = (batch_size, size, size, channels)\n        full_filter_size = filter_size + [channels]\n\n        with self.test_session() as sess:\n            placeholders = self._get_placeholders(image_size)\n\n            # Conditional information and skip connections.\n            h, a = None, placeholders.full_input\n            if use_h:\n                h = tf.constant(\n                    np.random.randn(batch_size, 20), dtype=tf.float32)\n            if use_extra_pixel_input:\n                a = tf.concat([a, 2 * placeholders.full_input], 3)\n\n            # OpenAI output.\n            def compute_ground_truth(init):\n                counters = {}\n                nn_out = placeholders.full_input\n                for _ in range(num_layers):\n                    nn_out = nn.gated_resnet(\n                        nn_out,\n                        a=a,\n                        h=h,\n                        conv=nn.down_right_shifted_conv2d,\n                        nonlinearity=nonlinearity,\n                        counters=counters,\n                        init=init)\n                return nn_out\n\n            compute_ground_truth(init=True)\n            tf.get_variable_scope().reuse_variables()\n            nn_out = compute_ground_truth(init=False)\n\n            # Our output.\n            extra_pixel_input = None\n            if use_extra_pixel_input:\n                extra_pixel_input = 2 * placeholders.pixel_input\n\n            counters = {}\n            fast_nn_out = placeholders.pixel_input\n            for _ in range(num_layers):\n                fast_nn_out = fast_nn.gated_resnet_hstack(\n                    fast_nn_out,\n                    placeholders.row_input, (image_size, full_filter_size),\n                    h=h,\n                    row=placeholders.row_id,\n                    col=placeholders.col_id,\n                    cache_every=1,\n                    run_every=1,\n                    extra_pixel_input=extra_pixel_input,\n                    nonlinearity=nonlinearity,\n                    counters=counters)\n\n        self._test_pixels_equal(sess, fast_nn_out, nn_out, placeholders,\n                                image_size)\n\n    def _test_sum_rightshift_downshift(self,\n                                       batch_size=10,\n                                       size=16,\n                                       channels=7,\n                                       nonlinearity=tf.sigmoid):\n        '''Tests the sum of the vertical and horizontal stack.'''\n        image_size = (batch_size, size, size, channels)\n\n        with self.test_session() as sess:\n            placeholders = self._get_placeholders(image_size)\n\n            # OpenAI output.\n            def compute_ground_truth(init):\n                counters = {}\n                nn_v_stack = nn.down_shifted_conv2d(\n                    placeholders.full_input,\n                    num_filters=channels,\n                    filter_size=[1, 3],\n                    stride=[1, 1],\n                    nonlinearity=nonlinearity,\n                    counters=counters,\n                    init=init)\n                nn_h_stack = nn.down_right_shifted_conv2d(\n                    placeholders.full_input,\n                    num_filters=channels,\n                    filter_size=[2, 1],\n                    stride=[1, 1],\n                    nonlinearity=nonlinearity,\n                    counters=counters,\n                    init=init)\n                return nn_v_stack + nn_h_stack\n\n            compute_ground_truth(init=True)\n            tf.get_variable_scope().reuse_variables()\n            nn_out = compute_ground_truth(init=False)\n\n            # Our output\n            counters, stride, cache_every, run_every = {}, 1, 1, 1\n            fast_nn_v_stack = fast_nn.down_shifted_conv2d(\n                placeholders.row_input,\n                network_info=(image_size, [1, 3, channels]),\n                stride=stride,\n                row=placeholders.row_id,\n                cache_every=cache_every,\n                run_every=run_every,\n                counters=counters,\n                nonlinearity=nonlinearity)\n            fast_nn_h_stack = fast_nn.down_right_shifted_conv2d(\n                placeholders.pixel_input,\n                network_info=(image_size, [2, 1, channels]),\n                row=placeholders.row_id,\n                col=placeholders.col_id,\n                cache_every=cache_every,\n                run_every=run_every,\n                counters=counters,\n                nonlinearity=nonlinearity)\n            fast_nn_out = fast_nn.sum_rightshift_downshift(\n                fast_nn_h_stack, fast_nn_v_stack, placeholders.col_id)\n\n        self._test_pixels_equal(sess, fast_nn_out, nn_out, placeholders,\n                                image_size)\n"
  },
  {
    "path": "fast_pixel_cnn_pp/test_end_to_end.py",
    "content": "from . import model\nfrom . import fast_nn\n\nimport tensorflow as tf\nimport numpy as np\n\nimport os\nimport unittest\n\n\nclass FastPixelCNNPPEndToEndTest(tf.test.TestCase):\n    def test_end_to_end(self):\n        with self.test_session() as sess:\n            print('Creating model')\n            image_size = (10, 32, 32, 4)\n            batch_size, image_height, image_width, image_channels = image_size\n\n            # Create placeholders.\n            row_input = tf.placeholder(\n                tf.float32, [batch_size, 1, image_width, image_channels],\n                name='row_input')\n            pixel_input = tf.placeholder(\n                tf.float32, [batch_size, 1, 1, image_channels],\n                name='pixel_input')\n            row_id = tf.placeholder(tf.int32, [], name='row_id')\n            col_id = tf.placeholder(tf.int32, [], name='col_id')\n            ema = tf.train.ExponentialMovingAverage(0.9995)\n\n            # Create the model.\n            model_spec = tf.make_template('model', model.model_spec)\n            sample, fast_nn_out, v_stack = model_spec(\n                row_input, pixel_input, row_id, col_id, image_size)\n\n            # Initialize the caches.\n            cache_variables = [\n                v for v in tf.global_variables() if 'cache' in v.name\n            ]\n            sess.run(tf.variables_initializer(cache_variables))\n\n            # Load the pretrained model\n            print('Restoring variables')\n            vars_to_restore = {\n                k: v\n                for k, v in ema.variables_to_restore().items()\n                if 'cache' not in k\n            }\n            saver = tf.train.Saver(vars_to_restore)\n            ckpt_path = None\n            assert ckpt_path, 'Provide a path to the checkpoint in this file'\n            saver.restore(sess, ckpt_path)\n\n            # Create the fixed random input.\n            np.random.seed(2702)\n            x = np.random.randint(0, 256, size=(10, 32, 32, 3))\n            x = np.cast[np.float32]((x - 127.5) / 127.5)\n            x_pad = np.concatenate(\n                (x, np.ones((batch_size, 32, 32, 1))), axis=3)\n            x_downshift = fast_nn.down_shift(x_pad)\n            x_rightshift = fast_nn.right_shift(x_pad)\n\n            # Holds the output.\n            num_output_features = 10 * 10\n            output_features = np.zeros(\n                (batch_size, 32, 32, num_output_features))\n\n            # Compute all features.\n            print('Computing features')\n            sess.run(fast_nn.reset_cache_op())\n            for row in range(image_height):\n                x_row_input = x_downshift[:, row:(row + 1), :, :]\n                sess.run(v_stack, {row_input: x_row_input, row_id: row})\n\n                for col in range(image_width):\n                    x_pixel_input = x_rightshift[:, row:(row + 1),\n                                                 col:(col + 1), :]\n                    feed_dict = {\n                        row_id: row,\n                        col_id: col,\n                        pixel_input: x_pixel_input\n                    }\n                    pixel_features = sess.run(fast_nn_out, feed_dict)\n                    output_features[:, row:(row + 1), col:(\n                        col + 1), :] = pixel_features\n\n            ground_truth_file = os.path.join(os.path.dirname(os.path.realpath(__file__)),\n                                             'ground_truth_output.npy')\n            ground_truth_features = np.load(ground_truth_file)\n            total_features = np.prod(output_features[0].shape)\n            for i in range(batch_size):\n                self.assertTrue(\n                    np.allclose(\n                        output_features[i, :, :, :],\n                        ground_truth_features[i, :, :, :],\n                        atol=1e-4))\n"
  },
  {
    "path": "generate.py",
    "content": "import fast_pixel_cnn_pp.model as model\nimport fast_pixel_cnn_pp.fast_nn as fast_nn\nimport fast_pixel_cnn_pp.plotting as plotting\n\nimport tensorflow as tf\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nimport argparse\nimport time\nimport os\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\n    '-b',\n    '--batch_size',\n    type=int,\n    default=16,\n    help='Number of images to generate simultaneously')\nparser.add_argument(\n    '-i',\n    '--image_size',\n    type=int,\n    default=32,\n    help='Height and width of the image')\nparser.add_argument(\n    '-s', '--seed', type=int, default=2702, help='Seed for random generation')\nparser.add_argument(\n    '-c',\n    '--checkpoint',\n    type=str,\n    default='/home/mbz/pixel_cnn_pp/params_cifar.ckpt',\n    help='Location of the pretrained checkpoint')\nparser.add_argument(\n    '-v',\n    '--save_dir',\n    type=str,\n    default='/tmp',\n    help='Location to save generated images to')\nargs = parser.parse_args()\n\ng = tf.Graph()\nwith g.as_default():\n    print('Creating model')\n    input_channels = 4  # 3 channels for RGB and 1 channel of all ones \n    image_size = (args.batch_size, args.image_size, args.image_size,\n                  input_channels)\n\n    row_input = tf.placeholder(\n        tf.float32, [args.batch_size, 1, args.image_size, input_channels],\n        name='row_input')\n    pixel_input = tf.placeholder(\n        tf.float32, [args.batch_size, 1, 1, input_channels],\n        name='pixel_input')\n    row_id = tf.placeholder(tf.int32, [], name='row_id')\n    col_id = tf.placeholder(tf.int32, [], name='col_id')\n    ema = tf.train.ExponentialMovingAverage(0.9995)\n\n    model_spec = tf.make_template('model', model.model_spec)\n    sample, fast_nn_out, v_stack = model_spec(\n        row_input, pixel_input, row_id, col_id, image_size, seed=args.seed)\n\n    all_cache_variables = [\n        v for v in tf.global_variables() if 'cache' in v.name\n    ]\n    initialize_cache = tf.variables_initializer(all_cache_variables)\n    reset_cache = fast_nn.reset_cache_op()\n\n    vars_to_restore = {\n        k: v\n        for k, v in ema.variables_to_restore().items() if 'cache' not in k\n    }\n    saver = tf.train.Saver(vars_to_restore)\n\n    output_images = np.zeros(\n        (args.batch_size, args.image_size, args.image_size, 3))\n\n    sess = tf.Session()\n    sess.run(initialize_cache)\n    print('Loading checkpoint %s' % args.checkpoint)\n    saver.restore(sess, args.checkpoint)\n\n    batch = 0\n    while True:\n        print('Generating')\n        sess.run(reset_cache)\n        start_time = time.time()\n        for row in range(args.image_size):\n            # Implicit downshift.\n            if row == 0:\n                x_row_input = np.zeros(\n                    (args.batch_size, 1, args.image_size, input_channels))\n            else:\n                x_row_input = output_images[:, (row - 1):row, :, :]\n                x_row_input = np.concatenate(\n                    (x_row_input, np.ones(\n                        (args.batch_size, 1, args.image_size, 1))),\n                    axis=3)\n\n            sess.run(v_stack, {row_input: x_row_input, row_id: row})\n\n            for col in range(args.image_size):\n                # Implicit rightshift.\n                if col == 0:\n                    x_pixel_input = np.zeros(\n                        (args.batch_size, 1, 1, input_channels))\n                else:\n                    x_pixel_input = output_images[:, row:(row + 1),\n                                                  (col - 1):col, :]\n                    x_pixel_input = np.concatenate(\n                        (x_pixel_input, np.ones((args.batch_size, 1, 1, 1))),\n                        axis=3)\n\n                feed_dict = {\n                    row_id: row,\n                    col_id: col,\n                    pixel_input: x_pixel_input\n                }\n                pixel_output = sess.run(sample, feed_dict)\n                output_images[:, row:(row + 1),\n                              col:(col + 1), :] = pixel_output\n\n        end_time = time.time()\n        print('Time taken to generate %d images: %.2f seconds' %\n              (args.batch_size, end_time - start_time))\n\n        plt.close('all')\n        image_tile = plotting.img_tile(\n            output_images, border_color=1.0, stretch=True)\n        plotting.plot_img(image_tile)\n        plt.savefig(os.path.join(args.save_dir, 'images_%d.png' % batch))\n\n        batch += 1\n\nplt.show()\n"
  }
]