Repository: 1zb/pytorch-image-comp-rnn Branch: master Commit: d7150f285cfd Files: 20 Total size: 46.6 KB Directory structure: gitextract_yfbt7nhy/ ├── .gitignore ├── README.md ├── dataset.py ├── decoder.py ├── encoder.py ├── functions/ │ ├── __init__.py │ └── sign.py ├── metric.py ├── modules/ │ ├── __init__.py │ ├── conv_rnn.py │ └── sign.py ├── network.py ├── test/ │ ├── calc_ssim.sh │ ├── draw_rd.py │ ├── enc_dec.sh │ ├── get_kodak.sh │ ├── jpeg.sh │ ├── jpeg_ssim.csv │ └── lstm_ssim.csv └── train.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ *.pth *.pyc *.bak *.png *.jpg *.npz !rd.png !/kodim10.png !/bpp-0.125-0.133-ssim-0.865-0.827.png !/bpp-0.250-0.249-ssim-0.937-0.918.png !/bpp-0.375-0.381-ssim-0.963-0.951.png ================================================ FILE: README.md ================================================ # Full Resolution Image Compression with Recurrent Neural Networks https://arxiv.org/abs/1608.05148v2 ## Requirements - PyTorch 0.2.0 ## Train ` python train.py -f /path/to/your/images/folder/like/mscoco ` ## Encode and Decode ### Encode ` python encoder.py --model checkpoint/encoder_epoch_00000005.pth --input /path/to/your/example.png --cuda --output ex --iterations 16 ` This will output binary codes saved in `.npz` format. ### Decode ` python decoder.py --model checkpoint/encoder_epoch_00000005.pth --input /path/to/your/example.npz --cuda --output /path/to/output/folder ` This will output images of different quality levels. ## Test ### Get Kodak dataset ```bash bash test/get_kodak.sh ``` ### Encode and decode with RNN model ```bash bash test/enc_dec.sh ``` ### Encode and decode with JPEG (use `convert` from ImageMagick) ```bash bash test/jpeg.sh ``` ### Calculate SSIM ```bash bash test/calc_ssim.sh ``` ### Draw rate-distortion curve ```bash python test/draw_rd.py ``` ## Result LSTM (Additive Reconstruction), before entropy coding ### Rate-distortion ![Rate-distortion](rd.png) ### `kodim10.png` Original Image ![Original Image](kodim10.png) Below Left: LSTM, SSIM=0.865, bpp=0.125 Below Right: JPEG, SSIM=0.827, bpp=0.133 ![bpp-0.125-0.133-ssim-0.865-0.827](bpp-0.125-0.133-ssim-0.865-0.827.png) Below Left: LSTM, SSIM=0.937, bpp=0.250 Below Right: JPEG, SSIM=0.918, bpp=0.249 ![bpp-0.250-0.249-ssim-0.937-0.918](bpp-0.250-0.249-ssim-0.937-0.918.png) Below Left: LSTM, SSIM=0.963, bpp=0.375 Below Right: JPEG, SSIM=0.951, bpp=0.381 ![bpp-0.375-0.381-ssim-0.963-0.951](bpp-0.375-0.381-ssim-0.963-0.951.png) ## What's inside - `train.py`: Main program for training. - `encoder.py` and `decoder.py`: Encoder and decoder. - `dataset.py`: Utils for reading images. - `metric.py`: Functions for Calculatnig MS-SSIM and PSNR. - `network.py`: Modules of encoder and decoder. - `modules/conv_rnn.py`: ConvLSTM module. - `functions/sign.py`: Forward and backward for binary quantization. ## Official Repo https://github.com/tensorflow/models/tree/master/compression ================================================ FILE: dataset.py ================================================ # modified from https://github.com/desimone/vision/blob/fb74c76d09bcc2594159613d5bdadd7d4697bb11/torchvision/datasets/folder.py import os import os.path import torch from torchvision import transforms import torch.utils.data as data from PIL import Image IMG_EXTENSIONS = [ '.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', ] def is_image_file(filename): return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) def default_loader(path): return Image.open(path).convert('RGB') class ImageFolder(data.Dataset): """ ImageFolder can be used to load images where there are no labels.""" def __init__(self, root, transform=None, loader=default_loader): images = [] for filename in os.listdir(root): if is_image_file(filename): images.append('{}'.format(filename)) self.root = root self.imgs = images self.transform = transform self.loader = loader def __getitem__(self, index): filename = self.imgs[index] try: img = self.loader(os.path.join(self.root, filename)) except: return torch.zeros((3, 32, 32)) if self.transform is not None: img = self.transform(img) return img def __len__(self): return len(self.imgs) ================================================ FILE: decoder.py ================================================ import os import argparse import numpy as np from scipy.misc import imread, imresize, imsave import torch from torch.autograd import Variable parser = argparse.ArgumentParser() parser.add_argument('--model', required=True, type=str, help='path to model') parser.add_argument('--input', required=True, type=str, help='input codes') parser.add_argument('--output', default='.', type=str, help='output folder') parser.add_argument('--cuda', action='store_true', help='enables cuda') parser.add_argument( '--iterations', type=int, default=16, help='unroll iterations') args = parser.parse_args() content = np.load(args.input) codes = np.unpackbits(content['codes']) codes = np.reshape(codes, content['shape']).astype(np.float32) * 2 - 1 codes = torch.from_numpy(codes) iters, batch_size, channels, height, width = codes.size() height = height * 16 width = width * 16 codes = Variable(codes, volatile=True) import network decoder = network.DecoderCell() decoder.eval() decoder.load_state_dict(torch.load(args.model)) decoder_h_1 = (Variable( torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True), Variable( torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True)) decoder_h_2 = (Variable( torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True), Variable( torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True)) decoder_h_3 = (Variable( torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True), Variable( torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True)) decoder_h_4 = (Variable( torch.zeros(batch_size, 128, height // 2, width // 2), volatile=True), Variable( torch.zeros(batch_size, 128, height // 2, width // 2), volatile=True)) if args.cuda: decoder = decoder.cuda() codes = codes.cuda() decoder_h_1 = (decoder_h_1[0].cuda(), decoder_h_1[1].cuda()) decoder_h_2 = (decoder_h_2[0].cuda(), decoder_h_2[1].cuda()) decoder_h_3 = (decoder_h_3[0].cuda(), decoder_h_3[1].cuda()) decoder_h_4 = (decoder_h_4[0].cuda(), decoder_h_4[1].cuda()) image = torch.zeros(1, 3, height, width) + 0.5 for iters in range(min(args.iterations, codes.size(0))): output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder( codes[iters], decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4) image = image + output.data.cpu() imsave( os.path.join(args.output, '{:02d}.png'.format(iters)), np.squeeze(image.numpy().clip(0, 1) * 255.0).astype(np.uint8) .transpose(1, 2, 0)) ================================================ FILE: encoder.py ================================================ import argparse import numpy as np from scipy.misc import imread, imresize, imsave import torch from torch.autograd import Variable parser = argparse.ArgumentParser() parser.add_argument( '--model', '-m', required=True, type=str, help='path to model') parser.add_argument( '--input', '-i', required=True, type=str, help='input image') parser.add_argument( '--output', '-o', required=True, type=str, help='output codes') parser.add_argument('--cuda', '-g', action='store_true', help='enables cuda') parser.add_argument( '--iterations', type=int, default=16, help='unroll iterations') args = parser.parse_args() image = imread(args.input, mode='RGB') image = torch.from_numpy( np.expand_dims( np.transpose(image.astype(np.float32) / 255.0, (2, 0, 1)), 0)) batch_size, input_channels, height, width = image.size() assert height % 32 == 0 and width % 32 == 0 image = Variable(image, volatile=True) import network encoder = network.EncoderCell() binarizer = network.Binarizer() decoder = network.DecoderCell() encoder.eval() binarizer.eval() decoder.eval() encoder.load_state_dict(torch.load(args.model)) binarizer.load_state_dict( torch.load(args.model.replace('encoder', 'binarizer'))) decoder.load_state_dict(torch.load(args.model.replace('encoder', 'decoder'))) encoder_h_1 = (Variable( torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True), Variable( torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True)) encoder_h_2 = (Variable( torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True), Variable( torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True)) encoder_h_3 = (Variable( torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True), Variable( torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True)) decoder_h_1 = (Variable( torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True), Variable( torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True)) decoder_h_2 = (Variable( torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True), Variable( torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True)) decoder_h_3 = (Variable( torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True), Variable( torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True)) decoder_h_4 = (Variable( torch.zeros(batch_size, 128, height // 2, width // 2), volatile=True), Variable( torch.zeros(batch_size, 128, height // 2, width // 2), volatile=True)) if args.cuda: encoder = encoder.cuda() binarizer = binarizer.cuda() decoder = decoder.cuda() image = image.cuda() encoder_h_1 = (encoder_h_1[0].cuda(), encoder_h_1[1].cuda()) encoder_h_2 = (encoder_h_2[0].cuda(), encoder_h_2[1].cuda()) encoder_h_3 = (encoder_h_3[0].cuda(), encoder_h_3[1].cuda()) decoder_h_1 = (decoder_h_1[0].cuda(), decoder_h_1[1].cuda()) decoder_h_2 = (decoder_h_2[0].cuda(), decoder_h_2[1].cuda()) decoder_h_3 = (decoder_h_3[0].cuda(), decoder_h_3[1].cuda()) decoder_h_4 = (decoder_h_4[0].cuda(), decoder_h_4[1].cuda()) codes = [] res = image - 0.5 for iters in range(args.iterations): encoded, encoder_h_1, encoder_h_2, encoder_h_3 = encoder( res, encoder_h_1, encoder_h_2, encoder_h_3) code = binarizer(encoded) output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder( code, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4) res = res - output codes.append(code.data.cpu().numpy()) print('Iter: {:02d}; Loss: {:.06f}'.format(iters, res.data.abs().mean())) codes = (np.stack(codes).astype(np.int8) + 1) // 2 export = np.packbits(codes.reshape(-1)) np.savez_compressed(args.output, shape=codes.shape, codes=export) ================================================ FILE: functions/__init__.py ================================================ from .sign import Sign ================================================ FILE: functions/sign.py ================================================ import torch from torch.autograd import Function class Sign(Function): """ Variable Rate Image Compression with Recurrent Neural Networks https://arxiv.org/abs/1511.06085 """ def __init__(self): super(Sign, self).__init__() @staticmethod def forward(ctx, input, is_training=True): # Apply quantization noise while only training if is_training: prob = input.new(input.size()).uniform_() x = input.clone() x[(1 - input) / 2 <= prob] = 1 x[(1 - input) / 2 > prob] = -1 return x else: return input.sign() @staticmethod def backward(ctx, grad_output): return grad_output, None ================================================ FILE: metric.py ================================================ ## some function borrowed from ## https://github.com/tensorflow/models/blob/master/compression/image_encoder/msssim.py """Python implementation of MS-SSIM. Usage: python msssim.py --original_image=original.png --compared_image=distorted.png """ import argparse import numpy as np from scipy import signal from scipy.ndimage.filters import convolve from PIL import Image parser = argparse.ArgumentParser() parser.add_argument('--metric', '-m', type=str, default='all', help='metric') parser.add_argument( '--original-image', '-o', type=str, required=True, help='original image') parser.add_argument( '--compared-image', '-c', type=str, required=True, help='compared image') args = parser.parse_args() def _FSpecialGauss(size, sigma): """Function to mimic the 'fspecial' gaussian MATLAB function.""" radius = size // 2 offset = 0.0 start, stop = -radius, radius + 1 if size % 2 == 0: offset = 0.5 stop -= 1 x, y = np.mgrid[offset + start:stop, offset + start:stop] assert len(x) == size g = np.exp(-((x**2 + y**2) / (2.0 * sigma**2))) return g / g.sum() def _SSIMForMultiScale(img1, img2, max_val=255, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03): """Return the Structural Similarity Map between `img1` and `img2`. This function attempts to match the functionality of ssim_index_new.m by Zhou Wang: http://www.cns.nyu.edu/~lcv/ssim/msssim.zip Arguments: img1: Numpy array holding the first RGB image batch. img2: Numpy array holding the second RGB image batch. max_val: the dynamic range of the images (i.e., the difference between the maximum the and minimum allowed values). filter_size: Size of blur kernel to use (will be reduced for small images). filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced for small images). k1: Constant used to maintain stability in the SSIM calculation (0.01 in the original paper). k2: Constant used to maintain stability in the SSIM calculation (0.03 in the original paper). Returns: Pair containing the mean SSIM and contrast sensitivity between `img1` and `img2`. Raises: RuntimeError: If input images don't have the same shape or don't have four dimensions: [batch_size, height, width, depth]. """ if img1.shape != img2.shape: raise RuntimeError( 'Input images must have the same shape (%s vs. %s).', img1.shape, img2.shape) if img1.ndim != 4: raise RuntimeError('Input images must have four dimensions, not %d', img1.ndim) img1 = img1.astype(np.float64) img2 = img2.astype(np.float64) _, height, width, _ = img1.shape # Filter size can't be larger than height or width of images. size = min(filter_size, height, width) # Scale down sigma if a smaller filter size is used. sigma = size * filter_sigma / filter_size if filter_size else 0 if filter_size: window = np.reshape(_FSpecialGauss(size, sigma), (1, size, size, 1)) mu1 = signal.fftconvolve(img1, window, mode='valid') mu2 = signal.fftconvolve(img2, window, mode='valid') sigma11 = signal.fftconvolve(img1 * img1, window, mode='valid') sigma22 = signal.fftconvolve(img2 * img2, window, mode='valid') sigma12 = signal.fftconvolve(img1 * img2, window, mode='valid') else: # Empty blur kernel so no need to convolve. mu1, mu2 = img1, img2 sigma11 = img1 * img1 sigma22 = img2 * img2 sigma12 = img1 * img2 mu11 = mu1 * mu1 mu22 = mu2 * mu2 mu12 = mu1 * mu2 sigma11 -= mu11 sigma22 -= mu22 sigma12 -= mu12 # Calculate intermediate values used by both ssim and cs_map. c1 = (k1 * max_val)**2 c2 = (k2 * max_val)**2 v1 = 2.0 * sigma12 + c2 v2 = sigma11 + sigma22 + c2 ssim = np.mean((((2.0 * mu12 + c1) * v1) / ((mu11 + mu22 + c1) * v2))) cs = np.mean(v1 / v2) return ssim, cs def MultiScaleSSIM(img1, img2, max_val=255, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03, weights=None): """Return the MS-SSIM score between `img1` and `img2`. This function implements Multi-Scale Structural Similarity (MS-SSIM) Image Quality Assessment according to Zhou Wang's paper, "Multi-scale structural similarity for image quality assessment" (2003). Link: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf Author's MATLAB implementation: http://www.cns.nyu.edu/~lcv/ssim/msssim.zip Arguments: img1: Numpy array holding the first RGB image batch. img2: Numpy array holding the second RGB image batch. max_val: the dynamic range of the images (i.e., the difference between the maximum the and minimum allowed values). filter_size: Size of blur kernel to use (will be reduced for small images). filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced for small images). k1: Constant used to maintain stability in the SSIM calculation (0.01 in the original paper). k2: Constant used to maintain stability in the SSIM calculation (0.03 in the original paper). weights: List of weights for each level; if none, use five levels and the weights from the original paper. Returns: MS-SSIM score between `img1` and `img2`. Raises: RuntimeError: If input images don't have the same shape or don't have four dimensions: [batch_size, height, width, depth]. """ if img1.shape != img2.shape: raise RuntimeError( 'Input images must have the same shape (%s vs. %s).', img1.shape, img2.shape) if img1.ndim != 4: raise RuntimeError('Input images must have four dimensions, not %d', img1.ndim) # Note: default weights don't sum to 1.0 but do match the paper / matlab code. weights = np.array(weights if weights else [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) levels = weights.size downsample_filter = np.ones((1, 2, 2, 1)) / 4.0 im1, im2 = [x.astype(np.float64) for x in [img1, img2]] mssim = np.array([]) mcs = np.array([]) for _ in range(levels): ssim, cs = _SSIMForMultiScale( im1, im2, max_val=max_val, filter_size=filter_size, filter_sigma=filter_sigma, k1=k1, k2=k2) mssim = np.append(mssim, ssim) mcs = np.append(mcs, cs) filtered = [ convolve(im, downsample_filter, mode='reflect') for im in [im1, im2] ] im1, im2 = [x[:, ::2, ::2, :] for x in filtered] return (np.prod(mcs[0:levels - 1]**weights[0:levels - 1]) * (mssim[levels - 1]**weights[levels - 1])) def msssim(original, compared): if isinstance(original, str): original = np.array(Image.open(original).convert('RGB'), dtype=np.float32) if isinstance(compared, str): compared = np.array(Image.open(compared).convert('RGB'), dtype=np.float32) original = original[None, ...] if original.ndim == 3 else original compared = compared[None, ...] if compared.ndim == 3 else compared return MultiScaleSSIM(original, compared, max_val=255) def psnr(original, compared): if isinstance(original, str): original = np.array(Image.open(original).convert('RGB'), dtype=np.float32) if isinstance(compared, str): compared = np.array(Image.open(compared).convert('RGB'), dtype=np.float32) mse = np.mean(np.square(original - compared)) psnr = np.clip( np.multiply(np.log10(255. * 255. / mse[mse > 0.]), 10.), 0., 99.99)[0] return psnr def main(): if args.metric != 'psnr': print(msssim(args.original_image, args.compared_image), end='') if args.metric != 'ssim': print(psnr(args.original_image, args.compared_image), end='') if __name__ == '__main__': main() ================================================ FILE: modules/__init__.py ================================================ from .conv_rnn import ConvLSTMCell #, ConvLSTM from .sign import Sign ================================================ FILE: modules/conv_rnn.py ================================================ import torch.nn as nn import torch.nn.functional as F import torch from torch.autograd import Variable from torch.nn.modules.utils import _pair class ConvRNNCellBase(nn.Module): def __repr__(self): s = ( '{name}({input_channels}, {hidden_channels}, kernel_size={kernel_size}' ', stride={stride}') if self.padding != (0, ) * len(self.padding): s += ', padding={padding}' if self.dilation != (1, ) * len(self.dilation): s += ', dilation={dilation}' s += ', hidden_kernel_size={hidden_kernel_size}' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__) class ConvLSTMCell(ConvRNNCellBase): def __init__(self, input_channels, hidden_channels, kernel_size=3, stride=1, padding=0, dilation=1, hidden_kernel_size=1, bias=True): super(ConvLSTMCell, self).__init__() self.input_channels = input_channels self.hidden_channels = hidden_channels self.kernel_size = _pair(kernel_size) self.stride = _pair(stride) self.padding = _pair(padding) self.dilation = _pair(dilation) self.hidden_kernel_size = _pair(hidden_kernel_size) hidden_padding = _pair(hidden_kernel_size // 2) gate_channels = 4 * self.hidden_channels self.conv_ih = nn.Conv2d( in_channels=self.input_channels, out_channels=gate_channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, dilation=self.dilation, bias=bias) self.conv_hh = nn.Conv2d( in_channels=self.hidden_channels, out_channels=gate_channels, kernel_size=hidden_kernel_size, stride=1, padding=hidden_padding, dilation=1, bias=bias) self.reset_parameters() def reset_parameters(self): self.conv_ih.reset_parameters() self.conv_hh.reset_parameters() def forward(self, input, hidden): hx, cx = hidden gates = self.conv_ih(input) + self.conv_hh(hx) ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) ingate = F.sigmoid(ingate) forgetgate = F.sigmoid(forgetgate) cellgate = F.tanh(cellgate) outgate = F.sigmoid(outgate) cy = (forgetgate * cx) + (ingate * cellgate) hy = outgate * F.tanh(cy) return hy, cy ================================================ FILE: modules/sign.py ================================================ import torch import torch.nn as nn from functions import Sign as SignFunction class Sign(nn.Module): def __init__(self): super().__init__() def forward(self, x): return SignFunction.apply(x, self.training) ================================================ FILE: network.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from modules import ConvLSTMCell, Sign class EncoderCell(nn.Module): def __init__(self): super(EncoderCell, self).__init__() self.conv = nn.Conv2d( 3, 64, kernel_size=3, stride=2, padding=1, bias=False) self.rnn1 = ConvLSTMCell( 64, 256, kernel_size=3, stride=2, padding=1, hidden_kernel_size=1, bias=False) self.rnn2 = ConvLSTMCell( 256, 512, kernel_size=3, stride=2, padding=1, hidden_kernel_size=1, bias=False) self.rnn3 = ConvLSTMCell( 512, 512, kernel_size=3, stride=2, padding=1, hidden_kernel_size=1, bias=False) def forward(self, input, hidden1, hidden2, hidden3): x = self.conv(input) hidden1 = self.rnn1(x, hidden1) x = hidden1[0] hidden2 = self.rnn2(x, hidden2) x = hidden2[0] hidden3 = self.rnn3(x, hidden3) x = hidden3[0] return x, hidden1, hidden2, hidden3 class Binarizer(nn.Module): def __init__(self): super(Binarizer, self).__init__() self.conv = nn.Conv2d(512, 32, kernel_size=1, bias=False) self.sign = Sign() def forward(self, input): feat = self.conv(input) x = F.tanh(feat) return self.sign(x) class DecoderCell(nn.Module): def __init__(self): super(DecoderCell, self).__init__() self.conv1 = nn.Conv2d( 32, 512, kernel_size=1, stride=1, padding=0, bias=False) self.rnn1 = ConvLSTMCell( 512, 512, kernel_size=3, stride=1, padding=1, hidden_kernel_size=1, bias=False) self.rnn2 = ConvLSTMCell( 128, 512, kernel_size=3, stride=1, padding=1, hidden_kernel_size=1, bias=False) self.rnn3 = ConvLSTMCell( 128, 256, kernel_size=3, stride=1, padding=1, hidden_kernel_size=3, bias=False) self.rnn4 = ConvLSTMCell( 64, 128, kernel_size=3, stride=1, padding=1, hidden_kernel_size=3, bias=False) self.conv2 = nn.Conv2d( 32, 3, kernel_size=1, stride=1, padding=0, bias=False) def forward(self, input, hidden1, hidden2, hidden3, hidden4): x = self.conv1(input) hidden1 = self.rnn1(x, hidden1) x = hidden1[0] x = F.pixel_shuffle(x, 2) hidden2 = self.rnn2(x, hidden2) x = hidden2[0] x = F.pixel_shuffle(x, 2) hidden3 = self.rnn3(x, hidden3) x = hidden3[0] x = F.pixel_shuffle(x, 2) hidden4 = self.rnn4(x, hidden4) x = hidden4[0] x = F.pixel_shuffle(x, 2) x = F.tanh(self.conv2(x)) / 2 return x, hidden1, hidden2, hidden3, hidden4 ================================================ FILE: test/calc_ssim.sh ================================================ #!/bin/bash LSTM=test/lstm_ssim.csv JPEG=test/jpeg_ssim.csv echo -n "" > $LSTM for i in {01..24..1}; do echo Processing test/decoded/kodim$i for j in {00..15..1}; do echo -n `python metric.py -m ssim -o test/images/kodim$i.png -c test/decoded/kodim$i/$j.png`', ' >> $LSTM done echo "" >> $LSTM done echo -n "" > $JPEG for i in {01..24..1}; do echo Processing test/jpeg/kodim$i for j in {01..20..1}; do echo -n `python metric.py -m ssim -o test/images/kodim$i.png -c test/jpeg/kodim$i/$j.jpg`', ' >> $JPEG done echo "" >> $JPEG done ================================================ FILE: test/draw_rd.py ================================================ import os import numpy as np from scipy.misc import imread import matplotlib.pyplot as plt line = True lstm_ssim = np.genfromtxt('test/lstm_ssim.csv', delimiter=',') lstm_ssim = lstm_ssim[:, :-1] if line: lstm_ssim = np.mean(lstm_ssim, axis=0) lstm_bpp = np.arange(1, 17) / 192 * 24 plt.plot(lstm_bpp, lstm_ssim, label='LSTM', marker='o') else: lstm_bpp = np.stack([np.arange(1, 17) for _ in range(24)]) / 192 * 24 plt.scatter( lstm_bpp.reshape(-1), lstm_ssim.reshape(-1), label='LSTM', marker='o') jpeg_ssim = np.genfromtxt('test/jpeg_ssim.csv', delimiter=',') jpeg_ssim = jpeg_ssim[:, :-1] if line: jpeg_ssim = np.mean(jpeg_ssim, axis=0) jpeg_bpp = np.array([ os.path.getsize('test/jpeg/kodim{:02d}/{:02d}.jpg'.format(i, q)) * 8 / (imread('test/jpeg/kodim{:02d}/{:02d}.jpg'.format(i, q)).size // 3) for i in range(1, 25) for q in range(1, 21) ]).reshape(24, 20) if line: jpeg_bpp = np.mean(jpeg_bpp, axis=0) plt.plot(jpeg_bpp, jpeg_ssim, label='JPEG', marker='x') else: plt.scatter( jpeg_bpp.reshape(-1), jpeg_ssim.reshape(-1), label='JPEG', marker='x') plt.xlim(0., 2.) plt.ylim(0.7, 1.0) plt.xlabel('bit per pixel') plt.ylabel('MS-SSIM') plt.legend() plt.show() ================================================ FILE: test/enc_dec.sh ================================================ #!/bin/bash for i in {01..24..1}; do echo Encoding test/images/kodim$i.png mkdir -p test/codes python encoder.py --model checkpoint/encoder_epoch_00000066.pth --input test/images/kodim$i.png --cuda --output test/codes/kodim$i --iterations 16 echo Decoding test/codes/kodim$i.npz mkdir -p test/decoded/kodim$i python decoder.py --model checkpoint/decoder_epoch_00000066.pth --input test/codes/kodim$i.npz --cuda --output test/decoded/kodim$i done ================================================ FILE: test/get_kodak.sh ================================================ #!/bin/bash mkdir -p test/images for i in {01..24..1}; do echo ${i} wget http://r0k.us/graphics/kodak/kodak/kodim${i}.png -O test/images/kodim${i}.png done ================================================ FILE: test/jpeg.sh ================================================ #!/bin/bash for i in {01..24..1}; do echo JPEG Encoding test/images/kodim$i.png mkdir -p test/jpeg/kodim$i for j in {1..20..1}; do convert test/images/kodim$i.png -quality $(($j*5)) -sampling-factor 4:2:0 test/jpeg/kodim$i/`printf "%02d" $j`.jpg done done ================================================ FILE: test/jpeg_ssim.csv ================================================ 0.818072219541, 0.915486738863, 0.941250388079, 0.954959594001, 0.964432174175, 0.969521944618, 0.974450024801, 0.977023476788, 0.979556769243, 0.981440100069, 0.98303158495, 0.984619828345, 0.986521991089, 0.988233765029, 0.990034353477, 0.991908912291, 0.993859786213, 0.995872272301, 0.997704734009, 0.999369880278, 0.666065535368, 0.833785271322, 0.862238250732, 0.894437174757, 0.913606004326, 0.927403225046, 0.936700087174, 0.942087207378, 0.947302906233, 0.952881599351, 0.957102998846, 0.961079845306, 0.964679737793, 0.969085669754, 0.972862674891, 0.977389816095, 0.981929271455, 0.986764889825, 0.992090856598, 0.997867382574, 0.810456808165, 0.885532082425, 0.923215832246, 0.942269281224, 0.951013313071, 0.960641959857, 0.965020974147, 0.968983036738, 0.972687587703, 0.97536435516, 0.977124722602, 0.979205659824, 0.981645499592, 0.983797485181, 0.985868845072, 0.988406175337, 0.990561443404, 0.992977803974, 0.995363893283, 0.998186608337, 0.736205708103, 0.864789258571, 0.903917595396, 0.928365043206, 0.940992130793, 0.950190769917, 0.957041193005, 0.961645055775, 0.965608843423, 0.968846156644, 0.971464980647, 0.974338971545, 0.976288817151, 0.979326195514, 0.981731800268, 0.985073759122, 0.988183892013, 0.991310684778, 0.99484060445, 0.998539576456, 0.839808316647, 0.914569680204, 0.944585553268, 0.958189541858, 0.965933250076, 0.971345068489, 0.975120264002, 0.978007417795, 0.980112049762, 0.982031704186, 0.983617064118, 0.985301594451, 0.986864130976, 0.988550054893, 0.990140431145, 0.991912697905, 0.993659191703, 0.995525236463, 0.997451078786, 0.999445634028, 0.774314920949, 0.875787337352, 0.923727573855, 0.943054472644, 0.95321690632, 0.961164793684, 0.967588086634, 0.971259154783, 0.974502697083, 0.976536008639, 0.979070907986, 0.980730423348, 0.982772471532, 0.985208074663, 0.987072532816, 0.989603098041, 0.991848356056, 0.993989496887, 0.996261532852, 0.998833407115, 0.853028224878, 0.924594399288, 0.944577324582, 0.962892871899, 0.968175881476, 0.973912836384, 0.977611559175, 0.980013739677, 0.981372066609, 0.983489296322, 0.984974095691, 0.986316785418, 0.987796973658, 0.989294643654, 0.990415707239, 0.991958140094, 0.993469699667, 0.99504535025, 0.996735156379, 0.998906782293, 0.880688294539, 0.933204609681, 0.953650276172, 0.964610058969, 0.971655363423, 0.976231891224, 0.979402609806, 0.981754260409, 0.983510269845, 0.98503201627, 0.986234509006, 0.987594582558, 0.988831108808, 0.990315455876, 0.991614978601, 0.993112910369, 0.994592787429, 0.996128854807, 0.997766030171, 0.999473104286, 0.853720880853, 0.912554857475, 0.930711615275, 0.950455648101, 0.960958503215, 0.96667403636, 0.971362349207, 0.974072622141, 0.975613226504, 0.978781040524, 0.979969048509, 0.981158076869, 0.982883137674, 0.984722620635, 0.98658900113, 0.988314796092, 0.99036007625, 0.992362984817, 0.994781216862, 0.998299761454, 0.81076220338, 0.893706653475, 0.927362493596, 0.943724930231, 0.955211693836, 0.962355553594, 0.96829254466, 0.97169265604, 0.9743439019, 0.976751528211, 0.978728976699, 0.980505251819, 0.982242719802, 0.98445680889, 0.986286665859, 0.988491766444, 0.990776417926, 0.992905056354, 0.99523724111, 0.998527390073, 0.791555686892, 0.881440991243, 0.920860393213, 0.942552828487, 0.953270936085, 0.959839333526, 0.966037446908, 0.969412253347, 0.972687657063, 0.975263831002, 0.977073908411, 0.979106240394, 0.981338251384, 0.984181227447, 0.985917501452, 0.988715661162, 0.991071452023, 0.993500123657, 0.996109324665, 0.999023620304, 0.801456772286, 0.878467768641, 0.908098341977, 0.933136250773, 0.946448137076, 0.955111533806, 0.963027394102, 0.967258044795, 0.970678461713, 0.974186408589, 0.975688326117, 0.977937704633, 0.980300465358, 0.982588089404, 0.985135222036, 0.987727956878, 0.990090763787, 0.992444149444, 0.995132471845, 0.998223031362, 0.797066626991, 0.891473428482, 0.927754609701, 0.944168588194, 0.953960755811, 0.960570142061, 0.96609929426, 0.969401647648, 0.97267245227, 0.975204503945, 0.977409332209, 0.979649966483, 0.981643846133, 0.984198624552, 0.986564293774, 0.989219718804, 0.991824162369, 0.994323089168, 0.996765571563, 0.999382383983, 0.785606294091, 0.88697381026, 0.922042917455, 0.942085547292, 0.952637067226, 0.95956056371, 0.965247607479, 0.968692477325, 0.971654628021, 0.974079838834, 0.976237478008, 0.978466192909, 0.980687087441, 0.983108513663, 0.985172089232, 0.987865648168, 0.990607337903, 0.993466132139, 0.996502037444, 0.999093888752, 0.791792877213, 0.87257820722, 0.910251232828, 0.931499697455, 0.943952143348, 0.951516025861, 0.957952121781, 0.962339434864, 0.966062146417, 0.968616514774, 0.971381676811, 0.973583176861, 0.976055057192, 0.979084758838, 0.98170749589, 0.984744675444, 0.987623582667, 0.990827225403, 0.994500953634, 0.998296977796, 0.77155359043, 0.869021329171, 0.916734846729, 0.939085193953, 0.951543198728, 0.959581932049, 0.965130869692, 0.970274846734, 0.973566829975, 0.976120453022, 0.978316404127, 0.980331703541, 0.982562198606, 0.984891459627, 0.986832139561, 0.98925186408, 0.991631382422, 0.993799878089, 0.996076708336, 0.998724170375, 0.834596697018, 0.910219914842, 0.938953487334, 0.954867094446, 0.965021373662, 0.971626167075, 0.975397860726, 0.978190765304, 0.980255255374, 0.982316415294, 0.983438406514, 0.985042910404, 0.986538091767, 0.988296632234, 0.989480478459, 0.991169147712, 0.992838525537, 0.994484549662, 0.996349006533, 0.998796129593, 0.789379799628, 0.883671741466, 0.918305025089, 0.937438007277, 0.948741402993, 0.955676286943, 0.961312308858, 0.965026524442, 0.96851536141, 0.971303982655, 0.973488277598, 0.975794458236, 0.978215624859, 0.980917182109, 0.983425193859, 0.986089845307, 0.988720428405, 0.991544643936, 0.994909928794, 0.998959191331, 0.798779556224, 0.879522751508, 0.920868515488, 0.940119021497, 0.951843311114, 0.959017895362, 0.965208194261, 0.969082003508, 0.97221027879, 0.974691498988, 0.976855260305, 0.978953094688, 0.981167317793, 0.983437790312, 0.985786929537, 0.988162887114, 0.990644384686, 0.993026404153, 0.995507443715, 0.998621645017, 0.879263930866, 0.922020986557, 0.946493775014, 0.958297320378, 0.964859488123, 0.970312211669, 0.97424567735, 0.976421680027, 0.978413032059, 0.979756913163, 0.981009201268, 0.982435429622, 0.98394030031, 0.98552902602, 0.987216102866, 0.988920029566, 0.990524992279, 0.992390079391, 0.994639444871, 0.998357673395, 0.827950174866, 0.894831372155, 0.9313264476, 0.94757593276, 0.956105682938, 0.962399480035, 0.967538440479, 0.971214741258, 0.973124877585, 0.976048275534, 0.977788559952, 0.979504668032, 0.981543513231, 0.983649526565, 0.985455734324, 0.987721369114, 0.990020663907, 0.992286543302, 0.994788654436, 0.998404057223, 0.754610909207, 0.860747233451, 0.899736485589, 0.925743580664, 0.939166571882, 0.948339482047, 0.954553566064, 0.958885994659, 0.962909621012, 0.966353236056, 0.969046685166, 0.971977721769, 0.974626509781, 0.977835303452, 0.980470725918, 0.983736275578, 0.986964243248, 0.990258297647, 0.993951518828, 0.998556710426, 0.784958874194, 0.874722011293, 0.915389887756, 0.934343596754, 0.948237654386, 0.95672974228, 0.963437883085, 0.96720024908, 0.970772553663, 0.973495160679, 0.975921421699, 0.978227890716, 0.980404390223, 0.983010440205, 0.985215187951, 0.987733337063, 0.98989913274, 0.992273798878, 0.994922548059, 0.99824153924, 0.822296667468, 0.89834966427, 0.933957434279, 0.951462994733, 0.95978880675, 0.965804629849, 0.970633032468, 0.973789346233, 0.976510124286, 0.978522727364, 0.980401981947, 0.982057448136, 0.984052101154, 0.986046587711, 0.98771061484, 0.989891431917, 0.992010784627, 0.994204644975, 0.996543052165, 0.99909102487, ================================================ FILE: test/lstm_ssim.csv ================================================ 0.769036023469, 0.88985832751, 0.927751848949, 0.944983151872, 0.95575157289, 0.963951705505, 0.97091182084, 0.975998458878, 0.979411580032, 0.981903538902, 0.984059023874, 0.985847690249, 0.9873427695, 0.988557280396, 0.989808215992, 0.990955339337, 0.805634389938, 0.885551865434, 0.926222844478, 0.945195184339, 0.956706397486, 0.964635434939, 0.970561396161, 0.975371321751, 0.979008896093, 0.981623105097, 0.983889966323, 0.985693922222, 0.987117323848, 0.988116453261, 0.989126652794, 0.990035372754, 0.886382589545, 0.944226634293, 0.966099553101, 0.97574891786, 0.981699305186, 0.985303117183, 0.988158479953, 0.990353107687, 0.991742329762, 0.992736636125, 0.993586103345, 0.994247922172, 0.994823658832, 0.99523818175, 0.99566695018, 0.996024457998, 0.844943736432, 0.912763739137, 0.941348917006, 0.955485600323, 0.964352299005, 0.971086350909, 0.976031220021, 0.980033889346, 0.983037018013, 0.985100156288, 0.986891083059, 0.98828168163, 0.989371458473, 0.990204501149, 0.991008062343, 0.991775083665, 0.807296336647, 0.901032376198, 0.936630910374, 0.953867965935, 0.965072366127, 0.973137076713, 0.979042150279, 0.983166487096, 0.985988824175, 0.987817180773, 0.989464580939, 0.990820235888, 0.991892422931, 0.992712253547, 0.993501561405, 0.994148098694, 0.785345752053, 0.884711958245, 0.92776368098, 0.948715836445, 0.96019570665, 0.96605401155, 0.972573249337, 0.97720752253, 0.980241043367, 0.982484654905, 0.984849828368, 0.986701720731, 0.988188381493, 0.98922008024, 0.990398367023, 0.991365949801, 0.879784256538, 0.951585216892, 0.972381608615, 0.981022083203, 0.985810271759, 0.988693576841, 0.990952401157, 0.992588896661, 0.993784304761, 0.994579427454, 0.995314920728, 0.99586908212, 0.996291444393, 0.996577798766, 0.996861139444, 0.997104778181, 0.815477354243, 0.906116356367, 0.944359558351, 0.96028214433, 0.969452971301, 0.975073789021, 0.979816675642, 0.983303430092, 0.985631084073, 0.987351580806, 0.988992309204, 0.99024746499, 0.991323941253, 0.992077409918, 0.992875703722, 0.993551435751, 0.898231591406, 0.952767120478, 0.971324990566, 0.979139320561, 0.983235587546, 0.985925513104, 0.98813798392, 0.989849456611, 0.991087155229, 0.991997953486, 0.992918185114, 0.993614367759, 0.994164453638, 0.994560726328, 0.994937149117, 0.995284527261, 0.8749003714, 0.942176320572, 0.966895893079, 0.977493170796, 0.982875921764, 0.986274282021, 0.98850906544, 0.990252593185, 0.991524822651, 0.992438630258, 0.993323088693, 0.994034014192, 0.994607996745, 0.994996589457, 0.995366557151, 0.99569469237, 0.831185109757, 0.908127863885, 0.94129870614, 0.957395739465, 0.966623340753, 0.971948863236, 0.977120326656, 0.980795107721, 0.983424173894, 0.9853323917, 0.987235341184, 0.988695696085, 0.989908689415, 0.990740343679, 0.991648831536, 0.992402349259, 0.867455122658, 0.929286374978, 0.956373268572, 0.968857918522, 0.975855487208, 0.979544787408, 0.983269284681, 0.985944004, 0.987814957273, 0.989295675642, 0.990662162931, 0.991705980646, 0.992601345026, 0.99321774356, 0.993851078984, 0.994396313967, 0.741714560572, 0.852678337269, 0.89776940045, 0.921657297609, 0.937268311557, 0.948777493887, 0.958587922105, 0.966219484116, 0.970930915102, 0.973981680428, 0.976954171415, 0.979158752294, 0.981113945529, 0.982622981303, 0.984346679365, 0.985939206912, 0.806023887853, 0.898208741526, 0.934023153024, 0.951794401119, 0.962675974888, 0.969812260388, 0.975737969034, 0.980133270486, 0.983291193208, 0.985354416391, 0.98719947195, 0.98876036237, 0.990016818176, 0.990903992042, 0.991817776314, 0.992603235282, 0.885580139856, 0.938517703814, 0.959134612721, 0.968447444909, 0.974006513851, 0.978175881773, 0.981340248165, 0.983964849591, 0.985997455292, 0.987552480984, 0.988913071725, 0.989989662136, 0.990858978786, 0.991507864087, 0.992101253603, 0.992706639095, 0.843116295553, 0.918865830119, 0.948986302851, 0.963840566631, 0.971863394331, 0.976064780745, 0.980732356364, 0.98399885847, 0.986220002076, 0.987800243738, 0.989471739682, 0.990764010755, 0.991805901688, 0.992578822333, 0.993415281055, 0.994092268046, 0.88346061529, 0.946000867977, 0.96760658804, 0.976501600498, 0.981740918991, 0.985389590011, 0.988097110038, 0.990101000036, 0.991452711903, 0.992372729231, 0.993280593371, 0.993979273128, 0.994576920437, 0.994992603394, 0.99541165183, 0.995793186385, 0.801573520738, 0.894556913607, 0.930557148943, 0.947742213628, 0.958336710515, 0.966270887606, 0.972043805126, 0.976686743156, 0.979997380949, 0.982263084747, 0.984371624652, 0.986043574351, 0.987418653777, 0.988435017399, 0.98943820625, 0.990384899709, 0.824336788203, 0.914307147717, 0.947162910211, 0.961113167281, 0.969251988341, 0.974708549133, 0.979348243062, 0.982888751196, 0.985039998783, 0.986690243229, 0.988265858776, 0.989457451692, 0.990425556398, 0.991142353579, 0.991913402107, 0.992646966198, 0.910171904704, 0.952375683078, 0.967737628914, 0.975113825208, 0.979525760009, 0.982903358103, 0.985497414947, 0.987532990507, 0.989022084385, 0.990094546854, 0.991046427233, 0.991786419928, 0.992476722408, 0.992941886737, 0.9934275669, 0.993905147004, 0.864225726494, 0.929980029382, 0.955778654113, 0.96740495639, 0.974178985114, 0.978346408898, 0.982328012772, 0.985109106162, 0.987107132775, 0.988477856206, 0.989870184226, 0.990930157844, 0.991839380567, 0.992475550555, 0.993157023569, 0.993742594496, 0.803770468068, 0.894709100591, 0.93001725197, 0.946972008686, 0.957454439962, 0.965257644855, 0.971242168162, 0.97578798359, 0.979145612293, 0.981419222339, 0.983528566216, 0.985134373586, 0.986496893548, 0.98753676895, 0.988504824958, 0.989410139829, 0.887667125483, 0.950280631321, 0.971841626657, 0.980425626652, 0.984775140336, 0.987822408265, 0.989908020125, 0.991464951332, 0.992696046157, 0.993535144861, 0.99432064875, 0.994901414575, 0.995374559063, 0.995705134049, 0.996001978174, 0.996274517893, 0.828518548366, 0.907992823067, 0.940938577175, 0.956594052005, 0.966157200834, 0.972833437098, 0.977987634362, 0.981907039284, 0.984516045128, 0.986238836359, 0.987891164633, 0.989148044611, 0.990264907045, 0.991071889081, 0.991900734924, 0.992623228933, ================================================ FILE: train.py ================================================ import time import os import argparse import numpy as np import torch import torch.optim as optim import torch.optim.lr_scheduler as LS import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable import torch.utils.data as data from torchvision import transforms parser = argparse.ArgumentParser() parser.add_argument( '--batch-size', '-N', type=int, default=32, help='batch size') parser.add_argument( '--train', '-f', required=True, type=str, help='folder of training images') parser.add_argument( '--max-epochs', '-e', type=int, default=200, help='max epochs') parser.add_argument('--lr', type=float, default=0.0005, help='learning rate') # parser.add_argument('--cuda', '-g', action='store_true', help='enables cuda') parser.add_argument( '--iterations', type=int, default=16, help='unroll iterations') parser.add_argument('--checkpoint', type=int, help='unroll iterations') args = parser.parse_args() ## load 32x32 patches from images import dataset train_transform = transforms.Compose([ transforms.RandomCrop((32, 32)), transforms.ToTensor(), ]) train_set = dataset.ImageFolder(root=args.train, transform=train_transform) train_loader = data.DataLoader( dataset=train_set, batch_size=args.batch_size, shuffle=True, num_workers=1) print('total images: {}; total batches: {}'.format( len(train_set), len(train_loader))) ## load networks on GPU import network encoder = network.EncoderCell().cuda() binarizer = network.Binarizer().cuda() decoder = network.DecoderCell().cuda() solver = optim.Adam( [ { 'params': encoder.parameters() }, { 'params': binarizer.parameters() }, { 'params': decoder.parameters() }, ], lr=args.lr) def resume(epoch=None): if epoch is None: s = 'iter' epoch = 0 else: s = 'epoch' encoder.load_state_dict( torch.load('checkpoint/encoder_{}_{:08d}.pth'.format(s, epoch))) binarizer.load_state_dict( torch.load('checkpoint/binarizer_{}_{:08d}.pth'.format(s, epoch))) decoder.load_state_dict( torch.load('checkpoint/decoder_{}_{:08d}.pth'.format(s, epoch))) def save(index, epoch=True): if not os.path.exists('checkpoint'): os.mkdir('checkpoint') if epoch: s = 'epoch' else: s = 'iter' torch.save(encoder.state_dict(), 'checkpoint/encoder_{}_{:08d}.pth'.format( s, index)) torch.save(binarizer.state_dict(), 'checkpoint/binarizer_{}_{:08d}.pth'.format(s, index)) torch.save(decoder.state_dict(), 'checkpoint/decoder_{}_{:08d}.pth'.format( s, index)) # resume() scheduler = LS.MultiStepLR(solver, milestones=[3, 10, 20, 50, 100], gamma=0.5) last_epoch = 0 if args.checkpoint: resume(args.checkpoint) last_epoch = args.checkpoint scheduler.last_epoch = last_epoch - 1 for epoch in range(last_epoch + 1, args.max_epochs + 1): scheduler.step() for batch, data in enumerate(train_loader): batch_t0 = time.time() ## init lstm state encoder_h_1 = (Variable(torch.zeros(data.size(0), 256, 8, 8).cuda()), Variable(torch.zeros(data.size(0), 256, 8, 8).cuda())) encoder_h_2 = (Variable(torch.zeros(data.size(0), 512, 4, 4).cuda()), Variable(torch.zeros(data.size(0), 512, 4, 4).cuda())) encoder_h_3 = (Variable(torch.zeros(data.size(0), 512, 2, 2).cuda()), Variable(torch.zeros(data.size(0), 512, 2, 2).cuda())) decoder_h_1 = (Variable(torch.zeros(data.size(0), 512, 2, 2).cuda()), Variable(torch.zeros(data.size(0), 512, 2, 2).cuda())) decoder_h_2 = (Variable(torch.zeros(data.size(0), 512, 4, 4).cuda()), Variable(torch.zeros(data.size(0), 512, 4, 4).cuda())) decoder_h_3 = (Variable(torch.zeros(data.size(0), 256, 8, 8).cuda()), Variable(torch.zeros(data.size(0), 256, 8, 8).cuda())) decoder_h_4 = (Variable(torch.zeros(data.size(0), 128, 16, 16).cuda()), Variable(torch.zeros(data.size(0), 128, 16, 16).cuda())) patches = Variable(data.cuda()) solver.zero_grad() losses = [] res = patches - 0.5 bp_t0 = time.time() for _ in range(args.iterations): encoded, encoder_h_1, encoder_h_2, encoder_h_3 = encoder( res, encoder_h_1, encoder_h_2, encoder_h_3) codes = binarizer(encoded) output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder( codes, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4) res = res - output losses.append(res.abs().mean()) bp_t1 = time.time() loss = sum(losses) / args.iterations loss.backward() solver.step() batch_t1 = time.time() print( '[TRAIN] Epoch[{}]({}/{}); Loss: {:.6f}; Backpropagation: {:.4f} sec; Batch: {:.4f} sec'. format(epoch, batch + 1, len(train_loader), loss.data[0], bp_t1 - bp_t0, batch_t1 - batch_t0)) print(('{:.4f} ' * args.iterations + '\n').format(* [l.data[0] for l in losses])) index = (epoch - 1) * len(train_loader) + batch ## save checkpoint every 500 training steps if index % 500 == 0: save(0, False) save(epoch)