Repository: 1zb/pytorch-image-comp-rnn
Branch: master
Commit: d7150f285cfd
Files: 20
Total size: 46.6 KB
Directory structure:
gitextract_yfbt7nhy/
├── .gitignore
├── README.md
├── dataset.py
├── decoder.py
├── encoder.py
├── functions/
│ ├── __init__.py
│ └── sign.py
├── metric.py
├── modules/
│ ├── __init__.py
│ ├── conv_rnn.py
│ └── sign.py
├── network.py
├── test/
│ ├── calc_ssim.sh
│ ├── draw_rd.py
│ ├── enc_dec.sh
│ ├── get_kodak.sh
│ ├── jpeg.sh
│ ├── jpeg_ssim.csv
│ └── lstm_ssim.csv
└── train.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*.pth
*.pyc
*.bak
*.png
*.jpg
*.npz
!rd.png
!/kodim10.png
!/bpp-0.125-0.133-ssim-0.865-0.827.png
!/bpp-0.250-0.249-ssim-0.937-0.918.png
!/bpp-0.375-0.381-ssim-0.963-0.951.png
================================================
FILE: README.md
================================================
# Full Resolution Image Compression with Recurrent Neural Networks
https://arxiv.org/abs/1608.05148v2
## Requirements
- PyTorch 0.2.0
## Train
`
python train.py -f /path/to/your/images/folder/like/mscoco
`
## Encode and Decode
### Encode
`
python encoder.py --model checkpoint/encoder_epoch_00000005.pth --input /path/to/your/example.png --cuda --output ex --iterations 16
`
This will output binary codes saved in `.npz` format.
### Decode
`
python decoder.py --model checkpoint/encoder_epoch_00000005.pth --input /path/to/your/example.npz --cuda --output /path/to/output/folder
`
This will output images of different quality levels.
## Test
### Get Kodak dataset
```bash
bash test/get_kodak.sh
```
### Encode and decode with RNN model
```bash
bash test/enc_dec.sh
```
### Encode and decode with JPEG (use `convert` from ImageMagick)
```bash
bash test/jpeg.sh
```
### Calculate SSIM
```bash
bash test/calc_ssim.sh
```
### Draw rate-distortion curve
```bash
python test/draw_rd.py
```
## Result
LSTM (Additive Reconstruction), before entropy coding
### Rate-distortion

### `kodim10.png`
Original Image

Below Left: LSTM, SSIM=0.865, bpp=0.125
Below Right: JPEG, SSIM=0.827, bpp=0.133

Below Left: LSTM, SSIM=0.937, bpp=0.250
Below Right: JPEG, SSIM=0.918, bpp=0.249

Below Left: LSTM, SSIM=0.963, bpp=0.375
Below Right: JPEG, SSIM=0.951, bpp=0.381

