Full Code of wzhouxiff/RestoreFormer for AI

main 294cf9521a86 cached
30 files
158.8 KB
39.9k tokens
172 symbols
1 requests
Download .txt
Repository: wzhouxiff/RestoreFormer
Branch: main
Commit: 294cf9521a86
Files: 30
Total size: 158.8 KB

Directory structure:
gitextract_istv2bdf/

├── .gitignore
├── LICENSE
├── README.md
├── RestoreFormer/
│   ├── data/
│   │   └── ffhq_degradation_dataset.py
│   ├── distributed/
│   │   ├── __init__.py
│   │   ├── distributed.py
│   │   └── launch.py
│   ├── models/
│   │   └── vqgan_v1.py
│   ├── modules/
│   │   ├── discriminator/
│   │   │   └── model.py
│   │   ├── losses/
│   │   │   ├── __init__.py
│   │   │   ├── lpips.py
│   │   │   └── vqperceptual.py
│   │   ├── util.py
│   │   └── vqvae/
│   │       ├── arcface_arch.py
│   │       ├── facial_component_discriminator.py
│   │       ├── utils.py
│   │       └── vqvae_arch.py
│   └── util.py
├── __init__.py
├── configs/
│   ├── HQ_Dictionary.yaml
│   └── RestoreFormer.yaml
├── main.py
├── restoreformer_requirement.txt
└── scripts/
    ├── metrics/
    │   ├── cal_fid.py
    │   ├── cal_identity_distance.py
    │   ├── cal_psnr_ssim.py
    │   └── run.sh
    ├── run.sh
    ├── test.py
    └── test.sh

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
data/FFHQ
scripts/data_synthetic
experiments/
scripts/run_clustre.sh
sftp-config.json
results/
# scripts/metrics


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# We have merged the code of RestoreFormer into our journal version, RestoreFormer++. Please feel free to access the resources from [https://github.com/wzhouxiff/RestoreFormerPlusPlus](https://github.com/wzhouxiff/RestoreFormerPlusPlus)

# Updating
- **20230915** Update an online demo [![Huggingface Gradio](https://img.shields.io/static/v1?label=Demo&message=Huggingface%20Gradio&color=orange)](https://huggingface.co/spaces/wzhouxiff/RestoreFormerPlusPlus)
- **20230915** A more user-friendly and comprehensive inference method refer to our [RestoreFormer++](https://github.com/wzhouxiff/RestoreFormerPlusPlus)
- **20230116** For convenience, we further upload the [test datasets](#testset), including CelebA (both HQ and LQ data), LFW-Test, CelebChild-Test, and Webphoto-Test, to OneDrive and BaiduYun.
- **20221003** We provide the link of the [test datasets](#testset).
- **20220924** We add the code for [**metrics**](#metrics) in scripts/metrics.


<!--
# RestoreFormer

This repo includes the source code of the paper: "[RestoreFormer: High-Quality Blind Face Restoration from Undegraded Key-Value Pairs](https://openaccess.thecvf.com/content/CVPR2022/papers/Wang_RestoreFormer_High-Quality_Blind_Face_Restoration_From_Undegraded_Key-Value_Pairs_CVPR_2022_paper.pdf)" (CVPR 2022) by Zhouxia Wang, Jiawei Zhang, Runjian Chen, Wenping Wang, and Ping Luo.

![](assets/figure1.png)

**RestoreFormer** tends to explore fully-spatial attentions to model contextual information and surpasses existing works that use local operators. It has several benefits compared to prior arts. First, it incorporates a multi-head coross-attention layer to learn fully-spatial interations between corrupted queries and high-quality key-value pairs. Second, the key-value pairs in RestoreFormer are sampled from a reconstruction-oriented high-quality dictionary, whose elements are rich in high-quality facial features specifically aimed for face reconstruction.

-->

<!-- ![](assets/framework.png "Framework")-->

<!--

## Environment

- python>=3.7
- pytorch>=1.7.1
- pytorch-lightning==1.0.8
- omegaconf==2.0.0
- basicsr==1.3.3.4

**Warning** Different versions of pytorch-lightning and omegaconf may lead to errors or different results.

## Preparations of dataset and models

**Dataset**: 
- Training data: Both **HQ Dictionary** and **RestoreFormer** in our work are trained with **FFHQ** which attained from [FFHQ repository](https://github.com/NVlabs/ffhq-dataset). The original size of the images in FFHQ are 1024x1024. We resize them to 512x512 with bilinear interpolation in our work. Link this dataset to ./data/FFHQ/image512x512.
- <a id="testset">Test data</a>: 
   * CelebA-Test-HQ: [OneDrive](https://connecthkuhk-my.sharepoint.com/:u:/g/personal/wzhoux_connect_hku_hk/EY7P-MReZUZOngy3UGa5abUBJKel1IH5uYZLdwp2e2KvUw?e=rK0VWh); [BaiduYun](https://pan.baidu.com/s/1tMpxz8lIW50U8h00047GIw?pwd=mp9t)(code mp9t)
   * CelebA-Test-LQ: [OneDrive](https://connecthkuhk-my.sharepoint.com/:u:/g/personal/wzhoux_connect_hku_hk/EXULDOtX3qdKg9_--k-hbr4BumxOUAi19iQjZNz75S6pKA?e=Kghqri); [BaiduYun](https://pan.baidu.com/s/1y6ZcQPCLyggj9VB5MgoWyg?pwd=7s6h)(code 7s6h)
   * LFW-Test: [OneDrive](https://connecthkuhk-my.sharepoint.com/:u:/g/personal/wzhoux_connect_hku_hk/EZ7ibkhUuRxBjdd-MesczpgBfpLVfv-9uYVskLuZiYpBsg?e=xPNH26); [BaiduYun](https://pan.baidu.com/s/1UkfYLTViL8XVdZ-Ej-2G9g?pwd=7fhr)(code 7fhr). Note that it was align with dlib.
   * CelebChild: [OneDrive](https://connecthkuhk-my.sharepoint.com/:u:/g/personal/wzhoux_connect_hku_hk/ESK6vjLzDuJAsd-cfWrfl20BTeSD_w4uRNJREGfl3zGzJg?e=Tou7ft); [BaiduYun](https://pan.baidu.com/s/1pGCD4TkhtDsmp8emZd8smA?pwd=rq65)(code rq65)
   * WepPhoto-Test: [OneDrive](https://connecthkuhk-my.sharepoint.com/:u:/g/personal/wzhoux_connect_hku_hk/ER1-0eYKGkZIs-YEDhNW0xIBohCI5IEZyAS2PAvI81Stcg?e=TFJFGh); [BaiduYun](https://pan.baidu.com/s/1SjBfinSL1F-bbOpXiD0nlw?pwd=nren)(code nren)

**Model**: Both pretrained models used for training and the trained model of our RestoreFormer can be attained from [OneDrive](https://connecthkuhk-my.sharepoint.com/:u:/g/personal/wzhoux_connect_hku_hk/Eb73S2jXZIxNrrOFRnFKu2MBTe7kl4cMYYwwiudAmDNwYg?e=Xa4ZDf) or [BaiduYun](https://pan.baidu.com/s/1EO7_1dYyCuORpPNosQgogg?pwd=x6nn)(code x6nn). Link these models to ./experiments.

## Test
    sh scripts/test.sh

## Training
    sh scripts/run.sh

**Note**. 
- The first stage is to attain **HQ Dictionary** by setting `conf_name` in scripts/run.sh to 'HQ\_Dictionary'. 
- The second stage is blind face restoration. You need to add your trained HQ\_Dictionary model to `ckpt_path` in config/RestoreFormer.yaml and set `conf_name` in scripts/run.sh to 'RestoreFormer'.
- Our model is trained with 4 V100 GPUs.

## <a id="metrics">Metrics</a>
    sh scripts/metrics/run.sh
    
**Note**. 
- You need to add the path of CelebA-Test dataset in the script if you want get IDD, PSRN, SSIM, LIPIS.

## Citation
    @article{wang2022restoreformer,
      title={RestoreFormer: High-Quality Blind Face Restoration from Undegraded Key-Value Pairs},
      author={Wang, Zhouxia and Zhang, Jiawei and Chen, Runjian and Wang, Wenping and Luo, Ping},
      booktitle={The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
      year={2022}
    }

## Acknowledgement
We thank everyone who makes their code and models available, especially [Taming Transformer](https://github.com/CompVis/taming-transformers), [basicsr](https://github.com/XPixelGroup/BasicSR), and [GFPGAN](https://github.com/TencentARC/GFPGAN).

## Contact
For any question, feel free to email `wzhoux@connect.hku.hk` or `zhouzi1212@gmail.com`.

-->


================================================
FILE: RestoreFormer/data/ffhq_degradation_dataset.py
================================================
import os
import cv2
import math
import numpy as np
import random
import os.path as osp
import torch
import torch.utils.data as data
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
                                               normalize)

from basicsr.data import degradations as degradations
from basicsr.data.data_util import paths_from_folder
from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY


@DATASET_REGISTRY.register()
class FFHQDegradationDataset(data.Dataset):

    def __init__(self, opt):
        super(FFHQDegradationDataset, self).__init__()
        self.opt = opt
        # file client (io backend)
        self.file_client = None
        self.io_backend_opt = opt['io_backend']

        self.gt_folder = opt['dataroot_gt']
        self.mean = opt['mean']
        self.std = opt['std']
        self.out_size = opt['out_size']

        self.crop_components = opt.get('crop_components', False)  # facial components
        self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1)

        if self.crop_components:
            self.components_list = torch.load(opt.get('component_path'))

        if self.io_backend_opt['type'] == 'lmdb':
            self.io_backend_opt['db_paths'] = self.gt_folder
            if not self.gt_folder.endswith('.lmdb'):
                raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
            with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
                self.paths = [line.split('.')[0] for line in fin]
        else:
            self.paths = paths_from_folder(self.gt_folder)

        # degradations
        self.blur_kernel_size = opt['blur_kernel_size']
        self.kernel_list = opt['kernel_list']
        self.kernel_prob = opt['kernel_prob']
        self.blur_sigma = opt['blur_sigma']
        self.downsample_range = opt['downsample_range']
        self.noise_range = opt['noise_range']
        self.jpeg_range = opt['jpeg_range']

        # color jitter
        self.color_jitter_prob = opt.get('color_jitter_prob')
        self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob')
        self.color_jitter_shift = opt.get('color_jitter_shift', 20)
        # to gray
        self.gray_prob = opt.get('gray_prob')

        logger = get_root_logger()
        logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, '
                    f'sigma: [{", ".join(map(str, self.blur_sigma))}]')
        logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
        logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
        logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')

        if self.color_jitter_prob is not None:
            logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, '
                        f'shift: {self.color_jitter_shift}')
        if self.gray_prob is not None:
            logger.info(f'Use random gray. Prob: {self.gray_prob}')

        self.color_jitter_shift /= 255.


    @staticmethod
    def color_jitter(img, shift):
        jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
        img = img + jitter_val
        img = np.clip(img, 0, 1)
        return img

    @staticmethod
    def color_jitter_pt(img, brightness, contrast, saturation, hue):
        fn_idx = torch.randperm(4)
        for fn_id in fn_idx:
            if fn_id == 0 and brightness is not None:
                brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
                img = adjust_brightness(img, brightness_factor)

            if fn_id == 1 and contrast is not None:
                contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
                img = adjust_contrast(img, contrast_factor)

            if fn_id == 2 and saturation is not None:
                saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
                img = adjust_saturation(img, saturation_factor)

            if fn_id == 3 and hue is not None:
                hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
                img = adjust_hue(img, hue_factor)
        return img

    def get_component_coordinates(self, index, status):
        components_bbox = self.components_list[f'{index:08d}']
        if status[0]:  # hflip
            # exchange right and left eye
            tmp = components_bbox['left_eye']
            components_bbox['left_eye'] = components_bbox['right_eye']
            components_bbox['right_eye'] = tmp
            # modify the width coordinate
            components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0]
            components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0]
            components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0]

        # get coordinates
        locations = []
        for part in ['left_eye', 'right_eye', 'mouth']:
            mean = components_bbox[part][0:2]
            half_len = components_bbox[part][2]
            if 'eye' in part:
                half_len *= self.eye_enlarge_ratio
            loc = np.hstack((mean - half_len + 1, mean + half_len))
            loc = torch.from_numpy(loc).float()
            locations.append(loc)
        return locations

    def __getitem__(self, index):
        if self.file_client is None:
            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)

        # load gt image
        gt_path = self.paths[index]
        img_bytes = self.file_client.get(gt_path)
        img_gt = imfrombytes(img_bytes, float32=True)

        # random horizontal flip
        img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
        h, w, _ = img_gt.shape

        if self.crop_components:
            locations = self.get_component_coordinates(index, status)
            loc_left_eye, loc_right_eye, loc_mouth = locations

        # ------------------------ generate lq image ------------------------ #
        # blur
        assert self.blur_kernel_size[0] < self.blur_kernel_size[1], 'Wrong blur kernel size range'
        cur_kernel_size = random.randint(self.blur_kernel_size[0],self.blur_kernel_size[1]) * 2 + 1
        kernel = degradations.random_mixed_kernels(
            self.kernel_list,
            self.kernel_prob,
            cur_kernel_size,
            self.blur_sigma,
            self.blur_sigma, [-math.pi, math.pi],
            noise_range=None)
        img_lq = cv2.filter2D(img_gt, -1, kernel)
        # downsample
        scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
        img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
        # noise
        if self.noise_range is not None:
            img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range)
        # jpeg compression
        if self.jpeg_range is not None:
            img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range)

        # resize to original size
        img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)

        # random color jitter (only for lq)
        if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
            img_lq = self.color_jitter(img_lq, self.color_jitter_shift)
        # random to gray (only for lq)
        if self.gray_prob and np.random.uniform() < self.gray_prob:
            img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
            img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
            if self.opt.get('gt_gray'):
                img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
                img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])

        # BGR to RGB, HWC to CHW, numpy to tensor
        img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)

        # random color jitter (pytorch version) (only for lq)
        if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
            brightness = self.opt.get('brightness', (0.5, 1.5))
            contrast = self.opt.get('contrast', (0.5, 1.5))
            saturation = self.opt.get('saturation', (0, 1.5))
            hue = self.opt.get('hue', (-0.1, 0.1))
            img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue)

        # round and clip
        img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.

        # normalize
        normalize(img_gt, self.mean, self.std, inplace=True)
        normalize(img_lq, self.mean, self.std, inplace=True)

        return_dict = {
                'lq': img_lq,
                'gt': img_gt,
                'gt_path': gt_path
            }
        if self.crop_components:
            return_dict['loc_left_eye'] = loc_left_eye
            return_dict['loc_right_eye'] = loc_right_eye
            return_dict['loc_mouth'] = loc_mouth


        return return_dict

    def __len__(self):
        return len(self.paths)

import argparse
from omegaconf import OmegaConf
import pdb
from basicsr.utils import img2tensor, imwrite, tensor2img

if __name__=='__main__':
    # pdb.set_trace()
    base='configs/RestoreFormer.yaml'

    opt = OmegaConf.load(base)
    dataset = FFHQDegradationDataset(opt['data']['params']['train']['params'])

    for i in range(100):
        sample = dataset.getitem(i)
        name = sample['gt_path'].split('/')[-1][:-4]
        gt = tensor2img(sample['gt'])
        imwrite(gt, +name+'_gt.png')
        lq = tensor2img(sample['lq'])
        imwrite(lq, name+'_lq_nojitter.png')

================================================
FILE: RestoreFormer/distributed/__init__.py
================================================
from .distributed import (
    get_rank,
    get_local_rank,
    is_primary,
    synchronize,
    get_world_size,
    all_reduce,
    all_gather,
    reduce_dict,
    data_sampler,
    LOCAL_PROCESS_GROUP,
)
from .launch import launch


================================================
FILE: RestoreFormer/distributed/distributed.py
================================================
import math
import pickle

import torch
from torch import distributed as dist
from torch.utils import data


LOCAL_PROCESS_GROUP = None


def is_primary():
    return get_rank() == 0


def get_rank():
    if not dist.is_available():
        return 0

    if not dist.is_initialized():
        return 0

    return dist.get_rank()


def get_local_rank():
    if not dist.is_available():
        return 0

    if not dist.is_initialized():
        return 0

    if LOCAL_PROCESS_GROUP is None:
        raise ValueError("tensorfn.distributed.LOCAL_PROCESS_GROUP is None")

    return dist.get_rank(group=LOCAL_PROCESS_GROUP)


def synchronize():
    if not dist.is_available():
        return

    if not dist.is_initialized():
        return

    world_size = dist.get_world_size()

    if world_size == 1:
        return

    dist.barrier()


def get_world_size():
    if not dist.is_available():
        return 1

    if not dist.is_initialized():
        return 1

    return dist.get_world_size()


def all_reduce(tensor, op=dist.ReduceOp.SUM):
    world_size = get_world_size()

    if world_size == 1:
        return tensor

    dist.all_reduce(tensor, op=op)

    return tensor


def all_gather(data):
    world_size = get_world_size()

    if world_size == 1:
        return [data]

    buffer = pickle.dumps(data)
    storage = torch.ByteStorage.from_buffer(buffer)
    tensor = torch.ByteTensor(storage).to("cuda")

    local_size = torch.IntTensor([tensor.numel()]).to("cuda")
    size_list = [torch.IntTensor([1]).to("cuda") for _ in range(world_size)]
    dist.all_gather(size_list, local_size)
    size_list = [int(size.item()) for size in size_list]
    max_size = max(size_list)

    tensor_list = []
    for _ in size_list:
        tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))

    if local_size != max_size:
        padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
        tensor = torch.cat((tensor, padding), 0)

    dist.all_gather(tensor_list, tensor)

    data_list = []

    for size, tensor in zip(size_list, tensor_list):
        buffer = tensor.cpu().numpy().tobytes()[:size]
        data_list.append(pickle.loads(buffer))

    return data_list


