[
  {
    "path": ".gitignore",
    "content": "data/FFHQ\nscripts/data_synthetic\nexperiments/\nscripts/run_clustre.sh\nsftp-config.json\nresults/\n# scripts/metrics\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# We have merged the code of RestoreFormer into our journal version, RestoreFormer++. Please feel free to access the resources from [https://github.com/wzhouxiff/RestoreFormerPlusPlus](https://github.com/wzhouxiff/RestoreFormerPlusPlus)\n\n# Updating\n- **20230915** Update an online demo [![Huggingface Gradio](https://img.shields.io/static/v1?label=Demo&message=Huggingface%20Gradio&color=orange)](https://huggingface.co/spaces/wzhouxiff/RestoreFormerPlusPlus)\n- **20230915** A more user-friendly and comprehensive inference method refer to our [RestoreFormer++](https://github.com/wzhouxiff/RestoreFormerPlusPlus)\n- **20230116** For convenience, we further upload the [test datasets](#testset), including CelebA (both HQ and LQ data), LFW-Test, CelebChild-Test, and Webphoto-Test, to OneDrive and BaiduYun.\n- **20221003** We provide the link of the [test datasets](#testset).\n- **20220924** We add the code for [**metrics**](#metrics) in scripts/metrics.\n\n\n<!--\n# RestoreFormer\n\nThis repo includes the source code of the paper: \"[RestoreFormer: High-Quality Blind Face Restoration from Undegraded Key-Value Pairs](https://openaccess.thecvf.com/content/CVPR2022/papers/Wang_RestoreFormer_High-Quality_Blind_Face_Restoration_From_Undegraded_Key-Value_Pairs_CVPR_2022_paper.pdf)\" (CVPR 2022) by Zhouxia Wang, Jiawei Zhang, Runjian Chen, Wenping Wang, and Ping Luo.\n\n![](assets/figure1.png)\n\n**RestoreFormer** tends to explore fully-spatial attentions to model contextual information and surpasses existing works that use local operators. It has several benefits compared to prior arts. First, it incorporates a multi-head coross-attention layer to learn fully-spatial interations between corrupted queries and high-quality key-value pairs. Second, the key-value pairs in RestoreFormer are sampled from a reconstruction-oriented high-quality dictionary, whose elements are rich in high-quality facial features specifically aimed for face reconstruction.\n\n-->\n\n<!-- ![](assets/framework.png \"Framework\")-->\n\n<!--\n\n## Environment\n\n- python>=3.7\n- pytorch>=1.7.1\n- pytorch-lightning==1.0.8\n- omegaconf==2.0.0\n- basicsr==1.3.3.4\n\n**Warning** Different versions of pytorch-lightning and omegaconf may lead to errors or different results.\n\n## Preparations of dataset and models\n\n**Dataset**: \n- Training data: Both **HQ Dictionary** and **RestoreFormer** in our work are trained with **FFHQ** which attained from [FFHQ repository](https://github.com/NVlabs/ffhq-dataset). The original size of the images in FFHQ are 1024x1024. We resize them to 512x512 with bilinear interpolation in our work. Link this dataset to ./data/FFHQ/image512x512.\n- <a id=\"testset\">Test data</a>: \n   * CelebA-Test-HQ: [OneDrive](https://connecthkuhk-my.sharepoint.com/:u:/g/personal/wzhoux_connect_hku_hk/EY7P-MReZUZOngy3UGa5abUBJKel1IH5uYZLdwp2e2KvUw?e=rK0VWh); [BaiduYun](https://pan.baidu.com/s/1tMpxz8lIW50U8h00047GIw?pwd=mp9t)(code mp9t)\n   * CelebA-Test-LQ: [OneDrive](https://connecthkuhk-my.sharepoint.com/:u:/g/personal/wzhoux_connect_hku_hk/EXULDOtX3qdKg9_--k-hbr4BumxOUAi19iQjZNz75S6pKA?e=Kghqri); [BaiduYun](https://pan.baidu.com/s/1y6ZcQPCLyggj9VB5MgoWyg?pwd=7s6h)(code 7s6h)\n   * LFW-Test: [OneDrive](https://connecthkuhk-my.sharepoint.com/:u:/g/personal/wzhoux_connect_hku_hk/EZ7ibkhUuRxBjdd-MesczpgBfpLVfv-9uYVskLuZiYpBsg?e=xPNH26); [BaiduYun](https://pan.baidu.com/s/1UkfYLTViL8XVdZ-Ej-2G9g?pwd=7fhr)(code 7fhr). Note that it was align with dlib.\n   * CelebChild: [OneDrive](https://connecthkuhk-my.sharepoint.com/:u:/g/personal/wzhoux_connect_hku_hk/ESK6vjLzDuJAsd-cfWrfl20BTeSD_w4uRNJREGfl3zGzJg?e=Tou7ft); [BaiduYun](https://pan.baidu.com/s/1pGCD4TkhtDsmp8emZd8smA?pwd=rq65)(code rq65)\n   * WepPhoto-Test: [OneDrive](https://connecthkuhk-my.sharepoint.com/:u:/g/personal/wzhoux_connect_hku_hk/ER1-0eYKGkZIs-YEDhNW0xIBohCI5IEZyAS2PAvI81Stcg?e=TFJFGh); [BaiduYun](https://pan.baidu.com/s/1SjBfinSL1F-bbOpXiD0nlw?pwd=nren)(code nren)\n\n**Model**: Both pretrained models used for training and the trained model of our RestoreFormer can be attained from [OneDrive](https://connecthkuhk-my.sharepoint.com/:u:/g/personal/wzhoux_connect_hku_hk/Eb73S2jXZIxNrrOFRnFKu2MBTe7kl4cMYYwwiudAmDNwYg?e=Xa4ZDf) or [BaiduYun](https://pan.baidu.com/s/1EO7_1dYyCuORpPNosQgogg?pwd=x6nn)(code x6nn). Link these models to ./experiments.\n\n## Test\n    sh scripts/test.sh\n\n## Training\n    sh scripts/run.sh\n\n**Note**. \n- The first stage is to attain **HQ Dictionary** by setting `conf_name` in scripts/run.sh to 'HQ\\_Dictionary'. \n- The second stage is blind face restoration. You need to add your trained HQ\\_Dictionary model to `ckpt_path` in config/RestoreFormer.yaml and set `conf_name` in scripts/run.sh to 'RestoreFormer'.\n- Our model is trained with 4 V100 GPUs.\n\n## <a id=\"metrics\">Metrics</a>\n    sh scripts/metrics/run.sh\n    \n**Note**. \n- You need to add the path of CelebA-Test dataset in the script if you want get IDD, PSRN, SSIM, LIPIS.\n\n## Citation\n    @article{wang2022restoreformer,\n      title={RestoreFormer: High-Quality Blind Face Restoration from Undegraded Key-Value Pairs},\n      author={Wang, Zhouxia and Zhang, Jiawei and Chen, Runjian and Wang, Wenping and Luo, Ping},\n      booktitle={The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},\n      year={2022}\n    }\n\n## Acknowledgement\nWe thank everyone who makes their code and models available, especially [Taming Transformer](https://github.com/CompVis/taming-transformers), [basicsr](https://github.com/XPixelGroup/BasicSR), and [GFPGAN](https://github.com/TencentARC/GFPGAN).\n\n## Contact\nFor any question, feel free to email `wzhoux@connect.hku.hk` or `zhouzi1212@gmail.com`.\n\n-->\n"
  },
  {
    "path": "RestoreFormer/data/ffhq_degradation_dataset.py",
    "content": "import os\nimport cv2\nimport math\nimport numpy as np\nimport random\nimport os.path as osp\nimport torch\nimport torch.utils.data as data\nfrom torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,\n                                               normalize)\n\nfrom basicsr.data import degradations as degradations\nfrom basicsr.data.data_util import paths_from_folder\nfrom basicsr.data.transforms import augment\nfrom basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor\nfrom basicsr.utils.registry import DATASET_REGISTRY\n\n\n@DATASET_REGISTRY.register()\nclass FFHQDegradationDataset(data.Dataset):\n\n    def __init__(self, opt):\n        super(FFHQDegradationDataset, self).__init__()\n        self.opt = opt\n        # file client (io backend)\n        self.file_client = None\n        self.io_backend_opt = opt['io_backend']\n\n        self.gt_folder = opt['dataroot_gt']\n        self.mean = opt['mean']\n        self.std = opt['std']\n        self.out_size = opt['out_size']\n\n        self.crop_components = opt.get('crop_components', False)  # facial components\n        self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1)\n\n        if self.crop_components:\n            self.components_list = torch.load(opt.get('component_path'))\n\n        if self.io_backend_opt['type'] == 'lmdb':\n            self.io_backend_opt['db_paths'] = self.gt_folder\n            if not self.gt_folder.endswith('.lmdb'):\n                raise ValueError(f\"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}\")\n            with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:\n                self.paths = [line.split('.')[0] for line in fin]\n        else:\n            self.paths = paths_from_folder(self.gt_folder)\n\n        # degradations\n        self.blur_kernel_size = opt['blur_kernel_size']\n        self.kernel_list = opt['kernel_list']\n        self.kernel_prob = opt['kernel_prob']\n        self.blur_sigma = opt['blur_sigma']\n        self.downsample_range = opt['downsample_range']\n        self.noise_range = opt['noise_range']\n        self.jpeg_range = opt['jpeg_range']\n\n        # color jitter\n        self.color_jitter_prob = opt.get('color_jitter_prob')\n        self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob')\n        self.color_jitter_shift = opt.get('color_jitter_shift', 20)\n        # to gray\n        self.gray_prob = opt.get('gray_prob')\n\n        logger = get_root_logger()\n        logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, '\n                    f'sigma: [{\", \".join(map(str, self.blur_sigma))}]')\n        logger.info(f'Downsample: downsample_range [{\", \".join(map(str, self.downsample_range))}]')\n        logger.info(f'Noise: [{\", \".join(map(str, self.noise_range))}]')\n        logger.info(f'JPEG compression: [{\", \".join(map(str, self.jpeg_range))}]')\n\n        if self.color_jitter_prob is not None:\n            logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, '\n                        f'shift: {self.color_jitter_shift}')\n        if self.gray_prob is not None:\n            logger.info(f'Use random gray. Prob: {self.gray_prob}')\n\n        self.color_jitter_shift /= 255.\n\n\n    @staticmethod\n    def color_jitter(img, shift):\n        jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)\n        img = img + jitter_val\n        img = np.clip(img, 0, 1)\n        return img\n\n    @staticmethod\n    def color_jitter_pt(img, brightness, contrast, saturation, hue):\n        fn_idx = torch.randperm(4)\n        for fn_id in fn_idx:\n            if fn_id == 0 and brightness is not None:\n                brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()\n                img = adjust_brightness(img, brightness_factor)\n\n            if fn_id == 1 and contrast is not None:\n                contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()\n                img = adjust_contrast(img, contrast_factor)\n\n            if fn_id == 2 and saturation is not None:\n                saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()\n                img = adjust_saturation(img, saturation_factor)\n\n            if fn_id == 3 and hue is not None:\n                hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()\n                img = adjust_hue(img, hue_factor)\n        return img\n\n    def get_component_coordinates(self, index, status):\n        components_bbox = self.components_list[f'{index:08d}']\n        if status[0]:  # hflip\n            # exchange right and left eye\n            tmp = components_bbox['left_eye']\n            components_bbox['left_eye'] = components_bbox['right_eye']\n            components_bbox['right_eye'] = tmp\n            # modify the width coordinate\n            components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0]\n            components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0]\n            components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0]\n\n        # get coordinates\n        locations = []\n        for part in ['left_eye', 'right_eye', 'mouth']:\n            mean = components_bbox[part][0:2]\n            half_len = components_bbox[part][2]\n            if 'eye' in part:\n                half_len *= self.eye_enlarge_ratio\n            loc = np.hstack((mean - half_len + 1, mean + half_len))\n            loc = torch.from_numpy(loc).float()\n            locations.append(loc)\n        return locations\n\n    def __getitem__(self, index):\n        if self.file_client is None:\n            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)\n\n        # load gt image\n        gt_path = self.paths[index]\n        img_bytes = self.file_client.get(gt_path)\n        img_gt = imfrombytes(img_bytes, float32=True)\n\n        # random horizontal flip\n        img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)\n        h, w, _ = img_gt.shape\n\n        if self.crop_components:\n            locations = self.get_component_coordinates(index, status)\n            loc_left_eye, loc_right_eye, loc_mouth = locations\n\n        # ------------------------ generate lq image ------------------------ #\n        # blur\n        assert self.blur_kernel_size[0] < self.blur_kernel_size[1], 'Wrong blur kernel size range'\n        cur_kernel_size = random.randint(self.blur_kernel_size[0],self.blur_kernel_size[1]) * 2 + 1\n        kernel = degradations.random_mixed_kernels(\n            self.kernel_list,\n            self.kernel_prob,\n            cur_kernel_size,\n            self.blur_sigma,\n            self.blur_sigma, [-math.pi, math.pi],\n            noise_range=None)\n        img_lq = cv2.filter2D(img_gt, -1, kernel)\n        # downsample\n        scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])\n        img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)\n        # noise\n        if self.noise_range is not None:\n            img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range)\n        # jpeg compression\n        if self.jpeg_range is not None:\n            img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range)\n\n        # resize to original size\n        img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)\n\n        # random color jitter (only for lq)\n        if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):\n            img_lq = self.color_jitter(img_lq, self.color_jitter_shift)\n        # random to gray (only for lq)\n        if self.gray_prob and np.random.uniform() < self.gray_prob:\n            img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)\n            img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])\n            if self.opt.get('gt_gray'):\n                img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)\n                img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])\n\n        # BGR to RGB, HWC to CHW, numpy to tensor\n        img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)\n\n        # random color jitter (pytorch version) (only for lq)\n        if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):\n            brightness = self.opt.get('brightness', (0.5, 1.5))\n            contrast = self.opt.get('contrast', (0.5, 1.5))\n            saturation = self.opt.get('saturation', (0, 1.5))\n            hue = self.opt.get('hue', (-0.1, 0.1))\n            img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue)\n\n        # round and clip\n        img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.\n\n        # normalize\n        normalize(img_gt, self.mean, self.std, inplace=True)\n        normalize(img_lq, self.mean, self.std, inplace=True)\n\n        return_dict = {\n                'lq': img_lq,\n                'gt': img_gt,\n                'gt_path': gt_path\n            }\n        if self.crop_components:\n            return_dict['loc_left_eye'] = loc_left_eye\n            return_dict['loc_right_eye'] = loc_right_eye\n            return_dict['loc_mouth'] = loc_mouth\n\n\n        return return_dict\n\n    def __len__(self):\n        return len(self.paths)\n\nimport argparse\nfrom omegaconf import OmegaConf\nimport pdb\nfrom basicsr.utils import img2tensor, imwrite, tensor2img\n\nif __name__=='__main__':\n    # pdb.set_trace()\n    base='configs/RestoreFormer.yaml'\n\n    opt = OmegaConf.load(base)\n    dataset = FFHQDegradationDataset(opt['data']['params']['train']['params'])\n\n    for i in range(100):\n        sample = dataset.getitem(i)\n        name = sample['gt_path'].split('/')[-1][:-4]\n        gt = tensor2img(sample['gt'])\n        imwrite(gt, +name+'_gt.png')\n        lq = tensor2img(sample['lq'])\n        imwrite(lq, name+'_lq_nojitter.png')"
  },
  {
    "path": "RestoreFormer/distributed/__init__.py",
    "content": "from .distributed import (\n    get_rank,\n    get_local_rank,\n    is_primary,\n    synchronize,\n    get_world_size,\n    all_reduce,\n    all_gather,\n    reduce_dict,\n    data_sampler,\n    LOCAL_PROCESS_GROUP,\n)\nfrom .launch import launch\n"
  },
  {
    "path": "RestoreFormer/distributed/distributed.py",
    "content": "import math\nimport pickle\n\nimport torch\nfrom torch import distributed as dist\nfrom torch.utils import data\n\n\nLOCAL_PROCESS_GROUP = None\n\n\ndef is_primary():\n    return get_rank() == 0\n\n\ndef get_rank():\n    if not dist.is_available():\n        return 0\n\n    if not dist.is_initialized():\n        return 0\n\n    return dist.get_rank()\n\n\ndef get_local_rank():\n    if not dist.is_available():\n        return 0\n\n    if not dist.is_initialized():\n        return 0\n\n    if LOCAL_PROCESS_GROUP is None:\n        raise ValueError(\"tensorfn.distributed.LOCAL_PROCESS_GROUP is None\")\n\n    return dist.get_rank(group=LOCAL_PROCESS_GROUP)\n\n\ndef synchronize():\n    if not dist.is_available():\n        return\n\n    if not dist.is_initialized():\n        return\n\n    world_size = dist.get_world_size()\n\n    if world_size == 1:\n        return\n\n    dist.barrier()\n\n\ndef get_world_size():\n    if not dist.is_available():\n        return 1\n\n    if not dist.is_initialized():\n        return 1\n\n    return dist.get_world_size()\n\n\ndef all_reduce(tensor, op=dist.ReduceOp.SUM):\n    world_size = get_world_size()\n\n    if world_size == 1:\n        return tensor\n\n    dist.all_reduce(tensor, op=op)\n\n    return tensor\n\n\ndef all_gather(data):\n    world_size = get_world_size()\n\n    if world_size == 1:\n        return [data]\n\n    buffer = pickle.dumps(data)\n    storage = torch.ByteStorage.from_buffer(buffer)\n    tensor = torch.ByteTensor(storage).to(\"cuda\")\n\n    local_size = torch.IntTensor([tensor.numel()]).to(\"cuda\")\n    size_list = [torch.IntTensor([1]).to(\"cuda\") for _ in range(world_size)]\n    dist.all_gather(size_list, local_size)\n    size_list = [int(size.item()) for size in size_list]\n    max_size = max(size_list)\n\n    tensor_list = []\n    for _ in size_list:\n        tensor_list.append(torch.ByteTensor(size=(max_size,)).to(\"cuda\"))\n\n    if local_size != max_size:\n        padding = torch.ByteTensor(size=(max_size - local_size,)).to(\"cuda\")\n        tensor = torch.cat((tensor, padding), 0)\n\n    dist.all_gather(tensor_list, tensor)\n\n    data_list = []\n\n    for size, tensor in zip(size_list, tensor_list):\n        buffer = tensor.cpu().numpy().tobytes()[:size]\n        data_list.append(pickle.loads(buffer))\n\n    return data_list\n\n\ndef reduce_dict(input_dict, average=True):\n    world_size = get_world_size()\n\n    if world_size < 2:\n        return input_dict\n\n    with torch.no_grad():\n        keys = []\n        values = []\n\n        for k in sorted(input_dict.keys()):\n            keys.append(k)\n            values.append(input_dict[k])\n\n        values = torch.stack(values, 0)\n        dist.reduce(values, dst=0)\n\n        if dist.get_rank() == 0 and average:\n            values /= world_size\n\n        reduced_dict = {k: v for k, v in zip(keys, values)}\n\n    return reduced_dict\n\n\ndef data_sampler(dataset, shuffle, distributed):\n    if distributed:\n        return data.distributed.DistributedSampler(dataset, shuffle=shuffle)\n\n    if shuffle:\n        return data.RandomSampler(dataset)\n\n    else:\n        return data.SequentialSampler(dataset)\n"
  },
  {
    "path": "RestoreFormer/distributed/launch.py",
    "content": "import os\n\nimport torch\nfrom torch import distributed as dist\nfrom torch import multiprocessing as mp\n\nfrom . import distributed as dist_fn\n\n\ndef find_free_port():\n    import socket\n\n    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n\n    sock.bind((\"\", 0))\n    port = sock.getsockname()[1]\n    sock.close()\n\n    return port\n\n\ndef launch(fn, n_gpu_per_machine, n_machine=1, machine_rank=0, dist_url=None, args=()):\n    world_size = n_machine * n_gpu_per_machine\n\n    if world_size > 1:\n        if \"OMP_NUM_THREADS\" not in os.environ:\n            os.environ[\"OMP_NUM_THREADS\"] = \"1\"\n\n        if dist_url == \"auto\":\n            if n_machine != 1:\n                raise ValueError('dist_url=\"auto\" not supported in multi-machine jobs')\n\n            port = find_free_port()\n            dist_url = f\"tcp://127.0.0.1:{port}\"\n\n        if n_machine > 1 and dist_url.startswith(\"file://\"):\n            raise ValueError(\n                \"file:// is not a reliable init method in multi-machine jobs. Prefer tcp://\"\n            )\n\n        mp.spawn(\n            distributed_worker,\n            nprocs=n_gpu_per_machine,\n            args=(fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args),\n            daemon=False,\n        )\n\n    else:\n        fn(*args)\n\n\ndef distributed_worker(\n    local_rank, fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args\n):\n    if not torch.cuda.is_available():\n        raise OSError(\"CUDA is not available. Please check your environments\")\n\n    global_rank = machine_rank * n_gpu_per_machine + local_rank\n\n    try:\n        dist.init_process_group(\n            backend=\"NCCL\",\n            init_method=dist_url,\n            world_size=world_size,\n            rank=global_rank,\n        )\n\n    except Exception:\n        raise OSError(\"failed to initialize NCCL groups\")\n\n    dist_fn.synchronize()\n\n    if n_gpu_per_machine > torch.cuda.device_count():\n        raise ValueError(\n            f\"specified n_gpu_per_machine larger than available device ({torch.cuda.device_count()})\"\n        )\n\n    torch.cuda.set_device(local_rank)\n\n    if dist_fn.LOCAL_PROCESS_GROUP is not None:\n        raise ValueError(\"torch.distributed.LOCAL_PROCESS_GROUP is not None\")\n\n    n_machine = world_size // n_gpu_per_machine\n\n    for i in range(n_machine):\n        ranks_on_i = list(range(i * n_gpu_per_machine, (i + 1) * n_gpu_per_machine))\n        pg = dist.new_group(ranks_on_i)\n\n        if i == machine_rank:\n            dist_fn.distributed.LOCAL_PROCESS_GROUP = pg\n\n    fn(*args)\n"
  },
  {
    "path": "RestoreFormer/models/vqgan_v1.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport pytorch_lightning as pl\nfrom main import instantiate_from_config\n\nfrom RestoreFormer.modules.vqvae.utils import get_roi_regions\n\nclass RestoreFormerModel(pl.LightningModule):\n    def __init__(self,\n                 ddconfig,\n                 lossconfig,\n                 ckpt_path=None,\n                 ignore_keys=[],\n                 image_key=\"lq\",\n                 colorize_nlabels=None,\n                 monitor=None,\n                 special_params_lr_scale=1.0,\n                 comp_params_lr_scale=1.0,\n                 schedule_step=[80000, 200000]\n                 ):\n        super().__init__()\n        self.image_key = image_key\n        self.vqvae = instantiate_from_config(ddconfig)\n\n        lossconfig['params']['distill_param']=ddconfig['params']\n        self.loss = instantiate_from_config(lossconfig)\n        if ckpt_path is not None:\n            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)\n\n        \n        if ('comp_weight' in lossconfig['params'] and lossconfig['params']['comp_weight']) or ('comp_style_weight' in lossconfig['params'] and lossconfig['params']['comp_style_weight']):\n            self.use_facial_disc = True\n        else:\n            self.use_facial_disc = False\n\n        self.fix_decoder = ddconfig['params']['fix_decoder']\n        \n        self.disc_start = lossconfig['params']['disc_start']\n        self.special_params_lr_scale = special_params_lr_scale\n        self.comp_params_lr_scale = comp_params_lr_scale\n        self.schedule_step = schedule_step\n\n    def init_from_ckpt(self, path, ignore_keys=list()):\n        sd = torch.load(path, map_location=\"cpu\")[\"state_dict\"]\n        keys = list(sd.keys())\n\n        for k in keys:\n            for ik in ignore_keys:\n                if k.startswith(ik):\n                    print(\"Deleting key {} from state_dict.\".format(k))\n                    del sd[k]\n\n        state_dict = self.state_dict()\n        require_keys = state_dict.keys()\n        keys = sd.keys()\n        un_pretrained_keys = []\n        for k in require_keys:\n            if k not in keys: \n                # miss 'vqvae.'\n                if k[6:] in keys:\n                    state_dict[k] = sd[k[6:]]\n                else:\n                    un_pretrained_keys.append(k)\n            else:\n                state_dict[k] = sd[k]\n\n        # print(f'*************************************************')\n        # print(f\"Layers without pretraining: {un_pretrained_keys}\")\n        # print(f'*************************************************')\n\n        self.load_state_dict(state_dict, strict=True)\n        print(f\"Restored from {path}\")\n\n    def forward(self, input):\n        dec, diff, info, hs = self.vqvae(input)\n        return dec, diff, info, hs\n\n    def training_step(self, batch, batch_idx, optimizer_idx):\n        \n        x = batch[self.image_key]\n        xrec, qloss, info, hs = self(x)\n\n        if self.image_key != 'gt':\n            x = batch['gt']\n\n        if self.use_facial_disc:\n            loc_left_eyes = batch['loc_left_eye']\n            loc_right_eyes = batch['loc_right_eye']\n            loc_mouths = batch['loc_mouth']\n            face_ratio = xrec.shape[-1] / 512\n            components = get_roi_regions(x, xrec, loc_left_eyes, loc_right_eyes, loc_mouths, face_ratio)\n        else:\n            components = None\n\n        if optimizer_idx == 0:\n            # autoencode\n            aeloss, log_dict_ae = self.loss(qloss, x, xrec, components, optimizer_idx, self.global_step,\n                                            last_layer=self.get_last_layer(), split=\"train\")\n\n            self.log(\"train/aeloss\", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)\n            self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)\n            return aeloss\n\n        if optimizer_idx == 1:\n            # discriminator\n            discloss, log_dict_disc = self.loss(qloss, x, xrec, components, optimizer_idx, self.global_step,\n                                            last_layer=None, split=\"train\")\n            self.log(\"train/discloss\", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)\n            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)\n            return discloss\n\n        \n        if self.disc_start <= self.global_step:\n\n            # left eye\n            if optimizer_idx == 2:\n                # discriminator\n                disc_left_loss, log_dict_disc = self.loss(qloss, x, xrec, components, optimizer_idx, self.global_step,\n                                                last_layer=None, split=\"train\")\n                self.log(\"train/disc_left_loss\", disc_left_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)\n                self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)\n                return disc_left_loss\n\n            # right eye\n            if optimizer_idx == 3:\n                # discriminator\n                disc_right_loss, log_dict_disc = self.loss(qloss, x, xrec, components, optimizer_idx, self.global_step,\n                                                last_layer=None, split=\"train\")\n                self.log(\"train/disc_right_loss\", disc_right_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)\n                self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)\n                return disc_right_loss\n\n            # mouth\n            if optimizer_idx == 4:\n                # discriminator\n                disc_mouth_loss, log_dict_disc = self.loss(qloss, x, xrec, components, optimizer_idx, self.global_step,\n                                                last_layer=None, split=\"train\")\n                self.log(\"train/disc_mouth_loss\", disc_mouth_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)\n                self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)\n                return disc_mouth_loss\n\n    def validation_step(self, batch, batch_idx):\n        x = batch[self.image_key]\n        xrec, qloss, info, hs = self(x)\n\n        if self.image_key != 'gt':\n            x = batch['gt']\n\n        aeloss, log_dict_ae = self.loss(qloss, x, xrec, None, 0, self.global_step,\n                                            last_layer=self.get_last_layer(), split=\"val\")\n\n        discloss, log_dict_disc = self.loss(qloss, x, xrec, None, 1, self.global_step,\n                                            last_layer=None, split=\"val\")\n        rec_loss = log_dict_ae[\"val/rec_loss\"]\n        self.log(\"val/rec_loss\", rec_loss,\n                   prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)\n        self.log(\"val/aeloss\", aeloss,\n                   prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)\n        self.log_dict(log_dict_ae)\n        self.log_dict(log_dict_disc)\n\n        return self.log_dict\n\n    def configure_optimizers(self):\n        lr = self.learning_rate\n\n        normal_params = []\n        special_params = []\n        for name, param in self.vqvae.named_parameters():\n            if not param.requires_grad:\n                continue\n            if 'decoder' in name and 'attn' in name:\n                special_params.append(param)\n            else:\n                normal_params.append(param)\n        # print('special_params', special_params)\n        opt_ae_params = [{'params': normal_params, 'lr': lr},\n                         {'params': special_params, 'lr': lr*self.special_params_lr_scale}]\n        opt_ae = torch.optim.Adam(opt_ae_params, betas=(0.5, 0.9))\n\n\n        opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),\n                                    lr=lr, betas=(0.5, 0.9))\n\n        optimizations = [opt_ae, opt_disc]\n\n        s0 = torch.optim.lr_scheduler.MultiStepLR(opt_ae, milestones=self.schedule_step, gamma=0.1, verbose=True)\n        s1 = torch.optim.lr_scheduler.MultiStepLR(opt_disc, milestones=self.schedule_step, gamma=0.1, verbose=True)\n        schedules = [s0, s1]\n\n        if self.use_facial_disc:\n            opt_l = torch.optim.Adam(self.loss.net_d_left_eye.parameters(),\n                                     lr=lr*self.comp_params_lr_scale, betas=(0.9, 0.99))\n            opt_r = torch.optim.Adam(self.loss.net_d_right_eye.parameters(),\n                                     lr=lr*self.comp_params_lr_scale, betas=(0.9, 0.99))\n            opt_m = torch.optim.Adam(self.loss.net_d_mouth.parameters(),\n                                     lr=lr*self.comp_params_lr_scale, betas=(0.9, 0.99))\n            optimizations += [opt_l, opt_r, opt_m]\n            \n            s2 = torch.optim.lr_scheduler.MultiStepLR(opt_l, milestones=self.schedule_step, gamma=0.1, verbose=True)\n            s3 = torch.optim.lr_scheduler.MultiStepLR(opt_r, milestones=self.schedule_step, gamma=0.1, verbose=True)\n            s4 = torch.optim.lr_scheduler.MultiStepLR(opt_m, milestones=self.schedule_step, gamma=0.1, verbose=True)\n            schedules += [s2, s3, s4]\n\n        return optimizations, schedules\n\n    def get_last_layer(self):\n        if self.fix_decoder:\n            return self.vqvae.quant_conv.weight\n        return self.vqvae.decoder.conv_out.weight\n\n    def log_images(self, batch, **kwargs):\n        log = dict()\n        x = batch[self.image_key]\n        x = x.to(self.device)\n        xrec, _, _, _ = self(x)\n        log[\"inputs\"] = x\n        log[\"reconstructions\"] = xrec\n\n        if self.image_key != 'gt':\n            x = batch['gt']\n            log[\"gt\"] = x\n        return log\n"
  },
  {
    "path": "RestoreFormer/modules/discriminator/model.py",
    "content": "import functools\nimport torch.nn as nn\n\n\nfrom RestoreFormer.modules.util import ActNorm\n\n\ndef weights_init(m):\n    classname = m.__class__.__name__\n    if classname.find('Conv') != -1:\n        nn.init.normal_(m.weight.data, 0.0, 0.02)\n    elif classname.find('BatchNorm') != -1:\n        nn.init.normal_(m.weight.data, 1.0, 0.02)\n        nn.init.constant_(m.bias.data, 0)\n\n\nclass NLayerDiscriminator(nn.Module):\n    \"\"\"Defines a PatchGAN discriminator as in Pix2Pix\n        --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py\n    \"\"\"\n    def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):\n        \"\"\"Construct a PatchGAN discriminator\n        Parameters:\n            input_nc (int)  -- the number of channels in input images\n            ndf (int)       -- the number of filters in the last conv layer\n            n_layers (int)  -- the number of conv layers in the discriminator\n            norm_layer      -- normalization layer\n        \"\"\"\n        super(NLayerDiscriminator, self).__init__()\n        if not use_actnorm:\n            norm_layer = nn.BatchNorm2d\n        else:\n            norm_layer = ActNorm\n        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters\n            use_bias = norm_layer.func != nn.BatchNorm2d\n        else:\n            use_bias = norm_layer != nn.BatchNorm2d\n\n        kw = 4\n        padw = 1\n        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]\n        nf_mult = 1\n        nf_mult_prev = 1\n        for n in range(1, n_layers):  # gradually increase the number of filters\n            nf_mult_prev = nf_mult\n            nf_mult = min(2 ** n, 8)\n            sequence += [\n                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),\n                norm_layer(ndf * nf_mult),\n                nn.LeakyReLU(0.2, True)\n            ]\n\n        nf_mult_prev = nf_mult\n        nf_mult = min(2 ** n_layers, 8)\n        sequence += [\n            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),\n            norm_layer(ndf * nf_mult),\n            nn.LeakyReLU(0.2, True)\n        ]\n\n        sequence += [\n            nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map\n        self.main = nn.Sequential(*sequence)\n\n    def forward(self, input):\n        \"\"\"Standard forward.\"\"\"\n        return self.main(input)\n\nclass NLayerDiscriminator_v1(nn.Module):\n    \"\"\"Defines a PatchGAN discriminator as in Pix2Pix\n        --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py\n    \"\"\"\n    def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):\n        \"\"\"Construct a PatchGAN discriminator\n        Parameters:\n            input_nc (int)  -- the number of channels in input images\n            ndf (int)       -- the number of filters in the last conv layer\n            n_layers (int)  -- the number of conv layers in the discriminator\n            norm_layer      -- normalization layer\n        \"\"\"\n        super(NLayerDiscriminator_v1, self).__init__()\n        if not use_actnorm:\n            norm_layer = nn.BatchNorm2d\n        else:\n            norm_layer = ActNorm\n        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters\n            use_bias = norm_layer.func != nn.BatchNorm2d\n        else:\n            use_bias = norm_layer != nn.BatchNorm2d\n\n        self.n_layers = n_layers\n\n        kw = 4\n        padw = 1\n        # sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]\n        self.head = nn.Sequential(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True))\n        # self.head = nn.Sequential(nn.Conv2d(3, 64, kernel_size=4, stride=1, padding=1), nn.LeakyReLU(0.2, True)).cuda()\n        nf_mult = 1\n        nf_mult_prev = 1\n        self.body = nn.ModuleList()\n        for n in range(1, n_layers):  # gradually increase the number of filters\n            nf_mult_prev = nf_mult\n            nf_mult = min(2 ** n, 8)\n\n            self.body.append(nn.Sequential(\n                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),\n                norm_layer(ndf * nf_mult),\n                nn.LeakyReLU(0.2, True)\n            ))\n\n        nf_mult_prev = nf_mult\n        nf_mult = min(2 ** n_layers, 8)\n        self.beforlast = nn.Sequential(\n            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),\n            norm_layer(ndf * nf_mult),\n            nn.LeakyReLU(0.2, True)\n        )\n\n        self.final = nn.Sequential(\n            nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw))  # output 1 channel prediction map\n        # self.main = nn.Sequential(*sequence)\n\n    def forward(self, input):\n        \"\"\"Standard forward.\"\"\"\n        # return self.main(input)\n        \n        features = []\n\n        f = self.head(input)\n        features.append(f) \n\n        for i in range(self.n_layers-1):\n            f = self.body[i](f)\n            features.append(f) \n\n        beforlastF = self.beforlast(f)\n        final = self.final(beforlastF)\n\n        return features, final\n\n"
  },
  {
    "path": "RestoreFormer/modules/losses/__init__.py",
    "content": "\n\n"
  },
  {
    "path": "RestoreFormer/modules/losses/lpips.py",
    "content": "\"\"\"Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models\"\"\"\n\nimport torch\nimport torch.nn as nn\nfrom torchvision import models\nfrom collections import namedtuple\n\nfrom RestoreFormer.util import get_ckpt_path\n\n\nclass LPIPS(nn.Module):\n    # Learned perceptual metric\n    def __init__(self, use_dropout=True, style_weight=0.):\n        super().__init__()\n        self.scaling_layer = ScalingLayer()\n        self.chns = [64, 128, 256, 512, 512]  # vg16 features\n        self.net = vgg16(pretrained=True, requires_grad=False)\n        self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)\n        self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)\n        self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)\n        self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)\n        self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)\n        self.load_from_pretrained()\n        for param in self.parameters():\n            param.requires_grad = False\n\n        self.style_weight = style_weight\n\n    def load_from_pretrained(self, name=\"vgg_lpips\"):\n        ckpt = get_ckpt_path(name, \"experiments/pretrained_models/lpips\")\n        self.load_state_dict(torch.load(ckpt, map_location=torch.device(\"cpu\")), strict=False)\n        print(\"loaded pretrained LPIPS loss from {}\".format(ckpt))\n\n    @classmethod\n    def from_pretrained(cls, name=\"vgg_lpips\"):\n        if name is not \"vgg_lpips\":\n            raise NotImplementedError\n        model = cls()\n        ckpt = get_ckpt_path(name)\n        model.load_state_dict(torch.load(ckpt, map_location=torch.device(\"cpu\")), strict=False)\n        return model\n\n    def forward(self, input, target):\n        in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))\n        outs0, outs1 = self.net(in0_input), self.net(in1_input)\n        feats0, feats1, diffs = {}, {}, {}\n        lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]\n        style_loss = torch.tensor([0.0]).to(input.device)\n        for kk in range(len(self.chns)):\n            feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])\n            diffs[kk] = (feats0[kk] - feats1[kk]) ** 2\n            if self.style_weight > 0.:\n                style_loss = style_loss + torch.mean((self._gram_mat(feats0[kk]) - \n                             self._gram_mat(feats1[kk])) ** 2)\n\n        res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]\n        val = res[0]\n        for l in range(1, len(self.chns)):\n            val += res[l]\n\n        return val, style_loss * self.style_weight\n\n    def _gram_mat(self, x):\n        \"\"\"Calculate Gram matrix.\n\n        Args:\n            x (torch.Tensor): Tensor with shape of (n, c, h, w).\n\n        Returns:\n            torch.Tensor: Gram matrix.\n        \"\"\"\n        n, c, h, w = x.size()\n        features = x.view(n, c, w * h)\n        features_t = features.transpose(1, 2)\n        gram = features.bmm(features_t) / (c * h * w)\n        return gram\n\n\nclass ScalingLayer(nn.Module):\n    def __init__(self):\n        super(ScalingLayer, self).__init__()\n        self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])\n        self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])\n\n    def forward(self, inp):\n        return (inp - self.shift) / self.scale\n\n\nclass NetLinLayer(nn.Module):\n    \"\"\" A single linear layer which does a 1x1 conv \"\"\"\n    def __init__(self, chn_in, chn_out=1, use_dropout=False):\n        super(NetLinLayer, self).__init__()\n        layers = [nn.Dropout(), ] if (use_dropout) else []\n        layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]\n        self.model = nn.Sequential(*layers)\n\n\nclass vgg16(torch.nn.Module):\n    def __init__(self, requires_grad=False, pretrained=True):\n        super(vgg16, self).__init__()\n        vgg_pretrained_features = models.vgg16(pretrained=pretrained).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        self.N_slices = 5\n        for x in range(4):\n            self.slice1.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(4, 9):\n            self.slice2.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(9, 16):\n            self.slice3.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(16, 23):\n            self.slice4.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(23, 30):\n            self.slice5.add_module(str(x), vgg_pretrained_features[x])\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, X):\n        h = self.slice1(X)\n        h_relu1_2 = h\n        h = self.slice2(h)\n        h_relu2_2 = h\n        h = self.slice3(h)\n        h_relu3_3 = h\n        h = self.slice4(h)\n        h_relu4_3 = h\n        h = self.slice5(h)\n        h_relu5_3 = h\n        vgg_outputs = namedtuple(\"VggOutputs\", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])\n        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)\n        return out\n\n\ndef normalize_tensor(x,eps=1e-10):\n    norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))\n    return x/(norm_factor+eps)\n\n\ndef spatial_average(x, keepdim=True):\n    return x.mean([2,3],keepdim=keepdim)\n\n"
  },
  {
    "path": "RestoreFormer/modules/losses/vqperceptual.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom copy import deepcopy\n\nfrom RestoreFormer.modules.losses.lpips import LPIPS\nfrom RestoreFormer.modules.discriminator.model import NLayerDiscriminator, weights_init\nfrom RestoreFormer.modules.vqvae.facial_component_discriminator import FacialComponentDiscriminator\nfrom basicsr.losses.losses import GANLoss, L1Loss\nfrom RestoreFormer.modules.vqvae.arcface_arch import ResNetArcFace\n\n\nclass DummyLoss(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n\ndef adopt_weight(weight, global_step, threshold=0, value=0.):\n    if global_step < threshold:\n        weight = value\n    return weight\n\n\ndef hinge_d_loss(logits_real, logits_fake):\n    loss_real = torch.mean(F.relu(1. - logits_real))\n    loss_fake = torch.mean(F.relu(1. + logits_fake))\n    d_loss = 0.5 * (loss_real + loss_fake)\n    return d_loss\n\n\ndef vanilla_d_loss(logits_real, logits_fake):\n    d_loss = 0.5 * (\n        torch.mean(torch.nn.functional.softplus(-logits_real)) +\n        torch.mean(torch.nn.functional.softplus(logits_fake)))\n    return d_loss\n\n\nclass VQLPIPSWithDiscriminatorWithCompWithIdentity(nn.Module):\n    def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,\n                 disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,\n                 perceptual_weight=1.0, use_actnorm=False, \n                 disc_ndf=64, disc_loss=\"hinge\", comp_weight=0.0, comp_style_weight=0.0, \n                 identity_weight=0.0, comp_disc_loss='vanilla', lpips_style_weight=0.0,\n                 identity_model_path=None, **ignore_kwargs):\n        super().__init__()\n        assert disc_loss in [\"hinge\", \"vanilla\"]\n        self.codebook_weight = codebook_weight\n        self.pixel_weight = pixelloss_weight\n        self.perceptual_loss = LPIPS(style_weight=lpips_style_weight).eval()\n        self.perceptual_weight = perceptual_weight\n\n        self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,\n                                                 n_layers=disc_num_layers,\n                                                 use_actnorm=use_actnorm,\n                                                 ndf=disc_ndf\n                                                 ).apply(weights_init)\n        if comp_weight > 0:\n            self.net_d_left_eye = FacialComponentDiscriminator()\n            self.net_d_right_eye = FacialComponentDiscriminator()\n            self.net_d_mouth = FacialComponentDiscriminator()\n            print(f'Use components discrimination')\n\n            self.cri_component = GANLoss(gan_type=comp_disc_loss, \n                                         real_label_val=1.0, \n                                         fake_label_val=0.0, \n                                         loss_weight=comp_weight)\n\n            if comp_style_weight > 0.:\n                self.cri_style = L1Loss(loss_weight=comp_style_weight, reduction='mean')\n\n        if identity_weight > 0:\n            self.identity = ResNetArcFace(block = 'IRBlock', \n                                          layers = [2, 2, 2, 2],\n                                          use_se = False)\n            print(f'Use identity loss')\n            if identity_model_path is not None:\n                sd = torch.load(identity_model_path, map_location=\"cpu\")\n                for k, v in deepcopy(sd).items():\n                    if k.startswith('module.'):\n                        sd[k[7:]] = v\n                        sd.pop(k)\n                self.identity.load_state_dict(sd, strict=True)\n\n            for param in self.identity.parameters():\n                param.requires_grad = False\n\n            self.cri_identity = L1Loss(loss_weight=identity_weight, reduction='mean')\n\n\n        self.discriminator_iter_start = disc_start\n        if disc_loss == \"hinge\":\n            self.disc_loss = hinge_d_loss\n        elif disc_loss == \"vanilla\":\n            self.disc_loss = vanilla_d_loss\n        else:\n            raise ValueError(f\"Unknown GAN loss '{disc_loss}'.\")\n        print(f\"VQLPIPSWithDiscriminatorWithCompWithIdentity running with {disc_loss} loss.\")\n        self.disc_factor = disc_factor\n        self.discriminator_weight = disc_weight\n        self.comp_weight = comp_weight\n        self.comp_style_weight = comp_style_weight\n        self.identity_weight = identity_weight\n        self.lpips_style_weight = lpips_style_weight\n\n    def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):\n        if last_layer is not None:\n            nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]\n            g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]\n        else:\n            nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]\n            g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]\n\n        d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)\n        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()\n        d_weight = d_weight * self.discriminator_weight\n        return d_weight\n\n    def _gram_mat(self, x):\n        \"\"\"Calculate Gram matrix.\n\n        Args:\n            x (torch.Tensor): Tensor with shape of (n, c, h, w).\n\n        Returns:\n            torch.Tensor: Gram matrix.\n        \"\"\"\n        n, c, h, w = x.size()\n        features = x.view(n, c, w * h)\n        features_t = features.transpose(1, 2)\n        gram = features.bmm(features_t) / (c * h * w)\n        return gram\n\n    def gray_resize_for_identity(self, out, size=128):\n        out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])\n        out_gray = out_gray.unsqueeze(1)\n        out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)\n        return out_gray\n\n    def forward(self, codebook_loss, gts, reconstructions, components, optimizer_idx,\n                global_step, last_layer=None, split=\"train\"):\n\n        # now the GAN part\n        if optimizer_idx == 0:\n            rec_loss = (torch.abs(gts.contiguous() - reconstructions.contiguous())) * self.pixel_weight\n            if self.perceptual_weight > 0:\n                p_loss, p_style_loss = self.perceptual_loss(gts.contiguous(), reconstructions.contiguous())\n                rec_loss = rec_loss + self.perceptual_weight * p_loss\n            else:\n                p_loss = torch.tensor([0.0])\n                p_style_loss = torch.tensor([0.0])\n\n            nll_loss = rec_loss\n            #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]\n            nll_loss = torch.mean(nll_loss)\n\n        \n            # generator update\n            \n            logits_fake = self.discriminator(reconstructions.contiguous())\n            g_loss = -torch.mean(logits_fake)\n\n            try:\n                d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)\n            except RuntimeError:\n                assert not self.training\n                d_weight = torch.tensor(0.0)\n\n            disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)\n            \n            loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() + p_style_loss\n\n            log = {\n                   \"{}/quant_loss\".format(split): codebook_loss.detach().mean(),\n                   \"{}/nll_loss\".format(split): nll_loss.detach().mean(),\n                   \"{}/rec_loss\".format(split): rec_loss.detach().mean(),\n                   \"{}/p_loss\".format(split): p_loss.detach().mean(),\n                   \"{}/p_style_loss\".format(split): p_style_loss.detach().mean(),\n                   \"{}/d_weight\".format(split): d_weight.detach(),\n                   \"{}/disc_factor\".format(split): torch.tensor(disc_factor),\n                   \"{}/g_loss\".format(split): g_loss.detach().mean(),\n                   }\n\n            if self.comp_weight > 0. and components is not None and self.discriminator_iter_start < global_step:\n                fake_left_eye, fake_left_eye_feats = self.net_d_left_eye(components['left_eyes'], return_feats=True)\n                comp_g_loss = self.cri_component(fake_left_eye, True, is_disc=False)\n                loss = loss + comp_g_loss \n                log[\"{}/g_left_loss\".format(split)] = comp_g_loss.detach()\n\n                fake_right_eye, fake_right_eye_feats = self.net_d_right_eye(components['right_eyes'], return_feats=True)\n                comp_g_loss = self.cri_component(fake_right_eye, True, is_disc=False)\n                loss = loss + comp_g_loss \n                log[\"{}/g_right_loss\".format(split)] = comp_g_loss.detach()\n\n                fake_mouth, fake_mouth_feats = self.net_d_mouth(components['mouths'], return_feats=True)\n                comp_g_loss = self.cri_component(fake_mouth, True, is_disc=False)\n                loss = loss + comp_g_loss \n                log[\"{}/g_mouth_loss\".format(split)] = comp_g_loss.detach()\n\n                if self.comp_style_weight > 0.:\n                    _, real_left_eye_feats = self.net_d_left_eye(components['left_eyes_gt'], return_feats=True)\n                    _, real_right_eye_feats = self.net_d_right_eye(components['right_eyes_gt'], return_feats=True)\n                    _, real_mouth_feats = self.net_d_mouth(components['mouths_gt'], return_feats=True)\n\n                    def _comp_style(feat, feat_gt, criterion):\n                        return criterion(self._gram_mat(feat[0]), self._gram_mat(\n                            feat_gt[0].detach())) * 0.5 + criterion(self._gram_mat(\n                            feat[1]), self._gram_mat(feat_gt[1].detach()))\n\n                    comp_style_loss = 0.\n                    comp_style_loss = comp_style_loss + _comp_style(fake_left_eye_feats, real_left_eye_feats, self.cri_style)\n                    comp_style_loss = comp_style_loss + _comp_style(fake_right_eye_feats, real_right_eye_feats, self.cri_style)\n                    comp_style_loss = comp_style_loss + _comp_style(fake_mouth_feats, real_mouth_feats, self.cri_style)\n                    loss = loss + comp_style_loss \n                    log[\"{}/comp_style_loss\".format(split)] = comp_style_loss.detach()\n\n            if self.identity_weight > 0. and self.discriminator_iter_start < global_step:\n                self.identity.eval()\n                out_gray = self.gray_resize_for_identity(reconstructions)\n                gt_gray = self.gray_resize_for_identity(gts)\n                \n                identity_gt = self.identity(gt_gray).detach()\n                identity_out = self.identity(out_gray)\n\n                identity_loss = self.cri_identity(identity_out, identity_gt)\n                loss = loss + identity_loss \n                log[\"{}/identity_loss\".format(split)] = identity_loss.detach()\n\n            log[\"{}/total_loss\".format(split)] = loss.clone().detach().mean()\n\n            return loss, log\n\n        if optimizer_idx == 1:\n            # second pass for discriminator update\n            \n            logits_real = self.discriminator(gts.contiguous().detach())\n            logits_fake = self.discriminator(reconstructions.contiguous().detach())\n\n            disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)\n            d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)\n\n            log = {\"{}/disc_loss\".format(split): d_loss.clone().detach().mean(),\n                   \"{}/logits_real\".format(split): logits_real.detach().mean(),\n                   \"{}/logits_fake\".format(split): logits_fake.detach().mean()\n                   }\n            return d_loss, log\n\n        # left eye\n        if optimizer_idx == 2:\n            # third pass for discriminator update\n            disc_factor = adopt_weight(1.0, global_step, threshold=self.discriminator_iter_start)\n            fake_d_pred, _ = self.net_d_left_eye(components['left_eyes'].detach())\n            real_d_pred, _ = self.net_d_left_eye(components['left_eyes_gt'])\n            d_loss = self.cri_component(real_d_pred, True, is_disc=True) + self.cri_component(fake_d_pred, False, is_disc=True)\n\n            log = {\"{}/d_left_loss\".format(split): d_loss.clone().detach().mean()}\n            return d_loss, log\n\n        # right eye\n        if optimizer_idx == 3:\n            # forth pass for discriminator update\n            fake_d_pred, _ = self.net_d_right_eye(components['right_eyes'].detach())\n            real_d_pred, _ = self.net_d_right_eye(components['right_eyes_gt'])\n            d_loss = self.cri_component(real_d_pred, True, is_disc=True) + self.cri_component(fake_d_pred, False, is_disc=True)\n\n            log = {\"{}/d_right_loss\".format(split): d_loss.clone().detach().mean()}\n            return d_loss, log\n\n        # mouth\n        if optimizer_idx == 4:\n            # fifth pass for discriminator update\n            fake_d_pred, _ = self.net_d_mouth(components['mouths'].detach())\n            real_d_pred, _ = self.net_d_mouth(components['mouths_gt'])\n            d_loss = self.cri_component(real_d_pred, True, is_disc=True) + self.cri_component(fake_d_pred, False, is_disc=True)\n\n            log = {\"{}/d_mouth_loss\".format(split): d_loss.clone().detach().mean()}\n            return d_loss, log\n"
  },
  {
    "path": "RestoreFormer/modules/util.py",
    "content": "import torch\nimport torch.nn as nn\n\n\ndef count_params(model):\n    total_params = sum(p.numel() for p in model.parameters())\n    return total_params\n\n\nclass ActNorm(nn.Module):\n    def __init__(self, num_features, logdet=False, affine=True,\n                 allow_reverse_init=False):\n        assert affine\n        super().__init__()\n        self.logdet = logdet\n        self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))\n        self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))\n        self.allow_reverse_init = allow_reverse_init\n\n        self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))\n\n    def initialize(self, input):\n        with torch.no_grad():\n            flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)\n            mean = (\n                flatten.mean(1)\n                .unsqueeze(1)\n                .unsqueeze(2)\n                .unsqueeze(3)\n                .permute(1, 0, 2, 3)\n            )\n            std = (\n                flatten.std(1)\n                .unsqueeze(1)\n                .unsqueeze(2)\n                .unsqueeze(3)\n                .permute(1, 0, 2, 3)\n            )\n\n            self.loc.data.copy_(-mean)\n            self.scale.data.copy_(1 / (std + 1e-6))\n\n    def forward(self, input, reverse=False):\n        if reverse:\n            return self.reverse(input)\n        if len(input.shape) == 2:\n            input = input[:,:,None,None]\n            squeeze = True\n        else:\n            squeeze = False\n\n        _, _, height, width = input.shape\n\n        if self.training and self.initialized.item() == 0:\n            self.initialize(input)\n            self.initialized.fill_(1)\n\n        h = self.scale * (input + self.loc)\n\n        if squeeze:\n            h = h.squeeze(-1).squeeze(-1)\n\n        if self.logdet:\n            log_abs = torch.log(torch.abs(self.scale))\n            logdet = height*width*torch.sum(log_abs)\n            logdet = logdet * torch.ones(input.shape[0]).to(input)\n            return h, logdet\n\n        return h\n\n    def reverse(self, output):\n        if self.training and self.initialized.item() == 0:\n            if not self.allow_reverse_init:\n                raise RuntimeError(\n                    \"Initializing ActNorm in reverse direction is \"\n                    \"disabled by default. Use allow_reverse_init=True to enable.\"\n                )\n            else:\n                self.initialize(output)\n                self.initialized.fill_(1)\n\n        if len(output.shape) == 2:\n            output = output[:,:,None,None]\n            squeeze = True\n        else:\n            squeeze = False\n\n        h = output / self.scale - self.loc\n\n        if squeeze:\n            h = h.squeeze(-1).squeeze(-1)\n        return h\n\n\nclass Attention2DConv(nn.Module):\n    \"\"\"to replace the convolutional architecture entirely\"\"\"\n    def __init__(self):\n        super().__init__()\n"
  },
  {
    "path": "RestoreFormer/modules/vqvae/arcface_arch.py",
    "content": "import torch.nn as nn\n\nfrom basicsr.utils.registry import ARCH_REGISTRY\n\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(BasicBlock, self).__init__()\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass IRBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):\n        super(IRBlock, self).__init__()\n        self.bn0 = nn.BatchNorm2d(inplanes)\n        self.conv1 = conv3x3(inplanes, inplanes)\n        self.bn1 = nn.BatchNorm2d(inplanes)\n        self.prelu = nn.PReLU()\n        self.conv2 = conv3x3(inplanes, planes, stride)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.downsample = downsample\n        self.stride = stride\n        self.use_se = use_se\n        if self.use_se:\n            self.se = SEBlock(planes)\n\n    def forward(self, x):\n        residual = x\n        out = self.bn0(x)\n        out = self.conv1(out)\n        out = self.bn1(out)\n        out = self.prelu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        if self.use_se:\n            out = self.se(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.prelu(out)\n\n        return out\n\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(Bottleneck, self).__init__()\n        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)\n        self.bn3 = nn.BatchNorm2d(planes * self.expansion)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass SEBlock(nn.Module):\n\n    def __init__(self, channel, reduction=16):\n        super(SEBlock, self).__init__()\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.fc = nn.Sequential(\n            nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),\n            nn.Sigmoid())\n\n    def forward(self, x):\n        b, c, _, _ = x.size()\n        y = self.avg_pool(x).view(b, c)\n        y = self.fc(y).view(b, c, 1, 1)\n        return x * y\n\n\n@ARCH_REGISTRY.register()\nclass ResNetArcFace(nn.Module):\n\n    def __init__(self, block, layers, use_se=True):\n        if block == 'IRBlock':\n            block = IRBlock\n        self.inplanes = 64\n        self.use_se = use_se\n        super(ResNetArcFace, self).__init__()\n        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(64)\n        self.prelu = nn.PReLU()\n        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n        self.bn4 = nn.BatchNorm2d(512)\n        self.dropout = nn.Dropout()\n        self.fc5 = nn.Linear(512 * 8 * 8, 512)\n        self.bn5 = nn.BatchNorm1d(512)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.xavier_normal_(m.weight)\n            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.xavier_normal_(m.weight)\n                nn.init.constant_(m.bias, 0)\n\n    def _make_layer(self, block, planes, blocks, stride=1):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),\n                nn.BatchNorm2d(planes * block.expansion),\n            )\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))\n        self.inplanes = planes\n        for _ in range(1, blocks):\n            layers.append(block(self.inplanes, planes, use_se=self.use_se))\n\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.prelu(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n        x = self.bn4(x)\n        x = self.dropout(x)\n        x = x.view(x.size(0), -1)\n        x = self.fc5(x)\n        x = self.bn5(x)\n\n        return x\n"
  },
  {
    "path": "RestoreFormer/modules/vqvae/facial_component_discriminator.py",
    "content": "import math\nimport random\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,\n                                          StyleGAN2Generator)\nfrom basicsr.ops.fused_act import FusedLeakyReLU\nfrom basicsr.utils.registry import ARCH_REGISTRY\n\n\n@ARCH_REGISTRY.register()\nclass FacialComponentDiscriminator(nn.Module):\n\n    def __init__(self):\n        super(FacialComponentDiscriminator, self).__init__()\n\n        self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)\n        self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)\n        self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)\n        self.conv4 = ConvLayer(128, 256, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)\n        self.conv5 = ConvLayer(256, 256, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)\n        self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)\n\n    def forward(self, x, return_feats=False):\n        feat = self.conv1(x)\n        feat = self.conv3(self.conv2(feat))\n        rlt_feats = []\n        if return_feats:\n            rlt_feats.append(feat.clone())\n        feat = self.conv5(self.conv4(feat))\n        if return_feats:\n            rlt_feats.append(feat.clone())\n        out = self.final_conv(feat)\n\n        if return_feats:\n            return out, rlt_feats\n        else:\n            return out, None\n"
  },
  {
    "path": "RestoreFormer/modules/vqvae/utils.py",
    "content": "from torchvision.ops import roi_align\nimport torch\n\ndef get_roi_regions(gt, output, loc_left_eyes, loc_right_eyes, loc_mouths,\n                    face_ratio=1, eye_out_size=80, mouth_out_size=120):\n    # hard code\n    eye_out_size *= face_ratio\n    mouth_out_size *= face_ratio\n\n    eye_out_size = int(eye_out_size)\n    mouth_out_size = int(mouth_out_size)\n\n    rois_eyes = []\n    rois_mouths = []\n    for b in range(loc_left_eyes.size(0)):  # loop for batch size\n        # left eye and right eye\n        img_inds = loc_left_eyes.new_full((2, 1), b)\n        bbox = torch.stack([loc_left_eyes[b, :], loc_right_eyes[b, :]], dim=0)  # shape: (2, 4)\n        rois = torch.cat([img_inds, bbox], dim=-1)  # shape: (2, 5)\n        rois_eyes.append(rois)\n        # mouse\n        img_inds = loc_left_eyes.new_full((1, 1), b)\n        rois = torch.cat([img_inds, loc_mouths[b:b + 1, :]], dim=-1)  # shape: (1, 5)\n        rois_mouths.append(rois)\n\n    rois_eyes = torch.cat(rois_eyes, 0)\n    rois_mouths = torch.cat(rois_mouths, 0)\n\n    # real images\n    all_eyes = roi_align(gt, boxes=rois_eyes, output_size=eye_out_size) * face_ratio\n    left_eyes_gt = all_eyes[0::2, :, :, :]\n    right_eyes_gt = all_eyes[1::2, :, :, :]\n    mouths_gt = roi_align(gt, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio\n    # output\n    all_eyes = roi_align(output, boxes=rois_eyes, output_size=eye_out_size) * face_ratio\n    left_eyes = all_eyes[0::2, :, :, :]\n    right_eyes = all_eyes[1::2, :, :, :]\n    mouths = roi_align(output, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio\n\n    return {'left_eyes_gt': left_eyes_gt, 'right_eyes_gt': right_eyes_gt, 'mouths_gt': mouths_gt, \n            'left_eyes': left_eyes, 'right_eyes': right_eyes, 'mouths': mouths}\n"
  },
  {
    "path": "RestoreFormer/modules/vqvae/vqvae_arch.py",
    "content": "import torch\nimport torch.nn as nn\nimport random\nimport math\nimport torch.nn.functional as F\nimport numpy as np\n# from basicsr.utils.registry import ARCH_REGISTRY\nimport torch.nn.utils.spectral_norm as SpectralNorm\nimport RestoreFormer.distributed as dist_fn\n\nclass VectorQuantizer(nn.Module):\n    \"\"\"\n    see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py\n    ____________________________________________\n    Discretization bottleneck part of the VQ-VAE.\n    Inputs:\n    - n_e : number of embeddings\n    - e_dim : dimension of embedding\n    - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2\n    _____________________________________________\n    \"\"\"\n\n    def __init__(self, n_e, e_dim, beta):\n        super(VectorQuantizer, self).__init__()\n        self.n_e = n_e\n        self.e_dim = e_dim\n        self.beta = beta\n\n        self.embedding = nn.Embedding(self.n_e, self.e_dim)\n        self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)\n\n    def forward(self, z):\n        \"\"\"\n        Inputs the output of the encoder network z and maps it to a discrete\n        one-hot vector that is the index of the closest embedding vector e_j\n        z (continuous) -> z_q (discrete)\n        z.shape = (batch, channel, height, width)\n        quantization pipeline:\n            1. get encoder input (B,C,H,W)\n            2. flatten input to (B*H*W,C)\n        \"\"\"\n        # reshape z -> (batch, height, width, channel) and flatten\n        z = z.permute(0, 2, 3, 1).contiguous()\n        z_flattened = z.view(-1, self.e_dim)\n        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z\n\n        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \\\n            torch.sum(self.embedding.weight**2, dim=1) - 2 * \\\n            torch.matmul(z_flattened, self.embedding.weight.t())\n\n        ## could possible replace this here\n        # #\\start...\n        # find closest encodings\n\n        min_value, min_encoding_indices = torch.min(d, dim=1)\n\n        min_encoding_indices = min_encoding_indices.unsqueeze(1)\n\n        min_encodings = torch.zeros(\n            min_encoding_indices.shape[0], self.n_e).to(z)\n        min_encodings.scatter_(1, min_encoding_indices, 1)\n\n        # dtype min encodings: torch.float32\n        # min_encodings shape: torch.Size([2048, 512])\n        # min_encoding_indices.shape: torch.Size([2048, 1])\n\n        # get quantized latent vectors\n        z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)\n        #.........\\end\n\n        # with:\n        # .........\\start\n        #min_encoding_indices = torch.argmin(d, dim=1)\n        #z_q = self.embedding(min_encoding_indices)\n        # ......\\end......... (TODO)\n\n        # compute loss for embedding\n        loss = torch.mean((z_q.detach()-z)**2) + self.beta * \\\n            torch.mean((z_q - z.detach()) ** 2)\n\n        # preserve gradients\n        z_q = z + (z_q - z).detach()\n\n        # perplexity\n        \n        e_mean = torch.mean(min_encodings, dim=0)\n        perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))\n\n        # reshape back to match original input shape\n        z_q = z_q.permute(0, 3, 1, 2).contiguous()\n\n        return z_q, loss, (perplexity, min_encodings, min_encoding_indices, d)\n\n    def get_codebook_entry(self, indices, shape):\n        # shape specifying (batch, height, width, channel)\n        # TODO: check for more easy handling with nn.Embedding\n        min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)\n        min_encodings.scatter_(1, indices[:,None], 1)\n\n        # get quantized latent vectors\n        z_q = torch.matmul(min_encodings.float(), self.embedding.weight)\n\n        if shape is not None:\n            z_q = z_q.view(shape)\n\n            # reshape back to match original input shape\n            z_q = z_q.permute(0, 3, 1, 2).contiguous()\n\n        return z_q\n\n# pytorch_diffusion + derived encoder decoder\ndef nonlinearity(x):\n    # swish\n    return x*torch.sigmoid(x)\n\n\ndef Normalize(in_channels):\n    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)\n\n\nclass Upsample(nn.Module):\n    def __init__(self, in_channels, with_conv):\n        super().__init__()\n        self.with_conv = with_conv\n        if self.with_conv:\n            self.conv = torch.nn.Conv2d(in_channels,\n                                        in_channels,\n                                        kernel_size=3,\n                                        stride=1,\n                                        padding=1)\n\n    def forward(self, x):\n        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode=\"nearest\")\n        if self.with_conv:\n            x = self.conv(x)\n        return x\n\n\nclass Downsample(nn.Module):\n    def __init__(self, in_channels, with_conv):\n        super().__init__()\n        self.with_conv = with_conv\n        if self.with_conv:\n            # no asymmetric padding in torch conv, must do it ourselves\n            self.conv = torch.nn.Conv2d(in_channels,\n                                        in_channels,\n                                        kernel_size=3,\n                                        stride=2,\n                                        padding=0)\n\n    def forward(self, x):\n        if self.with_conv:\n            pad = (0,1,0,1)\n            x = torch.nn.functional.pad(x, pad, mode=\"constant\", value=0)\n            x = self.conv(x)\n        else:\n            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)\n        return x\n\n\nclass ResnetBlock(nn.Module):\n    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,\n                 dropout, temb_channels=512):\n        super().__init__()\n        self.in_channels = in_channels\n        out_channels = in_channels if out_channels is None else out_channels\n        self.out_channels = out_channels\n        self.use_conv_shortcut = conv_shortcut\n\n        self.norm1 = Normalize(in_channels)\n        self.conv1 = torch.nn.Conv2d(in_channels,\n                                     out_channels,\n                                     kernel_size=3,\n                                     stride=1,\n                                     padding=1)\n        if temb_channels > 0:\n            self.temb_proj = torch.nn.Linear(temb_channels,\n                                             out_channels)\n        self.norm2 = Normalize(out_channels)\n        self.dropout = torch.nn.Dropout(dropout)\n        self.conv2 = torch.nn.Conv2d(out_channels,\n                                     out_channels,\n                                     kernel_size=3,\n                                     stride=1,\n                                     padding=1)\n        if self.in_channels != self.out_channels:\n            if self.use_conv_shortcut:\n                self.conv_shortcut = torch.nn.Conv2d(in_channels,\n                                                     out_channels,\n                                                     kernel_size=3,\n                                                     stride=1,\n                                                     padding=1)\n            else:\n                self.nin_shortcut = torch.nn.Conv2d(in_channels,\n                                                    out_channels,\n                                                    kernel_size=1,\n                                                    stride=1,\n                                                    padding=0)\n\n    def forward(self, x, temb):\n        h = x\n        h = self.norm1(h)\n        h = nonlinearity(h)\n        h = self.conv1(h)\n\n        if temb is not None:\n            h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]\n\n        h = self.norm2(h)\n        h = nonlinearity(h)\n        h = self.dropout(h)\n        h = self.conv2(h)\n\n        if self.in_channels != self.out_channels:\n            if self.use_conv_shortcut:\n                x = self.conv_shortcut(x)\n            else:\n                x = self.nin_shortcut(x)\n\n        return x+h\n\n\nclass MultiHeadAttnBlock(nn.Module):\n    def __init__(self, in_channels, head_size=1):\n        super().__init__()\n        self.in_channels = in_channels\n        self.head_size = head_size\n        self.att_size = in_channels // head_size\n        assert(in_channels % head_size == 0), 'The size of head should be divided by the number of channels.'\n\n        self.norm1 = Normalize(in_channels)\n        self.norm2 = Normalize(in_channels)\n\n        self.q = torch.nn.Conv2d(in_channels,\n                                 in_channels,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n        self.k = torch.nn.Conv2d(in_channels,\n                                 in_channels,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n        self.v = torch.nn.Conv2d(in_channels,\n                                 in_channels,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n        self.proj_out = torch.nn.Conv2d(in_channels,\n                                        in_channels,\n                                        kernel_size=1,\n                                        stride=1,\n                                        padding=0)\n        self.num = 0\n\n    def forward(self, x, y=None):\n        h_ = x\n        h_ = self.norm1(h_)\n        if y is None:\n            y = h_\n        else:\n            y = self.norm2(y)\n\n        q = self.q(y)\n        k = self.k(h_)\n        v = self.v(h_)\n\n        # compute attention\n        b,c,h,w = q.shape\n        q = q.reshape(b, self.head_size, self.att_size ,h*w) \n        q = q.permute(0, 3, 1, 2) # b, hw, head, att\n\n        k = k.reshape(b, self.head_size, self.att_size ,h*w) \n        k = k.permute(0, 3, 1, 2)\n\n        v = v.reshape(b, self.head_size, self.att_size ,h*w) \n        v = v.permute(0, 3, 1, 2)\n\n\n        q = q.transpose(1, 2)\n        v = v.transpose(1, 2)\n        k = k.transpose(1, 2).transpose(2,3)\n\n        scale = int(self.att_size)**(-0.5)\n        q.mul_(scale)\n        w_ = torch.matmul(q, k)\n        w_ = F.softmax(w_, dim=3)\n\n        w_ = w_.matmul(v)\n\n        w_ = w_.transpose(1, 2).contiguous() # [b, h*w, head, att]\n        w_ = w_.view(b, h, w, -1)\n        w_ = w_.permute(0, 3, 1, 2)\n\n        w_ = self.proj_out(w_)\n\n        return x+w_\n\n\nclass MultiHeadEncoder(nn.Module):\n    def __init__(self, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks=2,\n                 attn_resolutions=[16], dropout=0.0, resamp_with_conv=True, in_channels=3,\n                 resolution=512, z_channels=256, double_z=True, enable_mid=True,\n                 head_size=1, **ignore_kwargs):\n        super().__init__()\n        self.ch = ch\n        self.temb_ch = 0\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        self.resolution = resolution\n        self.in_channels = in_channels\n        self.enable_mid = enable_mid\n\n        # downsampling\n        self.conv_in = torch.nn.Conv2d(in_channels,\n                                       self.ch,\n                                       kernel_size=3,\n                                       stride=1,\n                                       padding=1)\n\n        curr_res = resolution\n        in_ch_mult = (1,)+tuple(ch_mult)\n        self.down = nn.ModuleList()\n        for i_level in range(self.num_resolutions):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_in = ch*in_ch_mult[i_level]\n            block_out = ch*ch_mult[i_level]\n            for i_block in range(self.num_res_blocks):\n                block.append(ResnetBlock(in_channels=block_in,\n                                         out_channels=block_out,\n                                         temb_channels=self.temb_ch,\n                                         dropout=dropout))\n                block_in = block_out\n                if curr_res in attn_resolutions:\n                    attn.append(MultiHeadAttnBlock(block_in, head_size))\n            down = nn.Module()\n            down.block = block\n            down.attn = attn\n            if i_level != self.num_resolutions-1:\n                down.downsample = Downsample(block_in, resamp_with_conv)\n                curr_res = curr_res // 2\n            self.down.append(down)\n\n        # middle\n        if self.enable_mid:\n            self.mid = nn.Module()\n            self.mid.block_1 = ResnetBlock(in_channels=block_in,\n                                           out_channels=block_in,\n                                           temb_channels=self.temb_ch,\n                                           dropout=dropout)\n            self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)\n            self.mid.block_2 = ResnetBlock(in_channels=block_in,\n                                           out_channels=block_in,\n                                           temb_channels=self.temb_ch,\n                                           dropout=dropout)\n\n        # end\n        self.norm_out = Normalize(block_in)\n        self.conv_out = torch.nn.Conv2d(block_in,\n                                        2*z_channels if double_z else z_channels,\n                                        kernel_size=3,\n                                        stride=1,\n                                        padding=1)\n\n\n    def forward(self, x):\n        #assert x.shape[2] == x.shape[3] == self.resolution, \"{}, {}, {}\".format(x.shape[2], x.shape[3], self.resolution)\n\n        hs = {}\n        # timestep embedding\n        temb = None\n\n        # downsampling\n        h = self.conv_in(x)\n        hs['in'] = h\n        for i_level in range(self.num_resolutions):\n            for i_block in range(self.num_res_blocks):\n                h = self.down[i_level].block[i_block](h, temb)\n                if len(self.down[i_level].attn) > 0:\n                    h = self.down[i_level].attn[i_block](h)\n\n            if i_level != self.num_resolutions-1:\n                # hs.append(h)\n                hs['block_'+str(i_level)] = h\n                h = self.down[i_level].downsample(h)\n\n        # middle\n        # h = hs[-1]\n        if self.enable_mid:\n            h = self.mid.block_1(h, temb)\n            hs['block_'+str(i_level)+'_atten'] = h\n            h = self.mid.attn_1(h)\n            h = self.mid.block_2(h, temb)\n            hs['mid_atten'] = h\n\n        # end\n        h = self.norm_out(h)\n        h = nonlinearity(h)\n        h = self.conv_out(h)\n        # hs.append(h)\n        hs['out'] = h\n\n        return hs\n\nclass MultiHeadDecoder(nn.Module):\n    def __init__(self, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks=2,\n                 attn_resolutions=16, dropout=0.0, resamp_with_conv=True, in_channels=3,\n                 resolution=512, z_channels=256, give_pre_end=False, enable_mid=True,\n                 head_size=1, **ignorekwargs):\n        super().__init__()\n        self.ch = ch\n        self.temb_ch = 0\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        self.resolution = resolution\n        self.in_channels = in_channels\n        self.give_pre_end = give_pre_end\n        self.enable_mid = enable_mid\n\n        # compute in_ch_mult, block_in and curr_res at lowest res\n        in_ch_mult = (1,)+tuple(ch_mult)\n        block_in = ch*ch_mult[self.num_resolutions-1]\n        curr_res = resolution // 2**(self.num_resolutions-1)\n        self.z_shape = (1,z_channels,curr_res,curr_res)\n        print(\"Working with z of shape {} = {} dimensions.\".format(\n            self.z_shape, np.prod(self.z_shape)))\n\n        # z to block_in\n        self.conv_in = torch.nn.Conv2d(z_channels,\n                                       block_in,\n                                       kernel_size=3,\n                                       stride=1,\n                                       padding=1)\n\n        # middle\n        if self.enable_mid:\n            self.mid = nn.Module()\n            self.mid.block_1 = ResnetBlock(in_channels=block_in,\n                                           out_channels=block_in,\n                                           temb_channels=self.temb_ch,\n                                           dropout=dropout)\n            self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)\n            self.mid.block_2 = ResnetBlock(in_channels=block_in,\n                                           out_channels=block_in,\n                                           temb_channels=self.temb_ch,\n                                           dropout=dropout)\n\n        # upsampling\n        self.up = nn.ModuleList()\n        for i_level in reversed(range(self.num_resolutions)):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_out = ch*ch_mult[i_level]\n            for i_block in range(self.num_res_blocks+1):\n                block.append(ResnetBlock(in_channels=block_in,\n                                         out_channels=block_out,\n                                         temb_channels=self.temb_ch,\n                                         dropout=dropout))\n                block_in = block_out\n                if curr_res in attn_resolutions:\n                    attn.append(MultiHeadAttnBlock(block_in, head_size))\n            up = nn.Module()\n            up.block = block\n            up.attn = attn\n            if i_level != 0:\n                up.upsample = Upsample(block_in, resamp_with_conv)\n                curr_res = curr_res * 2\n            self.up.insert(0, up) # prepend to get consistent order\n\n        # end\n        self.norm_out = Normalize(block_in)\n        self.conv_out = torch.nn.Conv2d(block_in,\n                                        out_ch,\n                                        kernel_size=3,\n                                        stride=1,\n                                        padding=1)\n\n    def forward(self, z):\n        #assert z.shape[1:] == self.z_shape[1:]\n        self.last_z_shape = z.shape\n\n        # timestep embedding\n        temb = None\n\n        # z to block_in\n        h = self.conv_in(z)\n\n        # middle\n        if self.enable_mid:\n            h = self.mid.block_1(h, temb)\n            h = self.mid.attn_1(h)\n            h = self.mid.block_2(h, temb)\n\n        # upsampling\n        for i_level in reversed(range(self.num_resolutions)):\n            for i_block in range(self.num_res_blocks+1):\n                h = self.up[i_level].block[i_block](h, temb)\n                if len(self.up[i_level].attn) > 0:\n                    h = self.up[i_level].attn[i_block](h)\n            if i_level != 0:\n                h = self.up[i_level].upsample(h)\n\n        # end\n        if self.give_pre_end:\n            return h\n\n        h = self.norm_out(h)\n        h = nonlinearity(h)\n        h = self.conv_out(h)\n        return h\n\nclass MultiHeadDecoderTransformer(nn.Module):\n    def __init__(self, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks=2,\n                 attn_resolutions=16, dropout=0.0, resamp_with_conv=True, in_channels=3,\n                 resolution=512, z_channels=256, give_pre_end=False, enable_mid=True,\n                 head_size=1, **ignorekwargs):\n        super().__init__()\n        self.ch = ch\n        self.temb_ch = 0\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        self.resolution = resolution\n        self.in_channels = in_channels\n        self.give_pre_end = give_pre_end\n        self.enable_mid = enable_mid\n\n        # compute in_ch_mult, block_in and curr_res at lowest res\n        in_ch_mult = (1,)+tuple(ch_mult)\n        block_in = ch*ch_mult[self.num_resolutions-1]\n        curr_res = resolution // 2**(self.num_resolutions-1)\n        self.z_shape = (1,z_channels,curr_res,curr_res)\n        print(\"Working with z of shape {} = {} dimensions.\".format(\n            self.z_shape, np.prod(self.z_shape)))\n\n        # z to block_in\n        self.conv_in = torch.nn.Conv2d(z_channels,\n                                       block_in,\n                                       kernel_size=3,\n                                       stride=1,\n                                       padding=1)\n\n        # middle\n        if self.enable_mid:\n            self.mid = nn.Module()\n            self.mid.block_1 = ResnetBlock(in_channels=block_in,\n                                           out_channels=block_in,\n                                           temb_channels=self.temb_ch,\n                                           dropout=dropout)\n            self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)\n            self.mid.block_2 = ResnetBlock(in_channels=block_in,\n                                           out_channels=block_in,\n                                           temb_channels=self.temb_ch,\n                                           dropout=dropout)\n\n        # upsampling\n        self.up = nn.ModuleList()\n        for i_level in reversed(range(self.num_resolutions)):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_out = ch*ch_mult[i_level]\n            for i_block in range(self.num_res_blocks+1):\n                block.append(ResnetBlock(in_channels=block_in,\n                                         out_channels=block_out,\n                                         temb_channels=self.temb_ch,\n                                         dropout=dropout))\n                block_in = block_out\n                if curr_res in attn_resolutions:\n                    attn.append(MultiHeadAttnBlock(block_in, head_size))\n            up = nn.Module()\n            up.block = block\n            up.attn = attn\n            if i_level != 0:\n                up.upsample = Upsample(block_in, resamp_with_conv)\n                curr_res = curr_res * 2\n            self.up.insert(0, up) # prepend to get consistent order\n\n        # end\n        self.norm_out = Normalize(block_in)\n        self.conv_out = torch.nn.Conv2d(block_in,\n                                        out_ch,\n                                        kernel_size=3,\n                                        stride=1,\n                                        padding=1)\n\n    def forward(self, z, hs):\n        #assert z.shape[1:] == self.z_shape[1:]\n        # self.last_z_shape = z.shape\n\n        # timestep embedding\n        temb = None\n\n        # z to block_in\n        h = self.conv_in(z)\n\n        # middle\n        if self.enable_mid:\n            h = self.mid.block_1(h, temb)\n            h = self.mid.attn_1(h, hs['mid_atten'])\n            h = self.mid.block_2(h, temb)\n\n        # upsampling\n        for i_level in reversed(range(self.num_resolutions)):\n            for i_block in range(self.num_res_blocks+1):\n                h = self.up[i_level].block[i_block](h, temb)\n                if len(self.up[i_level].attn) > 0:\n                    h = self.up[i_level].attn[i_block](h, hs['block_'+str(i_level)+'_atten'])\n                    # hfeature = h.clone()\n            if i_level != 0:\n                h = self.up[i_level].upsample(h)\n\n        # end\n        if self.give_pre_end:\n            return h\n\n        h = self.norm_out(h)\n        h = nonlinearity(h)\n        h = self.conv_out(h)\n        return h\n\n\nclass VQVAEGAN(nn.Module):\n    def __init__(self, n_embed=1024, embed_dim=256, ch=128, out_ch=3, ch_mult=(1,2,4,8), \n                 num_res_blocks=2, attn_resolutions=16, dropout=0.0, in_channels=3, \n                 resolution=512, z_channels=256, double_z=False, enable_mid=True, \n                 fix_decoder=False, fix_codebook=False, head_size=1, **ignore_kwargs):\n        super(VQVAEGAN, self).__init__()\n\n        self.encoder = MultiHeadEncoder(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,\n                               attn_resolutions=attn_resolutions, dropout=dropout, in_channels=in_channels,\n                               resolution=resolution, z_channels=z_channels, double_z=double_z, \n                               enable_mid=enable_mid, head_size=head_size)\n        self.decoder = MultiHeadDecoder(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,\n                               attn_resolutions=attn_resolutions, dropout=dropout, in_channels=in_channels,\n                               resolution=resolution, z_channels=z_channels, enable_mid=enable_mid, head_size=head_size)\n\n        self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)\n\n        self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1)\n        self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)\n\n        if fix_decoder:\n            for _, param in self.decoder.named_parameters():\n                param.requires_grad = False\n            for _, param in self.post_quant_conv.named_parameters():\n                param.requires_grad = False\n            for _, param in self.quantize.named_parameters():\n                param.requires_grad = False\n        elif fix_codebook:\n            for _, param in self.quantize.named_parameters():\n                param.requires_grad = False\n\n    def encode(self, x):\n\n        hs = self.encoder(x)\n        h = self.quant_conv(hs['out'])\n        quant, emb_loss, info = self.quantize(h)\n        return quant, emb_loss, info, hs\n\n    def decode(self, quant):\n        quant = self.post_quant_conv(quant)\n        dec = self.decoder(quant)\n\n        return dec\n\n    def forward(self, input):\n        quant, diff, info, hs = self.encode(input)\n        dec = self.decode(quant)\n\n        return dec, diff, info, hs\n\nclass VQVAEGANMultiHeadTransformer(nn.Module):\n    def __init__(self, n_embed=1024, embed_dim=256, ch=128, out_ch=3, ch_mult=(1,2,4,8), \n                 num_res_blocks=2, attn_resolutions=16, dropout=0.0, in_channels=3, \n                 resolution=512, z_channels=256, double_z=False, enable_mid=True, \n                 fix_decoder=False, fix_codebook=False, fix_encoder=False, constrastive_learning_loss_weight=0.0,\n                 head_size=1):\n        super(VQVAEGANMultiHeadTransformer, self).__init__()\n\n        self.encoder = MultiHeadEncoder(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,\n                               attn_resolutions=attn_resolutions, dropout=dropout, in_channels=in_channels,\n                               resolution=resolution, z_channels=z_channels, double_z=double_z, \n                               enable_mid=enable_mid, head_size=head_size)\n        self.decoder = MultiHeadDecoderTransformer(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,\n                               attn_resolutions=attn_resolutions, dropout=dropout, in_channels=in_channels,\n                               resolution=resolution, z_channels=z_channels, enable_mid=enable_mid, head_size=head_size)\n\n        self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)\n\n        self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1)\n        self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)\n\n        if fix_decoder:\n            for _, param in self.decoder.named_parameters():\n                param.requires_grad = False\n            for _, param in self.post_quant_conv.named_parameters():\n                param.requires_grad = False\n            for _, param in self.quantize.named_parameters():\n                param.requires_grad = False\n        elif fix_codebook:\n            for _, param in self.quantize.named_parameters():\n                param.requires_grad = False\n\n        if fix_encoder:\n            for _, param in self.encoder.named_parameters():\n                param.requires_grad = False\n\n    def encode(self, x):\n        \n        hs = self.encoder(x)\n        h = self.quant_conv(hs['out'])\n        quant, emb_loss, info = self.quantize(h)\n        return quant, emb_loss, info, hs\n\n    def decode(self, quant, hs):\n        quant = self.post_quant_conv(quant)\n        dec = self.decoder(quant, hs)\n\n        return dec\n\n    def forward(self, input):\n        quant, diff, info, hs = self.encode(input)\n        dec = self.decode(quant, hs)\n\n        return dec, diff, info, hs"
  },
  {
    "path": "RestoreFormer/util.py",
    "content": "import os, hashlib\nimport requests\nfrom tqdm import tqdm\n\nURL_MAP = {\n    \"vgg_lpips\": \"https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1\"\n}\n\nCKPT_MAP = {\n    \"vgg_lpips\": \"vgg.pth\"\n}\n\nMD5_MAP = {\n    \"vgg_lpips\": \"d507d7349b931f0638a25a48a722f98a\"\n}\n\n\ndef download(url, local_path, chunk_size=1024):\n    os.makedirs(os.path.split(local_path)[0], exist_ok=True)\n    with requests.get(url, stream=True) as r:\n        total_size = int(r.headers.get(\"content-length\", 0))\n        with tqdm(total=total_size, unit=\"B\", unit_scale=True) as pbar:\n            with open(local_path, \"wb\") as f:\n                for data in r.iter_content(chunk_size=chunk_size):\n                    if data:\n                        f.write(data)\n                        pbar.update(chunk_size)\n\n\ndef md5_hash(path):\n    with open(path, \"rb\") as f:\n        content = f.read()\n    return hashlib.md5(content).hexdigest()\n\n\ndef get_ckpt_path(name, root, check=False):\n    assert name in URL_MAP\n    path = os.path.join(root, CKPT_MAP[name])\n    if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):\n        print(\"Downloading {} model from {} to {}\".format(name, URL_MAP[name], path))\n        download(URL_MAP[name], path)\n        md5 = md5_hash(path)\n        assert md5 == MD5_MAP[name], md5\n    return path\n\n\nclass KeyNotFoundError(Exception):\n    def __init__(self, cause, keys=None, visited=None):\n        self.cause = cause\n        self.keys = keys\n        self.visited = visited\n        messages = list()\n        if keys is not None:\n            messages.append(\"Key not found: {}\".format(keys))\n        if visited is not None:\n            messages.append(\"Visited: {}\".format(visited))\n        messages.append(\"Cause:\\n{}\".format(cause))\n        message = \"\\n\".join(messages)\n        super().__init__(message)\n\n\ndef retrieve(\n    list_or_dict, key, splitval=\"/\", default=None, expand=True, pass_success=False\n):\n    \"\"\"Given a nested list or dict return the desired value at key expanding\n    callable nodes if necessary and :attr:`expand` is ``True``. The expansion\n    is done in-place.\n\n    Parameters\n    ----------\n        list_or_dict : list or dict\n            Possibly nested list or dictionary.\n        key : str\n            key/to/value, path like string describing all keys necessary to\n            consider to get to the desired value. List indices can also be\n            passed here.\n        splitval : str\n            String that defines the delimiter between keys of the\n            different depth levels in `key`.\n        default : obj\n            Value returned if :attr:`key` is not found.\n        expand : bool\n            Whether to expand callable nodes on the path or not.\n\n    Returns\n    -------\n        The desired value or if :attr:`default` is not ``None`` and the\n        :attr:`key` is not found returns ``default``.\n\n    Raises\n    ------\n        Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is\n        ``None``.\n    \"\"\"\n\n    keys = key.split(splitval)\n\n    success = True\n    try:\n        visited = []\n        parent = None\n        last_key = None\n        for key in keys:\n            if callable(list_or_dict):\n                if not expand:\n                    raise KeyNotFoundError(\n                        ValueError(\n                            \"Trying to get past callable node with expand=False.\"\n                        ),\n                        keys=keys,\n                        visited=visited,\n                    )\n                list_or_dict = list_or_dict()\n                parent[last_key] = list_or_dict\n\n            last_key = key\n            parent = list_or_dict\n\n            try:\n                if isinstance(list_or_dict, dict):\n                    list_or_dict = list_or_dict[key]\n                else:\n                    list_or_dict = list_or_dict[int(key)]\n            except (KeyError, IndexError, ValueError) as e:\n                raise KeyNotFoundError(e, keys=keys, visited=visited)\n\n            visited += [key]\n        # final expansion of retrieved value\n        if expand and callable(list_or_dict):\n            list_or_dict = list_or_dict()\n            parent[last_key] = list_or_dict\n    except KeyNotFoundError as e:\n        if default is None:\n            raise e\n        else:\n            list_or_dict = default\n            success = False\n\n    if not pass_success:\n        return list_or_dict\n    else:\n        return list_or_dict, success\n\n\nif __name__ == \"__main__\":\n    config = {\"keya\": \"a\",\n              \"keyb\": \"b\",\n              \"keyc\":\n                  {\"cc1\": 1,\n                   \"cc2\": 2,\n                   }\n              }\n    from omegaconf import OmegaConf\n    config = OmegaConf.create(config)\n    print(config)\n    retrieve(config, \"keya\")\n\n"
  },
  {
    "path": "__init__.py",
    "content": ""
  },
  {
    "path": "configs/HQ_Dictionary.yaml",
    "content": "model:\n  base_learning_rate: 4.5e-6\n  target: RestoreFormer.models.vqgan_v1.RestoreFormerModel\n  params:\n    image_key: 'gt'\n    schedule_step: [400000, 800000]\n    # ignore_keys: ['vqvae.quantize.utility_counter']\n    ddconfig:\n      target: RestoreFormer.modules.vqvae.vqvae_arch.VQVAEGAN\n      params:\n        embed_dim: 256\n        n_embed: 1024\n        double_z: False\n        z_channels: 256\n        resolution: 512\n        in_channels: 3\n        out_ch: 3\n        ch: 64\n        ch_mult: [ 1,2,2,4,4,8]  # num_down = len(ch_mult)-1\n        num_res_blocks: 2\n        attn_resolutions: [16]\n        dropout: 0.0\n        enable_mid: True\n        fix_decoder: False\n        fix_codebook: False\n        head_size: 8\n\n    lossconfig:\n      target: RestoreFormer.modules.losses.vqperceptual.VQLPIPSWithDiscriminatorWithCompWithIdentity\n      params:\n        disc_conditional: False\n        disc_in_channels: 3\n        disc_start: 30001\n        disc_weight: 0.8\n        codebook_weight: 1.0\n        use_actnorm: False\n\ndata:\n  target: main.DataModuleFromConfig\n  params:\n    batch_size: 4\n    num_workers: 8\n    train:\n      target: basicsr.data.ffhq_dataset.FFHQDataset\n      params:\n        dataroot_gt: data/FFHQ/images512x512\n        io_backend:\n          type: disk\n        use_hflip: True\n        mean: [0.5, 0.5, 0.5]\n        std: [0.5, 0.5, 0.5]\n        out_size: 512\n    validation:\n      target: basicsr.data.ffhq_dataset.FFHQDataset\n      params:\n        dataroot_gt: data/FFHQ/images512x512\n        io_backend:\n          type: disk\n        use_hflip: False\n        mean: [0.5, 0.5, 0.5]\n        std: [0.5, 0.5, 0.5]\n        out_size: 512\n"
  },
  {
    "path": "configs/RestoreFormer.yaml",
    "content": "model:\n  base_learning_rate: 4.5e-6 \n  target: RestoreFormer.models.vqgan_v1.RestoreFormerModel\n  params:\n    image_key: 'lq'\n    ckpt_path: 'YOUR TRAINED HD DICTIONARY MODEL'\n    special_params_lr_scale: 10\n    comp_params_lr_scale: 10\n    schedule_step: [4000000, 8000000]\n    ddconfig:\n      target: RestoreFormer.modules.vqvae.vqvae_arch.VQVAEGANMultiHeadTransformer\n      params:\n        embed_dim: 256\n        n_embed: 1024\n        double_z: False\n        z_channels: 256\n        resolution: 512\n        in_channels: 3  \n        out_ch: 3\n        ch: 64\n        ch_mult: [ 1,2,2,4,4,8]  # num_down = len(ch_mult)-1\n        num_res_blocks: 2\n        dropout: 0.0\n        attn_resolutions: [16]\n        enable_mid: True\n\n        fix_decoder: False\n        fix_codebook: True\n        fix_encoder: False\n        head_size: 8\n\n    lossconfig:\n      target: RestoreFormer.modules.losses.vqperceptual.VQLPIPSWithDiscriminatorWithCompWithIdentity\n      params:\n        disc_conditional: False\n        disc_in_channels: 3\n        disc_start: 10001\n        disc_weight: 0.8\n        codebook_weight: 1.0\n        use_actnorm: False\n        comp_weight: 1.5\n        comp_style_weight: 2e3 #2000.0\n        identity_weight: 3 #1.5\n        lpips_style_weight: 1e9\n        identity_model_path: experiments/pretrained_models/arcface_resnet18.pth\n\ndata:\n  target: main.DataModuleFromConfig\n  params:\n    batch_size: 4\n    num_workers: 8\n    train:\n      target: RestoreFormer.data.ffhq_degradation_dataset.FFHQDegradationDataset\n      params:\n        dataroot_gt: data/FFHQ/images512x512\n        io_backend:\n          type: disk\n        use_hflip: True\n        mean: [0.5, 0.5, 0.5]\n        std: [0.5, 0.5, 0.5]\n        out_size: 512\n\n        blur_kernel_size: [19,20]\n        kernel_list: ['iso', 'aniso']\n        kernel_prob: [0.5, 0.5]\n        blur_sigma: [0.1, 10]\n        downsample_range: [0.8, 8]\n        noise_range: [0, 20]\n        jpeg_range: [60, 100]\n\n        color_jitter_prob: ~\n        color_jitter_shift: 20\n        color_jitter_pt_prob: ~\n        gray_prob: ~\n        gt_gray: True\n\n        crop_components: True\n        component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth\n        eye_enlarge_ratio: 1.4\n\n\n    validation:\n      target: RestoreFormer.data.ffhq_degradation_dataset.FFHQDegradationDataset\n      params:\n        dataroot_gt: data/FFHQ/images512x512\n        io_backend:\n          type: disk\n        use_hflip: False\n        mean: [0.5, 0.5, 0.5]\n        std: [0.5, 0.5, 0.5]\n        out_size: 512\n\n        blur_kernel_size: [19,20]\n        kernel_list: ['iso', 'aniso']\n        kernel_prob: [0.5, 0.5]\n        blur_sigma: [0.1, 10]\n        downsample_range: [0.8, 8]\n        noise_range: [0, 20]\n        jpeg_range: [60, 100]\n\n        # color jitter and gray\n        color_jitter_prob: ~\n        color_jitter_shift: 20\n        color_jitter_pt_prob: ~\n        gray_prob: ~\n        gt_gray: True\n\n        crop_components: False\n        component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth\n        eye_enlarge_ratio: 1.4\n"
  },
  {
    "path": "main.py",
    "content": "import argparse, os, sys, datetime, glob, importlib\nfrom omegaconf import OmegaConf\nimport numpy as np\nfrom PIL import Image\nimport torch\nimport torchvision\nfrom torch.utils.data import random_split, DataLoader, Dataset\nimport pytorch_lightning as pl\nfrom pytorch_lightning import seed_everything\nfrom pytorch_lightning.trainer import Trainer\nfrom pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor\nfrom pytorch_lightning.utilities.distributed import rank_zero_only\nimport random\n\ndef get_obj_from_str(string, reload=False):\n    module, cls = string.rsplit(\".\", 1)\n    if reload:\n        module_imp = importlib.import_module(module)\n        importlib.reload(module_imp)\n    return getattr(importlib.import_module(module, package=None), cls)\n\n\ndef get_parser(**parser_kwargs):\n    def str2bool(v):\n        if isinstance(v, bool):\n            return v\n        if v.lower() in (\"yes\", \"true\", \"t\", \"y\", \"1\"):\n            return True\n        elif v.lower() in (\"no\", \"false\", \"f\", \"n\", \"0\"):\n            return False\n        else:\n            raise argparse.ArgumentTypeError(\"Boolean value expected.\")\n\n    parser = argparse.ArgumentParser(**parser_kwargs)\n    parser.add_argument(\n        \"-n\",\n        \"--name\",\n        type=str,\n        const=True,\n        default=\"\",\n        nargs=\"?\",\n        help=\"postfix for logdir\",\n    )\n    parser.add_argument(\n        \"-r\",\n        \"--resume\",\n        type=str,\n        const=True,\n        default=\"\",\n        nargs=\"?\",\n        help=\"resume from logdir or checkpoint in logdir\",\n    )\n    parser.add_argument(\n        \"--pretrain\",\n        type=str,\n        const=True,\n        default=\"\",\n        nargs=\"?\",\n        help=\"pretrain with existed weights\",\n    )\n    parser.add_argument(\n        \"-b\",\n        \"--base\",\n        nargs=\"*\",\n        metavar=\"base_config.yaml\",\n        help=\"paths to base configs. Loaded from left-to-right. \"\n        \"Parameters can be overwritten or added with command-line options of the form `--key value`.\",\n        default=list(),\n    )\n    parser.add_argument(\n        \"-t\",\n        \"--train\",\n        type=str2bool,\n        const=True,\n        default=False,\n        nargs=\"?\",\n        help=\"train\",\n    )\n    parser.add_argument(\n        \"--no-test\",\n        type=str2bool,\n        const=True,\n        default=False,\n        nargs=\"?\",\n        help=\"disable test\",\n    )\n    parser.add_argument(\"-p\", \"--project\", help=\"name of new or path to existing project\")\n    parser.add_argument(\n        \"-d\",\n        \"--debug\",\n        type=str2bool,\n        nargs=\"?\",\n        const=True,\n        default=False,\n        help=\"enable post-mortem debugging\",\n    )\n    parser.add_argument(\n        \"-s\",\n        \"--seed\",\n        type=int,\n        default=23,\n        help=\"seed for seed_everything\",\n    )\n    parser.add_argument(\n        \"--random-seed\",\n        type=str2bool,\n        nargs=\"?\",\n        const=True,\n        default=False,\n        help=\"enable post-mortem debugging\",\n    )\n    parser.add_argument(\n        \"-f\",\n        \"--postfix\",\n        type=str,\n        default=\"\",\n        help=\"post-postfix for default name\",\n    )\n\n    parser.add_argument(\n        \"--root-path\",\n        type=str,\n        default=\"./\",\n        help=\"root path for saving checkpoints and logs\"\n    )\n    parser.add_argument(\n        \"--num-nodes\",\n        type=int,\n        default=1,\n        help=\"number of gpu nodes\",\n    )\n    \n\n    return parser\n\n\ndef nondefault_trainer_args(opt):\n    parser = argparse.ArgumentParser()\n    parser = Trainer.add_argparse_args(parser)\n    args = parser.parse_args([])\n    return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))\n\n\ndef instantiate_from_config(config):\n    if not \"target\" in config:\n        raise KeyError(\"Expected key `target` to instantiate.\")\n    if 'basicsr.data' in config[\"target\"] or \\\n        'FFHQDegradationDataset' in config[\"target\"]:\n        return get_obj_from_str(config[\"target\"])(config.get(\"params\", dict()))\n    return get_obj_from_str(config[\"target\"])(**config.get(\"params\", dict()))\n\n\nclass WrappedDataset(Dataset):\n    \"\"\"Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset\"\"\"\n    def __init__(self, dataset):\n        self.data = dataset\n\n    def __len__(self):\n        return len(self.data)\n\n    def __getitem__(self, idx):\n        return self.data[idx]\n\n\nclass DataModuleFromConfig(pl.LightningDataModule):\n    def __init__(self, batch_size, train=None, validation=None, test=None,\n                 wrap=False, num_workers=None):\n        super().__init__()\n        self.batch_size = batch_size\n        self.dataset_configs = dict()\n        self.num_workers = num_workers if num_workers is not None else batch_size*2\n        if train is not None:\n            self.dataset_configs[\"train\"] = train\n            self.train_dataloader = self._train_dataloader\n        if validation is not None:\n            self.dataset_configs[\"validation\"] = validation\n            self.val_dataloader = self._val_dataloader\n        if test is not None:\n            self.dataset_configs[\"test\"] = test\n            self.test_dataloader = self._test_dataloader\n        self.wrap = wrap\n\n    def prepare_data(self):\n        for data_cfg in self.dataset_configs.values():\n            instantiate_from_config(data_cfg)\n\n    def setup(self, stage=None):\n        self.datasets = dict(\n            (k, instantiate_from_config(self.dataset_configs[k]))\n            for k in self.dataset_configs)\n        if self.wrap:\n            for k in self.datasets:\n                self.datasets[k] = WrappedDataset(self.datasets[k])\n\n    def _train_dataloader(self):\n        return DataLoader(self.datasets[\"train\"], batch_size=self.batch_size,\n                          num_workers=self.num_workers, shuffle=True)\n\n    def _val_dataloader(self):\n        return DataLoader(self.datasets[\"validation\"],\n                          batch_size=self.batch_size,\n                          num_workers=self.num_workers)\n\n    def _test_dataloader(self):\n        return DataLoader(self.datasets[\"test\"], batch_size=self.batch_size,\n                          num_workers=self.num_workers)\n\n\nclass SetupCallback(Callback):\n    def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):\n        super().__init__()\n        self.resume = resume\n        self.now = now\n        self.logdir = logdir\n        self.ckptdir = ckptdir\n        self.cfgdir = cfgdir\n        self.config = config\n        self.lightning_config = lightning_config\n\n    def on_pretrain_routine_start(self, trainer, pl_module):\n        if trainer.global_rank == 0:\n            # import pdb\n            # pdb.set_trace()\n            # Create logdirs and save configs\n            os.makedirs(self.logdir, exist_ok=True)\n            os.makedirs(self.ckptdir, exist_ok=True)\n            os.makedirs(self.cfgdir, exist_ok=True)\n\n            print(\"Project config\")\n            print(self.config.pretty())\n            OmegaConf.save(self.config,\n                           os.path.join(self.cfgdir, \"{}-project.yaml\".format(self.now)))\n\n            print(\"Lightning config\")\n            print(self.lightning_config.pretty())\n            OmegaConf.save(OmegaConf.create({\"lightning\": self.lightning_config}),\n                           os.path.join(self.cfgdir, \"{}-lightning.yaml\".format(self.now)))\n\n\nclass ImageLogger(Callback):\n    def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True):\n        super().__init__()\n        self.batch_freq = batch_frequency\n        self.max_images = max_images\n        self.logger_log_images = {\n            pl.loggers.WandbLogger: self._wandb,\n            pl.loggers.TestTubeLogger: self._testtube,\n        }\n        self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]\n        if not increase_log_steps:\n            self.log_steps = [self.batch_freq]\n        self.clamp = clamp\n\n    @rank_zero_only\n    def _wandb(self, pl_module, images, batch_idx, split):\n        raise ValueError(\"No way wandb\")\n        grids = dict()\n        for k in images:\n            grid = torchvision.utils.make_grid(images[k])\n            grids[f\"{split}/{k}\"] = wandb.Image(grid)\n        pl_module.logger.experiment.log(grids)\n\n    @rank_zero_only\n    def _testtube(self, pl_module, images, batch_idx, split):\n        for k in images:\n            grid = torchvision.utils.make_grid(images[k])\n            grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w\n\n            tag = f\"{split}/{k}\"\n            pl_module.logger.experiment.add_image(\n                tag, grid,\n                global_step=pl_module.global_step)\n\n    @rank_zero_only\n    def log_local(self, save_dir, split, images,\n                  global_step, current_epoch, batch_idx):\n        root = os.path.join(save_dir, \"images\", split)\n        for k in images:\n            grid = torchvision.utils.make_grid(images[k], nrow=4)\n\n            grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w\n            grid = grid.transpose(0,1).transpose(1,2).squeeze(-1)\n            grid = grid.numpy()\n            grid = (grid*255).astype(np.uint8)\n            filename = \"{}_gs-{:06}_e-{:06}_b-{:06}.png\".format(\n                k,\n                global_step,\n                current_epoch,\n                batch_idx)\n            path = os.path.join(root, filename)\n            os.makedirs(os.path.split(path)[0], exist_ok=True)\n            Image.fromarray(grid).save(path)\n\n    def log_img(self, pl_module, batch, batch_idx, split=\"train\"):\n        if (self.check_frequency(batch_idx) and  # batch_idx % self.batch_freq == 0\n                hasattr(pl_module, \"log_images\") and\n                callable(pl_module.log_images) and\n                self.max_images > 0):\n            logger = type(pl_module.logger)\n\n            is_train = pl_module.training\n            if is_train:\n                pl_module.eval()\n\n            with torch.no_grad():\n                images = pl_module.log_images(batch, split=split)\n\n            for k in images:\n                N = min(images[k].shape[0], self.max_images)\n                images[k] = images[k][:N]\n                if isinstance(images[k], torch.Tensor):\n                    images[k] = images[k].detach().cpu()\n                    if self.clamp:\n                        images[k] = torch.clamp(images[k], -1., 1.)\n\n            self.log_local(pl_module.logger.save_dir, split, images,\n                           pl_module.global_step, pl_module.current_epoch, batch_idx)\n\n            logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)\n            logger_log_images(pl_module, images, pl_module.global_step, split)\n\n            if is_train:\n                pl_module.train()\n\n    def check_frequency(self, batch_idx):\n        if (batch_idx % self.batch_freq) == 0 or (batch_idx in self.log_steps):\n            try:\n                self.log_steps.pop(0)\n            except IndexError:\n                pass\n            return True\n        return False\n\n    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):\n        self.log_img(pl_module, batch, batch_idx, split=\"train\")\n\n    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):\n        self.log_img(pl_module, batch, batch_idx, split=\"val\")\n\n\n\nif __name__ == \"__main__\":\n    # custom parser to specify config files, train, test and debug mode,\n    # postfix, resume.\n    # `--key value` arguments are interpreted as arguments to the trainer.\n    # `nested.key=value` arguments are interpreted as config parameters.\n    # configs are merged from left-to-right followed by command line parameters.\n\n    # model:\n    #   base_learning_rate: float\n    #   target: path to lightning module\n    #   params:\n    #       key: value\n    # data:\n    #   target: main.DataModuleFromConfig\n    #   params:\n    #      batch_size: int\n    #      wrap: bool\n    #      train:\n    #          target: path to train dataset\n    #          params:\n    #              key: value\n    #      validation:\n    #          target: path to validation dataset\n    #          params:\n    #              key: value\n    #      test:\n    #          target: path to test dataset\n    #          params:\n    #              key: value\n    # lightning: (optional, has sane defaults and can be specified on cmdline)\n    #   trainer:\n    #       additional arguments to trainer\n    #   logger:\n    #       logger to instantiate\n    #   modelcheckpoint:\n    #       modelcheckpoint to instantiate\n    #   callbacks:\n    #       callback1:\n    #           target: importpath\n    #           params:\n    #               key: value\n    now = datetime.datetime.now().strftime(\"%Y-%m-%dT%H-%M-%S\")\n\n    # add cwd for convenience and to make classes in this file available when\n    # running as `python main.py`\n    # (in particular `main.DataModuleFromConfig`)\n    sys.path.append(os.getcwd())\n\n    parser = get_parser()\n    parser = Trainer.add_argparse_args(parser)\n\n    opt, unknown = parser.parse_known_args()\n    if opt.name and opt.resume:\n        raise ValueError(\n            \"-n/--name and -r/--resume cannot be specified both.\"\n            \"If you want to resume training in a new log folder, \"\n            \"use -n/--name in combination with --resume_from_checkpoint\"\n        )\n    if opt.resume:\n        if not os.path.exists(opt.resume):\n            raise ValueError(\"Cannot find {}\".format(opt.resume))\n        if os.path.isfile(opt.resume):\n            paths = opt.resume.split(\"/\")\n            idx = len(paths)-paths[::-1].index(\"logs\")+1\n            logdir = \"/\".join(paths[:idx])\n            ckpt = opt.resume\n        else:\n            assert os.path.isdir(opt.resume), opt.resume\n            logdir = opt.resume.rstrip(\"/\")\n            ckpt = os.path.join(logdir, \"checkpoints\", \"last.ckpt\")\n\n        opt.resume_from_checkpoint = ckpt\n        base_configs = sorted(glob.glob(os.path.join(logdir, \"configs/*.yaml\")))\n        opt.base = base_configs+opt.base\n        _tmp = logdir.split(\"/\")\n        nowname = _tmp[_tmp.index(\"logs\")+1]+opt.postfix\n        logdir = os.path.join(opt.root_path, \"logs\", nowname)\n    else:\n        if opt.name:\n            name = \"_\"+opt.name\n        elif opt.base:\n            cfg_fname = os.path.split(opt.base[0])[-1]\n            cfg_name = os.path.splitext(cfg_fname)[0]\n            name = \"_\"+cfg_name\n        else:\n            name = \"\"\n        nowname = now+name+opt.postfix\n        logdir = os.path.join(opt.root_path, \"logs\", nowname)\n\n    if opt.random_seed:\n        opt.seed = random.randint(1,100)\n    logdir = logdir + '_seed' + str(opt.seed)\n    \n    ckptdir = os.path.join(logdir, \"checkpoints\")\n    cfgdir = os.path.join(logdir, \"configs\")\n\n    seed_everything(opt.seed)\n\n    try:\n        # init and save configs\n        configs = [OmegaConf.load(cfg) for cfg in opt.base]\n        cli = OmegaConf.from_dotlist(unknown)\n        config = OmegaConf.merge(*configs, cli)\n        lightning_config = config.pop(\"lightning\", OmegaConf.create())\n        # merge trainer cli with config\n        trainer_config = lightning_config.get(\"trainer\", OmegaConf.create())\n        # default to ddp\n        # trainer_config[\"distributed_backend\"] = \"ddp\"\n        trainer_config[\"accelerator\"] = \"ddp\"\n        # trainer_config[\"plugins\"]=\"ddp_sharded\"\n        for k in nondefault_trainer_args(opt):\n            trainer_config[k] = getattr(opt, k)\n        if not \"gpus\" in trainer_config:\n            del trainer_config[\"distributed_backend\"]\n            cpu = True\n        else:\n            gpuinfo = trainer_config[\"gpus\"]\n            print(f\"Running on GPUs {gpuinfo}\")\n            cpu = False\n        trainer_opt = argparse.Namespace(**trainer_config)\n        lightning_config.trainer = trainer_config\n\n        # model\n        model = instantiate_from_config(config.model)\n\n        # trainer and callbacks\n        trainer_kwargs = dict()\n        # trainer_kwargs['sync_batchnorm'] = True\n        \n        # default logger configs\n        # NOTE wandb < 0.10.0 interferes with shutdown\n        # wandb >= 0.10.0 seems to fix it but still interferes with pudb\n        # debugging (wrongly sized pudb ui)\n        # thus prefer testtube for now\n        default_logger_cfgs = {\n            \"wandb\": {\n                \"target\": \"pytorch_lightning.loggers.WandbLogger\",\n                \"params\": {\n                    \"name\": nowname,\n                    \"save_dir\": logdir,\n                    \"offline\": opt.debug,\n                    \"id\": nowname,\n                }\n            },\n            \"testtube\": {\n                \"target\": \"pytorch_lightning.loggers.TestTubeLogger\",\n                \"params\": {\n                    \"name\": \"testtube\",\n                    \"save_dir\": logdir,\n                }\n            },\n        }\n        default_logger_cfg = default_logger_cfgs[\"testtube\"]\n        logger_cfg = lightning_config.logger or OmegaConf.create()\n        logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)\n        trainer_kwargs[\"logger\"] = instantiate_from_config(logger_cfg)\n\n        # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to\n        # specify which metric is used to determine best models\n        default_modelckpt_cfg = {\n            \"target\": \"pytorch_lightning.callbacks.ModelCheckpoint\",\n            \"params\": {\n                \"dirpath\": ckptdir,\n                \"filename\": \"{epoch:06}\",\n                \"verbose\": True,\n                \"save_last\": True,\n                \"period\": 1\n            }\n        }\n        if hasattr(model, \"monitor\"):\n            print(f\"Monitoring {model.monitor} as checkpoint metric.\")\n            default_modelckpt_cfg[\"params\"][\"monitor\"] = model.monitor\n            default_modelckpt_cfg[\"params\"][\"save_top_k\"] = 3\n\n        modelckpt_cfg = lightning_config.modelcheckpoint or OmegaConf.create()\n        modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)\n        trainer_kwargs[\"checkpoint_callback\"] = instantiate_from_config(modelckpt_cfg)\n\n        # add callback which sets up log directory\n        default_callbacks_cfg = {\n            \"setup_callback\": {\n                \"target\": \"main.SetupCallback\",\n                \"params\": {\n                    \"resume\": opt.resume,\n                    \"now\": now,\n                    \"logdir\": logdir,\n                    \"ckptdir\": ckptdir,\n                    \"cfgdir\": cfgdir,\n                    \"config\": config,\n                    \"lightning_config\": lightning_config,\n                }\n            },\n            \"image_logger\": {\n                \"target\": \"main.ImageLogger\",\n                \"params\": {\n                    \"batch_frequency\": 750,\n                    \"max_images\": 4,\n                    \"clamp\": True\n                }\n            },\n            \"learning_rate_logger\": {\n                \"target\": \"main.LearningRateMonitor\",\n                \"params\": {\n                    \"logging_interval\": \"step\",\n                    #\"log_momentum\": True\n                }\n            },\n        }\n        callbacks_cfg = lightning_config.callbacks or OmegaConf.create()\n        callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)\n        trainer_kwargs[\"callbacks\"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]\n\n        trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)\n\n        # data\n        data = instantiate_from_config(config.data)\n        # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html\n        # calling these ourselves should not be necessary but it is.\n        # lightning still takes care of proper multiprocessing though\n        data.prepare_data()\n        data.setup()\n\n        # configure learning rate\n        bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate\n        if not cpu:\n            ngpu = len(lightning_config.trainer.gpus.strip(\",\").split(','))\n        else:\n            ngpu = 1\n        accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches or 1\n        print(f\"accumulate_grad_batches = {accumulate_grad_batches}\")\n        lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches\n        model.learning_rate = accumulate_grad_batches * ngpu * bs * trainer.num_nodes * base_lr\n        print(\"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (num_nodes) * {} (batchsize) * {:.2e} (base_lr)\".format(\n            model.learning_rate, accumulate_grad_batches, ngpu, trainer.num_nodes, bs, base_lr))\n\n        # allow checkpointing via USR1\n        def melk(*args, **kwargs):\n            # run all checkpoint hooks\n            if trainer.global_rank == 0:\n                print(\"Summoning checkpoint.\")\n                ckpt_path = os.path.join(ckptdir, \"last.ckpt\")\n                trainer.save_checkpoint(ckpt_path)\n\n        def divein(*args, **kwargs):\n            if trainer.global_rank == 0:\n                import pudb; pudb.set_trace()\n\n        import signal\n        signal.signal(signal.SIGUSR1, melk)\n        signal.signal(signal.SIGUSR2, divein)\n\n        # run\n        if opt.train:\n            try:\n                trainer.fit(model, data)\n            except Exception:\n                melk()\n                raise\n        if not opt.no_test and not trainer.interrupted:\n            trainer.test(model, data)\n    except Exception:\n        if opt.debug and trainer.global_rank==0:\n            try:\n                import pudb as debugger\n            except ImportError:\n                import pdb as debugger\n            debugger.post_mortem()\n        raise\n    finally:\n        # move newly created debug project to debug_runs\n        if opt.debug and not opt.resume and trainer.global_rank==0:\n            dst, name = os.path.split(logdir)\n            dst = os.path.join(dst, \"debug_runs\", name)\n            os.makedirs(os.path.split(dst)[0], exist_ok=True)\n            os.rename(logdir, dst)\n"
  },
  {
    "path": "restoreformer_requirement.txt",
    "content": "Package                 Version             Location\n----------------------- ------------------- ------------------------------------------------------------------------------\nabsl-py                 0.13.0\naddict                  2.4.0\naiohttp                 3.7.4.post0\nalbumentations          0.4.3\nantlr4-python3-runtime  4.8\nastunparse              1.6.3\nasync-timeout           3.0.1\nattrs                   21.2.0\nbasicsr                 1.3.3.4\ncached-property         1.5.2\ncachetools              4.2.2\ncertifi                 2021.5.30\nchardet                 4.0.0\ncycler                  0.10.0\ndlib                    19.22.99\nfacexlib                0.1.3.1\nflatbuffers             1.12\nfsspec                  2021.6.1\nfuture                  0.18.2\ngast                    0.4.0\ngoogle-auth             1.32.1\ngoogle-auth-oauthlib    0.4.4\ngoogle-pasta            0.2.0\ngrpcio                  1.39.0\nh5py                    3.1.0\nidna                    2.10\nimageio                 2.9.0\nimgaug                  0.2.6\nimportlib-metadata      4.6.1\njoblib                  1.0.1\nkeras-nightly           2.7.0.dev2021072800\nKeras-Preprocessing     1.1.2\nkiwisolver              1.3.1\nlibclang                11.1.0\nlmdb                    1.2.1\nMarkdown                3.3.4\nmatplotlib              3.4.2\nmkl-fft                 1.3.0\nmkl-random              1.2.1\nmkl-service             2.3.0\nmultidict               5.1.0\nnetworkx                2.6.1\nnumpy                   1.19.5\noauthlib                3.1.1\nolefile                 0.46\nomegaconf               2.0.0\nopencv-python           4.5.2.54\nopt-einsum              3.3.0\npackaging               21.0\npandas                  1.3.0\nPillow                  8.3.1\npip                     21.1.3\nprotobuf                3.17.3\npyasn1                  0.4.8\npyasn1-modules          0.2.8\npyDeprecate             0.3.0\npyparsing               2.4.7\npython-dateutil         2.8.1\npytorch-lightning       1.0.8\npytz                    2021.1\nPyWavelets              1.1.1\nPyYAML                  5.4.1\nrequests                2.25.1\nrequests-oauthlib       1.3.0\nrsa                     4.7.2\nscikit-image            0.18.2\nscikit-learn            0.24.2\nscipy                   1.7.0\nsetuptools              52.0.0.post20210125\nsix                     1.15.0\nsklearn                 0.0\ntb-nightly              2.6.0a20210728\ntensorboard-data-server 0.6.1\ntensorboard-plugin-wit  1.8.0\ntermcolor               1.1.0\ntest-tube               0.7.5\ntf-estimator-nightly    2.7.0.dev2021072801\ntf-nightly              2.7.0.dev20210728\nthreadpoolctl           2.2.0\ntifffile                2021.7.2\ntorch                   1.7.1\ntorchaudio              0.7.0a0+a853dff\ntorchmetrics            0.4.1\ntorchvision             0.8.2\ntqdm                    4.61.2\ntyping-extensions       3.7.4.3\nurllib3                 1.26.6\nWerkzeug                2.0.1\nwheel                   0.36.2\nwrapt                   1.12.1\nyapf                    0.31.0\nyarl                    1.6.3\nzipp                    3.5.0\n"
  },
  {
    "path": "scripts/metrics/cal_fid.py",
    "content": "import os, sys\nimport argparse\nimport math\nimport numpy as np\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom basicsr.data import build_dataset\nfrom basicsr.metrics.fid import calculate_fid, extract_inception_features, load_patched_inception_v3\n\n\ndef calculate_fid_folder():\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('folder', type=str, help='Path to the folder.')\n    parser.add_argument('--fid_stats', type=str, help='Path to the dataset fid statistics.')\n    parser.add_argument('--batch_size', type=int, default=64)\n    parser.add_argument('--num_sample', type=int, default=50000)\n    parser.add_argument('--num_workers', type=int, default=4)\n    parser.add_argument('--backend', type=str, default='disk', help='io backend for dataset. Option: disk, lmdb')\n    parser.add_argument('--save_name', type=str, default='fid', help='File name for saving results')\n    args = parser.parse_args()\n\n    # inception model\n    inception = load_patched_inception_v3(device)\n\n    # create dataset\n    opt = {}\n    opt['name'] = 'SingleImageDataset'\n    opt['type'] = 'SingleImageDataset'\n    opt['dataroot_lq'] = args.folder\n    opt['io_backend'] = dict(type=args.backend)\n    opt['mean'] = [0.5, 0.5, 0.5]\n    opt['std'] = [0.5, 0.5, 0.5]\n    dataset = build_dataset(opt)\n\n    # create dataloader\n    data_loader = DataLoader(\n        dataset=dataset,\n        batch_size=args.batch_size,\n        shuffle=False,\n        num_workers=args.num_workers,\n        sampler=None,\n        drop_last=False)\n    args.num_sample = min(args.num_sample, len(dataset))\n    total_batch = math.ceil(args.num_sample / args.batch_size)\n\n    def data_generator(data_loader, total_batch):\n        for idx, data in enumerate(data_loader):\n            if idx >= total_batch:\n                break\n            else:\n                yield data['lq']\n\n    features = extract_inception_features(data_generator(data_loader, total_batch), inception, total_batch, device)\n    features = features.numpy()\n    total_len = features.shape[0]\n    features = features[:args.num_sample]\n    # print(f'Extracted {total_len} features, ' f'use the first {features.shape[0]} features to calculate stats.')\n\n    sample_mean = np.mean(features, 0)\n    sample_cov = np.cov(features, rowvar=False)\n\n    # load the dataset stats\n    stats = torch.load(args.fid_stats)\n    real_mean = stats['mean']\n    real_cov = stats['cov']\n\n    # calculate FID metric\n    fid = calculate_fid(sample_mean, sample_cov, real_mean, real_cov)\n\n    fout=open(args.save_name, 'w')\n    fout.write(str(fid)+'\\n')\n    fout.close()\n\n    print(args.folder)\n    print('fid:', fid)\n\n\nif __name__ == '__main__':\n    calculate_fid_folder()\n"
  },
  {
    "path": "scripts/metrics/cal_identity_distance.py",
    "content": "import os, sys\nimport torch\nimport argparse\nimport cv2\nimport numpy as np\nimport glob\nimport pdb\nimport tqdm\nfrom copy import deepcopy\nimport torch.nn.functional as F\nimport math\n\n\nroot_path = os.path.abspath(os.path.join(__file__, os.path.pardir, os.path.pardir, os.path.pardir))\nsys.path.append(root_path)\nsys.path.append(os.path.join(root_path, 'RestoreFormer/modules/losses'))\n\nfrom RestoreFormer.modules.vqvae.arcface_arch import ResNetArcFace\nfrom basicsr.losses.losses import L1Loss, MSELoss\n\ndef cosine_similarity(emb1, emb2):\n    return np.arccos(np.dot(emb1, emb2) / ( np.linalg.norm(emb1) * np.linalg.norm(emb2)))\n\n\ndef gray_resize_for_identity(out, size=128):\n    out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])\n    out_gray = out_gray.unsqueeze(1)\n    out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)\n    return out_gray\n\ndef calculate_identity_distance_folder():\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument('folder', type=str, help='Path to the folder')\n    parser.add_argument('--gt_folder', type=str, help='Path to the GT')\n    parser.add_argument('--save_name', type=str, default='niqe', help='File name for saving results')\n    parser.add_argument('--need_post', type=int, default=0, help='0: the name of image does not include 00, 1: otherwise')\n\n    args = parser.parse_args()\n\n    fout = open(args.save_name, 'w')\n\n    identity = ResNetArcFace(block = 'IRBlock', \n                                  layers = [2, 2, 2, 2],\n                                  use_se = False)\n    identity_model_path = 'experiments/pretrained_models/arcface_resnet18.pth'\n    \n    sd = torch.load(identity_model_path, map_location=\"cpu\")\n    for k, v in deepcopy(sd).items():\n        if k.startswith('module.'):\n            sd[k[7:]] = v\n            sd.pop(k)\n    identity.load_state_dict(sd, strict=True)\n    identity.eval()\n\n    for param in identity.parameters():\n        param.requires_grad = False\n\n    identity = identity.cuda()\n\n    gt_names = glob.glob(os.path.join(args.gt_folder, '*'))\n    gt_names.sort()\n    \n    mean_dist = 0.\n    for i in tqdm.tqdm(range(len(gt_names))):\n        gt_name = gt_names[i].split('/')[-1][:-4]\n        if args.need_post:\n            img_name = os.path.join(args.folder,gt_name + '_00.png')\n        else:\n            img_name = os.path.join(args.folder,gt_name + '.png')\n        if not os.path.exists(img_name):\n            print(img_name, 'does not exist')\n            continue\n\n        img = cv2.imread(img_name)\n        gt = cv2.imread(gt_names[i])\n\n        img = img.astype(np.float32) / 255.\n        img = torch.FloatTensor(img).cuda()\n        img = img.permute(2,0,1)\n        img = img.unsqueeze(0)\n\n        gt = gt.astype(np.float32) / 255.\n        gt = torch.FloatTensor(gt).cuda()\n        gt = gt.permute(2,0,1)\n        gt = gt.unsqueeze(0)\n\n        out_gray = gray_resize_for_identity(img)\n        gt_gray = gray_resize_for_identity(gt)\n\n        with torch.no_grad():\n            identity_gt = identity(gt_gray)\n            identity_out = identity(out_gray)\n\n        identity_gt = identity_gt.cpu().data.numpy().squeeze()\n        identity_out = identity_out.cpu().data.numpy().squeeze()\n        identity_loss = cosine_similarity(identity_gt, identity_out)\n\n        fout.write(gt_name + ' ' + str(identity_loss) + '\\n')\n        mean_dist += identity_loss\n\n    fout.write('Mean: ' + str(mean_dist / len(gt_names)) + '\\n')\n    fout.close()\n    print('mean_dist:', mean_dist / len(gt_names))\n\nif __name__ == '__main__':\n    calculate_identity_distance_folder()"
  },
  {
    "path": "scripts/metrics/cal_psnr_ssim.py",
    "content": "import os, sys\nimport argparse\nimport cv2\nimport numpy as np\nimport glob\nimport pdb\nimport tqdm\nimport torch\n\nfrom basicsr.metrics.psnr_ssim import calculate_psnr, calculate_ssim\n\nroot_path = os.path.abspath(os.path.join(__file__, os.path.pardir, os.path.pardir, os.path.pardir))\nsys.path.append(root_path)\nsys.path.append(os.path.join(root_path, 'RestoreFormer/modules/losses'))\n\nfrom lpips import LPIPS\n\ndef calculate_psnr_ssim_lpips_folder():\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument('folder', type=str, help='Path to the folder')\n    parser.add_argument('--gt_folder', type=str, help='Path to the GT')\n    parser.add_argument('--save_name', type=str, default='niqe', help='File name for saving results')\n    parser.add_argument('--need_post', type=int, default=0, help='0: the name of image does not include 00, 1: otherwise')\n\n    args = parser.parse_args()\n\n    fout = open(args.save_name, 'w')\n    fout.write('NAME\\tPSNR\\tSSIM\\tLPIPS\\n')\n\n    H, W = 512, 512\n\n    gt_names = glob.glob(os.path.join(args.gt_folder, '*'))\n    gt_names.sort()\n\n    perceptual_loss = LPIPS().eval().cuda()\n\n    mean_psnr = 0.\n    mean_ssim = 0.\n    mean_lpips = 0.\n    mean_norm_lpips = 0.\n\n    for i in tqdm.tqdm(range(len(gt_names))):\n        gt_name = gt_names[i].split('/')[-1][:-4]\n\n        if args.need_post:\n            img_name = os.path.join(args.folder,gt_name + '_00.png')\n        else:\n            img_name = os.path.join(args.folder,gt_name + '.png')\n\n        if not os.path.exists(img_name):\n            print(img_name, 'does not exist')\n            continue\n\n        img = cv2.imread(img_name)\n        gt = cv2.imread(gt_names[i])\n\n        cur_psnr = calculate_psnr(img, gt, 0)\n        cur_ssim = calculate_ssim(img, gt, 0)\n\n        # lpips:\n        img = img.astype(np.float32) / 255.\n        img = torch.FloatTensor(img).cuda()\n        img = img.permute(2,0,1)\n        img = img.unsqueeze(0)\n\n        gt = gt.astype(np.float32) / 255.\n        gt = torch.FloatTensor(gt).cuda()\n        gt = gt.permute(2,0,1)\n        gt = gt.unsqueeze(0)\n\n        cur_lpips = perceptual_loss(img, gt)\n        cur_lpips = cur_lpips[0].item()\n\n        img = (img - 0.5) / 0.5\n        gt = (gt - 0.5) / 0.5\n\n        norm_lpips = perceptual_loss(img, gt)\n        norm_lpips = norm_lpips[0].item()\n\n        # print(cur_psnr, cur_ssim, cur_lpips, norm_lpips)\n\n        fout.write(gt_name + '\\t' + str(cur_psnr) + '\\t' + str(cur_ssim) + '\\t' + str(cur_lpips) + '\\t' + str(norm_lpips) + '\\n')\n\n        mean_psnr += cur_psnr\n        mean_ssim += cur_ssim\n        mean_lpips += cur_lpips\n        mean_norm_lpips += norm_lpips\n\n    mean_psnr /= float(len(gt_names))\n    mean_ssim /= float(len(gt_names))\n    mean_lpips /= float(len(gt_names))\n    mean_norm_lpips /= float(len(gt_names))\n\n    fout.write(str(mean_psnr) + '\\t' + str(mean_ssim) + '\\t' + str(mean_lpips) + '\\t' + str(mean_norm_lpips) + '\\n')\n    fout.close()\n\n    print('psnr, ssim, lpips, norm_lpips:', mean_psnr, mean_ssim, mean_lpips, mean_norm_lpips)\n\nif __name__ == '__main__':\n    calculate_psnr_ssim_lpips_folder()"
  },
  {
    "path": "scripts/metrics/run.sh",
    "content": "\n### Journal ###\nroot='results/'\nout_root='results/metrics'\n\ntest_name='RestoreFormer'\n\ntest_image=$test_name'/restored_faces'\nout_name=$test_name\nneed_post=1\n\nCelebAHQ_GT='YOUR_PATH'\n\n# FID\npython -u scripts/metrics/cal_fid.py \\\n$root'/'$test_image \\\n--fid_stats 'experiments/pretrained_models/inception_FFHQ_512-f7b384ab.pth' \\\n--save_name $out_root'/'$out_name'_fid.txt' \\\n\nif [ -d $CelebAHQ_GT ]\nthen\n    # PSRN SSIM LPIPS\n    python -u scripts/metrics/cal_psnr_ssim.py \\\n    $root'/'$test_image \\\n    --gt_folder $CelebAHQ_GT \\\n    --save_name $out_root'/'$out_name'_psnr_ssim_lpips.txt' \\\n    --need_post $need_post \\\n\n    # # # PSRN SSIM LPIPS\n    python -u scripts/metrics/cal_identity_distance.py  \\\n    $root'/'$test_image \\\n    --gt_folder $CelebAHQ_GT \\\n    --save_name $out_root'/'$out_name'_id.txt' \\\n    --need_post $need_post\nelse\n    echo 'The path of GT does not exist'\nfi"
  },
  {
    "path": "scripts/run.sh",
    "content": "export BASICSR_JIT=True\n\nconf_name='HQ_Dictionary'\n# conf_name='RestoreFormer'\n\nROOT_PATH='' # The path for saving model and logs\n\ngpus='0,1,2,3'\n\n#P: pretrain SL: soft learning\nnode_n=1\n\npython -u main.py \\\n--root-path $ROOT_PATH \\\n--base 'configs/'$conf_name'.yaml' \\\n-t True \\\n--postfix $conf_name \\\n--gpus $gpus \\\n--num-nodes $node_n \\\n--random-seed True \\\n"
  },
  {
    "path": "scripts/test.py",
    "content": "import argparse, os, sys, glob, math, time\nimport torch\nimport numpy as np\nfrom omegaconf import OmegaConf\nfrom PIL import Image\nimport pdb\n\nsys.path.append(os.getcwd())\n\nfrom main import instantiate_from_config, DataModuleFromConfig\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data.dataloader import default_collate\nfrom tqdm import trange, tqdm\n\nimport cv2\nfrom facexlib.utils.face_restoration_helper import FaceRestoreHelper\nfrom torchvision.transforms.functional import normalize\n\nfrom basicsr.utils import img2tensor, imwrite, tensor2img\n\n\ndef restoration(model,\n                face_helper,\n                img_path,\n                save_root,\n                has_aligned=False,\n                only_center_face=True,\n                suffix=None,\n                paste_back=False):\n    # read image\n    img_name = os.path.basename(img_path)\n    # print(f'Processing {img_name} ...')\n    basename, _ = os.path.splitext(img_name)\n    input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)\n    face_helper.clean_all()\n\n    if has_aligned:\n        input_img = cv2.resize(input_img, (512, 512))\n        face_helper.cropped_faces = [input_img]\n    else:\n        face_helper.read_image(input_img)\n        # get face landmarks for each face\n        face_helper.get_face_landmarks_5(only_center_face=only_center_face, pad_blur=False)\n        # align and warp each face\n        save_crop_path = os.path.join(save_root, 'cropped_faces', img_name)\n        face_helper.align_warp_face(save_crop_path)\n\n    # face restoration\n    for idx, cropped_face in enumerate(face_helper.cropped_faces):\n        # prepare data\n        cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)\n        normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)\n        cropped_face_t = cropped_face_t.unsqueeze(0).to('cuda')\n\n        try:\n            with torch.no_grad():\n                output = model(cropped_face_t)\n                restored_face = tensor2img(output[0].squeeze(0), rgb2bgr=True, min_max=(-1, 1))\n        except RuntimeError as error:\n            print(f'\\tFailed inference for GFPGAN: {error}.')\n            restored_face = cropped_face\n\n        restored_face = restored_face.astype('uint8')\n        face_helper.add_restored_face(restored_face)\n\n        if suffix is not None:\n            save_face_name = f'{basename}_{idx:02d}_{suffix}.png'\n        else:\n            save_face_name = f'{basename}_{idx:02d}.png'\n        save_restore_path = os.path.join(save_root, 'restored_faces', save_face_name)\n        imwrite(restored_face, save_restore_path)\n\n\n    if not has_aligned and paste_back:\n        face_helper.get_inverse_affine(None)\n        save_restore_path = os.path.join(save_root, 'restored_imgs', img_name)\n        # paste each restored face to the input image\n        face_helper.paste_faces_to_input_image(save_restore_path)\n\ndef get_parser():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"-r\",\n        \"--resume\",\n        type=str,\n        nargs=\"?\",\n        help=\"load from logdir or checkpoint in logdir\",\n    )\n    parser.add_argument(\n        \"-b\",\n        \"--base\",\n        nargs=\"*\",\n        metavar=\"base_config.yaml\",\n        help=\"paths to base configs. Loaded from left-to-right. \"\n        \"Parameters can be overwritten or added with command-line options of the form `--key value`.\",\n        default=list(),\n    )\n    parser.add_argument(\n        \"-c\",\n        \"--config\",\n        nargs=\"?\",\n        metavar=\"single_config.yaml\",\n        help=\"path to single config. If specified, base configs will be ignored \"\n        \"(except for the last one if left unspecified).\",\n        const=True,\n        default=\"\",\n    )\n    parser.add_argument(\n        \"--ignore_base_data\",\n        action=\"store_true\",\n        help=\"Ignore data specification from base configs. Useful if you want \"\n        \"to specify a custom datasets on the command line.\",\n    )\n    parser.add_argument(\n        \"--outdir\",\n        required=True,\n        type=str,\n        help=\"Where to write outputs to.\",\n    )\n    parser.add_argument(\n        \"--top_k\",\n        type=int,\n        default=100,\n        help=\"Sample from among top-k predictions.\",\n    )\n    parser.add_argument(\n        \"--temperature\",\n        type=float,\n        default=1.0,\n        help=\"Sampling temperature.\",\n    )\n    parser.add_argument('--upscale_factor', type=int, default=1)\n    parser.add_argument('--test_path', type=str, default='inputs/whole_imgs')\n    parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')\n    parser.add_argument('--only_center_face', action='store_true')\n    parser.add_argument('--aligned', action='store_true')\n    parser.add_argument('--paste_back', action='store_true')\n\n    return parser\n\n\ndef load_model_from_config(config, sd, gpu=True, eval_mode=True):\n    if \"ckpt_path\" in config.params:\n        print(\"Deleting the restore-ckpt path from the config...\")\n        config.params.ckpt_path = None\n    if \"downsample_cond_size\" in config.params:\n        print(\"Deleting downsample-cond-size from the config and setting factor=0.5 instead...\")\n        config.params.downsample_cond_size = -1\n        config.params[\"downsample_cond_factor\"] = 0.5\n    try:\n        if \"ckpt_path\" in config.params.first_stage_config.params:\n            config.params.first_stage_config.params.ckpt_path = None\n            print(\"Deleting the first-stage restore-ckpt path from the config...\")\n        if \"ckpt_path\" in config.params.cond_stage_config.params:\n            config.params.cond_stage_config.params.ckpt_path = None\n            print(\"Deleting the cond-stage restore-ckpt path from the config...\")\n    except:\n        pass\n\n    model = instantiate_from_config(config)\n    if sd is not None:\n        keys = list(sd.keys())\n\n        state_dict = model.state_dict()\n        require_keys = state_dict.keys()\n        keys = sd.keys()\n        un_pretrained_keys = []\n        for k in require_keys:\n            if k not in keys: \n                # miss 'vqvae.'\n                if k[6:] in keys:\n                    state_dict[k] = sd[k[6:]]\n                else:\n                    un_pretrained_keys.append(k)\n            else:\n                state_dict[k] = sd[k]\n\n        # print(f'*************************************************')\n        # print(f\"Layers without pretraining: {un_pretrained_keys}\")\n        # print(f'*************************************************')\n\n        model.load_state_dict(state_dict, strict=True)\n\n    if gpu:\n        model.cuda()\n    if eval_mode:\n        model.eval()\n    return {\"model\": model}\n\n\ndef load_model_and_dset(config, ckpt, gpu, eval_mode):\n\n    # now load the specified checkpoint\n    if ckpt:\n        pl_sd = torch.load(ckpt, map_location=\"cpu\")\n    else:\n        pl_sd = {\"state_dict\": None}\n\n    model = load_model_from_config(config.model,\n                                   pl_sd[\"state_dict\"],\n                                   gpu=gpu,\n                                   eval_mode=eval_mode)[\"model\"]\n    return model\n\nif __name__ == \"__main__\":\n    sys.path.append(os.getcwd())\n\n    parser = get_parser()\n\n    opt, unknown = parser.parse_known_args()\n\n    ckpt = None\n    if opt.resume:\n        if not os.path.exists(opt.resume):\n            raise ValueError(\"Cannot find {}\".format(opt.resume))\n        if os.path.isfile(opt.resume):\n            paths = opt.resume.split(\"/\")\n            try:\n                idx = len(paths)-paths[::-1].index(\"logs\")+1\n            except ValueError:\n                idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt\n            logdir = \"/\".join(paths[:idx])\n            ckpt = opt.resume\n        else:\n            assert os.path.isdir(opt.resume), opt.resume\n            logdir = opt.resume.rstrip(\"/\")\n            ckpt = os.path.join(logdir, \"checkpoints\", \"last.ckpt\")\n        print(f\"logdir:{logdir}\")\n        base_configs = sorted(glob.glob(os.path.join(logdir, \"configs/*-project.yaml\")))\n        opt.base = base_configs+opt.base\n\n    if opt.config:\n        if type(opt.config) == str:\n            if not os.path.exists(opt.config):\n                raise ValueError(\"Cannot find {}\".format(opt.config))\n            if os.path.isfile(opt.config):\n                opt.base = [opt.config]\n            else:\n                opt.base = sorted(glob.glob(os.path.join(opt.config, \"*-project.yaml\")))\n        else:\n            opt.base = [opt.base[-1]]\n\n    configs = [OmegaConf.load(cfg) for cfg in opt.base]\n    cli = OmegaConf.from_dotlist(unknown)\n    if opt.ignore_base_data:\n        for config in configs:\n            if hasattr(config, \"data\"): del config[\"data\"]\n    config = OmegaConf.merge(*configs, cli)\n    \n    print(config)\n    gpu = True\n    eval_mode = True\n    show_config = False\n    if show_config:\n        print(OmegaConf.to_container(config))\n\n    model = load_model_and_dset(config, ckpt, gpu, eval_mode)\n    \n    outdir = opt.outdir\n    os.makedirs(outdir, exist_ok=True)\n    print(\"Writing samples to \", outdir)\n\n    # initialize face helper\n    face_helper = FaceRestoreHelper(\n        opt.upscale_factor, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png')\n\n    img_list = sorted(glob.glob(os.path.join(opt.test_path, '*')))\n\n    print('Results are in the <{}> folder.'.format(outdir))\n    \n    for img_path in tqdm(img_list):\n        restoration(\n                model,\n                face_helper,\n                img_path,\n                outdir,\n                has_aligned=opt.aligned,\n                only_center_face=opt.only_center_face,\n                suffix=opt.suffix,\n                paste_back=opt.paste_back)\n\n    print('Test number: ', len(img_list))\n    print('Results are in the <{}> folder.'.format(outdir))\n"
  },
  {
    "path": "scripts/test.sh",
    "content": "# # ### Good\nexp_name='RestoreFormer'\n\nroot_path='experiments'\nout_root_path='results'\nalign_test_path='data/test'\ntag='test'\n\noutdir=$out_root_path'/'$exp_name'_'$tag\n\nif [ ! -d $outdir ];then\n    mkdir $outdir\nfi\n\npython -u scripts/test.py \\\n--outdir $outdir \\\n-r $root_path'/'$exp_name'/last.ckpt' \\\n-c 'configs/RestoreFormer.yaml' \\\n--test_path $align_test_path \\\n--aligned\n\n"
  }
]