Repository: yjn870/FSRCNN-pytorch Branch: master Commit: 3eaab297f4b7 Files: 8 Total size: 17.3 KB Directory structure: gitextract_64n4vj46/ ├── .gitignore ├── README.md ├── datasets.py ├── models.py ├── prepare.py ├── test.py ├── train.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ .DS_Store .idea ================================================ FILE: README.md ================================================ # FSRCNN This repository is implementation of the ["Accelerating the Super-Resolution Convolutional Neural Network"](https://arxiv.org/abs/1608.00367).
## Differences from the original - Added the zero-padding - Used the Adam instead of the SGD ## Requirements - PyTorch 1.0.0 - Numpy 1.15.4 - Pillow 5.4.1 - h5py 2.8.0 - tqdm 4.30.0 ## Train The 91-image, Set5 dataset converted to HDF5 can be downloaded from the links below. | Dataset | Scale | Type | Link | |---------|-------|------|------| | 91-image | 2 | Train | [Download](https://www.dropbox.com/s/01z95js39kgw1qv/91-image_x2.h5?dl=0) | | 91-image | 3 | Train | [Download](https://www.dropbox.com/s/qx4swlt2j7u4twr/91-image_x3.h5?dl=0) | | 91-image | 4 | Train | [Download](https://www.dropbox.com/s/vobvi2nlymtvezb/91-image_x4.h5?dl=0) | | Set5 | 2 | Eval | [Download](https://www.dropbox.com/s/4kzqmtqzzo29l1x/Set5_x2.h5?dl=0) | | Set5 | 3 | Eval | [Download](https://www.dropbox.com/s/kyhbhyc5a0qcgnp/Set5_x3.h5?dl=0) | | Set5 | 4 | Eval | [Download](https://www.dropbox.com/s/ihtv1acd48cof14/Set5_x4.h5?dl=0) | Otherwise, you can use `prepare.py` to create custom dataset. ```bash python train.py --train-file "BLAH_BLAH/91-image_x3.h5" \ --eval-file "BLAH_BLAH/Set5_x3.h5" \ --outputs-dir "BLAH_BLAH/outputs" \ --scale 3 \ --lr 1e-3 \ --batch-size 16 \ --num-epochs 20 \ --num-workers 8 \ --seed 123 ``` ## Test Pre-trained weights can be downloaded from the links below. | Model | Scale | Link | |-------|-------|------| | FSRCNN(56,12,4) | 2 | [Download](https://www.dropbox.com/s/1k3dker6g7hz76s/fsrcnn_x2.pth?dl=0) | | FSRCNN(56,12,4) | 3 | [Download](https://www.dropbox.com/s/pm1ed2nyboulz5z/fsrcnn_x3.pth?dl=0) | | FSRCNN(56,12,4) | 4 | [Download](https://www.dropbox.com/s/vsvumpopupdpmmu/fsrcnn_x4.pth?dl=0) | The results are stored in the same path as the query image. ```bash python test.py --weights-file "BLAH_BLAH/fsrcnn_x3.pth" \ --image-file "data/butterfly_GT.bmp" \ --scale 3 ``` ## Results PSNR was calculated on the Y channel. ### Set5 | Eval. Mat | Scale | Paper | Ours (91-image) | |-----------|-------|-------|-----------------| | PSNR | 2 | 36.94 | 37.12 | | PSNR | 3 | 33.06 | 33.22 | | PSNR | 4 | 30.55 | 30.50 |
Original
BICUBIC x3
FSRCNN x3 (34.66 dB)
Original
BICUBIC x3
FSRCNN x3 (28.55 dB)
================================================ FILE: datasets.py ================================================ import h5py import numpy as np from torch.utils.data import Dataset class TrainDataset(Dataset): def __init__(self, h5_file): super(TrainDataset, self).__init__() self.h5_file = h5_file def __getitem__(self, idx): with h5py.File(self.h5_file, 'r') as f: return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0) def __len__(self): with h5py.File(self.h5_file, 'r') as f: return len(f['lr']) class EvalDataset(Dataset): def __init__(self, h5_file): super(EvalDataset, self).__init__() self.h5_file = h5_file def __getitem__(self, idx): with h5py.File(self.h5_file, 'r') as f: return np.expand_dims(f['lr'][str(idx)][:, :] / 255., 0), np.expand_dims(f['hr'][str(idx)][:, :] / 255., 0) def __len__(self): with h5py.File(self.h5_file, 'r') as f: return len(f['lr']) ================================================ FILE: models.py ================================================ import math from torch import nn class FSRCNN(nn.Module): def __init__(self, scale_factor, num_channels=1, d=56, s=12, m=4): super(FSRCNN, self).__init__() self.first_part = nn.Sequential( nn.Conv2d(num_channels, d, kernel_size=5, padding=5//2), nn.PReLU(d) ) self.mid_part = [nn.Conv2d(d, s, kernel_size=1), nn.PReLU(s)] for _ in range(m): self.mid_part.extend([nn.Conv2d(s, s, kernel_size=3, padding=3//2), nn.PReLU(s)]) self.mid_part.extend([nn.Conv2d(s, d, kernel_size=1), nn.PReLU(d)]) self.mid_part = nn.Sequential(*self.mid_part) self.last_part = nn.ConvTranspose2d(d, num_channels, kernel_size=9, stride=scale_factor, padding=9//2, output_padding=scale_factor-1) self._initialize_weights() def _initialize_weights(self): for m in self.first_part: if isinstance(m, nn.Conv2d): nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel()))) nn.init.zeros_(m.bias.data) for m in self.mid_part: if isinstance(m, nn.Conv2d): nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel()))) nn.init.zeros_(m.bias.data) nn.init.normal_(self.last_part.weight.data, mean=0.0, std=0.001) nn.init.zeros_(self.last_part.bias.data) def forward(self, x): x = self.first_part(x) x = self.mid_part(x) x = self.last_part(x) return x ================================================ FILE: prepare.py ================================================ import argparse import glob import h5py import numpy as np import PIL.Image as pil_image from utils import calc_patch_size, convert_rgb_to_y @calc_patch_size def train(args): h5_file = h5py.File(args.output_path, 'w') lr_patches = [] hr_patches = [] for image_path in sorted(glob.glob('{}/*'.format(args.images_dir))): hr = pil_image.open(image_path).convert('RGB') hr_images = [] if args.with_aug: for s in [1.0, 0.9, 0.8, 0.7, 0.6]: for r in [0, 90, 180, 270]: tmp = hr.resize((int(hr.width * s), int(hr.height * s)), resample=pil_image.BICUBIC) tmp = tmp.rotate(r, expand=True) hr_images.append(tmp) else: hr_images.append(hr) for hr in hr_images: hr_width = (hr.width // args.scale) * args.scale hr_height = (hr.height // args.scale) * args.scale hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC) lr = hr.resize((hr.width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC) hr = np.array(hr).astype(np.float32) lr = np.array(lr).astype(np.float32) hr = convert_rgb_to_y(hr) lr = convert_rgb_to_y(lr) for i in range(0, lr.shape[0] - args.patch_size + 1, args.scale): for j in range(0, lr.shape[1] - args.patch_size + 1, args.scale): lr_patches.append(lr[i:i+args.patch_size, j:j+args.patch_size]) hr_patches.append(hr[i*args.scale:i*args.scale+args.patch_size*args.scale, j*args.scale:j*args.scale+args.patch_size*args.scale]) lr_patches = np.array(lr_patches) hr_patches = np.array(hr_patches) h5_file.create_dataset('lr', data=lr_patches) h5_file.create_dataset('hr', data=hr_patches) h5_file.close() def eval(args): h5_file = h5py.File(args.output_path, 'w') lr_group = h5_file.create_group('lr') hr_group = h5_file.create_group('hr') for i, image_path in enumerate(sorted(glob.glob('{}/*'.format(args.images_dir)))): hr = pil_image.open(image_path).convert('RGB') hr_width = (hr.width // args.scale) * args.scale hr_height = (hr.height // args.scale) * args.scale hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC) lr = hr.resize((hr.width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC) hr = np.array(hr).astype(np.float32) lr = np.array(lr).astype(np.float32) hr = convert_rgb_to_y(hr) lr = convert_rgb_to_y(lr) lr_group.create_dataset(str(i), data=lr) hr_group.create_dataset(str(i), data=hr) h5_file.close() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--images-dir', type=str, required=True) parser.add_argument('--output-path', type=str, required=True) parser.add_argument('--scale', type=int, default=2) parser.add_argument('--with-aug', action='store_true') parser.add_argument('--eval', action='store_true') args = parser.parse_args() if not args.eval: train(args) else: eval(args) ================================================ FILE: test.py ================================================ import argparse import torch import torch.backends.cudnn as cudnn import numpy as np import PIL.Image as pil_image from models import FSRCNN from utils import convert_ycbcr_to_rgb, preprocess, calc_psnr if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--weights-file', type=str, required=True) parser.add_argument('--image-file', type=str, required=True) parser.add_argument('--scale', type=int, default=3) args = parser.parse_args() cudnn.benchmark = True device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model = FSRCNN(scale_factor=args.scale).to(device) state_dict = model.state_dict() for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items(): if n in state_dict.keys(): state_dict[n].copy_(p) else: raise KeyError(n) model.eval() image = pil_image.open(args.image_file).convert('RGB') image_width = (image.width // args.scale) * args.scale image_height = (image.height // args.scale) * args.scale hr = image.resize((image_width, image_height), resample=pil_image.BICUBIC) lr = hr.resize((hr.width // args.scale, hr.height // args.scale), resample=pil_image.BICUBIC) bicubic = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC) bicubic.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale))) lr, _ = preprocess(lr, device) hr, _ = preprocess(hr, device) _, ycbcr = preprocess(bicubic, device) with torch.no_grad(): preds = model(lr).clamp(0.0, 1.0) psnr = calc_psnr(hr, preds) print('PSNR: {:.2f}'.format(psnr)) preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0) output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0]) output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8) output = pil_image.fromarray(output) output.save(args.image_file.replace('.', '_fsrcnn_x{}.'.format(args.scale))) ================================================ FILE: train.py ================================================ import argparse import os import copy import torch from torch import nn import torch.optim as optim import torch.backends.cudnn as cudnn from torch.utils.data.dataloader import DataLoader from tqdm import tqdm from models import FSRCNN from datasets import TrainDataset, EvalDataset from utils import AverageMeter, calc_psnr if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--train-file', type=str, required=True) parser.add_argument('--eval-file', type=str, required=True) parser.add_argument('--outputs-dir', type=str, required=True) parser.add_argument('--weights-file', type=str) parser.add_argument('--scale', type=int, default=2) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--batch-size', type=int, default=16) parser.add_argument('--num-epochs', type=int, default=20) parser.add_argument('--num-workers', type=int, default=8) parser.add_argument('--seed', type=int, default=123) args = parser.parse_args() args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale)) if not os.path.exists(args.outputs_dir): os.makedirs(args.outputs_dir) cudnn.benchmark = True device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') torch.manual_seed(args.seed) model = FSRCNN(scale_factor=args.scale).to(device) criterion = nn.MSELoss() optimizer = optim.Adam([ {'params': model.first_part.parameters()}, {'params': model.mid_part.parameters()}, {'params': model.last_part.parameters(), 'lr': args.lr * 0.1} ], lr=args.lr) train_dataset = TrainDataset(args.train_file) train_dataloader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) eval_dataset = EvalDataset(args.eval_file) eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1) best_weights = copy.deepcopy(model.state_dict()) best_epoch = 0 best_psnr = 0.0 for epoch in range(args.num_epochs): model.train() epoch_losses = AverageMeter() with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size), ncols=80) as t: t.set_description('epoch: {}/{}'.format(epoch, args.num_epochs - 1)) for data in train_dataloader: inputs, labels = data inputs = inputs.to(device) labels = labels.to(device) preds = model(inputs) loss = criterion(preds, labels) epoch_losses.update(loss.item(), len(inputs)) optimizer.zero_grad() loss.backward() optimizer.step() t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg)) t.update(len(inputs)) torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch))) model.eval() epoch_psnr = AverageMeter() for data in eval_dataloader: inputs, labels = data inputs = inputs.to(device) labels = labels.to(device) with torch.no_grad(): preds = model(inputs).clamp(0.0, 1.0) epoch_psnr.update(calc_psnr(preds, labels), len(inputs)) print('eval psnr: {:.2f}'.format(epoch_psnr.avg)) if epoch_psnr.avg > best_psnr: best_epoch = epoch best_psnr = epoch_psnr.avg best_weights = copy.deepcopy(model.state_dict()) print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr)) torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth')) ================================================ FILE: utils.py ================================================ import torch import numpy as np def calc_patch_size(func): def wrapper(args): if args.scale == 2: args.patch_size = 10 elif args.scale == 3: args.patch_size = 7 elif args.scale == 4: args.patch_size = 6 else: raise Exception('Scale Error', args.scale) return func(args) return wrapper def convert_rgb_to_y(img, dim_order='hwc'): if dim_order == 'hwc': return 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256. else: return 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256. def convert_rgb_to_ycbcr(img, dim_order='hwc'): if dim_order == 'hwc': y = 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256. cb = 128. + (-37.945 * img[..., 0] - 74.494 * img[..., 1] + 112.439 * img[..., 2]) / 256. cr = 128. + (112.439 * img[..., 0] - 94.154 * img[..., 1] - 18.285 * img[..., 2]) / 256. else: y = 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256. cb = 128. + (-37.945 * img[0] - 74.494 * img[1] + 112.439 * img[2]) / 256. cr = 128. + (112.439 * img[0] - 94.154 * img[1] - 18.285 * img[2]) / 256. return np.array([y, cb, cr]).transpose([1, 2, 0]) def convert_ycbcr_to_rgb(img, dim_order='hwc'): if dim_order == 'hwc': r = 298.082 * img[..., 0] / 256. + 408.583 * img[..., 2] / 256. - 222.921 g = 298.082 * img[..., 0] / 256. - 100.291 * img[..., 1] / 256. - 208.120 * img[..., 2] / 256. + 135.576 b = 298.082 * img[..., 0] / 256. + 516.412 * img[..., 1] / 256. - 276.836 else: r = 298.082 * img[0] / 256. + 408.583 * img[2] / 256. - 222.921 g = 298.082 * img[0] / 256. - 100.291 * img[1] / 256. - 208.120 * img[2] / 256. + 135.576 b = 298.082 * img[0] / 256. + 516.412 * img[1] / 256. - 276.836 return np.array([r, g, b]).transpose([1, 2, 0]) def preprocess(img, device): img = np.array(img).astype(np.float32) ycbcr = convert_rgb_to_ycbcr(img) x = ycbcr[..., 0] x /= 255. x = torch.from_numpy(x).to(device) x = x.unsqueeze(0).unsqueeze(0) return x, ycbcr def calc_psnr(img1, img2): return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2)) class AverageMeter(object): def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count