def reduce_dict(input_dict, average=True):
    world_size = get_world_size()

    if world_size < 2:
        return input_dict

    with torch.no_grad():
        keys = []
        values = []

        for k in sorted(input_dict.keys()):
            keys.append(k)
            values.append(input_dict[k])

        values = torch.stack(values, 0)
        dist.reduce(values, dst=0)

        if dist.get_rank() == 0 and average:
            values /= world_size

        reduced_dict = {k: v for k, v in zip(keys, values)}

    return reduced_dict


def data_sampler(dataset, shuffle, distributed):
    if distributed:
        return data.distributed.DistributedSampler(dataset, shuffle=shuffle)

    if shuffle:
        return data.RandomSampler(dataset)

    else:
        return data.SequentialSampler(dataset)


================================================
FILE: RestoreFormer/distributed/launch.py
================================================
import os

import torch
from torch import distributed as dist
from torch import multiprocessing as mp

from . import distributed as dist_fn


def find_free_port():
    import socket

    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    sock.bind(("", 0))
    port = sock.getsockname()[1]
    sock.close()

    return port


def launch(fn, n_gpu_per_machine, n_machine=1, machine_rank=0, dist_url=None, args=()):
    world_size = n_machine * n_gpu_per_machine

    if world_size > 1:
        if "OMP_NUM_THREADS" not in os.environ:
            os.environ["OMP_NUM_THREADS"] = "1"

        if dist_url == "auto":
            if n_machine != 1:
                raise ValueError('dist_url="auto" not supported in multi-machine jobs')

            port = find_free_port()
            dist_url = f"tcp://127.0.0.1:{port}"

        if n_machine > 1 and dist_url.startswith("file://"):
            raise ValueError(
                "file:// is not a reliable init method in multi-machine jobs. Prefer tcp://"
            )

        mp.spawn(
            distributed_worker,
            nprocs=n_gpu_per_machine,
            args=(fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args),
            daemon=False,
        )

    else:
        fn(*args)


def distributed_worker(
    local_rank, fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args
):
    if not torch.cuda.is_available():
        raise OSError("CUDA is not available. Please check your environments")

    global_rank = machine_rank * n_gpu_per_machine + local_rank

    try:
        dist.init_process_group(
            backend="NCCL",
            init_method=dist_url,
            world_size=world_size,
            rank=global_rank,
        )

    except Exception:
        raise OSError("failed to initialize NCCL groups")

    dist_fn.synchronize()

    if n_gpu_per_machine > torch.cuda.device_count():
        raise ValueError(
            f"specified n_gpu_per_machine larger than available device ({torch.cuda.device_count()})"
        )

    torch.cuda.set_device(local_rank)

    if dist_fn.LOCAL_PROCESS_GROUP is not None:
        raise ValueError("torch.distributed.LOCAL_PROCESS_GROUP is not None")

    n_machine = world_size // n_gpu_per_machine

    for i in range(n_machine):
        ranks_on_i = list(range(i * n_gpu_per_machine, (i + 1) * n_gpu_per_machine))
        pg = dist.new_group(ranks_on_i)

        if i == machine_rank:
            dist_fn.distributed.LOCAL_PROCESS_GROUP = pg

    fn(*args)


================================================
FILE: RestoreFormer/models/vqgan_v1.py
================================================
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from main import instantiate_from_config

from RestoreFormer.modules.vqvae.utils import get_roi_regions

class RestoreFormerModel(pl.LightningModule):
    def __init__(self,
                 ddconfig,
                 lossconfig,
                 ckpt_path=None,
                 ignore_keys=[],
                 image_key="lq",
                 colorize_nlabels=None,
                 monitor=None,
                 special_params_lr_scale=1.0,
                 comp_params_lr_scale=1.0,
                 schedule_step=[80000, 200000]
                 ):
        super().__init__()
        self.image_key = image_key
        self.vqvae = instantiate_from_config(ddconfig)

        lossconfig['params']['distill_param']=ddconfig['params']
        self.loss = instantiate_from_config(lossconfig)
        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)

        
        if ('comp_weight' in lossconfig['params'] and lossconfig['params']['comp_weight']) or ('comp_style_weight' in lossconfig['params'] and lossconfig['params']['comp_style_weight']):
            self.use_facial_disc = True
        else:
            self.use_facial_disc = False

        self.fix_decoder = ddconfig['params']['fix_decoder']
        
        self.disc_start = lossconfig['params']['disc_start']
        self.special_params_lr_scale = special_params_lr_scale
        self.comp_params_lr_scale = comp_params_lr_scale
        self.schedule_step = schedule_step

    def init_from_ckpt(self, path, ignore_keys=list()):
        sd = torch.load(path, map_location="cpu")["state_dict"]
        keys = list(sd.keys())

        for k in keys:
            for ik in ignore_keys:
                if k.startswith(ik):
                    print("Deleting key {} from state_dict.".format(k))
                    del sd[k]

        state_dict = self.state_dict()
        require_keys = state_dict.keys()
        keys = sd.keys()
        un_pretrained_keys = []
        for k in require_keys:
            if k not in keys: 
                # miss 'vqvae.'
                if k[6:] in keys:
                    state_dict[k] = sd[k[6:]]
                else:
                    un_pretrained_keys.append(k)
            else:
                state_dict[k] = sd[k]

        # print(f'*************************************************')
        # print(f"Layers without pretraining: {un_pretrained_keys}")
        # print(f'*************************************************')

        self.load_state_dict(state_dict, strict=True)
        print(f"Restored from {path}")

    def forward(self, input):
        dec, diff, info, hs = self.vqvae(input)
        return dec, diff, info, hs

    def training_step(self, batch, batch_idx, optimizer_idx):
        
        x = batch[self.image_key]
        xrec, qloss, info, hs = self(x)

        if self.image_key != 'gt':
            x = batch['gt']

        if self.use_facial_disc:
            loc_left_eyes = batch['loc_left_eye']
            loc_right_eyes = batch['loc_right_eye']
            loc_mouths = batch['loc_mouth']
            face_ratio = xrec.shape[-1] / 512
            components = get_roi_regions(x, xrec, loc_left_eyes, loc_right_eyes, loc_mouths, face_ratio)
        else:
            components = None

        if optimizer_idx == 0:
            # autoencode
            aeloss, log_dict_ae = self.loss(qloss, x, xrec, components, optimizer_idx, self.global_step,
                                            last_layer=self.get_last_layer(), split="train")

            self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
            self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
            return aeloss

        if optimizer_idx == 1:
            # discriminator
            discloss, log_dict_disc = self.loss(qloss, x, xrec, components, optimizer_idx, self.global_step,
                                            last_layer=None, split="train")
            self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
            return discloss

        
        if self.disc_start <= self.global_step:

            # left eye
            if optimizer_idx == 2:
                # discriminator
                disc_left_loss, log_dict_disc = self.loss(qloss, x, xrec, components, optimizer_idx, self.global_step,
                                                last_layer=None, split="train")
                self.log("train/disc_left_loss", disc_left_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
                self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
                return disc_left_loss

            # right eye
            if optimizer_idx == 3:
                # discriminator
                disc_right_loss, log_dict_disc = self.loss(qloss, x, xrec, components, optimizer_idx, self.global_step,
                                                last_layer=None, split="train")
                self.log("train/disc_right_loss", disc_right_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
                self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
                return disc_right_loss

            # mouth
            if optimizer_idx == 4:
                # discriminator
                disc_mouth_loss, log_dict_disc = self.loss(qloss, x, xrec, components, optimizer_idx, self.global_step,
                                                last_layer=None, split="train")
                self.log("train/disc_mouth_loss", disc_mouth_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
                self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
                return disc_mouth_loss

    def validation_step(self, batch, batch_idx):
        x = batch[self.image_key]
        xrec, qloss, info, hs = self(x)

        if self.image_key != 'gt':
            x = batch['gt']

        aeloss, log_dict_ae = self.loss(qloss, x, xrec, None, 0, self.global_step,
                                            last_layer=self.get_last_layer(), split="val")

        discloss, log_dict_disc = self.loss(qloss, x, xrec, None, 1, self.global_step,
                                            last_layer=None, split="val")
        rec_loss = log_dict_ae["val/rec_loss"]
        self.log("val/rec_loss", rec_loss,
                   prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
        self.log("val/aeloss", aeloss,
                   prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
        self.log_dict(log_dict_ae)
        self.log_dict(log_dict_disc)

        return self.log_dict

    def configure_optimizers(self):
        lr = self.learning_rate

        normal_params = []
        special_params = []
        for name, param in self.vqvae.named_parameters():
            if not param.requires_grad:
                continue
            if 'decoder' in name and 'attn' in name:
                special_params.append(param)
            else:
                normal_params.append(param)
        # print('special_params', special_params)
        opt_ae_params = [{'params': normal_params, 'lr': lr},
                         {'params': special_params, 'lr': lr*self.special_params_lr_scale}]
        opt_ae = torch.optim.Adam(opt_ae_params, betas=(0.5, 0.9))


        opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
                                    lr=lr, betas=(0.5, 0.9))

        optimizations = [opt_ae, opt_disc]

        s0 = torch.optim.lr_scheduler.MultiStepLR(opt_ae, milestones=self.schedule_step, gamma=0.1, verbose=True)
        s1 = torch.optim.lr_scheduler.MultiStepLR(opt_disc, milestones=self.schedule_step, gamma=0.1, verbose=True)
        schedules = [s0, s1]

        if self.use_facial_disc:
            opt_l = torch.optim.Adam(self.loss.net_d_left_eye.parameters(),
                                     lr=lr*self.comp_params_lr_scale, betas=(0.9, 0.99))
            opt_r = torch.optim.Adam(self.loss.net_d_right_eye.parameters(),
                                     lr=lr*self.comp_params_lr_scale, betas=(0.9, 0.99))
            opt_m = torch.optim.Adam(self.loss.net_d_mouth.parameters(),
                                     lr=lr*self.comp_params_lr_scale, betas=(0.9, 0.99))
            optimizations += [opt_l, opt_r, opt_m]
            
            s2 = torch.optim.lr_scheduler.MultiStepLR(opt_l, milestones=self.schedule_step, gamma=0.1, verbose=True)
            s3 = torch.optim.lr_scheduler.MultiStepLR(opt_r, milestones=self.schedule_step, gamma=0.1, verbose=True)
            s4 = torch.optim.lr_scheduler.MultiStepLR(opt_m, milestones=self.schedule_step, gamma=0.1, verbose=True)
            schedules += [s2, s3, s4]

        return optimizations, schedules

    def get_last_layer(self):
        if self.fix_decoder:
            return self.vqvae.quant_conv.weight
        return self.vqvae.decoder.conv_out.weight

    def log_images(self, batch, **kwargs):
        log = dict()
        x = batch[self.image_key]
        x = x.to(self.device)
        xrec, _, _, _ = self(x)
        log["inputs"] = x
        log["reconstructions"] = xrec

        if self.image_key != 'gt':
            x = batch['gt']
            log["gt"] = x
        return log


================================================
FILE: RestoreFormer/modules/discriminator/model.py
================================================
import functools
import torch.nn as nn


from RestoreFormer.modules.util import ActNorm


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


class NLayerDiscriminator(nn.Module):
    """Defines a PatchGAN discriminator as in Pix2Pix
        --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
    """
    def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
        """Construct a PatchGAN discriminator
        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """
        super(NLayerDiscriminator, self).__init__()
        if not use_actnorm:
            norm_layer = nn.BatchNorm2d
        else:
            norm_layer = ActNorm
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func != nn.BatchNorm2d
        else:
            use_bias = norm_layer != nn.BatchNorm2d

        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [
            nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
        self.main = nn.Sequential(*sequence)

    def forward(self, input):
        """Standard forward."""
        return self.main(input)

class NLayerDiscriminator_v1(nn.Module):
    """Defines a PatchGAN discriminator as in Pix2Pix
        --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
    """
    def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
        """Construct a PatchGAN discriminator
        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """
        super(NLayerDiscriminator_v1, self).__init__()
        if not use_actnorm:
            norm_layer = nn.BatchNorm2d
        else:
            norm_layer = ActNorm
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func != nn.BatchNorm2d
        else:
            use_bias = norm_layer != nn.BatchNorm2d

        self.n_layers = n_layers

        kw = 4
        padw = 1
        # sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        self.head = nn.Sequential(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True))
        # self.head = nn.Sequential(nn.Conv2d(3, 64, kernel_size=4, stride=1, padding=1), nn.LeakyReLU(0.2, True)).cuda()
        nf_mult = 1
        nf_mult_prev = 1
        self.body = nn.ModuleList()
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)

            self.body.append(nn.Sequential(
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ))

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        self.beforlast = nn.Sequential(
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        )

        self.final = nn.Sequential(
            nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw))  # output 1 channel prediction map
        # self.main = nn.Sequential(*sequence)

    def forward(self, input):
        """Standard forward."""
        # return self.main(input)
        
        features = []

        f = self.head(input)
        features.append(f) 

        for i in range(self.n_layers-1):
            f = self.body[i](f)
            features.append(f) 

        beforlastF = self.beforlast(f)
        final = self.final(beforlastF)

        return features, final



================================================
FILE: RestoreFormer/modules/losses/__init__.py
================================================




================================================
FILE: RestoreFormer/modules/losses/lpips.py
================================================
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""

import torch
import torch.nn as nn
from torchvision import models
from collections import namedtuple

from RestoreFormer.util import get_ckpt_path


class LPIPS(nn.Module):
    # Learned perceptual metric
    def __init__(self, use_dropout=True, style_weight=0.):
        super().__init__()
        self.scaling_layer = ScalingLayer()
        self.chns = [64, 128, 256, 512, 512]  # vg16 features
        self.net = vgg16(pretrained=True, requires_grad=False)
        self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
        self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
        self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
        self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
        self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
        self.load_from_pretrained()
        for param in self.parameters():
            param.requires_grad = False

        self.style_weight = style_weight

    def load_from_pretrained(self, name="vgg_lpips"):
        ckpt = get_ckpt_path(name, "experiments/pretrained_models/lpips")
        self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
        print("loaded pretrained LPIPS loss from {}".format(ckpt))

    @classmethod
    def from_pretrained(cls, name="vgg_lpips"):
        if name is not "vgg_lpips":
            raise NotImplementedError
        model = cls()
        ckpt = get_ckpt_path(name)
        model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
        return model

    def forward(self, input, target):
        in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
        outs0, outs1 = self.net(in0_input), self.net(in1_input)
        feats0, feats1, diffs = {}, {}, {}
        lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
        style_loss = torch.tensor([0.0]).to(input.device)
        for kk in range(len(self.chns)):
            feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
            diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
            if self.style_weight > 0.:
                style_loss = style_loss + torch.mean((self._gram_mat(feats0[kk]) - 
                             self._gram_mat(feats1[kk])) ** 2)

        res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
        val = res[0]
        for l in range(1, len(self.chns)):
            val += res[l]

        return val, style_loss * self.style_weight

    def _gram_mat(self, x):
        """Calculate Gram matrix.

        Args:
            x (torch.Tensor): Tensor with shape of (n, c, h, w).

        Returns:
            torch.Tensor: Gram matrix.
        """
        n, c, h, w = x.size()
        features = x.view(n, c, w * h)
        features_t = features.transpose(1, 2)
        gram = features.bmm(features_t) / (c * h * w)
        return gram


class ScalingLayer(nn.Module):
    def __init__(self):
        super(ScalingLayer, self).__init__()
        self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
        self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])

    def forward(self, inp):
        return (inp - self.shift) / self.scale


class NetLinLayer(nn.Module):
    """ A single linear layer which does a 1x1 conv """
    def __init__(self, chn_in, chn_out=1, use_dropout=False):
        super(NetLinLayer, self).__init__()
        layers = [nn.Dropout(), ] if (use_dropout) else []
        layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
        self.model = nn.Sequential(*layers)


class vgg16(torch.nn.Module):
    def __init__(self, requires_grad=False, pretrained=True):
        super(vgg16, self).__init__()
        vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        self.N_slices = 5
        for x in range(4):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(4, 9):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(9, 16):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(16, 23):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(23, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu1_2 = h
        h = self.slice2(h)
        h_relu2_2 = h
        h = self.slice3(h)
        h_relu3_3 = h
        h = self.slice4(h)
        h_relu4_3 = h
        h = self.slice5(h)
        h_relu5_3 = h
        vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
        return out


def normalize_tensor(x,eps=1e-10):
    norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
    return x/(norm_factor+eps)


def spatial_average(x, keepdim=True):
    return x.mean([2,3],keepdim=keepdim)



================================================
FILE: RestoreFormer/modules/losses/vqperceptual.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy

from RestoreFormer.modules.losses.lpips import LPIPS
from RestoreFormer.modules.discriminator.model import NLayerDiscriminator, weights_init
from RestoreFormer.modules.vqvae.facial_component_discriminator import FacialComponentDiscriminator
from basicsr.losses.losses import GANLoss, L1Loss
from RestoreFormer.modules.vqvae.arcface_arch import ResNetArcFace


class DummyLoss(nn.Module):
    def __init__(self):
        super().__init__()


def adopt_weight(weight, global_step, threshold=0, value=0.):
    if global_step < threshold:
        weight = value
    return weight


def hinge_d_loss(logits_real, logits_fake):
    loss_real = torch.mean(F.relu(1. - logits_real))
    loss_fake = torch.mean(F.relu(1. + logits_fake))
    d_loss = 0.5 * (loss_real + loss_fake)
    return d_loss


def vanilla_d_loss(logits_real, logits_fake):
    d_loss = 0.5 * (
        torch.mean(torch.nn.functional.softplus(-logits_real)) +
        torch.mean(torch.nn.functional.softplus(logits_fake)))
    return d_loss


class VQLPIPSWithDiscriminatorWithCompWithIdentity(nn.Module):
    def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
                 disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
                 perceptual_weight=1.0, use_actnorm=False, 
                 disc_ndf=64, disc_loss="hinge", comp_weight=0.0, comp_style_weight=0.0, 
                 identity_weight=0.0, comp_disc_loss='vanilla', lpips_style_weight=0.0,
                 identity_model_path=None, **ignore_kwargs):
        super().__init__()
        assert disc_loss in ["hinge", "vanilla"]
        self.codebook_weight = codebook_weight
        self.pixel_weight = pixelloss_weight
        self.perceptual_loss = LPIPS(style_weight=lpips_style_weight).eval()
        self.perceptual_weight = perceptual_weight

        self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
                                                 n_layers=disc_num_layers,
                                                 use_actnorm=use_actnorm,
                                                 ndf=disc_ndf
                                                 ).apply(weights_init)
        if comp_weight > 0:
            self.net_d_left_eye = FacialComponentDiscriminator()
            self.net_d_right_eye = FacialComponentDiscriminator()
            self.net_d_mouth = FacialComponentDiscriminator()
            print(f'Use components discrimination')

            self.cri_component = GANLoss(gan_type=comp_disc_loss, 
                                         real_label_val=1.0, 
                                         fake_label_val=0.0, 
                                         loss_weight=comp_weight)

            if comp_style_weight > 0.:
                self.cri_style = L1Loss(loss_weight=comp_style_weight, reduction='mean')

        if identity_weight > 0:
            self.identity = ResNetArcFace(block = 'IRBlock', 
                                          layers = [2, 2, 2, 2],
                                          use_se = False)
            print(f'Use identity loss')
            if identity_model_path is not None:
                sd = torch.load(identity_model_path, map_location="cpu")
                for k, v in deepcopy(sd).items():
                    if k.startswith('module.'):
                        sd[k[7:]] = v
                        sd.pop(k)
                self.identity.load_state_dict(sd, strict=True)

            for param in self.identity.parameters():
                param.requires_grad = False

            self.cri_identity = L1Loss(loss_weight=identity_weight, reduction='mean')


        self.discriminator_iter_start = disc_start
        if disc_loss == "hinge":
            self.disc_loss = hinge_d_loss
        elif disc_loss == "vanilla":
            self.disc_loss = vanilla_d_loss
        else:
            raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
        print(f"VQLPIPSWithDiscriminatorWithCompWithIdentity running with {disc_loss} loss.")
        self.disc_factor = disc_factor
        self.discriminator_weight = disc_weight
        self.comp_weight = comp_weight
        self.comp_style_weight = comp_style_weight
        self.identity_weight = identity_weight
        self.lpips_style_weight = lpips_style_weight

    def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
        if last_layer is not None:
            nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
            g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
        else:
            nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
            g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]

        d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
        d_weight = d_weight * self.discriminator_weight
        return d_weight

    def _gram_mat(self, x):
        """Calculate Gram matrix.

        Args:
            x (torch.Tensor): Tensor with shape of (n, c, h, w).

        Returns:
            torch.Tensor: Gram matrix.
        """
        n, c, h, w = x.size()
        features = x.view(n, c, w * h)
        features_t = features.transpose(1, 2)
        gram = features.bmm(features_t) / (c * h * w)
        return gram

    def gray_resize_for_identity(self, out, size=128):
        out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
        out_gray = out_gray.unsqueeze(1)
        out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
        return out_gray

    def forward(self, codebook_loss, gts, reconstructions, components, optimizer_idx,
                global_step, last_layer=None, split="train"):

        # now the GAN part
        if optimizer_idx == 0:
            rec_loss = (torch.abs(gts.contiguous() - reconstructions.contiguous())) * self.pixel_weight
            if self.perceptual_weight > 0:
                p_loss, p_style_loss = self.perceptual_loss(gts.contiguous(), reconstructions.contiguous())
                rec_loss = rec_loss + self.perceptual_weight * p_loss
            else:
                p_loss = torch.tensor([0.0])
                p_style_loss = torch.tensor([0.0])

            nll_loss = rec_loss
            #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
            nll_loss = torch.mean(nll_loss)

        
            # generator update
            
            logits_fake = self.discriminator(reconstructions.contiguous())
            g_loss = -torch.mean(logits_fake)

            try:
                d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
            except RuntimeError:
                assert not self.training
                d_weight = torch.tensor(0.0)

            disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
            
            loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() + p_style_loss

            log = {
                   "{}/quant_loss".format(split): codebook_loss.detach().mean(),
                   "{}/nll_loss".format(split): nll_loss.detach().mean(),
                   "{}/rec_loss".format(split): rec_loss.detach().mean(),
                   "{}/p_loss".format(split): p_loss.detach().mean(),
                   "{}/p_style_loss".format(split): p_style_loss.detach().mean(),
                   "{}/d_weight".format(split): d_weight.detach(),
                   "{}/disc_factor".format(split): torch.tensor(disc_factor),
                   "{}/g_loss".format(split): g_loss.detach().mean(),
                   }

            if self.comp_weight > 0. and components is not None and self.discriminator_iter_start < global_step:
                fake_left_eye, fake_left_eye_feats = self.net_d_left_eye(components['left_eyes'], return_feats=True)
                comp_g_loss = self.cri_component(fake_left_eye, True, is_disc=False)
                loss = loss + comp_g_loss 
                log["{}/g_left_loss".format(split)] = comp_g_loss.detach()

                fake_right_eye, fake_right_eye_feats = self.net_d_right_eye(components['right_eyes'], return_feats=True)
                comp_g_loss = self.cri_component(fake_right_eye, True, is_disc=False)
                loss = loss + comp_g_loss 
                log["{}/g_right_loss".format(split)] = comp_g_loss.detach()

                fake_mouth, fake_mouth_feats = self.net_d_mouth(components['mouths'], return_feats=True)
                comp_g_loss = self.cri_component(fake_mouth, True, is_disc=False)
                loss = loss + comp_g_loss 
                log["{}/g_mouth_loss".format(split)] = comp_g_loss.detach()

                if self.comp_style_weight > 0.:
                    _, real_left_eye_feats = self.net_d_left_eye(components['left_eyes_gt'], return_feats=True)
                    _, real_right_eye_feats = self.net_d_right_eye(components['right_eyes_gt'], return_feats=True)
                    _, real_mouth_feats = self.net_d_mouth(components['mouths_gt'], return_feats=True)

                    def _comp_style(feat, feat_gt, criterion):
                        return criterion(self._gram_mat(feat[0]), self._gram_mat(
                            feat_gt[0].detach())) * 0.5 + criterion(self._gram_mat(
                            feat[1]), self._gram_mat(feat_gt[1].detach()))

                    comp_style_loss = 0.
                    comp_style_loss = comp_style_loss + _comp_style(fake_left_eye_feats, real_left_eye_feats, self.cri_style)
                    comp_style_loss = comp_style_loss + _comp_style(fake_right_eye_feats, real_right_eye_feats, self.cri_style)
                    comp_style_loss = comp_style_loss + _comp_style(fake_mouth_feats, real_mouth_feats, self.cri_style)
                    loss = loss + comp_style_loss 
                    log["{}/comp_style_loss".format(split)] = comp_style_loss.detach()

            if self.identity_weight > 0. and self.discriminator_iter_start < global_step:
                self.identity.eval()
                out_gray = self.gray_resize_for_identity(reconstructions)
                gt_gray = self.gray_resize_for_identity(gts)
                
                identity_gt = self.identity(gt_gray).detach()
                identity_out = self.identity(out_gray)

                identity_loss = self.cri_identity(identity_out, identity_gt)
                loss = loss + identity_loss 
                log["{}/identity_loss".format(split)] = identity_loss.detach()

            log["{}/total_loss".format(split)] = loss.clone().detach().mean()

            return loss, log

        if optimizer_idx == 1:
            # second pass for discriminator update
            
            logits_real = self.discriminator(gts.contiguous().detach())
            logits_fake = self.discriminator(reconstructions.contiguous().detach())

            disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
            d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)

            log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
                   "{}/logits_real".format(split): logits_real.detach().mean(),
                   "{}/logits_fake".format(split): logits_fake.detach().mean()
                   }
            return d_loss, log

        # left eye
        if optimizer_idx == 2:
            # third pass for discriminator update
            disc_factor = adopt_weight(1.0, global_step, threshold=self.discriminator_iter_start)
            fake_d_pred, _ = self.net_d_left_eye(components['left_eyes'].detach())
            real_d_pred, _ = self.net_d_left_eye(components['left_eyes_gt'])
            d_loss = self.cri_component(real_d_pred, True, is_disc=True) + self.cri_component(fake_d_pred, False, is_disc=True)

            log = {"{}/d_left_loss".format(split): d_loss.clone().detach().mean()}
            return d_loss, log

        # right eye
        if optimizer_idx == 3:
            # forth pass for discriminator update
            fake_d_pred, _ = self.net_d_right_eye(components['right_eyes'].detach())
            real_d_pred, _ = self.net_d_right_eye(components['right_eyes_gt'])
            d_loss = self.cri_component(real_d_pred, True, is_disc=True) + self.cri_component(fake_d_pred, False, is_disc=True)

            log = {"{}/d_right_loss".format(split): d_loss.clone().detach().mean()}
            return d_loss, log

        # mouth
        if optimizer_idx == 4:
            # fifth pass for discriminator update
            fake_d_pred, _ = self.net_d_mouth(components['mouths'].detach())
            real_d_pred, _ = self.net_d_mouth(components['mouths_gt'])
            d_loss = self.cri_component(real_d_pred, True, is_disc=True) + self.cri_component(fake_d_pred, False, is_disc=True)

            log = {"{}/d_mouth_loss".format(split): d_loss.clone().detach().mean()}
            return d_loss, log


