[
  {
    "path": ".gitignore",
    "content": "*.pth\n*.pyc\n*.bak\n*.png\n*.jpg\n*.npz\n!rd.png\n!/kodim10.png\n!/bpp-0.125-0.133-ssim-0.865-0.827.png\n!/bpp-0.250-0.249-ssim-0.937-0.918.png\n!/bpp-0.375-0.381-ssim-0.963-0.951.png\n"
  },
  {
    "path": "README.md",
    "content": "# Full Resolution Image Compression with Recurrent Neural Networks\nhttps://arxiv.org/abs/1608.05148v2\n\n## Requirements\n- PyTorch 0.2.0\n\n## Train\n`\npython train.py -f /path/to/your/images/folder/like/mscoco\n`\n\n## Encode and Decode\n### Encode\n`\npython encoder.py --model checkpoint/encoder_epoch_00000005.pth --input /path/to/your/example.png --cuda --output ex --iterations 16\n`\n\nThis will output binary codes saved in `.npz` format.\n\n### Decode\n`\npython decoder.py --model checkpoint/encoder_epoch_00000005.pth --input /path/to/your/example.npz --cuda --output /path/to/output/folder\n`\n\nThis will output images of different quality levels.\n\n## Test\n### Get Kodak dataset\n```bash\nbash test/get_kodak.sh\n```\n\n### Encode and decode with RNN model\n```bash\nbash test/enc_dec.sh\n```\n\n### Encode and decode with JPEG (use `convert` from ImageMagick)\n```bash\nbash test/jpeg.sh\n```\n\n### Calculate SSIM\n```bash\nbash test/calc_ssim.sh\n```\n\n### Draw rate-distortion curve\n```bash\npython test/draw_rd.py\n```\n\n## Result\nLSTM (Additive Reconstruction), before entropy coding\n\n### Rate-distortion\n![Rate-distortion](rd.png)\n\n### `kodim10.png`\n\nOriginal Image\n\n![Original Image](kodim10.png)\n\nBelow Left: LSTM, SSIM=0.865, bpp=0.125\n\nBelow Right: JPEG, SSIM=0.827, bpp=0.133\n\n![bpp-0.125-0.133-ssim-0.865-0.827](bpp-0.125-0.133-ssim-0.865-0.827.png)\n\nBelow Left: LSTM, SSIM=0.937, bpp=0.250\n\nBelow Right: JPEG, SSIM=0.918, bpp=0.249\n\n![bpp-0.250-0.249-ssim-0.937-0.918](bpp-0.250-0.249-ssim-0.937-0.918.png)\n\nBelow Left: LSTM, SSIM=0.963, bpp=0.375\n\nBelow Right: JPEG, SSIM=0.951, bpp=0.381\n\n![bpp-0.375-0.381-ssim-0.963-0.951](bpp-0.375-0.381-ssim-0.963-0.951.png)\n\n## What's inside\n- `train.py`: Main program for training.\n- `encoder.py` and `decoder.py`: Encoder and decoder.\n- `dataset.py`: Utils for reading images.\n- `metric.py`: Functions for Calculatnig MS-SSIM and PSNR.\n- `network.py`: Modules of encoder and decoder.\n- `modules/conv_rnn.py`: ConvLSTM module.\n- `functions/sign.py`: Forward and backward for binary quantization.\n\n## Official Repo\nhttps://github.com/tensorflow/models/tree/master/compression\n"
  },
  {
    "path": "dataset.py",
    "content": "# modified from https://github.com/desimone/vision/blob/fb74c76d09bcc2594159613d5bdadd7d4697bb11/torchvision/datasets/folder.py\n\nimport os\nimport os.path\n\nimport torch\nfrom torchvision import transforms\nimport torch.utils.data as data\nfrom PIL import Image\n\nIMG_EXTENSIONS = [\n    '.jpg',\n    '.JPG',\n    '.jpeg',\n    '.JPEG',\n    '.png',\n    '.PNG',\n    '.ppm',\n    '.PPM',\n    '.bmp',\n    '.BMP',\n]\n\n\ndef is_image_file(filename):\n    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)\n\n\ndef default_loader(path):\n    return Image.open(path).convert('RGB')\n\n\nclass ImageFolder(data.Dataset):\n    \"\"\" ImageFolder can be used to load images where there are no labels.\"\"\"\n\n    def __init__(self, root, transform=None, loader=default_loader):\n        images = []\n        for filename in os.listdir(root):\n            if is_image_file(filename):\n                images.append('{}'.format(filename))\n\n        self.root = root\n        self.imgs = images\n        self.transform = transform\n        self.loader = loader\n\n    def __getitem__(self, index):\n        filename = self.imgs[index]\n        try:\n            img = self.loader(os.path.join(self.root, filename))\n        except:\n            return torch.zeros((3, 32, 32))\n\n        if self.transform is not None:\n            img = self.transform(img)\n        return img\n\n    def __len__(self):\n        return len(self.imgs)\n"
  },
  {
    "path": "decoder.py",
    "content": "import os\nimport argparse\n\nimport numpy as np\nfrom scipy.misc import imread, imresize, imsave\n\nimport torch\nfrom torch.autograd import Variable\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--model', required=True, type=str, help='path to model')\nparser.add_argument('--input', required=True, type=str, help='input codes')\nparser.add_argument('--output', default='.', type=str, help='output folder')\nparser.add_argument('--cuda', action='store_true', help='enables cuda')\nparser.add_argument(\n    '--iterations', type=int, default=16, help='unroll iterations')\nargs = parser.parse_args()\n\ncontent = np.load(args.input)\ncodes = np.unpackbits(content['codes'])\ncodes = np.reshape(codes, content['shape']).astype(np.float32) * 2 - 1\n\ncodes = torch.from_numpy(codes)\niters, batch_size, channels, height, width = codes.size()\nheight = height * 16\nwidth = width * 16\n\ncodes = Variable(codes, volatile=True)\n\nimport network\n\ndecoder = network.DecoderCell()\ndecoder.eval()\n\ndecoder.load_state_dict(torch.load(args.model))\n\ndecoder_h_1 = (Variable(\n    torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True),\n               Variable(\n                   torch.zeros(batch_size, 512, height // 16, width // 16),\n                   volatile=True))\ndecoder_h_2 = (Variable(\n    torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True),\n               Variable(\n                   torch.zeros(batch_size, 512, height // 8, width // 8),\n                   volatile=True))\ndecoder_h_3 = (Variable(\n    torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True),\n               Variable(\n                   torch.zeros(batch_size, 256, height // 4, width // 4),\n                   volatile=True))\ndecoder_h_4 = (Variable(\n    torch.zeros(batch_size, 128, height // 2, width // 2), volatile=True),\n               Variable(\n                   torch.zeros(batch_size, 128, height // 2, width // 2),\n                   volatile=True))\n\nif args.cuda:\n    decoder = decoder.cuda()\n\n    codes = codes.cuda()\n\n    decoder_h_1 = (decoder_h_1[0].cuda(), decoder_h_1[1].cuda())\n    decoder_h_2 = (decoder_h_2[0].cuda(), decoder_h_2[1].cuda())\n    decoder_h_3 = (decoder_h_3[0].cuda(), decoder_h_3[1].cuda())\n    decoder_h_4 = (decoder_h_4[0].cuda(), decoder_h_4[1].cuda())\n\nimage = torch.zeros(1, 3, height, width) + 0.5\nfor iters in range(min(args.iterations, codes.size(0))):\n\n    output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder(\n        codes[iters], decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4)\n    image = image + output.data.cpu()\n\n    imsave(\n        os.path.join(args.output, '{:02d}.png'.format(iters)),\n        np.squeeze(image.numpy().clip(0, 1) * 255.0).astype(np.uint8)\n        .transpose(1, 2, 0))\n"
  },
  {
    "path": "encoder.py",
    "content": "import argparse\n\nimport numpy as np\nfrom scipy.misc import imread, imresize, imsave\n\nimport torch\nfrom torch.autograd import Variable\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\n    '--model', '-m', required=True, type=str, help='path to model')\nparser.add_argument(\n    '--input', '-i', required=True, type=str, help='input image')\nparser.add_argument(\n    '--output', '-o', required=True, type=str, help='output codes')\nparser.add_argument('--cuda', '-g', action='store_true', help='enables cuda')\nparser.add_argument(\n    '--iterations', type=int, default=16, help='unroll iterations')\nargs = parser.parse_args()\n\nimage = imread(args.input, mode='RGB')\nimage = torch.from_numpy(\n    np.expand_dims(\n        np.transpose(image.astype(np.float32) / 255.0, (2, 0, 1)), 0))\nbatch_size, input_channels, height, width = image.size()\nassert height % 32 == 0 and width % 32 == 0\n\nimage = Variable(image, volatile=True)\n\nimport network\n\nencoder = network.EncoderCell()\nbinarizer = network.Binarizer()\ndecoder = network.DecoderCell()\n\nencoder.eval()\nbinarizer.eval()\ndecoder.eval()\n\nencoder.load_state_dict(torch.load(args.model))\nbinarizer.load_state_dict(\n    torch.load(args.model.replace('encoder', 'binarizer')))\ndecoder.load_state_dict(torch.load(args.model.replace('encoder', 'decoder')))\n\nencoder_h_1 = (Variable(\n    torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True),\n               Variable(\n                   torch.zeros(batch_size, 256, height // 4, width // 4),\n                   volatile=True))\nencoder_h_2 = (Variable(\n    torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True),\n               Variable(\n                   torch.zeros(batch_size, 512, height // 8, width // 8),\n                   volatile=True))\nencoder_h_3 = (Variable(\n    torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True),\n               Variable(\n                   torch.zeros(batch_size, 512, height // 16, width // 16),\n                   volatile=True))\n\ndecoder_h_1 = (Variable(\n    torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True),\n               Variable(\n                   torch.zeros(batch_size, 512, height // 16, width // 16),\n                   volatile=True))\ndecoder_h_2 = (Variable(\n    torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True),\n               Variable(\n                   torch.zeros(batch_size, 512, height // 8, width // 8),\n                   volatile=True))\ndecoder_h_3 = (Variable(\n    torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True),\n               Variable(\n                   torch.zeros(batch_size, 256, height // 4, width // 4),\n                   volatile=True))\ndecoder_h_4 = (Variable(\n    torch.zeros(batch_size, 128, height // 2, width // 2), volatile=True),\n               Variable(\n                   torch.zeros(batch_size, 128, height // 2, width // 2),\n                   volatile=True))\n\nif args.cuda:\n    encoder = encoder.cuda()\n    binarizer = binarizer.cuda()\n    decoder = decoder.cuda()\n\n    image = image.cuda()\n\n    encoder_h_1 = (encoder_h_1[0].cuda(), encoder_h_1[1].cuda())\n    encoder_h_2 = (encoder_h_2[0].cuda(), encoder_h_2[1].cuda())\n    encoder_h_3 = (encoder_h_3[0].cuda(), encoder_h_3[1].cuda())\n\n    decoder_h_1 = (decoder_h_1[0].cuda(), decoder_h_1[1].cuda())\n    decoder_h_2 = (decoder_h_2[0].cuda(), decoder_h_2[1].cuda())\n    decoder_h_3 = (decoder_h_3[0].cuda(), decoder_h_3[1].cuda())\n    decoder_h_4 = (decoder_h_4[0].cuda(), decoder_h_4[1].cuda())\n\ncodes = []\nres = image - 0.5\nfor iters in range(args.iterations):\n    encoded, encoder_h_1, encoder_h_2, encoder_h_3 = encoder(\n        res, encoder_h_1, encoder_h_2, encoder_h_3)\n\n    code = binarizer(encoded)\n\n    output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder(\n        code, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4)\n\n    res = res - output\n    codes.append(code.data.cpu().numpy())\n\n    print('Iter: {:02d}; Loss: {:.06f}'.format(iters, res.data.abs().mean()))\n\ncodes = (np.stack(codes).astype(np.int8) + 1) // 2\n\nexport = np.packbits(codes.reshape(-1))\n\nnp.savez_compressed(args.output, shape=codes.shape, codes=export)\n"
  },
  {
    "path": "functions/__init__.py",
    "content": "from .sign import Sign\n"
  },
  {
    "path": "functions/sign.py",
    "content": "import torch\nfrom torch.autograd import Function\n\n\nclass Sign(Function):\n    \"\"\"\n    Variable Rate Image Compression with Recurrent Neural Networks\n    https://arxiv.org/abs/1511.06085\n    \"\"\"\n\n    def __init__(self):\n        super(Sign, self).__init__()\n\n    @staticmethod\n    def forward(ctx, input, is_training=True):\n        # Apply quantization noise while only training\n        if is_training:\n            prob = input.new(input.size()).uniform_()\n            x = input.clone()\n            x[(1 - input) / 2 <= prob] = 1\n            x[(1 - input) / 2 > prob] = -1\n            return x\n        else:\n            return input.sign()\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        return grad_output, None\n"
  },
  {
    "path": "metric.py",
    "content": "## some function borrowed from\n## https://github.com/tensorflow/models/blob/master/compression/image_encoder/msssim.py\n\"\"\"Python implementation of MS-SSIM.\n\nUsage:\n\npython msssim.py --original_image=original.png --compared_image=distorted.png\n\"\"\"\nimport argparse\n\nimport numpy as np\nfrom scipy import signal\nfrom scipy.ndimage.filters import convolve\nfrom PIL import Image\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--metric', '-m', type=str, default='all', help='metric')\nparser.add_argument(\n    '--original-image', '-o', type=str, required=True, help='original image')\nparser.add_argument(\n    '--compared-image', '-c', type=str, required=True, help='compared image')\nargs = parser.parse_args()\n\n\ndef _FSpecialGauss(size, sigma):\n    \"\"\"Function to mimic the 'fspecial' gaussian MATLAB function.\"\"\"\n    radius = size // 2\n    offset = 0.0\n    start, stop = -radius, radius + 1\n    if size % 2 == 0:\n        offset = 0.5\n        stop -= 1\n    x, y = np.mgrid[offset + start:stop, offset + start:stop]\n    assert len(x) == size\n    g = np.exp(-((x**2 + y**2) / (2.0 * sigma**2)))\n    return g / g.sum()\n\n\ndef _SSIMForMultiScale(img1,\n                       img2,\n                       max_val=255,\n                       filter_size=11,\n                       filter_sigma=1.5,\n                       k1=0.01,\n                       k2=0.03):\n    \"\"\"Return the Structural Similarity Map between `img1` and `img2`.\n\n  This function attempts to match the functionality of ssim_index_new.m by\n  Zhou Wang: http://www.cns.nyu.edu/~lcv/ssim/msssim.zip\n\n  Arguments:\n    img1: Numpy array holding the first RGB image batch.\n    img2: Numpy array holding the second RGB image batch.\n    max_val: the dynamic range of the images (i.e., the difference between the\n      maximum the and minimum allowed values).\n    filter_size: Size of blur kernel to use (will be reduced for small images).\n    filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced\n      for small images).\n    k1: Constant used to maintain stability in the SSIM calculation (0.01 in\n      the original paper).\n    k2: Constant used to maintain stability in the SSIM calculation (0.03 in\n      the original paper).\n\n  Returns:\n    Pair containing the mean SSIM and contrast sensitivity between `img1` and\n    `img2`.\n\n  Raises:\n    RuntimeError: If input images don't have the same shape or don't have four\n      dimensions: [batch_size, height, width, depth].\n  \"\"\"\n    if img1.shape != img2.shape:\n        raise RuntimeError(\n            'Input images must have the same shape (%s vs. %s).', img1.shape,\n            img2.shape)\n    if img1.ndim != 4:\n        raise RuntimeError('Input images must have four dimensions, not %d',\n                           img1.ndim)\n\n    img1 = img1.astype(np.float64)\n    img2 = img2.astype(np.float64)\n    _, height, width, _ = img1.shape\n\n    # Filter size can't be larger than height or width of images.\n    size = min(filter_size, height, width)\n\n    # Scale down sigma if a smaller filter size is used.\n    sigma = size * filter_sigma / filter_size if filter_size else 0\n\n    if filter_size:\n        window = np.reshape(_FSpecialGauss(size, sigma), (1, size, size, 1))\n        mu1 = signal.fftconvolve(img1, window, mode='valid')\n        mu2 = signal.fftconvolve(img2, window, mode='valid')\n        sigma11 = signal.fftconvolve(img1 * img1, window, mode='valid')\n        sigma22 = signal.fftconvolve(img2 * img2, window, mode='valid')\n        sigma12 = signal.fftconvolve(img1 * img2, window, mode='valid')\n    else:\n        # Empty blur kernel so no need to convolve.\n        mu1, mu2 = img1, img2\n        sigma11 = img1 * img1\n        sigma22 = img2 * img2\n        sigma12 = img1 * img2\n\n    mu11 = mu1 * mu1\n    mu22 = mu2 * mu2\n    mu12 = mu1 * mu2\n    sigma11 -= mu11\n    sigma22 -= mu22\n    sigma12 -= mu12\n\n    # Calculate intermediate values used by both ssim and cs_map.\n    c1 = (k1 * max_val)**2\n    c2 = (k2 * max_val)**2\n    v1 = 2.0 * sigma12 + c2\n    v2 = sigma11 + sigma22 + c2\n    ssim = np.mean((((2.0 * mu12 + c1) * v1) / ((mu11 + mu22 + c1) * v2)))\n    cs = np.mean(v1 / v2)\n    return ssim, cs\n\n\ndef MultiScaleSSIM(img1,\n                   img2,\n                   max_val=255,\n                   filter_size=11,\n                   filter_sigma=1.5,\n                   k1=0.01,\n                   k2=0.03,\n                   weights=None):\n    \"\"\"Return the MS-SSIM score between `img1` and `img2`.\n\n  This function implements Multi-Scale Structural Similarity (MS-SSIM) Image\n  Quality Assessment according to Zhou Wang's paper, \"Multi-scale structural\n  similarity for image quality assessment\" (2003).\n  Link: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf\n\n  Author's MATLAB implementation:\n  http://www.cns.nyu.edu/~lcv/ssim/msssim.zip\n\n  Arguments:\n    img1: Numpy array holding the first RGB image batch.\n    img2: Numpy array holding the second RGB image batch.\n    max_val: the dynamic range of the images (i.e., the difference between the\n      maximum the and minimum allowed values).\n    filter_size: Size of blur kernel to use (will be reduced for small images).\n    filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced\n      for small images).\n    k1: Constant used to maintain stability in the SSIM calculation (0.01 in\n      the original paper).\n    k2: Constant used to maintain stability in the SSIM calculation (0.03 in\n      the original paper).\n    weights: List of weights for each level; if none, use five levels and the\n      weights from the original paper.\n\n  Returns:\n    MS-SSIM score between `img1` and `img2`.\n\n  Raises:\n    RuntimeError: If input images don't have the same shape or don't have four\n      dimensions: [batch_size, height, width, depth].\n  \"\"\"\n    if img1.shape != img2.shape:\n        raise RuntimeError(\n            'Input images must have the same shape (%s vs. %s).', img1.shape,\n            img2.shape)\n    if img1.ndim != 4:\n        raise RuntimeError('Input images must have four dimensions, not %d',\n                           img1.ndim)\n\n    # Note: default weights don't sum to 1.0 but do match the paper / matlab code.\n    weights = np.array(weights if weights else\n                       [0.0448, 0.2856, 0.3001, 0.2363, 0.1333])\n    levels = weights.size\n    downsample_filter = np.ones((1, 2, 2, 1)) / 4.0\n    im1, im2 = [x.astype(np.float64) for x in [img1, img2]]\n    mssim = np.array([])\n    mcs = np.array([])\n    for _ in range(levels):\n        ssim, cs = _SSIMForMultiScale(\n            im1,\n            im2,\n            max_val=max_val,\n            filter_size=filter_size,\n            filter_sigma=filter_sigma,\n            k1=k1,\n            k2=k2)\n        mssim = np.append(mssim, ssim)\n        mcs = np.append(mcs, cs)\n        filtered = [\n            convolve(im, downsample_filter, mode='reflect')\n            for im in [im1, im2]\n        ]\n        im1, im2 = [x[:, ::2, ::2, :] for x in filtered]\n    return (np.prod(mcs[0:levels - 1]**weights[0:levels - 1]) *\n            (mssim[levels - 1]**weights[levels - 1]))\n\n\ndef msssim(original, compared):\n    if isinstance(original, str):\n        original = np.array(Image.open(original).convert('RGB'), dtype=np.float32)\n    if isinstance(compared, str):\n        compared = np.array(Image.open(compared).convert('RGB'), dtype=np.float32)\n\n    original = original[None, ...] if original.ndim == 3 else original\n    compared = compared[None, ...] if compared.ndim == 3 else compared\n\n    return MultiScaleSSIM(original, compared, max_val=255)\n\n\ndef psnr(original, compared):\n    if isinstance(original, str):\n        original = np.array(Image.open(original).convert('RGB'), dtype=np.float32)\n    if isinstance(compared, str):\n        compared = np.array(Image.open(compared).convert('RGB'), dtype=np.float32)\n\n    mse = np.mean(np.square(original - compared))\n    psnr = np.clip(\n        np.multiply(np.log10(255. * 255. / mse[mse > 0.]), 10.), 0., 99.99)[0]\n    return psnr\n\n\ndef main():\n    if args.metric != 'psnr':\n        print(msssim(args.original_image, args.compared_image), end='')\n    if args.metric != 'ssim':\n        print(psnr(args.original_image, args.compared_image), end='')\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "modules/__init__.py",
    "content": "from .conv_rnn import ConvLSTMCell  #, ConvLSTM\nfrom .sign import Sign\n"
  },
  {
    "path": "modules/conv_rnn.py",
    "content": "import torch.nn as nn\nimport torch.nn.functional as F\nimport torch\nfrom torch.autograd import Variable\nfrom torch.nn.modules.utils import _pair\n\n\nclass ConvRNNCellBase(nn.Module):\n    def __repr__(self):\n        s = (\n            '{name}({input_channels}, {hidden_channels}, kernel_size={kernel_size}'\n            ', stride={stride}')\n        if self.padding != (0, ) * len(self.padding):\n            s += ', padding={padding}'\n        if self.dilation != (1, ) * len(self.dilation):\n            s += ', dilation={dilation}'\n        s += ', hidden_kernel_size={hidden_kernel_size}'\n        s += ')'\n        return s.format(name=self.__class__.__name__, **self.__dict__)\n\n\nclass ConvLSTMCell(ConvRNNCellBase):\n    def __init__(self,\n                 input_channels,\n                 hidden_channels,\n                 kernel_size=3,\n                 stride=1,\n                 padding=0,\n                 dilation=1,\n                 hidden_kernel_size=1,\n                 bias=True):\n        super(ConvLSTMCell, self).__init__()\n        self.input_channels = input_channels\n        self.hidden_channels = hidden_channels\n\n        self.kernel_size = _pair(kernel_size)\n        self.stride = _pair(stride)\n        self.padding = _pair(padding)\n        self.dilation = _pair(dilation)\n\n        self.hidden_kernel_size = _pair(hidden_kernel_size)\n\n        hidden_padding = _pair(hidden_kernel_size // 2)\n\n        gate_channels = 4 * self.hidden_channels\n        self.conv_ih = nn.Conv2d(\n            in_channels=self.input_channels,\n            out_channels=gate_channels,\n            kernel_size=self.kernel_size,\n            stride=self.stride,\n            padding=self.padding,\n            dilation=self.dilation,\n            bias=bias)\n\n        self.conv_hh = nn.Conv2d(\n            in_channels=self.hidden_channels,\n            out_channels=gate_channels,\n            kernel_size=hidden_kernel_size,\n            stride=1,\n            padding=hidden_padding,\n            dilation=1,\n            bias=bias)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        self.conv_ih.reset_parameters()\n        self.conv_hh.reset_parameters()\n\n    def forward(self, input, hidden):\n        hx, cx = hidden\n        gates = self.conv_ih(input) + self.conv_hh(hx)\n\n        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)\n\n        ingate = F.sigmoid(ingate)\n        forgetgate = F.sigmoid(forgetgate)\n        cellgate = F.tanh(cellgate)\n        outgate = F.sigmoid(outgate)\n\n        cy = (forgetgate * cx) + (ingate * cellgate)\n        hy = outgate * F.tanh(cy)\n\n        return hy, cy\n"
  },
  {
    "path": "modules/sign.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom functions import Sign as SignFunction\n\n\nclass Sign(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, x):\n        return SignFunction.apply(x, self.training)\n"
  },
  {
    "path": "network.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom modules import ConvLSTMCell, Sign\n\n\nclass EncoderCell(nn.Module):\n    def __init__(self):\n        super(EncoderCell, self).__init__()\n\n        self.conv = nn.Conv2d(\n            3, 64, kernel_size=3, stride=2, padding=1, bias=False)\n        self.rnn1 = ConvLSTMCell(\n            64,\n            256,\n            kernel_size=3,\n            stride=2,\n            padding=1,\n            hidden_kernel_size=1,\n            bias=False)\n        self.rnn2 = ConvLSTMCell(\n            256,\n            512,\n            kernel_size=3,\n            stride=2,\n            padding=1,\n            hidden_kernel_size=1,\n            bias=False)\n        self.rnn3 = ConvLSTMCell(\n            512,\n            512,\n            kernel_size=3,\n            stride=2,\n            padding=1,\n            hidden_kernel_size=1,\n            bias=False)\n\n    def forward(self, input, hidden1, hidden2, hidden3):\n        x = self.conv(input)\n\n        hidden1 = self.rnn1(x, hidden1)\n        x = hidden1[0]\n\n        hidden2 = self.rnn2(x, hidden2)\n        x = hidden2[0]\n\n        hidden3 = self.rnn3(x, hidden3)\n        x = hidden3[0]\n\n        return x, hidden1, hidden2, hidden3\n\n\nclass Binarizer(nn.Module):\n    def __init__(self):\n        super(Binarizer, self).__init__()\n        self.conv = nn.Conv2d(512, 32, kernel_size=1, bias=False)\n        self.sign = Sign()\n\n    def forward(self, input):\n        feat = self.conv(input)\n        x = F.tanh(feat)\n        return self.sign(x)\n\n\nclass DecoderCell(nn.Module):\n    def __init__(self):\n        super(DecoderCell, self).__init__()\n\n        self.conv1 = nn.Conv2d(\n            32, 512, kernel_size=1, stride=1, padding=0, bias=False)\n        self.rnn1 = ConvLSTMCell(\n            512,\n            512,\n            kernel_size=3,\n            stride=1,\n            padding=1,\n            hidden_kernel_size=1,\n            bias=False)\n        self.rnn2 = ConvLSTMCell(\n            128,\n            512,\n            kernel_size=3,\n            stride=1,\n            padding=1,\n            hidden_kernel_size=1,\n            bias=False)\n        self.rnn3 = ConvLSTMCell(\n            128,\n            256,\n            kernel_size=3,\n            stride=1,\n            padding=1,\n            hidden_kernel_size=3,\n            bias=False)\n        self.rnn4 = ConvLSTMCell(\n            64,\n            128,\n            kernel_size=3,\n            stride=1,\n            padding=1,\n            hidden_kernel_size=3,\n            bias=False)\n        self.conv2 = nn.Conv2d(\n            32, 3, kernel_size=1, stride=1, padding=0, bias=False)\n\n    def forward(self, input, hidden1, hidden2, hidden3, hidden4):\n        x = self.conv1(input)\n\n        hidden1 = self.rnn1(x, hidden1)\n        x = hidden1[0]\n        x = F.pixel_shuffle(x, 2)\n\n        hidden2 = self.rnn2(x, hidden2)\n        x = hidden2[0]\n        x = F.pixel_shuffle(x, 2)\n\n        hidden3 = self.rnn3(x, hidden3)\n        x = hidden3[0]\n        x = F.pixel_shuffle(x, 2)\n\n        hidden4 = self.rnn4(x, hidden4)\n        x = hidden4[0]\n        x = F.pixel_shuffle(x, 2)\n\n        x = F.tanh(self.conv2(x)) / 2\n        return x, hidden1, hidden2, hidden3, hidden4\n"
  },
  {
    "path": "test/calc_ssim.sh",
    "content": "#!/bin/bash\n\nLSTM=test/lstm_ssim.csv\nJPEG=test/jpeg_ssim.csv\n\necho -n \"\" > $LSTM\nfor i in {01..24..1}; do\n  echo Processing test/decoded/kodim$i\n  for j in {00..15..1}; do\n    echo -n `python metric.py -m ssim -o test/images/kodim$i.png -c test/decoded/kodim$i/$j.png`', ' >> $LSTM\n  done\n  echo \"\" >> $LSTM\ndone\n\necho -n \"\" > $JPEG\nfor i in {01..24..1}; do\n  echo Processing test/jpeg/kodim$i\n  for j in {01..20..1}; do\n    echo -n `python metric.py -m ssim -o test/images/kodim$i.png -c test/jpeg/kodim$i/$j.jpg`', ' >> $JPEG\n  done\n  echo \"\" >> $JPEG\ndone\n"
  },
  {
    "path": "test/draw_rd.py",
    "content": "import os\n\nimport numpy as np\nfrom scipy.misc import imread\nimport matplotlib.pyplot as plt\n\nline = True\n\nlstm_ssim = np.genfromtxt('test/lstm_ssim.csv', delimiter=',')\nlstm_ssim = lstm_ssim[:, :-1]\nif line:\n    lstm_ssim = np.mean(lstm_ssim, axis=0)\n    lstm_bpp = np.arange(1, 17) / 192 * 24\n    plt.plot(lstm_bpp, lstm_ssim, label='LSTM', marker='o')\nelse:\n    lstm_bpp = np.stack([np.arange(1, 17) for _ in range(24)]) / 192 * 24\n    plt.scatter(\n        lstm_bpp.reshape(-1), lstm_ssim.reshape(-1), label='LSTM', marker='o')\n\njpeg_ssim = np.genfromtxt('test/jpeg_ssim.csv', delimiter=',')\njpeg_ssim = jpeg_ssim[:, :-1]\nif line:\n    jpeg_ssim = np.mean(jpeg_ssim, axis=0)\n\njpeg_bpp = np.array([\n    os.path.getsize('test/jpeg/kodim{:02d}/{:02d}.jpg'.format(i, q)) * 8 /\n    (imread('test/jpeg/kodim{:02d}/{:02d}.jpg'.format(i, q)).size // 3)\n    for i in range(1, 25) for q in range(1, 21)\n]).reshape(24, 20)\n\nif line:\n    jpeg_bpp = np.mean(jpeg_bpp, axis=0)\n    plt.plot(jpeg_bpp, jpeg_ssim, label='JPEG', marker='x')\nelse:\n    plt.scatter(\n        jpeg_bpp.reshape(-1), jpeg_ssim.reshape(-1), label='JPEG', marker='x')\n\nplt.xlim(0., 2.)\nplt.ylim(0.7, 1.0)\nplt.xlabel('bit per pixel')\nplt.ylabel('MS-SSIM')\nplt.legend()\nplt.show()\n"
  },
  {
    "path": "test/enc_dec.sh",
    "content": "#!/bin/bash\n\nfor i in {01..24..1}; do\n  echo Encoding test/images/kodim$i.png\n  mkdir -p test/codes\n  python encoder.py --model checkpoint/encoder_epoch_00000066.pth --input test/images/kodim$i.png --cuda --output test/codes/kodim$i --iterations 16\n\n  echo Decoding test/codes/kodim$i.npz\n  mkdir -p test/decoded/kodim$i\n  python decoder.py --model checkpoint/decoder_epoch_00000066.pth --input test/codes/kodim$i.npz --cuda --output test/decoded/kodim$i\ndone\n"
  },
  {
    "path": "test/get_kodak.sh",
    "content": "#!/bin/bash\n\nmkdir -p test/images\n\nfor i in {01..24..1}; do\n  echo ${i}\n  wget http://r0k.us/graphics/kodak/kodak/kodim${i}.png -O test/images/kodim${i}.png\ndone\n"
  },
  {
    "path": "test/jpeg.sh",
    "content": "#!/bin/bash\n\nfor i in {01..24..1}; do\n  echo JPEG Encoding test/images/kodim$i.png\n  mkdir -p test/jpeg/kodim$i\n  for j in {1..20..1}; do\n    convert test/images/kodim$i.png -quality $(($j*5)) -sampling-factor 4:2:0 test/jpeg/kodim$i/`printf \"%02d\" $j`.jpg\n  done\ndone\n"
  },
  {
    "path": "test/jpeg_ssim.csv",
    "content": "0.818072219541, 0.915486738863, 0.941250388079, 0.954959594001, 0.964432174175, 0.969521944618, 0.974450024801, 0.977023476788, 0.979556769243, 0.981440100069, 0.98303158495, 0.984619828345, 0.986521991089, 0.988233765029, 0.990034353477, 0.991908912291, 0.993859786213, 0.995872272301, 0.997704734009, 0.999369880278, \n0.666065535368, 0.833785271322, 0.862238250732, 0.894437174757, 0.913606004326, 0.927403225046, 0.936700087174, 0.942087207378, 0.947302906233, 0.952881599351, 0.957102998846, 0.961079845306, 0.964679737793, 0.969085669754, 0.972862674891, 0.977389816095, 0.981929271455, 0.986764889825, 0.992090856598, 0.997867382574, \n0.810456808165, 0.885532082425, 0.923215832246, 0.942269281224, 0.951013313071, 0.960641959857, 0.965020974147, 0.968983036738, 0.972687587703, 0.97536435516, 0.977124722602, 0.979205659824, 0.981645499592, 0.983797485181, 0.985868845072, 0.988406175337, 0.990561443404, 0.992977803974, 0.995363893283, 0.998186608337, \n0.736205708103, 0.864789258571, 0.903917595396, 0.928365043206, 0.940992130793, 0.950190769917, 0.957041193005, 0.961645055775, 0.965608843423, 0.968846156644, 0.971464980647, 0.974338971545, 0.976288817151, 0.979326195514, 0.981731800268, 0.985073759122, 0.988183892013, 0.991310684778, 0.99484060445, 0.998539576456, \n0.839808316647, 0.914569680204, 0.944585553268, 0.958189541858, 0.965933250076, 0.971345068489, 0.975120264002, 0.978007417795, 0.980112049762, 0.982031704186, 0.983617064118, 0.985301594451, 0.986864130976, 0.988550054893, 0.990140431145, 0.991912697905, 0.993659191703, 0.995525236463, 0.997451078786, 0.999445634028, \n0.774314920949, 0.875787337352, 0.923727573855, 0.943054472644, 0.95321690632, 0.961164793684, 0.967588086634, 0.971259154783, 0.974502697083, 0.976536008639, 0.979070907986, 0.980730423348, 0.982772471532, 0.985208074663, 0.987072532816, 0.989603098041, 0.991848356056, 0.993989496887, 0.996261532852, 0.998833407115, \n0.853028224878, 0.924594399288, 0.944577324582, 0.962892871899, 0.968175881476, 0.973912836384, 0.977611559175, 0.980013739677, 0.981372066609, 0.983489296322, 0.984974095691, 0.986316785418, 0.987796973658, 0.989294643654, 0.990415707239, 0.991958140094, 0.993469699667, 0.99504535025, 0.996735156379, 0.998906782293, \n0.880688294539, 0.933204609681, 0.953650276172, 0.964610058969, 0.971655363423, 0.976231891224, 0.979402609806, 0.981754260409, 0.983510269845, 0.98503201627, 0.986234509006, 0.987594582558, 0.988831108808, 0.990315455876, 0.991614978601, 0.993112910369, 0.994592787429, 0.996128854807, 0.997766030171, 0.999473104286, \n0.853720880853, 0.912554857475, 0.930711615275, 0.950455648101, 0.960958503215, 0.96667403636, 0.971362349207, 0.974072622141, 0.975613226504, 0.978781040524, 0.979969048509, 0.981158076869, 0.982883137674, 0.984722620635, 0.98658900113, 0.988314796092, 0.99036007625, 0.992362984817, 0.994781216862, 0.998299761454, \n0.81076220338, 0.893706653475, 0.927362493596, 0.943724930231, 0.955211693836, 0.962355553594, 0.96829254466, 0.97169265604, 0.9743439019, 0.976751528211, 0.978728976699, 0.980505251819, 0.982242719802, 0.98445680889, 0.986286665859, 0.988491766444, 0.990776417926, 0.992905056354, 0.99523724111, 0.998527390073, \n0.791555686892, 0.881440991243, 0.920860393213, 0.942552828487, 0.953270936085, 0.959839333526, 0.966037446908, 0.969412253347, 0.972687657063, 0.975263831002, 0.977073908411, 0.979106240394, 0.981338251384, 0.984181227447, 0.985917501452, 0.988715661162, 0.991071452023, 0.993500123657, 0.996109324665, 0.999023620304, \n0.801456772286, 0.878467768641, 0.908098341977, 0.933136250773, 0.946448137076, 0.955111533806, 0.963027394102, 0.967258044795, 0.970678461713, 0.974186408589, 0.975688326117, 0.977937704633, 0.980300465358, 0.982588089404, 0.985135222036, 0.987727956878, 0.990090763787, 0.992444149444, 0.995132471845, 0.998223031362, \n0.797066626991, 0.891473428482, 0.927754609701, 0.944168588194, 0.953960755811, 0.960570142061, 0.96609929426, 0.969401647648, 0.97267245227, 0.975204503945, 0.977409332209, 0.979649966483, 0.981643846133, 0.984198624552, 0.986564293774, 0.989219718804, 0.991824162369, 0.994323089168, 0.996765571563, 0.999382383983, \n0.785606294091, 0.88697381026, 0.922042917455, 0.942085547292, 0.952637067226, 0.95956056371, 0.965247607479, 0.968692477325, 0.971654628021, 0.974079838834, 0.976237478008, 0.978466192909, 0.980687087441, 0.983108513663, 0.985172089232, 0.987865648168, 0.990607337903, 0.993466132139, 0.996502037444, 0.999093888752, \n0.791792877213, 0.87257820722, 0.910251232828, 0.931499697455, 0.943952143348, 0.951516025861, 0.957952121781, 0.962339434864, 0.966062146417, 0.968616514774, 0.971381676811, 0.973583176861, 0.976055057192, 0.979084758838, 0.98170749589, 0.984744675444, 0.987623582667, 0.990827225403, 0.994500953634, 0.998296977796, \n0.77155359043, 0.869021329171, 0.916734846729, 0.939085193953, 0.951543198728, 0.959581932049, 0.965130869692, 0.970274846734, 0.973566829975, 0.976120453022, 0.978316404127, 0.980331703541, 0.982562198606, 0.984891459627, 0.986832139561, 0.98925186408, 0.991631382422, 0.993799878089, 0.996076708336, 0.998724170375, \n0.834596697018, 0.910219914842, 0.938953487334, 0.954867094446, 0.965021373662, 0.971626167075, 0.975397860726, 0.978190765304, 0.980255255374, 0.982316415294, 0.983438406514, 0.985042910404, 0.986538091767, 0.988296632234, 0.989480478459, 0.991169147712, 0.992838525537, 0.994484549662, 0.996349006533, 0.998796129593, \n0.789379799628, 0.883671741466, 0.918305025089, 0.937438007277, 0.948741402993, 0.955676286943, 0.961312308858, 0.965026524442, 0.96851536141, 0.971303982655, 0.973488277598, 0.975794458236, 0.978215624859, 0.980917182109, 0.983425193859, 0.986089845307, 0.988720428405, 0.991544643936, 0.994909928794, 0.998959191331, \n0.798779556224, 0.879522751508, 0.920868515488, 0.940119021497, 0.951843311114, 0.959017895362, 0.965208194261, 0.969082003508, 0.97221027879, 0.974691498988, 0.976855260305, 0.978953094688, 0.981167317793, 0.983437790312, 0.985786929537, 0.988162887114, 0.990644384686, 0.993026404153, 0.995507443715, 0.998621645017, \n0.879263930866, 0.922020986557, 0.946493775014, 0.958297320378, 0.964859488123, 0.970312211669, 0.97424567735, 0.976421680027, 0.978413032059, 0.979756913163, 0.981009201268, 0.982435429622, 0.98394030031, 0.98552902602, 0.987216102866, 0.988920029566, 0.990524992279, 0.992390079391, 0.994639444871, 0.998357673395, \n0.827950174866, 0.894831372155, 0.9313264476, 0.94757593276, 0.956105682938, 0.962399480035, 0.967538440479, 0.971214741258, 0.973124877585, 0.976048275534, 0.977788559952, 0.979504668032, 0.981543513231, 0.983649526565, 0.985455734324, 0.987721369114, 0.990020663907, 0.992286543302, 0.994788654436, 0.998404057223, \n0.754610909207, 0.860747233451, 0.899736485589, 0.925743580664, 0.939166571882, 0.948339482047, 0.954553566064, 0.958885994659, 0.962909621012, 0.966353236056, 0.969046685166, 0.971977721769, 0.974626509781, 0.977835303452, 0.980470725918, 0.983736275578, 0.986964243248, 0.990258297647, 0.993951518828, 0.998556710426, \n0.784958874194, 0.874722011293, 0.915389887756, 0.934343596754, 0.948237654386, 0.95672974228, 0.963437883085, 0.96720024908, 0.970772553663, 0.973495160679, 0.975921421699, 0.978227890716, 0.980404390223, 0.983010440205, 0.985215187951, 0.987733337063, 0.98989913274, 0.992273798878, 0.994922548059, 0.99824153924, \n0.822296667468, 0.89834966427, 0.933957434279, 0.951462994733, 0.95978880675, 0.965804629849, 0.970633032468, 0.973789346233, 0.976510124286, 0.978522727364, 0.980401981947, 0.982057448136, 0.984052101154, 0.986046587711, 0.98771061484, 0.989891431917, 0.992010784627, 0.994204644975, 0.996543052165, 0.99909102487, \n"
  },
  {
    "path": "test/lstm_ssim.csv",
    "content": "0.769036023469, 0.88985832751, 0.927751848949, 0.944983151872, 0.95575157289, 0.963951705505, 0.97091182084, 0.975998458878, 0.979411580032, 0.981903538902, 0.984059023874, 0.985847690249, 0.9873427695, 0.988557280396, 0.989808215992, 0.990955339337, \n0.805634389938, 0.885551865434, 0.926222844478, 0.945195184339, 0.956706397486, 0.964635434939, 0.970561396161, 0.975371321751, 0.979008896093, 0.981623105097, 0.983889966323, 0.985693922222, 0.987117323848, 0.988116453261, 0.989126652794, 0.990035372754, \n0.886382589545, 0.944226634293, 0.966099553101, 0.97574891786, 0.981699305186, 0.985303117183, 0.988158479953, 0.990353107687, 0.991742329762, 0.992736636125, 0.993586103345, 0.994247922172, 0.994823658832, 0.99523818175, 0.99566695018, 0.996024457998, \n0.844943736432, 0.912763739137, 0.941348917006, 0.955485600323, 0.964352299005, 0.971086350909, 0.976031220021, 0.980033889346, 0.983037018013, 0.985100156288, 0.986891083059, 0.98828168163, 0.989371458473, 0.990204501149, 0.991008062343, 0.991775083665, \n0.807296336647, 0.901032376198, 0.936630910374, 0.953867965935, 0.965072366127, 0.973137076713, 0.979042150279, 0.983166487096, 0.985988824175, 0.987817180773, 0.989464580939, 0.990820235888, 0.991892422931, 0.992712253547, 0.993501561405, 0.994148098694, \n0.785345752053, 0.884711958245, 0.92776368098, 0.948715836445, 0.96019570665, 0.96605401155, 0.972573249337, 0.97720752253, 0.980241043367, 0.982484654905, 0.984849828368, 0.986701720731, 0.988188381493, 0.98922008024, 0.990398367023, 0.991365949801, \n0.879784256538, 0.951585216892, 0.972381608615, 0.981022083203, 0.985810271759, 0.988693576841, 0.990952401157, 0.992588896661, 0.993784304761, 0.994579427454, 0.995314920728, 0.99586908212, 0.996291444393, 0.996577798766, 0.996861139444, 0.997104778181, \n0.815477354243, 0.906116356367, 0.944359558351, 0.96028214433, 0.969452971301, 0.975073789021, 0.979816675642, 0.983303430092, 0.985631084073, 0.987351580806, 0.988992309204, 0.99024746499, 0.991323941253, 0.992077409918, 0.992875703722, 0.993551435751, \n0.898231591406, 0.952767120478, 0.971324990566, 0.979139320561, 0.983235587546, 0.985925513104, 0.98813798392, 0.989849456611, 0.991087155229, 0.991997953486, 0.992918185114, 0.993614367759, 0.994164453638, 0.994560726328, 0.994937149117, 0.995284527261, \n0.8749003714, 0.942176320572, 0.966895893079, 0.977493170796, 0.982875921764, 0.986274282021, 0.98850906544, 0.990252593185, 0.991524822651, 0.992438630258, 0.993323088693, 0.994034014192, 0.994607996745, 0.994996589457, 0.995366557151, 0.99569469237, \n0.831185109757, 0.908127863885, 0.94129870614, 0.957395739465, 0.966623340753, 0.971948863236, 0.977120326656, 0.980795107721, 0.983424173894, 0.9853323917, 0.987235341184, 0.988695696085, 0.989908689415, 0.990740343679, 0.991648831536, 0.992402349259, \n0.867455122658, 0.929286374978, 0.956373268572, 0.968857918522, 0.975855487208, 0.979544787408, 0.983269284681, 0.985944004, 0.987814957273, 0.989295675642, 0.990662162931, 0.991705980646, 0.992601345026, 0.99321774356, 0.993851078984, 0.994396313967, \n0.741714560572, 0.852678337269, 0.89776940045, 0.921657297609, 0.937268311557, 0.948777493887, 0.958587922105, 0.966219484116, 0.970930915102, 0.973981680428, 0.976954171415, 0.979158752294, 0.981113945529, 0.982622981303, 0.984346679365, 0.985939206912, \n0.806023887853, 0.898208741526, 0.934023153024, 0.951794401119, 0.962675974888, 0.969812260388, 0.975737969034, 0.980133270486, 0.983291193208, 0.985354416391, 0.98719947195, 0.98876036237, 0.990016818176, 0.990903992042, 0.991817776314, 0.992603235282, \n0.885580139856, 0.938517703814, 0.959134612721, 0.968447444909, 0.974006513851, 0.978175881773, 0.981340248165, 0.983964849591, 0.985997455292, 0.987552480984, 0.988913071725, 0.989989662136, 0.990858978786, 0.991507864087, 0.992101253603, 0.992706639095, \n0.843116295553, 0.918865830119, 0.948986302851, 0.963840566631, 0.971863394331, 0.976064780745, 0.980732356364, 0.98399885847, 0.986220002076, 0.987800243738, 0.989471739682, 0.990764010755, 0.991805901688, 0.992578822333, 0.993415281055, 0.994092268046, \n0.88346061529, 0.946000867977, 0.96760658804, 0.976501600498, 0.981740918991, 0.985389590011, 0.988097110038, 0.990101000036, 0.991452711903, 0.992372729231, 0.993280593371, 0.993979273128, 0.994576920437, 0.994992603394, 0.99541165183, 0.995793186385, \n0.801573520738, 0.894556913607, 0.930557148943, 0.947742213628, 0.958336710515, 0.966270887606, 0.972043805126, 0.976686743156, 0.979997380949, 0.982263084747, 0.984371624652, 0.986043574351, 0.987418653777, 0.988435017399, 0.98943820625, 0.990384899709, \n0.824336788203, 0.914307147717, 0.947162910211, 0.961113167281, 0.969251988341, 0.974708549133, 0.979348243062, 0.982888751196, 0.985039998783, 0.986690243229, 0.988265858776, 0.989457451692, 0.990425556398, 0.991142353579, 0.991913402107, 0.992646966198, \n0.910171904704, 0.952375683078, 0.967737628914, 0.975113825208, 0.979525760009, 0.982903358103, 0.985497414947, 0.987532990507, 0.989022084385, 0.990094546854, 0.991046427233, 0.991786419928, 0.992476722408, 0.992941886737, 0.9934275669, 0.993905147004, \n0.864225726494, 0.929980029382, 0.955778654113, 0.96740495639, 0.974178985114, 0.978346408898, 0.982328012772, 0.985109106162, 0.987107132775, 0.988477856206, 0.989870184226, 0.990930157844, 0.991839380567, 0.992475550555, 0.993157023569, 0.993742594496, \n0.803770468068, 0.894709100591, 0.93001725197, 0.946972008686, 0.957454439962, 0.965257644855, 0.971242168162, 0.97578798359, 0.979145612293, 0.981419222339, 0.983528566216, 0.985134373586, 0.986496893548, 0.98753676895, 0.988504824958, 0.989410139829, \n0.887667125483, 0.950280631321, 0.971841626657, 0.980425626652, 0.984775140336, 0.987822408265, 0.989908020125, 0.991464951332, 0.992696046157, 0.993535144861, 0.99432064875, 0.994901414575, 0.995374559063, 0.995705134049, 0.996001978174, 0.996274517893, \n0.828518548366, 0.907992823067, 0.940938577175, 0.956594052005, 0.966157200834, 0.972833437098, 0.977987634362, 0.981907039284, 0.984516045128, 0.986238836359, 0.987891164633, 0.989148044611, 0.990264907045, 0.991071889081, 0.991900734924, 0.992623228933, \n"
  },
  {
    "path": "train.py",
    "content": "import time\nimport os\nimport argparse\n\nimport numpy as np\n\nimport torch\nimport torch.optim as optim\nimport torch.optim.lr_scheduler as LS\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\nimport torch.utils.data as data\nfrom torchvision import transforms\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\n    '--batch-size', '-N', type=int, default=32, help='batch size')\nparser.add_argument(\n    '--train', '-f', required=True, type=str, help='folder of training images')\nparser.add_argument(\n    '--max-epochs', '-e', type=int, default=200, help='max epochs')\nparser.add_argument('--lr', type=float, default=0.0005, help='learning rate')\n# parser.add_argument('--cuda', '-g', action='store_true', help='enables cuda')\nparser.add_argument(\n    '--iterations', type=int, default=16, help='unroll iterations')\nparser.add_argument('--checkpoint', type=int, help='unroll iterations')\nargs = parser.parse_args()\n\n## load 32x32 patches from images\nimport dataset\n\ntrain_transform = transforms.Compose([\n    transforms.RandomCrop((32, 32)),\n    transforms.ToTensor(),\n])\n\ntrain_set = dataset.ImageFolder(root=args.train, transform=train_transform)\n\ntrain_loader = data.DataLoader(\n    dataset=train_set, batch_size=args.batch_size, shuffle=True, num_workers=1)\n\nprint('total images: {}; total batches: {}'.format(\n    len(train_set), len(train_loader)))\n\n## load networks on GPU\nimport network\n\nencoder = network.EncoderCell().cuda()\nbinarizer = network.Binarizer().cuda()\ndecoder = network.DecoderCell().cuda()\n\nsolver = optim.Adam(\n    [\n        {\n            'params': encoder.parameters()\n        },\n        {\n            'params': binarizer.parameters()\n        },\n        {\n            'params': decoder.parameters()\n        },\n    ],\n    lr=args.lr)\n\n\ndef resume(epoch=None):\n    if epoch is None:\n        s = 'iter'\n        epoch = 0\n    else:\n        s = 'epoch'\n\n    encoder.load_state_dict(\n        torch.load('checkpoint/encoder_{}_{:08d}.pth'.format(s, epoch)))\n    binarizer.load_state_dict(\n        torch.load('checkpoint/binarizer_{}_{:08d}.pth'.format(s, epoch)))\n    decoder.load_state_dict(\n        torch.load('checkpoint/decoder_{}_{:08d}.pth'.format(s, epoch)))\n\n\ndef save(index, epoch=True):\n    if not os.path.exists('checkpoint'):\n        os.mkdir('checkpoint')\n\n    if epoch:\n        s = 'epoch'\n    else:\n        s = 'iter'\n\n    torch.save(encoder.state_dict(), 'checkpoint/encoder_{}_{:08d}.pth'.format(\n        s, index))\n\n    torch.save(binarizer.state_dict(),\n               'checkpoint/binarizer_{}_{:08d}.pth'.format(s, index))\n\n    torch.save(decoder.state_dict(), 'checkpoint/decoder_{}_{:08d}.pth'.format(\n        s, index))\n\n\n# resume()\n\nscheduler = LS.MultiStepLR(solver, milestones=[3, 10, 20, 50, 100], gamma=0.5)\n\nlast_epoch = 0\nif args.checkpoint:\n    resume(args.checkpoint)\n    last_epoch = args.checkpoint\n    scheduler.last_epoch = last_epoch - 1\n\nfor epoch in range(last_epoch + 1, args.max_epochs + 1):\n\n    scheduler.step()\n\n    for batch, data in enumerate(train_loader):\n        batch_t0 = time.time()\n\n        ## init lstm state\n        encoder_h_1 = (Variable(torch.zeros(data.size(0), 256, 8, 8).cuda()),\n                       Variable(torch.zeros(data.size(0), 256, 8, 8).cuda()))\n        encoder_h_2 = (Variable(torch.zeros(data.size(0), 512, 4, 4).cuda()),\n                       Variable(torch.zeros(data.size(0), 512, 4, 4).cuda()))\n        encoder_h_3 = (Variable(torch.zeros(data.size(0), 512, 2, 2).cuda()),\n                       Variable(torch.zeros(data.size(0), 512, 2, 2).cuda()))\n\n        decoder_h_1 = (Variable(torch.zeros(data.size(0), 512, 2, 2).cuda()),\n                       Variable(torch.zeros(data.size(0), 512, 2, 2).cuda()))\n        decoder_h_2 = (Variable(torch.zeros(data.size(0), 512, 4, 4).cuda()),\n                       Variable(torch.zeros(data.size(0), 512, 4, 4).cuda()))\n        decoder_h_3 = (Variable(torch.zeros(data.size(0), 256, 8, 8).cuda()),\n                       Variable(torch.zeros(data.size(0), 256, 8, 8).cuda()))\n        decoder_h_4 = (Variable(torch.zeros(data.size(0), 128, 16, 16).cuda()),\n                       Variable(torch.zeros(data.size(0), 128, 16, 16).cuda()))\n\n        patches = Variable(data.cuda())\n\n        solver.zero_grad()\n\n        losses = []\n\n        res = patches - 0.5\n\n        bp_t0 = time.time()\n\n        for _ in range(args.iterations):\n            encoded, encoder_h_1, encoder_h_2, encoder_h_3 = encoder(\n                res, encoder_h_1, encoder_h_2, encoder_h_3)\n\n            codes = binarizer(encoded)\n\n            output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder(\n                codes, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4)\n\n            res = res - output\n            losses.append(res.abs().mean())\n\n        bp_t1 = time.time()\n\n        loss = sum(losses) / args.iterations\n        loss.backward()\n\n        solver.step()\n\n        batch_t1 = time.time()\n\n        print(\n            '[TRAIN] Epoch[{}]({}/{}); Loss: {:.6f}; Backpropagation: {:.4f} sec; Batch: {:.4f} sec'.\n            format(epoch, batch + 1,\n                   len(train_loader), loss.data[0], bp_t1 - bp_t0, batch_t1 -\n                   batch_t0))\n        print(('{:.4f} ' * args.iterations +\n               '\\n').format(* [l.data[0] for l in losses]))\n\n        index = (epoch - 1) * len(train_loader) + batch\n\n        ## save checkpoint every 500 training steps\n        if index % 500 == 0:\n            save(0, False)\n\n    save(epoch)\n"
  }
]