Repository: wzhouxiff/RestoreFormer
Branch: main
Commit: 294cf9521a86
Files: 30
Total size: 158.8 KB
Directory structure:
gitextract_istv2bdf/
├── .gitignore
├── LICENSE
├── README.md
├── RestoreFormer/
│ ├── data/
│ │ └── ffhq_degradation_dataset.py
│ ├── distributed/
│ │ ├── __init__.py
│ │ ├── distributed.py
│ │ └── launch.py
│ ├── models/
│ │ └── vqgan_v1.py
│ ├── modules/
│ │ ├── discriminator/
│ │ │ └── model.py
│ │ ├── losses/
│ │ │ ├── __init__.py
│ │ │ ├── lpips.py
│ │ │ └── vqperceptual.py
│ │ ├── util.py
│ │ └── vqvae/
│ │ ├── arcface_arch.py
│ │ ├── facial_component_discriminator.py
│ │ ├── utils.py
│ │ └── vqvae_arch.py
│ └── util.py
├── __init__.py
├── configs/
│ ├── HQ_Dictionary.yaml
│ └── RestoreFormer.yaml
├── main.py
├── restoreformer_requirement.txt
└── scripts/
├── metrics/
│ ├── cal_fid.py
│ ├── cal_identity_distance.py
│ ├── cal_psnr_ssim.py
│ └── run.sh
├── run.sh
├── test.py
└── test.sh
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
data/FFHQ
scripts/data_synthetic
experiments/
scripts/run_clustre.sh
sftp-config.json
results/
# scripts/metrics
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# We have merged the code of RestoreFormer into our journal version, RestoreFormer++. Please feel free to access the resources from [https://github.com/wzhouxiff/RestoreFormerPlusPlus](https://github.com/wzhouxiff/RestoreFormerPlusPlus)
# Updating
- **20230915** Update an online demo [](https://huggingface.co/spaces/wzhouxiff/RestoreFormerPlusPlus)
- **20230915** A more user-friendly and comprehensive inference method refer to our [RestoreFormer++](https://github.com/wzhouxiff/RestoreFormerPlusPlus)
- **20230116** For convenience, we further upload the [test datasets](#testset), including CelebA (both HQ and LQ data), LFW-Test, CelebChild-Test, and Webphoto-Test, to OneDrive and BaiduYun.
- **20221003** We provide the link of the [test datasets](#testset).
- **20220924** We add the code for [**metrics**](#metrics) in scripts/metrics.
<!--
# RestoreFormer
This repo includes the source code of the paper: "[RestoreFormer: High-Quality Blind Face Restoration from Undegraded Key-Value Pairs](https://openaccess.thecvf.com/content/CVPR2022/papers/Wang_RestoreFormer_High-Quality_Blind_Face_Restoration_From_Undegraded_Key-Value_Pairs_CVPR_2022_paper.pdf)" (CVPR 2022) by Zhouxia Wang, Jiawei Zhang, Runjian Chen, Wenping Wang, and Ping Luo.

**RestoreFormer** tends to explore fully-spatial attentions to model contextual information and surpasses existing works that use local operators. It has several benefits compared to prior arts. First, it incorporates a multi-head coross-attention layer to learn fully-spatial interations between corrupted queries and high-quality key-value pairs. Second, the key-value pairs in RestoreFormer are sampled from a reconstruction-oriented high-quality dictionary, whose elements are rich in high-quality facial features specifically aimed for face reconstruction.
-->
<!-- -->
<!--
## Environment
- python>=3.7
- pytorch>=1.7.1
- pytorch-lightning==1.0.8
- omegaconf==2.0.0
- basicsr==1.3.3.4
**Warning** Different versions of pytorch-lightning and omegaconf may lead to errors or different results.
## Preparations of dataset and models
**Dataset**:
- Training data: Both **HQ Dictionary** and **RestoreFormer** in our work are trained with **FFHQ** which attained from [FFHQ repository](https://github.com/NVlabs/ffhq-dataset). The original size of the images in FFHQ are 1024x1024. We resize them to 512x512 with bilinear interpolation in our work. Link this dataset to ./data/FFHQ/image512x512.
- <a id="testset">Test data</a>:
* CelebA-Test-HQ: [OneDrive](https://connecthkuhk-my.sharepoint.com/:u:/g/personal/wzhoux_connect_hku_hk/EY7P-MReZUZOngy3UGa5abUBJKel1IH5uYZLdwp2e2KvUw?e=rK0VWh); [BaiduYun](https://pan.baidu.com/s/1tMpxz8lIW50U8h00047GIw?pwd=mp9t)(code mp9t)
* CelebA-Test-LQ: [OneDrive](https://connecthkuhk-my.sharepoint.com/:u:/g/personal/wzhoux_connect_hku_hk/EXULDOtX3qdKg9_--k-hbr4BumxOUAi19iQjZNz75S6pKA?e=Kghqri); [BaiduYun](https://pan.baidu.com/s/1y6ZcQPCLyggj9VB5MgoWyg?pwd=7s6h)(code 7s6h)
* LFW-Test: [OneDrive](https://connecthkuhk-my.sharepoint.com/:u:/g/personal/wzhoux_connect_hku_hk/EZ7ibkhUuRxBjdd-MesczpgBfpLVfv-9uYVskLuZiYpBsg?e=xPNH26); [BaiduYun](https://pan.baidu.com/s/1UkfYLTViL8XVdZ-Ej-2G9g?pwd=7fhr)(code 7fhr). Note that it was align with dlib.
* CelebChild: [OneDrive](https://connecthkuhk-my.sharepoint.com/:u:/g/personal/wzhoux_connect_hku_hk/ESK6vjLzDuJAsd-cfWrfl20BTeSD_w4uRNJREGfl3zGzJg?e=Tou7ft); [BaiduYun](https://pan.baidu.com/s/1pGCD4TkhtDsmp8emZd8smA?pwd=rq65)(code rq65)
* WepPhoto-Test: [OneDrive](https://connecthkuhk-my.sharepoint.com/:u:/g/personal/wzhoux_connect_hku_hk/ER1-0eYKGkZIs-YEDhNW0xIBohCI5IEZyAS2PAvI81Stcg?e=TFJFGh); [BaiduYun](https://pan.baidu.com/s/1SjBfinSL1F-bbOpXiD0nlw?pwd=nren)(code nren)
**Model**: Both pretrained models used for training and the trained model of our RestoreFormer can be attained from [OneDrive](https://connecthkuhk-my.sharepoint.com/:u:/g/personal/wzhoux_connect_hku_hk/Eb73S2jXZIxNrrOFRnFKu2MBTe7kl4cMYYwwiudAmDNwYg?e=Xa4ZDf) or [BaiduYun](https://pan.baidu.com/s/1EO7_1dYyCuORpPNosQgogg?pwd=x6nn)(code x6nn). Link these models to ./experiments.
## Test
sh scripts/test.sh
## Training
sh scripts/run.sh
**Note**.
- The first stage is to attain **HQ Dictionary** by setting `conf_name` in scripts/run.sh to 'HQ\_Dictionary'.
- The second stage is blind face restoration. You need to add your trained HQ\_Dictionary model to `ckpt_path` in config/RestoreFormer.yaml and set `conf_name` in scripts/run.sh to 'RestoreFormer'.
- Our model is trained with 4 V100 GPUs.
## <a id="metrics">Metrics</a>
sh scripts/metrics/run.sh
**Note**.
- You need to add the path of CelebA-Test dataset in the script if you want get IDD, PSRN, SSIM, LIPIS.
## Citation
@article{wang2022restoreformer,
title={RestoreFormer: High-Quality Blind Face Restoration from Undegraded Key-Value Pairs},
author={Wang, Zhouxia and Zhang, Jiawei and Chen, Runjian and Wang, Wenping and Luo, Ping},
booktitle={The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2022}
}
## Acknowledgement
We thank everyone who makes their code and models available, especially [Taming Transformer](https://github.com/CompVis/taming-transformers), [basicsr](https://github.com/XPixelGroup/BasicSR), and [GFPGAN](https://github.com/TencentARC/GFPGAN).
## Contact
For any question, feel free to email `wzhoux@connect.hku.hk` or `zhouzi1212@gmail.com`.
-->
================================================
FILE: RestoreFormer/data/ffhq_degradation_dataset.py
================================================
import os
import cv2
import math
import numpy as np
import random
import os.path as osp
import torch
import torch.utils.data as data
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
normalize)
from basicsr.data import degradations as degradations
from basicsr.data.data_util import paths_from_folder
from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class FFHQDegradationDataset(data.Dataset):
def __init__(self, opt):
super(FFHQDegradationDataset, self).__init__()
self.opt = opt
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.gt_folder = opt['dataroot_gt']
self.mean = opt['mean']
self.std = opt['std']
self.out_size = opt['out_size']
self.crop_components = opt.get('crop_components', False) # facial components
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1)
if self.crop_components:
self.components_list = torch.load(opt.get('component_path'))
if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = self.gt_folder
if not self.gt_folder.endswith('.lmdb'):
raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
self.paths = [line.split('.')[0] for line in fin]
else:
self.paths = paths_from_folder(self.gt_folder)
# degradations
self.blur_kernel_size = opt['blur_kernel_size']
self.kernel_list = opt['kernel_list']
self.kernel_prob = opt['kernel_prob']
self.blur_sigma = opt['blur_sigma']
self.downsample_range = opt['downsample_range']
self.noise_range = opt['noise_range']
self.jpeg_range = opt['jpeg_range']
# color jitter
self.color_jitter_prob = opt.get('color_jitter_prob')
self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob')
self.color_jitter_shift = opt.get('color_jitter_shift', 20)
# to gray
self.gray_prob = opt.get('gray_prob')
logger = get_root_logger()
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, '
f'sigma: [{", ".join(map(str, self.blur_sigma))}]')
logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
if self.color_jitter_prob is not None:
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, '
f'shift: {self.color_jitter_shift}')
if self.gray_prob is not None:
logger.info(f'Use random gray. Prob: {self.gray_prob}')
self.color_jitter_shift /= 255.
@staticmethod
def color_jitter(img, shift):
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
img = img + jitter_val
img = np.clip(img, 0, 1)
return img
@staticmethod
def color_jitter_pt(img, brightness, contrast, saturation, hue):
fn_idx = torch.randperm(4)
for fn_id in fn_idx:
if fn_id == 0 and brightness is not None:
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
img = adjust_brightness(img, brightness_factor)
if fn_id == 1 and contrast is not None:
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
img = adjust_contrast(img, contrast_factor)
if fn_id == 2 and saturation is not None:
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
img = adjust_saturation(img, saturation_factor)
if fn_id == 3 and hue is not None:
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
img = adjust_hue(img, hue_factor)
return img
def get_component_coordinates(self, index, status):
components_bbox = self.components_list[f'{index:08d}']
if status[0]: # hflip
# exchange right and left eye
tmp = components_bbox['left_eye']
components_bbox['left_eye'] = components_bbox['right_eye']
components_bbox['right_eye'] = tmp
# modify the width coordinate
components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0]
components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0]
components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0]
# get coordinates
locations = []
for part in ['left_eye', 'right_eye', 'mouth']:
mean = components_bbox[part][0:2]
half_len = components_bbox[part][2]
if 'eye' in part:
half_len *= self.eye_enlarge_ratio
loc = np.hstack((mean - half_len + 1, mean + half_len))
loc = torch.from_numpy(loc).float()
locations.append(loc)
return locations
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
# load gt image
gt_path = self.paths[index]
img_bytes = self.file_client.get(gt_path)
img_gt = imfrombytes(img_bytes, float32=True)
# random horizontal flip
img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
h, w, _ = img_gt.shape
if self.crop_components:
locations = self.get_component_coordinates(index, status)
loc_left_eye, loc_right_eye, loc_mouth = locations
# ------------------------ generate lq image ------------------------ #
# blur
assert self.blur_kernel_size[0] < self.blur_kernel_size[1], 'Wrong blur kernel size range'
cur_kernel_size = random.randint(self.blur_kernel_size[0],self.blur_kernel_size[1]) * 2 + 1
kernel = degradations.random_mixed_kernels(
self.kernel_list,
self.kernel_prob,
cur_kernel_size,
self.blur_sigma,
self.blur_sigma, [-math.pi, math.pi],
noise_range=None)
img_lq = cv2.filter2D(img_gt, -1, kernel)
# downsample
scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
# noise
if self.noise_range is not None:
img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range)
# jpeg compression
if self.jpeg_range is not None:
img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range)
# resize to original size
img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)
# random color jitter (only for lq)
if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
img_lq = self.color_jitter(img_lq, self.color_jitter_shift)
# random to gray (only for lq)
if self.gray_prob and np.random.uniform() < self.gray_prob:
img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
if self.opt.get('gt_gray'):
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
# random color jitter (pytorch version) (only for lq)
if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
brightness = self.opt.get('brightness', (0.5, 1.5))
contrast = self.opt.get('contrast', (0.5, 1.5))
saturation = self.opt.get('saturation', (0, 1.5))
hue = self.opt.get('hue', (-0.1, 0.1))
img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue)
# round and clip
img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.
# normalize
normalize(img_gt, self.mean, self.std, inplace=True)
normalize(img_lq, self.mean, self.std, inplace=True)
return_dict = {
'lq': img_lq,
'gt': img_gt,
'gt_path': gt_path
}
if self.crop_components:
return_dict['loc_left_eye'] = loc_left_eye
return_dict['loc_right_eye'] = loc_right_eye
return_dict['loc_mouth'] = loc_mouth
return return_dict
def __len__(self):
return len(self.paths)
import argparse
from omegaconf import OmegaConf
import pdb
from basicsr.utils import img2tensor, imwrite, tensor2img
if __name__=='__main__':
# pdb.set_trace()
base='configs/RestoreFormer.yaml'
opt = OmegaConf.load(base)
dataset = FFHQDegradationDataset(opt['data']['params']['train']['params'])
for i in range(100):
sample = dataset.getitem(i)
name = sample['gt_path'].split('/')[-1][:-4]
gt = tensor2img(sample['gt'])
imwrite(gt, +name+'_gt.png')
lq = tensor2img(sample['lq'])
imwrite(lq, name+'_lq_nojitter.png')
================================================
FILE: RestoreFormer/distributed/__init__.py
================================================
from .distributed import (
get_rank,
get_local_rank,
is_primary,
synchronize,
get_world_size,
all_reduce,
all_gather,
reduce_dict,
data_sampler,
LOCAL_PROCESS_GROUP,
)
from .launch import launch
================================================
FILE: RestoreFormer/distributed/distributed.py
================================================
import math
import pickle
import torch
from torch import distributed as dist
from torch.utils import data
LOCAL_PROCESS_GROUP = None
def is_primary():
return get_rank() == 0
def get_rank():
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def get_local_rank():
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
if LOCAL_PROCESS_GROUP is None:
raise ValueError("tensorfn.distributed.LOCAL_PROCESS_GROUP is None")
return dist.get_rank(group=LOCAL_PROCESS_GROUP)
def synchronize():
if not dist.is_available():
return
if not dist.is_initialized():
return
world_size = dist.get_world_size()
if world_size == 1:
return
dist.barrier()
def get_world_size():
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
def all_reduce(tensor, op=dist.ReduceOp.SUM):
world_size = get_world_size()
if world_size == 1:
return tensor
dist.all_reduce(tensor, op=op)
return tensor
def all_gather(data):
world_size = get_world_size()
if world_size == 1:
return [data]
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to("cuda")
local_size = torch.IntTensor([tensor.numel()]).to("cuda")
size_list = [torch.IntTensor([1]).to("cuda") for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
tensor_list = []
for _ in size_list:
tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
if local_size != max_size:
padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
tensor = torch.cat((tensor, padding), 0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def reduce_dict(input_dict, average=True):
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
keys = []
values = []
for k in sorted(input_dict.keys()):
keys.append(k)
values.append(input_dict[k])
values = torch.stack(values, 0)
dist.reduce(values, dst=0)
if dist.get_rank() == 0 and average:
values /= world_size
reduced_dict = {k: v for k, v in zip(keys, values)}
return reduced_dict
def data_sampler(dataset, shuffle, distributed):
if distributed:
return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
if shuffle:
return data.RandomSampler(dataset)
else:
return data.SequentialSampler(dataset)
================================================
FILE: RestoreFormer/distributed/launch.py
================================================
import os
import torch
from torch import distributed as dist
from torch import multiprocessing as mp
from . import distributed as dist_fn
def find_free_port():
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0))
port = sock.getsockname()[1]
sock.close()
return port
def launch(fn, n_gpu_per_machine, n_machine=1, machine_rank=0, dist_url=None, args=()):
world_size = n_machine * n_gpu_per_machine
if world_size > 1:
if "OMP_NUM_THREADS" not in os.environ:
os.environ["OMP_NUM_THREADS"] = "1"
if dist_url == "auto":
if n_machine != 1:
raise ValueError('dist_url="auto" not supported in multi-machine jobs')
port = find_free_port()
dist_url = f"tcp://127.0.0.1:{port}"
if n_machine > 1 and dist_url.startswith("file://"):
raise ValueError(
"file:// is not a reliable init method in multi-machine jobs. Prefer tcp://"
)
mp.spawn(
distributed_worker,
nprocs=n_gpu_per_machine,
args=(fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args),
daemon=False,
)
else:
fn(*args)
def distributed_worker(
local_rank, fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args
):
if not torch.cuda.is_available():
raise OSError("CUDA is not available. Please check your environments")
global_rank = machine_rank * n_gpu_per_machine + local_rank
try:
dist.init_process_group(
backend="NCCL",
init_method=dist_url,
world_size=world_size,
rank=global_rank,
)
except Exception:
raise OSError("failed to initialize NCCL groups")
dist_fn.synchronize()
if n_gpu_per_machine > torch.cuda.device_count():
raise ValueError(
f"specified n_gpu_per_machine larger than available device ({torch.cuda.device_count()})"
)
torch.cuda.set_device(local_rank)
if dist_fn.LOCAL_PROCESS_GROUP is not None:
raise ValueError("torch.distributed.LOCAL_PROCESS_GROUP is not None")
n_machine = world_size // n_gpu_per_machine
for i in range(n_machine):
ranks_on_i = list(range(i * n_gpu_per_machine, (i + 1) * n_gpu_per_machine))
pg = dist.new_group(ranks_on_i)
if i == machine_rank:
dist_fn.distributed.LOCAL_PROCESS_GROUP = pg
fn(*args)
================================================
FILE: RestoreFormer/models/vqgan_v1.py
================================================
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from main import instantiate_from_config
from RestoreFormer.modules.vqvae.utils import get_roi_regions
class RestoreFormerModel(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
ckpt_path=None,
ignore_keys=[],
image_key="lq",
colorize_nlabels=None,
monitor=None,
special_params_lr_scale=1.0,
comp_params_lr_scale=1.0,
schedule_step=[80000, 200000]
):
super().__init__()
self.image_key = image_key
self.vqvae = instantiate_from_config(ddconfig)
lossconfig['params']['distill_param']=ddconfig['params']
self.loss = instantiate_from_config(lossconfig)
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
if ('comp_weight' in lossconfig['params'] and lossconfig['params']['comp_weight']) or ('comp_style_weight' in lossconfig['params'] and lossconfig['params']['comp_style_weight']):
self.use_facial_disc = True
else:
self.use_facial_disc = False
self.fix_decoder = ddconfig['params']['fix_decoder']
self.disc_start = lossconfig['params']['disc_start']
self.special_params_lr_scale = special_params_lr_scale
self.comp_params_lr_scale = comp_params_lr_scale
self.schedule_step = schedule_step
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
state_dict = self.state_dict()
require_keys = state_dict.keys()
keys = sd.keys()
un_pretrained_keys = []
for k in require_keys:
if k not in keys:
# miss 'vqvae.'
if k[6:] in keys:
state_dict[k] = sd[k[6:]]
else:
un_pretrained_keys.append(k)
else:
state_dict[k] = sd[k]
# print(f'*************************************************')
# print(f"Layers without pretraining: {un_pretrained_keys}")
# print(f'*************************************************')
self.load_state_dict(state_dict, strict=True)
print(f"Restored from {path}")
def forward(self, input):
dec, diff, info, hs = self.vqvae(input)
return dec, diff, info, hs
def training_step(self, batch, batch_idx, optimizer_idx):
x = batch[self.image_key]
xrec, qloss, info, hs = self(x)
if self.image_key != 'gt':
x = batch['gt']
if self.use_facial_disc:
loc_left_eyes = batch['loc_left_eye']
loc_right_eyes = batch['loc_right_eye']
loc_mouths = batch['loc_mouth']
face_ratio = xrec.shape[-1] / 512
components = get_roi_regions(x, xrec, loc_left_eyes, loc_right_eyes, loc_mouths, face_ratio)
else:
components = None
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(qloss, x, xrec, components, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return aeloss
if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(qloss, x, xrec, components, optimizer_idx, self.global_step,
last_layer=None, split="train")
self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return discloss
if self.disc_start <= self.global_step:
# left eye
if optimizer_idx == 2:
# discriminator
disc_left_loss, log_dict_disc = self.loss(qloss, x, xrec, components, optimizer_idx, self.global_step,
last_layer=None, split="train")
self.log("train/disc_left_loss", disc_left_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return disc_left_loss
# right eye
if optimizer_idx == 3:
# discriminator
disc_right_loss, log_dict_disc = self.loss(qloss, x, xrec, components, optimizer_idx, self.global_step,
last_layer=None, split="train")
self.log("train/disc_right_loss", disc_right_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return disc_right_loss
# mouth
if optimizer_idx == 4:
# discriminator
disc_mouth_loss, log_dict_disc = self.loss(qloss, x, xrec, components, optimizer_idx, self.global_step,
last_layer=None, split="train")
self.log("train/disc_mouth_loss", disc_mouth_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return disc_mouth_loss
def validation_step(self, batch, batch_idx):
x = batch[self.image_key]
xrec, qloss, info, hs = self(x)
if self.image_key != 'gt':
x = batch['gt']
aeloss, log_dict_ae = self.loss(qloss, x, xrec, None, 0, self.global_step,
last_layer=self.get_last_layer(), split="val")
discloss, log_dict_disc = self.loss(qloss, x, xrec, None, 1, self.global_step,
last_layer=None, split="val")
rec_loss = log_dict_ae["val/rec_loss"]
self.log("val/rec_loss", rec_loss,
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
self.log("val/aeloss", aeloss,
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
normal_params = []
special_params = []
for name, param in self.vqvae.named_parameters():
if not param.requires_grad:
continue
if 'decoder' in name and 'attn' in name:
special_params.append(param)
else:
normal_params.append(param)
# print('special_params', special_params)
opt_ae_params = [{'params': normal_params, 'lr': lr},
{'params': special_params, 'lr': lr*self.special_params_lr_scale}]
opt_ae = torch.optim.Adam(opt_ae_params, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
optimizations = [opt_ae, opt_disc]
s0 = torch.optim.lr_scheduler.MultiStepLR(opt_ae, milestones=self.schedule_step, gamma=0.1, verbose=True)
s1 = torch.optim.lr_scheduler.MultiStepLR(opt_disc, milestones=self.schedule_step, gamma=0.1, verbose=True)
schedules = [s0, s1]
if self.use_facial_disc:
opt_l = torch.optim.Adam(self.loss.net_d_left_eye.parameters(),
lr=lr*self.comp_params_lr_scale, betas=(0.9, 0.99))
opt_r = torch.optim.Adam(self.loss.net_d_right_eye.parameters(),
lr=lr*self.comp_params_lr_scale, betas=(0.9, 0.99))
opt_m = torch.optim.Adam(self.loss.net_d_mouth.parameters(),
lr=lr*self.comp_params_lr_scale, betas=(0.9, 0.99))
optimizations += [opt_l, opt_r, opt_m]
s2 = torch.optim.lr_scheduler.MultiStepLR(opt_l, milestones=self.schedule_step, gamma=0.1, verbose=True)
s3 = torch.optim.lr_scheduler.MultiStepLR(opt_r, milestones=self.schedule_step, gamma=0.1, verbose=True)
s4 = torch.optim.lr_scheduler.MultiStepLR(opt_m, milestones=self.schedule_step, gamma=0.1, verbose=True)
schedules += [s2, s3, s4]
return optimizations, schedules
def get_last_layer(self):
if self.fix_decoder:
return self.vqvae.quant_conv.weight
return self.vqvae.decoder.conv_out.weight
def log_images(self, batch, **kwargs):
log = dict()
x = batch[self.image_key]
x = x.to(self.device)
xrec, _, _, _ = self(x)
log["inputs"] = x
log["reconstructions"] = xrec
if self.image_key != 'gt':
x = batch['gt']
log["gt"] = x
return log
================================================
FILE: RestoreFormer/modules/discriminator/model.py
================================================
import functools
import torch.nn as nn
from RestoreFormer.modules.util import ActNorm
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator as in Pix2Pix
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
"""
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(NLayerDiscriminator, self).__init__()
if not use_actnorm:
norm_layer = nn.BatchNorm2d
else:
norm_layer = ActNorm
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func != nn.BatchNorm2d
else:
use_bias = norm_layer != nn.BatchNorm2d
kw = 4
padw = 1
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
self.main = nn.Sequential(*sequence)
def forward(self, input):
"""Standard forward."""
return self.main(input)
class NLayerDiscriminator_v1(nn.Module):
"""Defines a PatchGAN discriminator as in Pix2Pix
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
"""
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(NLayerDiscriminator_v1, self).__init__()
if not use_actnorm:
norm_layer = nn.BatchNorm2d
else:
norm_layer = ActNorm
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func != nn.BatchNorm2d
else:
use_bias = norm_layer != nn.BatchNorm2d
self.n_layers = n_layers
kw = 4
padw = 1
# sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
self.head = nn.Sequential(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True))
# self.head = nn.Sequential(nn.Conv2d(3, 64, kernel_size=4, stride=1, padding=1), nn.LeakyReLU(0.2, True)).cuda()
nf_mult = 1
nf_mult_prev = 1
self.body = nn.ModuleList()
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
self.body.append(nn.Sequential(
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
))
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
self.beforlast = nn.Sequential(
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
)
self.final = nn.Sequential(
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)) # output 1 channel prediction map
# self.main = nn.Sequential(*sequence)
def forward(self, input):
"""Standard forward."""
# return self.main(input)
features = []
f = self.head(input)
features.append(f)
for i in range(self.n_layers-1):
f = self.body[i](f)
features.append(f)
beforlastF = self.beforlast(f)
final = self.final(beforlastF)
return features, final
================================================
FILE: RestoreFormer/modules/losses/__init__.py
================================================
================================================
FILE: RestoreFormer/modules/losses/lpips.py
================================================
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
import torch
import torch.nn as nn
from torchvision import models
from collections import namedtuple
from RestoreFormer.util import get_ckpt_path
class LPIPS(nn.Module):
# Learned perceptual metric
def __init__(self, use_dropout=True, style_weight=0.):
super().__init__()
self.scaling_layer = ScalingLayer()
self.chns = [64, 128, 256, 512, 512] # vg16 features
self.net = vgg16(pretrained=True, requires_grad=False)
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.load_from_pretrained()
for param in self.parameters():
param.requires_grad = False
self.style_weight = style_weight
def load_from_pretrained(self, name="vgg_lpips"):
ckpt = get_ckpt_path(name, "experiments/pretrained_models/lpips")
self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
print("loaded pretrained LPIPS loss from {}".format(ckpt))
@classmethod
def from_pretrained(cls, name="vgg_lpips"):
if name is not "vgg_lpips":
raise NotImplementedError
model = cls()
ckpt = get_ckpt_path(name)
model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
return model
def forward(self, input, target):
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
outs0, outs1 = self.net(in0_input), self.net(in1_input)
feats0, feats1, diffs = {}, {}, {}
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
style_loss = torch.tensor([0.0]).to(input.device)
for kk in range(len(self.chns)):
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
if self.style_weight > 0.:
style_loss = style_loss + torch.mean((self._gram_mat(feats0[kk]) -
self._gram_mat(feats1[kk])) ** 2)
res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
val = res[0]
for l in range(1, len(self.chns)):
val += res[l]
return val, style_loss * self.style_weight
def _gram_mat(self, x):
"""Calculate Gram matrix.
Args:
x (torch.Tensor): Tensor with shape of (n, c, h, w).
Returns:
torch.Tensor: Gram matrix.
"""
n, c, h, w = x.size()
features = x.view(n, c, w * h)
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (c * h * w)
return gram
class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
def forward(self, inp):
return (inp - self.shift) / self.scale
class NetLinLayer(nn.Module):
""" A single linear layer which does a 1x1 conv """
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = [nn.Dropout(), ] if (use_dropout) else []
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
self.model = nn.Sequential(*layers)
class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out
def normalize_tensor(x,eps=1e-10):
norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
return x/(norm_factor+eps)
def spatial_average(x, keepdim=True):
return x.mean([2,3],keepdim=keepdim)
================================================
FILE: RestoreFormer/modules/losses/vqperceptual.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from RestoreFormer.modules.losses.lpips import LPIPS
from RestoreFormer.modules.discriminator.model import NLayerDiscriminator, weights_init
from RestoreFormer.modules.vqvae.facial_component_discriminator import FacialComponentDiscriminator
from basicsr.losses.losses import GANLoss, L1Loss
from RestoreFormer.modules.vqvae.arcface_arch import ResNetArcFace
class DummyLoss(nn.Module):
def __init__(self):
super().__init__()
def adopt_weight(weight, global_step, threshold=0, value=0.):
if global_step < threshold:
weight = value
return weight
def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1. - logits_real))
loss_fake = torch.mean(F.relu(1. + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def vanilla_d_loss(logits_real, logits_fake):
d_loss = 0.5 * (
torch.mean(torch.nn.functional.softplus(-logits_real)) +
torch.mean(torch.nn.functional.softplus(logits_fake)))
return d_loss
class VQLPIPSWithDiscriminatorWithCompWithIdentity(nn.Module):
def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
perceptual_weight=1.0, use_actnorm=False,
disc_ndf=64, disc_loss="hinge", comp_weight=0.0, comp_style_weight=0.0,
identity_weight=0.0, comp_disc_loss='vanilla', lpips_style_weight=0.0,
identity_model_path=None, **ignore_kwargs):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
self.codebook_weight = codebook_weight
self.pixel_weight = pixelloss_weight
self.perceptual_loss = LPIPS(style_weight=lpips_style_weight).eval()
self.perceptual_weight = perceptual_weight
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
n_layers=disc_num_layers,
use_actnorm=use_actnorm,
ndf=disc_ndf
).apply(weights_init)
if comp_weight > 0:
self.net_d_left_eye = FacialComponentDiscriminator()
self.net_d_right_eye = FacialComponentDiscriminator()
self.net_d_mouth = FacialComponentDiscriminator()
print(f'Use components discrimination')
self.cri_component = GANLoss(gan_type=comp_disc_loss,
real_label_val=1.0,
fake_label_val=0.0,
loss_weight=comp_weight)
if comp_style_weight > 0.:
self.cri_style = L1Loss(loss_weight=comp_style_weight, reduction='mean')
if identity_weight > 0:
self.identity = ResNetArcFace(block = 'IRBlock',
layers = [2, 2, 2, 2],
use_se = False)
print(f'Use identity loss')
if identity_model_path is not None:
sd = torch.load(identity_model_path, map_location="cpu")
for k, v in deepcopy(sd).items():
if k.startswith('module.'):
sd[k[7:]] = v
sd.pop(k)
self.identity.load_state_dict(sd, strict=True)
for param in self.identity.parameters():
param.requires_grad = False
self.cri_identity = L1Loss(loss_weight=identity_weight, reduction='mean')
self.discriminator_iter_start = disc_start
if disc_loss == "hinge":
self.disc_loss = hinge_d_loss
elif disc_loss == "vanilla":
self.disc_loss = vanilla_d_loss
else:
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
print(f"VQLPIPSWithDiscriminatorWithCompWithIdentity running with {disc_loss} loss.")
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.comp_weight = comp_weight
self.comp_style_weight = comp_style_weight
self.identity_weight = identity_weight
self.lpips_style_weight = lpips_style_weight
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def _gram_mat(self, x):
"""Calculate Gram matrix.
Args:
x (torch.Tensor): Tensor with shape of (n, c, h, w).
Returns:
torch.Tensor: Gram matrix.
"""
n, c, h, w = x.size()
features = x.view(n, c, w * h)
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (c * h * w)
return gram
def gray_resize_for_identity(self, out, size=128):
out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
out_gray = out_gray.unsqueeze(1)
out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
return out_gray
def forward(self, codebook_loss, gts, reconstructions, components, optimizer_idx,
global_step, last_layer=None, split="train"):
# now the GAN part
if optimizer_idx == 0:
rec_loss = (torch.abs(gts.contiguous() - reconstructions.contiguous())) * self.pixel_weight
if self.perceptual_weight > 0:
p_loss, p_style_loss = self.perceptual_loss(gts.contiguous(), reconstructions.contiguous())
rec_loss = rec_loss + self.perceptual_weight * p_loss
else:
p_loss = torch.tensor([0.0])
p_style_loss = torch.tensor([0.0])
nll_loss = rec_loss
#nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
nll_loss = torch.mean(nll_loss)
# generator update
logits_fake = self.discriminator(reconstructions.contiguous())
g_loss = -torch.mean(logits_fake)
try:
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() + p_style_loss
log = {
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
"{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/p_loss".format(split): p_loss.detach().mean(),
"{}/p_style_loss".format(split): p_style_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
}
if self.comp_weight > 0. and components is not None and self.discriminator_iter_start < global_step:
fake_left_eye, fake_left_eye_feats = self.net_d_left_eye(components['left_eyes'], return_feats=True)
comp_g_loss = self.cri_component(fake_left_eye, True, is_disc=False)
loss = loss + comp_g_loss
log["{}/g_left_loss".format(split)] = comp_g_loss.detach()
fake_right_eye, fake_right_eye_feats = self.net_d_right_eye(components['right_eyes'], return_feats=True)
comp_g_loss = self.cri_component(fake_right_eye, True, is_disc=False)
loss = loss + comp_g_loss
log["{}/g_right_loss".format(split)] = comp_g_loss.detach()
fake_mouth, fake_mouth_feats = self.net_d_mouth(components['mouths'], return_feats=True)
comp_g_loss = self.cri_component(fake_mouth, True, is_disc=False)
loss = loss + comp_g_loss
log["{}/g_mouth_loss".format(split)] = comp_g_loss.detach()
if self.comp_style_weight > 0.:
_, real_left_eye_feats = self.net_d_left_eye(components['left_eyes_gt'], return_feats=True)
_, real_right_eye_feats = self.net_d_right_eye(components['right_eyes_gt'], return_feats=True)
_, real_mouth_feats = self.net_d_mouth(components['mouths_gt'], return_feats=True)
def _comp_style(feat, feat_gt, criterion):
return criterion(self._gram_mat(feat[0]), self._gram_mat(
feat_gt[0].detach())) * 0.5 + criterion(self._gram_mat(
feat[1]), self._gram_mat(feat_gt[1].detach()))
comp_style_loss = 0.
comp_style_loss = comp_style_loss + _comp_style(fake_left_eye_feats, real_left_eye_feats, self.cri_style)
comp_style_loss = comp_style_loss + _comp_style(fake_right_eye_feats, real_right_eye_feats, self.cri_style)
comp_style_loss = comp_style_loss + _comp_style(fake_mouth_feats, real_mouth_feats, self.cri_style)
loss = loss + comp_style_loss
log["{}/comp_style_loss".format(split)] = comp_style_loss.detach()
if self.identity_weight > 0. and self.discriminator_iter_start < global_step:
self.identity.eval()
out_gray = self.gray_resize_for_identity(reconstructions)
gt_gray = self.gray_resize_for_identity(gts)
identity_gt = self.identity(gt_gray).detach()
identity_out = self.identity(out_gray)
identity_loss = self.cri_identity(identity_out, identity_gt)
loss = loss + identity_loss
log["{}/identity_loss".format(split)] = identity_loss.detach()
log["{}/total_loss".format(split)] = loss.clone().detach().mean()
return loss, log
if optimizer_idx == 1:
# second pass for discriminator update
logits_real = self.discriminator(gts.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean()
}
return d_loss, log
# left eye
if optimizer_idx == 2:
# third pass for discriminator update
disc_factor = adopt_weight(1.0, global_step, threshold=self.discriminator_iter_start)
fake_d_pred, _ = self.net_d_left_eye(components['left_eyes'].detach())
real_d_pred, _ = self.net_d_left_eye(components['left_eyes_gt'])
d_loss = self.cri_component(real_d_pred, True, is_disc=True) + self.cri_component(fake_d_pred, False, is_disc=True)
log = {"{}/d_left_loss".format(split): d_loss.clone().detach().mean()}
return d_loss, log
# right eye
if optimizer_idx == 3:
# forth pass for discriminator update
fake_d_pred, _ = self.net_d_right_eye(components['right_eyes'].detach())
real_d_pred, _ = self.net_d_right_eye(components['right_eyes_gt'])
d_loss = self.cri_component(real_d_pred, True, is_disc=True) + self.cri_component(fake_d_pred, False, is_disc=True)
log = {"{}/d_right_loss".format(split): d_loss.clone().detach().mean()}
return d_loss, log
# mouth
if optimizer_idx == 4:
# fifth pass for discriminator update
fake_d_pred, _ = self.net_d_mouth(components['mouths'].detach())
real_d_pred, _ = self.net_d_mouth(components['mouths_gt'])
d_loss = self.cri_component(real_d_pred, True, is_disc=True) + self.cri_component(fake_d_pred, False, is_disc=True)
log = {"{}/d_mouth_loss".format(split): d_loss.clone().detach().mean()}
return d_loss, log
================================================
FILE: RestoreFormer/modules/util.py
================================================
import torch
import torch.nn as nn
def count_params(model):
total_params = sum(p.numel() for p in model.parameters())
return total_params
class ActNorm(nn.Module):
def __init__(self, num_features, logdet=False, affine=True,
allow_reverse_init=False):
assert affine
super().__init__()
self.logdet = logdet
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
self.allow_reverse_init = allow_reverse_init
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
def initialize(self, input):
with torch.no_grad():
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
mean = (
flatten.mean(1)
.unsqueeze(1)
.unsqueeze(2)
.unsqueeze(3)
.permute(1, 0, 2, 3)
)
std = (
flatten.std(1)
.unsqueeze(1)
.unsqueeze(2)
.unsqueeze(3)
.permute(1, 0, 2, 3)
)
self.loc.data.copy_(-mean)
self.scale.data.copy_(1 / (std + 1e-6))
def forward(self, input, reverse=False):
if reverse:
return self.reverse(input)
if len(input.shape) == 2:
input = input[:,:,None,None]
squeeze = True
else:
squeeze = False
_, _, height, width = input.shape
if self.training and self.initialized.item() == 0:
self.initialize(input)
self.initialized.fill_(1)
h = self.scale * (input + self.loc)
if squeeze:
h = h.squeeze(-1).squeeze(-1)
if self.logdet:
log_abs = torch.log(torch.abs(self.scale))
logdet = height*width*torch.sum(log_abs)
logdet = logdet * torch.ones(input.shape[0]).to(input)
return h, logdet
return h
def reverse(self, output):
if self.training and self.initialized.item() == 0:
if not self.allow_reverse_init:
raise RuntimeError(
"Initializing ActNorm in reverse direction is "
"disabled by default. Use allow_reverse_init=True to enable."
)
else:
self.initialize(output)
self.initialized.fill_(1)
if len(output.shape) == 2:
output = output[:,:,None,None]
squeeze = True
else:
squeeze = False
h = output / self.scale - self.loc
if squeeze:
h = h.squeeze(-1).squeeze(-1)
return h
class Attention2DConv(nn.Module):
"""to replace the convolutional architecture entirely"""
def __init__(self):
super().__init__()
================================================
FILE: RestoreFormer/modules/vqvae/arcface_arch.py
================================================
import torch.nn as nn
from basicsr.utils.registry import ARCH_REGISTRY
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class IRBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
super(IRBlock, self).__init__()
self.bn0 = nn.BatchNorm2d(inplanes)
self.conv1 = conv3x3(inplanes, inplanes)
self.bn1 = nn.BatchNorm2d(inplanes)
self.prelu = nn.PReLU()
self.conv2 = conv3x3(inplanes, planes, stride)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
self.use_se = use_se
if self.use_se:
self.se = SEBlock(planes)
def forward(self, x):
residual = x
out = self.bn0(x)
out = self.conv1(out)
out = self.bn1(out)
out = self.prelu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.use_se:
out = self.se(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.prelu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class SEBlock(nn.Module):
def __init__(self, channel, reduction=16):
super(SEBlock, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
nn.Sigmoid())
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y
@ARCH_REGISTRY.register()
class ResNetArcFace(nn.Module):
def __init__(self, block, layers, use_se=True):
if block == 'IRBlock':
block = IRBlock
self.inplanes = 64
self.use_se = use_se
super(ResNetArcFace, self).__init__()
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.prelu = nn.PReLU()
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.bn4 = nn.BatchNorm2d(512)
self.dropout = nn.Dropout()
self.fc5 = nn.Linear(512 * 8 * 8, 512)
self.bn5 = nn.BatchNorm1d(512)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
self.inplanes = planes
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, use_se=self.use_se))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.prelu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.bn4(x)
x = self.dropout(x)
x = x.view(x.size(0), -1)
x = self.fc5(x)
x = self.bn5(x)
return x
================================================
FILE: RestoreFormer/modules/vqvae/facial_component_discriminator.py
================================================
import math
import random
import torch
from torch import nn
from torch.nn import functional as F
from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
StyleGAN2Generator)
from basicsr.ops.fused_act import FusedLeakyReLU
from basicsr.utils.registry import ARCH_REGISTRY
@ARCH_REGISTRY.register()
class FacialComponentDiscriminator(nn.Module):
def __init__(self):
super(FacialComponentDiscriminator, self).__init__()
self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
self.conv4 = ConvLayer(128, 256, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
self.conv5 = ConvLayer(256, 256, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
def forward(self, x, return_feats=False):
feat = self.conv1(x)
feat = self.conv3(self.conv2(feat))
rlt_feats = []
if return_feats:
rlt_feats.append(feat.clone())
feat = self.conv5(self.conv4(feat))
if return_feats:
rlt_feats.append(feat.clone())
out = self.final_conv(feat)
if return_feats:
return out, rlt_feats
else:
return out, None
================================================
FILE: RestoreFormer/modules/vqvae/utils.py
================================================
from torchvision.ops import roi_align
import torch
def get_roi_regions(gt, output, loc_left_eyes, loc_right_eyes, loc_mouths,
face_ratio=1, eye_out_size=80, mouth_out_size=120):
# hard code
eye_out_size *= face_ratio
mouth_out_size *= face_ratio
eye_out_size = int(eye_out_size)
mouth_out_size = int(mouth_out_size)
rois_eyes = []
rois_mouths = []
for b in range(loc_left_eyes.size(0)): # loop for batch size
# left eye and right eye
img_inds = loc_left_eyes.new_full((2, 1), b)
bbox = torch.stack([loc_left_eyes[b, :], loc_right_eyes[b, :]], dim=0) # shape: (2, 4)
rois = torch.cat([img_inds, bbox], dim=-1) # shape: (2, 5)
rois_eyes.append(rois)
# mouse
img_inds = loc_left_eyes.new_full((1, 1), b)
rois = torch.cat([img_inds, loc_mouths[b:b + 1, :]], dim=-1) # shape: (1, 5)
rois_mouths.append(rois)
rois_eyes = torch.cat(rois_eyes, 0)
rois_mouths = torch.cat(rois_mouths, 0)
# real images
all_eyes = roi_align(gt, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
left_eyes_gt = all_eyes[0::2, :, :, :]
right_eyes_gt = all_eyes[1::2, :, :, :]
mouths_gt = roi_align(gt, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
# output
all_eyes = roi_align(output, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
left_eyes = all_eyes[0::2, :, :, :]
right_eyes = all_eyes[1::2, :, :, :]
mouths = roi_align(output, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
return {'left_eyes_gt': left_eyes_gt, 'right_eyes_gt': right_eyes_gt, 'mouths_gt': mouths_gt,
'left_eyes': left_eyes, 'right_eyes': right_eyes, 'mouths': mouths}
================================================
FILE: RestoreFormer/modules/vqvae/vqvae_arch.py
================================================
import torch
import torch.nn as nn
import random
import math
import torch.nn.functional as F
import numpy as np
# from basicsr.utils.registry import ARCH_REGISTRY
import torch.nn.utils.spectral_norm as SpectralNorm
import RestoreFormer.distributed as dist_fn
class VectorQuantizer(nn.Module):
"""
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
____________________________________________
Discretization bottleneck part of the VQ-VAE.
Inputs:
- n_e : number of embeddings
- e_dim : dimension of embedding
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
_____________________________________________
"""
def __init__(self, n_e, e_dim, beta):
super(VectorQuantizer, self).__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
def forward(self, z):
"""
Inputs the output of the encoder network z and maps it to a discrete
one-hot vector that is the index of the closest embedding vector e_j
z (continuous) -> z_q (discrete)
z.shape = (batch, channel, height, width)
quantization pipeline:
1. get encoder input (B,C,H,W)
2. flatten input to (B*H*W,C)
"""
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.e_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
torch.matmul(z_flattened, self.embedding.weight.t())
## could possible replace this here
# #\start...
# find closest encodings
min_value, min_encoding_indices = torch.min(d, dim=1)
min_encoding_indices = min_encoding_indices.unsqueeze(1)
min_encodings = torch.zeros(
min_encoding_indices.shape[0], self.n_e).to(z)
min_encodings.scatter_(1, min_encoding_indices, 1)
# dtype min encodings: torch.float32
# min_encodings shape: torch.Size([2048, 512])
# min_encoding_indices.shape: torch.Size([2048, 1])
# get quantized latent vectors
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
#.........\end
# with:
# .........\start
#min_encoding_indices = torch.argmin(d, dim=1)
#z_q = self.embedding(min_encoding_indices)
# ......\end......... (TODO)
# compute loss for embedding
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
z_q = z + (z_q - z).detach()
# perplexity
e_mean = torch.mean(min_encodings, dim=0)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q, loss, (perplexity, min_encodings, min_encoding_indices, d)
def get_codebook_entry(self, indices, shape):
# shape specifying (batch, height, width, channel)
# TODO: check for more easy handling with nn.Embedding
min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
min_encodings.scatter_(1, indices[:,None], 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
if shape is not None:
z_q = z_q.view(shape)
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q
# pytorch_diffusion + derived encoder decoder
def nonlinearity(x):
# swish
return x*torch.sigmoid(x)
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=2,
padding=0)
def forward(self, x):
if self.with_conv:
pad = (0,1,0,1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
dropout, temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels,
out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x+h
class MultiHeadAttnBlock(nn.Module):
def __init__(self, in_channels, head_size=1):
super().__init__()
self.in_channels = in_channels
self.head_size = head_size
self.att_size = in_channels // head_size
assert(in_channels % head_size == 0), 'The size of head should be divided by the number of channels.'
self.norm1 = Normalize(in_channels)
self.norm2 = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.num = 0
def forward(self, x, y=None):
h_ = x
h_ = self.norm1(h_)
if y is None:
y = h_
else:
y = self.norm2(y)
q = self.q(y)
k = self.k(h_)
v = self.v(h_)
# compute attention
b,c,h,w = q.shape
q = q.reshape(b, self.head_size, self.att_size ,h*w)
q = q.permute(0, 3, 1, 2) # b, hw, head, att
k = k.reshape(b, self.head_size, self.att_size ,h*w)
k = k.permute(0, 3, 1, 2)
v = v.reshape(b, self.head_size, self.att_size ,h*w)
v = v.permute(0, 3, 1, 2)
q = q.transpose(1, 2)
v = v.transpose(1, 2)
k = k.transpose(1, 2).transpose(2,3)
scale = int(self.att_size)**(-0.5)
q.mul_(scale)
w_ = torch.matmul(q, k)
w_ = F.softmax(w_, dim=3)
w_ = w_.matmul(v)
w_ = w_.transpose(1, 2).contiguous() # [b, h*w, head, att]
w_ = w_.view(b, h, w, -1)
w_ = w_.permute(0, 3, 1, 2)
w_ = self.proj_out(w_)
return x+w_
class MultiHeadEncoder(nn.Module):
def __init__(self, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks=2,
attn_resolutions=[16], dropout=0.0, resamp_with_conv=True, in_channels=3,
resolution=512, z_channels=256, double_z=True, enable_mid=True,
head_size=1, **ignore_kwargs):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.enable_mid = enable_mid
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
padding=1)
curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(MultiHeadAttnBlock(block_in, head_size))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
if self.enable_mid:
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
2*z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
#assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
hs = {}
# timestep embedding
temb = None
# downsampling
h = self.conv_in(x)
hs['in'] = h
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h, temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions-1:
# hs.append(h)
hs['block_'+str(i_level)] = h
h = self.down[i_level].downsample(h)
# middle
# h = hs[-1]
if self.enable_mid:
h = self.mid.block_1(h, temb)
hs['block_'+str(i_level)+'_atten'] = h
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
hs['mid_atten'] = h
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
# hs.append(h)
hs['out'] = h
return hs
class MultiHeadDecoder(nn.Module):
def __init__(self, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks=2,
attn_resolutions=16, dropout=0.0, resamp_with_conv=True, in_channels=3,
resolution=512, z_channels=256, give_pre_end=False, enable_mid=True,
head_size=1, **ignorekwargs):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.enable_mid = enable_mid
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,)+tuple(ch_mult)
block_in = ch*ch_mult[self.num_resolutions-1]
curr_res = resolution // 2**(self.num_resolutions-1)
self.z_shape = (1,z_channels,curr_res,curr_res)
print("Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels,
block_in,
kernel_size=3,
stride=1,
padding=1)
# middle
if self.enable_mid:
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks+1):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(MultiHeadAttnBlock(block_in, head_size))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1)
def forward(self, z):
#assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
if self.enable_mid:
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class MultiHeadDecoderTransformer(nn.Module):
def __init__(self, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks=2,
attn_resolutions=16, dropout=0.0, resamp_with_conv=True, in_channels=3,
resolution=512, z_channels=256, give_pre_end=False, enable_mid=True,
head_size=1, **ignorekwargs):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.enable_mid = enable_mid
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,)+tuple(ch_mult)
block_in = ch*ch_mult[self.num_resolutions-1]
curr_res = resolution // 2**(self.num_resolutions-1)
self.z_shape = (1,z_channels,curr_res,curr_res)
print("Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels,
block_in,
kernel_size=3,
stride=1,
padding=1)
# middle
if self.enable_mid:
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks+1):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(MultiHeadAttnBlock(block_in, head_size))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1)
def forward(self, z, hs):
#assert z.shape[1:] == self.z_shape[1:]
# self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
if self.enable_mid:
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h, hs['mid_atten'])
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, hs['block_'+str(i_level)+'_atten'])
# hfeature = h.clone()
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class VQVAEGAN(nn.Module):
def __init__(self, n_embed=1024, embed_dim=256, ch=128, out_ch=3, ch_mult=(1,2,4,8),
num_res_blocks=2, attn_resolutions=16, dropout=0.0, in_channels=3,
resolution=512, z_channels=256, double_z=False, enable_mid=True,
fix_decoder=False, fix_codebook=False, head_size=1, **ignore_kwargs):
super(VQVAEGAN, self).__init__()
self.encoder = MultiHeadEncoder(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions, dropout=dropout, in_channels=in_channels,
resolution=resolution, z_channels=z_channels, double_z=double_z,
enable_mid=enable_mid, head_size=head_size)
self.decoder = MultiHeadDecoder(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions, dropout=dropout, in_channels=in_channels,
resolution=resolution, z_channels=z_channels, enable_mid=enable_mid, head_size=head_size)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)
self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
if fix_decoder:
for _, param in self.decoder.named_parameters():
param.requires_grad = False
for _, param in self.post_quant_conv.named_parameters():
param.requires_grad = False
for _, param in self.quantize.named_parameters():
param.requires_grad = False
elif fix_codebook:
for _, param in self.quantize.named_parameters():
param.requires_grad = False
def encode(self, x):
hs = self.encoder(x)
h = self.quant_conv(hs['out'])
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info, hs
def decode(self, quant):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
def forward(self, input):
quant, diff, info, hs = self.encode(input)
dec = self.decode(quant)
return dec, diff, info, hs
class VQVAEGANMultiHeadTransformer(nn.Module):
def __init__(self, n_embed=1024, embed_dim=256, ch=128, out_ch=3, ch_mult=(1,2,4,8),
num_res_blocks=2, attn_resolutions=16, dropout=0.0, in_channels=3,
resolution=512, z_channels=256, double_z=False, enable_mid=True,
fix_decoder=False, fix_codebook=False, fix_encoder=False, constrastive_learning_loss_weight=0.0,
head_size=1):
super(VQVAEGANMultiHeadTransformer, self).__init__()
self.encoder = MultiHeadEncoder(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions, dropout=dropout, in_channels=in_channels,
resolution=resolution, z_channels=z_channels, double_z=double_z,
enable_mid=enable_mid, head_size=head_size)
self.decoder = MultiHeadDecoderTransformer(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions, dropout=dropout, in_channels=in_channels,
resolution=resolution, z_channels=z_channels, enable_mid=enable_mid, head_size=head_size)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)
self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
if fix_decoder:
for _, param in self.decoder.named_parameters():
param.requires_grad = False
for _, param in self.post_quant_conv.named_parameters():
param.requires_grad = False
for _, param in self.quantize.named_parameters():
param.requires_grad = False
elif fix_codebook:
for _, param in self.quantize.named_parameters():
param.requires_grad = False
if fix_encoder:
for _, param in self.encoder.named_parameters():
param.requires_grad = False
def encode(self, x):
hs = self.encoder(x)
h = self.quant_conv(hs['out'])
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info, hs
def decode(self, quant, hs):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant, hs)
return dec
def forward(self, input):
quant, diff, info, hs = self.encode(input)
dec = self.decode(quant, hs)
return dec, diff, info, hs
================================================
FILE: RestoreFormer/util.py
================================================
import os, hashlib
import requests
from tqdm import tqdm
URL_MAP = {
"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
}
CKPT_MAP = {
"vgg_lpips": "vgg.pth"
}
MD5_MAP = {
"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
}
def download(url, local_path, chunk_size=1024):
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
with requests.get(url, stream=True) as r:
total_size = int(r.headers.get("content-length", 0))
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
with open(local_path, "wb") as f:
for data in r.iter_content(chunk_size=chunk_size):
if data:
f.write(data)
pbar.update(chunk_size)
def md5_hash(path):
with open(path, "rb") as f:
content = f.read()
return hashlib.md5(content).hexdigest()
def get_ckpt_path(name, root, check=False):
assert name in URL_MAP
path = os.path.join(root, CKPT_MAP[name])
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
download(URL_MAP[name], path)
md5 = md5_hash(path)
assert md5 == MD5_MAP[name], md5
return path
class KeyNotFoundError(Exception):
def __init__(self, cause, keys=None, visited=None):
self.cause = cause
self.keys = keys
self.visited = visited
messages = list()
if keys is not None:
messages.append("Key not found: {}".format(keys))
if visited is not None:
messages.append("Visited: {}".format(visited))
messages.append("Cause:\n{}".format(cause))
message = "\n".join(messages)
super().__init__(message)
def retrieve(
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
):
"""Given a nested list or dict return the desired value at key expanding
callable nodes if necessary and :attr:`expand` is ``True``. The expansion
is done in-place.
Parameters
----------
list_or_dict : list or dict
Possibly nested list or dictionary.
key : str
key/to/value, path like string describing all keys necessary to
consider to get to the desired value. List indices can also be
passed here.
splitval : str
String that defines the delimiter between keys of the
different depth levels in `key`.
default : obj
Value returned if :attr:`key` is not found.
expand : bool
Whether to expand callable nodes on the path or not.
Returns
-------
The desired value or if :attr:`default` is not ``None`` and the
:attr:`key` is not found returns ``default``.
Raises
------
Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
``None``.
"""
keys = key.split(splitval)
success = True
try:
visited = []
parent = None
last_key = None
for key in keys:
if callable(list_or_dict):
if not expand:
raise KeyNotFoundError(
ValueError(
"Trying to get past callable node with expand=False."
),
keys=keys,
visited=visited,
)
list_or_dict = list_or_dict()
parent[last_key] = list_or_dict
last_key = key
parent = list_or_dict
try:
if isinstance(list_or_dict, dict):
list_or_dict = list_or_dict[key]
else:
list_or_dict = list_or_dict[int(key)]
except (KeyError, IndexError, ValueError) as e:
raise KeyNotFoundError(e, keys=keys, visited=visited)
visited += [key]
# final expansion of retrieved value
if expand and callable(list_or_dict):
list_or_dict = list_or_dict()
parent[last_key] = list_or_dict
except KeyNotFoundError as e:
if default is None:
raise e
else:
list_or_dict = default
success = False
if not pass_success:
return list_or_dict
else:
return list_or_dict, success
if __name__ == "__main__":
config = {"keya": "a",
"keyb": "b",
"keyc":
{"cc1": 1,
"cc2": 2,
}
}
from omegaconf import OmegaConf
config = OmegaConf.create(config)
print(config)
retrieve(config, "keya")
================================================
FILE: __init__.py
================================================
================================================
FILE: configs/HQ_Dictionary.yaml
================================================
model:
base_learning_rate: 4.5e-6
target: RestoreFormer.models.vqgan_v1.RestoreFormerModel
params:
image_key: 'gt'
schedule_step: [400000, 800000]
# ignore_keys: ['vqvae.quantize.utility_counter']
ddconfig:
target: RestoreFormer.modules.vqvae.vqvae_arch.VQVAEGAN
params:
embed_dim: 256
n_embed: 1024
double_z: False
z_channels: 256
resolution: 512
in_channels: 3
out_ch: 3
ch: 64
ch_mult: [ 1,2,2,4,4,8] # num_down = len(ch_mult)-1
num_res_blocks: 2
attn_resolutions: [16]
dropout: 0.0
enable_mid: True
fix_decoder: False
fix_codebook: False
head_size: 8
lossconfig:
target: RestoreFormer.modules.losses.vqperceptual.VQLPIPSWithDiscriminatorWithCompWithIdentity
params:
disc_conditional: False
disc_in_channels: 3
disc_start: 30001
disc_weight: 0.8
codebook_weight: 1.0
use_actnorm: False
data:
target: main.DataModuleFromConfig
params:
batch_size: 4
num_workers: 8
train:
target: basicsr.data.ffhq_dataset.FFHQDataset
params:
dataroot_gt: data/FFHQ/images512x512
io_backend:
type: disk
use_hflip: True
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
out_size: 512
validation:
target: basicsr.data.ffhq_dataset.FFHQDataset
params:
dataroot_gt: data/FFHQ/images512x512
io_backend:
type: disk
use_hflip: False
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
out_size: 512
================================================
FILE: configs/RestoreFormer.yaml
================================================
model:
base_learning_rate: 4.5e-6
target: RestoreFormer.models.vqgan_v1.RestoreFormerModel
params:
image_key: 'lq'
ckpt_path: 'YOUR TRAINED HD DICTIONARY MODEL'
special_params_lr_scale: 10
comp_params_lr_scale: 10
schedule_step: [4000000, 8000000]
ddconfig:
target: RestoreFormer.modules.vqvae.vqvae_arch.VQVAEGANMultiHeadTransformer
params:
embed_dim: 256
n_embed: 1024
double_z: False
z_channels: 256
resolution: 512
in_channels: 3
out_ch: 3
ch: 64
ch_mult: [ 1,2,2,4,4,8] # num_down = len(ch_mult)-1
num_res_blocks: 2
dropout: 0.0
attn_resolutions: [16]
enable_mid: True
fix_decoder: False
fix_codebook: True
fix_encoder: False
head_size: 8
lossconfig:
target: RestoreFormer.modules.losses.vqperceptual.VQLPIPSWithDiscriminatorWithCompWithIdentity
params:
disc_conditional: False
disc_in_channels: 3
disc_start: 10001
disc_weight: 0.8
codebook_weight: 1.0
use_actnorm: False
comp_weight: 1.5
comp_style_weight: 2e3 #2000.0
identity_weight: 3 #1.5
lpips_style_weight: 1e9
identity_model_path: experiments/pretrained_models/arcface_resnet18.pth
data:
target: main.DataModuleFromConfig
params:
batch_size: 4
num_workers: 8
train:
target: RestoreFormer.data.ffhq_degradation_dataset.FFHQDegradationDataset
params:
dataroot_gt: data/FFHQ/images512x512
io_backend:
type: disk
use_hflip: True
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
out_size: 512
blur_kernel_size: [19,20]
kernel_list: ['iso', 'aniso']
kernel_prob: [0.5, 0.5]
blur_sigma: [0.1, 10]
downsample_range: [0.8, 8]
noise_range: [0, 20]
jpeg_range: [60, 100]
color_jitter_prob: ~
color_jitter_shift: 20
color_jitter_pt_prob: ~
gray_prob: ~
gt_gray: True
crop_components: True
component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth
eye_enlarge_ratio: 1.4
validation:
target: RestoreFormer.data.ffhq_degradation_dataset.FFHQDegradationDataset
params:
dataroot_gt: data/FFHQ/images512x512
io_backend:
type: disk
use_hflip: False
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
out_size: 512
blur_kernel_size: [19,20]
kernel_list: ['iso', 'aniso']
kernel_prob: [0.5, 0.5]
blur_sigma: [0.1, 10]
downsample_range: [0.8, 8]
noise_range: [0, 20]
jpeg_range: [60, 100]
# color jitter and gray
color_jitter_prob: ~
color_jitter_shift: 20
color_jitter_pt_prob: ~
gray_prob: ~
gt_gray: True
crop_components: False
component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth
eye_enlarge_ratio: 1.4
================================================
FILE: main.py
================================================
import argparse, os, sys, datetime, glob, importlib
from omegaconf import OmegaConf
import numpy as np
from PIL import Image
import torch
import torchvision
from torch.utils.data import random_split, DataLoader, Dataset
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
from pytorch_lightning.utilities.distributed import rank_zero_only
import random
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def get_parser(**parser_kwargs):
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
parser = argparse.ArgumentParser(**parser_kwargs)
parser.add_argument(
"-n",
"--name",
type=str,
const=True,
default="",
nargs="?",
help="postfix for logdir",
)
parser.add_argument(
"-r",
"--resume",
type=str,
const=True,
default="",
nargs="?",
help="resume from logdir or checkpoint in logdir",
)
parser.add_argument(
"--pretrain",
type=str,
const=True,
default="",
nargs="?",
help="pretrain with existed weights",
)
parser.add_argument(
"-b",
"--base",
nargs="*",
metavar="base_config.yaml",
help="paths to base configs. Loaded from left-to-right. "
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
default=list(),
)
parser.add_argument(
"-t",
"--train",
type=str2bool,
const=True,
default=False,
nargs="?",
help="train",
)
parser.add_argument(
"--no-test",
type=str2bool,
const=True,
default=False,
nargs="?",
help="disable test",
)
parser.add_argument("-p", "--project", help="name of new or path to existing project")
parser.add_argument(
"-d",
"--debug",
type=str2bool,
nargs="?",
const=True,
default=False,
help="enable post-mortem debugging",
)
parser.add_argument(
"-s",
"--seed",
type=int,
default=23,
help="seed for seed_everything",
)
parser.add_argument(
"--random-seed",
type=str2bool,
nargs="?",
const=True,
default=False,
help="enable post-mortem debugging",
)
parser.add_argument(
"-f",
"--postfix",
type=str,
default="",
help="post-postfix for default name",
)
parser.add_argument(
"--root-path",
type=str,
default="./",
help="root path for saving checkpoints and logs"
)
parser.add_argument(
"--num-nodes",
type=int,
default=1,
help="number of gpu nodes",
)
return parser
def nondefault_trainer_args(opt):
parser = argparse.ArgumentParser()
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args([])
return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
def instantiate_from_config(config):
if not "target" in config:
raise KeyError("Expected key `target` to instantiate.")
if 'basicsr.data' in config["target"] or \
'FFHQDegradationDataset' in config["target"]:
return get_obj_from_str(config["target"])(config.get("params", dict()))
return get_obj_from_str(config["target"])(**config.get("params", dict()))
class WrappedDataset(Dataset):
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
def __init__(self, dataset):
self.data = dataset
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
class DataModuleFromConfig(pl.LightningDataModule):
def __init__(self, batch_size, train=None, validation=None, test=None,
wrap=False, num_workers=None):
super().__init__()
self.batch_size = batch_size
self.dataset_configs = dict()
self.num_workers = num_workers if num_workers is not None else batch_size*2
if train is not None:
self.dataset_configs["train"] = train
self.train_dataloader = self._train_dataloader
if validation is not None:
self.dataset_configs["validation"] = validation
self.val_dataloader = self._val_dataloader
if test is not None:
self.dataset_configs["test"] = test
self.test_dataloader = self._test_dataloader
self.wrap = wrap
def prepare_data(self):
for data_cfg in self.dataset_configs.values():
instantiate_from_config(data_cfg)
def setup(self, stage=None):
self.datasets = dict(
(k, instantiate_from_config(self.dataset_configs[k]))
for k in self.dataset_configs)
if self.wrap:
for k in self.datasets:
self.datasets[k] = WrappedDataset(self.datasets[k])
def _train_dataloader(self):
return DataLoader(self.datasets["train"], batch_size=self.batch_size,
num_workers=self.num_workers, shuffle=True)
def _val_dataloader(self):
return DataLoader(self.datasets["validation"],
batch_size=self.batch_size,
num_workers=self.num_workers)
def _test_dataloader(self):
return DataLoader(self.datasets["test"], batch_size=self.batch_size,
num_workers=self.num_workers)
class SetupCallback(Callback):
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
super().__init__()
self.resume = resume
self.now = now
self.logdir = logdir
self.ckptdir = ckptdir
self.cfgdir = cfgdir
self.config = config
self.lightning_config = lightning_config
def on_pretrain_routine_start(self, trainer, pl_module):
if trainer.global_rank == 0:
# import pdb
# pdb.set_trace()
# Create logdirs and save configs
os.makedirs(self.logdir, exist_ok=True)
os.makedirs(self.ckptdir, exist_ok=True)
os.makedirs(self.cfgdir, exist_ok=True)
print("Project config")
print(self.config.pretty())
OmegaConf.save(self.config,
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
print("Lightning config")
print(self.lightning_config.pretty())
OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
class ImageLogger(Callback):
def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True):
super().__init__()
self.batch_freq = batch_frequency
self.max_images = max_images
self.logger_log_images = {
pl.loggers.WandbLogger: self._wandb,
pl.loggers.TestTubeLogger: self._testtube,
}
self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
if not increase_log_steps:
self.log_steps = [self.batch_freq]
self.clamp = clamp
@rank_zero_only
def _wandb(self, pl_module, images, batch_idx, split):
raise ValueError("No way wandb")
grids = dict()
for k in images:
grid = torchvision.utils.make_grid(images[k])
grids[f"{split}/{k}"] = wandb.Image(grid)
pl_module.logger.experiment.log(grids)
@rank_zero_only
def _testtube(self, pl_module, images, batch_idx, split):
for k in images:
grid = torchvision.utils.make_grid(images[k])
grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w
tag = f"{split}/{k}"
pl_module.logger.experiment.add_image(
tag, grid,
global_step=pl_module.global_step)
@rank_zero_only
def log_local(self, save_dir, split, images,
global_step, current_epoch, batch_idx):
root = os.path.join(save_dir, "images", split)
for k in images:
grid = torchvision.utils.make_grid(images[k], nrow=4)
grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w
grid = grid.transpose(0,1).transpose(1,2).squeeze(-1)
grid = grid.numpy()
grid = (grid*255).astype(np.uint8)
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
k,
global_step,
current_epoch,
batch_idx)
path = os.path.join(root, filename)
os.makedirs(os.path.split(path)[0], exist_ok=True)
Image.fromarray(grid).save(path)
def log_img(self, pl_module, batch, batch_idx, split="train"):
if (self.check_frequency(batch_idx) and # batch_idx % self.batch_freq == 0
hasattr(pl_module, "log_images") and
callable(pl_module.log_images) and
self.max_images > 0):
logger = type(pl_module.logger)
is_train = pl_module.training
if is_train:
pl_module.eval()
with torch.no_grad():
images = pl_module.log_images(batch, split=split)
for k in images:
N = min(images[k].shape[0], self.max_images)
images[k] = images[k][:N]
if isinstance(images[k], torch.Tensor):
images[k] = images[k].detach().cpu()
if self.clamp:
images[k] = torch.clamp(images[k], -1., 1.)
self.log_local(pl_module.logger.save_dir, split, images,
pl_module.global_step, pl_module.current_epoch, batch_idx)
logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
logger_log_images(pl_module, images, pl_module.global_step, split)
if is_train:
pl_module.train()
def check_frequency(self, batch_idx):
if (batch_idx % self.batch_freq) == 0 or (batch_idx in self.log_steps):
try:
self.log_steps.pop(0)
except IndexError:
pass
return True
return False
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self.log_img(pl_module, batch, batch_idx, split="train")
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self.log_img(pl_module, batch, batch_idx, split="val")
if __name__ == "__main__":
# custom parser to specify config files, train, test and debug mode,
# postfix, resume.
# `--key value` arguments are interpreted as arguments to the trainer.
# `nested.key=value` arguments are interpreted as config parameters.
# configs are merged from left-to-right followed by command line parameters.
# model:
# base_learning_rate: float
# target: path to lightning module
# params:
# key: value
# data:
# target: main.DataModuleFromConfig
# params:
# batch_size: int
# wrap: bool
# train:
# target: path to train dataset
# params:
# key: value
# validation:
# target: path to validation dataset
# params:
# key: value
# test:
# target: path to test dataset
# params:
# key: value
# lightning: (optional, has sane defaults and can be specified on cmdline)
# trainer:
# additional arguments to trainer
# logger:
# logger to instantiate
# modelcheckpoint:
# modelcheckpoint to instantiate
# callbacks:
# callback1:
# target: importpath
# params:
# key: value
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
# add cwd for convenience and to make classes in this file available when
# running as `python main.py`
# (in particular `main.DataModuleFromConfig`)
sys.path.append(os.getcwd())
parser = get_parser()
parser = Trainer.add_argparse_args(parser)
opt, unknown = parser.parse_known_args()
if opt.name and opt.resume:
raise ValueError(
"-n/--name and -r/--resume cannot be specified both."
"If you want to resume training in a new log folder, "
"use -n/--name in combination with --resume_from_checkpoint"
)
if opt.resume:
if not os.path.exists(opt.resume):
raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume):
paths = opt.resume.split("/")
idx = len(paths)-paths[::-1].index("logs")+1
logdir = "/".join(paths[:idx])
ckpt = opt.resume
else:
assert os.path.isdir(opt.resume), opt.resume
logdir = opt.resume.rstrip("/")
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
opt.resume_from_checkpoint = ckpt
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
opt.base = base_configs+opt.base
_tmp = logdir.split("/")
nowname = _tmp[_tmp.index("logs")+1]+opt.postfix
logdir = os.path.join(opt.root_path, "logs", nowname)
else:
if opt.name:
name = "_"+opt.name
elif opt.base:
cfg_fname = os.path.split(opt.base[0])[-1]
cfg_name = os.path.splitext(cfg_fname)[0]
name = "_"+cfg_name
else:
name = ""
nowname = now+name+opt.postfix
logdir = os.path.join(opt.root_path, "logs", nowname)
if opt.random_seed:
opt.seed = random.randint(1,100)
logdir = logdir + '_seed' + str(opt.seed)
ckptdir = os.path.join(logdir, "checkpoints")
cfgdir = os.path.join(logdir, "configs")
seed_everything(opt.seed)
try:
# init and save configs
configs = [OmegaConf.load(cfg) for cfg in opt.base]
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(*configs, cli)
lightning_config = config.pop("lightning", OmegaConf.create())
# merge trainer cli with config
trainer_config = lightning_config.get("trainer", OmegaConf.create())
# default to ddp
# trainer_config["distributed_backend"] = "ddp"
trainer_config["accelerator"] = "ddp"
# trainer_config["plugins"]="ddp_sharded"
for k in nondefault_trainer_args(opt):
trainer_config[k] = getattr(opt, k)
if not "gpus" in trainer_config:
del trainer_config["distributed_backend"]
cpu = True
else:
gpuinfo = trainer_config["gpus"]
print(f"Running on GPUs {gpuinfo}")
cpu = False
trainer_opt = argparse.Namespace(**trainer_config)
lightning_config.trainer = trainer_config
# model
model = instantiate_from_config(config.model)
# trainer and callbacks
trainer_kwargs = dict()
# trainer_kwargs['sync_batchnorm'] = True
# default logger configs
# NOTE wandb < 0.10.0 interferes with shutdown
# wandb >= 0.10.0 seems to fix it but still interferes with pudb
# debugging (wrongly sized pudb ui)
# thus prefer testtube for now
default_logger_cfgs = {
"wandb": {
"target": "pytorch_lightning.loggers.WandbLogger",
"params": {
"name": nowname,
"save_dir": logdir,
"offline": opt.debug,
"id": nowname,
}
},
"testtube": {
"target": "pytorch_lightning.loggers.TestTubeLogger",
"params": {
"name": "testtube",
"save_dir": logdir,
}
},
}
default_logger_cfg = default_logger_cfgs["testtube"]
logger_cfg = lightning_config.logger or OmegaConf.create()
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
# specify which metric is used to determine best models
default_modelckpt_cfg = {
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"params": {
"dirpath": ckptdir,
"filename": "{epoch:06}",
"verbose": True,
"save_last": True,
"period": 1
}
}
if hasattr(model, "monitor"):
print(f"Monitoring {model.monitor} as checkpoint metric.")
default_modelckpt_cfg["params"]["monitor"] = model.monitor
default_modelckpt_cfg["params"]["save_top_k"] = 3
modelckpt_cfg = lightning_config.modelcheckpoint or OmegaConf.create()
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
# add callback which sets up log directory
default_callbacks_cfg = {
"setup_callback": {
"target": "main.SetupCallback",
"params": {
"resume": opt.resume,
"now": now,
"logdir": logdir,
"ckptdir": ckptdir,
"cfgdir": cfgdir,
"config": config,
"lightning_config": lightning_config,
}
},
"image_logger": {
"target": "main.ImageLogger",
"params": {
"batch_frequency": 750,
"max_images": 4,
"clamp": True
}
},
"learning_rate_logger": {
"target": "main.LearningRateMonitor",
"params": {
"logging_interval": "step",
#"log_momentum": True
}
},
}
callbacks_cfg = lightning_config.callbacks or OmegaConf.create()
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
# data
data = instantiate_from_config(config.data)
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
# calling these ourselves should not be necessary but it is.
# lightning still takes care of proper multiprocessing though
data.prepare_data()
data.setup()
# configure learning rate
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
if not cpu:
ngpu = len(lightning_config.trainer.gpus.strip(",").split(','))
else:
ngpu = 1
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches or 1
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
model.learning_rate = accumulate_grad_batches * ngpu * bs * trainer.num_nodes * base_lr
print("Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (num_nodes) * {} (batchsize) * {:.2e} (base_lr)".format(
model.learning_rate, accumulate_grad_batches, ngpu, trainer.num_nodes, bs, base_lr))
# allow checkpointing via USR1
def melk(*args, **kwargs):
# run all checkpoint hooks
if trainer.global_rank == 0:
print("Summoning checkpoint.")
ckpt_path = os.path.join(ckptdir, "last.ckpt")
trainer.save_checkpoint(ckpt_path)
def divein(*args, **kwargs):
if trainer.global_rank == 0:
import pudb; pudb.set_trace()
import signal
signal.signal(signal.SIGUSR1, melk)
signal.signal(signal.SIGUSR2, divein)
# run
if opt.train:
try:
trainer.fit(model, data)
except Exception:
melk()
raise
if not opt.no_test and not trainer.interrupted:
trainer.test(model, data)
except Exception:
if opt.debug and trainer.global_rank==0:
try:
import pudb as debugger
except ImportError:
import pdb as debugger
debugger.post_mortem()
raise
finally:
# move newly created debug project to debug_runs
if opt.debug and not opt.resume and trainer.global_rank==0:
dst, name = os.path.split(logdir)
dst = os.path.join(dst, "debug_runs", name)
os.makedirs(os.path.split(dst)[0], exist_ok=True)
os.rename(logdir, dst)
================================================
FILE: restoreformer_requirement.txt
================================================
Package Version Location
----------------------- ------------------- ------------------------------------------------------------------------------
absl-py 0.13.0
addict 2.4.0
aiohttp 3.7.4.post0
albumentations 0.4.3
antlr4-python3-runtime 4.8
astunparse 1.6.3
async-timeout 3.0.1
attrs 21.2.0
basicsr 1.3.3.4
cached-property 1.5.2
cachetools 4.2.2
certifi 2021.5.30
chardet 4.0.0
cycler 0.10.0
dlib 19.22.99
facexlib 0.1.3.1
flatbuffers 1.12
fsspec 2021.6.1
future 0.18.2
gast 0.4.0
google-auth 1.32.1
google-auth-oauthlib 0.4.4
google-pasta 0.2.0
grpcio 1.39.0
h5py 3.1.0
idna 2.10
imageio 2.9.0
imgaug 0.2.6
importlib-metadata 4.6.1
joblib 1.0.1
keras-nightly 2.7.0.dev2021072800
Keras-Preprocessing 1.1.2
kiwisolver 1.3.1
libclang 11.1.0
lmdb 1.2.1
Markdown 3.3.4
matplotlib 3.4.2
mkl-fft 1.3.0
mkl-random 1.2.1
mkl-service 2.3.0
multidict 5.1.0
networkx 2.6.1
numpy 1.19.5
oauthlib 3.1.1
olefile 0.46
omegaconf 2.0.0
opencv-python 4.5.2.54
opt-einsum 3.3.0
packaging 21.0
pandas 1.3.0
Pillow 8.3.1
pip 21.1.3
protobuf 3.17.3
pyasn1 0.4.8
pyasn1-modules 0.2.8
pyDeprecate 0.3.0
pyparsing 2.4.7
python-dateutil 2.8.1
pytorch-lightning 1.0.8
pytz 2021.1
PyWavelets 1.1.1
PyYAML 5.4.1
requests 2.25.1
requests-oauthlib 1.3.0
rsa 4.7.2
scikit-image 0.18.2
scikit-learn 0.24.2
scipy 1.7.0
setuptools 52.0.0.post20210125
six 1.15.0
sklearn 0.0
tb-nightly 2.6.0a20210728
tensorboard-data-server 0.6.1
tensorboard-plugin-wit 1.8.0
termcolor 1.1.0
test-tube 0.7.5
tf-estimator-nightly 2.7.0.dev2021072801
tf-nightly 2.7.0.dev20210728
threadpoolctl 2.2.0
tifffile 2021.7.2
torch 1.7.1
torchaudio 0.7.0a0+a853dff
torchmetrics 0.4.1
torchvision 0.8.2
tqdm 4.61.2
typing-extensions 3.7.4.3
urllib3 1.26.6
Werkzeug 2.0.1
wheel 0.36.2
wrapt 1.12.1
yapf 0.31.0
yarl 1.6.3
zipp 3.5.0
================================================
FILE: scripts/metrics/cal_fid.py
================================================
import os, sys
import argparse
import math
import numpy as np
import torch
from torch.utils.data import DataLoader
from basicsr.data import build_dataset
from basicsr.metrics.fid import calculate_fid, extract_inception_features, load_patched_inception_v3
def calculate_fid_folder():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = argparse.ArgumentParser()
parser.add_argument('folder', type=str, help='Path to the folder.')
parser.add_argument('--fid_stats', type=str, help='Path to the dataset fid statistics.')
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--num_sample', type=int, default=50000)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--backend', type=str, default='disk', help='io backend for dataset. Option: disk, lmdb')
parser.add_argument('--save_name', type=str, default='fid', help='File name for saving results')
args = parser.parse_args()
# inception model
inception = load_patched_inception_v3(device)
# create dataset
opt = {}
opt['name'] = 'SingleImageDataset'
opt['type'] = 'SingleImageDataset'
opt['dataroot_lq'] = args.folder
opt['io_backend'] = dict(type=args.backend)
opt['mean'] = [0.5, 0.5, 0.5]
opt['std'] = [0.5, 0.5, 0.5]
dataset = build_dataset(opt)
# create dataloader
data_loader = DataLoader(
dataset=dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
sampler=None,
drop_last=False)
args.num_sample = min(args.num_sample, len(dataset))
total_batch = math.ceil(args.num_sample / args.batch_size)
def data_generator(data_loader, total_batch):
for idx, data in enumerate(data_loader):
if idx >= total_batch:
break
else:
yield data['lq']
features = extract_inception_features(data_generator(data_loader, total_batch), inception, total_batch, device)
features = features.numpy()
total_len = features.shape[0]
features = features[:args.num_sample]
# print(f'Extracted {total_len} features, ' f'use the first {features.shape[0]} features to calculate stats.')
sample_mean = np.mean(features, 0)
sample_cov = np.cov(features, rowvar=False)
# load the dataset stats
stats = torch.load(args.fid_stats)
real_mean = stats['mean']
real_cov = stats['cov']
# calculate FID metric
fid = calculate_fid(sample_mean, sample_cov, real_mean, real_cov)
fout=open(args.save_name, 'w')
fout.write(str(fid)+'\n')
fout.close()
print(args.folder)
print('fid:', fid)
if __name__ == '__main__':
calculate_fid_folder()
================================================
FILE: scripts/metrics/cal_identity_distance.py
================================================
import os, sys
import torch
import argparse
import cv2
import numpy as np
import glob
import pdb
import tqdm
from copy import deepcopy
import torch.nn.functional as F
import math
root_path = os.path.abspath(os.path.join(__file__, os.path.pardir, os.path.pardir, os.path.pardir))
sys.path.append(root_path)
sys.path.append(os.path.join(root_path, 'RestoreFormer/modules/losses'))
from RestoreFormer.modules.vqvae.arcface_arch import ResNetArcFace
from basicsr.losses.losses import L1Loss, MSELoss
def cosine_similarity(emb1, emb2):
return np.arccos(np.dot(emb1, emb2) / ( np.linalg.norm(emb1) * np.linalg.norm(emb2)))
def gray_resize_for_identity(out, size=128):
out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
out_gray = out_gray.unsqueeze(1)
out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
return out_gray
def calculate_identity_distance_folder():
parser = argparse.ArgumentParser()
parser.add_argument('folder', type=str, help='Path to the folder')
parser.add_argument('--gt_folder', type=str, help='Path to the GT')
parser.add_argument('--save_name', type=str, default='niqe', help='File name for saving results')
parser.add_argument('--need_post', type=int, default=0, help='0: the name of image does not include 00, 1: otherwise')
args = parser.parse_args()
fout = open(args.save_name, 'w')
identity = ResNetArcFace(block = 'IRBlock',
layers = [2, 2, 2, 2],
use_se = False)
identity_model_path = 'experiments/pretrained_models/arcface_resnet18.pth'
sd = torch.load(identity_model_path, map_location="cpu")
for k, v in deepcopy(sd).items():
if k.startswith('module.'):
sd[k[7:]] = v
sd.pop(k)
identity.load_state_dict(sd, strict=True)
identity.eval()
for param in identity.parameters():
param.requires_grad = False
identity = identity.cuda()
gt_names = glob.glob(os.path.join(args.gt_folder, '*'))
gt_names.sort()
mean_dist = 0.
for i in tqdm.tqdm(range(len(gt_names))):
gt_name = gt_names[i].split('/')[-1][:-4]
if args.need_post:
img_name = os.path.join(args.folder,gt_name + '_00.png')
else:
img_name = os.path.join(args.folder,gt_name + '.png')
if not os.path.exists(img_name):
print(img_name, 'does not exist')
continue
img = cv2.imread(img_name)
gt = cv2.imread(gt_names[i])
img = img.astype(np.float32) / 255.
img = torch.FloatTensor(img).cuda()
img = img.permute(2,0,1)
img = img.unsqueeze(0)
gt = gt.astype(np.float32) / 255.
gt = torch.FloatTensor(gt).cuda()
gt = gt.permute(2,0,1)
gt = gt.unsqueeze(0)
out_gray = gray_resize_for_identity(img)
gt_gray = gray_resize_for_identity(gt)
with torch.no_grad():
identity_gt = identity(gt_gray)
identity_out = identity(out_gray)
identity_gt = identity_gt.cpu().data.numpy().squeeze()
identity_out = identity_out.cpu().data.numpy().squeeze()
identity_loss = cosine_similarity(identity_gt, identity_out)
fout.write(gt_name + ' ' + str(identity_loss) + '\n')
mean_dist += identity_loss
fout.write('Mean: ' + str(mean_dist / len(gt_names)) + '\n')
fout.close()
print('mean_dist:', mean_dist / len(gt_names))
if __name__ == '__main__':
calculate_identity_distance_folder()
================================================
FILE: scripts/metrics/cal_psnr_ssim.py
================================================
import os, sys
import argparse
import cv2
import numpy as np
import glob
import pdb
import tqdm
import torch
from basicsr.metrics.psnr_ssim import calculate_psnr, calculate_ssim
root_path = os.path.abspath(os.path.join(__file__, os.path.pardir, os.path.pardir, os.path.pardir))
sys.path.append(root_path)
sys.path.append(os.path.join(root_path, 'RestoreFormer/modules/losses'))
from lpips import LPIPS
def calculate_psnr_ssim_lpips_folder():
parser = argparse.ArgumentParser()
parser.add_argument('folder', type=str, help='Path to the folder')
parser.add_argument('--gt_folder', type=str, help='Path to the GT')
parser.add_argument('--save_name', type=str, default='niqe', help='File name for saving results')
parser.add_argument('--need_post', type=int, default=0, help='0: the name of image does not include 00, 1: otherwise')
args = parser.parse_args()
fout = open(args.save_name, 'w')
fout.write('NAME\tPSNR\tSSIM\tLPIPS\n')
H, W = 512, 512
gt_names = glob.glob(os.path.join(args.gt_folder, '*'))
gt_names.sort()
perceptual_loss = LPIPS().eval().cuda()
mean_psnr = 0.
mean_ssim = 0.
mean_lpips = 0.
mean_norm_lpips = 0.
for i in tqdm.tqdm(range(len(gt_names))):
gt_name = gt_names[i].split('/')[-1][:-4]
if args.need_post:
img_name = os.path.join(args.folder,gt_name + '_00.png')
else:
img_name = os.path.join(args.folder,gt_name + '.png')
if not os.path.exists(img_name):
print(img_name, 'does not exist')
continue
img = cv2.imread(img_name)
gt = cv2.imread(gt_names[i])
cur_psnr = calculate_psnr(img, gt, 0)
cur_ssim = calculate_ssim(img, gt, 0)
# lpips:
img = img.astype(np.float32) / 255.
img = torch.FloatTensor(img).cuda()
img = img.permute(2,0,1)
img = img.unsqueeze(0)
gt = gt.astype(np.float32) / 255.
gt = torch.FloatTensor(gt).cuda()
gt = gt.permute(2,0,1)
gt = gt.unsqueeze(0)
cur_lpips = perceptual_loss(img, gt)
cur_lpips = cur_lpips[0].item()
img = (img - 0.5) / 0.5
gt = (gt - 0.5) / 0.5
norm_lpips = perceptual_loss(img, gt)
norm_lpips = norm_lpips[0].item()
# print(cur_psnr, cur_ssim, cur_lpips, norm_lpips)
fout.write(gt_name + '\t' + str(cur_psnr) + '\t' + str(cur_ssim) + '\t' + str(cur_lpips) + '\t' + str(norm_lpips) + '\n')
mean_psnr += cur_psnr
mean_ssim += cur_ssim
mean_lpips += cur_lpips
mean_norm_lpips += norm_lpips
mean_psnr /= float(len(gt_names))
mean_ssim /= float(len(gt_names))
mean_lpips /= float(len(gt_names))
mean_norm_lpips /= float(len(gt_names))
fout.write(str(mean_psnr) + '\t' + str(mean_ssim) + '\t' + str(mean_lpips) + '\t' + str(mean_norm_lpips) + '\n')
fout.close()
print('psnr, ssim, lpips, norm_lpips:', mean_psnr, mean_ssim, mean_lpips, mean_norm_lpips)
if __name__ == '__main__':
calculate_psnr_ssim_lpips_folder()
================================================
FILE: scripts/metrics/run.sh
================================================
### Journal ###
root='results/'
out_root='results/metrics'
test_name='RestoreFormer'
test_image=$test_name'/restored_faces'
out_name=$test_name
need_post=1
CelebAHQ_GT='YOUR_PATH'
# FID
python -u scripts/metrics/cal_fid.py \
$root'/'$test_image \
--fid_stats 'experiments/pretrained_models/inception_FFHQ_512-f7b384ab.pth' \
--save_name $out_root'/'$out_name'_fid.txt' \
if [ -d $CelebAHQ_GT ]
then
# PSRN SSIM LPIPS
python -u scripts/metrics/cal_psnr_ssim.py \
$root'/'$test_image \
--gt_folder $CelebAHQ_GT \
--save_name $out_root'/'$out_name'_psnr_ssim_lpips.txt' \
--need_post $need_post \
# # # PSRN SSIM LPIPS
python -u scripts/metrics/cal_identity_distance.py \
$root'/'$test_image \
--gt_folder $CelebAHQ_GT \
--save_name $out_root'/'$out_name'_id.txt' \
--need_post $need_post
else
echo 'The path of GT does not exist'
fi
================================================
FILE: scripts/run.sh
================================================
export BASICSR_JIT=True
conf_name='HQ_Dictionary'
# conf_name='RestoreFormer'
ROOT_PATH='' # The path for saving model and logs
gpus='0,1,2,3'
#P: pretrain SL: soft learning
node_n=1
python -u main.py \
--root-path $ROOT_PATH \
--base 'configs/'$conf_name'.yaml' \
-t True \
--postfix $conf_name \
--gpus $gpus \
--num-nodes $node_n \
--random-seed True \
================================================
FILE: scripts/test.py
================================================
import argparse, os, sys, glob, math, time
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
import pdb
sys.path.append(os.getcwd())
from main import instantiate_from_config, DataModuleFromConfig
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from tqdm import trange, tqdm
import cv2
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from torchvision.transforms.functional import normalize
from basicsr.utils import img2tensor, imwrite, tensor2img
def restoration(model,
face_helper,
img_path,
save_root,
has_aligned=False,
only_center_face=True,
suffix=None,
paste_back=False):
# read image
img_name = os.path.basename(img_path)
# print(f'Processing {img_name} ...')
basename, _ = os.path.splitext(img_name)
input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
face_helper.clean_all()
if has_aligned:
input_img = cv2.resize(input_img, (512, 512))
face_helper.cropped_faces = [input_img]
else:
face_helper.read_image(input_img)
# get face landmarks for each face
face_helper.get_face_landmarks_5(only_center_face=only_center_face, pad_blur=False)
# align and warp each face
save_crop_path = os.path.join(save_root, 'cropped_faces', img_name)
face_helper.align_warp_face(save_crop_path)
# face restoration
for idx, cropped_face in enumerate(face_helper.cropped_faces):
# prepare data
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to('cuda')
try:
with torch.no_grad():
output = model(cropped_face_t)
restored_face = tensor2img(output[0].squeeze(0), rgb2bgr=True, min_max=(-1, 1))
except RuntimeError as error:
print(f'\tFailed inference for GFPGAN: {error}.')
restored_face = cropped_face
restored_face = restored_face.astype('uint8')
face_helper.add_restored_face(restored_face)
if suffix is not None:
save_face_name = f'{basename}_{idx:02d}_{suffix}.png'
else:
save_face_name = f'{basename}_{idx:02d}.png'
save_restore_path = os.path.join(save_root, 'restored_faces', save_face_name)
imwrite(restored_face, save_restore_path)
if not has_aligned and paste_back:
face_helper.get_inverse_affine(None)
save_restore_path = os.path.join(save_root, 'restored_imgs', img_name)
# paste each restored face to the input image
face_helper.paste_faces_to_input_image(save_restore_path)
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-r",
"--resume",
type=str,
nargs="?",
help="load from logdir or checkpoint in logdir",
)
parser.add_argument(
"-b",
"--base",
nargs="*",
metavar="base_config.yaml",
help="paths to base configs. Loaded from left-to-right. "
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
default=list(),
)
parser.add_argument(
"-c",
"--config",
nargs="?",
metavar="single_config.yaml",
help="path to single config. If specified, base configs will be ignored "
"(except for the last one if left unspecified).",
const=True,
default="",
)
parser.add_argument(
"--ignore_base_data",
action="store_true",
help="Ignore data specification from base configs. Useful if you want "
"to specify a custom datasets on the command line.",
)
parser.add_argument(
"--outdir",
required=True,
type=str,
help="Where to write outputs to.",
)
parser.add_argument(
"--top_k",
type=int,
default=100,
help="Sample from among top-k predictions.",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="Sampling temperature.",
)
parser.add_argument('--upscale_factor', type=int, default=1)
parser.add_argument('--test_path', type=str, default='inputs/whole_imgs')
parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
parser.add_argument('--only_center_face', action='store_true')
parser.add_argument('--aligned', action='store_true')
parser.add_argument('--paste_back', action='store_true')
return parser
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
if "ckpt_path" in config.params:
print("Deleting the restore-ckpt path from the config...")
config.params.ckpt_path = None
if "downsample_cond_size" in config.params:
print("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
config.params.downsample_cond_size = -1
config.params["downsample_cond_factor"] = 0.5
try:
if "ckpt_path" in config.params.first_stage_config.params:
config.params.first_stage_config.params.ckpt_path = None
print("Deleting the first-stage restore-ckpt path from the config...")
if "ckpt_path" in config.params.cond_stage_config.params:
config.params.cond_stage_config.params.ckpt_path = None
print("Deleting the cond-stage restore-ckpt path from the config...")
except:
pass
model = instantiate_from_config(config)
if sd is not None:
keys = list(sd.keys())
state_dict = model.state_dict()
require_keys = state_dict.keys()
keys = sd.keys()
un_pretrained_keys = []
for k in require_keys:
if k not in keys:
# miss 'vqvae.'
if k[6:] in keys:
state_dict[k] = sd[k[6:]]
else:
un_pretrained_keys.append(k)
else:
state_dict[k] = sd[k]
# print(f'*************************************************')
# print(f"Layers without pretraining: {un_pretrained_keys}")
# print(f'*************************************************')
model.load_state_dict(state_dict, strict=True)
if gpu:
model.cuda()
if eval_mode:
model.eval()
return {"model": model}
def load_model_and_dset(config, ckpt, gpu, eval_mode):
# now load the specified checkpoint
if ckpt:
pl_sd = torch.load(ckpt, map_location="cpu")
else:
pl_sd = {"state_dict": None}
model = load_model_from_config(config.model,
pl_sd["state_dict"],
gpu=gpu,
eval_mode=eval_mode)["model"]
return model
if __name__ == "__main__":
sys.path.append(os.getcwd())
parser = get_parser()
opt, unknown = parser.parse_known_args()
ckpt = None
if opt.resume:
if not os.path.exists(opt.resume):
raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume):
paths = opt.resume.split("/")
try:
idx = len(paths)-paths[::-1].index("logs")+1
except ValueError:
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
logdir = "/".join(paths[:idx])
ckpt = opt.resume
else:
assert os.path.isdir(opt.resume), opt.resume
logdir = opt.resume.rstrip("/")
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
print(f"logdir:{logdir}")
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
opt.base = base_configs+opt.base
if opt.config:
if type(opt.config) == str:
if not os.path.exists(opt.config):
raise ValueError("Cannot find {}".format(opt.config))
if os.path.isfile(opt.config):
opt.base = [opt.config]
else:
opt.base = sorted(glob.glob(os.path.join(opt.config, "*-project.yaml")))
else:
opt.base = [opt.base[-1]]
configs = [OmegaConf.load(cfg) for cfg in opt.base]
cli = OmegaConf.from_dotlist(unknown)
if opt.ignore_base_data:
for config in configs:
if hasattr(config, "data"): del config["data"]
config = OmegaConf.merge(*configs, cli)
print(config)
gpu = True
eval_mode = True
show_config = False
if show_config:
print(OmegaConf.to_container(config))
model = load_model_and_dset(config, ckpt, gpu, eval_mode)
outdir = opt.outdir
os.makedirs(outdir, exist_ok=True)
print("Writing samples to ", outdir)
# initialize face helper
face_helper = FaceRestoreHelper(
opt.upscale_factor, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png')
img_list = sorted(glob.glob(os.path.join(opt.test_path, '*')))
print('Results are in the <{}> folder.'.format(outdir))
for img_path in tqdm(img_list):
restoration(
model,
face_helper,
img_path,
outdir,
has_aligned=opt.aligned,
only_center_face=opt.only_center_face,
suffix=opt.suffix,
paste_back=opt.paste_back)
print('Test number: ', len(img_list))
print('Results are in the <{}> folder.'.format(outdir))
================================================
FILE: scripts/test.sh
================================================
# # ### Good
exp_name='RestoreFormer'
root_path='experiments'
out_root_path='results'
align_test_path='data/test'
tag='test'
outdir=$out_root_path'/'$exp_name'_'$tag
if [ ! -d $outdir ];then
mkdir $outdir
fi
python -u scripts/test.py \
--outdir $outdir \
-r $root_path'/'$exp_name'/last.ckpt' \
-c 'configs/RestoreFormer.yaml' \
--test_path $align_test_path \
--aligned
gitextract_istv2bdf/
├── .gitignore
├── LICENSE
├── README.md
├── RestoreFormer/
│ ├── data/
│ │ └── ffhq_degradation_dataset.py
│ ├── distributed/
│ │ ├── __init__.py
│ │ ├── distributed.py
│ │ └── launch.py
│ ├── models/
│ │ └── vqgan_v1.py
│ ├── modules/
│ │ ├── discriminator/
│ │ │ └── model.py
│ │ ├── losses/
│ │ │ ├── __init__.py
│ │ │ ├── lpips.py
│ │ │ └── vqperceptual.py
│ │ ├── util.py
│ │ └── vqvae/
│ │ ├── arcface_arch.py
│ │ ├── facial_component_discriminator.py
│ │ ├── utils.py
│ │ └── vqvae_arch.py
│ └── util.py
├── __init__.py
├── configs/
│ ├── HQ_Dictionary.yaml
│ └── RestoreFormer.yaml
├── main.py
├── restoreformer_requirement.txt
└── scripts/
├── metrics/
│ ├── cal_fid.py
│ ├── cal_identity_distance.py
│ ├── cal_psnr_ssim.py
│ └── run.sh
├── run.sh
├── test.py
└── test.sh
SYMBOL INDEX (172 symbols across 18 files)
FILE: RestoreFormer/data/ffhq_degradation_dataset.py
class FFHQDegradationDataset (line 20) | class FFHQDegradationDataset(data.Dataset):
method __init__ (line 22) | def __init__(self, opt):
method color_jitter (line 82) | def color_jitter(img, shift):
method color_jitter_pt (line 89) | def color_jitter_pt(img, brightness, contrast, saturation, hue):
method get_component_coordinates (line 109) | def get_component_coordinates(self, index, status):
method __getitem__ (line 133) | def __getitem__(self, index):
method __len__ (line 217) | def __len__(self):
FILE: RestoreFormer/distributed/distributed.py
function is_primary (line 12) | def is_primary():
function get_rank (line 16) | def get_rank():
function get_local_rank (line 26) | def get_local_rank():
function synchronize (line 39) | def synchronize():
function get_world_size (line 54) | def get_world_size():
function all_reduce (line 64) | def all_reduce(tensor, op=dist.ReduceOp.SUM):
function all_gather (line 75) | def all_gather(data):
function reduce_dict (line 110) | def reduce_dict(input_dict, average=True):
function data_sampler (line 135) | def data_sampler(dataset, shuffle, distributed):
FILE: RestoreFormer/distributed/launch.py
function find_free_port (line 10) | def find_free_port():
function launch (line 22) | def launch(fn, n_gpu_per_machine, n_machine=1, machine_rank=0, dist_url=...
function distributed_worker (line 52) | def distributed_worker(
FILE: RestoreFormer/models/vqgan_v1.py
class RestoreFormerModel (line 8) | class RestoreFormerModel(pl.LightningModule):
method __init__ (line 9) | def __init__(self,
method init_from_ckpt (line 43) | def init_from_ckpt(self, path, ignore_keys=list()):
method forward (line 74) | def forward(self, input):
method training_step (line 78) | def training_step(self, batch, batch_idx, optimizer_idx):
method validation_step (line 142) | def validation_step(self, batch, batch_idx):
method configure_optimizers (line 164) | def configure_optimizers(self):
method get_last_layer (line 207) | def get_last_layer(self):
method log_images (line 212) | def log_images(self, batch, **kwargs):
FILE: RestoreFormer/modules/discriminator/model.py
function weights_init (line 8) | def weights_init(m):
class NLayerDiscriminator (line 17) | class NLayerDiscriminator(nn.Module):
method __init__ (line 21) | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
method forward (line 65) | def forward(self, input):
class NLayerDiscriminator_v1 (line 69) | class NLayerDiscriminator_v1(nn.Module):
method __init__ (line 73) | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
method forward (line 123) | def forward(self, input):
FILE: RestoreFormer/modules/losses/lpips.py
class LPIPS (line 11) | class LPIPS(nn.Module):
method __init__ (line 13) | def __init__(self, use_dropout=True, style_weight=0.):
method load_from_pretrained (line 29) | def load_from_pretrained(self, name="vgg_lpips"):
method from_pretrained (line 35) | def from_pretrained(cls, name="vgg_lpips"):
method forward (line 43) | def forward(self, input, target):
method _gram_mat (line 63) | def _gram_mat(self, x):
class ScalingLayer (line 79) | class ScalingLayer(nn.Module):
method __init__ (line 80) | def __init__(self):
method forward (line 85) | def forward(self, inp):
class NetLinLayer (line 89) | class NetLinLayer(nn.Module):
method __init__ (line 91) | def __init__(self, chn_in, chn_out=1, use_dropout=False):
class vgg16 (line 98) | class vgg16(torch.nn.Module):
method __init__ (line 99) | def __init__(self, requires_grad=False, pretrained=True):
method forward (line 122) | def forward(self, X):
function normalize_tensor (line 138) | def normalize_tensor(x,eps=1e-10):
function spatial_average (line 143) | def spatial_average(x, keepdim=True):
FILE: RestoreFormer/modules/losses/vqperceptual.py
class DummyLoss (line 13) | class DummyLoss(nn.Module):
method __init__ (line 14) | def __init__(self):
function adopt_weight (line 18) | def adopt_weight(weight, global_step, threshold=0, value=0.):
function hinge_d_loss (line 24) | def hinge_d_loss(logits_real, logits_fake):
function vanilla_d_loss (line 31) | def vanilla_d_loss(logits_real, logits_fake):
class VQLPIPSWithDiscriminatorWithCompWithIdentity (line 38) | class VQLPIPSWithDiscriminatorWithCompWithIdentity(nn.Module):
method __init__ (line 39) | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
method calculate_adaptive_weight (line 105) | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
method _gram_mat (line 118) | def _gram_mat(self, x):
method gray_resize_for_identity (line 133) | def gray_resize_for_identity(self, out, size=128):
method forward (line 139) | def forward(self, codebook_loss, gts, reconstructions, components, opt...
FILE: RestoreFormer/modules/util.py
function count_params (line 5) | def count_params(model):
class ActNorm (line 10) | class ActNorm(nn.Module):
method __init__ (line 11) | def __init__(self, num_features, logdet=False, affine=True,
method initialize (line 22) | def initialize(self, input):
method forward (line 43) | def forward(self, input, reverse=False):
method reverse (line 71) | def reverse(self, output):
class Attention2DConv (line 95) | class Attention2DConv(nn.Module):
method __init__ (line 97) | def __init__(self):
FILE: RestoreFormer/modules/vqvae/arcface_arch.py
function conv3x3 (line 6) | def conv3x3(in_planes, out_planes, stride=1):
class BasicBlock (line 11) | class BasicBlock(nn.Module):
method __init__ (line 14) | def __init__(self, inplanes, planes, stride=1, downsample=None):
method forward (line 24) | def forward(self, x):
class IRBlock (line 43) | class IRBlock(nn.Module):
method __init__ (line 46) | def __init__(self, inplanes, planes, stride=1, downsample=None, use_se...
method forward (line 60) | def forward(self, x):
class Bottleneck (line 81) | class Bottleneck(nn.Module):
method __init__ (line 84) | def __init__(self, inplanes, planes, stride=1, downsample=None):
method forward (line 96) | def forward(self, x):
class SEBlock (line 119) | class SEBlock(nn.Module):
method __init__ (line 121) | def __init__(self, channel, reduction=16):
method forward (line 128) | def forward(self, x):
class ResNetArcFace (line 136) | class ResNetArcFace(nn.Module):
method __init__ (line 138) | def __init__(self, block, layers, use_se=True):
method _make_layer (line 167) | def _make_layer(self, block, planes, blocks, stride=1):
method forward (line 182) | def forward(self, x):
FILE: RestoreFormer/modules/vqvae/facial_component_discriminator.py
class FacialComponentDiscriminator (line 14) | class FacialComponentDiscriminator(nn.Module):
method __init__ (line 16) | def __init__(self):
method forward (line 26) | def forward(self, x, return_feats=False):
FILE: RestoreFormer/modules/vqvae/utils.py
function get_roi_regions (line 4) | def get_roi_regions(gt, output, loc_left_eyes, loc_right_eyes, loc_mouths,
FILE: RestoreFormer/modules/vqvae/vqvae_arch.py
class VectorQuantizer (line 11) | class VectorQuantizer(nn.Module):
method __init__ (line 23) | def __init__(self, n_e, e_dim, beta):
method forward (line 32) | def forward(self, z):
method get_codebook_entry (line 94) | def get_codebook_entry(self, indices, shape):
function nonlinearity (line 112) | def nonlinearity(x):
function Normalize (line 117) | def Normalize(in_channels):
class Upsample (line 121) | class Upsample(nn.Module):
method __init__ (line 122) | def __init__(self, in_channels, with_conv):
method forward (line 132) | def forward(self, x):
class Downsample (line 139) | class Downsample(nn.Module):
method __init__ (line 140) | def __init__(self, in_channels, with_conv):
method forward (line 151) | def forward(self, x):
class ResnetBlock (line 161) | class ResnetBlock(nn.Module):
method __init__ (line 162) | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=Fa...
method forward (line 200) | def forward(self, x, temb):
class MultiHeadAttnBlock (line 223) | class MultiHeadAttnBlock(nn.Module):
method __init__ (line 224) | def __init__(self, in_channels, head_size=1):
method forward (line 256) | def forward(self, x, y=None):
class MultiHeadEncoder (line 300) | class MultiHeadEncoder(nn.Module):
method __init__ (line 301) | def __init__(self, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks=2,
method forward (line 367) | def forward(self, x):
class MultiHeadDecoder (line 406) | class MultiHeadDecoder(nn.Module):
method __init__ (line 407) | def __init__(self, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks=2,
method forward (line 479) | def forward(self, z):
class MultiHeadDecoderTransformer (line 513) | class MultiHeadDecoderTransformer(nn.Module):
method __init__ (line 514) | def __init__(self, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks=2,
method forward (line 586) | def forward(self, z, hs):
class VQVAEGAN (line 622) | class VQVAEGAN(nn.Module):
method __init__ (line 623) | def __init__(self, n_embed=1024, embed_dim=256, ch=128, out_ch=3, ch_m...
method encode (line 653) | def encode(self, x):
method decode (line 660) | def decode(self, quant):
method forward (line 666) | def forward(self, input):
class VQVAEGANMultiHeadTransformer (line 672) | class VQVAEGANMultiHeadTransformer(nn.Module):
method __init__ (line 673) | def __init__(self, n_embed=1024, embed_dim=256, ch=128, out_ch=3, ch_m...
method encode (line 708) | def encode(self, x):
method decode (line 715) | def decode(self, quant, hs):
method forward (line 721) | def forward(self, input):
FILE: RestoreFormer/util.py
function download (line 18) | def download(url, local_path, chunk_size=1024):
function md5_hash (line 30) | def md5_hash(path):
function get_ckpt_path (line 36) | def get_ckpt_path(name, root, check=False):
class KeyNotFoundError (line 47) | class KeyNotFoundError(Exception):
method __init__ (line 48) | def __init__(self, cause, keys=None, visited=None):
function retrieve (line 62) | def retrieve(
FILE: main.py
function get_obj_from_str (line 15) | def get_obj_from_str(string, reload=False):
function get_parser (line 23) | def get_parser(**parser_kwargs):
function nondefault_trainer_args (line 137) | def nondefault_trainer_args(opt):
function instantiate_from_config (line 144) | def instantiate_from_config(config):
class WrappedDataset (line 153) | class WrappedDataset(Dataset):
method __init__ (line 155) | def __init__(self, dataset):
method __len__ (line 158) | def __len__(self):
method __getitem__ (line 161) | def __getitem__(self, idx):
class DataModuleFromConfig (line 165) | class DataModuleFromConfig(pl.LightningDataModule):
method __init__ (line 166) | def __init__(self, batch_size, train=None, validation=None, test=None,
method prepare_data (line 183) | def prepare_data(self):
method setup (line 187) | def setup(self, stage=None):
method _train_dataloader (line 195) | def _train_dataloader(self):
method _val_dataloader (line 199) | def _val_dataloader(self):
method _test_dataloader (line 204) | def _test_dataloader(self):
class SetupCallback (line 209) | class SetupCallback(Callback):
method __init__ (line 210) | def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, light...
method on_pretrain_routine_start (line 220) | def on_pretrain_routine_start(self, trainer, pl_module):
class ImageLogger (line 240) | class ImageLogger(Callback):
method __init__ (line 241) | def __init__(self, batch_frequency, max_images, clamp=True, increase_l...
method _wandb (line 255) | def _wandb(self, pl_module, images, batch_idx, split):
method _testtube (line 264) | def _testtube(self, pl_module, images, batch_idx, split):
method log_local (line 275) | def log_local(self, save_dir, split, images,
method log_img (line 294) | def log_img(self, pl_module, batch, batch_idx, split="train"):
method check_frequency (line 325) | def check_frequency(self, batch_idx):
method on_train_batch_end (line 334) | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch...
method on_validation_batch_end (line 337) | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, ...
function melk (line 578) | def melk(*args, **kwargs):
function divein (line 585) | def divein(*args, **kwargs):
FILE: scripts/metrics/cal_fid.py
function calculate_fid_folder (line 12) | def calculate_fid_folder():
FILE: scripts/metrics/cal_identity_distance.py
function cosine_similarity (line 21) | def cosine_similarity(emb1, emb2):
function gray_resize_for_identity (line 25) | def gray_resize_for_identity(out, size=128):
function calculate_identity_distance_folder (line 31) | def calculate_identity_distance_folder():
FILE: scripts/metrics/cal_psnr_ssim.py
function calculate_psnr_ssim_lpips_folder (line 18) | def calculate_psnr_ssim_lpips_folder():
FILE: scripts/test.py
function restoration (line 22) | def restoration(model,
function get_parser (line 80) | def get_parser():
function load_model_from_config (line 142) | def load_model_from_config(config, sd, gpu=True, eval_mode=True):
function load_model_and_dset (line 191) | def load_model_and_dset(config, ckpt, gpu, eval_mode):
Condensed preview — 30 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (170K chars).
[
{
"path": ".gitignore",
"chars": 113,
"preview": "data/FFHQ\nscripts/data_synthetic\nexperiments/\nscripts/run_clustre.sh\nsftp-config.json\nresults/\n# scripts/metrics\n"
},
{
"path": "LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 5640,
"preview": "# We have merged the code of RestoreFormer into our journal version, RestoreFormer++. Please feel free to access the res"
},
{
"path": "RestoreFormer/data/ffhq_degradation_dataset.py",
"chars": 9897,
"preview": "import os\nimport cv2\nimport math\nimport numpy as np\nimport random\nimport os.path as osp\nimport torch\nimport torch.utils."
},
{
"path": "RestoreFormer/distributed/__init__.py",
"chars": 235,
"preview": "from .distributed import (\n get_rank,\n get_local_rank,\n is_primary,\n synchronize,\n get_world_size,\n al"
},
{
"path": "RestoreFormer/distributed/distributed.py",
"chars": 3025,
"preview": "import math\nimport pickle\n\nimport torch\nfrom torch import distributed as dist\nfrom torch.utils import data\n\n\nLOCAL_PROCE"
},
{
"path": "RestoreFormer/distributed/launch.py",
"chars": 2513,
"preview": "import os\n\nimport torch\nfrom torch import distributed as dist\nfrom torch import multiprocessing as mp\n\nfrom . import dis"
},
{
"path": "RestoreFormer/models/vqgan_v1.py",
"chars": 9607,
"preview": "import torch\nimport torch.nn.functional as F\nimport pytorch_lightning as pl\nfrom main import instantiate_from_config\n\nfr"
},
{
"path": "RestoreFormer/modules/discriminator/model.py",
"chars": 5430,
"preview": "import functools\nimport torch.nn as nn\n\n\nfrom RestoreFormer.modules.util import ActNorm\n\n\ndef weights_init(m):\n class"
},
{
"path": "RestoreFormer/modules/losses/__init__.py",
"chars": 2,
"preview": "\n\n"
},
{
"path": "RestoreFormer/modules/losses/lpips.py",
"chars": 5583,
"preview": "\"\"\"Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models\"\"\"\n\nimport torch\nimport torc"
},
{
"path": "RestoreFormer/modules/losses/vqperceptual.py",
"chars": 13251,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom copy import deepcopy\n\nfrom RestoreFormer.modules"
},
{
"path": "RestoreFormer/modules/util.py",
"chars": 2900,
"preview": "import torch\nimport torch.nn as nn\n\n\ndef count_params(model):\n total_params = sum(p.numel() for p in model.parameters"
},
{
"path": "RestoreFormer/modules/vqvae/arcface_arch.py",
"chars": 6109,
"preview": "import torch.nn as nn\n\nfrom basicsr.utils.registry import ARCH_REGISTRY\n\n\ndef conv3x3(in_planes, out_planes, stride=1):\n"
},
{
"path": "RestoreFormer/modules/vqvae/facial_component_discriminator.py",
"chars": 1646,
"preview": "import math\nimport random\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom basicsr.archs.sty"
},
{
"path": "RestoreFormer/modules/vqvae/utils.py",
"chars": 1754,
"preview": "from torchvision.ops import roi_align\nimport torch\n\ndef get_roi_regions(gt, output, loc_left_eyes, loc_right_eyes, loc_m"
},
{
"path": "RestoreFormer/modules/vqvae/vqvae_arch.py",
"chars": 28124,
"preview": "import torch\nimport torch.nn as nn\nimport random\nimport math\nimport torch.nn.functional as F\nimport numpy as np\n# from b"
},
{
"path": "RestoreFormer/util.py",
"chars": 4777,
"preview": "import os, hashlib\nimport requests\nfrom tqdm import tqdm\n\nURL_MAP = {\n \"vgg_lpips\": \"https://heibox.uni-heidelberg.de"
},
{
"path": "__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "configs/HQ_Dictionary.yaml",
"chars": 1649,
"preview": "model:\n base_learning_rate: 4.5e-6\n target: RestoreFormer.models.vqgan_v1.RestoreFormerModel\n params:\n image_key: "
},
{
"path": "configs/RestoreFormer.yaml",
"chars": 3087,
"preview": "model:\n base_learning_rate: 4.5e-6 \n target: RestoreFormer.models.vqgan_v1.RestoreFormerModel\n params:\n image_key:"
},
{
"path": "main.py",
"chars": 21984,
"preview": "import argparse, os, sys, datetime, glob, importlib\nfrom omegaconf import OmegaConf\nimport numpy as np\nfrom PIL import I"
},
{
"path": "restoreformer_requirement.txt",
"chars": 3079,
"preview": "Package Version Location\n----------------------- ------------------- -----------------------"
},
{
"path": "scripts/metrics/cal_fid.py",
"chars": 2762,
"preview": "import os, sys\nimport argparse\nimport math\nimport numpy as np\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom"
},
{
"path": "scripts/metrics/cal_identity_distance.py",
"chars": 3608,
"preview": "import os, sys\nimport torch\nimport argparse\nimport cv2\nimport numpy as np\nimport glob\nimport pdb\nimport tqdm\nfrom copy i"
},
{
"path": "scripts/metrics/cal_psnr_ssim.py",
"chars": 3080,
"preview": "import os, sys\nimport argparse\nimport cv2\nimport numpy as np\nimport glob\nimport pdb\nimport tqdm\nimport torch\n\nfrom basic"
},
{
"path": "scripts/metrics/run.sh",
"chars": 890,
"preview": "\n### Journal ###\nroot='results/'\nout_root='results/metrics'\n\ntest_name='RestoreFormer'\n\ntest_image=$test_name'/restored_"
},
{
"path": "scripts/run.sh",
"chars": 361,
"preview": "export BASICSR_JIT=True\n\nconf_name='HQ_Dictionary'\n# conf_name='RestoreFormer'\n\nROOT_PATH='' # The path for saving model"
},
{
"path": "scripts/test.py",
"chars": 9786,
"preview": "import argparse, os, sys, glob, math, time\nimport torch\nimport numpy as np\nfrom omegaconf import OmegaConf\nfrom PIL impo"
},
{
"path": "scripts/test.sh",
"chars": 379,
"preview": "# # ### Good\nexp_name='RestoreFormer'\n\nroot_path='experiments'\nout_root_path='results'\nalign_test_path='data/test'\ntag='"
}
]
About this extraction
This page contains the full source code of the wzhouxiff/RestoreFormer GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 30 files (158.8 KB), approximately 39.9k tokens, and a symbol index with 172 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.