================================================
FILE: RestoreFormer/modules/util.py
================================================
import torch
import torch.nn as nn


def count_params(model):
    total_params = sum(p.numel() for p in model.parameters())
    return total_params


class ActNorm(nn.Module):
    def __init__(self, num_features, logdet=False, affine=True,
                 allow_reverse_init=False):
        assert affine
        super().__init__()
        self.logdet = logdet
        self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
        self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
        self.allow_reverse_init = allow_reverse_init

        self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))

    def initialize(self, input):
        with torch.no_grad():
            flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
            mean = (
                flatten.mean(1)
                .unsqueeze(1)
                .unsqueeze(2)
                .unsqueeze(3)
                .permute(1, 0, 2, 3)
            )
            std = (
                flatten.std(1)
                .unsqueeze(1)
                .unsqueeze(2)
                .unsqueeze(3)
                .permute(1, 0, 2, 3)
            )

            self.loc.data.copy_(-mean)
            self.scale.data.copy_(1 / (std + 1e-6))

    def forward(self, input, reverse=False):
        if reverse:
            return self.reverse(input)
        if len(input.shape) == 2:
            input = input[:,:,None,None]
            squeeze = True
        else:
            squeeze = False

        _, _, height, width = input.shape

        if self.training and self.initialized.item() == 0:
            self.initialize(input)
            self.initialized.fill_(1)

        h = self.scale * (input + self.loc)

        if squeeze:
            h = h.squeeze(-1).squeeze(-1)

        if self.logdet:
            log_abs = torch.log(torch.abs(self.scale))
            logdet = height*width*torch.sum(log_abs)
            logdet = logdet * torch.ones(input.shape[0]).to(input)
            return h, logdet

        return h

    def reverse(self, output):
        if self.training and self.initialized.item() == 0:
            if not self.allow_reverse_init:
                raise RuntimeError(
                    "Initializing ActNorm in reverse direction is "
                    "disabled by default. Use allow_reverse_init=True to enable."
                )
            else:
                self.initialize(output)
                self.initialized.fill_(1)

        if len(output.shape) == 2:
            output = output[:,:,None,None]
            squeeze = True
        else:
            squeeze = False

        h = output / self.scale - self.loc

        if squeeze:
            h = h.squeeze(-1).squeeze(-1)
        return h


class Attention2DConv(nn.Module):
    """to replace the convolutional architecture entirely"""
    def __init__(self):
        super().__init__()


================================================
FILE: RestoreFormer/modules/vqvae/arcface_arch.py
================================================
import torch.nn as nn

from basicsr.utils.registry import ARCH_REGISTRY


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class IRBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
        super(IRBlock, self).__init__()
        self.bn0 = nn.BatchNorm2d(inplanes)
        self.conv1 = conv3x3(inplanes, inplanes)
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.prelu = nn.PReLU()
        self.conv2 = conv3x3(inplanes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride
        self.use_se = use_se
        if self.use_se:
            self.se = SEBlock(planes)

    def forward(self, x):
        residual = x
        out = self.bn0(x)
        out = self.conv1(out)
        out = self.bn1(out)
        out = self.prelu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        if self.use_se:
            out = self.se(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.prelu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class SEBlock(nn.Module):

    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
            nn.Sigmoid())

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y


@ARCH_REGISTRY.register()
class ResNetArcFace(nn.Module):

    def __init__(self, block, layers, use_se=True):
        if block == 'IRBlock':
            block = IRBlock
        self.inplanes = 64
        self.use_se = use_se
        super(ResNetArcFace, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.prelu = nn.PReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.bn4 = nn.BatchNorm2d(512)
        self.dropout = nn.Dropout()
        self.fc5 = nn.Linear(512 * 8 * 8, 512)
        self.bn5 = nn.BatchNorm1d(512)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
        self.inplanes = planes
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, use_se=self.use_se))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.prelu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.bn4(x)
        x = self.dropout(x)
        x = x.view(x.size(0), -1)
        x = self.fc5(x)
        x = self.bn5(x)

        return x


================================================
FILE: RestoreFormer/modules/vqvae/facial_component_discriminator.py
================================================
import math
import random
import torch
from torch import nn
from torch.nn import functional as F

from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
                                          StyleGAN2Generator)
from basicsr.ops.fused_act import FusedLeakyReLU
from basicsr.utils.registry import ARCH_REGISTRY


@ARCH_REGISTRY.register()
class FacialComponentDiscriminator(nn.Module):

    def __init__(self):
        super(FacialComponentDiscriminator, self).__init__()

        self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
        self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
        self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
        self.conv4 = ConvLayer(128, 256, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
        self.conv5 = ConvLayer(256, 256, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
        self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)

    def forward(self, x, return_feats=False):
        feat = self.conv1(x)
        feat = self.conv3(self.conv2(feat))
        rlt_feats = []
        if return_feats:
            rlt_feats.append(feat.clone())
        feat = self.conv5(self.conv4(feat))
        if return_feats:
            rlt_feats.append(feat.clone())
        out = self.final_conv(feat)

        if return_feats:
            return out, rlt_feats
        else:
            return out, None


================================================
FILE: RestoreFormer/modules/vqvae/utils.py
================================================
from torchvision.ops import roi_align
import torch

def get_roi_regions(gt, output, loc_left_eyes, loc_right_eyes, loc_mouths,
                    face_ratio=1, eye_out_size=80, mouth_out_size=120):
    # hard code
    eye_out_size *= face_ratio
    mouth_out_size *= face_ratio

    eye_out_size = int(eye_out_size)
    mouth_out_size = int(mouth_out_size)

    rois_eyes = []
    rois_mouths = []
    for b in range(loc_left_eyes.size(0)):  # loop for batch size
        # left eye and right eye
        img_inds = loc_left_eyes.new_full((2, 1), b)
        bbox = torch.stack([loc_left_eyes[b, :], loc_right_eyes[b, :]], dim=0)  # shape: (2, 4)
        rois = torch.cat([img_inds, bbox], dim=-1)  # shape: (2, 5)
        rois_eyes.append(rois)
        # mouse
        img_inds = loc_left_eyes.new_full((1, 1), b)
        rois = torch.cat([img_inds, loc_mouths[b:b + 1, :]], dim=-1)  # shape: (1, 5)
        rois_mouths.append(rois)

    rois_eyes = torch.cat(rois_eyes, 0)
    rois_mouths = torch.cat(rois_mouths, 0)

    # real images
    all_eyes = roi_align(gt, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
    left_eyes_gt = all_eyes[0::2, :, :, :]
    right_eyes_gt = all_eyes[1::2, :, :, :]
    mouths_gt = roi_align(gt, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
    # output
    all_eyes = roi_align(output, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
    left_eyes = all_eyes[0::2, :, :, :]
    right_eyes = all_eyes[1::2, :, :, :]
    mouths = roi_align(output, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio

    return {'left_eyes_gt': left_eyes_gt, 'right_eyes_gt': right_eyes_gt, 'mouths_gt': mouths_gt, 
            'left_eyes': left_eyes, 'right_eyes': right_eyes, 'mouths': mouths}


================================================
FILE: RestoreFormer/modules/vqvae/vqvae_arch.py
================================================
import torch
import torch.nn as nn
import random
import math
import torch.nn.functional as F
import numpy as np
# from basicsr.utils.registry import ARCH_REGISTRY
import torch.nn.utils.spectral_norm as SpectralNorm
import RestoreFormer.distributed as dist_fn

class VectorQuantizer(nn.Module):
    """
    see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
    ____________________________________________
    Discretization bottleneck part of the VQ-VAE.
    Inputs:
    - n_e : number of embeddings
    - e_dim : dimension of embedding
    - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
    _____________________________________________
    """

    def __init__(self, n_e, e_dim, beta):
        super(VectorQuantizer, self).__init__()
        self.n_e = n_e
        self.e_dim = e_dim
        self.beta = beta

        self.embedding = nn.Embedding(self.n_e, self.e_dim)
        self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)

    def forward(self, z):
        """
        Inputs the output of the encoder network z and maps it to a discrete
        one-hot vector that is the index of the closest embedding vector e_j
        z (continuous) -> z_q (discrete)
        z.shape = (batch, channel, height, width)
        quantization pipeline:
            1. get encoder input (B,C,H,W)
            2. flatten input to (B*H*W,C)
        """
        # reshape z -> (batch, height, width, channel) and flatten
        z = z.permute(0, 2, 3, 1).contiguous()
        z_flattened = z.view(-1, self.e_dim)
        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z

        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
            torch.sum(self.embedding.weight**2, dim=1) - 2 * \
            torch.matmul(z_flattened, self.embedding.weight.t())

        ## could possible replace this here
        # #\start...
        # find closest encodings

        min_value, min_encoding_indices = torch.min(d, dim=1)

        min_encoding_indices = min_encoding_indices.unsqueeze(1)

        min_encodings = torch.zeros(
            min_encoding_indices.shape[0], self.n_e).to(z)
        min_encodings.scatter_(1, min_encoding_indices, 1)

        # dtype min encodings: torch.float32
        # min_encodings shape: torch.Size([2048, 512])
        # min_encoding_indices.shape: torch.Size([2048, 1])

        # get quantized latent vectors
        z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
        #.........\end

        # with:
        # .........\start
        #min_encoding_indices = torch.argmin(d, dim=1)
        #z_q = self.embedding(min_encoding_indices)
        # ......\end......... (TODO)

        # compute loss for embedding
        loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
            torch.mean((z_q - z.detach()) ** 2)

        # preserve gradients
        z_q = z + (z_q - z).detach()

        # perplexity
        
        e_mean = torch.mean(min_encodings, dim=0)
        perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))

        # reshape back to match original input shape
        z_q = z_q.permute(0, 3, 1, 2).contiguous()

        return z_q, loss, (perplexity, min_encodings, min_encoding_indices, d)

    def get_codebook_entry(self, indices, shape):
        # shape specifying (batch, height, width, channel)
        # TODO: check for more easy handling with nn.Embedding
        min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
        min_encodings.scatter_(1, indices[:,None], 1)

        # get quantized latent vectors
        z_q = torch.matmul(min_encodings.float(), self.embedding.weight)

        if shape is not None:
            z_q = z_q.view(shape)

            # reshape back to match original input shape
            z_q = z_q.permute(0, 3, 1, 2).contiguous()

        return z_q

# pytorch_diffusion + derived encoder decoder
def nonlinearity(x):
    # swish
    return x*torch.sigmoid(x)


def Normalize(in_channels):
    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)


class Upsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        super().__init__()
        self.with_conv = with_conv
        if self.with_conv:
            self.conv = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def forward(self, x):
        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
        if self.with_conv:
            x = self.conv(x)
        return x


class Downsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        super().__init__()
        self.with_conv = with_conv
        if self.with_conv:
            # no asymmetric padding in torch conv, must do it ourselves
            self.conv = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=3,
                                        stride=2,
                                        padding=0)

    def forward(self, x):
        if self.with_conv:
            pad = (0,1,0,1)
            x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
            x = self.conv(x)
        else:
            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
        return x


