[
  {
    "path": ".gitignore",
    "content": ".DS_Store\n.idea"
  },
  {
    "path": "README.md",
    "content": "# FSRCNN\n\nThis repository is implementation of the [\"Accelerating the Super-Resolution Convolutional Neural Network\"](https://arxiv.org/abs/1608.00367).\n\n<center><img src=\"./thumbnails/fig1.png\"></center>\n\n## Differences from the original\n\n- Added the zero-padding\n- Used the Adam instead of the SGD\n\n## Requirements\n\n- PyTorch 1.0.0\n- Numpy 1.15.4\n- Pillow 5.4.1\n- h5py 2.8.0\n- tqdm 4.30.0\n\n## Train\n\nThe 91-image, Set5 dataset converted to HDF5 can be downloaded from the links below.\n\n| Dataset | Scale | Type | Link |\n|---------|-------|------|------|\n| 91-image | 2 | Train | [Download](https://www.dropbox.com/s/01z95js39kgw1qv/91-image_x2.h5?dl=0) |\n| 91-image | 3 | Train | [Download](https://www.dropbox.com/s/qx4swlt2j7u4twr/91-image_x3.h5?dl=0) |\n| 91-image | 4 | Train | [Download](https://www.dropbox.com/s/vobvi2nlymtvezb/91-image_x4.h5?dl=0) |\n| Set5 | 2 | Eval | [Download](https://www.dropbox.com/s/4kzqmtqzzo29l1x/Set5_x2.h5?dl=0) |\n| Set5 | 3 | Eval | [Download](https://www.dropbox.com/s/kyhbhyc5a0qcgnp/Set5_x3.h5?dl=0) |\n| Set5 | 4 | Eval | [Download](https://www.dropbox.com/s/ihtv1acd48cof14/Set5_x4.h5?dl=0) |\n\nOtherwise, you can use `prepare.py` to create custom dataset.\n\n```bash\npython train.py --train-file \"BLAH_BLAH/91-image_x3.h5\" \\\n                --eval-file \"BLAH_BLAH/Set5_x3.h5\" \\\n                --outputs-dir \"BLAH_BLAH/outputs\" \\\n                --scale 3 \\\n                --lr 1e-3 \\\n                --batch-size 16 \\\n                --num-epochs 20 \\\n                --num-workers 8 \\\n                --seed 123                \n```\n\n## Test\n\nPre-trained weights can be downloaded from the links below.\n\n| Model | Scale | Link |\n|-------|-------|------|\n| FSRCNN(56,12,4) | 2 | [Download](https://www.dropbox.com/s/1k3dker6g7hz76s/fsrcnn_x2.pth?dl=0) |\n| FSRCNN(56,12,4) | 3 | [Download](https://www.dropbox.com/s/pm1ed2nyboulz5z/fsrcnn_x3.pth?dl=0) |\n| FSRCNN(56,12,4) | 4 | [Download](https://www.dropbox.com/s/vsvumpopupdpmmu/fsrcnn_x4.pth?dl=0) |\n\nThe results are stored in the same path as the query image.\n\n```bash\npython test.py --weights-file \"BLAH_BLAH/fsrcnn_x3.pth\" \\\n               --image-file \"data/butterfly_GT.bmp\" \\\n               --scale 3\n```\n\n## Results\n\nPSNR was calculated on the Y channel.\n\n### Set5\n\n| Eval. Mat | Scale | Paper | Ours (91-image) |\n|-----------|-------|-------|-----------------|\n| PSNR | 2 | 36.94 | 37.12 |\n| PSNR | 3 | 33.06 | 33.22 |\n| PSNR | 4 | 30.55 | 30.50 |\n\n<table>\n    <tr>\n        <td><center>Original</center></td>\n        <td><center>BICUBIC x3</center></td>\n        <td><center>FSRCNN x3 (34.66 dB)</center></td>\n    </tr>\n    <tr>\n    \t<td>\n    \t\t<center><img src=\"./data/lenna.bmp\"\"></center>\n    \t</td>\n    \t<td>\n    \t\t<center><img src=\"./data/lenna_bicubic_x3.bmp\"></center>\n    \t</td>\n    \t<td>\n    \t\t<center><img src=\"./data/lenna_fsrcnn_x3.bmp\"></center>\n    \t</td>\n    </tr>\n    <tr>\n        <td><center>Original</center></td>\n        <td><center>BICUBIC x3</center></td>\n        <td><center>FSRCNN x3 (28.55 dB)</center></td>\n    </tr>\n    <tr>\n    \t<td>\n    \t\t<center><img src=\"./data/butterfly_GT.bmp\"\"></center>\n    \t</td>\n    \t<td>\n    \t\t<center><img src=\"./data/butterfly_GT_bicubic_x3.bmp\"></center>\n    \t</td>\n    \t<td>\n    \t\t<center><img src=\"./data/butterfly_GT_fsrcnn_x3.bmp\"></center>\n    \t</td>\n    </tr>\n</table>\n"
  },
  {
    "path": "datasets.py",
    "content": "import h5py\nimport numpy as np\nfrom torch.utils.data import Dataset\n\n\nclass TrainDataset(Dataset):\n    def __init__(self, h5_file):\n        super(TrainDataset, self).__init__()\n        self.h5_file = h5_file\n\n    def __getitem__(self, idx):\n        with h5py.File(self.h5_file, 'r') as f:\n            return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)\n\n    def __len__(self):\n        with h5py.File(self.h5_file, 'r') as f:\n            return len(f['lr'])\n\n\nclass EvalDataset(Dataset):\n    def __init__(self, h5_file):\n        super(EvalDataset, self).__init__()\n        self.h5_file = h5_file\n\n    def __getitem__(self, idx):\n        with h5py.File(self.h5_file, 'r') as f:\n            return np.expand_dims(f['lr'][str(idx)][:, :] / 255., 0), np.expand_dims(f['hr'][str(idx)][:, :] / 255., 0)\n\n    def __len__(self):\n        with h5py.File(self.h5_file, 'r') as f:\n            return len(f['lr'])\n"
  },
  {
    "path": "models.py",
    "content": "import math\nfrom torch import nn\n\n\nclass FSRCNN(nn.Module):\n    def __init__(self, scale_factor, num_channels=1, d=56, s=12, m=4):\n        super(FSRCNN, self).__init__()\n        self.first_part = nn.Sequential(\n            nn.Conv2d(num_channels, d, kernel_size=5, padding=5//2),\n            nn.PReLU(d)\n        )\n        self.mid_part = [nn.Conv2d(d, s, kernel_size=1), nn.PReLU(s)]\n        for _ in range(m):\n            self.mid_part.extend([nn.Conv2d(s, s, kernel_size=3, padding=3//2), nn.PReLU(s)])\n        self.mid_part.extend([nn.Conv2d(s, d, kernel_size=1), nn.PReLU(d)])\n        self.mid_part = nn.Sequential(*self.mid_part)\n        self.last_part = nn.ConvTranspose2d(d, num_channels, kernel_size=9, stride=scale_factor, padding=9//2,\n                                            output_padding=scale_factor-1)\n\n        self._initialize_weights()\n\n    def _initialize_weights(self):\n        for m in self.first_part:\n            if isinstance(m, nn.Conv2d):\n                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))\n                nn.init.zeros_(m.bias.data)\n        for m in self.mid_part:\n            if isinstance(m, nn.Conv2d):\n                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))\n                nn.init.zeros_(m.bias.data)\n        nn.init.normal_(self.last_part.weight.data, mean=0.0, std=0.001)\n        nn.init.zeros_(self.last_part.bias.data)\n\n    def forward(self, x):\n        x = self.first_part(x)\n        x = self.mid_part(x)\n        x = self.last_part(x)\n        return x\n\n\n"
  },
  {
    "path": "prepare.py",
    "content": "import argparse\nimport glob\nimport h5py\nimport numpy as np\nimport PIL.Image as pil_image\nfrom utils import calc_patch_size, convert_rgb_to_y\n\n\n@calc_patch_size\ndef train(args):\n    h5_file = h5py.File(args.output_path, 'w')\n\n    lr_patches = []\n    hr_patches = []\n\n    for image_path in sorted(glob.glob('{}/*'.format(args.images_dir))):\n        hr = pil_image.open(image_path).convert('RGB')\n        hr_images = []\n\n        if args.with_aug:\n            for s in [1.0, 0.9, 0.8, 0.7, 0.6]:\n                for r in [0, 90, 180, 270]:\n                    tmp = hr.resize((int(hr.width * s), int(hr.height * s)), resample=pil_image.BICUBIC)\n                    tmp = tmp.rotate(r, expand=True)\n                    hr_images.append(tmp)\n        else:\n            hr_images.append(hr)\n\n        for hr in hr_images:\n            hr_width = (hr.width // args.scale) * args.scale\n            hr_height = (hr.height // args.scale) * args.scale\n            hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)\n            lr = hr.resize((hr.width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)\n            hr = np.array(hr).astype(np.float32)\n            lr = np.array(lr).astype(np.float32)\n            hr = convert_rgb_to_y(hr)\n            lr = convert_rgb_to_y(lr)\n\n            for i in range(0, lr.shape[0] - args.patch_size + 1, args.scale):\n                for j in range(0, lr.shape[1] - args.patch_size + 1, args.scale):\n                    lr_patches.append(lr[i:i+args.patch_size, j:j+args.patch_size])\n                    hr_patches.append(hr[i*args.scale:i*args.scale+args.patch_size*args.scale, j*args.scale:j*args.scale+args.patch_size*args.scale])\n\n    lr_patches = np.array(lr_patches)\n    hr_patches = np.array(hr_patches)\n\n    h5_file.create_dataset('lr', data=lr_patches)\n    h5_file.create_dataset('hr', data=hr_patches)\n\n    h5_file.close()\n\n\ndef eval(args):\n    h5_file = h5py.File(args.output_path, 'w')\n\n    lr_group = h5_file.create_group('lr')\n    hr_group = h5_file.create_group('hr')\n\n    for i, image_path in enumerate(sorted(glob.glob('{}/*'.format(args.images_dir)))):\n        hr = pil_image.open(image_path).convert('RGB')\n        hr_width = (hr.width // args.scale) * args.scale\n        hr_height = (hr.height // args.scale) * args.scale\n        hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)\n        lr = hr.resize((hr.width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)\n        hr = np.array(hr).astype(np.float32)\n        lr = np.array(lr).astype(np.float32)\n        hr = convert_rgb_to_y(hr)\n        lr = convert_rgb_to_y(lr)\n\n        lr_group.create_dataset(str(i), data=lr)\n        hr_group.create_dataset(str(i), data=hr)\n\n    h5_file.close()\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--images-dir', type=str, required=True)\n    parser.add_argument('--output-path', type=str, required=True)\n    parser.add_argument('--scale', type=int, default=2)\n    parser.add_argument('--with-aug', action='store_true')\n    parser.add_argument('--eval', action='store_true')\n    args = parser.parse_args()\n\n    if not args.eval:\n        train(args)\n    else:\n        eval(args)\n"
  },
  {
    "path": "test.py",
    "content": "import argparse\n\nimport torch\nimport torch.backends.cudnn as cudnn\nimport numpy as np\nimport PIL.Image as pil_image\n\nfrom models import FSRCNN\nfrom utils import convert_ycbcr_to_rgb, preprocess, calc_psnr\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--weights-file', type=str, required=True)\n    parser.add_argument('--image-file', type=str, required=True)\n    parser.add_argument('--scale', type=int, default=3)\n    args = parser.parse_args()\n\n    cudnn.benchmark = True\n    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n\n    model = FSRCNN(scale_factor=args.scale).to(device)\n\n    state_dict = model.state_dict()\n    for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():\n        if n in state_dict.keys():\n            state_dict[n].copy_(p)\n        else:\n            raise KeyError(n)\n\n    model.eval()\n\n    image = pil_image.open(args.image_file).convert('RGB')\n\n    image_width = (image.width // args.scale) * args.scale\n    image_height = (image.height // args.scale) * args.scale\n\n    hr = image.resize((image_width, image_height), resample=pil_image.BICUBIC)\n    lr = hr.resize((hr.width // args.scale, hr.height // args.scale), resample=pil_image.BICUBIC)\n    bicubic = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)\n    bicubic.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))\n\n    lr, _ = preprocess(lr, device)\n    hr, _ = preprocess(hr, device)\n    _, ycbcr = preprocess(bicubic, device)\n\n    with torch.no_grad():\n        preds = model(lr).clamp(0.0, 1.0)\n\n    psnr = calc_psnr(hr, preds)\n    print('PSNR: {:.2f}'.format(psnr))\n\n    preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)\n\n    output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])\n    output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)\n    output = pil_image.fromarray(output)\n    output.save(args.image_file.replace('.', '_fsrcnn_x{}.'.format(args.scale)))\n"
  },
  {
    "path": "train.py",
    "content": "import argparse\nimport os\nimport copy\n\nimport torch\nfrom torch import nn\nimport torch.optim as optim\nimport torch.backends.cudnn as cudnn\nfrom torch.utils.data.dataloader import DataLoader\nfrom tqdm import tqdm\n\nfrom models import FSRCNN\nfrom datasets import TrainDataset, EvalDataset\nfrom utils import AverageMeter, calc_psnr\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--train-file', type=str, required=True)\n    parser.add_argument('--eval-file', type=str, required=True)\n    parser.add_argument('--outputs-dir', type=str, required=True)\n    parser.add_argument('--weights-file', type=str)\n    parser.add_argument('--scale', type=int, default=2)\n    parser.add_argument('--lr', type=float, default=1e-3)\n    parser.add_argument('--batch-size', type=int, default=16)\n    parser.add_argument('--num-epochs', type=int, default=20)\n    parser.add_argument('--num-workers', type=int, default=8)\n    parser.add_argument('--seed', type=int, default=123)\n    args = parser.parse_args()\n\n    args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale))\n\n    if not os.path.exists(args.outputs_dir):\n        os.makedirs(args.outputs_dir)\n\n    cudnn.benchmark = True\n    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n\n    torch.manual_seed(args.seed)\n\n    model = FSRCNN(scale_factor=args.scale).to(device)\n    criterion = nn.MSELoss()\n    optimizer = optim.Adam([\n        {'params': model.first_part.parameters()},\n        {'params': model.mid_part.parameters()},\n        {'params': model.last_part.parameters(), 'lr': args.lr * 0.1}\n    ], lr=args.lr)\n\n    train_dataset = TrainDataset(args.train_file)\n    train_dataloader = DataLoader(dataset=train_dataset,\n                                  batch_size=args.batch_size,\n                                  shuffle=True,\n                                  num_workers=args.num_workers,\n                                  pin_memory=True)\n    eval_dataset = EvalDataset(args.eval_file)\n    eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)\n\n    best_weights = copy.deepcopy(model.state_dict())\n    best_epoch = 0\n    best_psnr = 0.0\n\n    for epoch in range(args.num_epochs):\n        model.train()\n        epoch_losses = AverageMeter()\n\n        with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size), ncols=80) as t:\n            t.set_description('epoch: {}/{}'.format(epoch, args.num_epochs - 1))\n\n            for data in train_dataloader:\n                inputs, labels = data\n\n                inputs = inputs.to(device)\n                labels = labels.to(device)\n\n                preds = model(inputs)\n\n                loss = criterion(preds, labels)\n\n                epoch_losses.update(loss.item(), len(inputs))\n\n                optimizer.zero_grad()\n                loss.backward()\n                optimizer.step()\n\n                t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))\n                t.update(len(inputs))\n\n        torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch)))\n\n        model.eval()\n        epoch_psnr = AverageMeter()\n\n        for data in eval_dataloader:\n            inputs, labels = data\n\n            inputs = inputs.to(device)\n            labels = labels.to(device)\n\n            with torch.no_grad():\n                preds = model(inputs).clamp(0.0, 1.0)\n\n            epoch_psnr.update(calc_psnr(preds, labels), len(inputs))\n\n        print('eval psnr: {:.2f}'.format(epoch_psnr.avg))\n\n        if epoch_psnr.avg > best_psnr:\n            best_epoch = epoch\n            best_psnr = epoch_psnr.avg\n            best_weights = copy.deepcopy(model.state_dict())\n\n    print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))\n    torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))\n"
  },
  {
    "path": "utils.py",
    "content": "import torch\nimport numpy as np\n\n\ndef calc_patch_size(func):\n    def wrapper(args):\n        if args.scale == 2:\n            args.patch_size = 10\n        elif args.scale == 3:\n            args.patch_size = 7\n        elif args.scale == 4:\n            args.patch_size = 6\n        else:\n            raise Exception('Scale Error', args.scale)\n        return func(args)\n    return wrapper\n\n\ndef convert_rgb_to_y(img, dim_order='hwc'):\n    if dim_order == 'hwc':\n        return 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.\n    else:\n        return 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.\n\n\ndef convert_rgb_to_ycbcr(img, dim_order='hwc'):\n    if dim_order == 'hwc':\n        y = 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.\n        cb = 128. + (-37.945 * img[..., 0] - 74.494 * img[..., 1] + 112.439 * img[..., 2]) / 256.\n        cr = 128. + (112.439 * img[..., 0] - 94.154 * img[..., 1] - 18.285 * img[..., 2]) / 256.\n    else:\n        y = 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.\n        cb = 128. + (-37.945 * img[0] - 74.494 * img[1] + 112.439 * img[2]) / 256.\n        cr = 128. + (112.439 * img[0] - 94.154 * img[1] - 18.285 * img[2]) / 256.\n    return np.array([y, cb, cr]).transpose([1, 2, 0])\n\n\ndef convert_ycbcr_to_rgb(img, dim_order='hwc'):\n    if dim_order == 'hwc':\n        r = 298.082 * img[..., 0] / 256. + 408.583 * img[..., 2] / 256. - 222.921\n        g = 298.082 * img[..., 0] / 256. - 100.291 * img[..., 1] / 256. - 208.120 * img[..., 2] / 256. + 135.576\n        b = 298.082 * img[..., 0] / 256. + 516.412 * img[..., 1] / 256. - 276.836\n    else:\n        r = 298.082 * img[0] / 256. + 408.583 * img[2] / 256. - 222.921\n        g = 298.082 * img[0] / 256. - 100.291 * img[1] / 256. - 208.120 * img[2] / 256. + 135.576\n        b = 298.082 * img[0] / 256. + 516.412 * img[1] / 256. - 276.836\n    return np.array([r, g, b]).transpose([1, 2, 0])\n\n\ndef preprocess(img, device):\n    img = np.array(img).astype(np.float32)\n    ycbcr = convert_rgb_to_ycbcr(img)\n    x = ycbcr[..., 0]\n    x /= 255.\n    x = torch.from_numpy(x).to(device)\n    x = x.unsqueeze(0).unsqueeze(0)\n    return x, ycbcr\n\n\ndef calc_psnr(img1, img2):\n    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))\n\n\nclass AverageMeter(object):\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n"
  }
]