## What's inside
- `train.py`: Main program for training.
- `encoder.py` and `decoder.py`: Encoder and decoder.
- `dataset.py`: Utils for reading images.
- `metric.py`: Functions for Calculatnig MS-SSIM and PSNR.
- `network.py`: Modules of encoder and decoder.
- `modules/conv_rnn.py`: ConvLSTM module.
- `functions/sign.py`: Forward and backward for binary quantization.
## Official Repo
https://github.com/tensorflow/models/tree/master/compression
================================================
FILE: dataset.py
================================================
# modified from https://github.com/desimone/vision/blob/fb74c76d09bcc2594159613d5bdadd7d4697bb11/torchvision/datasets/folder.py
import os
import os.path
import torch
from torchvision import transforms
import torch.utils.data as data
from PIL import Image
IMG_EXTENSIONS = [
'.jpg',
'.JPG',
'.jpeg',
'.JPEG',
'.png',
'.PNG',
'.ppm',
'.PPM',
'.bmp',
'.BMP',
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolder(data.Dataset):
""" ImageFolder can be used to load images where there are no labels."""
def __init__(self, root, transform=None, loader=default_loader):
images = []
for filename in os.listdir(root):
if is_image_file(filename):
images.append('{}'.format(filename))
self.root = root
self.imgs = images
self.transform = transform
self.loader = loader
def __getitem__(self, index):
filename = self.imgs[index]
try:
img = self.loader(os.path.join(self.root, filename))
except:
return torch.zeros((3, 32, 32))
if self.transform is not None:
img = self.transform(img)
return img
def __len__(self):
return len(self.imgs)
================================================
FILE: decoder.py
================================================
import os
import argparse
import numpy as np
from scipy.misc import imread, imresize, imsave
import torch
from torch.autograd import Variable
parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True, type=str, help='path to model')
parser.add_argument('--input', required=True, type=str, help='input codes')
parser.add_argument('--output', default='.', type=str, help='output folder')
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument(
'--iterations', type=int, default=16, help='unroll iterations')
args = parser.parse_args()
content = np.load(args.input)
codes = np.unpackbits(content['codes'])
codes = np.reshape(codes, content['shape']).astype(np.float32) * 2 - 1
codes = torch.from_numpy(codes)
iters, batch_size, channels, height, width = codes.size()
height = height * 16
width = width * 16
codes = Variable(codes, volatile=True)
import network
decoder = network.DecoderCell()
decoder.eval()
decoder.load_state_dict(torch.load(args.model))
decoder_h_1 = (Variable(
torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True),
Variable(
torch.zeros(batch_size, 512, height // 16, width // 16),
volatile=True))
decoder_h_2 = (Variable(
torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True),
Variable(
torch.zeros(batch_size, 512, height // 8, width // 8),
volatile=True))
decoder_h_3 = (Variable(
torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True),
Variable(
torch.zeros(batch_size, 256, height // 4, width // 4),
volatile=True))
decoder_h_4 = (Variable(
torch.zeros(batch_size, 128, height // 2, width // 2), volatile=True),
Variable(
torch.zeros(batch_size, 128, height // 2, width // 2),
volatile=True))
if args.cuda:
decoder = decoder.cuda()
codes = codes.cuda()
decoder_h_1 = (decoder_h_1[0].cuda(), decoder_h_1[1].cuda())
decoder_h_2 = (decoder_h_2[0].cuda(), decoder_h_2[1].cuda())
decoder_h_3 = (decoder_h_3[0].cuda(), decoder_h_3[1].cuda())
decoder_h_4 = (decoder_h_4[0].cuda(), decoder_h_4[1].cuda())
image = torch.zeros(1, 3, height, width) + 0.5
for iters in range(min(args.iterations, codes.size(0))):
output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder(
codes[iters], decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4)
image = image + output.data.cpu()
imsave(
os.path.join(args.output, '{:02d}.png'.format(iters)),
np.squeeze(image.numpy().clip(0, 1) * 255.0).astype(np.uint8)
.transpose(1, 2, 0))
================================================
FILE: encoder.py
================================================
import argparse
import numpy as np
from scipy.misc import imread, imresize, imsave
import torch
from torch.autograd import Variable
parser = argparse.ArgumentParser()
parser.add_argument(
'--model', '-m', required=True, type=str, help='path to model')
parser.add_argument(
'--input', '-i', required=True, type=str, help='input image')
parser.add_argument(
'--output', '-o', required=True, type=str, help='output codes')
parser.add_argument('--cuda', '-g', action='store_true', help='enables cuda')
parser.add_argument(
'--iterations', type=int, default=16, help='unroll iterations')
args = parser.parse_args()
image = imread(args.input, mode='RGB')
image = torch.from_numpy(
np.expand_dims(
np.transpose(image.astype(np.float32) / 255.0, (2, 0, 1)), 0))
batch_size, input_channels, height, width = image.size()
assert height % 32 == 0 and width % 32 == 0
image = Variable(image, volatile=True)
import network
encoder = network.EncoderCell()
binarizer = network.Binarizer()
decoder = network.DecoderCell()
encoder.eval()
binarizer.eval()
decoder.eval()
encoder.load_state_dict(torch.load(args.model))
binarizer.load_state_dict(
torch.load(args.model.replace('encoder', 'binarizer')))
decoder.load_state_dict(torch.load(args.model.replace('encoder', 'decoder')))
encoder_h_1 = (Variable(
torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True),
Variable(
torch.zeros(batch_size, 256, height // 4, width // 4),
volatile=True))
encoder_h_2 = (Variable(
torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True),
Variable(
torch.zeros(batch_size, 512, height // 8, width // 8),
volatile=True))
encoder_h_3 = (Variable(
torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True),
Variable(
torch.zeros(batch_size, 512, height // 16, width // 16),
volatile=True))
decoder_h_1 = (Variable(
torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True),
Variable(
torch.zeros(batch_size, 512, height // 16, width // 16),
volatile=True))
decoder_h_2 = (Variable(
torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True),
Variable(
torch.zeros(batch_size, 512, height // 8, width // 8),
volatile=True))
decoder_h_3 = (Variable(
torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True),
Variable(
torch.zeros(batch_size, 256, height // 4, width // 4),
volatile=True))
decoder_h_4 = (Variable(
torch.zeros(batch_size, 128, height // 2, width // 2), volatile=True),
Variable(
torch.zeros(batch_size, 128, height // 2, width // 2),
volatile=True))
if args.cuda:
encoder = encoder.cuda()
binarizer = binarizer.cuda()
decoder = decoder.cuda()
image = image.cuda()
encoder_h_1 = (encoder_h_1[0].cuda(), encoder_h_1[1].cuda())
encoder_h_2 = (encoder_h_2[0].cuda(), encoder_h_2[1].cuda())
encoder_h_3 = (encoder_h_3[0].cuda(), encoder_h_3[1].cuda())
decoder_h_1 = (decoder_h_1[0].cuda(), decoder_h_1[1].cuda())
decoder_h_2 = (decoder_h_2[0].cuda(), decoder_h_2[1].cuda())
decoder_h_3 = (decoder_h_3[0].cuda(), decoder_h_3[1].cuda())
decoder_h_4 = (decoder_h_4[0].cuda(), decoder_h_4[1].cuda())
codes = []
res = image - 0.5
for iters in range(args.iterations):
encoded, encoder_h_1, encoder_h_2, encoder_h_3 = encoder(
res, encoder_h_1, encoder_h_2, encoder_h_3)
code = binarizer(encoded)
output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder(
code, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4)
res = res - output
codes.append(code.data.cpu().numpy())
print('Iter: {:02d}; Loss: {:.06f}'.format(iters, res.data.abs().mean()))
codes = (np.stack(codes).astype(np.int8) + 1) // 2
export = np.packbits(codes.reshape(-1))
np.savez_compressed(args.output, shape=codes.shape, codes=export)
================================================
FILE: functions/__init__.py
================================================
from .sign import Sign
================================================
FILE: functions/sign.py
================================================
import torch
from torch.autograd import Function
class Sign(Function):
"""
Variable Rate Image Compression with Recurrent Neural Networks
https://arxiv.org/abs/1511.06085
"""
def __init__(self):
super(Sign, self).__init__()
@staticmethod
def forward(ctx, input, is_training=True):
# Apply quantization noise while only training
if is_training:
prob = input.new(input.size()).uniform_()
x = input.clone()
x[(1 - input) / 2 <= prob] = 1
x[(1 - input) / 2 > prob] = -1
return x
else:
return input.sign()
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
================================================
FILE: metric.py
================================================
## some function borrowed from
## https://github.com/tensorflow/models/blob/master/compression/image_encoder/msssim.py
"""Python implementation of MS-SSIM.
Usage:
python msssim.py --original_image=original.png --compared_image=distorted.png
"""
import argparse
import numpy as np
from scipy import signal
from scipy.ndimage.filters import convolve
from PIL import Image
parser = argparse.ArgumentParser()
parser.add_argument('--metric', '-m', type=str, default='all', help='metric')
parser.add_argument(
'--original-image', '-o', type=str, required=True, help='original image')
parser.add_argument(
'--compared-image', '-c', type=str, required=True, help='compared image')
args = parser.parse_args()
def _FSpecialGauss(size, sigma):
"""Function to mimic the 'fspecial' gaussian MATLAB function."""
radius = size // 2
offset = 0.0
start, stop = -radius, radius + 1
if size % 2 == 0:
offset = 0.5
stop -= 1
x, y = np.mgrid[offset + start:stop, offset + start:stop]
assert len(x) == size
g = np.exp(-((x**2 + y**2) / (2.0 * sigma**2)))
return g / g.sum()
def _SSIMForMultiScale(img1,
img2,
max_val=255,
filter_size=11,
filter_sigma=1.5,
k1=0.01,
k2=0.03):
"""Return the Structural Similarity Map between `img1` and `img2`.
This function attempts to match the functionality of ssim_index_new.m by
Zhou Wang: http://www.cns.nyu.edu/~lcv/ssim/msssim.zip
Arguments:
img1: Numpy array holding the first RGB image batch.
img2: Numpy array holding the second RGB image batch.
max_val: the dynamic range of the images (i.e., the difference between the
maximum the and minimum allowed values).
filter_size: Size of blur kernel to use (will be reduced for small images).
filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced
for small images).
k1: Constant used to maintain stability in the SSIM calculation (0.01 in
the original paper).
k2: Constant used to maintain stability in the SSIM calculation (0.03 in
the original paper).
Returns:
Pair containing the mean SSIM and contrast sensitivity between `img1` and
`img2`.
Raises:
RuntimeError: If input images don't have the same shape or don't have four
dimensions: [batch_size, height, width, depth].
"""
if img1.shape != img2.shape:
raise RuntimeError(
'Input images must have the same shape (%s vs. %s).', img1.shape,
img2.shape)
if img1.ndim != 4:
raise RuntimeError('Input images must have four dimensions, not %d',
img1.ndim)
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
_, height, width, _ = img1.shape
# Filter size can't be larger than height or width of images.
size = min(filter_size, height, width)
# Scale down sigma if a smaller filter size is used.
sigma = size * filter_sigma / filter_size if filter_size else 0
if filter_size:
window = np.reshape(_FSpecialGauss(size, sigma), (1, size, size, 1))
mu1 = signal.fftconvolve(img1, window, mode='valid')
mu2 = signal.fftconvolve(img2, window, mode='valid')
sigma11 = signal.fftconvolve(img1 * img1, window, mode='valid')
sigma22 = signal.fftconvolve(img2 * img2, window, mode='valid')
sigma12 = signal.fftconvolve(img1 * img2, window, mode='valid')
else:
# Empty blur kernel so no need to convolve.
mu1, mu2 = img1, img2
sigma11 = img1 * img1
sigma22 = img2 * img2
sigma12 = img1 * img2
mu11 = mu1 * mu1
mu22 = mu2 * mu2
mu12 = mu1 * mu2
sigma11 -= mu11
sigma22 -= mu22
sigma12 -= mu12
# Calculate intermediate values used by both ssim and cs_map.
c1 = (k1 * max_val)**2
c2 = (k2 * max_val)**2
v1 = 2.0 * sigma12 + c2
v2 = sigma11 + sigma22 + c2
ssim = np.mean((((2.0 * mu12 + c1) * v1) / ((mu11 + mu22 + c1) * v2)))
cs = np.mean(v1 / v2)
return ssim, cs
def MultiScaleSSIM(img1,
img2,
max_val=255,
filter_size=11,
filter_sigma=1.5,
k1=0.01,
k2=0.03,
weights=None):
"""Return the MS-SSIM score between `img1` and `img2`.
This function implements Multi-Scale Structural Similarity (MS-SSIM) Image
Quality Assessment according to Zhou Wang's paper, "Multi-scale structural
similarity for image quality assessment" (2003).
Link: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf
Author's MATLAB implementation:
http://www.cns.nyu.edu/~lcv/ssim/msssim.zip
Arguments:
img1: Numpy array holding the first RGB image batch.
img2: Numpy array holding the second RGB image batch.
max_val: the dynamic range of the images (i.e., the difference between the
maximum the and minimum allowed values).
filter_size: Size of blur kernel to use (will be reduced for small images).
filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced
for small images).
k1: Constant used to maintain stability in the SSIM calculation (0.01 in
the original paper).
k2: Constant used to maintain stability in the SSIM calculation (0.03 in
the original paper).
weights: List of weights for each level; if none, use five levels and the
weights from the original paper.
Returns:
MS-SSIM score between `img1` and `img2`.
Raises:
RuntimeError: If input images don't have the same shape or don't have four
dimensions: [batch_size, height, width, depth].
"""
if img1.shape != img2.shape:
raise RuntimeError(
'Input images must have the same shape (%s vs. %s).', img1.shape,
img2.shape)
if img1.ndim != 4:
raise RuntimeError('Input images must have four dimensions, not %d',
img1.ndim)
# Note: default weights don't sum to 1.0 but do match the paper / matlab code.
weights = np.array(weights if weights else
[0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
levels = weights.size
downsample_filter = np.ones((1, 2, 2, 1)) / 4.0
im1, im2 = [x.astype(np.float64) for x in [img1, img2]]
mssim = np.array([])
mcs = np.array([])
for _ in range(levels):
ssim, cs = _SSIMForMultiScale(
im1,
im2,
max_val=max_val,
filter_size=filter_size,
filter_sigma=filter_sigma,
k1=k1,
k2=k2)
mssim = np.append(mssim, ssim)
mcs = np.append(mcs, cs)
filtered = [
convolve(im, downsample_filter, mode='reflect')
for im in [im1, im2]
]
im1, im2 = [x[:, ::2, ::2, :] for x in filtered]
return (np.prod(mcs[0:levels - 1]**weights[0:levels - 1]) *
(mssim[levels - 1]**weights[levels - 1]))
def msssim(original, compared):
if isinstance(original, str):
original = np.array(Image.open(original).convert('RGB'), dtype=np.float32)
if isinstance(compared, str):
compared = np.array(Image.open(compared).convert('RGB'), dtype=np.float32)
original = original[None, ...] if original.ndim == 3 else original
compared = compared[None, ...] if compared.ndim == 3 else compared
return MultiScaleSSIM(original, compared, max_val=255)
def psnr(original, compared):
if isinstance(original, str):
original = np.array(Image.open(original).convert('RGB'), dtype=np.float32)
if isinstance(compared, str):
compared = np.array(Image.open(compared).convert('RGB'), dtype=np.float32)
mse = np.mean(np.square(original - compared))
psnr = np.clip(
np.multiply(np.log10(255. * 255. / mse[mse > 0.]), 10.), 0., 99.99)[0]
return psnr
def main():
if args.metric != 'psnr':
print(msssim(args.original_image, args.compared_image), end='')
if args.metric != 'ssim':
print(psnr(args.original_image, args.compared_image), end='')
if __name__ == '__main__':
main()
================================================
FILE: modules/__init__.py
================================================
from .conv_rnn import ConvLSTMCell #, ConvLSTM
from .sign import Sign
================================================
FILE: modules/conv_rnn.py
================================================
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.autograd import Variable
from torch.nn.modules.utils import _pair
class ConvRNNCellBase(nn.Module):
def __repr__(self):
s = (
'{name}({input_channels}, {hidden_channels}, kernel_size={kernel_size}'
', stride={stride}')
if self.padding != (0, ) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1, ) * len(self.dilation):
s += ', dilation={dilation}'
s += ', hidden_kernel_size={hidden_kernel_size}'
s += ')'
return s.format(name=self.__class__.__name__, **self.__dict__)
class ConvLSTMCell(ConvRNNCellBase):
def __init__(self,
input_channels,
hidden_channels,
kernel_size=3,
stride=1,
padding=0,
dilation=1,
hidden_kernel_size=1,
bias=True):
super(ConvLSTMCell, self).__init__()
self.input_channels = input_channels
self.hidden_channels = hidden_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.hidden_kernel_size = _pair(hidden_kernel_size)
hidden_padding = _pair(hidden_kernel_size // 2)
gate_channels = 4 * self.hidden_channels
self.conv_ih = nn.Conv2d(
in_channels=self.input_channels,
out_channels=gate_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
bias=bias)
self.conv_hh = nn.Conv2d(
in_channels=self.hidden_channels,
out_channels=gate_channels,
kernel_size=hidden_kernel_size,
stride=1,
padding=hidden_padding,
dilation=1,
bias=bias)
self.reset_parameters()
def reset_parameters(self):
self.conv_ih.reset_parameters()
self.conv_hh.reset_parameters()
def forward(self, input, hidden):
hx, cx = hidden
gates = self.conv_ih(input) + self.conv_hh(hx)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = F.sigmoid(ingate)
forgetgate = F.sigmoid(forgetgate)
cellgate = F.tanh(cellgate)
outgate = F.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * F.tanh(cy)
return hy, cy
================================================
FILE: modules/sign.py
================================================
import torch
import torch.nn as nn
from functions import Sign as SignFunction
class Sign(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return SignFunction.apply(x, self.training)
================================================
FILE: network.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from modules import ConvLSTMCell, Sign
class EncoderCell(nn.Module):
def __init__(self):
super(EncoderCell, self).__init__()
self.conv = nn.Conv2d(
3, 64, kernel_size=3, stride=2, padding=1, bias=False)
self.rnn1 = ConvLSTMCell(
64,
256,
kernel_size=3,
stride=2,
padding=1,
hidden_kernel_size=1,
bias=False)
self.rnn2 = ConvLSTMCell(
256,
512,
kernel_size=3,
stride=2,
padding=1,
hidden_kernel_size=1,
bias=False)
self.rnn3 = ConvLSTMCell(
512,
512,
kernel_size=3,
stride=2,
padding=1,
hidden_kernel_size=1,
bias=False)
def forward(self, input, hidden1, hidden2, hidden3):
x = self.conv(input)
hidden1 = self.rnn1(x, hidden1)
x = hidden1[0]
hidden2 = self.rnn2(x, hidden2)
x = hidden2[0]
hidden3 = self.rnn3(x, hidden3)
x = hidden3[0]
return x, hidden1, hidden2, hidden3
class Binarizer(nn.Module):
def __init__(self):
super(Binarizer, self).__init__()
self.conv = nn.Conv2d(512, 32, kernel_size=1, bias=False)
self.sign = Sign()
def forward(self, input):
feat = self.conv(input)
x = F.tanh(feat)
return self.sign(x)
class DecoderCell(nn.Module):
def __init__(self):
super(DecoderCell, self).__init__()
self.conv1 = nn.Conv2d(
32, 512, kernel_size=1, stride=1, padding=0, bias=False)
self.rnn1 = ConvLSTMCell(
512,
512,
kernel_size=3,
stride=1,
padding=1,
hidden_kernel_size=1,
bias=False)
self.rnn2 = ConvLSTMCell(
128,
512,
kernel_size=3,
stride=1,
padding=1,
hidden_kernel_size=1,
bias=False)
self.rnn3 = ConvLSTMCell(
128,
256,
kernel_size=3,
stride=1,
padding=1,
hidden_kernel_size=3,
bias=False)
self.rnn4 = ConvLSTMCell(
64,
128,
kernel_size=3,
stride=1,
padding=1,
hidden_kernel_size=3,
bias=False)
self.conv2 = nn.Conv2d(
32, 3, kernel_size=1, stride=1, padding=0, bias=False)
def forward(self, input, hidden1, hidden2, hidden3, hidden4):
x = self.conv1(input)
hidden1 = self.rnn1(x, hidden1)
x = hidden1[0]
x = F.pixel_shuffle(x, 2)
hidden2 = self.rnn2(x, hidden2)
x = hidden2[0]
x = F.pixel_shuffle(x, 2)
hidden3 = self.rnn3(x, hidden3)
x = hidden3[0]
x = F.pixel_shuffle(x, 2)
hidden4 = self.rnn4(x, hidden4)
x = hidden4[0]
x = F.pixel_shuffle(x, 2)
x = F.tanh(self.conv2(x)) / 2
return x, hidden1, hidden2, hidden3, hidden4
================================================
FILE: test/calc_ssim.sh
================================================
#!/bin/bash
LSTM=test/lstm_ssim.csv
JPEG=test/jpeg_ssim.csv
echo -n "" > $LSTM
for i in {01..24..1}; do
echo Processing test/decoded/kodim$i
for j in {00..15..1}; do
echo -n `python metric.py -m ssim -o test/images/kodim$i.png -c test/decoded/kodim$i/$j.png`', ' >> $LSTM
done
echo "" >> $LSTM
done
echo -n "" > $JPEG
for i in {01..24..1}; do
echo Processing test/jpeg/kodim$i
for j in {01..20..1}; do
echo -n `python metric.py -m ssim -o test/images/kodim$i.png -c test/jpeg/kodim$i/$j.jpg`', ' >> $JPEG
done
echo "" >> $JPEG
done
================================================
FILE: test/draw_rd.py
================================================
import os
import numpy as np
from scipy.misc import imread
import matplotlib.pyplot as plt
line = True
lstm_ssim = np.genfromtxt('test/lstm_ssim.csv', delimiter=',')
lstm_ssim = lstm_ssim[:, :-1]
if line:
lstm_ssim = np.mean(lstm_ssim, axis=0)
lstm_bpp = np.arange(1, 17) / 192 * 24
plt.plot(lstm_bpp, lstm_ssim, label='LSTM', marker='o')
else:
lstm_bpp = np.stack([np.arange(1, 17) for _ in range(24)]) / 192 * 24
plt.scatter(
lstm_bpp.reshape(-1), lstm_ssim.reshape(-1), label='LSTM', marker='o')
jpeg_ssim = np.genfromtxt('test/jpeg_ssim.csv', delimiter=',')
jpeg_ssim = jpeg_ssim[:, :-1]
if line:
jpeg_ssim = np.mean(jpeg_ssim, axis=0)
jpeg_bpp = np.array([
os.path.getsize('test/jpeg/kodim{:02d}/{:02d}.jpg'.format(i, q)) * 8 /
(imread('test/jpeg/kodim{:02d}/{:02d}.jpg'.format(i, q)).size // 3)
for i in range(1, 25) for q in range(1, 21)
]).reshape(24, 20)
if line:
jpeg_bpp = np.mean(jpeg_bpp, axis=0)
plt.plot(jpeg_bpp, jpeg_ssim, label='JPEG', marker='x')
else:
plt.scatter(
jpeg_bpp.reshape(-1), jpeg_ssim.reshape(-1), label='JPEG', marker='x')
plt.xlim(0., 2.)
plt.ylim(0.7, 1.0)
plt.xlabel('bit per pixel')
plt.ylabel('MS-SSIM')
plt.legend()
plt.show()
================================================
FILE: test/enc_dec.sh
================================================
#!/bin/bash
for i in {01..24..1}; do
echo Encoding test/images/kodim$i.png
mkdir -p test/codes
python encoder.py --model checkpoint/encoder_epoch_00000066.pth --input test/images/kodim$i.png --cuda --output test/codes/kodim$i --iterations 16
echo Decoding test/codes/kodim$i.npz
mkdir -p test/decoded/kodim$i
python decoder.py --model checkpoint/decoder_epoch_00000066.pth --input test/codes/kodim$i.npz --cuda --output test/decoded/kodim$i
done
================================================
FILE: test/get_kodak.sh
================================================
#!/bin/bash
mkdir -p test/images
for i in {01..24..1}; do
echo ${i}
wget http://r0k.us/graphics/kodak/kodak/kodim${i}.png -O test/images/kodim${i}.png
done
================================================
FILE: test/jpeg.sh
================================================
#!/bin/bash
for i in {01..24..1}; do
echo JPEG Encoding test/images/kodim$i.png
mkdir -p test/jpeg/kodim$i
for j in {1..20..1}; do
convert test/images/kodim$i.png -quality $(($j*5)) -sampling-factor 4:2:0 test/jpeg/kodim$i/`printf "%02d" $j`.jpg
done
done
================================================
FILE: test/jpeg_ssim.csv
================================================
0.818072219541, 0.915486738863, 0.941250388079, 0.954959594001, 0.964432174175, 0.969521944618, 0.974450024801, 0.977023476788, 0.979556769243, 0.981440100069, 0.98303158495, 0.984619828345, 0.986521991089, 0.988233765029, 0.990034353477, 0.991908912291, 0.993859786213, 0.995872272301, 0.997704734009, 0.999369880278,
0.666065535368, 0.833785271322, 0.862238250732, 0.894437174757, 0.913606004326, 0.927403225046, 0.936700087174, 0.942087207378, 0.947302906233, 0.952881599351, 0.957102998846, 0.961079845306, 0.964679737793, 0.969085669754, 0.972862674891, 0.977389816095, 0.981929271455, 0.986764889825, 0.992090856598, 0.997867382574,
0.810456808165, 0.885532082425, 0.923215832246, 0.942269281224, 0.951013313071, 0.960641959857, 0.965020974147, 0.968983036738, 0.972687587703, 0.97536435516, 0.977124722602, 0.979205659824, 0.981645499592, 0.983797485181, 0.985868845072, 0.988406175337, 0.990561443404, 0.992977803974, 0.995363893283, 0.998186608337,
0.736205708103, 0.864789258571, 0.903917595396, 0.928365043206, 0.940992130793, 0.950190769917, 0.957041193005, 0.961645055775, 0.965608843423, 0.968846156644, 0.971464980647, 0.974338971545, 0.976288817151, 0.979326195514, 0.981731800268, 0.985073759122, 0.988183892013, 0.991310684778, 0.99484060445, 0.998539576456,
0.839808316647, 0.914569680204, 0.944585553268, 0.958189541858, 0.965933250076, 0.971345068489, 0.975120264002, 0.978007417795, 0.980112049762, 0.982031704186, 0.983617064118, 0.985301594451, 0.986864130976, 0.988550054893, 0.990140431145, 0.991912697905, 0.993659191703, 0.995525236463, 0.997451078786, 0.999445634028,
0.774314920949, 0.875787337352, 0.923727573855, 0.943054472644, 0.95321690632, 0.961164793684, 0.967588086634, 0.971259154783, 0.974502697083, 0.976536008639, 0.979070907986, 0.980730423348, 0.982772471532, 0.985208074663, 0.987072532816, 0.989603098041, 0.991848356056, 0.993989496887, 0.996261532852, 0.998833407115,
0.853028224878, 0.924594399288, 0.944577324582, 0.962892871899, 0.968175881476, 0.973912836384, 0.977611559175, 0.980013739677, 0.981372066609, 0.983489296322, 0.984974095691, 0.986316785418, 0.987796973658, 0.989294643654, 0.990415707239, 0.991958140094, 0.993469699667, 0.99504535025, 0.996735156379, 0.998906782293,
0.880688294539, 0.933204609681, 0.953650276172, 0.964610058969, 0.971655363423, 0.976231891224, 0.979402609806, 0.981754260409, 0.983510269845, 0.98503201627, 0.986234509006, 0.987594582558, 0.988831108808, 0.990315455876, 0.991614978601, 0.993112910369, 0.994592787429, 0.996128854807, 0.997766030171, 0.999473104286,
0.853720880853, 0.912554857475, 0.930711615275, 0.950455648101, 0.960958503215, 0.96667403636, 0.971362349207, 0.974072622141, 0.975613226504, 0.978781040524, 0.979969048509, 0.981158076869, 0.982883137674, 0.984722620635, 0.98658900113, 0.988314796092, 0.99036007625, 0.992362984817, 0.994781216862, 0.998299761454,
0.81076220338, 0.893706653475, 0.927362493596, 0.943724930231, 0.955211693836, 0.962355553594, 0.96829254466, 0.97169265604, 0.9743439019, 0.976751528211, 0.978728976699, 0.980505251819, 0.982242719802, 0.98445680889, 0.986286665859, 0.988491766444, 0.990776417926, 0.992905056354, 0.99523724111, 0.998527390073,
0.791555686892, 0.881440991243, 0.920860393213, 0.942552828487, 0.953270936085, 0.959839333526, 0.966037446908, 0.969412253347, 0.972687657063, 0.975263831002, 0.977073908411, 0.979106240394, 0.981338251384, 0.984181227447, 0.985917501452, 0.988715661162, 0.991071452023, 0.993500123657, 0.996109324665, 0.999023620304,
0.801456772286, 0.878467768641, 0.908098341977, 0.933136250773, 0.946448137076, 0.955111533806, 0.963027394102, 0.967258044795, 0.970678461713, 0.974186408589, 0.975688326117, 0.977937704633, 0.980300465358, 0.982588089404, 0.985135222036, 0.987727956878, 0.990090763787, 0.992444149444, 0.995132471845, 0.998223031362,
0.797066626991, 0.891473428482, 0.927754609701, 0.944168588194, 0.953960755811, 0.960570142061, 0.96609929426, 0.969401647648, 0.97267245227, 0.975204503945, 0.977409332209, 0.979649966483, 0.981643846133, 0.984198624552, 0.986564293774, 0.989219718804, 0.991824162369, 0.994323089168, 0.996765571563, 0.999382383983,
0.785606294091, 0.88697381026, 0.922042917455, 0.942085547292, 0.952637067226, 0.95956056371, 0.965247607479, 0.968692477325, 0.971654628021, 0.974079838834, 0.976237478008, 0.978466192909, 0.980687087441, 0.983108513663, 0.985172089232, 0.987865648168, 0.990607337903, 0.993466132139, 0.996502037444, 0.999093888752,
0.791792877213, 0.87257820722, 0.910251232828, 0.931499697455, 0.943952143348, 0.951516025861, 0.957952121781, 0.962339434864, 0.966062146417, 0.968616514774, 0.971381676811, 0.973583176861, 0.976055057192, 0.979084758838, 0.98170749589, 0.984744675444, 0.987623582667, 0.990827225403, 0.994500953634, 0.998296977796,
0.77155359043, 0.869021329171, 0.916734846729, 0.939085193953, 0.951543198728, 0.959581932049, 0.965130869692, 0.970274846734, 0.973566829975, 0.976120453022, 0.978316404127, 0.980331703541, 0.982562198606, 0.984891459627, 0.986832139561, 0.98925186408, 0.991631382422, 0.993799878089, 0.996076708336, 0.998724170375,
0.834596697018, 0.910219914842, 0.938953487334, 0.954867094446, 0.965021373662, 0.971626167075, 0.975397860726, 0.978190765304, 0.980255255374, 0.982316415294, 0.983438406514, 0.985042910404, 0.986538091767, 0.988296632234, 0.989480478459, 0.991169147712, 0.992838525537, 0.994484549662, 0.996349006533, 0.998796129593,
0.789379799628, 0.883671741466, 0.918305025089, 0.937438007277, 0.948741402993, 0.955676286943, 0.961312308858, 0.965026524442, 0.96851536141, 0.971303982655, 0.973488277598, 0.975794458236, 0.978215624859, 0.980917182109, 0.983425193859, 0.986089845307, 0.988720428405, 0.991544643936, 0.994909928794, 0.998959191331,
0.798779556224, 0.879522751508, 0.920868515488, 0.940119021497, 0.951843311114, 0.959017895362, 0.965208194261, 0.969082003508, 0.97221027879, 0.974691498988, 0.976855260305, 0.978953094688, 0.981167317793, 0.983437790312, 0.985786929537, 0.988162887114, 0.990644384686, 0.993026404153, 0.995507443715, 0.998621645017,
0.879263930866, 0.922020986557, 0.946493775014, 0.958297320378, 0.964859488123, 0.970312211669, 0.97424567735, 0.976421680027, 0.978413032059, 0.979756913163, 0.981009201268, 0.982435429622, 0.98394030031, 0.98552902602, 0.987216102866, 0.988920029566, 0.990524992279, 0.992390079391, 0.994639444871, 0.998357673395,
0.827950174866, 0.894831372155, 0.9313264476, 0.94757593276, 0.956105682938, 0.962399480035, 0.967538440479, 0.971214741258, 0.973124877585, 0.976048275534, 0.977788559952, 0.979504668032, 0.981543513231, 0.983649526565, 0.985455734324, 0.987721369114, 0.990020663907, 0.992286543302, 0.994788654436, 0.998404057223,
0.754610909207, 0.860747233451, 0.899736485589, 0.925743580664, 0.939166571882, 0.948339482047, 0.954553566064, 0.958885994659, 0.962909621012, 0.966353236056, 0.969046685166, 0.971977721769, 0.974626509781, 0.977835303452, 0.980470725918, 0.983736275578, 0.986964243248, 0.990258297647, 0.993951518828, 0.998556710426,
0.784958874194, 0.874722011293, 0.915389887756, 0.934343596754, 0.948237654386, 0.95672974228, 0.963437883085, 0.96720024908, 0.970772553663, 0.973495160679, 0.975921421699, 0.978227890716, 0.980404390223, 0.983010440205, 0.985215187951, 0.987733337063, 0.98989913274, 0.992273798878, 0.994922548059, 0.99824153924,
0.822296667468, 0.89834966427, 0.933957434279, 0.951462994733, 0.95978880675, 0.965804629849, 0.970633032468, 0.973789346233, 0.976510124286, 0.978522727364, 0.980401981947, 0.982057448136, 0.984052101154, 0.986046587711, 0.98771061484, 0.989891431917, 0.992010784627, 0.994204644975, 0.996543052165, 0.99909102487,
================================================
FILE: test/lstm_ssim.csv
================================================
0.769036023469, 0.88985832751, 0.927751848949, 0.944983151872, 0.95575157289, 0.963951705505, 0.97091182084, 0.975998458878, 0.979411580032, 0.981903538902, 0.984059023874, 0.985847690249, 0.9873427695, 0.988557280396, 0.989808215992, 0.990955339337,
0.805634389938, 0.885551865434, 0.926222844478, 0.945195184339, 0.956706397486, 0.964635434939, 0.970561396161, 0.975371321751, 0.979008896093, 0.981623105097, 0.983889966323, 0.985693922222, 0.987117323848, 0.988116453261, 0.989126652794, 0.990035372754,
0.886382589545, 0.944226634293, 0.966099553101, 0.97574891786, 0.981699305186, 0.985303117183, 0.988158479953, 0.990353107687, 0.991742329762, 0.992736636125, 0.993586103345, 0.994247922172, 0.994823658832, 0.99523818175, 0.99566695018, 0.996024457998,
0.844943736432, 0.912763739137, 0.941348917006, 0.955485600323, 0.964352299005, 0.971086350909, 0.976031220021, 0.980033889346, 0.983037018013, 0.985100156288, 0.986891083059, 0.98828168163, 0.989371458473, 0.990204501149, 0.991008062343, 0.991775083665,
0.807296336647, 0.901032376198, 0.936630910374, 0.953867965935, 0.965072366127, 0.973137076713, 0.979042150279, 0.983166487096, 0.985988824175, 0.987817180773, 0.989464580939, 0.990820235888, 0.991892422931, 0.992712253547, 0.993501561405, 0.994148098694,
0.785345752053, 0.884711958245, 0.92776368098, 0.948715836445, 0.96019570665, 0.96605401155, 0.972573249337, 0.97720752253, 0.980241043367, 0.982484654905, 0.984849828368, 0.986701720731, 0.988188381493, 0.98922008024, 0.990398367023, 0.991365949801,
0.879784256538, 0.951585216892, 0.972381608615, 0.981022083203, 0.985810271759, 0.988693576841, 0.990952401157, 0.992588896661, 0.993784304761, 0.994579427454, 0.995314920728, 0.99586908212, 0.996291444393, 0.996577798766, 0.996861139444, 0.997104778181,
0.815477354243, 0.906116356367, 0.944359558351, 0.96028214433, 0.969452971301, 0.975073789021, 0.979816675642, 0.983303430092, 0.985631084073, 0.987351580806, 0.988992309204, 0.99024746499, 0.991323941253, 0.992077409918, 0.992875703722, 0.993551435751,
0.898231591406, 0.952767120478, 0.971324990566, 0.979139320561, 0.983235587546, 0.985925513104, 0.98813798392, 0.989849456611, 0.991087155229, 0.991997953486, 0.992918185114, 0.993614367759, 0.994164453638, 0.994560726328, 0.994937149117, 0.995284527261,
0.8749003714, 0.942176320572, 0.966895893079, 0.977493170796, 0.982875921764, 0.986274282021, 0.98850906544, 0.990252593185, 0.991524822651, 0.992438630258, 0.993323088693, 0.994034014192, 0.994607996745, 0.994996589457, 0.995366557151, 0.99569469237,
0.831185109757, 0.908127863885, 0.94129870614, 0.957395739465, 0.966623340753, 0.971948863236, 0.977120326656, 0.980795107721, 0.983424173894, 0.9853323917, 0.987235341184, 0.988695696085, 0.989908689415, 0.990740343679, 0.991648831536, 0.992402349259,
0.867455122658, 0.929286374978, 0.956373268572, 0.968857918522, 0.975855487208, 0.979544787408, 0.983269284681, 0.985944004, 0.987814957273, 0.989295675642, 0.990662162931, 0.991705980646, 0.992601345026, 0.99321774356, 0.993851078984, 0.994396313967,
0.741714560572, 0.852678337269, 0.89776940045, 0.921657297609, 0.937268311557, 0.948777493887, 0.958587922105, 0.966219484116, 0.970930915102, 0.973981680428, 0.976954171415, 0.979158752294, 0.981113945529, 0.982622981303, 0.984346679365, 0.985939206912,
0.806023887853, 0.898208741526, 0.934023153024, 0.951794401119, 0.962675974888, 0.969812260388, 0.975737969034, 0.980133270486, 0.983291193208, 0.985354416391, 0.98719947195, 0.98876036237, 0.990016818176, 0.990903992042, 0.991817776314, 0.992603235282,
0.885580139856, 0.938517703814, 0.959134612721, 0.968447444909, 0.974006513851, 0.978175881773, 0.981340248165, 0.983964849591, 0.985997455292, 0.987552480984, 0.988913071725, 0.989989662136, 0.990858978786, 0.991507864087, 0.992101253603, 0.992706639095,
0.843116295553, 0.918865830119, 0.948986302851, 0.963840566631, 0.971863394331, 0.976064780745, 0.980732356364, 0.98399885847, 0.986220002076, 0.987800243738, 0.989471739682, 0.990764010755, 0.991805901688, 0.992578822333, 0.993415281055, 0.994092268046,
0.88346061529, 0.946000867977, 0.96760658804, 0.976501600498, 0.981740918991, 0.985389590011, 0.988097110038, 0.990101000036, 0.991452711903, 0.992372729231, 0.993280593371, 0.993979273128, 0.994576920437, 0.994992603394, 0.99541165183, 0.995793186385,
0.801573520738, 0.894556913607, 0.930557148943, 0.947742213628, 0.958336710515, 0.966270887606, 0.972043805126, 0.976686743156, 0.979997380949, 0.982263084747, 0.984371624652, 0.986043574351, 0.987418653777, 0.988435017399, 0.98943820625, 0.990384899709,
0.824336788203, 0.914307147717, 0.947162910211, 0.961113167281, 0.969251988341, 0.974708549133, 0.979348243062, 0.982888751196, 0.985039998783, 0.986690243229, 0.988265858776, 0.989457451692, 0.990425556398, 0.991142353579, 0.991913402107, 0.992646966198,
0.910171904704, 0.952375683078, 0.967737628914, 0.975113825208, 0.979525760009, 0.982903358103, 0.985497414947, 0.987532990507, 0.989022084385, 0.990094546854, 0.991046427233, 0.991786419928, 0.992476722408, 0.992941886737, 0.9934275669, 0.993905147004,
0.864225726494, 0.929980029382, 0.955778654113, 0.96740495639, 0.974178985114, 0.978346408898, 0.982328012772, 0.985109106162, 0.987107132775, 0.988477856206, 0.989870184226, 0.990930157844, 0.991839380567, 0.992475550555, 0.993157023569, 0.993742594496,
0.803770468068, 0.894709100591, 0.93001725197, 0.946972008686, 0.957454439962, 0.965257644855, 0.971242168162, 0.97578798359, 0.979145612293, 0.981419222339, 0.983528566216, 0.985134373586, 0.986496893548, 0.98753676895, 0.988504824958, 0.989410139829,
0.887667125483, 0.950280631321, 0.971841626657, 0.980425626652, 0.984775140336, 0.987822408265, 0.989908020125, 0.991464951332, 0.992696046157, 0.993535144861, 0.99432064875, 0.994901414575, 0.995374559063, 0.995705134049, 0.996001978174, 0.996274517893,
0.828518548366, 0.907992823067, 0.940938577175, 0.956594052005, 0.966157200834, 0.972833437098, 0.977987634362, 0.981907039284, 0.984516045128, 0.986238836359, 0.987891164633, 0.989148044611, 0.990264907045, 0.991071889081, 0.991900734924, 0.992623228933,
================================================
FILE: train.py
================================================
import time
import os
import argparse
import numpy as np
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as LS
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.utils.data as data
from torchvision import transforms
parser = argparse.ArgumentParser()
parser.add_argument(
'--batch-size', '-N', type=int, default=32, help='batch size')
parser.add_argument(
'--train', '-f', required=True, type=str, help='folder of training images')
parser.add_argument(
'--max-epochs', '-e', type=int, default=200, help='max epochs')
parser.add_argument('--lr', type=float, default=0.0005, help='learning rate')
# parser.add_argument('--cuda', '-g', action='store_true', help='enables cuda')
parser.add_argument(
'--iterations', type=int, default=16, help='unroll iterations')
parser.add_argument('--checkpoint', type=int, help='unroll iterations')
args = parser.parse_args()
## load 32x32 patches from images
import dataset
train_transform = transforms.Compose([
transforms.RandomCrop((32, 32)),
transforms.ToTensor(),
])
train_set = dataset.ImageFolder(root=args.train, transform=train_transform)
train_loader = data.DataLoader(
dataset=train_set, batch_size=args.batch_size, shuffle=True, num_workers=1)
print('total images: {}; total batches: {}'.format(
len(train_set), len(train_loader)))
## load networks on GPU
import network
encoder = network.EncoderCell().cuda()
binarizer = network.Binarizer().cuda()
decoder = network.DecoderCell().cuda()
solver = optim.Adam(
[
{
'params': encoder.parameters()
},
{
'params': binarizer.parameters()
},
{
'params': decoder.parameters()
},
],
lr=args.lr)
def resume(epoch=None):
if epoch is None:
s = 'iter'
epoch = 0
else:
s = 'epoch'
encoder.load_state_dict(
torch.load('checkpoint/encoder_{}_{:08d}.pth'.format(s, epoch)))
binarizer.load_state_dict(
torch.load('checkpoint/binarizer_{}_{:08d}.pth'.format(s, epoch)))
decoder.load_state_dict(
torch.load('checkpoint/decoder_{}_{:08d}.pth'.format(s, epoch)))
def save(index, epoch=True):
if not os.path.exists('checkpoint'):
os.mkdir('checkpoint')
if epoch:
s = 'epoch'
else:
s = 'iter'
torch.save(encoder.state_dict(), 'checkpoint/encoder_{}_{:08d}.pth'.format(
s, index))
torch.save(binarizer.state_dict(),
'checkpoint/binarizer_{}_{:08d}.pth'.format(s, index))
torch.save(decoder.state_dict(), 'checkpoint/decoder_{}_{:08d}.pth'.format(
s, index))
# resume()
scheduler = LS.MultiStepLR(solver, milestones=[3, 10, 20, 50, 100], gamma=0.5)
last_epoch = 0
if args.checkpoint:
resume(args.checkpoint)
last_epoch = args.checkpoint
scheduler.last_epoch = last_epoch - 1
for epoch in range(last_epoch + 1, args.max_epochs + 1):
scheduler.step()
for batch, data in enumerate(train_loader):
batch_t0 = time.time()
## init lstm state
encoder_h_1 = (Variable(torch.zeros(data.size(0), 256, 8, 8).cuda()),
Variable(torch.zeros(data.size(0), 256, 8, 8).cuda()))
encoder_h_2 = (Variable(torch.zeros(data.size(0), 512, 4, 4).cuda()),
Variable(torch.zeros(data.size(0), 512, 4, 4).cuda()))
encoder_h_3 = (Variable(torch.zeros(data.size(0), 512, 2, 2).cuda()),
Variable(torch.zeros(data.size(0), 512, 2, 2).cuda()))
decoder_h_1 = (Variable(torch.zeros(data.size(0), 512, 2, 2).cuda()),
Variable(torch.zeros(data.size(0), 512, 2, 2).cuda()))
decoder_h_2 = (Variable(torch.zeros(data.size(0), 512, 4, 4).cuda()),
Variable(torch.zeros(data.size(0), 512, 4, 4).cuda()))
decoder_h_3 = (Variable(torch.zeros(data.size(0), 256, 8, 8).cuda()),
Variable(torch.zeros(data.size(0), 256, 8, 8).cuda()))
decoder_h_4 = (Variable(torch.zeros(data.size(0), 128, 16, 16).cuda()),
Variable(torch.zeros(data.size(0), 128, 16, 16).cuda()))
patches = Variable(data.cuda())
solver.zero_grad()
losses = []
res = patches - 0.5
bp_t0 = time.time()
for _ in range(args.iterations):
encoded, encoder_h_1, encoder_h_2, encoder_h_3 = encoder(
res, encoder_h_1, encoder_h_2, encoder_h_3)
codes = binarizer(encoded)
output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder(
codes, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4)
res = res - output
losses.append(res.abs().mean())
bp_t1 = time.time()
loss = sum(losses) / args.iterations
loss.backward()
solver.step()
batch_t1 = time.time()
print(
'[TRAIN] Epoch[{}]({}/{}); Loss: {:.6f}; Backpropagation: {:.4f} sec; Batch: {:.4f} sec'.
format(epoch, batch + 1,
len(train_loader), loss.data[0], bp_t1 - bp_t0, batch_t1 -
batch_t0))
print(('{:.4f} ' * args.iterations +
'\n').format(* [l.data[0] for l in losses]))
index = (epoch - 1) * len(train_loader) + batch
## save checkpoint every 500 training steps
if index % 500 == 0:
save(0, False)
save(epoch)
gitextract_yfbt7nhy/ ├── .gitignore ├── README.md ├── dataset.py ├── decoder.py ├── encoder.py ├── functions/ │ ├── __init__.py │ └── sign.py ├── metric.py ├── modules/ │ ├── __init__.py │ ├── conv_rnn.py │ └── sign.py ├── network.py ├── test/ │ ├── calc_ssim.sh │ ├── draw_rd.py │ ├── enc_dec.sh │ ├── get_kodak.sh │ ├── jpeg.sh │ ├── jpeg_ssim.csv │ └── lstm_ssim.csv └── train.py
SYMBOL INDEX (36 symbols across 7 files)
FILE: dataset.py
function is_image_file (line 25) | def is_image_file(filename):
function default_loader (line 29) | def default_loader(path):
class ImageFolder (line 33) | class ImageFolder(data.Dataset):
method __init__ (line 36) | def __init__(self, root, transform=None, loader=default_loader):
method __getitem__ (line 47) | def __getitem__(self, index):
method __len__ (line 58) | def __len__(self):
FILE: functions/sign.py
class Sign (line 5) | class Sign(Function):
method __init__ (line 11) | def __init__(self):
method forward (line 15) | def forward(ctx, input, is_training=True):
method backward (line 27) | def backward(ctx, grad_output):
FILE: metric.py
function _FSpecialGauss (line 25) | def _FSpecialGauss(size, sigma):
function _SSIMForMultiScale (line 39) | def _SSIMForMultiScale(img1,
function MultiScaleSSIM (line 121) | def MultiScaleSSIM(img1,
function msssim (line 197) | def msssim(original, compared):
function psnr (line 209) | def psnr(original, compared):
function main (line 221) | def main():
FILE: modules/conv_rnn.py
class ConvRNNCellBase (line 8) | class ConvRNNCellBase(nn.Module):
method __repr__ (line 9) | def __repr__(self):
class ConvLSTMCell (line 22) | class ConvLSTMCell(ConvRNNCellBase):
method __init__ (line 23) | def __init__(self,
method reset_parameters (line 66) | def reset_parameters(self):
method forward (line 70) | def forward(self, input, hidden):
FILE: modules/sign.py
class Sign (line 7) | class Sign(nn.Module):
method __init__ (line 8) | def __init__(self):
method forward (line 11) | def forward(self, x):
FILE: network.py
class EncoderCell (line 8) | class EncoderCell(nn.Module):
method __init__ (line 9) | def __init__(self):
method forward (line 39) | def forward(self, input, hidden1, hidden2, hidden3):
class Binarizer (line 54) | class Binarizer(nn.Module):
method __init__ (line 55) | def __init__(self):
method forward (line 60) | def forward(self, input):
class DecoderCell (line 66) | class DecoderCell(nn.Module):
method __init__ (line 67) | def __init__(self):
method forward (line 107) | def forward(self, input, hidden1, hidden2, hidden3, hidden4):
FILE: train.py
function resume (line 68) | def resume(epoch=None):
function save (line 83) | def save(index, epoch=True):
Condensed preview — 20 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (50K chars).
[
{
"path": ".gitignore",
"chars": 175,
"preview": "*.pth\n*.pyc\n*.bak\n*.png\n*.jpg\n*.npz\n!rd.png\n!/kodim10.png\n!/bpp-0.125-0.133-ssim-0.865-0.827.png\n!/bpp-0.250-0.249-ssim-"
},
{
"path": "README.md",
"chars": 2101,
"preview": "# Full Resolution Image Compression with Recurrent Neural Networks\nhttps://arxiv.org/abs/1608.05148v2\n\n## Requirements\n-"
},
{
"path": "dataset.py",
"chars": 1394,
"preview": "# modified from https://github.com/desimone/vision/blob/fb74c76d09bcc2594159613d5bdadd7d4697bb11/torchvision/datasets/fo"
},
{
"path": "decoder.py",
"chars": 2764,
"preview": "import os\nimport argparse\n\nimport numpy as np\nfrom scipy.misc import imread, imresize, imsave\n\nimport torch\nfrom torch.a"
},
{
"path": "encoder.py",
"chars": 4197,
"preview": "import argparse\n\nimport numpy as np\nfrom scipy.misc import imread, imresize, imsave\n\nimport torch\nfrom torch.autograd im"
},
{
"path": "functions/__init__.py",
"chars": 23,
"preview": "from .sign import Sign\n"
},
{
"path": "functions/sign.py",
"chars": 725,
"preview": "import torch\nfrom torch.autograd import Function\n\n\nclass Sign(Function):\n \"\"\"\n Variable Rate Image Compression wit"
},
{
"path": "metric.py",
"chars": 8237,
"preview": "## some function borrowed from\n## https://github.com/tensorflow/models/blob/master/compression/image_encoder/msssim.py\n\""
},
{
"path": "modules/__init__.py",
"chars": 71,
"preview": "from .conv_rnn import ConvLSTMCell #, ConvLSTM\nfrom .sign import Sign\n"
},
{
"path": "modules/conv_rnn.py",
"chars": 2599,
"preview": "import torch.nn as nn\nimport torch.nn.functional as F\nimport torch\nfrom torch.autograd import Variable\nfrom torch.nn.mod"
},
{
"path": "modules/sign.py",
"chars": 234,
"preview": "import torch\nimport torch.nn as nn\n\nfrom functions import Sign as SignFunction\n\n\nclass Sign(nn.Module):\n def __init__"
},
{
"path": "network.py",
"chars": 3201,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom modules import ConvLSTMCell, Sign\n\n\nclass Encod"
},
{
"path": "test/calc_ssim.sh",
"chars": 559,
"preview": "#!/bin/bash\n\nLSTM=test/lstm_ssim.csv\nJPEG=test/jpeg_ssim.csv\n\necho -n \"\" > $LSTM\nfor i in {01..24..1}; do\n echo Process"
},
{
"path": "test/draw_rd.py",
"chars": 1237,
"preview": "import os\n\nimport numpy as np\nfrom scipy.misc import imread\nimport matplotlib.pyplot as plt\n\nline = True\n\nlstm_ssim = np"
},
{
"path": "test/enc_dec.sh",
"chars": 460,
"preview": "#!/bin/bash\n\nfor i in {01..24..1}; do\n echo Encoding test/images/kodim$i.png\n mkdir -p test/codes\n python encoder.py "
},
{
"path": "test/get_kodak.sh",
"chars": 162,
"preview": "#!/bin/bash\n\nmkdir -p test/images\n\nfor i in {01..24..1}; do\n echo ${i}\n wget http://r0k.us/graphics/kodak/kodak/kodim$"
},
{
"path": "test/jpeg.sh",
"chars": 269,
"preview": "#!/bin/bash\n\nfor i in {01..24..1}; do\n echo JPEG Encoding test/images/kodim$i.png\n mkdir -p test/jpeg/kodim$i\n for j "
},
{
"path": "test/jpeg_ssim.csv",
"chars": 7664,
"preview": "0.818072219541, 0.915486738863, 0.941250388079, 0.954959594001, 0.964432174175, 0.969521944618, 0.974450024801, 0.977023"
},
{
"path": "test/lstm_ssim.csv",
"chars": 6124,
"preview": "0.769036023469, 0.88985832751, 0.927751848949, 0.944983151872, 0.95575157289, 0.963951705505, 0.97091182084, 0.975998458"
},
{
"path": "train.py",
"chars": 5517,
"preview": "import time\nimport os\nimport argparse\n\nimport numpy as np\n\nimport torch\nimport torch.optim as optim\nimport torch.optim.l"
}
]
About this extraction
This page contains the full source code of the 1zb/pytorch-image-comp-rnn GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 20 files (46.6 KB), approximately 16.6k tokens, and a symbol index with 36 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.