class ResnetBlock(nn.Module):
    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
                 dropout, temb_channels=512):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut

        self.norm1 = Normalize(in_channels)
        self.conv1 = torch.nn.Conv2d(in_channels,
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)
        if temb_channels > 0:
            self.temb_proj = torch.nn.Linear(temb_channels,
                                             out_channels)
        self.norm2 = Normalize(out_channels)
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = torch.nn.Conv2d(out_channels,
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)
        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                self.conv_shortcut = torch.nn.Conv2d(in_channels,
                                                     out_channels,
                                                     kernel_size=3,
                                                     stride=1,
                                                     padding=1)
            else:
                self.nin_shortcut = torch.nn.Conv2d(in_channels,
                                                    out_channels,
                                                    kernel_size=1,
                                                    stride=1,
                                                    padding=0)

    def forward(self, x, temb):
        h = x
        h = self.norm1(h)
        h = nonlinearity(h)
        h = self.conv1(h)

        if temb is not None:
            h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]

        h = self.norm2(h)
        h = nonlinearity(h)
        h = self.dropout(h)
        h = self.conv2(h)

        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                x = self.conv_shortcut(x)
            else:
                x = self.nin_shortcut(x)

        return x+h


class MultiHeadAttnBlock(nn.Module):
    def __init__(self, in_channels, head_size=1):
        super().__init__()
        self.in_channels = in_channels
        self.head_size = head_size
        self.att_size = in_channels // head_size
        assert(in_channels % head_size == 0), 'The size of head should be divided by the number of channels.'

        self.norm1 = Normalize(in_channels)
        self.norm2 = Normalize(in_channels)

        self.q = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.k = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.v = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.proj_out = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)
        self.num = 0

    def forward(self, x, y=None):
        h_ = x
        h_ = self.norm1(h_)
        if y is None:
            y = h_
        else:
            y = self.norm2(y)

        q = self.q(y)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention
        b,c,h,w = q.shape
        q = q.reshape(b, self.head_size, self.att_size ,h*w) 
        q = q.permute(0, 3, 1, 2) # b, hw, head, att

        k = k.reshape(b, self.head_size, self.att_size ,h*w) 
        k = k.permute(0, 3, 1, 2)

        v = v.reshape(b, self.head_size, self.att_size ,h*w) 
        v = v.permute(0, 3, 1, 2)


        q = q.transpose(1, 2)
        v = v.transpose(1, 2)
        k = k.transpose(1, 2).transpose(2,3)

        scale = int(self.att_size)**(-0.5)
        q.mul_(scale)
        w_ = torch.matmul(q, k)
        w_ = F.softmax(w_, dim=3)

        w_ = w_.matmul(v)

        w_ = w_.transpose(1, 2).contiguous() # [b, h*w, head, att]
        w_ = w_.view(b, h, w, -1)
        w_ = w_.permute(0, 3, 1, 2)

        w_ = self.proj_out(w_)

        return x+w_


class MultiHeadEncoder(nn.Module):
    def __init__(self, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks=2,
                 attn_resolutions=[16], dropout=0.0, resamp_with_conv=True, in_channels=3,
                 resolution=512, z_channels=256, double_z=True, enable_mid=True,
                 head_size=1, **ignore_kwargs):
        super().__init__()
        self.ch = ch
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels
        self.enable_mid = enable_mid

        # downsampling
        self.conv_in = torch.nn.Conv2d(in_channels,
                                       self.ch,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)

        curr_res = resolution
        in_ch_mult = (1,)+tuple(ch_mult)
        self.down = nn.ModuleList()
        for i_level in range(self.num_resolutions):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_in = ch*in_ch_mult[i_level]
            block_out = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(MultiHeadAttnBlock(block_in, head_size))
            down = nn.Module()
            down.block = block
            down.attn = attn
            if i_level != self.num_resolutions-1:
                down.downsample = Downsample(block_in, resamp_with_conv)
                curr_res = curr_res // 2
            self.down.append(down)

        # middle
        if self.enable_mid:
            self.mid = nn.Module()
            self.mid.block_1 = ResnetBlock(in_channels=block_in,
                                           out_channels=block_in,
                                           temb_channels=self.temb_ch,
                                           dropout=dropout)
            self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
            self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                           out_channels=block_in,
                                           temb_channels=self.temb_ch,
                                           dropout=dropout)

        # end
        self.norm_out = Normalize(block_in)
        self.conv_out = torch.nn.Conv2d(block_in,
                                        2*z_channels if double_z else z_channels,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)


    def forward(self, x):
        #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)

        hs = {}
        # timestep embedding
        temb = None

        # downsampling
        h = self.conv_in(x)
        hs['in'] = h
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                h = self.down[i_level].block[i_block](h, temb)
                if len(self.down[i_level].attn) > 0:
                    h = self.down[i_level].attn[i_block](h)

            if i_level != self.num_resolutions-1:
                # hs.append(h)
                hs['block_'+str(i_level)] = h
                h = self.down[i_level].downsample(h)

        # middle
        # h = hs[-1]
        if self.enable_mid:
            h = self.mid.block_1(h, temb)
            hs['block_'+str(i_level)+'_atten'] = h
            h = self.mid.attn_1(h)
            h = self.mid.block_2(h, temb)
            hs['mid_atten'] = h

        # end
        h = self.norm_out(h)
        h = nonlinearity(h)
        h = self.conv_out(h)
        # hs.append(h)
        hs['out'] = h

        return hs

class MultiHeadDecoder(nn.Module):
    def __init__(self, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks=2,
                 attn_resolutions=16, dropout=0.0, resamp_with_conv=True, in_channels=3,
                 resolution=512, z_channels=256, give_pre_end=False, enable_mid=True,
                 head_size=1, **ignorekwargs):
        super().__init__()
        self.ch = ch
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels
        self.give_pre_end = give_pre_end
        self.enable_mid = enable_mid

        # compute in_ch_mult, block_in and curr_res at lowest res
        in_ch_mult = (1,)+tuple(ch_mult)
        block_in = ch*ch_mult[self.num_resolutions-1]
        curr_res = resolution // 2**(self.num_resolutions-1)
        self.z_shape = (1,z_channels,curr_res,curr_res)
        print("Working with z of shape {} = {} dimensions.".format(
            self.z_shape, np.prod(self.z_shape)))

        # z to block_in
        self.conv_in = torch.nn.Conv2d(z_channels,
                                       block_in,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)

        # middle
        if self.enable_mid:
            self.mid = nn.Module()
            self.mid.block_1 = ResnetBlock(in_channels=block_in,
                                           out_channels=block_in,
                                           temb_channels=self.temb_ch,
                                           dropout=dropout)
            self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
            self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                           out_channels=block_in,
                                           temb_channels=self.temb_ch,
                                           dropout=dropout)

        # upsampling
        self.up = nn.ModuleList()
        for i_level in reversed(range(self.num_resolutions)):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_out = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks+1):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(MultiHeadAttnBlock(block_in, head_size))
            up = nn.Module()
            up.block = block
            up.attn = attn
            if i_level != 0:
                up.upsample = Upsample(block_in, resamp_with_conv)
                curr_res = curr_res * 2
            self.up.insert(0, up) # prepend to get consistent order

        # end
        self.norm_out = Normalize(block_in)
        self.conv_out = torch.nn.Conv2d(block_in,
                                        out_ch,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def forward(self, z):
        #assert z.shape[1:] == self.z_shape[1:]
        self.last_z_shape = z.shape

        # timestep embedding
        temb = None

        # z to block_in
        h = self.conv_in(z)

        # middle
        if self.enable_mid:
            h = self.mid.block_1(h, temb)
            h = self.mid.attn_1(h)
            h = self.mid.block_2(h, temb)

        # upsampling
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks+1):
                h = self.up[i_level].block[i_block](h, temb)
                if len(self.up[i_level].attn) > 0:
                    h = self.up[i_level].attn[i_block](h)
            if i_level != 0:
                h = self.up[i_level].upsample(h)

        # end
        if self.give_pre_end:
            return h

        h = self.norm_out(h)
        h = nonlinearity(h)
        h = self.conv_out(h)
        return h

class MultiHeadDecoderTransformer(nn.Module):
    def __init__(self, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks=2,
                 attn_resolutions=16, dropout=0.0, resamp_with_conv=True, in_channels=3,
                 resolution=512, z_channels=256, give_pre_end=False, enable_mid=True,
                 head_size=1, **ignorekwargs):
        super().__init__()
        self.ch = ch
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels
        self.give_pre_end = give_pre_end
        self.enable_mid = enable_mid

        # compute in_ch_mult, block_in and curr_res at lowest res
        in_ch_mult = (1,)+tuple(ch_mult)
        block_in = ch*ch_mult[self.num_resolutions-1]
        curr_res = resolution // 2**(self.num_resolutions-1)
        self.z_shape = (1,z_channels,curr_res,curr_res)
        print("Working with z of shape {} = {} dimensions.".format(
            self.z_shape, np.prod(self.z_shape)))

        # z to block_in
        self.conv_in = torch.nn.Conv2d(z_channels,
                                       block_in,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)

        # middle
        if self.enable_mid:
            self.mid = nn.Module()
            self.mid.block_1 = ResnetBlock(in_channels=block_in,
                                           out_channels=block_in,
                                           temb_channels=self.temb_ch,
                                           dropout=dropout)
            self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
            self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                           out_channels=block_in,
                                           temb_channels=self.temb_ch,
                                           dropout=dropout)

        # upsampling
        self.up = nn.ModuleList()
        for i_level in reversed(range(self.num_resolutions)):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_out = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks+1):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(MultiHeadAttnBlock(block_in, head_size))
            up = nn.Module()
            up.block = block
            up.attn = attn
            if i_level != 0:
                up.upsample = Upsample(block_in, resamp_with_conv)
                curr_res = curr_res * 2
            self.up.insert(0, up) # prepend to get consistent order

        # end
        self.norm_out = Normalize(block_in)
        self.conv_out = torch.nn.Conv2d(block_in,
                                        out_ch,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def forward(self, z, hs):
        #assert z.shape[1:] == self.z_shape[1:]
        # self.last_z_shape = z.shape

        # timestep embedding
        temb = None

        # z to block_in
        h = self.conv_in(z)

        # middle
        if self.enable_mid:
            h = self.mid.block_1(h, temb)
            h = self.mid.attn_1(h, hs['mid_atten'])
            h = self.mid.block_2(h, temb)

        # upsampling
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks+1):
                h = self.up[i_level].block[i_block](h, temb)
                if len(self.up[i_level].attn) > 0:
                    h = self.up[i_level].attn[i_block](h, hs['block_'+str(i_level)+'_atten'])
                    # hfeature = h.clone()
            if i_level != 0:
                h = self.up[i_level].upsample(h)

        # end
        if self.give_pre_end:
            return h

        h = self.norm_out(h)
        h = nonlinearity(h)
        h = self.conv_out(h)
        return h


class VQVAEGAN(nn.Module):
    def __init__(self, n_embed=1024, embed_dim=256, ch=128, out_ch=3, ch_mult=(1,2,4,8), 
                 num_res_blocks=2, attn_resolutions=16, dropout=0.0, in_channels=3, 
                 resolution=512, z_channels=256, double_z=False, enable_mid=True, 
                 fix_decoder=False, fix_codebook=False, head_size=1, **ignore_kwargs):
        super(VQVAEGAN, self).__init__()

        self.encoder = MultiHeadEncoder(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,
                               attn_resolutions=attn_resolutions, dropout=dropout, in_channels=in_channels,
                               resolution=resolution, z_channels=z_channels, double_z=double_z, 
                               enable_mid=enable_mid, head_size=head_size)
        self.decoder = MultiHeadDecoder(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,
                               attn_resolutions=attn_resolutions, dropout=dropout, in_channels=in_channels,
                               resolution=resolution, z_channels=z_channels, enable_mid=enable_mid, head_size=head_size)

        self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)

        self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1)
        self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)

        if fix_decoder:
            for _, param in self.decoder.named_parameters():
                param.requires_grad = False
            for _, param in self.post_quant_conv.named_parameters():
                param.requires_grad = False
            for _, param in self.quantize.named_parameters():
                param.requires_grad = False
        elif fix_codebook:
            for _, param in self.quantize.named_parameters():
                param.requires_grad = False

    def encode(self, x):

        hs = self.encoder(x)
        h = self.quant_conv(hs['out'])
        quant, emb_loss, info = self.quantize(h)
        return quant, emb_loss, info, hs

    def decode(self, quant):
        quant = self.post_quant_conv(quant)
        dec = self.decoder(quant)

        return dec

    def forward(self, input):
        quant, diff, info, hs = self.encode(input)
        dec = self.decode(quant)

        return dec, diff, info, hs

class VQVAEGANMultiHeadTransformer(nn.Module):
    def __init__(self, n_embed=1024, embed_dim=256, ch=128, out_ch=3, ch_mult=(1,2,4,8), 
                 num_res_blocks=2, attn_resolutions=16, dropout=0.0, in_channels=3, 
                 resolution=512, z_channels=256, double_z=False, enable_mid=True, 
                 fix_decoder=False, fix_codebook=False, fix_encoder=False, constrastive_learning_loss_weight=0.0,
                 head_size=1):
        super(VQVAEGANMultiHeadTransformer, self).__init__()

        self.encoder = MultiHeadEncoder(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,
                               attn_resolutions=attn_resolutions, dropout=dropout, in_channels=in_channels,
                               resolution=resolution, z_channels=z_channels, double_z=double_z, 
                               enable_mid=enable_mid, head_size=head_size)
        self.decoder = MultiHeadDecoderTransformer(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,
                               attn_resolutions=attn_resolutions, dropout=dropout, in_channels=in_channels,
                               resolution=resolution, z_channels=z_channels, enable_mid=enable_mid, head_size=head_size)

        self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)

        self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1)
        self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)

        if fix_decoder:
            for _, param in self.decoder.named_parameters():
                param.requires_grad = False
            for _, param in self.post_quant_conv.named_parameters():
                param.requires_grad = False
            for _, param in self.quantize.named_parameters():
                param.requires_grad = False
        elif fix_codebook:
            for _, param in self.quantize.named_parameters():
                param.requires_grad = False

        if fix_encoder:
            for _, param in self.encoder.named_parameters():
                param.requires_grad = False

    def encode(self, x):
        
        hs = self.encoder(x)
        h = self.quant_conv(hs['out'])
        quant, emb_loss, info = self.quantize(h)
        return quant, emb_loss, info, hs

    def decode(self, quant, hs):
        quant = self.post_quant_conv(quant)
        dec = self.decoder(quant, hs)

        return dec

    def forward(self, input):
        quant, diff, info, hs = self.encode(input)
        dec = self.decode(quant, hs)

        return dec, diff, info, hs

