Repository: yjn870/FSRCNN-pytorch
Branch: master
Commit: 3eaab297f4b7
Files: 8
Total size: 17.3 KB
Directory structure:
gitextract_64n4vj46/
├── .gitignore
├── README.md
├── datasets.py
├── models.py
├── prepare.py
├── test.py
├── train.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
.DS_Store
.idea
================================================
FILE: README.md
================================================
# FSRCNN
This repository is implementation of the ["Accelerating the Super-Resolution Convolutional Neural Network"](https://arxiv.org/abs/1608.00367).
<center><img src="./thumbnails/fig1.png"></center>
## Differences from the original
- Added the zero-padding
- Used the Adam instead of the SGD
## Requirements
- PyTorch 1.0.0
- Numpy 1.15.4
- Pillow 5.4.1
- h5py 2.8.0
- tqdm 4.30.0
## Train
The 91-image, Set5 dataset converted to HDF5 can be downloaded from the links below.
| Dataset | Scale | Type | Link |
|---------|-------|------|------|
| 91-image | 2 | Train | [Download](https://www.dropbox.com/s/01z95js39kgw1qv/91-image_x2.h5?dl=0) |
| 91-image | 3 | Train | [Download](https://www.dropbox.com/s/qx4swlt2j7u4twr/91-image_x3.h5?dl=0) |
| 91-image | 4 | Train | [Download](https://www.dropbox.com/s/vobvi2nlymtvezb/91-image_x4.h5?dl=0) |
| Set5 | 2 | Eval | [Download](https://www.dropbox.com/s/4kzqmtqzzo29l1x/Set5_x2.h5?dl=0) |
| Set5 | 3 | Eval | [Download](https://www.dropbox.com/s/kyhbhyc5a0qcgnp/Set5_x3.h5?dl=0) |
| Set5 | 4 | Eval | [Download](https://www.dropbox.com/s/ihtv1acd48cof14/Set5_x4.h5?dl=0) |
Otherwise, you can use `prepare.py` to create custom dataset.
```bash
python train.py --train-file "BLAH_BLAH/91-image_x3.h5" \
--eval-file "BLAH_BLAH/Set5_x3.h5" \
--outputs-dir "BLAH_BLAH/outputs" \
--scale 3 \
--lr 1e-3 \
--batch-size 16 \
--num-epochs 20 \
--num-workers 8 \
--seed 123
```
## Test
Pre-trained weights can be downloaded from the links below.
| Model | Scale | Link |
|-------|-------|------|
| FSRCNN(56,12,4) | 2 | [Download](https://www.dropbox.com/s/1k3dker6g7hz76s/fsrcnn_x2.pth?dl=0) |
| FSRCNN(56,12,4) | 3 | [Download](https://www.dropbox.com/s/pm1ed2nyboulz5z/fsrcnn_x3.pth?dl=0) |
| FSRCNN(56,12,4) | 4 | [Download](https://www.dropbox.com/s/vsvumpopupdpmmu/fsrcnn_x4.pth?dl=0) |
The results are stored in the same path as the query image.
```bash
python test.py --weights-file "BLAH_BLAH/fsrcnn_x3.pth" \
--image-file "data/butterfly_GT.bmp" \
--scale 3
```
## Results
PSNR was calculated on the Y channel.
### Set5
| Eval. Mat | Scale | Paper | Ours (91-image) |
|-----------|-------|-------|-----------------|
| PSNR | 2 | 36.94 | 37.12 |
| PSNR | 3 | 33.06 | 33.22 |
| PSNR | 4 | 30.55 | 30.50 |
<table>
<tr>
<td><center>Original</center></td>
<td><center>BICUBIC x3</center></td>
<td><center>FSRCNN x3 (34.66 dB)</center></td>
</tr>
<tr>
<td>
<center><img src="./data/lenna.bmp""></center>
</td>
<td>
<center><img src="./data/lenna_bicubic_x3.bmp"></center>
</td>
<td>
<center><img src="./data/lenna_fsrcnn_x3.bmp"></center>
</td>
</tr>
<tr>
<td><center>Original</center></td>
<td><center>BICUBIC x3</center></td>
<td><center>FSRCNN x3 (28.55 dB)</center></td>
</tr>
<tr>
<td>
<center><img src="./data/butterfly_GT.bmp""></center>
</td>
<td>
<center><img src="./data/butterfly_GT_bicubic_x3.bmp"></center>
</td>
<td>
<center><img src="./data/butterfly_GT_fsrcnn_x3.bmp"></center>
</td>
</tr>
</table>
================================================
FILE: datasets.py
================================================
import h5py
import numpy as np
from torch.utils.data import Dataset
class TrainDataset(Dataset):
def __init__(self, h5_file):
super(TrainDataset, self).__init__()
self.h5_file = h5_file
def __getitem__(self, idx):
with h5py.File(self.h5_file, 'r') as f:
return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)
def __len__(self):
with h5py.File(self.h5_file, 'r') as f:
return len(f['lr'])
class EvalDataset(Dataset):
def __init__(self, h5_file):
super(EvalDataset, self).__init__()
self.h5_file = h5_file
def __getitem__(self, idx):
with h5py.File(self.h5_file, 'r') as f:
return np.expand_dims(f['lr'][str(idx)][:, :] / 255., 0), np.expand_dims(f['hr'][str(idx)][:, :] / 255., 0)
def __len__(self):
with h5py.File(self.h5_file, 'r') as f:
return len(f['lr'])
================================================
FILE: models.py
================================================
import math
from torch import nn
class FSRCNN(nn.Module):
def __init__(self, scale_factor, num_channels=1, d=56, s=12, m=4):
super(FSRCNN, self).__init__()
self.first_part = nn.Sequential(
nn.Conv2d(num_channels, d, kernel_size=5, padding=5//2),
nn.PReLU(d)
)
self.mid_part = [nn.Conv2d(d, s, kernel_size=1), nn.PReLU(s)]
for _ in range(m):
self.mid_part.extend([nn.Conv2d(s, s, kernel_size=3, padding=3//2), nn.PReLU(s)])
self.mid_part.extend([nn.Conv2d(s, d, kernel_size=1), nn.PReLU(d)])
self.mid_part = nn.Sequential(*self.mid_part)
self.last_part = nn.ConvTranspose2d(d, num_channels, kernel_size=9, stride=scale_factor, padding=9//2,
output_padding=scale_factor-1)
self._initialize_weights()
def _initialize_weights(self):
for m in self.first_part:
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
nn.init.zeros_(m.bias.data)
for m in self.mid_part:
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
nn.init.zeros_(m.bias.data)
nn.init.normal_(self.last_part.weight.data, mean=0.0, std=0.001)
nn.init.zeros_(self.last_part.bias.data)
def forward(self, x):
x = self.first_part(x)
x = self.mid_part(x)
x = self.last_part(x)
return x
================================================
FILE: prepare.py
================================================
import argparse
import glob
import h5py
import numpy as np
import PIL.Image as pil_image
from utils import calc_patch_size, convert_rgb_to_y
@calc_patch_size
def train(args):
h5_file = h5py.File(args.output_path, 'w')
lr_patches = []
hr_patches = []
for image_path in sorted(glob.glob('{}/*'.format(args.images_dir))):
hr = pil_image.open(image_path).convert('RGB')
hr_images = []
if args.with_aug:
for s in [1.0, 0.9, 0.8, 0.7, 0.6]:
for r in [0, 90, 180, 270]:
tmp = hr.resize((int(hr.width * s), int(hr.height * s)), resample=pil_image.BICUBIC)
tmp = tmp.rotate(r, expand=True)
hr_images.append(tmp)
else:
hr_images.append(hr)
for hr in hr_images:
hr_width = (hr.width // args.scale) * args.scale
hr_height = (hr.height // args.scale) * args.scale
hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
lr = hr.resize((hr.width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)
hr = np.array(hr).astype(np.float32)
lr = np.array(lr).astype(np.float32)
hr = convert_rgb_to_y(hr)
lr = convert_rgb_to_y(lr)
for i in range(0, lr.shape[0] - args.patch_size + 1, args.scale):
for j in range(0, lr.shape[1] - args.patch_size + 1, args.scale):
lr_patches.append(lr[i:i+args.patch_size, j:j+args.patch_size])
hr_patches.append(hr[i*args.scale:i*args.scale+args.patch_size*args.scale, j*args.scale:j*args.scale+args.patch_size*args.scale])
lr_patches = np.array(lr_patches)
hr_patches = np.array(hr_patches)
h5_file.create_dataset('lr', data=lr_patches)
h5_file.create_dataset('hr', data=hr_patches)
h5_file.close()
def eval(args):
h5_file = h5py.File(args.output_path, 'w')
lr_group = h5_file.create_group('lr')
hr_group = h5_file.create_group('hr')
for i, image_path in enumerate(sorted(glob.glob('{}/*'.format(args.images_dir)))):
hr = pil_image.open(image_path).convert('RGB')
hr_width = (hr.width // args.scale) * args.scale
hr_height = (hr.height // args.scale) * args.scale
hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
lr = hr.resize((hr.width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)
hr = np.array(hr).astype(np.float32)
lr = np.array(lr).astype(np.float32)
hr = convert_rgb_to_y(hr)
lr = convert_rgb_to_y(lr)
lr_group.create_dataset(str(i), data=lr)
hr_group.create_dataset(str(i), data=hr)
h5_file.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--images-dir', type=str, required=True)
parser.add_argument('--output-path', type=str, required=True)
parser.add_argument('--scale', type=int, default=2)
parser.add_argument('--with-aug', action='store_true')
parser.add_argument('--eval', action='store_true')
args = parser.parse_args()
if not args.eval:
train(args)
else:
eval(args)
================================================
FILE: test.py
================================================
import argparse
import torch
import torch.backends.cudnn as cudnn
import numpy as np
import PIL.Image as pil_image
from models import FSRCNN
from utils import convert_ycbcr_to_rgb, preprocess, calc_psnr
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights-file', type=str, required=True)
parser.add_argument('--image-file', type=str, required=True)
parser.add_argument('--scale', type=int, default=3)
args = parser.parse_args()
cudnn.benchmark = True
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = FSRCNN(scale_factor=args.scale).to(device)
state_dict = model.state_dict()
for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():
if n in state_dict.keys():
state_dict[n].copy_(p)
else:
raise KeyError(n)
model.eval()
image = pil_image.open(args.image_file).convert('RGB')
image_width = (image.width // args.scale) * args.scale
image_height = (image.height // args.scale) * args.scale
hr = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
lr = hr.resize((hr.width // args.scale, hr.height // args.scale), resample=pil_image.BICUBIC)
bicubic = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
bicubic.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))
lr, _ = preprocess(lr, device)
hr, _ = preprocess(hr, device)
_, ycbcr = preprocess(bicubic, device)
with torch.no_grad():
preds = model(lr).clamp(0.0, 1.0)
psnr = calc_psnr(hr, preds)
print('PSNR: {:.2f}'.format(psnr))
preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)
output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
output = pil_image.fromarray(output)
output.save(args.image_file.replace('.', '_fsrcnn_x{}.'.format(args.scale)))
================================================
FILE: train.py
================================================
import argparse
import os
import copy
import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from models import FSRCNN
from datasets import TrainDataset, EvalDataset
from utils import AverageMeter, calc_psnr
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--train-file', type=str, required=True)
parser.add_argument('--eval-file', type=str, required=True)
parser.add_argument('--outputs-dir', type=str, required=True)
parser.add_argument('--weights-file', type=str)
parser.add_argument('--scale', type=int, default=2)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--num-epochs', type=int, default=20)
parser.add_argument('--num-workers', type=int, default=8)
parser.add_argument('--seed', type=int, default=123)
args = parser.parse_args()
args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale))
if not os.path.exists(args.outputs_dir):
os.makedirs(args.outputs_dir)
cudnn.benchmark = True
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(args.seed)
model = FSRCNN(scale_factor=args.scale).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam([
{'params': model.first_part.parameters()},
{'params': model.mid_part.parameters()},
{'params': model.last_part.parameters(), 'lr': args.lr * 0.1}
], lr=args.lr)
train_dataset = TrainDataset(args.train_file)
train_dataloader = DataLoader(dataset=train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True)
eval_dataset = EvalDataset(args.eval_file)
eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)
best_weights = copy.deepcopy(model.state_dict())
best_epoch = 0
best_psnr = 0.0
for epoch in range(args.num_epochs):
model.train()
epoch_losses = AverageMeter()
with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size), ncols=80) as t:
t.set_description('epoch: {}/{}'.format(epoch, args.num_epochs - 1))
for data in train_dataloader:
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
preds = model(inputs)
loss = criterion(preds, labels)
epoch_losses.update(loss.item(), len(inputs))
optimizer.zero_grad()
loss.backward()
optimizer.step()
t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
t.update(len(inputs))
torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch)))
model.eval()
epoch_psnr = AverageMeter()
for data in eval_dataloader:
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
with torch.no_grad():
preds = model(inputs).clamp(0.0, 1.0)
epoch_psnr.update(calc_psnr(preds, labels), len(inputs))
print('eval psnr: {:.2f}'.format(epoch_psnr.avg))
if epoch_psnr.avg > best_psnr:
best_epoch = epoch
best_psnr = epoch_psnr.avg
best_weights = copy.deepcopy(model.state_dict())
print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))
================================================
FILE: utils.py
================================================
import torch
import numpy as np
def calc_patch_size(func):
def wrapper(args):
if args.scale == 2:
args.patch_size = 10
elif args.scale == 3:
args.patch_size = 7
elif args.scale == 4:
args.patch_size = 6
else:
raise Exception('Scale Error', args.scale)
return func(args)
return wrapper
def convert_rgb_to_y(img, dim_order='hwc'):
if dim_order == 'hwc':
return 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
else:
return 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.
def convert_rgb_to_ycbcr(img, dim_order='hwc'):
if dim_order == 'hwc':
y = 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
cb = 128. + (-37.945 * img[..., 0] - 74.494 * img[..., 1] + 112.439 * img[..., 2]) / 256.
cr = 128. + (112.439 * img[..., 0] - 94.154 * img[..., 1] - 18.285 * img[..., 2]) / 256.
else:
y = 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.
cb = 128. + (-37.945 * img[0] - 74.494 * img[1] + 112.439 * img[2]) / 256.
cr = 128. + (112.439 * img[0] - 94.154 * img[1] - 18.285 * img[2]) / 256.
return np.array([y, cb, cr]).transpose([1, 2, 0])
def convert_ycbcr_to_rgb(img, dim_order='hwc'):
if dim_order == 'hwc':
r = 298.082 * img[..., 0] / 256. + 408.583 * img[..., 2] / 256. - 222.921
g = 298.082 * img[..., 0] / 256. - 100.291 * img[..., 1] / 256. - 208.120 * img[..., 2] / 256. + 135.576
b = 298.082 * img[..., 0] / 256. + 516.412 * img[..., 1] / 256. - 276.836
else:
r = 298.082 * img[0] / 256. + 408.583 * img[2] / 256. - 222.921
g = 298.082 * img[0] / 256. - 100.291 * img[1] / 256. - 208.120 * img[2] / 256. + 135.576
b = 298.082 * img[0] / 256. + 516.412 * img[1] / 256. - 276.836
return np.array([r, g, b]).transpose([1, 2, 0])
def preprocess(img, device):
img = np.array(img).astype(np.float32)
ycbcr = convert_rgb_to_ycbcr(img)
x = ycbcr[..., 0]
x /= 255.
x = torch.from_numpy(x).to(device)
x = x.unsqueeze(0).unsqueeze(0)
return x, ycbcr
def calc_psnr(img1, img2):
return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
gitextract_64n4vj46/ ├── .gitignore ├── README.md ├── datasets.py ├── models.py ├── prepare.py ├── test.py ├── train.py └── utils.py
SYMBOL INDEX (24 symbols across 4 files)
FILE: datasets.py
class TrainDataset (line 6) | class TrainDataset(Dataset):
method __init__ (line 7) | def __init__(self, h5_file):
method __getitem__ (line 11) | def __getitem__(self, idx):
method __len__ (line 15) | def __len__(self):
class EvalDataset (line 20) | class EvalDataset(Dataset):
method __init__ (line 21) | def __init__(self, h5_file):
method __getitem__ (line 25) | def __getitem__(self, idx):
method __len__ (line 29) | def __len__(self):
FILE: models.py
class FSRCNN (line 5) | class FSRCNN(nn.Module):
method __init__ (line 6) | def __init__(self, scale_factor, num_channels=1, d=56, s=12, m=4):
method _initialize_weights (line 22) | def _initialize_weights(self):
method forward (line 34) | def forward(self, x):
FILE: prepare.py
function train (line 10) | def train(args):
function eval (line 53) | def eval(args):
FILE: utils.py
function calc_patch_size (line 5) | def calc_patch_size(func):
function convert_rgb_to_y (line 19) | def convert_rgb_to_y(img, dim_order='hwc'):
function convert_rgb_to_ycbcr (line 26) | def convert_rgb_to_ycbcr(img, dim_order='hwc'):
function convert_ycbcr_to_rgb (line 38) | def convert_ycbcr_to_rgb(img, dim_order='hwc'):
function preprocess (line 50) | def preprocess(img, device):
function calc_psnr (line 60) | def calc_psnr(img1, img2):
class AverageMeter (line 64) | class AverageMeter(object):
method __init__ (line 65) | def __init__(self):
method reset (line 68) | def reset(self):
method update (line 74) | def update(self, val, n=1):
Condensed preview — 8 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (19K chars).
[
{
"path": ".gitignore",
"chars": 15,
"preview": ".DS_Store\n.idea"
},
{
"path": "README.md",
"chars": 3332,
"preview": "# FSRCNN\n\nThis repository is implementation of the [\"Accelerating the Super-Resolution Convolutional Neural Network\"](ht"
},
{
"path": "datasets.py",
"chars": 934,
"preview": "import h5py\nimport numpy as np\nfrom torch.utils.data import Dataset\n\n\nclass TrainDataset(Dataset):\n def __init__(self"
},
{
"path": "models.py",
"chars": 1627,
"preview": "import math\nfrom torch import nn\n\n\nclass FSRCNN(nn.Module):\n def __init__(self, scale_factor, num_channels=1, d=56, s"
},
{
"path": "prepare.py",
"chars": 3224,
"preview": "import argparse\nimport glob\nimport h5py\nimport numpy as np\nimport PIL.Image as pil_image\nfrom utils import calc_patch_si"
},
{
"path": "test.py",
"chars": 2068,
"preview": "import argparse\n\nimport torch\nimport torch.backends.cudnn as cudnn\nimport numpy as np\nimport PIL.Image as pil_image\n\nfro"
},
{
"path": "train.py",
"chars": 3841,
"preview": "import argparse\nimport os\nimport copy\n\nimport torch\nfrom torch import nn\nimport torch.optim as optim\nimport torch.backen"
},
{
"path": "utils.py",
"chars": 2651,
"preview": "import torch\nimport numpy as np\n\n\ndef calc_patch_size(func):\n def wrapper(args):\n if args.scale == 2:\n "
}
]
About this extraction
This page contains the full source code of the yjn870/FSRCNN-pytorch GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 8 files (17.3 KB), approximately 5.3k tokens, and a symbol index with 24 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.