Repository: wzhouxiff/RestoreFormer Branch: main Commit: 294cf9521a86 Files: 30 Total size: 158.8 KB Directory structure: gitextract_istv2bdf/ ├── .gitignore ├── LICENSE ├── README.md ├── RestoreFormer/ │ ├── data/ │ │ └── ffhq_degradation_dataset.py │ ├── distributed/ │ │ ├── __init__.py │ │ ├── distributed.py │ │ └── launch.py │ ├── models/ │ │ └── vqgan_v1.py │ ├── modules/ │ │ ├── discriminator/ │ │ │ └── model.py │ │ ├── losses/ │ │ │ ├── __init__.py │ │ │ ├── lpips.py │ │ │ └── vqperceptual.py │ │ ├── util.py │ │ └── vqvae/ │ │ ├── arcface_arch.py │ │ ├── facial_component_discriminator.py │ │ ├── utils.py │ │ └── vqvae_arch.py │ └── util.py ├── __init__.py ├── configs/ │ ├── HQ_Dictionary.yaml │ └── RestoreFormer.yaml ├── main.py ├── restoreformer_requirement.txt └── scripts/ ├── metrics/ │ ├── cal_fid.py │ ├── cal_identity_distance.py │ ├── cal_psnr_ssim.py │ └── run.sh ├── run.sh ├── test.py └── test.sh ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ data/FFHQ scripts/data_synthetic experiments/ scripts/run_clustre.sh sftp-config.json results/ # scripts/metrics ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # We have merged the code of RestoreFormer into our journal version, RestoreFormer++. Please feel free to access the resources from [https://github.com/wzhouxiff/RestoreFormerPlusPlus](https://github.com/wzhouxiff/RestoreFormerPlusPlus) # Updating - **20230915** Update an online demo [![Huggingface Gradio](https://img.shields.io/static/v1?label=Demo&message=Huggingface%20Gradio&color=orange)](https://huggingface.co/spaces/wzhouxiff/RestoreFormerPlusPlus) - **20230915** A more user-friendly and comprehensive inference method refer to our [RestoreFormer++](https://github.com/wzhouxiff/RestoreFormerPlusPlus) - **20230116** For convenience, we further upload the [test datasets](#testset), including CelebA (both HQ and LQ data), LFW-Test, CelebChild-Test, and Webphoto-Test, to OneDrive and BaiduYun. - **20221003** We provide the link of the [test datasets](#testset). - **20220924** We add the code for [**metrics**](#metrics) in scripts/metrics. ================================================ FILE: RestoreFormer/data/ffhq_degradation_dataset.py ================================================ import os import cv2 import math import numpy as np import random import os.path as osp import torch import torch.utils.data as data from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation, normalize) from basicsr.data import degradations as degradations from basicsr.data.data_util import paths_from_folder from basicsr.data.transforms import augment from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor from basicsr.utils.registry import DATASET_REGISTRY @DATASET_REGISTRY.register() class FFHQDegradationDataset(data.Dataset): def __init__(self, opt): super(FFHQDegradationDataset, self).__init__() self.opt = opt # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] self.gt_folder = opt['dataroot_gt'] self.mean = opt['mean'] self.std = opt['std'] self.out_size = opt['out_size'] self.crop_components = opt.get('crop_components', False) # facial components self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) if self.crop_components: self.components_list = torch.load(opt.get('component_path')) if self.io_backend_opt['type'] == 'lmdb': self.io_backend_opt['db_paths'] = self.gt_folder if not self.gt_folder.endswith('.lmdb'): raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: self.paths = [line.split('.')[0] for line in fin] else: self.paths = paths_from_folder(self.gt_folder) # degradations self.blur_kernel_size = opt['blur_kernel_size'] self.kernel_list = opt['kernel_list'] self.kernel_prob = opt['kernel_prob'] self.blur_sigma = opt['blur_sigma'] self.downsample_range = opt['downsample_range'] self.noise_range = opt['noise_range'] self.jpeg_range = opt['jpeg_range'] # color jitter self.color_jitter_prob = opt.get('color_jitter_prob') self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob') self.color_jitter_shift = opt.get('color_jitter_shift', 20) # to gray self.gray_prob = opt.get('gray_prob') logger = get_root_logger() logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, ' f'sigma: [{", ".join(map(str, self.blur_sigma))}]') logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]') logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]') logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]') if self.color_jitter_prob is not None: logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, ' f'shift: {self.color_jitter_shift}') if self.gray_prob is not None: logger.info(f'Use random gray. Prob: {self.gray_prob}') self.color_jitter_shift /= 255. @staticmethod def color_jitter(img, shift): jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32) img = img + jitter_val img = np.clip(img, 0, 1) return img @staticmethod def color_jitter_pt(img, brightness, contrast, saturation, hue): fn_idx = torch.randperm(4) for fn_id in fn_idx: if fn_id == 0 and brightness is not None: brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item() img = adjust_brightness(img, brightness_factor) if fn_id == 1 and contrast is not None: contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item() img = adjust_contrast(img, contrast_factor) if fn_id == 2 and saturation is not None: saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item() img = adjust_saturation(img, saturation_factor) if fn_id == 3 and hue is not None: hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item() img = adjust_hue(img, hue_factor) return img def get_component_coordinates(self, index, status): components_bbox = self.components_list[f'{index:08d}'] if status[0]: # hflip # exchange right and left eye tmp = components_bbox['left_eye'] components_bbox['left_eye'] = components_bbox['right_eye'] components_bbox['right_eye'] = tmp # modify the width coordinate components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0] components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0] components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0] # get coordinates locations = [] for part in ['left_eye', 'right_eye', 'mouth']: mean = components_bbox[part][0:2] half_len = components_bbox[part][2] if 'eye' in part: half_len *= self.eye_enlarge_ratio loc = np.hstack((mean - half_len + 1, mean + half_len)) loc = torch.from_numpy(loc).float() locations.append(loc) return locations def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) # load gt image gt_path = self.paths[index] img_bytes = self.file_client.get(gt_path) img_gt = imfrombytes(img_bytes, float32=True) # random horizontal flip img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True) h, w, _ = img_gt.shape if self.crop_components: locations = self.get_component_coordinates(index, status) loc_left_eye, loc_right_eye, loc_mouth = locations # ------------------------ generate lq image ------------------------ # # blur assert self.blur_kernel_size[0] < self.blur_kernel_size[1], 'Wrong blur kernel size range' cur_kernel_size = random.randint(self.blur_kernel_size[0],self.blur_kernel_size[1]) * 2 + 1 kernel = degradations.random_mixed_kernels( self.kernel_list, self.kernel_prob, cur_kernel_size, self.blur_sigma, self.blur_sigma, [-math.pi, math.pi], noise_range=None) img_lq = cv2.filter2D(img_gt, -1, kernel) # downsample scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1]) img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR) # noise if self.noise_range is not None: img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range) # jpeg compression if self.jpeg_range is not None: img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range) # resize to original size img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR) # random color jitter (only for lq) if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob): img_lq = self.color_jitter(img_lq, self.color_jitter_shift) # random to gray (only for lq) if self.gray_prob and np.random.uniform() < self.gray_prob: img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY) img_lq = np.tile(img_lq[:, :, None], [1, 1, 3]) if self.opt.get('gt_gray'): img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY) img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # BGR to RGB, HWC to CHW, numpy to tensor img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) # random color jitter (pytorch version) (only for lq) if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob): brightness = self.opt.get('brightness', (0.5, 1.5)) contrast = self.opt.get('contrast', (0.5, 1.5)) saturation = self.opt.get('saturation', (0, 1.5)) hue = self.opt.get('hue', (-0.1, 0.1)) img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue) # round and clip img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255. # normalize normalize(img_gt, self.mean, self.std, inplace=True) normalize(img_lq, self.mean, self.std, inplace=True) return_dict = { 'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path } if self.crop_components: return_dict['loc_left_eye'] = loc_left_eye return_dict['loc_right_eye'] = loc_right_eye return_dict['loc_mouth'] = loc_mouth return return_dict def __len__(self): return len(self.paths) import argparse from omegaconf import OmegaConf import pdb from basicsr.utils import img2tensor, imwrite, tensor2img if __name__=='__main__': # pdb.set_trace() base='configs/RestoreFormer.yaml' opt = OmegaConf.load(base) dataset = FFHQDegradationDataset(opt['data']['params']['train']['params']) for i in range(100): sample = dataset.getitem(i) name = sample['gt_path'].split('/')[-1][:-4] gt = tensor2img(sample['gt']) imwrite(gt, +name+'_gt.png') lq = tensor2img(sample['lq']) imwrite(lq, name+'_lq_nojitter.png') ================================================ FILE: RestoreFormer/distributed/__init__.py ================================================ from .distributed import ( get_rank, get_local_rank, is_primary, synchronize, get_world_size, all_reduce, all_gather, reduce_dict, data_sampler, LOCAL_PROCESS_GROUP, ) from .launch import launch ================================================ FILE: RestoreFormer/distributed/distributed.py ================================================ import math import pickle import torch from torch import distributed as dist from torch.utils import data LOCAL_PROCESS_GROUP = None def is_primary(): return get_rank() == 0 def get_rank(): if not dist.is_available(): return 0 if not dist.is_initialized(): return 0 return dist.get_rank() def get_local_rank(): if not dist.is_available(): return 0 if not dist.is_initialized(): return 0 if LOCAL_PROCESS_GROUP is None: raise ValueError("tensorfn.distributed.LOCAL_PROCESS_GROUP is None") return dist.get_rank(group=LOCAL_PROCESS_GROUP) def synchronize(): if not dist.is_available(): return if not dist.is_initialized(): return world_size = dist.get_world_size() if world_size == 1: return dist.barrier() def get_world_size(): if not dist.is_available(): return 1 if not dist.is_initialized(): return 1 return dist.get_world_size() def all_reduce(tensor, op=dist.ReduceOp.SUM): world_size = get_world_size() if world_size == 1: return tensor dist.all_reduce(tensor, op=op) return tensor def all_gather(data): world_size = get_world_size() if world_size == 1: return [data] buffer = pickle.dumps(data) storage = torch.ByteStorage.from_buffer(buffer) tensor = torch.ByteTensor(storage).to("cuda") local_size = torch.IntTensor([tensor.numel()]).to("cuda") size_list = [torch.IntTensor([1]).to("cuda") for _ in range(world_size)] dist.all_gather(size_list, local_size) size_list = [int(size.item()) for size in size_list] max_size = max(size_list) tensor_list = [] for _ in size_list: tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) if local_size != max_size: padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") tensor = torch.cat((tensor, padding), 0) dist.all_gather(tensor_list, tensor) data_list = [] for size, tensor in zip(size_list, tensor_list): buffer = tensor.cpu().numpy().tobytes()[:size] data_list.append(pickle.loads(buffer)) return data_list def reduce_dict(input_dict, average=True): world_size = get_world_size() if world_size < 2: return input_dict with torch.no_grad(): keys = [] values = [] for k in sorted(input_dict.keys()): keys.append(k) values.append(input_dict[k]) values = torch.stack(values, 0) dist.reduce(values, dst=0) if dist.get_rank() == 0 and average: values /= world_size reduced_dict = {k: v for k, v in zip(keys, values)} return reduced_dict def data_sampler(dataset, shuffle, distributed): if distributed: return data.distributed.DistributedSampler(dataset, shuffle=shuffle) if shuffle: return data.RandomSampler(dataset) else: return data.SequentialSampler(dataset) ================================================ FILE: RestoreFormer/distributed/launch.py ================================================ import os import torch from torch import distributed as dist from torch import multiprocessing as mp from . import distributed as dist_fn def find_free_port(): import socket sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.bind(("", 0)) port = sock.getsockname()[1] sock.close() return port def launch(fn, n_gpu_per_machine, n_machine=1, machine_rank=0, dist_url=None, args=()): world_size = n_machine * n_gpu_per_machine if world_size > 1: if "OMP_NUM_THREADS" not in os.environ: os.environ["OMP_NUM_THREADS"] = "1" if dist_url == "auto": if n_machine != 1: raise ValueError('dist_url="auto" not supported in multi-machine jobs') port = find_free_port() dist_url = f"tcp://127.0.0.1:{port}" if n_machine > 1 and dist_url.startswith("file://"): raise ValueError( "file:// is not a reliable init method in multi-machine jobs. Prefer tcp://" ) mp.spawn( distributed_worker, nprocs=n_gpu_per_machine, args=(fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args), daemon=False, ) else: fn(*args) def distributed_worker( local_rank, fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args ): if not torch.cuda.is_available(): raise OSError("CUDA is not available. Please check your environments") global_rank = machine_rank * n_gpu_per_machine + local_rank try: dist.init_process_group( backend="NCCL", init_method=dist_url, world_size=world_size, rank=global_rank, ) except Exception: raise OSError("failed to initialize NCCL groups") dist_fn.synchronize() if n_gpu_per_machine > torch.cuda.device_count(): raise ValueError( f"specified n_gpu_per_machine larger than available device ({torch.cuda.device_count()})" ) torch.cuda.set_device(local_rank) if dist_fn.LOCAL_PROCESS_GROUP is not None: raise ValueError("torch.distributed.LOCAL_PROCESS_GROUP is not None") n_machine = world_size // n_gpu_per_machine for i in range(n_machine): ranks_on_i = list(range(i * n_gpu_per_machine, (i + 1) * n_gpu_per_machine)) pg = dist.new_group(ranks_on_i) if i == machine_rank: dist_fn.distributed.LOCAL_PROCESS_GROUP = pg fn(*args) ================================================ FILE: RestoreFormer/models/vqgan_v1.py ================================================ import torch import torch.nn.functional as F import pytorch_lightning as pl from main import instantiate_from_config from RestoreFormer.modules.vqvae.utils import get_roi_regions class RestoreFormerModel(pl.LightningModule): def __init__(self, ddconfig, lossconfig, ckpt_path=None, ignore_keys=[], image_key="lq", colorize_nlabels=None, monitor=None, special_params_lr_scale=1.0, comp_params_lr_scale=1.0, schedule_step=[80000, 200000] ): super().__init__() self.image_key = image_key self.vqvae = instantiate_from_config(ddconfig) lossconfig['params']['distill_param']=ddconfig['params'] self.loss = instantiate_from_config(lossconfig) if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) if ('comp_weight' in lossconfig['params'] and lossconfig['params']['comp_weight']) or ('comp_style_weight' in lossconfig['params'] and lossconfig['params']['comp_style_weight']): self.use_facial_disc = True else: self.use_facial_disc = False self.fix_decoder = ddconfig['params']['fix_decoder'] self.disc_start = lossconfig['params']['disc_start'] self.special_params_lr_scale = special_params_lr_scale self.comp_params_lr_scale = comp_params_lr_scale self.schedule_step = schedule_step def init_from_ckpt(self, path, ignore_keys=list()): sd = torch.load(path, map_location="cpu")["state_dict"] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] state_dict = self.state_dict() require_keys = state_dict.keys() keys = sd.keys() un_pretrained_keys = [] for k in require_keys: if k not in keys: # miss 'vqvae.' if k[6:] in keys: state_dict[k] = sd[k[6:]] else: un_pretrained_keys.append(k) else: state_dict[k] = sd[k] # print(f'*************************************************') # print(f"Layers without pretraining: {un_pretrained_keys}") # print(f'*************************************************') self.load_state_dict(state_dict, strict=True) print(f"Restored from {path}") def forward(self, input): dec, diff, info, hs = self.vqvae(input) return dec, diff, info, hs def training_step(self, batch, batch_idx, optimizer_idx): x = batch[self.image_key] xrec, qloss, info, hs = self(x) if self.image_key != 'gt': x = batch['gt'] if self.use_facial_disc: loc_left_eyes = batch['loc_left_eye'] loc_right_eyes = batch['loc_right_eye'] loc_mouths = batch['loc_mouth'] face_ratio = xrec.shape[-1] / 512 components = get_roi_regions(x, xrec, loc_left_eyes, loc_right_eyes, loc_mouths, face_ratio) else: components = None if optimizer_idx == 0: # autoencode aeloss, log_dict_ae = self.loss(qloss, x, xrec, components, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train") self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) return aeloss if optimizer_idx == 1: # discriminator discloss, log_dict_disc = self.loss(qloss, x, xrec, components, optimizer_idx, self.global_step, last_layer=None, split="train") self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) return discloss if self.disc_start <= self.global_step: # left eye if optimizer_idx == 2: # discriminator disc_left_loss, log_dict_disc = self.loss(qloss, x, xrec, components, optimizer_idx, self.global_step, last_layer=None, split="train") self.log("train/disc_left_loss", disc_left_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) return disc_left_loss # right eye if optimizer_idx == 3: # discriminator disc_right_loss, log_dict_disc = self.loss(qloss, x, xrec, components, optimizer_idx, self.global_step, last_layer=None, split="train") self.log("train/disc_right_loss", disc_right_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) return disc_right_loss # mouth if optimizer_idx == 4: # discriminator disc_mouth_loss, log_dict_disc = self.loss(qloss, x, xrec, components, optimizer_idx, self.global_step, last_layer=None, split="train") self.log("train/disc_mouth_loss", disc_mouth_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) return disc_mouth_loss def validation_step(self, batch, batch_idx): x = batch[self.image_key] xrec, qloss, info, hs = self(x) if self.image_key != 'gt': x = batch['gt'] aeloss, log_dict_ae = self.loss(qloss, x, xrec, None, 0, self.global_step, last_layer=self.get_last_layer(), split="val") discloss, log_dict_disc = self.loss(qloss, x, xrec, None, 1, self.global_step, last_layer=None, split="val") rec_loss = log_dict_ae["val/rec_loss"] self.log("val/rec_loss", rec_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) self.log("val/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) self.log_dict(log_dict_ae) self.log_dict(log_dict_disc) return self.log_dict def configure_optimizers(self): lr = self.learning_rate normal_params = [] special_params = [] for name, param in self.vqvae.named_parameters(): if not param.requires_grad: continue if 'decoder' in name and 'attn' in name: special_params.append(param) else: normal_params.append(param) # print('special_params', special_params) opt_ae_params = [{'params': normal_params, 'lr': lr}, {'params': special_params, 'lr': lr*self.special_params_lr_scale}] opt_ae = torch.optim.Adam(opt_ae_params, betas=(0.5, 0.9)) opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)) optimizations = [opt_ae, opt_disc] s0 = torch.optim.lr_scheduler.MultiStepLR(opt_ae, milestones=self.schedule_step, gamma=0.1, verbose=True) s1 = torch.optim.lr_scheduler.MultiStepLR(opt_disc, milestones=self.schedule_step, gamma=0.1, verbose=True) schedules = [s0, s1] if self.use_facial_disc: opt_l = torch.optim.Adam(self.loss.net_d_left_eye.parameters(), lr=lr*self.comp_params_lr_scale, betas=(0.9, 0.99)) opt_r = torch.optim.Adam(self.loss.net_d_right_eye.parameters(), lr=lr*self.comp_params_lr_scale, betas=(0.9, 0.99)) opt_m = torch.optim.Adam(self.loss.net_d_mouth.parameters(), lr=lr*self.comp_params_lr_scale, betas=(0.9, 0.99)) optimizations += [opt_l, opt_r, opt_m] s2 = torch.optim.lr_scheduler.MultiStepLR(opt_l, milestones=self.schedule_step, gamma=0.1, verbose=True) s3 = torch.optim.lr_scheduler.MultiStepLR(opt_r, milestones=self.schedule_step, gamma=0.1, verbose=True) s4 = torch.optim.lr_scheduler.MultiStepLR(opt_m, milestones=self.schedule_step, gamma=0.1, verbose=True) schedules += [s2, s3, s4] return optimizations, schedules def get_last_layer(self): if self.fix_decoder: return self.vqvae.quant_conv.weight return self.vqvae.decoder.conv_out.weight def log_images(self, batch, **kwargs): log = dict() x = batch[self.image_key] x = x.to(self.device) xrec, _, _, _ = self(x) log["inputs"] = x log["reconstructions"] = xrec if self.image_key != 'gt': x = batch['gt'] log["gt"] = x return log ================================================ FILE: RestoreFormer/modules/discriminator/model.py ================================================ import functools import torch.nn as nn from RestoreFormer.modules.util import ActNorm def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find('BatchNorm') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0) class NLayerDiscriminator(nn.Module): """Defines a PatchGAN discriminator as in Pix2Pix --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py """ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): """Construct a PatchGAN discriminator Parameters: input_nc (int) -- the number of channels in input images ndf (int) -- the number of filters in the last conv layer n_layers (int) -- the number of conv layers in the discriminator norm_layer -- normalization layer """ super(NLayerDiscriminator, self).__init__() if not use_actnorm: norm_layer = nn.BatchNorm2d else: norm_layer = ActNorm if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters use_bias = norm_layer.func != nn.BatchNorm2d else: use_bias = norm_layer != nn.BatchNorm2d kw = 4 padw = 1 sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] nf_mult = 1 nf_mult_prev = 1 for n in range(1, n_layers): # gradually increase the number of filters nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] nf_mult_prev = nf_mult nf_mult = min(2 ** n_layers, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] sequence += [ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map self.main = nn.Sequential(*sequence) def forward(self, input): """Standard forward.""" return self.main(input) class NLayerDiscriminator_v1(nn.Module): """Defines a PatchGAN discriminator as in Pix2Pix --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py """ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): """Construct a PatchGAN discriminator Parameters: input_nc (int) -- the number of channels in input images ndf (int) -- the number of filters in the last conv layer n_layers (int) -- the number of conv layers in the discriminator norm_layer -- normalization layer """ super(NLayerDiscriminator_v1, self).__init__() if not use_actnorm: norm_layer = nn.BatchNorm2d else: norm_layer = ActNorm if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters use_bias = norm_layer.func != nn.BatchNorm2d else: use_bias = norm_layer != nn.BatchNorm2d self.n_layers = n_layers kw = 4 padw = 1 # sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] self.head = nn.Sequential(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)) # self.head = nn.Sequential(nn.Conv2d(3, 64, kernel_size=4, stride=1, padding=1), nn.LeakyReLU(0.2, True)).cuda() nf_mult = 1 nf_mult_prev = 1 self.body = nn.ModuleList() for n in range(1, n_layers): # gradually increase the number of filters nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) self.body.append(nn.Sequential( nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) )) nf_mult_prev = nf_mult nf_mult = min(2 ** n_layers, 8) self.beforlast = nn.Sequential( nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ) self.final = nn.Sequential( nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)) # output 1 channel prediction map # self.main = nn.Sequential(*sequence) def forward(self, input): """Standard forward.""" # return self.main(input) features = [] f = self.head(input) features.append(f) for i in range(self.n_layers-1): f = self.body[i](f) features.append(f) beforlastF = self.beforlast(f) final = self.final(beforlastF) return features, final ================================================ FILE: RestoreFormer/modules/losses/__init__.py ================================================ ================================================ FILE: RestoreFormer/modules/losses/lpips.py ================================================ """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" import torch import torch.nn as nn from torchvision import models from collections import namedtuple from RestoreFormer.util import get_ckpt_path class LPIPS(nn.Module): # Learned perceptual metric def __init__(self, use_dropout=True, style_weight=0.): super().__init__() self.scaling_layer = ScalingLayer() self.chns = [64, 128, 256, 512, 512] # vg16 features self.net = vgg16(pretrained=True, requires_grad=False) self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) self.load_from_pretrained() for param in self.parameters(): param.requires_grad = False self.style_weight = style_weight def load_from_pretrained(self, name="vgg_lpips"): ckpt = get_ckpt_path(name, "experiments/pretrained_models/lpips") self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) print("loaded pretrained LPIPS loss from {}".format(ckpt)) @classmethod def from_pretrained(cls, name="vgg_lpips"): if name is not "vgg_lpips": raise NotImplementedError model = cls() ckpt = get_ckpt_path(name) model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) return model def forward(self, input, target): in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) outs0, outs1 = self.net(in0_input), self.net(in1_input) feats0, feats1, diffs = {}, {}, {} lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] style_loss = torch.tensor([0.0]).to(input.device) for kk in range(len(self.chns)): feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 if self.style_weight > 0.: style_loss = style_loss + torch.mean((self._gram_mat(feats0[kk]) - self._gram_mat(feats1[kk])) ** 2) res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] val = res[0] for l in range(1, len(self.chns)): val += res[l] return val, style_loss * self.style_weight def _gram_mat(self, x): """Calculate Gram matrix. Args: x (torch.Tensor): Tensor with shape of (n, c, h, w). Returns: torch.Tensor: Gram matrix. """ n, c, h, w = x.size() features = x.view(n, c, w * h) features_t = features.transpose(1, 2) gram = features.bmm(features_t) / (c * h * w) return gram class ScalingLayer(nn.Module): def __init__(self): super(ScalingLayer, self).__init__() self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) def forward(self, inp): return (inp - self.shift) / self.scale class NetLinLayer(nn.Module): """ A single linear layer which does a 1x1 conv """ def __init__(self, chn_in, chn_out=1, use_dropout=False): super(NetLinLayer, self).__init__() layers = [nn.Dropout(), ] if (use_dropout) else [] layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] self.model = nn.Sequential(*layers) class vgg16(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True): super(vgg16, self).__init__() vgg_pretrained_features = models.vgg16(pretrained=pretrained).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.N_slices = 5 for x in range(4): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(4, 9): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(9, 16): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(16, 23): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(23, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h = self.slice1(X) h_relu1_2 = h h = self.slice2(h) h_relu2_2 = h h = self.slice3(h) h_relu3_3 = h h = self.slice4(h) h_relu4_3 = h h = self.slice5(h) h_relu5_3 = h vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) return out def normalize_tensor(x,eps=1e-10): norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) return x/(norm_factor+eps) def spatial_average(x, keepdim=True): return x.mean([2,3],keepdim=keepdim) ================================================ FILE: RestoreFormer/modules/losses/vqperceptual.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from copy import deepcopy from RestoreFormer.modules.losses.lpips import LPIPS from RestoreFormer.modules.discriminator.model import NLayerDiscriminator, weights_init from RestoreFormer.modules.vqvae.facial_component_discriminator import FacialComponentDiscriminator from basicsr.losses.losses import GANLoss, L1Loss from RestoreFormer.modules.vqvae.arcface_arch import ResNetArcFace class DummyLoss(nn.Module): def __init__(self): super().__init__() def adopt_weight(weight, global_step, threshold=0, value=0.): if global_step < threshold: weight = value return weight def hinge_d_loss(logits_real, logits_fake): loss_real = torch.mean(F.relu(1. - logits_real)) loss_fake = torch.mean(F.relu(1. + logits_fake)) d_loss = 0.5 * (loss_real + loss_fake) return d_loss def vanilla_d_loss(logits_real, logits_fake): d_loss = 0.5 * ( torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake))) return d_loss class VQLPIPSWithDiscriminatorWithCompWithIdentity(nn.Module): def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, perceptual_weight=1.0, use_actnorm=False, disc_ndf=64, disc_loss="hinge", comp_weight=0.0, comp_style_weight=0.0, identity_weight=0.0, comp_disc_loss='vanilla', lpips_style_weight=0.0, identity_model_path=None, **ignore_kwargs): super().__init__() assert disc_loss in ["hinge", "vanilla"] self.codebook_weight = codebook_weight self.pixel_weight = pixelloss_weight self.perceptual_loss = LPIPS(style_weight=lpips_style_weight).eval() self.perceptual_weight = perceptual_weight self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm, ndf=disc_ndf ).apply(weights_init) if comp_weight > 0: self.net_d_left_eye = FacialComponentDiscriminator() self.net_d_right_eye = FacialComponentDiscriminator() self.net_d_mouth = FacialComponentDiscriminator() print(f'Use components discrimination') self.cri_component = GANLoss(gan_type=comp_disc_loss, real_label_val=1.0, fake_label_val=0.0, loss_weight=comp_weight) if comp_style_weight > 0.: self.cri_style = L1Loss(loss_weight=comp_style_weight, reduction='mean') if identity_weight > 0: self.identity = ResNetArcFace(block = 'IRBlock', layers = [2, 2, 2, 2], use_se = False) print(f'Use identity loss') if identity_model_path is not None: sd = torch.load(identity_model_path, map_location="cpu") for k, v in deepcopy(sd).items(): if k.startswith('module.'): sd[k[7:]] = v sd.pop(k) self.identity.load_state_dict(sd, strict=True) for param in self.identity.parameters(): param.requires_grad = False self.cri_identity = L1Loss(loss_weight=identity_weight, reduction='mean') self.discriminator_iter_start = disc_start if disc_loss == "hinge": self.disc_loss = hinge_d_loss elif disc_loss == "vanilla": self.disc_loss = vanilla_d_loss else: raise ValueError(f"Unknown GAN loss '{disc_loss}'.") print(f"VQLPIPSWithDiscriminatorWithCompWithIdentity running with {disc_loss} loss.") self.disc_factor = disc_factor self.discriminator_weight = disc_weight self.comp_weight = comp_weight self.comp_style_weight = comp_style_weight self.identity_weight = identity_weight self.lpips_style_weight = lpips_style_weight def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): if last_layer is not None: nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] else: nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() d_weight = d_weight * self.discriminator_weight return d_weight def _gram_mat(self, x): """Calculate Gram matrix. Args: x (torch.Tensor): Tensor with shape of (n, c, h, w). Returns: torch.Tensor: Gram matrix. """ n, c, h, w = x.size() features = x.view(n, c, w * h) features_t = features.transpose(1, 2) gram = features.bmm(features_t) / (c * h * w) return gram def gray_resize_for_identity(self, out, size=128): out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :]) out_gray = out_gray.unsqueeze(1) out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False) return out_gray def forward(self, codebook_loss, gts, reconstructions, components, optimizer_idx, global_step, last_layer=None, split="train"): # now the GAN part if optimizer_idx == 0: rec_loss = (torch.abs(gts.contiguous() - reconstructions.contiguous())) * self.pixel_weight if self.perceptual_weight > 0: p_loss, p_style_loss = self.perceptual_loss(gts.contiguous(), reconstructions.contiguous()) rec_loss = rec_loss + self.perceptual_weight * p_loss else: p_loss = torch.tensor([0.0]) p_style_loss = torch.tensor([0.0]) nll_loss = rec_loss #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] nll_loss = torch.mean(nll_loss) # generator update logits_fake = self.discriminator(reconstructions.contiguous()) g_loss = -torch.mean(logits_fake) try: d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) except RuntimeError: assert not self.training d_weight = torch.tensor(0.0) disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() + p_style_loss log = { "{}/quant_loss".format(split): codebook_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), "{}/rec_loss".format(split): rec_loss.detach().mean(), "{}/p_loss".format(split): p_loss.detach().mean(), "{}/p_style_loss".format(split): p_style_loss.detach().mean(), "{}/d_weight".format(split): d_weight.detach(), "{}/disc_factor".format(split): torch.tensor(disc_factor), "{}/g_loss".format(split): g_loss.detach().mean(), } if self.comp_weight > 0. and components is not None and self.discriminator_iter_start < global_step: fake_left_eye, fake_left_eye_feats = self.net_d_left_eye(components['left_eyes'], return_feats=True) comp_g_loss = self.cri_component(fake_left_eye, True, is_disc=False) loss = loss + comp_g_loss log["{}/g_left_loss".format(split)] = comp_g_loss.detach() fake_right_eye, fake_right_eye_feats = self.net_d_right_eye(components['right_eyes'], return_feats=True) comp_g_loss = self.cri_component(fake_right_eye, True, is_disc=False) loss = loss + comp_g_loss log["{}/g_right_loss".format(split)] = comp_g_loss.detach() fake_mouth, fake_mouth_feats = self.net_d_mouth(components['mouths'], return_feats=True) comp_g_loss = self.cri_component(fake_mouth, True, is_disc=False) loss = loss + comp_g_loss log["{}/g_mouth_loss".format(split)] = comp_g_loss.detach() if self.comp_style_weight > 0.: _, real_left_eye_feats = self.net_d_left_eye(components['left_eyes_gt'], return_feats=True) _, real_right_eye_feats = self.net_d_right_eye(components['right_eyes_gt'], return_feats=True) _, real_mouth_feats = self.net_d_mouth(components['mouths_gt'], return_feats=True) def _comp_style(feat, feat_gt, criterion): return criterion(self._gram_mat(feat[0]), self._gram_mat( feat_gt[0].detach())) * 0.5 + criterion(self._gram_mat( feat[1]), self._gram_mat(feat_gt[1].detach())) comp_style_loss = 0. comp_style_loss = comp_style_loss + _comp_style(fake_left_eye_feats, real_left_eye_feats, self.cri_style) comp_style_loss = comp_style_loss + _comp_style(fake_right_eye_feats, real_right_eye_feats, self.cri_style) comp_style_loss = comp_style_loss + _comp_style(fake_mouth_feats, real_mouth_feats, self.cri_style) loss = loss + comp_style_loss log["{}/comp_style_loss".format(split)] = comp_style_loss.detach() if self.identity_weight > 0. and self.discriminator_iter_start < global_step: self.identity.eval() out_gray = self.gray_resize_for_identity(reconstructions) gt_gray = self.gray_resize_for_identity(gts) identity_gt = self.identity(gt_gray).detach() identity_out = self.identity(out_gray) identity_loss = self.cri_identity(identity_out, identity_gt) loss = loss + identity_loss log["{}/identity_loss".format(split)] = identity_loss.detach() log["{}/total_loss".format(split)] = loss.clone().detach().mean() return loss, log if optimizer_idx == 1: # second pass for discriminator update logits_real = self.discriminator(gts.contiguous().detach()) logits_fake = self.discriminator(reconstructions.contiguous().detach()) disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), "{}/logits_real".format(split): logits_real.detach().mean(), "{}/logits_fake".format(split): logits_fake.detach().mean() } return d_loss, log # left eye if optimizer_idx == 2: # third pass for discriminator update disc_factor = adopt_weight(1.0, global_step, threshold=self.discriminator_iter_start) fake_d_pred, _ = self.net_d_left_eye(components['left_eyes'].detach()) real_d_pred, _ = self.net_d_left_eye(components['left_eyes_gt']) d_loss = self.cri_component(real_d_pred, True, is_disc=True) + self.cri_component(fake_d_pred, False, is_disc=True) log = {"{}/d_left_loss".format(split): d_loss.clone().detach().mean()} return d_loss, log # right eye if optimizer_idx == 3: # forth pass for discriminator update fake_d_pred, _ = self.net_d_right_eye(components['right_eyes'].detach()) real_d_pred, _ = self.net_d_right_eye(components['right_eyes_gt']) d_loss = self.cri_component(real_d_pred, True, is_disc=True) + self.cri_component(fake_d_pred, False, is_disc=True) log = {"{}/d_right_loss".format(split): d_loss.clone().detach().mean()} return d_loss, log # mouth if optimizer_idx == 4: # fifth pass for discriminator update fake_d_pred, _ = self.net_d_mouth(components['mouths'].detach()) real_d_pred, _ = self.net_d_mouth(components['mouths_gt']) d_loss = self.cri_component(real_d_pred, True, is_disc=True) + self.cri_component(fake_d_pred, False, is_disc=True) log = {"{}/d_mouth_loss".format(split): d_loss.clone().detach().mean()} return d_loss, log ================================================ FILE: RestoreFormer/modules/util.py ================================================ import torch import torch.nn as nn def count_params(model): total_params = sum(p.numel() for p in model.parameters()) return total_params class ActNorm(nn.Module): def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False): assert affine super().__init__() self.logdet = logdet self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) self.allow_reverse_init = allow_reverse_init self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) def initialize(self, input): with torch.no_grad(): flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) mean = ( flatten.mean(1) .unsqueeze(1) .unsqueeze(2) .unsqueeze(3) .permute(1, 0, 2, 3) ) std = ( flatten.std(1) .unsqueeze(1) .unsqueeze(2) .unsqueeze(3) .permute(1, 0, 2, 3) ) self.loc.data.copy_(-mean) self.scale.data.copy_(1 / (std + 1e-6)) def forward(self, input, reverse=False): if reverse: return self.reverse(input) if len(input.shape) == 2: input = input[:,:,None,None] squeeze = True else: squeeze = False _, _, height, width = input.shape if self.training and self.initialized.item() == 0: self.initialize(input) self.initialized.fill_(1) h = self.scale * (input + self.loc) if squeeze: h = h.squeeze(-1).squeeze(-1) if self.logdet: log_abs = torch.log(torch.abs(self.scale)) logdet = height*width*torch.sum(log_abs) logdet = logdet * torch.ones(input.shape[0]).to(input) return h, logdet return h def reverse(self, output): if self.training and self.initialized.item() == 0: if not self.allow_reverse_init: raise RuntimeError( "Initializing ActNorm in reverse direction is " "disabled by default. Use allow_reverse_init=True to enable." ) else: self.initialize(output) self.initialized.fill_(1) if len(output.shape) == 2: output = output[:,:,None,None] squeeze = True else: squeeze = False h = output / self.scale - self.loc if squeeze: h = h.squeeze(-1).squeeze(-1) return h class Attention2DConv(nn.Module): """to replace the convolutional architecture entirely""" def __init__(self): super().__init__() ================================================ FILE: RestoreFormer/modules/vqvae/arcface_arch.py ================================================ import torch.nn as nn from basicsr.utils.registry import ARCH_REGISTRY def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class IRBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True): super(IRBlock, self).__init__() self.bn0 = nn.BatchNorm2d(inplanes) self.conv1 = conv3x3(inplanes, inplanes) self.bn1 = nn.BatchNorm2d(inplanes) self.prelu = nn.PReLU() self.conv2 = conv3x3(inplanes, planes, stride) self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample self.stride = stride self.use_se = use_se if self.use_se: self.se = SEBlock(planes) def forward(self, x): residual = x out = self.bn0(x) out = self.conv1(out) out = self.bn1(out) out = self.prelu(out) out = self.conv2(out) out = self.bn2(out) if self.use_se: out = self.se(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.prelu(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class SEBlock(nn.Module): def __init__(self, channel, reduction=16): super(SEBlock, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel), nn.Sigmoid()) def forward(self, x): b, c, _, _ = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) return x * y @ARCH_REGISTRY.register() class ResNetArcFace(nn.Module): def __init__(self, block, layers, use_se=True): if block == 'IRBlock': block = IRBlock self.inplanes = 64 self.use_se = use_se super(ResNetArcFace, self).__init__() self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64) self.prelu = nn.PReLU() self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2) self.bn4 = nn.BatchNorm2d(512) self.dropout = nn.Dropout() self.fc5 = nn.Linear(512 * 8 * 8, 512) self.bn5 = nn.BatchNorm1d(512) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.xavier_normal_(m.weight) elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.xavier_normal_(m.weight) nn.init.constant_(m.bias, 0) def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se)) self.inplanes = planes for _ in range(1, blocks): layers.append(block(self.inplanes, planes, use_se=self.use_se)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.prelu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.bn4(x) x = self.dropout(x) x = x.view(x.size(0), -1) x = self.fc5(x) x = self.bn5(x) return x ================================================ FILE: RestoreFormer/modules/vqvae/facial_component_discriminator.py ================================================ import math import random import torch from torch import nn from torch.nn import functional as F from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU, StyleGAN2Generator) from basicsr.ops.fused_act import FusedLeakyReLU from basicsr.utils.registry import ARCH_REGISTRY @ARCH_REGISTRY.register() class FacialComponentDiscriminator(nn.Module): def __init__(self): super(FacialComponentDiscriminator, self).__init__() self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) self.conv4 = ConvLayer(128, 256, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) self.conv5 = ConvLayer(256, 256, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False) def forward(self, x, return_feats=False): feat = self.conv1(x) feat = self.conv3(self.conv2(feat)) rlt_feats = [] if return_feats: rlt_feats.append(feat.clone()) feat = self.conv5(self.conv4(feat)) if return_feats: rlt_feats.append(feat.clone()) out = self.final_conv(feat) if return_feats: return out, rlt_feats else: return out, None ================================================ FILE: RestoreFormer/modules/vqvae/utils.py ================================================ from torchvision.ops import roi_align import torch def get_roi_regions(gt, output, loc_left_eyes, loc_right_eyes, loc_mouths, face_ratio=1, eye_out_size=80, mouth_out_size=120): # hard code eye_out_size *= face_ratio mouth_out_size *= face_ratio eye_out_size = int(eye_out_size) mouth_out_size = int(mouth_out_size) rois_eyes = [] rois_mouths = [] for b in range(loc_left_eyes.size(0)): # loop for batch size # left eye and right eye img_inds = loc_left_eyes.new_full((2, 1), b) bbox = torch.stack([loc_left_eyes[b, :], loc_right_eyes[b, :]], dim=0) # shape: (2, 4) rois = torch.cat([img_inds, bbox], dim=-1) # shape: (2, 5) rois_eyes.append(rois) # mouse img_inds = loc_left_eyes.new_full((1, 1), b) rois = torch.cat([img_inds, loc_mouths[b:b + 1, :]], dim=-1) # shape: (1, 5) rois_mouths.append(rois) rois_eyes = torch.cat(rois_eyes, 0) rois_mouths = torch.cat(rois_mouths, 0) # real images all_eyes = roi_align(gt, boxes=rois_eyes, output_size=eye_out_size) * face_ratio left_eyes_gt = all_eyes[0::2, :, :, :] right_eyes_gt = all_eyes[1::2, :, :, :] mouths_gt = roi_align(gt, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio # output all_eyes = roi_align(output, boxes=rois_eyes, output_size=eye_out_size) * face_ratio left_eyes = all_eyes[0::2, :, :, :] right_eyes = all_eyes[1::2, :, :, :] mouths = roi_align(output, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio return {'left_eyes_gt': left_eyes_gt, 'right_eyes_gt': right_eyes_gt, 'mouths_gt': mouths_gt, 'left_eyes': left_eyes, 'right_eyes': right_eyes, 'mouths': mouths} ================================================ FILE: RestoreFormer/modules/vqvae/vqvae_arch.py ================================================ import torch import torch.nn as nn import random import math import torch.nn.functional as F import numpy as np # from basicsr.utils.registry import ARCH_REGISTRY import torch.nn.utils.spectral_norm as SpectralNorm import RestoreFormer.distributed as dist_fn class VectorQuantizer(nn.Module): """ see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py ____________________________________________ Discretization bottleneck part of the VQ-VAE. Inputs: - n_e : number of embeddings - e_dim : dimension of embedding - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 _____________________________________________ """ def __init__(self, n_e, e_dim, beta): super(VectorQuantizer, self).__init__() self.n_e = n_e self.e_dim = e_dim self.beta = beta self.embedding = nn.Embedding(self.n_e, self.e_dim) self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) def forward(self, z): """ Inputs the output of the encoder network z and maps it to a discrete one-hot vector that is the index of the closest embedding vector e_j z (continuous) -> z_q (discrete) z.shape = (batch, channel, height, width) quantization pipeline: 1. get encoder input (B,C,H,W) 2. flatten input to (B*H*W,C) """ # reshape z -> (batch, height, width, channel) and flatten z = z.permute(0, 2, 3, 1).contiguous() z_flattened = z.view(-1, self.e_dim) # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ torch.sum(self.embedding.weight**2, dim=1) - 2 * \ torch.matmul(z_flattened, self.embedding.weight.t()) ## could possible replace this here # #\start... # find closest encodings min_value, min_encoding_indices = torch.min(d, dim=1) min_encoding_indices = min_encoding_indices.unsqueeze(1) min_encodings = torch.zeros( min_encoding_indices.shape[0], self.n_e).to(z) min_encodings.scatter_(1, min_encoding_indices, 1) # dtype min encodings: torch.float32 # min_encodings shape: torch.Size([2048, 512]) # min_encoding_indices.shape: torch.Size([2048, 1]) # get quantized latent vectors z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) #.........\end # with: # .........\start #min_encoding_indices = torch.argmin(d, dim=1) #z_q = self.embedding(min_encoding_indices) # ......\end......... (TODO) # compute loss for embedding loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ torch.mean((z_q - z.detach()) ** 2) # preserve gradients z_q = z + (z_q - z).detach() # perplexity e_mean = torch.mean(min_encodings, dim=0) perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) # reshape back to match original input shape z_q = z_q.permute(0, 3, 1, 2).contiguous() return z_q, loss, (perplexity, min_encodings, min_encoding_indices, d) def get_codebook_entry(self, indices, shape): # shape specifying (batch, height, width, channel) # TODO: check for more easy handling with nn.Embedding min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices) min_encodings.scatter_(1, indices[:,None], 1) # get quantized latent vectors z_q = torch.matmul(min_encodings.float(), self.embedding.weight) if shape is not None: z_q = z_q.view(shape) # reshape back to match original input shape z_q = z_q.permute(0, 3, 1, 2).contiguous() return z_q # pytorch_diffusion + derived encoder decoder def nonlinearity(x): # swish return x*torch.sigmoid(x) def Normalize(in_channels): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) class Upsample(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") if self.with_conv: x = self.conv(x) return x class Downsample(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) def forward(self, x): if self.with_conv: pad = (0,1,0,1) x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) else: x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) return x class ResnetBlock(nn.Module): def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.norm1 = Normalize(in_channels) self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels > 0: self.temb_proj = torch.nn.Linear(temb_channels, out_channels) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) else: self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x, temb): h = x h = self.norm1(h) h = nonlinearity(h) h = self.conv1(h) if temb is not None: h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] h = self.norm2(h) h = nonlinearity(h) h = self.dropout(h) h = self.conv2(h) if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) else: x = self.nin_shortcut(x) return x+h class MultiHeadAttnBlock(nn.Module): def __init__(self, in_channels, head_size=1): super().__init__() self.in_channels = in_channels self.head_size = head_size self.att_size = in_channels // head_size assert(in_channels % head_size == 0), 'The size of head should be divided by the number of channels.' self.norm1 = Normalize(in_channels) self.norm2 = Normalize(in_channels) self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.num = 0 def forward(self, x, y=None): h_ = x h_ = self.norm1(h_) if y is None: y = h_ else: y = self.norm2(y) q = self.q(y) k = self.k(h_) v = self.v(h_) # compute attention b,c,h,w = q.shape q = q.reshape(b, self.head_size, self.att_size ,h*w) q = q.permute(0, 3, 1, 2) # b, hw, head, att k = k.reshape(b, self.head_size, self.att_size ,h*w) k = k.permute(0, 3, 1, 2) v = v.reshape(b, self.head_size, self.att_size ,h*w) v = v.permute(0, 3, 1, 2) q = q.transpose(1, 2) v = v.transpose(1, 2) k = k.transpose(1, 2).transpose(2,3) scale = int(self.att_size)**(-0.5) q.mul_(scale) w_ = torch.matmul(q, k) w_ = F.softmax(w_, dim=3) w_ = w_.matmul(v) w_ = w_.transpose(1, 2).contiguous() # [b, h*w, head, att] w_ = w_.view(b, h, w, -1) w_ = w_.permute(0, 3, 1, 2) w_ = self.proj_out(w_) return x+w_ class MultiHeadEncoder(nn.Module): def __init__(self, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks=2, attn_resolutions=[16], dropout=0.0, resamp_with_conv=True, in_channels=3, resolution=512, z_channels=256, double_z=True, enable_mid=True, head_size=1, **ignore_kwargs): super().__init__() self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels self.enable_mid = enable_mid # downsampling self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) curr_res = resolution in_ch_mult = (1,)+tuple(ch_mult) self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = ch*in_ch_mult[i_level] block_out = ch*ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(MultiHeadAttnBlock(block_in, head_size)) down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions-1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle if self.enable_mid: self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d(block_in, 2*z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) hs = {} # timestep embedding temb = None # downsampling h = self.conv_in(x) hs['in'] = h for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](h, temb) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) if i_level != self.num_resolutions-1: # hs.append(h) hs['block_'+str(i_level)] = h h = self.down[i_level].downsample(h) # middle # h = hs[-1] if self.enable_mid: h = self.mid.block_1(h, temb) hs['block_'+str(i_level)+'_atten'] = h h = self.mid.attn_1(h) h = self.mid.block_2(h, temb) hs['mid_atten'] = h # end h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h) # hs.append(h) hs['out'] = h return hs class MultiHeadDecoder(nn.Module): def __init__(self, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks=2, attn_resolutions=16, dropout=0.0, resamp_with_conv=True, in_channels=3, resolution=512, z_channels=256, give_pre_end=False, enable_mid=True, head_size=1, **ignorekwargs): super().__init__() self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels self.give_pre_end = give_pre_end self.enable_mid = enable_mid # compute in_ch_mult, block_in and curr_res at lowest res in_ch_mult = (1,)+tuple(ch_mult) block_in = ch*ch_mult[self.num_resolutions-1] curr_res = resolution // 2**(self.num_resolutions-1) self.z_shape = (1,z_channels,curr_res,curr_res) print("Working with z of shape {} = {} dimensions.".format( self.z_shape, np.prod(self.z_shape))) # z to block_in self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) # middle if self.enable_mid: self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = ch*ch_mult[i_level] for i_block in range(self.num_res_blocks+1): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(MultiHeadAttnBlock(block_in, head_size)) up = nn.Module() up.block = block up.attn = attn if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) def forward(self, z): #assert z.shape[1:] == self.z_shape[1:] self.last_z_shape = z.shape # timestep embedding temb = None # z to block_in h = self.conv_in(z) # middle if self.enable_mid: h = self.mid.block_1(h, temb) h = self.mid.attn_1(h) h = self.mid.block_2(h, temb) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks+1): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: h = self.up[i_level].upsample(h) # end if self.give_pre_end: return h h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h) return h class MultiHeadDecoderTransformer(nn.Module): def __init__(self, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks=2, attn_resolutions=16, dropout=0.0, resamp_with_conv=True, in_channels=3, resolution=512, z_channels=256, give_pre_end=False, enable_mid=True, head_size=1, **ignorekwargs): super().__init__() self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels self.give_pre_end = give_pre_end self.enable_mid = enable_mid # compute in_ch_mult, block_in and curr_res at lowest res in_ch_mult = (1,)+tuple(ch_mult) block_in = ch*ch_mult[self.num_resolutions-1] curr_res = resolution // 2**(self.num_resolutions-1) self.z_shape = (1,z_channels,curr_res,curr_res) print("Working with z of shape {} = {} dimensions.".format( self.z_shape, np.prod(self.z_shape))) # z to block_in self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) # middle if self.enable_mid: self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = ch*ch_mult[i_level] for i_block in range(self.num_res_blocks+1): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(MultiHeadAttnBlock(block_in, head_size)) up = nn.Module() up.block = block up.attn = attn if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) def forward(self, z, hs): #assert z.shape[1:] == self.z_shape[1:] # self.last_z_shape = z.shape # timestep embedding temb = None # z to block_in h = self.conv_in(z) # middle if self.enable_mid: h = self.mid.block_1(h, temb) h = self.mid.attn_1(h, hs['mid_atten']) h = self.mid.block_2(h, temb) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks+1): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h, hs['block_'+str(i_level)+'_atten']) # hfeature = h.clone() if i_level != 0: h = self.up[i_level].upsample(h) # end if self.give_pre_end: return h h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h) return h class VQVAEGAN(nn.Module): def __init__(self, n_embed=1024, embed_dim=256, ch=128, out_ch=3, ch_mult=(1,2,4,8), num_res_blocks=2, attn_resolutions=16, dropout=0.0, in_channels=3, resolution=512, z_channels=256, double_z=False, enable_mid=True, fix_decoder=False, fix_codebook=False, head_size=1, **ignore_kwargs): super(VQVAEGAN, self).__init__() self.encoder = MultiHeadEncoder(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks, attn_resolutions=attn_resolutions, dropout=dropout, in_channels=in_channels, resolution=resolution, z_channels=z_channels, double_z=double_z, enable_mid=enable_mid, head_size=head_size) self.decoder = MultiHeadDecoder(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks, attn_resolutions=attn_resolutions, dropout=dropout, in_channels=in_channels, resolution=resolution, z_channels=z_channels, enable_mid=enable_mid, head_size=head_size) self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25) self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) if fix_decoder: for _, param in self.decoder.named_parameters(): param.requires_grad = False for _, param in self.post_quant_conv.named_parameters(): param.requires_grad = False for _, param in self.quantize.named_parameters(): param.requires_grad = False elif fix_codebook: for _, param in self.quantize.named_parameters(): param.requires_grad = False def encode(self, x): hs = self.encoder(x) h = self.quant_conv(hs['out']) quant, emb_loss, info = self.quantize(h) return quant, emb_loss, info, hs def decode(self, quant): quant = self.post_quant_conv(quant) dec = self.decoder(quant) return dec def forward(self, input): quant, diff, info, hs = self.encode(input) dec = self.decode(quant) return dec, diff, info, hs class VQVAEGANMultiHeadTransformer(nn.Module): def __init__(self, n_embed=1024, embed_dim=256, ch=128, out_ch=3, ch_mult=(1,2,4,8), num_res_blocks=2, attn_resolutions=16, dropout=0.0, in_channels=3, resolution=512, z_channels=256, double_z=False, enable_mid=True, fix_decoder=False, fix_codebook=False, fix_encoder=False, constrastive_learning_loss_weight=0.0, head_size=1): super(VQVAEGANMultiHeadTransformer, self).__init__() self.encoder = MultiHeadEncoder(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks, attn_resolutions=attn_resolutions, dropout=dropout, in_channels=in_channels, resolution=resolution, z_channels=z_channels, double_z=double_z, enable_mid=enable_mid, head_size=head_size) self.decoder = MultiHeadDecoderTransformer(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks, attn_resolutions=attn_resolutions, dropout=dropout, in_channels=in_channels, resolution=resolution, z_channels=z_channels, enable_mid=enable_mid, head_size=head_size) self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25) self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) if fix_decoder: for _, param in self.decoder.named_parameters(): param.requires_grad = False for _, param in self.post_quant_conv.named_parameters(): param.requires_grad = False for _, param in self.quantize.named_parameters(): param.requires_grad = False elif fix_codebook: for _, param in self.quantize.named_parameters(): param.requires_grad = False if fix_encoder: for _, param in self.encoder.named_parameters(): param.requires_grad = False def encode(self, x): hs = self.encoder(x) h = self.quant_conv(hs['out']) quant, emb_loss, info = self.quantize(h) return quant, emb_loss, info, hs def decode(self, quant, hs): quant = self.post_quant_conv(quant) dec = self.decoder(quant, hs) return dec def forward(self, input): quant, diff, info, hs = self.encode(input) dec = self.decode(quant, hs) return dec, diff, info, hs ================================================ FILE: RestoreFormer/util.py ================================================ import os, hashlib import requests from tqdm import tqdm URL_MAP = { "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" } CKPT_MAP = { "vgg_lpips": "vgg.pth" } MD5_MAP = { "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" } def download(url, local_path, chunk_size=1024): os.makedirs(os.path.split(local_path)[0], exist_ok=True) with requests.get(url, stream=True) as r: total_size = int(r.headers.get("content-length", 0)) with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: with open(local_path, "wb") as f: for data in r.iter_content(chunk_size=chunk_size): if data: f.write(data) pbar.update(chunk_size) def md5_hash(path): with open(path, "rb") as f: content = f.read() return hashlib.md5(content).hexdigest() def get_ckpt_path(name, root, check=False): assert name in URL_MAP path = os.path.join(root, CKPT_MAP[name]) if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) download(URL_MAP[name], path) md5 = md5_hash(path) assert md5 == MD5_MAP[name], md5 return path class KeyNotFoundError(Exception): def __init__(self, cause, keys=None, visited=None): self.cause = cause self.keys = keys self.visited = visited messages = list() if keys is not None: messages.append("Key not found: {}".format(keys)) if visited is not None: messages.append("Visited: {}".format(visited)) messages.append("Cause:\n{}".format(cause)) message = "\n".join(messages) super().__init__(message) def retrieve( list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False ): """Given a nested list or dict return the desired value at key expanding callable nodes if necessary and :attr:`expand` is ``True``. The expansion is done in-place. Parameters ---------- list_or_dict : list or dict Possibly nested list or dictionary. key : str key/to/value, path like string describing all keys necessary to consider to get to the desired value. List indices can also be passed here. splitval : str String that defines the delimiter between keys of the different depth levels in `key`. default : obj Value returned if :attr:`key` is not found. expand : bool Whether to expand callable nodes on the path or not. Returns ------- The desired value or if :attr:`default` is not ``None`` and the :attr:`key` is not found returns ``default``. Raises ------ Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is ``None``. """ keys = key.split(splitval) success = True try: visited = [] parent = None last_key = None for key in keys: if callable(list_or_dict): if not expand: raise KeyNotFoundError( ValueError( "Trying to get past callable node with expand=False." ), keys=keys, visited=visited, ) list_or_dict = list_or_dict() parent[last_key] = list_or_dict last_key = key parent = list_or_dict try: if isinstance(list_or_dict, dict): list_or_dict = list_or_dict[key] else: list_or_dict = list_or_dict[int(key)] except (KeyError, IndexError, ValueError) as e: raise KeyNotFoundError(e, keys=keys, visited=visited) visited += [key] # final expansion of retrieved value if expand and callable(list_or_dict): list_or_dict = list_or_dict() parent[last_key] = list_or_dict except KeyNotFoundError as e: if default is None: raise e else: list_or_dict = default success = False if not pass_success: return list_or_dict else: return list_or_dict, success if __name__ == "__main__": config = {"keya": "a", "keyb": "b", "keyc": {"cc1": 1, "cc2": 2, } } from omegaconf import OmegaConf config = OmegaConf.create(config) print(config) retrieve(config, "keya") ================================================ FILE: __init__.py ================================================ ================================================ FILE: configs/HQ_Dictionary.yaml ================================================ model: base_learning_rate: 4.5e-6 target: RestoreFormer.models.vqgan_v1.RestoreFormerModel params: image_key: 'gt' schedule_step: [400000, 800000] # ignore_keys: ['vqvae.quantize.utility_counter'] ddconfig: target: RestoreFormer.modules.vqvae.vqvae_arch.VQVAEGAN params: embed_dim: 256 n_embed: 1024 double_z: False z_channels: 256 resolution: 512 in_channels: 3 out_ch: 3 ch: 64 ch_mult: [ 1,2,2,4,4,8] # num_down = len(ch_mult)-1 num_res_blocks: 2 attn_resolutions: [16] dropout: 0.0 enable_mid: True fix_decoder: False fix_codebook: False head_size: 8 lossconfig: target: RestoreFormer.modules.losses.vqperceptual.VQLPIPSWithDiscriminatorWithCompWithIdentity params: disc_conditional: False disc_in_channels: 3 disc_start: 30001 disc_weight: 0.8 codebook_weight: 1.0 use_actnorm: False data: target: main.DataModuleFromConfig params: batch_size: 4 num_workers: 8 train: target: basicsr.data.ffhq_dataset.FFHQDataset params: dataroot_gt: data/FFHQ/images512x512 io_backend: type: disk use_hflip: True mean: [0.5, 0.5, 0.5] std: [0.5, 0.5, 0.5] out_size: 512 validation: target: basicsr.data.ffhq_dataset.FFHQDataset params: dataroot_gt: data/FFHQ/images512x512 io_backend: type: disk use_hflip: False mean: [0.5, 0.5, 0.5] std: [0.5, 0.5, 0.5] out_size: 512 ================================================ FILE: configs/RestoreFormer.yaml ================================================ model: base_learning_rate: 4.5e-6 target: RestoreFormer.models.vqgan_v1.RestoreFormerModel params: image_key: 'lq' ckpt_path: 'YOUR TRAINED HD DICTIONARY MODEL' special_params_lr_scale: 10 comp_params_lr_scale: 10 schedule_step: [4000000, 8000000] ddconfig: target: RestoreFormer.modules.vqvae.vqvae_arch.VQVAEGANMultiHeadTransformer params: embed_dim: 256 n_embed: 1024 double_z: False z_channels: 256 resolution: 512 in_channels: 3 out_ch: 3 ch: 64 ch_mult: [ 1,2,2,4,4,8] # num_down = len(ch_mult)-1 num_res_blocks: 2 dropout: 0.0 attn_resolutions: [16] enable_mid: True fix_decoder: False fix_codebook: True fix_encoder: False head_size: 8 lossconfig: target: RestoreFormer.modules.losses.vqperceptual.VQLPIPSWithDiscriminatorWithCompWithIdentity params: disc_conditional: False disc_in_channels: 3 disc_start: 10001 disc_weight: 0.8 codebook_weight: 1.0 use_actnorm: False comp_weight: 1.5 comp_style_weight: 2e3 #2000.0 identity_weight: 3 #1.5 lpips_style_weight: 1e9 identity_model_path: experiments/pretrained_models/arcface_resnet18.pth data: target: main.DataModuleFromConfig params: batch_size: 4 num_workers: 8 train: target: RestoreFormer.data.ffhq_degradation_dataset.FFHQDegradationDataset params: dataroot_gt: data/FFHQ/images512x512 io_backend: type: disk use_hflip: True mean: [0.5, 0.5, 0.5] std: [0.5, 0.5, 0.5] out_size: 512 blur_kernel_size: [19,20] kernel_list: ['iso', 'aniso'] kernel_prob: [0.5, 0.5] blur_sigma: [0.1, 10] downsample_range: [0.8, 8] noise_range: [0, 20] jpeg_range: [60, 100] color_jitter_prob: ~ color_jitter_shift: 20 color_jitter_pt_prob: ~ gray_prob: ~ gt_gray: True crop_components: True component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth eye_enlarge_ratio: 1.4 validation: target: RestoreFormer.data.ffhq_degradation_dataset.FFHQDegradationDataset params: dataroot_gt: data/FFHQ/images512x512 io_backend: type: disk use_hflip: False mean: [0.5, 0.5, 0.5] std: [0.5, 0.5, 0.5] out_size: 512 blur_kernel_size: [19,20] kernel_list: ['iso', 'aniso'] kernel_prob: [0.5, 0.5] blur_sigma: [0.1, 10] downsample_range: [0.8, 8] noise_range: [0, 20] jpeg_range: [60, 100] # color jitter and gray color_jitter_prob: ~ color_jitter_shift: 20 color_jitter_pt_prob: ~ gray_prob: ~ gt_gray: True crop_components: False component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth eye_enlarge_ratio: 1.4 ================================================ FILE: main.py ================================================ import argparse, os, sys, datetime, glob, importlib from omegaconf import OmegaConf import numpy as np from PIL import Image import torch import torchvision from torch.utils.data import random_split, DataLoader, Dataset import pytorch_lightning as pl from pytorch_lightning import seed_everything from pytorch_lightning.trainer import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor from pytorch_lightning.utilities.distributed import rank_zero_only import random def get_obj_from_str(string, reload=False): module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def get_parser(**parser_kwargs): def str2bool(v): if isinstance(v, bool): return v if v.lower() in ("yes", "true", "t", "y", "1"): return True elif v.lower() in ("no", "false", "f", "n", "0"): return False else: raise argparse.ArgumentTypeError("Boolean value expected.") parser = argparse.ArgumentParser(**parser_kwargs) parser.add_argument( "-n", "--name", type=str, const=True, default="", nargs="?", help="postfix for logdir", ) parser.add_argument( "-r", "--resume", type=str, const=True, default="", nargs="?", help="resume from logdir or checkpoint in logdir", ) parser.add_argument( "--pretrain", type=str, const=True, default="", nargs="?", help="pretrain with existed weights", ) parser.add_argument( "-b", "--base", nargs="*", metavar="base_config.yaml", help="paths to base configs. Loaded from left-to-right. " "Parameters can be overwritten or added with command-line options of the form `--key value`.", default=list(), ) parser.add_argument( "-t", "--train", type=str2bool, const=True, default=False, nargs="?", help="train", ) parser.add_argument( "--no-test", type=str2bool, const=True, default=False, nargs="?", help="disable test", ) parser.add_argument("-p", "--project", help="name of new or path to existing project") parser.add_argument( "-d", "--debug", type=str2bool, nargs="?", const=True, default=False, help="enable post-mortem debugging", ) parser.add_argument( "-s", "--seed", type=int, default=23, help="seed for seed_everything", ) parser.add_argument( "--random-seed", type=str2bool, nargs="?", const=True, default=False, help="enable post-mortem debugging", ) parser.add_argument( "-f", "--postfix", type=str, default="", help="post-postfix for default name", ) parser.add_argument( "--root-path", type=str, default="./", help="root path for saving checkpoints and logs" ) parser.add_argument( "--num-nodes", type=int, default=1, help="number of gpu nodes", ) return parser def nondefault_trainer_args(opt): parser = argparse.ArgumentParser() parser = Trainer.add_argparse_args(parser) args = parser.parse_args([]) return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) def instantiate_from_config(config): if not "target" in config: raise KeyError("Expected key `target` to instantiate.") if 'basicsr.data' in config["target"] or \ 'FFHQDegradationDataset' in config["target"]: return get_obj_from_str(config["target"])(config.get("params", dict())) return get_obj_from_str(config["target"])(**config.get("params", dict())) class WrappedDataset(Dataset): """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" def __init__(self, dataset): self.data = dataset def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] class DataModuleFromConfig(pl.LightningDataModule): def __init__(self, batch_size, train=None, validation=None, test=None, wrap=False, num_workers=None): super().__init__() self.batch_size = batch_size self.dataset_configs = dict() self.num_workers = num_workers if num_workers is not None else batch_size*2 if train is not None: self.dataset_configs["train"] = train self.train_dataloader = self._train_dataloader if validation is not None: self.dataset_configs["validation"] = validation self.val_dataloader = self._val_dataloader if test is not None: self.dataset_configs["test"] = test self.test_dataloader = self._test_dataloader self.wrap = wrap def prepare_data(self): for data_cfg in self.dataset_configs.values(): instantiate_from_config(data_cfg) def setup(self, stage=None): self.datasets = dict( (k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) if self.wrap: for k in self.datasets: self.datasets[k] = WrappedDataset(self.datasets[k]) def _train_dataloader(self): return DataLoader(self.datasets["train"], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True) def _val_dataloader(self): return DataLoader(self.datasets["validation"], batch_size=self.batch_size, num_workers=self.num_workers) def _test_dataloader(self): return DataLoader(self.datasets["test"], batch_size=self.batch_size, num_workers=self.num_workers) class SetupCallback(Callback): def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config): super().__init__() self.resume = resume self.now = now self.logdir = logdir self.ckptdir = ckptdir self.cfgdir = cfgdir self.config = config self.lightning_config = lightning_config def on_pretrain_routine_start(self, trainer, pl_module): if trainer.global_rank == 0: # import pdb # pdb.set_trace() # Create logdirs and save configs os.makedirs(self.logdir, exist_ok=True) os.makedirs(self.ckptdir, exist_ok=True) os.makedirs(self.cfgdir, exist_ok=True) print("Project config") print(self.config.pretty()) OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) print("Lightning config") print(self.lightning_config.pretty()) OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}), os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now))) class ImageLogger(Callback): def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True): super().__init__() self.batch_freq = batch_frequency self.max_images = max_images self.logger_log_images = { pl.loggers.WandbLogger: self._wandb, pl.loggers.TestTubeLogger: self._testtube, } self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)] if not increase_log_steps: self.log_steps = [self.batch_freq] self.clamp = clamp @rank_zero_only def _wandb(self, pl_module, images, batch_idx, split): raise ValueError("No way wandb") grids = dict() for k in images: grid = torchvision.utils.make_grid(images[k]) grids[f"{split}/{k}"] = wandb.Image(grid) pl_module.logger.experiment.log(grids) @rank_zero_only def _testtube(self, pl_module, images, batch_idx, split): for k in images: grid = torchvision.utils.make_grid(images[k]) grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w tag = f"{split}/{k}" pl_module.logger.experiment.add_image( tag, grid, global_step=pl_module.global_step) @rank_zero_only def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx): root = os.path.join(save_dir, "images", split) for k in images: grid = torchvision.utils.make_grid(images[k], nrow=4) grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w grid = grid.transpose(0,1).transpose(1,2).squeeze(-1) grid = grid.numpy() grid = (grid*255).astype(np.uint8) filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format( k, global_step, current_epoch, batch_idx) path = os.path.join(root, filename) os.makedirs(os.path.split(path)[0], exist_ok=True) Image.fromarray(grid).save(path) def log_img(self, pl_module, batch, batch_idx, split="train"): if (self.check_frequency(batch_idx) and # batch_idx % self.batch_freq == 0 hasattr(pl_module, "log_images") and callable(pl_module.log_images) and self.max_images > 0): logger = type(pl_module.logger) is_train = pl_module.training if is_train: pl_module.eval() with torch.no_grad(): images = pl_module.log_images(batch, split=split) for k in images: N = min(images[k].shape[0], self.max_images) images[k] = images[k][:N] if isinstance(images[k], torch.Tensor): images[k] = images[k].detach().cpu() if self.clamp: images[k] = torch.clamp(images[k], -1., 1.) self.log_local(pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch, batch_idx) logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None) logger_log_images(pl_module, images, pl_module.global_step, split) if is_train: pl_module.train() def check_frequency(self, batch_idx): if (batch_idx % self.batch_freq) == 0 or (batch_idx in self.log_steps): try: self.log_steps.pop(0) except IndexError: pass return True return False def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): self.log_img(pl_module, batch, batch_idx, split="train") def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): self.log_img(pl_module, batch, batch_idx, split="val") if __name__ == "__main__": # custom parser to specify config files, train, test and debug mode, # postfix, resume. # `--key value` arguments are interpreted as arguments to the trainer. # `nested.key=value` arguments are interpreted as config parameters. # configs are merged from left-to-right followed by command line parameters. # model: # base_learning_rate: float # target: path to lightning module # params: # key: value # data: # target: main.DataModuleFromConfig # params: # batch_size: int # wrap: bool # train: # target: path to train dataset # params: # key: value # validation: # target: path to validation dataset # params: # key: value # test: # target: path to test dataset # params: # key: value # lightning: (optional, has sane defaults and can be specified on cmdline) # trainer: # additional arguments to trainer # logger: # logger to instantiate # modelcheckpoint: # modelcheckpoint to instantiate # callbacks: # callback1: # target: importpath # params: # key: value now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") # add cwd for convenience and to make classes in this file available when # running as `python main.py` # (in particular `main.DataModuleFromConfig`) sys.path.append(os.getcwd()) parser = get_parser() parser = Trainer.add_argparse_args(parser) opt, unknown = parser.parse_known_args() if opt.name and opt.resume: raise ValueError( "-n/--name and -r/--resume cannot be specified both." "If you want to resume training in a new log folder, " "use -n/--name in combination with --resume_from_checkpoint" ) if opt.resume: if not os.path.exists(opt.resume): raise ValueError("Cannot find {}".format(opt.resume)) if os.path.isfile(opt.resume): paths = opt.resume.split("/") idx = len(paths)-paths[::-1].index("logs")+1 logdir = "/".join(paths[:idx]) ckpt = opt.resume else: assert os.path.isdir(opt.resume), opt.resume logdir = opt.resume.rstrip("/") ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") opt.resume_from_checkpoint = ckpt base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) opt.base = base_configs+opt.base _tmp = logdir.split("/") nowname = _tmp[_tmp.index("logs")+1]+opt.postfix logdir = os.path.join(opt.root_path, "logs", nowname) else: if opt.name: name = "_"+opt.name elif opt.base: cfg_fname = os.path.split(opt.base[0])[-1] cfg_name = os.path.splitext(cfg_fname)[0] name = "_"+cfg_name else: name = "" nowname = now+name+opt.postfix logdir = os.path.join(opt.root_path, "logs", nowname) if opt.random_seed: opt.seed = random.randint(1,100) logdir = logdir + '_seed' + str(opt.seed) ckptdir = os.path.join(logdir, "checkpoints") cfgdir = os.path.join(logdir, "configs") seed_everything(opt.seed) try: # init and save configs configs = [OmegaConf.load(cfg) for cfg in opt.base] cli = OmegaConf.from_dotlist(unknown) config = OmegaConf.merge(*configs, cli) lightning_config = config.pop("lightning", OmegaConf.create()) # merge trainer cli with config trainer_config = lightning_config.get("trainer", OmegaConf.create()) # default to ddp # trainer_config["distributed_backend"] = "ddp" trainer_config["accelerator"] = "ddp" # trainer_config["plugins"]="ddp_sharded" for k in nondefault_trainer_args(opt): trainer_config[k] = getattr(opt, k) if not "gpus" in trainer_config: del trainer_config["distributed_backend"] cpu = True else: gpuinfo = trainer_config["gpus"] print(f"Running on GPUs {gpuinfo}") cpu = False trainer_opt = argparse.Namespace(**trainer_config) lightning_config.trainer = trainer_config # model model = instantiate_from_config(config.model) # trainer and callbacks trainer_kwargs = dict() # trainer_kwargs['sync_batchnorm'] = True # default logger configs # NOTE wandb < 0.10.0 interferes with shutdown # wandb >= 0.10.0 seems to fix it but still interferes with pudb # debugging (wrongly sized pudb ui) # thus prefer testtube for now default_logger_cfgs = { "wandb": { "target": "pytorch_lightning.loggers.WandbLogger", "params": { "name": nowname, "save_dir": logdir, "offline": opt.debug, "id": nowname, } }, "testtube": { "target": "pytorch_lightning.loggers.TestTubeLogger", "params": { "name": "testtube", "save_dir": logdir, } }, } default_logger_cfg = default_logger_cfgs["testtube"] logger_cfg = lightning_config.logger or OmegaConf.create() logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to # specify which metric is used to determine best models default_modelckpt_cfg = { "target": "pytorch_lightning.callbacks.ModelCheckpoint", "params": { "dirpath": ckptdir, "filename": "{epoch:06}", "verbose": True, "save_last": True, "period": 1 } } if hasattr(model, "monitor"): print(f"Monitoring {model.monitor} as checkpoint metric.") default_modelckpt_cfg["params"]["monitor"] = model.monitor default_modelckpt_cfg["params"]["save_top_k"] = 3 modelckpt_cfg = lightning_config.modelcheckpoint or OmegaConf.create() modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg) # add callback which sets up log directory default_callbacks_cfg = { "setup_callback": { "target": "main.SetupCallback", "params": { "resume": opt.resume, "now": now, "logdir": logdir, "ckptdir": ckptdir, "cfgdir": cfgdir, "config": config, "lightning_config": lightning_config, } }, "image_logger": { "target": "main.ImageLogger", "params": { "batch_frequency": 750, "max_images": 4, "clamp": True } }, "learning_rate_logger": { "target": "main.LearningRateMonitor", "params": { "logging_interval": "step", #"log_momentum": True } }, } callbacks_cfg = lightning_config.callbacks or OmegaConf.create() callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) # data data = instantiate_from_config(config.data) # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html # calling these ourselves should not be necessary but it is. # lightning still takes care of proper multiprocessing though data.prepare_data() data.setup() # configure learning rate bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate if not cpu: ngpu = len(lightning_config.trainer.gpus.strip(",").split(',')) else: ngpu = 1 accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches or 1 print(f"accumulate_grad_batches = {accumulate_grad_batches}") lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches model.learning_rate = accumulate_grad_batches * ngpu * bs * trainer.num_nodes * base_lr print("Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (num_nodes) * {} (batchsize) * {:.2e} (base_lr)".format( model.learning_rate, accumulate_grad_batches, ngpu, trainer.num_nodes, bs, base_lr)) # allow checkpointing via USR1 def melk(*args, **kwargs): # run all checkpoint hooks if trainer.global_rank == 0: print("Summoning checkpoint.") ckpt_path = os.path.join(ckptdir, "last.ckpt") trainer.save_checkpoint(ckpt_path) def divein(*args, **kwargs): if trainer.global_rank == 0: import pudb; pudb.set_trace() import signal signal.signal(signal.SIGUSR1, melk) signal.signal(signal.SIGUSR2, divein) # run if opt.train: try: trainer.fit(model, data) except Exception: melk() raise if not opt.no_test and not trainer.interrupted: trainer.test(model, data) except Exception: if opt.debug and trainer.global_rank==0: try: import pudb as debugger except ImportError: import pdb as debugger debugger.post_mortem() raise finally: # move newly created debug project to debug_runs if opt.debug and not opt.resume and trainer.global_rank==0: dst, name = os.path.split(logdir) dst = os.path.join(dst, "debug_runs", name) os.makedirs(os.path.split(dst)[0], exist_ok=True) os.rename(logdir, dst) ================================================ FILE: restoreformer_requirement.txt ================================================ Package Version Location ----------------------- ------------------- ------------------------------------------------------------------------------ absl-py 0.13.0 addict 2.4.0 aiohttp 3.7.4.post0 albumentations 0.4.3 antlr4-python3-runtime 4.8 astunparse 1.6.3 async-timeout 3.0.1 attrs 21.2.0 basicsr 1.3.3.4 cached-property 1.5.2 cachetools 4.2.2 certifi 2021.5.30 chardet 4.0.0 cycler 0.10.0 dlib 19.22.99 facexlib 0.1.3.1 flatbuffers 1.12 fsspec 2021.6.1 future 0.18.2 gast 0.4.0 google-auth 1.32.1 google-auth-oauthlib 0.4.4 google-pasta 0.2.0 grpcio 1.39.0 h5py 3.1.0 idna 2.10 imageio 2.9.0 imgaug 0.2.6 importlib-metadata 4.6.1 joblib 1.0.1 keras-nightly 2.7.0.dev2021072800 Keras-Preprocessing 1.1.2 kiwisolver 1.3.1 libclang 11.1.0 lmdb 1.2.1 Markdown 3.3.4 matplotlib 3.4.2 mkl-fft 1.3.0 mkl-random 1.2.1 mkl-service 2.3.0 multidict 5.1.0 networkx 2.6.1 numpy 1.19.5 oauthlib 3.1.1 olefile 0.46 omegaconf 2.0.0 opencv-python 4.5.2.54 opt-einsum 3.3.0 packaging 21.0 pandas 1.3.0 Pillow 8.3.1 pip 21.1.3 protobuf 3.17.3 pyasn1 0.4.8 pyasn1-modules 0.2.8 pyDeprecate 0.3.0 pyparsing 2.4.7 python-dateutil 2.8.1 pytorch-lightning 1.0.8 pytz 2021.1 PyWavelets 1.1.1 PyYAML 5.4.1 requests 2.25.1 requests-oauthlib 1.3.0 rsa 4.7.2 scikit-image 0.18.2 scikit-learn 0.24.2 scipy 1.7.0 setuptools 52.0.0.post20210125 six 1.15.0 sklearn 0.0 tb-nightly 2.6.0a20210728 tensorboard-data-server 0.6.1 tensorboard-plugin-wit 1.8.0 termcolor 1.1.0 test-tube 0.7.5 tf-estimator-nightly 2.7.0.dev2021072801 tf-nightly 2.7.0.dev20210728 threadpoolctl 2.2.0 tifffile 2021.7.2 torch 1.7.1 torchaudio 0.7.0a0+a853dff torchmetrics 0.4.1 torchvision 0.8.2 tqdm 4.61.2 typing-extensions 3.7.4.3 urllib3 1.26.6 Werkzeug 2.0.1 wheel 0.36.2 wrapt 1.12.1 yapf 0.31.0 yarl 1.6.3 zipp 3.5.0 ================================================ FILE: scripts/metrics/cal_fid.py ================================================ import os, sys import argparse import math import numpy as np import torch from torch.utils.data import DataLoader from basicsr.data import build_dataset from basicsr.metrics.fid import calculate_fid, extract_inception_features, load_patched_inception_v3 def calculate_fid_folder(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') parser = argparse.ArgumentParser() parser.add_argument('folder', type=str, help='Path to the folder.') parser.add_argument('--fid_stats', type=str, help='Path to the dataset fid statistics.') parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--num_sample', type=int, default=50000) parser.add_argument('--num_workers', type=int, default=4) parser.add_argument('--backend', type=str, default='disk', help='io backend for dataset. Option: disk, lmdb') parser.add_argument('--save_name', type=str, default='fid', help='File name for saving results') args = parser.parse_args() # inception model inception = load_patched_inception_v3(device) # create dataset opt = {} opt['name'] = 'SingleImageDataset' opt['type'] = 'SingleImageDataset' opt['dataroot_lq'] = args.folder opt['io_backend'] = dict(type=args.backend) opt['mean'] = [0.5, 0.5, 0.5] opt['std'] = [0.5, 0.5, 0.5] dataset = build_dataset(opt) # create dataloader data_loader = DataLoader( dataset=dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, sampler=None, drop_last=False) args.num_sample = min(args.num_sample, len(dataset)) total_batch = math.ceil(args.num_sample / args.batch_size) def data_generator(data_loader, total_batch): for idx, data in enumerate(data_loader): if idx >= total_batch: break else: yield data['lq'] features = extract_inception_features(data_generator(data_loader, total_batch), inception, total_batch, device) features = features.numpy() total_len = features.shape[0] features = features[:args.num_sample] # print(f'Extracted {total_len} features, ' f'use the first {features.shape[0]} features to calculate stats.') sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) # load the dataset stats stats = torch.load(args.fid_stats) real_mean = stats['mean'] real_cov = stats['cov'] # calculate FID metric fid = calculate_fid(sample_mean, sample_cov, real_mean, real_cov) fout=open(args.save_name, 'w') fout.write(str(fid)+'\n') fout.close() print(args.folder) print('fid:', fid) if __name__ == '__main__': calculate_fid_folder() ================================================ FILE: scripts/metrics/cal_identity_distance.py ================================================ import os, sys import torch import argparse import cv2 import numpy as np import glob import pdb import tqdm from copy import deepcopy import torch.nn.functional as F import math root_path = os.path.abspath(os.path.join(__file__, os.path.pardir, os.path.pardir, os.path.pardir)) sys.path.append(root_path) sys.path.append(os.path.join(root_path, 'RestoreFormer/modules/losses')) from RestoreFormer.modules.vqvae.arcface_arch import ResNetArcFace from basicsr.losses.losses import L1Loss, MSELoss def cosine_similarity(emb1, emb2): return np.arccos(np.dot(emb1, emb2) / ( np.linalg.norm(emb1) * np.linalg.norm(emb2))) def gray_resize_for_identity(out, size=128): out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :]) out_gray = out_gray.unsqueeze(1) out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False) return out_gray def calculate_identity_distance_folder(): parser = argparse.ArgumentParser() parser.add_argument('folder', type=str, help='Path to the folder') parser.add_argument('--gt_folder', type=str, help='Path to the GT') parser.add_argument('--save_name', type=str, default='niqe', help='File name for saving results') parser.add_argument('--need_post', type=int, default=0, help='0: the name of image does not include 00, 1: otherwise') args = parser.parse_args() fout = open(args.save_name, 'w') identity = ResNetArcFace(block = 'IRBlock', layers = [2, 2, 2, 2], use_se = False) identity_model_path = 'experiments/pretrained_models/arcface_resnet18.pth' sd = torch.load(identity_model_path, map_location="cpu") for k, v in deepcopy(sd).items(): if k.startswith('module.'): sd[k[7:]] = v sd.pop(k) identity.load_state_dict(sd, strict=True) identity.eval() for param in identity.parameters(): param.requires_grad = False identity = identity.cuda() gt_names = glob.glob(os.path.join(args.gt_folder, '*')) gt_names.sort() mean_dist = 0. for i in tqdm.tqdm(range(len(gt_names))): gt_name = gt_names[i].split('/')[-1][:-4] if args.need_post: img_name = os.path.join(args.folder,gt_name + '_00.png') else: img_name = os.path.join(args.folder,gt_name + '.png') if not os.path.exists(img_name): print(img_name, 'does not exist') continue img = cv2.imread(img_name) gt = cv2.imread(gt_names[i]) img = img.astype(np.float32) / 255. img = torch.FloatTensor(img).cuda() img = img.permute(2,0,1) img = img.unsqueeze(0) gt = gt.astype(np.float32) / 255. gt = torch.FloatTensor(gt).cuda() gt = gt.permute(2,0,1) gt = gt.unsqueeze(0) out_gray = gray_resize_for_identity(img) gt_gray = gray_resize_for_identity(gt) with torch.no_grad(): identity_gt = identity(gt_gray) identity_out = identity(out_gray) identity_gt = identity_gt.cpu().data.numpy().squeeze() identity_out = identity_out.cpu().data.numpy().squeeze() identity_loss = cosine_similarity(identity_gt, identity_out) fout.write(gt_name + ' ' + str(identity_loss) + '\n') mean_dist += identity_loss fout.write('Mean: ' + str(mean_dist / len(gt_names)) + '\n') fout.close() print('mean_dist:', mean_dist / len(gt_names)) if __name__ == '__main__': calculate_identity_distance_folder() ================================================ FILE: scripts/metrics/cal_psnr_ssim.py ================================================ import os, sys import argparse import cv2 import numpy as np import glob import pdb import tqdm import torch from basicsr.metrics.psnr_ssim import calculate_psnr, calculate_ssim root_path = os.path.abspath(os.path.join(__file__, os.path.pardir, os.path.pardir, os.path.pardir)) sys.path.append(root_path) sys.path.append(os.path.join(root_path, 'RestoreFormer/modules/losses')) from lpips import LPIPS def calculate_psnr_ssim_lpips_folder(): parser = argparse.ArgumentParser() parser.add_argument('folder', type=str, help='Path to the folder') parser.add_argument('--gt_folder', type=str, help='Path to the GT') parser.add_argument('--save_name', type=str, default='niqe', help='File name for saving results') parser.add_argument('--need_post', type=int, default=0, help='0: the name of image does not include 00, 1: otherwise') args = parser.parse_args() fout = open(args.save_name, 'w') fout.write('NAME\tPSNR\tSSIM\tLPIPS\n') H, W = 512, 512 gt_names = glob.glob(os.path.join(args.gt_folder, '*')) gt_names.sort() perceptual_loss = LPIPS().eval().cuda() mean_psnr = 0. mean_ssim = 0. mean_lpips = 0. mean_norm_lpips = 0. for i in tqdm.tqdm(range(len(gt_names))): gt_name = gt_names[i].split('/')[-1][:-4] if args.need_post: img_name = os.path.join(args.folder,gt_name + '_00.png') else: img_name = os.path.join(args.folder,gt_name + '.png') if not os.path.exists(img_name): print(img_name, 'does not exist') continue img = cv2.imread(img_name) gt = cv2.imread(gt_names[i]) cur_psnr = calculate_psnr(img, gt, 0) cur_ssim = calculate_ssim(img, gt, 0) # lpips: img = img.astype(np.float32) / 255. img = torch.FloatTensor(img).cuda() img = img.permute(2,0,1) img = img.unsqueeze(0) gt = gt.astype(np.float32) / 255. gt = torch.FloatTensor(gt).cuda() gt = gt.permute(2,0,1) gt = gt.unsqueeze(0) cur_lpips = perceptual_loss(img, gt) cur_lpips = cur_lpips[0].item() img = (img - 0.5) / 0.5 gt = (gt - 0.5) / 0.5 norm_lpips = perceptual_loss(img, gt) norm_lpips = norm_lpips[0].item() # print(cur_psnr, cur_ssim, cur_lpips, norm_lpips) fout.write(gt_name + '\t' + str(cur_psnr) + '\t' + str(cur_ssim) + '\t' + str(cur_lpips) + '\t' + str(norm_lpips) + '\n') mean_psnr += cur_psnr mean_ssim += cur_ssim mean_lpips += cur_lpips mean_norm_lpips += norm_lpips mean_psnr /= float(len(gt_names)) mean_ssim /= float(len(gt_names)) mean_lpips /= float(len(gt_names)) mean_norm_lpips /= float(len(gt_names)) fout.write(str(mean_psnr) + '\t' + str(mean_ssim) + '\t' + str(mean_lpips) + '\t' + str(mean_norm_lpips) + '\n') fout.close() print('psnr, ssim, lpips, norm_lpips:', mean_psnr, mean_ssim, mean_lpips, mean_norm_lpips) if __name__ == '__main__': calculate_psnr_ssim_lpips_folder() ================================================ FILE: scripts/metrics/run.sh ================================================ ### Journal ### root='results/' out_root='results/metrics' test_name='RestoreFormer' test_image=$test_name'/restored_faces' out_name=$test_name need_post=1 CelebAHQ_GT='YOUR_PATH' # FID python -u scripts/metrics/cal_fid.py \ $root'/'$test_image \ --fid_stats 'experiments/pretrained_models/inception_FFHQ_512-f7b384ab.pth' \ --save_name $out_root'/'$out_name'_fid.txt' \ if [ -d $CelebAHQ_GT ] then # PSRN SSIM LPIPS python -u scripts/metrics/cal_psnr_ssim.py \ $root'/'$test_image \ --gt_folder $CelebAHQ_GT \ --save_name $out_root'/'$out_name'_psnr_ssim_lpips.txt' \ --need_post $need_post \ # # # PSRN SSIM LPIPS python -u scripts/metrics/cal_identity_distance.py \ $root'/'$test_image \ --gt_folder $CelebAHQ_GT \ --save_name $out_root'/'$out_name'_id.txt' \ --need_post $need_post else echo 'The path of GT does not exist' fi ================================================ FILE: scripts/run.sh ================================================ export BASICSR_JIT=True conf_name='HQ_Dictionary' # conf_name='RestoreFormer' ROOT_PATH='' # The path for saving model and logs gpus='0,1,2,3' #P: pretrain SL: soft learning node_n=1 python -u main.py \ --root-path $ROOT_PATH \ --base 'configs/'$conf_name'.yaml' \ -t True \ --postfix $conf_name \ --gpus $gpus \ --num-nodes $node_n \ --random-seed True \ ================================================ FILE: scripts/test.py ================================================ import argparse, os, sys, glob, math, time import torch import numpy as np from omegaconf import OmegaConf from PIL import Image import pdb sys.path.append(os.getcwd()) from main import instantiate_from_config, DataModuleFromConfig from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate from tqdm import trange, tqdm import cv2 from facexlib.utils.face_restoration_helper import FaceRestoreHelper from torchvision.transforms.functional import normalize from basicsr.utils import img2tensor, imwrite, tensor2img def restoration(model, face_helper, img_path, save_root, has_aligned=False, only_center_face=True, suffix=None, paste_back=False): # read image img_name = os.path.basename(img_path) # print(f'Processing {img_name} ...') basename, _ = os.path.splitext(img_name) input_img = cv2.imread(img_path, cv2.IMREAD_COLOR) face_helper.clean_all() if has_aligned: input_img = cv2.resize(input_img, (512, 512)) face_helper.cropped_faces = [input_img] else: face_helper.read_image(input_img) # get face landmarks for each face face_helper.get_face_landmarks_5(only_center_face=only_center_face, pad_blur=False) # align and warp each face save_crop_path = os.path.join(save_root, 'cropped_faces', img_name) face_helper.align_warp_face(save_crop_path) # face restoration for idx, cropped_face in enumerate(face_helper.cropped_faces): # prepare data cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) cropped_face_t = cropped_face_t.unsqueeze(0).to('cuda') try: with torch.no_grad(): output = model(cropped_face_t) restored_face = tensor2img(output[0].squeeze(0), rgb2bgr=True, min_max=(-1, 1)) except RuntimeError as error: print(f'\tFailed inference for GFPGAN: {error}.') restored_face = cropped_face restored_face = restored_face.astype('uint8') face_helper.add_restored_face(restored_face) if suffix is not None: save_face_name = f'{basename}_{idx:02d}_{suffix}.png' else: save_face_name = f'{basename}_{idx:02d}.png' save_restore_path = os.path.join(save_root, 'restored_faces', save_face_name) imwrite(restored_face, save_restore_path) if not has_aligned and paste_back: face_helper.get_inverse_affine(None) save_restore_path = os.path.join(save_root, 'restored_imgs', img_name) # paste each restored face to the input image face_helper.paste_faces_to_input_image(save_restore_path) def get_parser(): parser = argparse.ArgumentParser() parser.add_argument( "-r", "--resume", type=str, nargs="?", help="load from logdir or checkpoint in logdir", ) parser.add_argument( "-b", "--base", nargs="*", metavar="base_config.yaml", help="paths to base configs. Loaded from left-to-right. " "Parameters can be overwritten or added with command-line options of the form `--key value`.", default=list(), ) parser.add_argument( "-c", "--config", nargs="?", metavar="single_config.yaml", help="path to single config. If specified, base configs will be ignored " "(except for the last one if left unspecified).", const=True, default="", ) parser.add_argument( "--ignore_base_data", action="store_true", help="Ignore data specification from base configs. Useful if you want " "to specify a custom datasets on the command line.", ) parser.add_argument( "--outdir", required=True, type=str, help="Where to write outputs to.", ) parser.add_argument( "--top_k", type=int, default=100, help="Sample from among top-k predictions.", ) parser.add_argument( "--temperature", type=float, default=1.0, help="Sampling temperature.", ) parser.add_argument('--upscale_factor', type=int, default=1) parser.add_argument('--test_path', type=str, default='inputs/whole_imgs') parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces') parser.add_argument('--only_center_face', action='store_true') parser.add_argument('--aligned', action='store_true') parser.add_argument('--paste_back', action='store_true') return parser def load_model_from_config(config, sd, gpu=True, eval_mode=True): if "ckpt_path" in config.params: print("Deleting the restore-ckpt path from the config...") config.params.ckpt_path = None if "downsample_cond_size" in config.params: print("Deleting downsample-cond-size from the config and setting factor=0.5 instead...") config.params.downsample_cond_size = -1 config.params["downsample_cond_factor"] = 0.5 try: if "ckpt_path" in config.params.first_stage_config.params: config.params.first_stage_config.params.ckpt_path = None print("Deleting the first-stage restore-ckpt path from the config...") if "ckpt_path" in config.params.cond_stage_config.params: config.params.cond_stage_config.params.ckpt_path = None print("Deleting the cond-stage restore-ckpt path from the config...") except: pass model = instantiate_from_config(config) if sd is not None: keys = list(sd.keys()) state_dict = model.state_dict() require_keys = state_dict.keys() keys = sd.keys() un_pretrained_keys = [] for k in require_keys: if k not in keys: # miss 'vqvae.' if k[6:] in keys: state_dict[k] = sd[k[6:]] else: un_pretrained_keys.append(k) else: state_dict[k] = sd[k] # print(f'*************************************************') # print(f"Layers without pretraining: {un_pretrained_keys}") # print(f'*************************************************') model.load_state_dict(state_dict, strict=True) if gpu: model.cuda() if eval_mode: model.eval() return {"model": model} def load_model_and_dset(config, ckpt, gpu, eval_mode): # now load the specified checkpoint if ckpt: pl_sd = torch.load(ckpt, map_location="cpu") else: pl_sd = {"state_dict": None} model = load_model_from_config(config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"] return model if __name__ == "__main__": sys.path.append(os.getcwd()) parser = get_parser() opt, unknown = parser.parse_known_args() ckpt = None if opt.resume: if not os.path.exists(opt.resume): raise ValueError("Cannot find {}".format(opt.resume)) if os.path.isfile(opt.resume): paths = opt.resume.split("/") try: idx = len(paths)-paths[::-1].index("logs")+1 except ValueError: idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt logdir = "/".join(paths[:idx]) ckpt = opt.resume else: assert os.path.isdir(opt.resume), opt.resume logdir = opt.resume.rstrip("/") ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") print(f"logdir:{logdir}") base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml"))) opt.base = base_configs+opt.base if opt.config: if type(opt.config) == str: if not os.path.exists(opt.config): raise ValueError("Cannot find {}".format(opt.config)) if os.path.isfile(opt.config): opt.base = [opt.config] else: opt.base = sorted(glob.glob(os.path.join(opt.config, "*-project.yaml"))) else: opt.base = [opt.base[-1]] configs = [OmegaConf.load(cfg) for cfg in opt.base] cli = OmegaConf.from_dotlist(unknown) if opt.ignore_base_data: for config in configs: if hasattr(config, "data"): del config["data"] config = OmegaConf.merge(*configs, cli) print(config) gpu = True eval_mode = True show_config = False if show_config: print(OmegaConf.to_container(config)) model = load_model_and_dset(config, ckpt, gpu, eval_mode) outdir = opt.outdir os.makedirs(outdir, exist_ok=True) print("Writing samples to ", outdir) # initialize face helper face_helper = FaceRestoreHelper( opt.upscale_factor, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png') img_list = sorted(glob.glob(os.path.join(opt.test_path, '*'))) print('Results are in the <{}> folder.'.format(outdir)) for img_path in tqdm(img_list): restoration( model, face_helper, img_path, outdir, has_aligned=opt.aligned, only_center_face=opt.only_center_face, suffix=opt.suffix, paste_back=opt.paste_back) print('Test number: ', len(img_list)) print('Results are in the <{}> folder.'.format(outdir)) ================================================ FILE: scripts/test.sh ================================================ # # ### Good exp_name='RestoreFormer' root_path='experiments' out_root_path='results' align_test_path='data/test' tag='test' outdir=$out_root_path'/'$exp_name'_'$tag if [ ! -d $outdir ];then mkdir $outdir fi python -u scripts/test.py \ --outdir $outdir \ -r $root_path'/'$exp_name'/last.ckpt' \ -c 'configs/RestoreFormer.yaml' \ --test_path $align_test_path \ --aligned