================================================
FILE: RestoreFormer/util.py
================================================
import os, hashlib
import requests
from tqdm import tqdm

URL_MAP = {
    "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
}

CKPT_MAP = {
    "vgg_lpips": "vgg.pth"
}

MD5_MAP = {
    "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
}


def download(url, local_path, chunk_size=1024):
    os.makedirs(os.path.split(local_path)[0], exist_ok=True)
    with requests.get(url, stream=True) as r:
        total_size = int(r.headers.get("content-length", 0))
        with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
            with open(local_path, "wb") as f:
                for data in r.iter_content(chunk_size=chunk_size):
                    if data:
                        f.write(data)
                        pbar.update(chunk_size)


def md5_hash(path):
    with open(path, "rb") as f:
        content = f.read()
    return hashlib.md5(content).hexdigest()


def get_ckpt_path(name, root, check=False):
    assert name in URL_MAP
    path = os.path.join(root, CKPT_MAP[name])
    if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
        print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
        download(URL_MAP[name], path)
        md5 = md5_hash(path)
        assert md5 == MD5_MAP[name], md5
    return path


class KeyNotFoundError(Exception):
    def __init__(self, cause, keys=None, visited=None):
        self.cause = cause
        self.keys = keys
        self.visited = visited
        messages = list()
        if keys is not None:
            messages.append("Key not found: {}".format(keys))
        if visited is not None:
            messages.append("Visited: {}".format(visited))
        messages.append("Cause:\n{}".format(cause))
        message = "\n".join(messages)
        super().__init__(message)


def retrieve(
    list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
):
    """Given a nested list or dict return the desired value at key expanding
    callable nodes if necessary and :attr:`expand` is ``True``. The expansion
    is done in-place.

    Parameters
    ----------
        list_or_dict : list or dict
            Possibly nested list or dictionary.
        key : str
            key/to/value, path like string describing all keys necessary to
            consider to get to the desired value. List indices can also be
            passed here.
        splitval : str
            String that defines the delimiter between keys of the
            different depth levels in `key`.
        default : obj
            Value returned if :attr:`key` is not found.
        expand : bool
            Whether to expand callable nodes on the path or not.

    Returns
    -------
        The desired value or if :attr:`default` is not ``None`` and the
        :attr:`key` is not found returns ``default``.

    Raises
    ------
        Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
        ``None``.
    """

    keys = key.split(splitval)

    success = True
    try:
        visited = []
        parent = None
        last_key = None
        for key in keys:
            if callable(list_or_dict):
                if not expand:
                    raise KeyNotFoundError(
                        ValueError(
                            "Trying to get past callable node with expand=False."
                        ),
                        keys=keys,
                        visited=visited,
                    )
                list_or_dict = list_or_dict()
                parent[last_key] = list_or_dict

            last_key = key
            parent = list_or_dict

            try:
                if isinstance(list_or_dict, dict):
                    list_or_dict = list_or_dict[key]
                else:
                    list_or_dict = list_or_dict[int(key)]
            except (KeyError, IndexError, ValueError) as e:
                raise KeyNotFoundError(e, keys=keys, visited=visited)

            visited += [key]
        # final expansion of retrieved value
        if expand and callable(list_or_dict):
            list_or_dict = list_or_dict()
            parent[last_key] = list_or_dict
    except KeyNotFoundError as e:
        if default is None:
            raise e
        else:
            list_or_dict = default
            success = False

    if not pass_success:
        return list_or_dict
    else:
        return list_or_dict, success


if __name__ == "__main__":
    config = {"keya": "a",
              "keyb": "b",
              "keyc":
                  {"cc1": 1,
                   "cc2": 2,
                   }
              }
    from omegaconf import OmegaConf
    config = OmegaConf.create(config)
    print(config)
    retrieve(config, "keya")



================================================
FILE: __init__.py
================================================


================================================
FILE: configs/HQ_Dictionary.yaml
================================================
model:
  base_learning_rate: 4.5e-6
  target: RestoreFormer.models.vqgan_v1.RestoreFormerModel
  params:
    image_key: 'gt'
    schedule_step: [400000, 800000]
    # ignore_keys: ['vqvae.quantize.utility_counter']
    ddconfig:
      target: RestoreFormer.modules.vqvae.vqvae_arch.VQVAEGAN
      params:
        embed_dim: 256
        n_embed: 1024
        double_z: False
        z_channels: 256
        resolution: 512
        in_channels: 3
        out_ch: 3
        ch: 64
        ch_mult: [ 1,2,2,4,4,8]  # num_down = len(ch_mult)-1
        num_res_blocks: 2
        attn_resolutions: [16]
        dropout: 0.0
        enable_mid: True
        fix_decoder: False
        fix_codebook: False
        head_size: 8

    lossconfig:
      target: RestoreFormer.modules.losses.vqperceptual.VQLPIPSWithDiscriminatorWithCompWithIdentity
      params:
        disc_conditional: False
        disc_in_channels: 3
        disc_start: 30001
        disc_weight: 0.8
        codebook_weight: 1.0
        use_actnorm: False

data:
  target: main.DataModuleFromConfig
  params:
    batch_size: 4
    num_workers: 8
    train:
      target: basicsr.data.ffhq_dataset.FFHQDataset
      params:
        dataroot_gt: data/FFHQ/images512x512
        io_backend:
          type: disk
        use_hflip: True
        mean: [0.5, 0.5, 0.5]
        std: [0.5, 0.5, 0.5]
        out_size: 512
    validation:
      target: basicsr.data.ffhq_dataset.FFHQDataset
      params:
        dataroot_gt: data/FFHQ/images512x512
        io_backend:
          type: disk
        use_hflip: False
        mean: [0.5, 0.5, 0.5]
        std: [0.5, 0.5, 0.5]
        out_size: 512


================================================
FILE: configs/RestoreFormer.yaml
================================================
model:
  base_learning_rate: 4.5e-6 
  target: RestoreFormer.models.vqgan_v1.RestoreFormerModel
  params:
    image_key: 'lq'
    ckpt_path: 'YOUR TRAINED HD DICTIONARY MODEL'
    special_params_lr_scale: 10
    comp_params_lr_scale: 10
    schedule_step: [4000000, 8000000]
    ddconfig:
      target: RestoreFormer.modules.vqvae.vqvae_arch.VQVAEGANMultiHeadTransformer
      params:
        embed_dim: 256
        n_embed: 1024
        double_z: False
        z_channels: 256
        resolution: 512
        in_channels: 3  
        out_ch: 3
        ch: 64
        ch_mult: [ 1,2,2,4,4,8]  # num_down = len(ch_mult)-1
        num_res_blocks: 2
        dropout: 0.0
        attn_resolutions: [16]
        enable_mid: True

        fix_decoder: False
        fix_codebook: True
        fix_encoder: False
        head_size: 8

    lossconfig:
      target: RestoreFormer.modules.losses.vqperceptual.VQLPIPSWithDiscriminatorWithCompWithIdentity
      params:
        disc_conditional: False
        disc_in_channels: 3
        disc_start: 10001
        disc_weight: 0.8
        codebook_weight: 1.0
        use_actnorm: False
        comp_weight: 1.5
        comp_style_weight: 2e3 #2000.0
        identity_weight: 3 #1.5
        lpips_style_weight: 1e9
        identity_model_path: experiments/pretrained_models/arcface_resnet18.pth

data:
  target: main.DataModuleFromConfig
  params:
    batch_size: 4
    num_workers: 8
    train:
      target: RestoreFormer.data.ffhq_degradation_dataset.FFHQDegradationDataset
      params:
        dataroot_gt: data/FFHQ/images512x512
        io_backend:
          type: disk
        use_hflip: True
        mean: [0.5, 0.5, 0.5]
        std: [0.5, 0.5, 0.5]
        out_size: 512

        blur_kernel_size: [19,20]
        kernel_list: ['iso', 'aniso']
        kernel_prob: [0.5, 0.5]
        blur_sigma: [0.1, 10]
        downsample_range: [0.8, 8]
        noise_range: [0, 20]
        jpeg_range: [60, 100]

        color_jitter_prob: ~
        color_jitter_shift: 20
        color_jitter_pt_prob: ~
        gray_prob: ~
        gt_gray: True

        crop_components: True
        component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth
        eye_enlarge_ratio: 1.4


    validation:
      target: RestoreFormer.data.ffhq_degradation_dataset.FFHQDegradationDataset
      params:
        dataroot_gt: data/FFHQ/images512x512
        io_backend:
          type: disk
        use_hflip: False
        mean: [0.5, 0.5, 0.5]
        std: [0.5, 0.5, 0.5]
        out_size: 512

        blur_kernel_size: [19,20]
        kernel_list: ['iso', 'aniso']
        kernel_prob: [0.5, 0.5]
        blur_sigma: [0.1, 10]
        downsample_range: [0.8, 8]
        noise_range: [0, 20]
        jpeg_range: [60, 100]

        # color jitter and gray
        color_jitter_prob: ~
        color_jitter_shift: 20
        color_jitter_pt_prob: ~
        gray_prob: ~
        gt_gray: True

        crop_components: False
        component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth
        eye_enlarge_ratio: 1.4


================================================
FILE: main.py
================================================
import argparse, os, sys, datetime, glob, importlib
from omegaconf import OmegaConf
import numpy as np
from PIL import Image
import torch
import torchvision
from torch.utils.data import random_split, DataLoader, Dataset
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
from pytorch_lightning.utilities.distributed import rank_zero_only
import random

def get_obj_from_str(string, reload=False):
    module, cls = string.rsplit(".", 1)
    if reload:
        module_imp = importlib.import_module(module)
        importlib.reload(module_imp)
    return getattr(importlib.import_module(module, package=None), cls)


def get_parser(**parser_kwargs):
    def str2bool(v):
        if isinstance(v, bool):
            return v
        if v.lower() in ("yes", "true", "t", "y", "1"):
            return True
        elif v.lower() in ("no", "false", "f", "n", "0"):
            return False
        else:
            raise argparse.ArgumentTypeError("Boolean value expected.")

    parser = argparse.ArgumentParser(**parser_kwargs)
    parser.add_argument(
        "-n",
        "--name",
        type=str,
        const=True,
        default="",
        nargs="?",
        help="postfix for logdir",
    )
    parser.add_argument(
        "-r",
        "--resume",
        type=str,
        const=True,
        default="",
        nargs="?",
        help="resume from logdir or checkpoint in logdir",
    )
    parser.add_argument(
        "--pretrain",
        type=str,
        const=True,
        default="",
        nargs="?",
        help="pretrain with existed weights",
    )
    parser.add_argument(
        "-b",
        "--base",
        nargs="*",
        metavar="base_config.yaml",
        help="paths to base configs. Loaded from left-to-right. "
        "Parameters can be overwritten or added with command-line options of the form `--key value`.",
        default=list(),
    )
    parser.add_argument(
        "-t",
        "--train",
        type=str2bool,
        const=True,
        default=False,
        nargs="?",
        help="train",
    )
    parser.add_argument(
        "--no-test",
        type=str2bool,
        const=True,
        default=False,
        nargs="?",
        help="disable test",
    )
    parser.add_argument("-p", "--project", help="name of new or path to existing project")
    parser.add_argument(
        "-d",
        "--debug",
        type=str2bool,
        nargs="?",
        const=True,
        default=False,
        help="enable post-mortem debugging",
    )
    parser.add_argument(
        "-s",
        "--seed",
        type=int,
        default=23,
        help="seed for seed_everything",
    )
    parser.add_argument(
        "--random-seed",
        type=str2bool,
        nargs="?",
        const=True,
        default=False,
        help="enable post-mortem debugging",
    )
    parser.add_argument(
        "-f",
        "--postfix",
        type=str,
        default="",
        help="post-postfix for default name",
    )

    parser.add_argument(
        "--root-path",
        type=str,
        default="./",
        help="root path for saving checkpoints and logs"
    )
    parser.add_argument(
        "--num-nodes",
        type=int,
        default=1,
        help="number of gpu nodes",
    )
    

    return parser


def nondefault_trainer_args(opt):
    parser = argparse.ArgumentParser()
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args([])
    return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))


def instantiate_from_config(config):
    if not "target" in config:
        raise KeyError("Expected key `target` to instantiate.")
    if 'basicsr.data' in config["target"] or \
        'FFHQDegradationDataset' in config["target"]:
        return get_obj_from_str(config["target"])(config.get("params", dict()))
    return get_obj_from_str(config["target"])(**config.get("params", dict()))


class WrappedDataset(Dataset):
    """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
    def __init__(self, dataset):
        self.data = dataset

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


class DataModuleFromConfig(pl.LightningDataModule):
    def __init__(self, batch_size, train=None, validation=None, test=None,
                 wrap=False, num_workers=None):
        super().__init__()
        self.batch_size = batch_size
        self.dataset_configs = dict()
        self.num_workers = num_workers if num_workers is not None else batch_size*2
        if train is not None:
            self.dataset_configs["train"] = train
            self.train_dataloader = self._train_dataloader
        if validation is not None:
            self.dataset_configs["validation"] = validation
            self.val_dataloader = self._val_dataloader
        if test is not None:
            self.dataset_configs["test"] = test
            self.test_dataloader = self._test_dataloader
        self.wrap = wrap

    def prepare_data(self):
        for data_cfg in self.dataset_configs.values():
            instantiate_from_config(data_cfg)

    def setup(self, stage=None):
        self.datasets = dict(
            (k, instantiate_from_config(self.dataset_configs[k]))
            for k in self.dataset_configs)
        if self.wrap:
            for k in self.datasets:
                self.datasets[k] = WrappedDataset(self.datasets[k])

    def _train_dataloader(self):
        return DataLoader(self.datasets["train"], batch_size=self.batch_size,
                          num_workers=self.num_workers, shuffle=True)

    def _val_dataloader(self):
        return DataLoader(self.datasets["validation"],
                          batch_size=self.batch_size,
                          num_workers=self.num_workers)

    def _test_dataloader(self):
        return DataLoader(self.datasets["test"], batch_size=self.batch_size,
                          num_workers=self.num_workers)


class SetupCallback(Callback):
    def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
        super().__init__()
        self.resume = resume
        self.now = now
        self.logdir = logdir
        self.ckptdir = ckptdir
        self.cfgdir = cfgdir
        self.config = config
        self.lightning_config = lightning_config

    def on_pretrain_routine_start(self, trainer, pl_module):
        if trainer.global_rank == 0:
            # import pdb
            # pdb.set_trace()
            # Create logdirs and save configs
            os.makedirs(self.logdir, exist_ok=True)
            os.makedirs(self.ckptdir, exist_ok=True)
            os.makedirs(self.cfgdir, exist_ok=True)

            print("Project config")
            print(self.config.pretty())
            OmegaConf.save(self.config,
                           os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))

            print("Lightning config")
            print(self.lightning_config.pretty())
            OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
                           os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))


class ImageLogger(Callback):
    def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True):
        super().__init__()
        self.batch_freq = batch_frequency
        self.max_images = max_images
        self.logger_log_images = {
            pl.loggers.WandbLogger: self._wandb,
            pl.loggers.TestTubeLogger: self._testtube,
        }
        self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
        if not increase_log_steps:
            self.log_steps = [self.batch_freq]
        self.clamp = clamp

    @rank_zero_only
    def _wandb(self, pl_module, images, batch_idx, split):
        raise ValueError("No way wandb")
        grids = dict()
        for k in images:
            grid = torchvision.utils.make_grid(images[k])
            grids[f"{split}/{k}"] = wandb.Image(grid)
        pl_module.logger.experiment.log(grids)

    @rank_zero_only
    def _testtube(self, pl_module, images, batch_idx, split):
        for k in images:
            grid = torchvision.utils.make_grid(images[k])
            grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w

            tag = f"{split}/{k}"
            pl_module.logger.experiment.add_image(
                tag, grid,
                global_step=pl_module.global_step)

    @rank_zero_only
    def log_local(self, save_dir, split, images,
                  global_step, current_epoch, batch_idx):
        root = os.path.join(save_dir, "images", split)
        for k in images:
            grid = torchvision.utils.make_grid(images[k], nrow=4)

            grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w
            grid = grid.transpose(0,1).transpose(1,2).squeeze(-1)
            grid = grid.numpy()
            grid = (grid*255).astype(np.uint8)
            filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
                k,
                global_step,
                current_epoch,
                batch_idx)
            path = os.path.join(root, filename)
            os.makedirs(os.path.split(path)[0], exist_ok=True)
            Image.fromarray(grid).save(path)

    def log_img(self, pl_module, batch, batch_idx, split="train"):
        if (self.check_frequency(batch_idx) and  # batch_idx % self.batch_freq == 0
                hasattr(pl_module, "log_images") and
                callable(pl_module.log_images) and
                self.max_images > 0):
            logger = type(pl_module.logger)

            is_train = pl_module.training
            if is_train:
                pl_module.eval()

            with torch.no_grad():
                images = pl_module.log_images(batch, split=split)

            for k in images:
                N = min(images[k].shape[0], self.max_images)
                images[k] = images[k][:N]
                if isinstance(images[k], torch.Tensor):
                    images[k] = images[k].detach().cpu()
                    if self.clamp:
                        images[k] = torch.clamp(images[k], -1., 1.)

            self.log_local(pl_module.logger.save_dir, split, images,
                           pl_module.global_step, pl_module.current_epoch, batch_idx)

            logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
            logger_log_images(pl_module, images, pl_module.global_step, split)

            if is_train:
                pl_module.train()

    def check_frequency(self, batch_idx):
        if (batch_idx % self.batch_freq) == 0 or (batch_idx in self.log_steps):
            try:
                self.log_steps.pop(0)
            except IndexError:
                pass
            return True
        return False

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        self.log_img(pl_module, batch, batch_idx, split="train")

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        self.log_img(pl_module, batch, batch_idx, split="val")



if __name__ == "__main__":
    # custom parser to specify config files, train, test and debug mode,
    # postfix, resume.
    # `--key value` arguments are interpreted as arguments to the trainer.
    # `nested.key=value` arguments are interpreted as config parameters.
    # configs are merged from left-to-right followed by command line parameters.

    # model:
    #   base_learning_rate: float
    #   target: path to lightning module
    #   params:
    #       key: value
    # data:
    #   target: main.DataModuleFromConfig
    #   params:
    #      batch_size: int
    #      wrap: bool
    #      train:
    #          target: path to train dataset
    #          params:
    #              key: value
    #      validation:
    #          target: path to validation dataset
    #          params:
    #              key: value
    #      test:
    #          target: path to test dataset
    #          params:
    #              key: value
    # lightning: (optional, has sane defaults and can be specified on cmdline)
    #   trainer:
    #       additional arguments to trainer
    #   logger:
    #       logger to instantiate
    #   modelcheckpoint:
    #       modelcheckpoint to instantiate
    #   callbacks:
    #       callback1:
    #           target: importpath
    #           params:
    #               key: value
    now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")

    # add cwd for convenience and to make classes in this file available when
    # running as `python main.py`
    # (in particular `main.DataModuleFromConfig`)
    sys.path.append(os.getcwd())

    parser = get_parser()
    parser = Trainer.add_argparse_args(parser)

    opt, unknown = parser.parse_known_args()
    if opt.name and opt.resume:
        raise ValueError(
            "-n/--name and -r/--resume cannot be specified both."
            "If you want to resume training in a new log folder, "
            "use -n/--name in combination with --resume_from_checkpoint"
        )
    if opt.resume:
        if not os.path.exists(opt.resume):
            raise ValueError("Cannot find {}".format(opt.resume))
        if os.path.isfile(opt.resume):
            paths = opt.resume.split("/")
            idx = len(paths)-paths[::-1].index("logs")+1
            logdir = "/".join(paths[:idx])
            ckpt = opt.resume
        else:
            assert os.path.isdir(opt.resume), opt.resume
            logdir = opt.resume.rstrip("/")
            ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")

        opt.resume_from_checkpoint = ckpt
        base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
        opt.base = base_configs+opt.base
        _tmp = logdir.split("/")
        nowname = _tmp[_tmp.index("logs")+1]+opt.postfix
        logdir = os.path.join(opt.root_path, "logs", nowname)
    else:
        if opt.name:
            name = "_"+opt.name
        elif opt.base:
            cfg_fname = os.path.split(opt.base[0])[-1]
            cfg_name = os.path.splitext(cfg_fname)[0]
            name = "_"+cfg_name
        else:
            name = ""
        nowname = now+name+opt.postfix
        logdir = os.path.join(opt.root_path, "logs", nowname)

    if opt.random_seed:
        opt.seed = random.randint(1,100)
    logdir = logdir + '_seed' + str(opt.seed)
    
    ckptdir = os.path.join(logdir, "checkpoints")
    cfgdir = os.path.join(logdir, "configs")

    seed_everything(opt.seed)

    try:
        # init and save configs
        configs = [OmegaConf.load(cfg) for cfg in opt.base]
        cli = OmegaConf.from_dotlist(unknown)
        config = OmegaConf.merge(*configs, cli)
        lightning_config = config.pop("lightning", OmegaConf.create())
        # merge trainer cli with config
        trainer_config = lightning_config.get("trainer", OmegaConf.create())
        # default to ddp
        # trainer_config["distributed_backend"] = "ddp"
        trainer_config["accelerator"] = "ddp"
        # trainer_config["plugins"]="ddp_sharded"
        for k in nondefault_trainer_args(opt):
            trainer_config[k] = getattr(opt, k)
        if not "gpus" in trainer_config:
            del trainer_config["distributed_backend"]
            cpu = True
        else:
            gpuinfo = trainer_config["gpus"]
            print(f"Running on GPUs {gpuinfo}")
            cpu = False
        trainer_opt = argparse.Namespace(**trainer_config)
        lightning_config.trainer = trainer_config

        # model
        model = instantiate_from_config(config.model)

        # trainer and callbacks
        trainer_kwargs = dict()
        # trainer_kwargs['sync_batchnorm'] = True
        
        # default logger configs
        # NOTE wandb < 0.10.0 interferes with shutdown
        # wandb >= 0.10.0 seems to fix it but still interferes with pudb
        # debugging (wrongly sized pudb ui)
        # thus prefer testtube for now
        default_logger_cfgs = {
            "wandb": {
                "target": "pytorch_lightning.loggers.WandbLogger",
                "params": {
                    "name": nowname,
                    "save_dir": logdir,
                    "offline": opt.debug,
                    "id": nowname,
                }
            },
            "testtube": {
                "target": "pytorch_lightning.loggers.TestTubeLogger",
                "params": {
                    "name": "testtube",
                    "save_dir": logdir,
                }
            },
        }
        default_logger_cfg = default_logger_cfgs["testtube"]
        logger_cfg = lightning_config.logger or OmegaConf.create()
        logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
        trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)

        # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
        # specify which metric is used to determine best models
        default_modelckpt_cfg = {
            "target": "pytorch_lightning.callbacks.ModelCheckpoint",
            "params": {
                "dirpath": ckptdir,
                "filename": "{epoch:06}",
                "verbose": True,
                "save_last": True,
                "period": 1
            }
        }
        if hasattr(model, "monitor"):
            print(f"Monitoring {model.monitor} as checkpoint metric.")
            default_modelckpt_cfg["params"]["monitor"] = model.monitor
            default_modelckpt_cfg["params"]["save_top_k"] = 3

        modelckpt_cfg = lightning_config.modelcheckpoint or OmegaConf.create()
        modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
        trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)

        # add callback which sets up log directory
        default_callbacks_cfg = {
            "setup_callback": {
                "target": "main.SetupCallback",
                "params": {
                    "resume": opt.resume,
                    "now": now,
                    "logdir": logdir,
                    "ckptdir": ckptdir,
                    "cfgdir": cfgdir,
                    "config": config,
                    "lightning_config": lightning_config,
                }
            },
            "image_logger": {
                "target": "main.ImageLogger",
                "params": {
                    "batch_frequency": 750,
                    "max_images": 4,
                    "clamp": True
                }
            },
            "learning_rate_logger": {
                "target": "main.LearningRateMonitor",
                "params": {
                    "logging_interval": "step",
                    #"log_momentum": True
                }
            },
        }
        callbacks_cfg = lightning_config.callbacks or OmegaConf.create()
        callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
        trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]

        trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)

        # data
        data = instantiate_from_config(config.data)
        # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
        # calling these ourselves should not be necessary but it is.
        # lightning still takes care of proper multiprocessing though
        data.prepare_data()
        data.setup()

        # configure learning rate
        bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
        if not cpu:
            ngpu = len(lightning_config.trainer.gpus.strip(",").split(','))
        else:
            ngpu = 1
        accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches or 1
        print(f"accumulate_grad_batches = {accumulate_grad_batches}")
        lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
        model.learning_rate = accumulate_grad_batches * ngpu * bs * trainer.num_nodes * base_lr
        print("Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (num_nodes) * {} (batchsize) * {:.2e} (base_lr)".format(
            model.learning_rate, accumulate_grad_batches, ngpu, trainer.num_nodes, bs, base_lr))

        # allow checkpointing via USR1
        def melk(*args, **kwargs):
            # run all checkpoint hooks
            if trainer.global_rank == 0:
                print("Summoning checkpoint.")
                ckpt_path = os.path.join(ckptdir, "last.ckpt")
                trainer.save_checkpoint(ckpt_path)

        def divein(*args, **kwargs):
            if trainer.global_rank == 0:
                import pudb; pudb.set_trace()

        import signal
        signal.signal(signal.SIGUSR1, melk)
        signal.signal(signal.SIGUSR2, divein)

        # run
        if opt.train:
            try:
                trainer.fit(model, data)
            except Exception:
                melk()
                raise
        if not opt.no_test and not trainer.interrupted:
            trainer.test(model, data)
    except Exception:
        if opt.debug and trainer.global_rank==0:
            try:
                import pudb as debugger
            except ImportError:
                import pdb as debugger
            debugger.post_mortem()
        raise
    finally:
        # move newly created debug project to debug_runs
        if opt.debug and not opt.resume and trainer.global_rank==0:
            dst, name = os.path.split(logdir)
            dst = os.path.join(dst, "debug_runs", name)
            os.makedirs(os.path.split(dst)[0], exist_ok=True)
            os.rename(logdir, dst)


================================================
FILE: restoreformer_requirement.txt
================================================
Package                 Version             Location
----------------------- ------------------- ------------------------------------------------------------------------------
absl-py                 0.13.0
addict                  2.4.0
aiohttp                 3.7.4.post0
albumentations          0.4.3
antlr4-python3-runtime  4.8
astunparse              1.6.3
async-timeout           3.0.1
attrs                   21.2.0
basicsr                 1.3.3.4
cached-property         1.5.2
cachetools              4.2.2
certifi                 2021.5.30
chardet                 4.0.0
cycler                  0.10.0
dlib                    19.22.99
facexlib                0.1.3.1
flatbuffers             1.12
fsspec                  2021.6.1
future                  0.18.2
gast                    0.4.0
google-auth             1.32.1
google-auth-oauthlib    0.4.4
google-pasta            0.2.0
grpcio                  1.39.0
h5py                    3.1.0
idna                    2.10
imageio                 2.9.0
imgaug                  0.2.6
importlib-metadata      4.6.1
joblib                  1.0.1
keras-nightly           2.7.0.dev2021072800
Keras-Preprocessing     1.1.2
kiwisolver              1.3.1
libclang                11.1.0
lmdb                    1.2.1
Markdown                3.3.4
matplotlib              3.4.2
mkl-fft                 1.3.0
mkl-random              1.2.1
mkl-service             2.3.0
multidict               5.1.0
networkx                2.6.1
numpy                   1.19.5
oauthlib                3.1.1
olefile                 0.46
omegaconf               2.0.0
opencv-python           4.5.2.54
opt-einsum              3.3.0
packaging               21.0
pandas                  1.3.0
Pillow                  8.3.1
pip                     21.1.3
protobuf                3.17.3
pyasn1                  0.4.8
pyasn1-modules          0.2.8
pyDeprecate             0.3.0
pyparsing               2.4.7
python-dateutil         2.8.1
pytorch-lightning       1.0.8
pytz                    2021.1
PyWavelets              1.1.1
PyYAML                  5.4.1
requests                2.25.1
requests-oauthlib       1.3.0
rsa                     4.7.2
scikit-image            0.18.2
scikit-learn            0.24.2
scipy                   1.7.0
setuptools              52.0.0.post20210125
six                     1.15.0
sklearn                 0.0
tb-nightly              2.6.0a20210728
tensorboard-data-server 0.6.1
tensorboard-plugin-wit  1.8.0
termcolor               1.1.0
test-tube               0.7.5
tf-estimator-nightly    2.7.0.dev2021072801
tf-nightly              2.7.0.dev20210728
threadpoolctl           2.2.0
tifffile                2021.7.2
torch                   1.7.1
torchaudio              0.7.0a0+a853dff
torchmetrics            0.4.1
torchvision             0.8.2
tqdm                    4.61.2
typing-extensions       3.7.4.3
urllib3                 1.26.6
Werkzeug                2.0.1
wheel                   0.36.2
wrapt                   1.12.1
yapf                    0.31.0
yarl                    1.6.3
zipp                    3.5.0


================================================
FILE: scripts/metrics/cal_fid.py
================================================
import os, sys
import argparse
import math
import numpy as np
import torch
from torch.utils.data import DataLoader

from basicsr.data import build_dataset
from basicsr.metrics.fid import calculate_fid, extract_inception_features, load_patched_inception_v3


def calculate_fid_folder():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    parser = argparse.ArgumentParser()
    parser.add_argument('folder', type=str, help='Path to the folder.')
    parser.add_argument('--fid_stats', type=str, help='Path to the dataset fid statistics.')
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--num_sample', type=int, default=50000)
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--backend', type=str, default='disk', help='io backend for dataset. Option: disk, lmdb')
    parser.add_argument('--save_name', type=str, default='fid', help='File name for saving results')
    args = parser.parse_args()

    # inception model
    inception = load_patched_inception_v3(device)

    # create dataset
    opt = {}
    opt['name'] = 'SingleImageDataset'
    opt['type'] = 'SingleImageDataset'
    opt['dataroot_lq'] = args.folder
    opt['io_backend'] = dict(type=args.backend)
    opt['mean'] = [0.5, 0.5, 0.5]
    opt['std'] = [0.5, 0.5, 0.5]
    dataset = build_dataset(opt)

    # create dataloader
    data_loader = DataLoader(
        dataset=dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        sampler=None,
        drop_last=False)
    args.num_sample = min(args.num_sample, len(dataset))
    total_batch = math.ceil(args.num_sample / args.batch_size)

    def data_generator(data_loader, total_batch):
        for idx, data in enumerate(data_loader):
            if idx >= total_batch:
                break
            else:
                yield data['lq']

    features = extract_inception_features(data_generator(data_loader, total_batch), inception, total_batch, device)
    features = features.numpy()
    total_len = features.shape[0]
    features = features[:args.num_sample]
    # print(f'Extracted {total_len} features, ' f'use the first {features.shape[0]} features to calculate stats.')

    sample_mean = np.mean(features, 0)
    sample_cov = np.cov(features, rowvar=False)

    # load the dataset stats
    stats = torch.load(args.fid_stats)
    real_mean = stats['mean']
    real_cov = stats['cov']

    # calculate FID metric
    fid = calculate_fid(sample_mean, sample_cov, real_mean, real_cov)

    fout=open(args.save_name, 'w')
    fout.write(str(fid)+'\n')
    fout.close()

    print(args.folder)
    print('fid:', fid)


if __name__ == '__main__':
    calculate_fid_folder()


================================================
FILE: scripts/metrics/cal_identity_distance.py
================================================
import os, sys
import torch
import argparse
import cv2
import numpy as np
import glob
import pdb
import tqdm
from copy import deepcopy
import torch.nn.functional as F
import math


root_path = os.path.abspath(os.path.join(__file__, os.path.pardir, os.path.pardir, os.path.pardir))
sys.path.append(root_path)
sys.path.append(os.path.join(root_path, 'RestoreFormer/modules/losses'))

from RestoreFormer.modules.vqvae.arcface_arch import ResNetArcFace
from basicsr.losses.losses import L1Loss, MSELoss

def cosine_similarity(emb1, emb2):
    return np.arccos(np.dot(emb1, emb2) / ( np.linalg.norm(emb1) * np.linalg.norm(emb2)))


def gray_resize_for_identity(out, size=128):
    out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
    out_gray = out_gray.unsqueeze(1)
    out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
    return out_gray

def calculate_identity_distance_folder():
    parser = argparse.ArgumentParser()

    parser.add_argument('folder', type=str, help='Path to the folder')
    parser.add_argument('--gt_folder', type=str, help='Path to the GT')
    parser.add_argument('--save_name', type=str, default='niqe', help='File name for saving results')
    parser.add_argument('--need_post', type=int, default=0, help='0: the name of image does not include 00, 1: otherwise')

    args = parser.parse_args()

    fout = open(args.save_name, 'w')

    identity = ResNetArcFace(block = 'IRBlock', 
                                  layers = [2, 2, 2, 2],
                                  use_se = False)
    identity_model_path = 'experiments/pretrained_models/arcface_resnet18.pth'
    
    sd = torch.load(identity_model_path, map_location="cpu")
    for k, v in deepcopy(sd).items():
        if k.startswith('module.'):
            sd[k[7:]] = v
            sd.pop(k)
    identity.load_state_dict(sd, strict=True)
    identity.eval()

    for param in identity.parameters():
        param.requires_grad = False

    identity = identity.cuda()

    gt_names = glob.glob(os.path.join(args.gt_folder, '*'))
    gt_names.sort()
    
    mean_dist = 0.
    for i in tqdm.tqdm(range(len(gt_names))):
        gt_name = gt_names[i].split('/')[-1][:-4]
        if args.need_post:
            img_name = os.path.join(args.folder,gt_name + '_00.png')
        else:
            img_name = os.path.join(args.folder,gt_name + '.png')
        if not os.path.exists(img_name):
            print(img_name, 'does not exist')
            continue

        img = cv2.imread(img_name)
        gt = cv2.imread(gt_names[i])

        img = img.astype(np.float32) / 255.
        img = torch.FloatTensor(img).cuda()
        img = img.permute(2,0,1)
        img = img.unsqueeze(0)

        gt = gt.astype(np.float32) / 255.
        gt = torch.FloatTensor(gt).cuda()
        gt = gt.permute(2,0,1)
        gt = gt.unsqueeze(0)

        out_gray = gray_resize_for_identity(img)
        gt_gray = gray_resize_for_identity(gt)

        with torch.no_grad():
            identity_gt = identity(gt_gray)
            identity_out = identity(out_gray)

        identity_gt = identity_gt.cpu().data.numpy().squeeze()
        identity_out = identity_out.cpu().data.numpy().squeeze()
        identity_loss = cosine_similarity(identity_gt, identity_out)

        fout.write(gt_name + ' ' + str(identity_loss) + '\n')
        mean_dist += identity_loss

    fout.write('Mean: ' + str(mean_dist / len(gt_names)) + '\n')
    fout.close()
    print('mean_dist:', mean_dist / len(gt_names))

if __name__ == '__main__':
    calculate_identity_distance_folder()

================================================
FILE: scripts/metrics/cal_psnr_ssim.py
================================================
import os, sys
import argparse
import cv2
import numpy as np
import glob
import pdb
import tqdm
import torch

from basicsr.metrics.psnr_ssim import calculate_psnr, calculate_ssim

root_path = os.path.abspath(os.path.join(__file__, os.path.pardir, os.path.pardir, os.path.pardir))
sys.path.append(root_path)
sys.path.append(os.path.join(root_path, 'RestoreFormer/modules/losses'))

from lpips import LPIPS

def calculate_psnr_ssim_lpips_folder():
    parser = argparse.ArgumentParser()

    parser.add_argument('folder', type=str, help='Path to the folder')
    parser.add_argument('--gt_folder', type=str, help='Path to the GT')
    parser.add_argument('--save_name', type=str, default='niqe', help='File name for saving results')
    parser.add_argument('--need_post', type=int, default=0, help='0: the name of image does not include 00, 1: otherwise')

    args = parser.parse_args()

    fout = open(args.save_name, 'w')
    fout.write('NAME\tPSNR\tSSIM\tLPIPS\n')

    H, W = 512, 512

    gt_names = glob.glob(os.path.join(args.gt_folder, '*'))
    gt_names.sort()

    perceptual_loss = LPIPS().eval().cuda()

    mean_psnr = 0.
    mean_ssim = 0.
    mean_lpips = 0.
    mean_norm_lpips = 0.

    for i in tqdm.tqdm(range(len(gt_names))):
        gt_name = gt_names[i].split('/')[-1][:-4]

        if args.need_post:
            img_name = os.path.join(args.folder,gt_name + '_00.png')
        else:
            img_name = os.path.join(args.folder,gt_name + '.png')

        if not os.path.exists(img_name):
            print(img_name, 'does not exist')
            continue

        img = cv2.imread(img_name)
        gt = cv2.imread(gt_names[i])

        cur_psnr = calculate_psnr(img, gt, 0)
        cur_ssim = calculate_ssim(img, gt, 0)

        # lpips:
        img = img.astype(np.float32) / 255.
        img = torch.FloatTensor(img).cuda()
        img = img.permute(2,0,1)
        img = img.unsqueeze(0)

        gt = gt.astype(np.float32) / 255.
        gt = torch.FloatTensor(gt).cuda()
        gt = gt.permute(2,0,1)
        gt = gt.unsqueeze(0)

        cur_lpips = perceptual_loss(img, gt)
        cur_lpips = cur_lpips[0].item()

        img = (img - 0.5) / 0.5
        gt = (gt - 0.5) / 0.5

        norm_lpips = perceptual_loss(img, gt)
        norm_lpips = norm_lpips[0].item()

        # print(cur_psnr, cur_ssim, cur_lpips, norm_lpips)

        fout.write(gt_name + '\t' + str(cur_psnr) + '\t' + str(cur_ssim) + '\t' + str(cur_lpips) + '\t' + str(norm_lpips) + '\n')

        mean_psnr += cur_psnr
        mean_ssim += cur_ssim
        mean_lpips += cur_lpips
        mean_norm_lpips += norm_lpips

    mean_psnr /= float(len(gt_names))
    mean_ssim /= float(len(gt_names))
    mean_lpips /= float(len(gt_names))
    mean_norm_lpips /= float(len(gt_names))

    fout.write(str(mean_psnr) + '\t' + str(mean_ssim) + '\t' + str(mean_lpips) + '\t' + str(mean_norm_lpips) + '\n')
    fout.close()

    print('psnr, ssim, lpips, norm_lpips:', mean_psnr, mean_ssim, mean_lpips, mean_norm_lpips)

if __name__ == '__main__':
    calculate_psnr_ssim_lpips_folder()

================================================
FILE: scripts/metrics/run.sh
================================================

### Journal ###
root='results/'
out_root='results/metrics'

test_name='RestoreFormer'

test_image=$test_name'/restored_faces'
out_name=$test_name
need_post=1

CelebAHQ_GT='YOUR_PATH'

# FID
python -u scripts/metrics/cal_fid.py \
$root'/'$test_image \
--fid_stats 'experiments/pretrained_models/inception_FFHQ_512-f7b384ab.pth' \
--save_name $out_root'/'$out_name'_fid.txt' \

if [ -d $CelebAHQ_GT ]
then
    # PSRN SSIM LPIPS
    python -u scripts/metrics/cal_psnr_ssim.py \
    $root'/'$test_image \
    --gt_folder $CelebAHQ_GT \
    --save_name $out_root'/'$out_name'_psnr_ssim_lpips.txt' \
    --need_post $need_post \

    # # # PSRN SSIM LPIPS
    python -u scripts/metrics/cal_identity_distance.py  \
    $root'/'$test_image \
    --gt_folder $CelebAHQ_GT \
    --save_name $out_root'/'$out_name'_id.txt' \
    --need_post $need_post
else
    echo 'The path of GT does not exist'
fi

================================================
FILE: scripts/run.sh
================================================
export BASICSR_JIT=True

conf_name='HQ_Dictionary'
# conf_name='RestoreFormer'

ROOT_PATH='' # The path for saving model and logs

gpus='0,1,2,3'

#P: pretrain SL: soft learning
node_n=1

python -u main.py \
--root-path $ROOT_PATH \
--base 'configs/'$conf_name'.yaml' \
-t True \
--postfix $conf_name \
--gpus $gpus \
--num-nodes $node_n \
--random-seed True \


================================================
FILE: scripts/test.py
================================================
import argparse, os, sys, glob, math, time
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
import pdb

sys.path.append(os.getcwd())

from main import instantiate_from_config, DataModuleFromConfig
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from tqdm import trange, tqdm

import cv2
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from torchvision.transforms.functional import normalize

from basicsr.utils import img2tensor, imwrite, tensor2img


def restoration(model,
                face_helper,
                img_path,
                save_root,
                has_aligned=False,
                only_center_face=True,
                suffix=None,
                paste_back=False):
    # read image
    img_name = os.path.basename(img_path)
    # print(f'Processing {img_name} ...')
    basename, _ = os.path.splitext(img_name)
    input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
    face_helper.clean_all()

    if has_aligned:
        input_img = cv2.resize(input_img, (512, 512))
        face_helper.cropped_faces = [input_img]
    else:
        face_helper.read_image(input_img)
        # get face landmarks for each face
        face_helper.get_face_landmarks_5(only_center_face=only_center_face, pad_blur=False)
        # align and warp each face
        save_crop_path = os.path.join(save_root, 'cropped_faces', img_name)
        face_helper.align_warp_face(save_crop_path)

    # face restoration
    for idx, cropped_face in enumerate(face_helper.cropped_faces):
        # prepare data
        cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
        normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
        cropped_face_t = cropped_face_t.unsqueeze(0).to('cuda')

        try:
            with torch.no_grad():
                output = model(cropped_face_t)
                restored_face = tensor2img(output[0].squeeze(0), rgb2bgr=True, min_max=(-1, 1))
        except RuntimeError as error:
            print(f'\tFailed inference for GFPGAN: {error}.')
            restored_face = cropped_face

        restored_face = restored_face.astype('uint8')
        face_helper.add_restored_face(restored_face)

        if suffix is not None:
            save_face_name = f'{basename}_{idx:02d}_{suffix}.png'
        else:
            save_face_name = f'{basename}_{idx:02d}.png'
        save_restore_path = os.path.join(save_root, 'restored_faces', save_face_name)
        imwrite(restored_face, save_restore_path)


    if not has_aligned and paste_back:
        face_helper.get_inverse_affine(None)
        save_restore_path = os.path.join(save_root, 'restored_imgs', img_name)
        # paste each restored face to the input image
        face_helper.paste_faces_to_input_image(save_restore_path)

def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-r",
        "--resume",
        type=str,
        nargs="?",
        help="load from logdir or checkpoint in logdir",
    )
    parser.add_argument(
        "-b",
        "--base",
        nargs="*",
        metavar="base_config.yaml",
        help="paths to base configs. Loaded from left-to-right. "
        "Parameters can be overwritten or added with command-line options of the form `--key value`.",
        default=list(),
    )
    parser.add_argument(
        "-c",
        "--config",
        nargs="?",
        metavar="single_config.yaml",
        help="path to single config. If specified, base configs will be ignored "
        "(except for the last one if left unspecified).",
        const=True,
        default="",
    )
    parser.add_argument(
        "--ignore_base_data",
        action="store_true",
        help="Ignore data specification from base configs. Useful if you want "
        "to specify a custom datasets on the command line.",
    )
    parser.add_argument(
        "--outdir",
        required=True,
        type=str,
        help="Where to write outputs to.",
    )
    parser.add_argument(
        "--top_k",
        type=int,
        default=100,
        help="Sample from among top-k predictions.",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=1.0,
        help="Sampling temperature.",
    )
    parser.add_argument('--upscale_factor', type=int, default=1)
    parser.add_argument('--test_path', type=str, default='inputs/whole_imgs')
    parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
    parser.add_argument('--only_center_face', action='store_true')
    parser.add_argument('--aligned', action='store_true')
    parser.add_argument('--paste_back', action='store_true')

    return parser


def load_model_from_config(config, sd, gpu=True, eval_mode=True):
    if "ckpt_path" in config.params:
        print("Deleting the restore-ckpt path from the config...")
        config.params.ckpt_path = None
    if "downsample_cond_size" in config.params:
        print("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
        config.params.downsample_cond_size = -1
        config.params["downsample_cond_factor"] = 0.5
    try:
        if "ckpt_path" in config.params.first_stage_config.params:
            config.params.first_stage_config.params.ckpt_path = None
            print("Deleting the first-stage restore-ckpt path from the config...")
        if "ckpt_path" in config.params.cond_stage_config.params:
            config.params.cond_stage_config.params.ckpt_path = None
            print("Deleting the cond-stage restore-ckpt path from the config...")
    except:
        pass

    model = instantiate_from_config(config)
    if sd is not None:
        keys = list(sd.keys())

        state_dict = model.state_dict()
        require_keys = state_dict.keys()
        keys = sd.keys()
        un_pretrained_keys = []
        for k in require_keys:
            if k not in keys: 
                # miss 'vqvae.'
                if k[6:] in keys:
                    state_dict[k] = sd[k[6:]]
                else:
                    un_pretrained_keys.append(k)
            else:
                state_dict[k] = sd[k]

        # print(f'*************************************************')
        # print(f"Layers without pretraining: {un_pretrained_keys}")
        # print(f'*************************************************')

        model.load_state_dict(state_dict, strict=True)

    if gpu:
        model.cuda()
    if eval_mode:
        model.eval()
    return {"model": model}


def load_model_and_dset(config, ckpt, gpu, eval_mode):

    # now load the specified checkpoint
    if ckpt:
        pl_sd = torch.load(ckpt, map_location="cpu")
    else:
        pl_sd = {"state_dict": None}

    model = load_model_from_config(config.model,
                                   pl_sd["state_dict"],
                                   gpu=gpu,
                                   eval_mode=eval_mode)["model"]
    return model

if __name__ == "__main__":
    sys.path.append(os.getcwd())

    parser = get_parser()

    opt, unknown = parser.parse_known_args()

    ckpt = None
    if opt.resume:
        if not os.path.exists(opt.resume):
            raise ValueError("Cannot find {}".format(opt.resume))
        if os.path.isfile(opt.resume):
            paths = opt.resume.split("/")
            try:
                idx = len(paths)-paths[::-1].index("logs")+1
            except ValueError:
                idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
            logdir = "/".join(paths[:idx])
            ckpt = opt.resume
        else:
            assert os.path.isdir(opt.resume), opt.resume
            logdir = opt.resume.rstrip("/")
            ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
        print(f"logdir:{logdir}")
        base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
        opt.base = base_configs+opt.base

    if opt.config:
        if type(opt.config) == str:
            if not os.path.exists(opt.config):
                raise ValueError("Cannot find {}".format(opt.config))
            if os.path.isfile(opt.config):
                opt.base = [opt.config]
            else:
                opt.base = sorted(glob.glob(os.path.join(opt.config, "*-project.yaml")))
        else:
            opt.base = [opt.base[-1]]

    configs = [OmegaConf.load(cfg) for cfg in opt.base]
    cli = OmegaConf.from_dotlist(unknown)
    if opt.ignore_base_data:
        for config in configs:
            if hasattr(config, "data"): del config["data"]
    config = OmegaConf.merge(*configs, cli)
    
    print(config)
    gpu = True
    eval_mode = True
    show_config = False
    if show_config:
        print(OmegaConf.to_container(config))

    model = load_model_and_dset(config, ckpt, gpu, eval_mode)
    
    outdir = opt.outdir
    os.makedirs(outdir, exist_ok=True)
    print("Writing samples to ", outdir)

    # initialize face helper
    face_helper = FaceRestoreHelper(
        opt.upscale_factor, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png')

    img_list = sorted(glob.glob(os.path.join(opt.test_path, '*')))

    print('Results are in the <{}> folder.'.format(outdir))
    
    for img_path in tqdm(img_list):
        restoration(
                model,
                face_helper,
                img_path,
                outdir,
                has_aligned=opt.aligned,
                only_center_face=opt.only_center_face,
                suffix=opt.suffix,
                paste_back=opt.paste_back)

    print('Test number: ', len(img_list))
    print('Results are in the <{}> folder.'.format(outdir))


================================================
FILE: scripts/test.sh
================================================
# # ### Good
exp_name='RestoreFormer'

root_path='experiments'
out_root_path='results'
align_test_path='data/test'
tag='test'

outdir=$out_root_path'/'$exp_name'_'$tag

if [ ! -d $outdir ];then
    mkdir $outdir
fi

python -u scripts/test.py \
--outdir $outdir \
-r $root_path'/'$exp_name'/last.ckpt' \
-c 'configs/RestoreFormer.yaml' \
--test_path $align_test_path \
--aligned

Download .txt
gitextract_istv2bdf/

├── .gitignore
├── LICENSE
├── README.md
├── RestoreFormer/
│   ├── data/
│   │   └── ffhq_degradation_dataset.py
│   ├── distributed/
│   │   ├── __init__.py
│   │   ├── distributed.py
│   │   └── launch.py
│   ├── models/
│   │   └── vqgan_v1.py
│   ├── modules/
│   │   ├── discriminator/
│   │   │   └── model.py
│   │   ├── losses/
│   │   │   ├── __init__.py
│   │   │   ├── lpips.py
│   │   │   └── vqperceptual.py
│   │   ├── util.py
│   │   └── vqvae/
│   │       ├── arcface_arch.py
│   │       ├── facial_component_discriminator.py
│   │       ├── utils.py
│   │       └── vqvae_arch.py
│   └── util.py
├── __init__.py
├── configs/
│   ├── HQ_Dictionary.yaml
│   └── RestoreFormer.yaml
├── main.py
├── restoreformer_requirement.txt
└── scripts/
    ├── metrics/
    │   ├── cal_fid.py
    │   ├── cal_identity_distance.py
    │   ├── cal_psnr_ssim.py
    │   └── run.sh
    ├── run.sh
    ├── test.py
    └── test.sh
Download .txt
SYMBOL INDEX (172 symbols across 18 files)

FILE: RestoreFormer/data/ffhq_degradation_dataset.py
  class FFHQDegradationDataset (line 20) | class FFHQDegradationDataset(data.Dataset):
    method __init__ (line 22) | def __init__(self, opt):
    method color_jitter (line 82) | def color_jitter(img, shift):
    method color_jitter_pt (line 89) | def color_jitter_pt(img, brightness, contrast, saturation, hue):
    method get_component_coordinates (line 109) | def get_component_coordinates(self, index, status):
    method __getitem__ (line 133) | def __getitem__(self, index):
    method __len__ (line 217) | def __len__(self):

FILE: RestoreFormer/distributed/distributed.py
  function is_primary (line 12) | def is_primary():
  function get_rank (line 16) | def get_rank():
  function get_local_rank (line 26) | def get_local_rank():
  function synchronize (line 39) | def synchronize():
  function get_world_size (line 54) | def get_world_size():
  function all_reduce (line 64) | def all_reduce(tensor, op=dist.ReduceOp.SUM):
  function all_gather (line 75) | def all_gather(data):
  function reduce_dict (line 110) | def reduce_dict(input_dict, average=True):
  function data_sampler (line 135) | def data_sampler(dataset, shuffle, distributed):

FILE: RestoreFormer/distributed/launch.py
  function find_free_port (line 10) | def find_free_port():
  function launch (line 22) | def launch(fn, n_gpu_per_machine, n_machine=1, machine_rank=0, dist_url=...
  function distributed_worker (line 52) | def distributed_worker(

FILE: RestoreFormer/models/vqgan_v1.py
  class RestoreFormerModel (line 8) | class RestoreFormerModel(pl.LightningModule):
    method __init__ (line 9) | def __init__(self,
    method init_from_ckpt (line 43) | def init_from_ckpt(self, path, ignore_keys=list()):
    method forward (line 74) | def forward(self, input):
    method training_step (line 78) | def training_step(self, batch, batch_idx, optimizer_idx):
    method validation_step (line 142) | def validation_step(self, batch, batch_idx):
    method configure_optimizers (line 164) | def configure_optimizers(self):
    method get_last_layer (line 207) | def get_last_layer(self):
    method log_images (line 212) | def log_images(self, batch, **kwargs):

FILE: RestoreFormer/modules/discriminator/model.py
  function weights_init (line 8) | def weights_init(m):
  class NLayerDiscriminator (line 17) | class NLayerDiscriminator(nn.Module):
    method __init__ (line 21) | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
    method forward (line 65) | def forward(self, input):
  class NLayerDiscriminator_v1 (line 69) | class NLayerDiscriminator_v1(nn.Module):
    method __init__ (line 73) | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
    method forward (line 123) | def forward(self, input):

FILE: RestoreFormer/modules/losses/lpips.py
  class LPIPS (line 11) | class LPIPS(nn.Module):
    method __init__ (line 13) | def __init__(self, use_dropout=True, style_weight=0.):
    method load_from_pretrained (line 29) | def load_from_pretrained(self, name="vgg_lpips"):
    method from_pretrained (line 35) | def from_pretrained(cls, name="vgg_lpips"):
    method forward (line 43) | def forward(self, input, target):
    method _gram_mat (line 63) | def _gram_mat(self, x):
  class ScalingLayer (line 79) | class ScalingLayer(nn.Module):
    method __init__ (line 80) | def __init__(self):
    method forward (line 85) | def forward(self, inp):
  class NetLinLayer (line 89) | class NetLinLayer(nn.Module):
    method __init__ (line 91) | def __init__(self, chn_in, chn_out=1, use_dropout=False):
  class vgg16 (line 98) | class vgg16(torch.nn.Module):
    method __init__ (line 99) | def __init__(self, requires_grad=False, pretrained=True):
    method forward (line 122) | def forward(self, X):
  function normalize_tensor (line 138) | def normalize_tensor(x,eps=1e-10):
  function spatial_average (line 143) | def spatial_average(x, keepdim=True):

FILE: RestoreFormer/modules/losses/vqperceptual.py
  class DummyLoss (line 13) | class DummyLoss(nn.Module):
    method __init__ (line 14) | def __init__(self):
  function adopt_weight (line 18) | def adopt_weight(weight, global_step, threshold=0, value=0.):
  function hinge_d_loss (line 24) | def hinge_d_loss(logits_real, logits_fake):
  function vanilla_d_loss (line 31) | def vanilla_d_loss(logits_real, logits_fake):
  class VQLPIPSWithDiscriminatorWithCompWithIdentity (line 38) | class VQLPIPSWithDiscriminatorWithCompWithIdentity(nn.Module):
    method __init__ (line 39) | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
    method calculate_adaptive_weight (line 105) | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
    method _gram_mat (line 118) | def _gram_mat(self, x):
    method gray_resize_for_identity (line 133) | def gray_resize_for_identity(self, out, size=128):
    method forward (line 139) | def forward(self, codebook_loss, gts, reconstructions, components, opt...

FILE: RestoreFormer/modules/util.py
  function count_params (line 5) | def count_params(model):
  class ActNorm (line 10) | class ActNorm(nn.Module):
    method __init__ (line 11) | def __init__(self, num_features, logdet=False, affine=True,
    method initialize (line 22) | def initialize(self, input):
    method forward (line 43) | def forward(self, input, reverse=False):
    method reverse (line 71) | def reverse(self, output):
  class Attention2DConv (line 95) | class Attention2DConv(nn.Module):
    method __init__ (line 97) | def __init__(self):

FILE: RestoreFormer/modules/vqvae/arcface_arch.py
  function conv3x3 (line 6) | def conv3x3(in_planes, out_planes, stride=1):
  class BasicBlock (line 11) | class BasicBlock(nn.Module):
    method __init__ (line 14) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 24) | def forward(self, x):
  class IRBlock (line 43) | class IRBlock(nn.Module):
    method __init__ (line 46) | def __init__(self, inplanes, planes, stride=1, downsample=None, use_se...
    method forward (line 60) | def forward(self, x):
  class Bottleneck (line 81) | class Bottleneck(nn.Module):
    method __init__ (line 84) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 96) | def forward(self, x):
  class SEBlock (line 119) | class SEBlock(nn.Module):
    method __init__ (line 121) | def __init__(self, channel, reduction=16):
    method forward (line 128) | def forward(self, x):
  class ResNetArcFace (line 136) | class ResNetArcFace(nn.Module):
    method __init__ (line 138) | def __init__(self, block, layers, use_se=True):
    method _make_layer (line 167) | def _make_layer(self, block, planes, blocks, stride=1):
    method forward (line 182) | def forward(self, x):

FILE: RestoreFormer/modules/vqvae/facial_component_discriminator.py
  class FacialComponentDiscriminator (line 14) | class FacialComponentDiscriminator(nn.Module):
    method __init__ (line 16) | def __init__(self):
    method forward (line 26) | def forward(self, x, return_feats=False):

FILE: RestoreFormer/modules/vqvae/utils.py
  function get_roi_regions (line 4) | def get_roi_regions(gt, output, loc_left_eyes, loc_right_eyes, loc_mouths,

FILE: RestoreFormer/modules/vqvae/vqvae_arch.py
  class VectorQuantizer (line 11) | class VectorQuantizer(nn.Module):
    method __init__ (line 23) | def __init__(self, n_e, e_dim, beta):
    method forward (line 32) | def forward(self, z):
    method get_codebook_entry (line 94) | def get_codebook_entry(self, indices, shape):
  function nonlinearity (line 112) | def nonlinearity(x):
  function Normalize (line 117) | def Normalize(in_channels):
  class Upsample (line 121) | class Upsample(nn.Module):
    method __init__ (line 122) | def __init__(self, in_channels, with_conv):
    method forward (line 132) | def forward(self, x):
  class Downsample (line 139) | class Downsample(nn.Module):
    method __init__ (line 140) | def __init__(self, in_channels, with_conv):
    method forward (line 151) | def forward(self, x):
  class ResnetBlock (line 161) | class ResnetBlock(nn.Module):
    method __init__ (line 162) | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=Fa...
    method forward (line 200) | def forward(self, x, temb):
  class MultiHeadAttnBlock (line 223) | class MultiHeadAttnBlock(nn.Module):
    method __init__ (line 224) | def __init__(self, in_channels, head_size=1):
    method forward (line 256) | def forward(self, x, y=None):
  class MultiHeadEncoder (line 300) | class MultiHeadEncoder(nn.Module):
    method __init__ (line 301) | def __init__(self, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks=2,
    method forward (line 367) | def forward(self, x):
  class MultiHeadDecoder (line 406) | class MultiHeadDecoder(nn.Module):
    method __init__ (line 407) | def __init__(self, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks=2,
    method forward (line 479) | def forward(self, z):
  class MultiHeadDecoderTransformer (line 513) | class MultiHeadDecoderTransformer(nn.Module):
    method __init__ (line 514) | def __init__(self, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks=2,
    method forward (line 586) | def forward(self, z, hs):
  class VQVAEGAN (line 622) | class VQVAEGAN(nn.Module):
    method __init__ (line 623) | def __init__(self, n_embed=1024, embed_dim=256, ch=128, out_ch=3, ch_m...
    method encode (line 653) | def encode(self, x):
    method decode (line 660) | def decode(self, quant):
    method forward (line 666) | def forward(self, input):
  class VQVAEGANMultiHeadTransformer (line 672) | class VQVAEGANMultiHeadTransformer(nn.Module):
    method __init__ (line 673) | def __init__(self, n_embed=1024, embed_dim=256, ch=128, out_ch=3, ch_m...
    method encode (line 708) | def encode(self, x):
    method decode (line 715) | def decode(self, quant, hs):
    method forward (line 721) | def forward(self, input):

FILE: RestoreFormer/util.py
  function download (line 18) | def download(url, local_path, chunk_size=1024):
  function md5_hash (line 30) | def md5_hash(path):
  function get_ckpt_path (line 36) | def get_ckpt_path(name, root, check=False):
  class KeyNotFoundError (line 47) | class KeyNotFoundError(Exception):
    method __init__ (line 48) | def __init__(self, cause, keys=None, visited=None):
  function retrieve (line 62) | def retrieve(

FILE: main.py
  function get_obj_from_str (line 15) | def get_obj_from_str(string, reload=False):
  function get_parser (line 23) | def get_parser(**parser_kwargs):
  function nondefault_trainer_args (line 137) | def nondefault_trainer_args(opt):
  function instantiate_from_config (line 144) | def instantiate_from_config(config):
  class WrappedDataset (line 153) | class WrappedDataset(Dataset):
    method __init__ (line 155) | def __init__(self, dataset):
    method __len__ (line 158) | def __len__(self):
    method __getitem__ (line 161) | def __getitem__(self, idx):
  class DataModuleFromConfig (line 165) | class DataModuleFromConfig(pl.LightningDataModule):
    method __init__ (line 166) | def __init__(self, batch_size, train=None, validation=None, test=None,
    method prepare_data (line 183) | def prepare_data(self):
    method setup (line 187) | def setup(self, stage=None):
    method _train_dataloader (line 195) | def _train_dataloader(self):
    method _val_dataloader (line 199) | def _val_dataloader(self):
    method _test_dataloader (line 204) | def _test_dataloader(self):
  class SetupCallback (line 209) | class SetupCallback(Callback):
    method __init__ (line 210) | def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, light...
    method on_pretrain_routine_start (line 220) | def on_pretrain_routine_start(self, trainer, pl_module):
  class ImageLogger (line 240) | class ImageLogger(Callback):
    method __init__ (line 241) | def __init__(self, batch_frequency, max_images, clamp=True, increase_l...
    method _wandb (line 255) | def _wandb(self, pl_module, images, batch_idx, split):
    method _testtube (line 264) | def _testtube(self, pl_module, images, batch_idx, split):
    method log_local (line 275) | def log_local(self, save_dir, split, images,
    method log_img (line 294) | def log_img(self, pl_module, batch, batch_idx, split="train"):
    method check_frequency (line 325) | def check_frequency(self, batch_idx):
    method on_train_batch_end (line 334) | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch...
    method on_validation_batch_end (line 337) | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, ...
  function melk (line 578) | def melk(*args, **kwargs):
  function divein (line 585) | def divein(*args, **kwargs):

FILE: scripts/metrics/cal_fid.py
  function calculate_fid_folder (line 12) | def calculate_fid_folder():

FILE: scripts/metrics/cal_identity_distance.py
  function cosine_similarity (line 21) | def cosine_similarity(emb1, emb2):
  function gray_resize_for_identity (line 25) | def gray_resize_for_identity(out, size=128):
  function calculate_identity_distance_folder (line 31) | def calculate_identity_distance_folder():

FILE: scripts/metrics/cal_psnr_ssim.py
  function calculate_psnr_ssim_lpips_folder (line 18) | def calculate_psnr_ssim_lpips_folder():

FILE: scripts/test.py
  function restoration (line 22) | def restoration(model,
  function get_parser (line 80) | def get_parser():
  function load_model_from_config (line 142) | def load_model_from_config(config, sd, gpu=True, eval_mode=True):
  function load_model_and_dset (line 191) | def load_model_and_dset(config, ckpt, gpu, eval_mode):
Condensed preview — 30 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (170K chars).
[
  {
    "path": ".gitignore",
    "chars": 113,
    "preview": "data/FFHQ\nscripts/data_synthetic\nexperiments/\nscripts/run_clustre.sh\nsftp-config.json\nresults/\n# scripts/metrics\n"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 5640,
    "preview": "# We have merged the code of RestoreFormer into our journal version, RestoreFormer++. Please feel free to access the res"
  },
  {
    "path": "RestoreFormer/data/ffhq_degradation_dataset.py",
    "chars": 9897,
    "preview": "import os\nimport cv2\nimport math\nimport numpy as np\nimport random\nimport os.path as osp\nimport torch\nimport torch.utils."
  },
  {
    "path": "RestoreFormer/distributed/__init__.py",
    "chars": 235,
    "preview": "from .distributed import (\n    get_rank,\n    get_local_rank,\n    is_primary,\n    synchronize,\n    get_world_size,\n    al"
  },
  {
    "path": "RestoreFormer/distributed/distributed.py",
    "chars": 3025,
    "preview": "import math\nimport pickle\n\nimport torch\nfrom torch import distributed as dist\nfrom torch.utils import data\n\n\nLOCAL_PROCE"
  },
  {
    "path": "RestoreFormer/distributed/launch.py",
    "chars": 2513,
    "preview": "import os\n\nimport torch\nfrom torch import distributed as dist\nfrom torch import multiprocessing as mp\n\nfrom . import dis"
  },
  {
    "path": "RestoreFormer/models/vqgan_v1.py",
    "chars": 9607,
    "preview": "import torch\nimport torch.nn.functional as F\nimport pytorch_lightning as pl\nfrom main import instantiate_from_config\n\nfr"
  },
  {
    "path": "RestoreFormer/modules/discriminator/model.py",
    "chars": 5430,
    "preview": "import functools\nimport torch.nn as nn\n\n\nfrom RestoreFormer.modules.util import ActNorm\n\n\ndef weights_init(m):\n    class"
  },
  {
    "path": "RestoreFormer/modules/losses/__init__.py",
    "chars": 2,
    "preview": "\n\n"
  },
  {
    "path": "RestoreFormer/modules/losses/lpips.py",
    "chars": 5583,
    "preview": "\"\"\"Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models\"\"\"\n\nimport torch\nimport torc"
  },
  {
    "path": "RestoreFormer/modules/losses/vqperceptual.py",
    "chars": 13251,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom copy import deepcopy\n\nfrom RestoreFormer.modules"
  },
  {
    "path": "RestoreFormer/modules/util.py",
    "chars": 2900,
    "preview": "import torch\nimport torch.nn as nn\n\n\ndef count_params(model):\n    total_params = sum(p.numel() for p in model.parameters"
  },
  {
    "path": "RestoreFormer/modules/vqvae/arcface_arch.py",
    "chars": 6109,
    "preview": "import torch.nn as nn\n\nfrom basicsr.utils.registry import ARCH_REGISTRY\n\n\ndef conv3x3(in_planes, out_planes, stride=1):\n"
  },
  {
    "path": "RestoreFormer/modules/vqvae/facial_component_discriminator.py",
    "chars": 1646,
    "preview": "import math\nimport random\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom basicsr.archs.sty"
  },
  {
    "path": "RestoreFormer/modules/vqvae/utils.py",
    "chars": 1754,
    "preview": "from torchvision.ops import roi_align\nimport torch\n\ndef get_roi_regions(gt, output, loc_left_eyes, loc_right_eyes, loc_m"
  },
  {
    "path": "RestoreFormer/modules/vqvae/vqvae_arch.py",
    "chars": 28124,
    "preview": "import torch\nimport torch.nn as nn\nimport random\nimport math\nimport torch.nn.functional as F\nimport numpy as np\n# from b"
  },
  {
    "path": "RestoreFormer/util.py",
    "chars": 4777,
    "preview": "import os, hashlib\nimport requests\nfrom tqdm import tqdm\n\nURL_MAP = {\n    \"vgg_lpips\": \"https://heibox.uni-heidelberg.de"
  },
  {
    "path": "__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "configs/HQ_Dictionary.yaml",
    "chars": 1649,
    "preview": "model:\n  base_learning_rate: 4.5e-6\n  target: RestoreFormer.models.vqgan_v1.RestoreFormerModel\n  params:\n    image_key: "
  },
  {
    "path": "configs/RestoreFormer.yaml",
    "chars": 3087,
    "preview": "model:\n  base_learning_rate: 4.5e-6 \n  target: RestoreFormer.models.vqgan_v1.RestoreFormerModel\n  params:\n    image_key:"
  },
  {
    "path": "main.py",
    "chars": 21984,
    "preview": "import argparse, os, sys, datetime, glob, importlib\nfrom omegaconf import OmegaConf\nimport numpy as np\nfrom PIL import I"
  },
  {
    "path": "restoreformer_requirement.txt",
    "chars": 3079,
    "preview": "Package                 Version             Location\n----------------------- ------------------- -----------------------"
  },
  {
    "path": "scripts/metrics/cal_fid.py",
    "chars": 2762,
    "preview": "import os, sys\nimport argparse\nimport math\nimport numpy as np\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom"
  },
  {
    "path": "scripts/metrics/cal_identity_distance.py",
    "chars": 3608,
    "preview": "import os, sys\nimport torch\nimport argparse\nimport cv2\nimport numpy as np\nimport glob\nimport pdb\nimport tqdm\nfrom copy i"
  },
  {
    "path": "scripts/metrics/cal_psnr_ssim.py",
    "chars": 3080,
    "preview": "import os, sys\nimport argparse\nimport cv2\nimport numpy as np\nimport glob\nimport pdb\nimport tqdm\nimport torch\n\nfrom basic"
  },
  {
    "path": "scripts/metrics/run.sh",
    "chars": 890,
    "preview": "\n### Journal ###\nroot='results/'\nout_root='results/metrics'\n\ntest_name='RestoreFormer'\n\ntest_image=$test_name'/restored_"
  },
  {
    "path": "scripts/run.sh",
    "chars": 361,
    "preview": "export BASICSR_JIT=True\n\nconf_name='HQ_Dictionary'\n# conf_name='RestoreFormer'\n\nROOT_PATH='' # The path for saving model"
  },
  {
    "path": "scripts/test.py",
    "chars": 9786,
    "preview": "import argparse, os, sys, glob, math, time\nimport torch\nimport numpy as np\nfrom omegaconf import OmegaConf\nfrom PIL impo"
  },
  {
    "path": "scripts/test.sh",
    "chars": 379,
    "preview": "# # ### Good\nexp_name='RestoreFormer'\n\nroot_path='experiments'\nout_root_path='results'\nalign_test_path='data/test'\ntag='"
  }
]

About this extraction

This page contains the full source code of the wzhouxiff/RestoreFormer GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 30 files (158.8 KB), approximately 39.9k tokens, and a symbol index with 172 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!