[
  {
    "path": ".gitignore",
    "content": "/data\n/results\n*.dat\n*.pth\n*.pt\n\n# *.txt\n# !/expe/*/*.txt\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n.idea/\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n.vscode/settings.json\n"
  },
  {
    "path": "LICENSE",
    "content": "Attribution-NonCommercial-ShareAlike 4.0 International\n\n=======================================================================\n\nCreative Commons Corporation (\"Creative Commons\") is not a law firm and\ndoes not provide legal services or legal advice. Distribution of\nCreative Commons public licenses does not create a lawyer-client or\nother relationship. Creative Commons makes its licenses and related\ninformation available on an \"as-is\" basis. Creative Commons gives no\nwarranties regarding its licenses, any material licensed under their\nterms and conditions, or any related information. Creative Commons\ndisclaims all liability for damages resulting from their use to the\nfullest extent possible.\n\nUsing Creative Commons Public Licenses\n\nCreative Commons public licenses provide a standard set of terms and\nconditions that creators and other rights holders may use to share\noriginal works of authorship and other material subject to copyright\nand certain other rights specified in the public license below. The\nfollowing considerations are for informational purposes only, are not\nexhaustive, and do not form part of our licenses.\n\n     Considerations for licensors: Our public licenses are\n     intended for use by those authorized to give the public\n     permission to use material in ways otherwise restricted by\n     copyright and certain other rights. Our licenses are\n     irrevocable. Licensors should read and understand the terms\n     and conditions of the license they choose before applying it.\n     Licensors should also secure all rights necessary before\n     applying our licenses so that the public can reuse the\n     material as expected. Licensors should clearly mark any\n     material not subject to the license. This includes other CC-\n     licensed material, or material used under an exception or\n     limitation to copyright. More considerations for licensors:\n    wiki.creativecommons.org/Considerations_for_licensors\n\n     Considerations for the public: By using one of our public\n     licenses, a licensor grants the public permission to use the\n     licensed material under specified terms and conditions. If\n     the licensor's permission is not necessary for any reason--for\n     example, because of any applicable exception or limitation to\n     copyright--then that use is not regulated by the license. Our\n     licenses grant only permissions under copyright and certain\n     other rights that a licensor has authority to grant. Use of\n     the licensed material may still be restricted for other\n     reasons, including because others have copyright or other\n     rights in the material. A licensor may make special requests,\n     such as asking that all changes be marked or described.\n     Although not required by our licenses, you are encouraged to\n     respect those requests where reasonable. More considerations\n     for the public:\n    wiki.creativecommons.org/Considerations_for_licensees\n\n=======================================================================\n\nCreative Commons Attribution-NonCommercial-ShareAlike 4.0 International\nPublic License\n\nBy exercising the Licensed Rights (defined below), You accept and agree\nto be bound by the terms and conditions of this Creative Commons\nAttribution-NonCommercial-ShareAlike 4.0 International Public License\n(\"Public License\"). To the extent this Public License may be\ninterpreted as a contract, You are granted the Licensed Rights in\nconsideration of Your acceptance of these terms and conditions, and the\nLicensor grants You such rights in consideration of benefits the\nLicensor receives from making the Licensed Material available under\nthese terms and conditions.\n\n\nSection 1 -- Definitions.\n\n  a. Adapted Material means material subject to Copyright and Similar\n     Rights that is derived from or based upon the Licensed Material\n     and in which the Licensed Material is translated, altered,\n     arranged, transformed, or otherwise modified in a manner requiring\n     permission under the Copyright and Similar Rights held by the\n     Licensor. For purposes of this Public License, where the Licensed\n     Material is a musical work, performance, or sound recording,\n     Adapted Material is always produced where the Licensed Material is\n     synched in timed relation with a moving image.\n\n  b. Adapter's License means the license You apply to Your Copyright\n     and Similar Rights in Your contributions to Adapted Material in\n     accordance with the terms and conditions of this Public License.\n\n  c. BY-NC-SA Compatible License means a license listed at\n     creativecommons.org/compatiblelicenses, approved by Creative\n     Commons as essentially the equivalent of this Public License.\n\n  d. Copyright and Similar Rights means copyright and/or similar rights\n     closely related to copyright including, without limitation,\n     performance, broadcast, sound recording, and Sui Generis Database\n     Rights, without regard to how the rights are labeled or\n     categorized. For purposes of this Public License, the rights\n     specified in Section 2(b)(1)-(2) are not Copyright and Similar\n     Rights.\n\n  e. Effective Technological Measures means those measures that, in the\n     absence of proper authority, may not be circumvented under laws\n     fulfilling obligations under Article 11 of the WIPO Copyright\n     Treaty adopted on December 20, 1996, and/or similar international\n     agreements.\n\n  f. Exceptions and Limitations means fair use, fair dealing, and/or\n     any other exception or limitation to Copyright and Similar Rights\n     that applies to Your use of the Licensed Material.\n\n  g. License Elements means the license attributes listed in the name\n     of a Creative Commons Public License. The License Elements of this\n     Public License are Attribution, NonCommercial, and ShareAlike.\n\n  h. Licensed Material means the artistic or literary work, database,\n     or other material to which the Licensor applied this Public\n     License.\n\n  i. Licensed Rights means the rights granted to You subject to the\n     terms and conditions of this Public License, which are limited to\n     all Copyright and Similar Rights that apply to Your use of the\n     Licensed Material and that the Licensor has authority to license.\n\n  j. Licensor means the individual(s) or entity(ies) granting rights\n     under this Public License.\n\n  k. NonCommercial means not primarily intended for or directed towards\n     commercial advantage or monetary compensation. For purposes of\n     this Public License, the exchange of the Licensed Material for\n     other material subject to Copyright and Similar Rights by digital\n     file-sharing or similar means is NonCommercial provided there is\n     no payment of monetary compensation in connection with the\n     exchange.\n\n  l. Share means to provide material to the public by any means or\n     process that requires permission under the Licensed Rights, such\n     as reproduction, public display, public performance, distribution,\n     dissemination, communication, or importation, and to make material\n     available to the public including in ways that members of the\n     public may access the material from a place and at a time\n     individually chosen by them.\n\n  m. Sui Generis Database Rights means rights other than copyright\n     resulting from Directive 96/9/EC of the European Parliament and of\n     the Council of 11 March 1996 on the legal protection of databases,\n     as amended and/or succeeded, as well as other essentially\n     equivalent rights anywhere in the world.\n\n  n. You means the individual or entity exercising the Licensed Rights\n     under this Public License. Your has a corresponding meaning.\n\n\nSection 2 -- Scope.\n\n  a. License grant.\n\n       1. Subject to the terms and conditions of this Public License,\n          the Licensor hereby grants You a worldwide, royalty-free,\n          non-sublicensable, non-exclusive, irrevocable license to\n          exercise the Licensed Rights in the Licensed Material to:\n\n            a. reproduce and Share the Licensed Material, in whole or\n               in part, for NonCommercial purposes only; and\n\n            b. produce, reproduce, and Share Adapted Material for\n               NonCommercial purposes only.\n\n       2. Exceptions and Limitations. For the avoidance of doubt, where\n          Exceptions and Limitations apply to Your use, this Public\n          License does not apply, and You do not need to comply with\n          its terms and conditions.\n\n       3. Term. The term of this Public License is specified in Section\n          6(a).\n\n       4. Media and formats; technical modifications allowed. The\n          Licensor authorizes You to exercise the Licensed Rights in\n          all media and formats whether now known or hereafter created,\n          and to make technical modifications necessary to do so. The\n          Licensor waives and/or agrees not to assert any right or\n          authority to forbid You from making technical modifications\n          necessary to exercise the Licensed Rights, including\n          technical modifications necessary to circumvent Effective\n          Technological Measures. For purposes of this Public License,\n          simply making modifications authorized by this Section 2(a)\n          (4) never produces Adapted Material.\n\n       5. Downstream recipients.\n\n            a. Offer from the Licensor -- Licensed Material. Every\n               recipient of the Licensed Material automatically\n               receives an offer from the Licensor to exercise the\n               Licensed Rights under the terms and conditions of this\n               Public License.\n\n            b. Additional offer from the Licensor -- Adapted Material.\n               Every recipient of Adapted Material from You\n               automatically receives an offer from the Licensor to\n               exercise the Licensed Rights in the Adapted Material\n               under the conditions of the Adapter's License You apply.\n\n            c. No downstream restrictions. You may not offer or impose\n               any additional or different terms or conditions on, or\n               apply any Effective Technological Measures to, the\n               Licensed Material if doing so restricts exercise of the\n               Licensed Rights by any recipient of the Licensed\n               Material.\n\n       6. No endorsement. Nothing in this Public License constitutes or\n          may be construed as permission to assert or imply that You\n          are, or that Your use of the Licensed Material is, connected\n          with, or sponsored, endorsed, or granted official status by,\n          the Licensor or others designated to receive attribution as\n          provided in Section 3(a)(1)(A)(i).\n\n  b. Other rights.\n\n       1. Moral rights, such as the right of integrity, are not\n          licensed under this Public License, nor are publicity,\n          privacy, and/or other similar personality rights; however, to\n          the extent possible, the Licensor waives and/or agrees not to\n          assert any such rights held by the Licensor to the limited\n          extent necessary to allow You to exercise the Licensed\n          Rights, but not otherwise.\n\n       2. Patent and trademark rights are not licensed under this\n          Public License.\n\n       3. To the extent possible, the Licensor waives any right to\n          collect royalties from You for the exercise of the Licensed\n          Rights, whether directly or through a collecting society\n          under any voluntary or waivable statutory or compulsory\n          licensing scheme. In all other cases the Licensor expressly\n          reserves any right to collect such royalties, including when\n          the Licensed Material is used other than for NonCommercial\n          purposes.\n\n\nSection 3 -- License Conditions.\n\nYour exercise of the Licensed Rights is expressly made subject to the\nfollowing conditions.\n\n  a. Attribution.\n\n       1. If You Share the Licensed Material (including in modified\n          form), You must:\n\n            a. retain the following if it is supplied by the Licensor\n               with the Licensed Material:\n\n                 i. identification of the creator(s) of the Licensed\n                    Material and any others designated to receive\n                    attribution, in any reasonable manner requested by\n                    the Licensor (including by pseudonym if\n                    designated);\n\n                ii. a copyright notice;\n\n               iii. a notice that refers to this Public License;\n\n                iv. a notice that refers to the disclaimer of\n                    warranties;\n\n                 v. a URI or hyperlink to the Licensed Material to the\n                    extent reasonably practicable;\n\n            b. indicate if You modified the Licensed Material and\n               retain an indication of any previous modifications; and\n\n            c. indicate the Licensed Material is licensed under this\n               Public License, and include the text of, or the URI or\n               hyperlink to, this Public License.\n\n       2. You may satisfy the conditions in Section 3(a)(1) in any\n          reasonable manner based on the medium, means, and context in\n          which You Share the Licensed Material. For example, it may be\n          reasonable to satisfy the conditions by providing a URI or\n          hyperlink to a resource that includes the required\n          information.\n       3. If requested by the Licensor, You must remove any of the\n          information required by Section 3(a)(1)(A) to the extent\n          reasonably practicable.\n\n  b. ShareAlike.\n\n     In addition to the conditions in Section 3(a), if You Share\n     Adapted Material You produce, the following conditions also apply.\n\n       1. The Adapter's License You apply must be a Creative Commons\n          license with the same License Elements, this version or\n          later, or a BY-NC-SA Compatible License.\n\n       2. You must include the text of, or the URI or hyperlink to, the\n          Adapter's License You apply. You may satisfy this condition\n          in any reasonable manner based on the medium, means, and\n          context in which You Share Adapted Material.\n\n       3. You may not offer or impose any additional or different terms\n          or conditions on, or apply any Effective Technological\n          Measures to, Adapted Material that restrict exercise of the\n          rights granted under the Adapter's License You apply.\n\n\nSection 4 -- Sui Generis Database Rights.\n\nWhere the Licensed Rights include Sui Generis Database Rights that\napply to Your use of the Licensed Material:\n\n  a. for the avoidance of doubt, Section 2(a)(1) grants You the right\n     to extract, reuse, reproduce, and Share all or a substantial\n     portion of the contents of the database for NonCommercial purposes\n     only;\n\n  b. if You include all or a substantial portion of the database\n     contents in a database in which You have Sui Generis Database\n     Rights, then the database in which You have Sui Generis Database\n     Rights (but not its individual contents) is Adapted Material,\n     including for purposes of Section 3(b); and\n\n  c. You must comply with the conditions in Section 3(a) if You Share\n     all or a substantial portion of the contents of the database.\n\nFor the avoidance of doubt, this Section 4 supplements and does not\nreplace Your obligations under this Public License where the Licensed\nRights include other Copyright and Similar Rights.\n\n\nSection 5 -- Disclaimer of Warranties and Limitation of Liability.\n\n  a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE\n     EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS\n     AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF\n     ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,\n     IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,\n     WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR\n     PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,\n     ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT\n     KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT\n     ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.\n\n  b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE\n     TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,\n     NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,\n     INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,\n     COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR\n     USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN\n     ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR\n     DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR\n     IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.\n\n  c. The disclaimer of warranties and limitation of liability provided\n     above shall be interpreted in a manner that, to the extent\n     possible, most closely approximates an absolute disclaimer and\n     waiver of all liability.\n\n\nSection 6 -- Term and Termination.\n\n  a. This Public License applies for the term of the Copyright and\n     Similar Rights licensed here. However, if You fail to comply with\n     this Public License, then Your rights under this Public License\n     terminate automatically.\n\n  b. Where Your right to use the Licensed Material has terminated under\n     Section 6(a), it reinstates:\n\n       1. automatically as of the date the violation is cured, provided\n          it is cured within 30 days of Your discovery of the\n          violation; or\n\n       2. upon express reinstatement by the Licensor.\n\n     For the avoidance of doubt, this Section 6(b) does not affect any\n     right the Licensor may have to seek remedies for Your violations\n     of this Public License.\n\n  c. For the avoidance of doubt, the Licensor may also offer the\n     Licensed Material under separate terms or conditions or stop\n     distributing the Licensed Material at any time; however, doing so\n     will not terminate this Public License.\n\n  d. Sections 1, 5, 6, 7, and 8 survive termination of this Public\n     License.\n\n\nSection 7 -- Other Terms and Conditions.\n\n  a. The Licensor shall not be bound by any additional or different\n     terms or conditions communicated by You unless expressly agreed.\n\n  b. Any arrangements, understandings, or agreements regarding the\n     Licensed Material not stated herein are separate from and\n     independent of the terms and conditions of this Public License.\n\n\nSection 8 -- Interpretation.\n\n  a. For the avoidance of doubt, this Public License does not, and\n     shall not be interpreted to, reduce, limit, restrict, or impose\n     conditions on any use of the Licensed Material that could lawfully\n     be made without permission under this Public License.\n\n  b. To the extent possible, if any provision of this Public License is\n     deemed unenforceable, it shall be automatically reformed to the\n     minimum extent necessary to make it enforceable. If the provision\n     cannot be reformed, it shall be severed from this Public License\n     without affecting the enforceability of the remaining terms and\n     conditions.\n\n  c. No term or condition of this Public License will be waived and no\n     failure to comply consented to unless expressly agreed to by the\n     Licensor.\n\n  d. Nothing in this Public License constitutes or may be interpreted\n     as a limitation upon, or waiver of, any privileges and immunities\n     that apply to the Licensor or You, including from the legal\n     processes of any jurisdiction or authority.\n\n=======================================================================\n\nCreative Commons is not a party to its public\nlicenses. Notwithstanding, Creative Commons may elect to apply one of\nits public licenses to material it publishes and in those instances\nwill be considered the “Licensor.” The text of the Creative Commons\npublic licenses is dedicated to the public domain under the CC0 Public\nDomain Dedication. Except for the limited purpose of indicating that\nmaterial is shared under a Creative Commons public license or as\notherwise permitted by the Creative Commons policies published at\ncreativecommons.org/policies, Creative Commons does not authorize the\nuse of the trademark \"Creative Commons\" or any other trademark or logo\nof Creative Commons without its prior written consent including,\nwithout limitation, in connection with any unauthorized modifications\nto any of its public licenses or any other arrangements,\nunderstandings, or agreements concerning use of licensed material. For\nthe avoidance of doubt, this paragraph does not form part of the\npublic licenses.\n\nCreative Commons may be contacted at creativecommons.org."
  },
  {
    "path": "README.md",
    "content": "# EleGANt: Exquisite and Locally Editable GAN for Makeup Transfer\n\n[![CC BY-NC-SA 4.0][cc-by-nc-sa-shield]][cc-by-nc-sa]\n\nOfficial [PyTorch](https://pytorch.org/) implementation of ECCV 2022 paper \"[EleGANt: Exquisite and Locally Editable GAN for Makeup Transfer](https://arxiv.org/abs/2207.09840)\"\n\n*Chenyu Yang, Wanrong He, Yingqing Xu, and Yang Gao*.\n\n![teaser](assets/figs/teaser.png)\n\n## Getting Started\n\n- [Installation](assets/docs/install.md)\n- [Prepare Dataset & Checkpoints](assets/docs/prepare.md)\n\n## Test\n\nTo test our model, download the [weights](https://drive.google.com/drive/folders/1xzIS3Dfmsssxkk9OhhAS4svrZSPfQYRe?usp=sharing) of the trained model and run\n\n```bash\npython scripts/demo.py\n```\n\nExamples of makeup transfer results can be seen [here](assets/images/examples/).\n\n## Train\n\nTo train a model from scratch, run\n\n```bash\npython scripts/train.py\n```\n\n## Customized Transfer\n\nhttps://user-images.githubusercontent.com/61506577/180593092-ccadddff-76be-4b7b-921e-0d3b4cc27d9b.mp4\n\nThis is our demo of customized makeup editing. The interactive system is built upon [Streamlit](https://github.com/streamlit/streamlit) and the interface in `./training/inference.py`.\n\n**Controllable makeup transfer.**\n\n![control](assets/figs/control.png 'controllable makeup transfer')\n\n**Local makeup editing.**\n\n![edit](assets/figs/edit.png 'local makeup editing')\n\n## Citation\n\nIf this work is helpful for your research, please consider citing the following BibTeX entry.\n\n```text\n@article{yang2022elegant,\n  title={EleGANt: Exquisite and Locally Editable GAN for Makeup Transfer},\n  author={Yang, Chenyu and He, Wanrong and Xu, Yingqing and Gao, Yang}\n  journal={arXiv preprint arXiv:2207.09840},\n  year={2022}\n}\n```\n\n## Acknowledgement\n\nSome of the codes are build upon [PSGAN](https://github.com/wtjiang98/PSGAN) and [aster.Pytorch](https://github.com/ayumiymk/aster.pytorch).\n\n## License\n\nThis work is licensed under a\n[Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License][cc-by-nc-sa].\n\n[![CC BY-NC-SA 4.0][cc-by-nc-sa-image]][cc-by-nc-sa]\n\n[cc-by-nc-sa]: http://creativecommons.org/licenses/by-nc-sa/4.0/\n[cc-by-nc-sa-image]: https://licensebuttons.net/l/by-nc-sa/4.0/88x31.png\n[cc-by-nc-sa-shield]: https://img.shields.io/badge/License-CC%20BY--NC--SA%204.0-lightgrey.svg\n"
  },
  {
    "path": "assets/docs/install.md",
    "content": "# Installation Instructions\r\n\r\nThis code was tested on Ubuntu 20.04 with CUDA 11.1.\r\n\r\n**a. Create a conda virtual environment and activate it.**\r\n\r\n```bash\r\nconda create -n elegant python=3.8\r\nconda activate elegant\r\n```\r\n\r\n**b. Install PyTorch and torchvision following the [official instructions](https://pytorch.org/).**\r\n\r\n```bash\r\npip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html\r\n```\r\n\r\n**c. Install other required libaries.**\r\n\r\n```bash\r\npip install opencv-python matplotlib dlib fvcore\r\n```\r\n"
  },
  {
    "path": "assets/docs/prepare.md",
    "content": "# Preparation Instructions\r\n\r\nClone this repository and prepare the dataset and weights through the following steps:\r\n\r\n**a. Prepare model weights for face detection.**\r\n\r\nDownload the weights of [dlib](https://github.com/davisking/dlib) face detector of 68 landmarks [here](http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2). Unzip it and move it to the directory `./faceutils/dlibutils`.\r\n\r\nDownload the weights of BiSeNet ([PyTorch implementation](https://github.com/zllrunning/face-parsing.PyTorch)) for face parsing [here](https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812). Rename it as `resnet.pth` and move it to the directory `./faceutils/mask`.\r\n\r\n**b. Prepare Makeup Transfer (MT) dataset.**\r\n\r\nDownload raw data of the MT Dataset [here](https://github.com/wtjiang98/PSGAN) and unzip it into sub directory `./data`.\r\n\r\nRun the following command to preprocess data:\r\n\r\n```bash\r\npython training/preprocess.py\r\n```\r\n\r\nYour data directory should look like:\r\n\r\n```text\r\ndata\r\n└── MT-Dataset\r\n    ├── images\r\n    │   ├── makeup\r\n    │   └── non-makeup\r\n    ├── segs\r\n    │   ├── makeup\r\n    │   └── non-makeup\r\n    ├── lms\r\n    │   ├── makeup\r\n    │   └── non-makeup\r\n    ├── makeup.txt\r\n    ├── non-makeup.txt\r\n    └── ...\r\n```\r\n\r\n**c. Download weights of trained EleGANt.**\r\n\r\nThe weights of our trained model can be download [here](https://drive.google.com/drive/folders/1xzIS3Dfmsssxkk9OhhAS4svrZSPfQYRe?usp=sharing). Put it under the directory `./ckpts`.\r\n"
  },
  {
    "path": "concern/__init__.py",
    "content": "from .image import load_image\n"
  },
  {
    "path": "concern/image.py",
    "content": "import numpy as np\nimport cv2\nfrom io import BytesIO\n\n\ndef load_image(path):\n    with path.open(\"rb\") as reader:\n        data = np.fromstring(reader.read(), dtype=np.uint8)\n        img = cv2.imdecode(data, cv2.IMREAD_COLOR)\n        if img is None:\n            return\n        img = img[..., ::-1]\n    return img\n\ndef resize_by_max(image, max_side=512, force=False):\n    h, w = image.shape[:2]\n    if max(h, w) < max_side and not force:\n        return image\n    ratio = max(h, w) / max_side\n\n    w = int(w / ratio + 0.5)\n    h = int(h / ratio + 0.5)\n    return cv2.resize(image, (w, h))\n\ndef image2buffer(image):\n    is_success, buffer = cv2.imencode(\".jpg\", image)\n    if not is_success:\n        return None\n    return BytesIO(buffer)\n"
  },
  {
    "path": "concern/track.py",
    "content": "import time\n\nimport torch\n\n\nclass Track:\n    def __init__(self):\n        self.log_point = time.time()\n        self.enable_track = False\n\n    def track(self, mark):\n        if not self.enable_track:\n            return\n\n        if torch.cuda.is_available():\n            torch.cuda.synchronize()\n            print(\"{} memory:\".format(mark), torch.cuda.memory_allocated() / 1024 / 1024, \"M\")\n        print(\"{} time cost:\".format(mark), time.time() - self.log_point)\n        self.log_point = time.time()\n"
  },
  {
    "path": "concern/visualize.py",
    "content": "import numpy as np\nimport cv2\n\n\ndef channel_first(image, format):\n    return image.transpose(\n        format.index(\"C\"), format.index(\"H\"), format.index(\"W\"))\n\ndef mask2image(mask:np.array, format=\"HWC\"):\n    H, W = mask.shape\n\n    canvas = np.zeros((H, W, 3), dtype=np.uint8)\n    for i in range(int(mask.max())):\n        color = np.random.rand(1, 1, 3) * 255\n        canvas += (mask == i)[:, :, None] * color.astype(np.uint8)\n    return canvas\n\ndef draw_points(image, points, color=(255, 0, 0)):\n    for point in points:\n        print(int(point[1]), int(point[0]))\n        image = cv2.circle(image, (int(point[1]), int(point[0])), 3, color)\n\n    if hasattr(image, \"get\"):\n        return image.get()\n    return image\n"
  },
  {
    "path": "faceutils/__init__.py",
    "content": "#!/usr/bin/python\n# -*- encoding: utf-8 -*-\n#from . import faceplusplus as fpp\nfrom . import dlibutils as dlib\nfrom . import mask\n"
  },
  {
    "path": "faceutils/dlibutils/__init__.py",
    "content": "#!/usr/bin/python\n# -*- encoding: utf-8 -*-\nfrom .main import detect, crop, landmarks, crop_from_array\n"
  },
  {
    "path": "faceutils/dlibutils/main.py",
    "content": "#!/usr/bin/python\n# -*- encoding: utf-8 -*-\nimport os.path as osp\n\nimport numpy as np\nfrom PIL import Image\nimport dlib\nimport cv2\nfrom concern.image import resize_by_max\n\ndetector = dlib.get_frontal_face_detector()\npredictor = dlib.shape_predictor(osp.split(osp.realpath(__file__))[0] + '/shape_predictor_68_face_landmarks.dat')\n\n\ndef detect(image: Image) -> 'faces':\n    image = np.asarray(image)\n    h, w = image.shape[:2]\n    image = resize_by_max(image, 361)\n    actual_h, actual_w = image.shape[:2]\n    faces_on_small = detector(image, 1)\n    faces = dlib.rectangles()\n    for face in faces_on_small:\n        faces.append(\n            dlib.rectangle(\n                int(face.left() / actual_w * w + 0.5),\n                int(face.top() / actual_h * h + 0.5),\n                int(face.right() / actual_w * w + 0.5),\n                int(face.bottom() / actual_h * h  + 0.5)\n            )\n        )\n    return faces\n\ndef crop(image: Image, face, up_ratio, down_ratio, width_ratio) -> (Image, 'face'):\n    width, height = image.size\n    face_height = face.height()\n    face_width = face.width()\n    delta_up = up_ratio * face_height\n    delta_down = down_ratio * face_height\n    delta_width = width_ratio * width\n\n    img_left = int(max(0, face.left() - delta_width))\n    img_top = int(max(0, face.top() - delta_up))\n    img_right = int(min(width, face.right() + delta_width))\n    img_bottom = int(min(height, face.bottom() + delta_down))\n    image = image.crop((img_left, img_top, img_right, img_bottom))\n    face = dlib.rectangle(face.left() - img_left, face.top() - img_top,\n                        face.right() - img_left, face.bottom() - img_top)\n    face_expand = dlib.rectangle(img_left, img_top, img_right, img_bottom)\n    center = face_expand.center()\n    width, height = image.size\n    # import ipdb; ipdb.set_trace()\n    crop_left = img_left\n    crop_top = img_top\n    crop_right = img_right\n    crop_bottom = img_bottom\n    if width > height:\n        left = int(center.x - height / 2)\n        right = int(center.x + height / 2)\n        if left < 0:\n            left, right = 0, height\n        elif right > width:\n            left, right = width - height, width\n        image = image.crop((left, 0, right, height))\n        face = dlib.rectangle(face.left() - left, face.top(),\n                              face.right() - left, face.bottom())\n        crop_left += left\n        crop_right = crop_left + height\n    elif width < height:\n        top = int(center.y - width / 2)\n        bottom = int(center.y + width / 2)\n        if top < 0:\n            top, bottom = 0, width\n        elif bottom > height:\n            top, bottom = height - width, height\n        image = image.crop((0, top, width, bottom))\n        face = dlib.rectangle(face.left(), face.top() - top,\n                              face.right(), face.bottom() - top)\n        crop_top += top\n        crop_bottom = crop_top + width\n    crop_face = dlib.rectangle(crop_left, crop_top, crop_right, crop_bottom)\n    return image, face, crop_face\n\n\ndef crop_by_image_size(image: Image, face) -> (Image, 'face'):\n    center = face.center()\n    width, height = image.size\n    if width > height:\n        left = int(center.x - height / 2)\n        right = int(center.x + height / 2)\n        if left < 0:\n            left, right = 0, height\n        elif right > width:\n            left, right = width - height, width\n        image = image.crop((left, 0, right, height))\n        face = dlib.rectangle(face.left() - left, face.top(),\n                              face.right() - left, face.bottom())\n    elif width < height:\n        top = int(center.y - width / 2)\n        bottom = int(center.y + width / 2)\n        if top < 0:\n            top, bottom = 0, width\n        elif bottom > height:\n            top, bottom = height - width, height\n        image = image.crop((0, top, width, bottom))\n        face = dlib.rectangle(face.left(), face.top() - top, \n                              face.right(), face.bottom() - top)\n    return image, face\n\n\ndef landmarks(image: Image, face):\n    shape = predictor(np.asarray(image), face).parts()\n    return np.array([[p.y, p.x] for p in shape])\n\ndef crop_from_array(image: np.array, face) -> (np.array, 'face'):\n    ratio = 0.20 / 0.85 # delta_size / face_size\n    height, width = image.shape[:2]\n    face_height = face.height()\n    face_width = face.width()\n    delta_height = ratio * face_height\n    delta_width = ratio * width\n\n    img_left = int(max(0, face.left() - delta_width))\n    img_top = int(max(0, face.top() - delta_height))\n    img_right = int(min(width, face.right() + delta_width))\n    img_bottom = int(min(height, face.bottom() + delta_height))\n    image = image[img_top:img_bottom, img_left:img_right]\n    face = dlib.rectangle(face.left() - img_left, face.top() - img_top,\n                        face.right() - img_left, face.bottom() - img_top)\n    center = face.center()\n    height, width = image.shape[:2]\n    if width > height:\n        left = int(center.x - height / 2)\n        right = int(center.x + height / 2)\n        if left < 0:\n            left, right = 0, height\n        elif right > width:\n            left, right = width - height, width\n        image = image[0:height, left:right]\n        face = dlib.rectangle(face.left() - left, face.top(),\n                              face.right() - left, face.bottom())\n    elif width < height:\n        top = int(center.y - width / 2)\n        bottom = int(center.y + width / 2)\n        if top < 0:\n            top, bottom = 0, width\n        elif bottom > height:\n            top, bottom = height - width, height\n        image = image[top:bottom, 0:width]\n        face = dlib.rectangle(face.left(), face.top() - top,\n                              face.right(), face.bottom() - top)\n    return image, face\n\n"
  },
  {
    "path": "faceutils/mask/__init__.py",
    "content": "#!/usr/bin/python\n# -*- encoding: utf-8 -*-\nfrom .main import FaceParser\n"
  },
  {
    "path": "faceutils/mask/main.py",
    "content": "#!/usr/bin/python\n# -*- encoding: utf-8 -*-\nimport os.path as osp\n\nimport numpy as np\nimport cv2\nfrom PIL import Image\nimport torch\nimport torchvision.transforms as transforms\n\nfrom .model import BiSeNet\n\n\nclass FaceParser:\n    def __init__(self, device=\"cpu\"):\n        mapper = [0, 1, 2, 3, 4, 5, 0, 11, 12, 0, 6, 8, 7, 9, 13, 0, 0, 10, 0]\n        self.device = device\n        self.dic = torch.tensor(mapper, device=device).unsqueeze(1)\n        save_pth = osp.split(osp.realpath(__file__))[0] + '/resnet.pth'\n\n        net = BiSeNet(n_classes=19)\n        net.load_state_dict(torch.load(save_pth, map_location=device))\n        self.net = net.to(device).eval()\n        self.to_tensor = transforms.Compose([\n            transforms.ToTensor(),\n            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n        ])\n\n\n    def parse(self, image: Image):\n        assert image.shape[:2] == (512, 512)\n        with torch.no_grad():\n            image = self.to_tensor(image).to(self.device)\n            image = torch.unsqueeze(image, 0)\n            out = self.net(image)[0]\n            parsing = out.squeeze(0).argmax(0)\n        parsing = torch.nn.functional.embedding(parsing, self.dic)\n        return parsing.float().squeeze(2)\n\n"
  },
  {
    "path": "faceutils/mask/model.py",
    "content": "#!/usr/bin/python\n# -*- encoding: utf-8 -*-\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision\n\nfrom .resnet import Resnet18\n\n\nclass ConvBNReLU(nn.Module):\n    def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):\n        super(ConvBNReLU, self).__init__()\n        self.conv = nn.Conv2d(in_chan,\n                out_chan,\n                kernel_size = ks,\n                stride = stride,\n                padding = padding,\n                bias = False)\n        self.bn = nn.BatchNorm2d(out_chan)\n        self.init_weight()\n\n    def forward(self, x):\n        x = self.conv(x)\n        x = F.relu(self.bn(x))\n        return x\n\n    def init_weight(self):\n        for ly in self.children():\n            if isinstance(ly, nn.Conv2d):\n                nn.init.kaiming_normal_(ly.weight, a=1)\n                if not ly.bias is None: nn.init.constant_(ly.bias, 0)\n\nclass BiSeNetOutput(nn.Module):\n    def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):\n        super(BiSeNetOutput, self).__init__()\n        self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)\n        self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)\n        self.init_weight()\n\n    def forward(self, x):\n        x = self.conv(x)\n        x = self.conv_out(x)\n        return x\n\n    def init_weight(self):\n        for ly in self.children():\n            if isinstance(ly, nn.Conv2d):\n                nn.init.kaiming_normal_(ly.weight, a=1)\n                if not ly.bias is None: nn.init.constant_(ly.bias, 0)\n\n    def get_params(self):\n        wd_params, nowd_params = [], []\n        for name, module in self.named_modules():\n            if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):\n                wd_params.append(module.weight)\n                if not module.bias is None:\n                    nowd_params.append(module.bias)\n            elif isinstance(module, nn.BatchNorm2d):\n                nowd_params += list(module.parameters())\n        return wd_params, nowd_params\n\n\nclass AttentionRefinementModule(nn.Module):\n    def __init__(self, in_chan, out_chan, *args, **kwargs):\n        super(AttentionRefinementModule, self).__init__()\n        self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)\n        self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)\n        self.bn_atten = nn.BatchNorm2d(out_chan)\n        self.sigmoid_atten = nn.Sigmoid()\n        self.init_weight()\n\n    def forward(self, x):\n        feat = self.conv(x)\n        atten = F.avg_pool2d(feat, feat.size()[2:])\n        atten = self.conv_atten(atten)\n        atten = self.bn_atten(atten)\n        atten = self.sigmoid_atten(atten)\n        out = torch.mul(feat, atten)\n        return out\n\n    def init_weight(self):\n        for ly in self.children():\n            if isinstance(ly, nn.Conv2d):\n                nn.init.kaiming_normal_(ly.weight, a=1)\n                if not ly.bias is None: nn.init.constant_(ly.bias, 0)\n\n\nclass ContextPath(nn.Module):\n    def __init__(self, *args, **kwargs):\n        super(ContextPath, self).__init__()\n        self.resnet = Resnet18()\n        self.arm16 = AttentionRefinementModule(256, 128)\n        self.arm32 = AttentionRefinementModule(512, 128)\n        self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)\n        self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)\n        self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)\n\n        self.init_weight()\n\n    def forward(self, x):\n        H0, W0 = x.size()[2:]\n        feat8, feat16, feat32 = self.resnet(x)\n        H8, W8 = feat8.size()[2:]\n        H16, W16 = feat16.size()[2:]\n        H32, W32 = feat32.size()[2:]\n\n        avg = F.avg_pool2d(feat32, feat32.size()[2:])\n        avg = self.conv_avg(avg)\n        avg_up = F.interpolate(avg, (H32, W32), mode='nearest')\n\n        feat32_arm = self.arm32(feat32)\n        feat32_sum = feat32_arm + avg_up\n        feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')\n        feat32_up = self.conv_head32(feat32_up)\n\n        feat16_arm = self.arm16(feat16)\n        feat16_sum = feat16_arm + feat32_up\n        feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')\n        feat16_up = self.conv_head16(feat16_up)\n\n        return feat8, feat16_up, feat32_up  # x8, x8, x16\n\n    def init_weight(self):\n        for ly in self.children():\n            if isinstance(ly, nn.Conv2d):\n                nn.init.kaiming_normal_(ly.weight, a=1)\n                if not ly.bias is None: nn.init.constant_(ly.bias, 0)\n\n    def get_params(self):\n        wd_params, nowd_params = [], []\n        for name, module in self.named_modules():\n            if isinstance(module, (nn.Linear, nn.Conv2d)):\n                wd_params.append(module.weight)\n                if not module.bias is None:\n                    nowd_params.append(module.bias)\n            elif isinstance(module, nn.BatchNorm2d):\n                nowd_params += list(module.parameters())\n        return wd_params, nowd_params\n\n\n### This is not used, since I replace this with the resnet feature with the same size\nclass SpatialPath(nn.Module):\n    def __init__(self, *args, **kwargs):\n        super(SpatialPath, self).__init__()\n        self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)\n        self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)\n        self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)\n        self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)\n        self.init_weight()\n\n    def forward(self, x):\n        feat = self.conv1(x)\n        feat = self.conv2(feat)\n        feat = self.conv3(feat)\n        feat = self.conv_out(feat)\n        return feat\n\n    def init_weight(self):\n        for ly in self.children():\n            if isinstance(ly, nn.Conv2d):\n                nn.init.kaiming_normal_(ly.weight, a=1)\n                if not ly.bias is None: nn.init.constant_(ly.bias, 0)\n\n    def get_params(self):\n        wd_params, nowd_params = [], []\n        for name, module in self.named_modules():\n            if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):\n                wd_params.append(module.weight)\n                if not module.bias is None:\n                    nowd_params.append(module.bias)\n            elif isinstance(module, nn.BatchNorm2d):\n                nowd_params += list(module.parameters())\n        return wd_params, nowd_params\n\n\nclass FeatureFusionModule(nn.Module):\n    def __init__(self, in_chan, out_chan, *args, **kwargs):\n        super(FeatureFusionModule, self).__init__()\n        self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)\n        self.conv1 = nn.Conv2d(out_chan,\n                out_chan//4,\n                kernel_size = 1,\n                stride = 1,\n                padding = 0,\n                bias = False)\n        self.conv2 = nn.Conv2d(out_chan//4,\n                out_chan,\n                kernel_size = 1,\n                stride = 1,\n                padding = 0,\n                bias = False)\n        self.relu = nn.ReLU(inplace=True)\n        self.sigmoid = nn.Sigmoid()\n        self.init_weight()\n\n    def forward(self, fsp, fcp):\n        fcat = torch.cat([fsp, fcp], dim=1)\n        feat = self.convblk(fcat)\n        atten = F.avg_pool2d(feat, feat.size()[2:])\n        atten = self.conv1(atten)\n        atten = self.relu(atten)\n        atten = self.conv2(atten)\n        atten = self.sigmoid(atten)\n        feat_atten = torch.mul(feat, atten)\n        feat_out = feat_atten + feat\n        return feat_out\n\n    def init_weight(self):\n        for ly in self.children():\n            if isinstance(ly, nn.Conv2d):\n                nn.init.kaiming_normal_(ly.weight, a=1)\n                if not ly.bias is None: nn.init.constant_(ly.bias, 0)\n\n    def get_params(self):\n        wd_params, nowd_params = [], []\n        for name, module in self.named_modules():\n            if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):\n                wd_params.append(module.weight)\n                if not module.bias is None:\n                    nowd_params.append(module.bias)\n            elif isinstance(module, nn.BatchNorm2d):\n                nowd_params += list(module.parameters())\n        return wd_params, nowd_params\n\n\nclass BiSeNet(nn.Module):\n    def __init__(self, n_classes, *args, **kwargs):\n        super(BiSeNet, self).__init__()\n        self.cp = ContextPath()\n        ## here self.sp is deleted\n        self.ffm = FeatureFusionModule(256, 256)\n        self.conv_out = BiSeNetOutput(256, 256, n_classes)\n        self.conv_out16 = BiSeNetOutput(128, 64, n_classes)\n        self.conv_out32 = BiSeNetOutput(128, 64, n_classes)\n        # self.init_weight()\n\n    def forward(self, x):\n        H, W = x.size()[2:]\n        feat_res8, feat_cp8, feat_cp16 = self.cp(x)  # here return res3b1 feature\n        feat_sp = feat_res8  # use res3b1 feature to replace spatial path feature\n        feat_fuse = self.ffm(feat_sp, feat_cp8)\n\n        feat_out = self.conv_out(feat_fuse)\n        feat_out16 = self.conv_out16(feat_cp8)\n        feat_out32 = self.conv_out32(feat_cp16)\n\n        feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)\n        feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)\n        feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)\n        return feat_out, feat_out16, feat_out32\n\n    def init_weight(self):\n        for ly in self.children():\n            if isinstance(ly, nn.Conv2d):\n                nn.init.kaiming_normal_(ly.weight, a=1)\n                if not ly.bias is None: nn.init.constant_(ly.bias, 0)\n\n    def get_params(self):\n        wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []\n        for name, child in self.named_children():\n            child_wd_params, child_nowd_params = child.get_params()\n            if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):\n                lr_mul_wd_params += child_wd_params\n                lr_mul_nowd_params += child_nowd_params\n            else:\n                wd_params += child_wd_params\n                nowd_params += child_nowd_params\n        return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params\n\n\nif __name__ == \"__main__\":\n    net = BiSeNet(19)\n    net.cuda()\n    net.eval()\n    in_ten = torch.randn(16, 3, 640, 480).cuda()\n    out, out16, out32 = net(in_ten)\n    print(out.shape)\n\n    net.get_params()\n"
  },
  {
    "path": "faceutils/mask/resnet.py",
    "content": "#!/usr/bin/python\n# -*- encoding: utf-8 -*-\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.model_zoo as modelzoo\n\nresnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'\n\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=1, bias=False)\n\n\nclass BasicBlock(nn.Module):\n    def __init__(self, in_chan, out_chan, stride=1):\n        super(BasicBlock, self).__init__()\n        self.conv1 = conv3x3(in_chan, out_chan, stride)\n        self.bn1 = nn.BatchNorm2d(out_chan)\n        self.conv2 = conv3x3(out_chan, out_chan)\n        self.bn2 = nn.BatchNorm2d(out_chan)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = None\n        if in_chan != out_chan or stride != 1:\n            self.downsample = nn.Sequential(\n                nn.Conv2d(in_chan, out_chan,\n                          kernel_size=1, stride=stride, bias=False),\n                nn.BatchNorm2d(out_chan),\n                )\n\n    def forward(self, x):\n        residual = self.conv1(x)\n        residual = F.relu(self.bn1(residual))\n        residual = self.conv2(residual)\n        residual = self.bn2(residual)\n\n        shortcut = x\n        if self.downsample is not None:\n            shortcut = self.downsample(x)\n\n        out = shortcut + residual\n        out = self.relu(out)\n        return out\n\n\ndef create_layer_basic(in_chan, out_chan, bnum, stride=1):\n    layers = [BasicBlock(in_chan, out_chan, stride=stride)]\n    for i in range(bnum-1):\n        layers.append(BasicBlock(out_chan, out_chan, stride=1))\n    return nn.Sequential(*layers)\n\n\nclass Resnet18(nn.Module):\n    def __init__(self):\n        super(Resnet18, self).__init__()\n        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,\n                               bias=False)\n        self.bn1 = nn.BatchNorm2d(64)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)\n        self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)\n        self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)\n        self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)\n        # self.init_weight()\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = F.relu(self.bn1(x))\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        feat8 = self.layer2(x) # 1/8\n        feat16 = self.layer3(feat8) # 1/16\n        feat32 = self.layer4(feat16) # 1/32\n        return feat8, feat16, feat32\n\n    def init_weight(self):\n        state_dict = modelzoo.load_url(resnet18_url)\n        self_state_dict = self.state_dict()\n        for k, v in state_dict.items():\n            if 'fc' in k: continue\n            self_state_dict.update({k: v})\n        self.load_state_dict(self_state_dict)\n\n    def get_params(self):\n        wd_params, nowd_params = [], []\n        for name, module in self.named_modules():\n            if isinstance(module, (nn.Linear, nn.Conv2d)):\n                wd_params.append(module.weight)\n                if not module.bias is None:\n                    nowd_params.append(module.bias)\n            elif isinstance(module,  nn.BatchNorm2d):\n                nowd_params += list(module.parameters())\n        return wd_params, nowd_params\n\n\nif __name__ == \"__main__\":\n    net = Resnet18()\n    x = torch.randn(16, 3, 224, 224)\n    out = net(x)\n    print(out[0].size())\n    print(out[1].size())\n    print(out[2].size())\n    net.get_params()\n"
  },
  {
    "path": "models/__init__.py",
    "content": ""
  },
  {
    "path": "models/elegant.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .modules.module_base import ResidualBlock_IN, Downsample, Upsample, PositionalEmbedding, MergeBlock\nfrom .modules.module_attn import Attention_apply, FeedForwardLayer, MultiheadAttention \nfrom .modules.sow_attention import SowAttention\nfrom .modules.tps_transform import tps_spatial_transform\n\n\nclass Generator(nn.ModuleDict):\n    \"\"\"Generator. Encoder-Decoder Architecture.\"\"\"\n    def __init__(self, conv_dim=64, image_size=256, num_layer_e=2, num_layer_d=1, window_size=16, use_ff=False,\n                 merge_mode='conv', num_head=1, double_encoder=False, **unused):\n        super(Generator, self).__init__()\n\n        # -------------------------- Encoder --------------------------\n\n        layers = nn.Conv2d(3, conv_dim, kernel_size=7, stride=1, padding=3, bias=False)\n        self.add_module('in_conv', layers)\n\n        # Down-Sampling & Bottleneck\n        curr_dim = conv_dim; feature_size = image_size\n        for i in range(2):\n            layers = Downsample(curr_dim, curr_dim * 2, affine=True)\n            self.add_module('down_{:d}'.format(i+1), layers)\n            curr_dim = curr_dim * 2; feature_size = feature_size // 2\n\n            self.add_module('e_bottleneck_{:d}'.format(i+1), \n                nn.Sequential(*[ResidualBlock_IN(curr_dim, curr_dim, affine=True) for j in range(num_layer_e)])\n            )\n\n        ### second encoder\n        self.double_encoder = double_encoder\n        if self.double_encoder:\n            layers = nn.Conv2d(3, conv_dim, kernel_size=7, stride=1, padding=3, bias=False)\n            self.add_module('in_conv_s', layers)\n\n            # Down-Sampling & Bottleneck\n            curr_dim = conv_dim; feature_size = image_size\n            for i in range(2):\n                layers = Downsample(curr_dim, curr_dim * 2, affine=True)\n                self.add_module('down_{:d}_s'.format(i+1), layers)\n                curr_dim = curr_dim * 2; feature_size = feature_size // 2\n\n                self.add_module('e_bottleneck_{:d}_s'.format(i+1), \n                    nn.Sequential(*[ResidualBlock_IN(curr_dim, curr_dim, affine=True) for j in range(num_layer_e)])\n                )\n\n        # --------------------------- Transfer ----------------------------\n        curr_dim = conv_dim; feature_size = image_size\n        self.use_ff = use_ff\n        for i in range(2):\n            curr_dim = curr_dim * 2; feature_size = feature_size // 2\n            self.add_module('embedding_{:d}'.format(i+1), PositionalEmbedding(\n                embedding_dim=136,\n                feature_size=feature_size,\n                max_size=image_size,\n                embedding_type='l2_norm'\n            ))\n            if i < 1:\n                self.add_module('attention_extract_{:d}'.format(i+1), SowAttention(\n                    window_size=window_size,\n                    in_channels=curr_dim + 136,\n                    proj_channels=curr_dim + 136,\n                    value_channels=curr_dim,\n                    out_channels=curr_dim,\n                    num_heads=num_head\n                ))\n            else:\n                self.add_module('attention_extract_{:d}'.format(i+1), MultiheadAttention(\n                    in_channels=curr_dim + 136,\n                    proj_channels=curr_dim + 136,\n                    value_channels=curr_dim,\n                    out_channels=curr_dim,\n                    num_heads=num_head\n                ))\n                \n            if use_ff:\n                self.add_module('feedforward_{:d}'.format(i+1), FeedForwardLayer(curr_dim, curr_dim))\n            self.add_module('attention_apply_{:d}'.format(i+1), Attention_apply(curr_dim))           \n\n        # --------------------------- Decoder ----------------------------\n\n        # Bottleneck & Up-Sampling & Merge\n        for i in range(2):\n            self.add_module('d_bottleneck_{:d}'.format(i+1), \n                nn.Sequential(*[ResidualBlock_IN(curr_dim, curr_dim, affine=True) for j in range(num_layer_d)])\n            )            \n            layers = Upsample(curr_dim, curr_dim // 2, affine=True)\n            self.add_module('up_{:d}'.format(i+1), layers)\n            curr_dim = curr_dim // 2\n            if i < 1:\n                self.add_module('merge_{:d}'.format(i+1), MergeBlock(merge_mode, curr_dim))\n\n        layers = nn.Sequential(\n            nn.InstanceNorm2d(curr_dim, affine=True),\n            nn.LeakyReLU(0.2, inplace=True),\n            nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False),\n        )\n        self.add_module('out_conv', layers)\n\n\n    def get_transfer_input(self, image, mask, diff, lms, is_reference=False):\n        feature_size = image.shape[2]; scale_factor = 1.0\n        fea_list, mask_list, diff_list, lms_list = [], [], [], []\n\n        # input conv\n        if self.double_encoder and is_reference:\n            fea = self['in_conv_s'](image)\n        else:\n            fea = self['in_conv'](image)\n\n        # down-sampling & bottleneck\n        for i in range(2):\n            if self.double_encoder and is_reference:\n                fea = self['down_{:d}_s'.format(i+1)](fea)\n                fea_ = self['e_bottleneck_{:d}_s'.format(i+1)](fea)\n            else:\n                fea = self['down_{:d}'.format(i+1)](fea)\n                fea_ = self['e_bottleneck_{:d}'.format(i+1)](fea)\n            fea_list.append(fea_)\n            \n            feature_size = feature_size // 2; scale_factor = scale_factor * 0.5\n            mask_ = F.interpolate(mask, feature_size, mode='nearest')\n            mask_list.append(mask_)\n\n            diff_ = self['embedding_{:d}'.format(i+1)](diff, mask)\n            diff_list.append(diff_)\n            \n            lms_ = lms * scale_factor\n            lms_list.append(lms_)\n            \n        return [fea_list, mask_list, diff_list, lms_list]\n\n\n    def get_transfer_output(self, fea_c_list, mask_c_list, diff_c_list, lms_c_list,\n                            fea_s_list, mask_s_list, diff_s_list, lms_s_list):\n        attn_out_list = []\n        for i in range(2):\n            feature_size = fea_c_list[i].shape[2]\n\n            # align\n            if i == 0:\n                fea_s_ = self.tps_align(feature_size, lms_s_list[i], lms_c_list[i], fea_s_list[i])\n                mask_s_ = self.tps_align(feature_size, lms_s_list[i], lms_c_list[i], mask_s_list[i], 'nearest')\n                diff_s_ = self.tps_align(feature_size, lms_s_list[i], lms_c_list[i], diff_s_list[i], 'nearest')\n            else:\n                fea_s_ = fea_s_list[i]\n                mask_s_ = mask_s_list[i]\n                diff_s_ = diff_s_list[i]\n\n            # transfer\n            input_q = torch.cat((fea_c_list[i], diff_c_list[i]), dim=1)\n            input_k = torch.cat((fea_s_, diff_s_), dim=1)\n            attn_out = self['attention_extract_{:d}'.format(i+1)](input_q, input_k, fea_s_, mask_c_list[i], mask_s_)\n            if self.use_ff:\n                attn_out = self['feedforward_{:d}'.format(i+1)](attn_out)\n            attn_out_list.append(attn_out)\n        \n        return attn_out_list\n\n    \n    def decode(self, fea_c_list, attn_out_list):\n        # apply\n        for i in range(2): \n            fea_c_ = self['attention_apply_{:d}'.format(i+1)](fea_c_list[i], attn_out_list[i])\n            fea_c_ = self['d_bottleneck_{:d}'.format(2-i)](fea_c_)\n            fea_c_list[i] = fea_c_\n\n        # up-sampling & merge\n        fea_c = fea_c_list[1]\n        for i in range(2):\n            fea_c = self['up_{:d}'.format(i+1)](fea_c)\n            if i < 1:  \n                fea_c = self['merge_{:d}'.format(i+1)](fea_c_list[0], fea_c)\n\n        fea_c = self['out_conv'](fea_c)\n        return fea_c\n\n    \n    def forward(self, c, s, mask_c, mask_s, diff_c, diff_s, lms_c, lms_s):\n        \"\"\"\n        c: content, stands for source image. shape: (b, c, h, w)\n        s: style, stands for reference image. shape: (b, c, h, w)\n        mask_c: (b, c', h, w)\n        diff: (b, d, h, w)\n        lms: (b, K, 2)\n        \"\"\"\n        transfer_input_c = self.get_transfer_input(c, mask_c, diff_c, lms_c)\n        transfer_input_s = self.get_transfer_input(s, mask_s, diff_s, lms_s, True)\n        attn_out_list = self.get_transfer_output(*transfer_input_c, *transfer_input_s)\n        fea_c = self.decode(transfer_input_c[0], attn_out_list)\n        return fea_c\n\n\n    def tps_align(self, feature_size, lms_s, lms_c, fea_s, sample_mode='bilinear'):\n        '''\n        fea: (B, C, H, W), lms: (B, K, 2)\n        '''\n        fea_out = []\n        for l_s, l_c, f_s in zip(lms_s, lms_c, fea_s):\n            l_c = torch.flip(l_c, dims=[1]) / (feature_size - 1)\n            l_s = (torch.flip(l_s, dims=[1]) / (feature_size - 1)).unsqueeze(0)\n            f_s = f_s.unsqueeze(0) # (1, C, H, W)\n            fea_trans, _ = tps_spatial_transform(feature_size, feature_size, l_c, f_s, l_s, sample_mode)\n            fea_out.append(fea_trans)\n        return torch.cat(fea_out, dim=0)\n"
  },
  {
    "path": "models/loss.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .modules.histogram_matching import histogram_matching\nfrom .modules.pseudo_gt import fine_align, expand_area, mask_blur\n\n\nclass GANLoss(nn.Module):\n    \"\"\"Define different GAN objectives.\n    The GANLoss class abstracts away the need to create the target label tensor\n    that has the same size as the input.\n    \"\"\"\n\n    def __init__(self, gan_mode='lsgan', target_real_label=1.0, target_fake_label=0.0):\n        \"\"\" Initialize the GANLoss class.\n        Parameters:\n            gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.\n            target_real_label (bool) - - label for a real image\n            target_fake_label (bool) - - label of a fake image\n        Note: Do not use sigmoid as the last layer of Discriminator.\n        LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.\n        \"\"\"\n        super(GANLoss, self).__init__()\n        self.register_buffer('real_label', torch.tensor(target_real_label))\n        self.register_buffer('fake_label', torch.tensor(target_fake_label))\n        self.gan_mode = gan_mode\n        if gan_mode == 'lsgan':\n            self.loss = nn.MSELoss()\n        elif gan_mode == 'vanilla':\n            self.loss = nn.BCEWithLogitsLoss()\n        else:\n            raise NotImplementedError('gan mode %s not implemented' % gan_mode)\n\n    def forward(self, prediction, target_is_real):\n        \"\"\"Calculate loss given Discriminator's output and grount truth labels.\n        Parameters:\n            prediction (tensor) - - tpyically the prediction output from a discriminator\n            target_is_real (bool) - - if the ground truth label is for real images or fake images\n        Returns:\n            the calculated loss.\n        \"\"\"\n        if target_is_real:\n            target_tensor = self.real_label\n        else:\n            target_tensor = self.fake_label\n        target_tensor = target_tensor.expand_as(prediction).to(prediction.device)\n        \n        loss = self.loss(prediction, target_tensor)\n        return loss\n\n\ndef norm(x: torch.Tensor):\n    return x * 2 - 1\n\ndef de_norm(x: torch.Tensor):\n    out = (x + 1) / 2\n    return out.clamp(0, 1)\n\ndef masked_his_match(image_s, image_r, mask_s, mask_r):\n    '''\n    image: (3, h, w)\n    mask: (1, h, w)\n    '''\n    index_tmp = torch.nonzero(mask_s)\n    x_A_index = index_tmp[:, 1]\n    y_A_index = index_tmp[:, 2]\n    index_tmp = torch.nonzero(mask_r)\n    x_B_index = index_tmp[:, 1]\n    y_B_index = index_tmp[:, 2]\n\n    image_s = (de_norm(image_s) * 255) #[-1, 1] -> [0, 255]\n    image_r = (de_norm(image_r) * 255)\n    \n    source_masked = image_s * mask_s\n    target_masked = image_r * mask_r\n    \n    source_match = histogram_matching(\n                source_masked, target_masked,\n                [x_A_index, y_A_index, x_B_index, y_B_index])\n    source_match = source_match.to(image_s.device)\n    \n    return norm(source_match / 255) #[0, 255] -> [-1, 1]\n\n\ndef generate_pgt(image_s, image_r, mask_s, mask_r, lms_s, lms_r, margins, blend_alphas, img_size=None):\n        \"\"\"\n        input_data: (3, h, w)\n        mask: (c, h, w), lip, skin, left eye, right eye\n        \"\"\"\n        if img_size is None:\n            img_size = image_s.shape[1]\n        pgt = image_s.detach().clone()\n\n        # skin match\n        skin_match = masked_his_match(image_s, image_r, mask_s[1:2], mask_r[1:2])\n        pgt = (1 - mask_s[1:2]) * pgt + mask_s[1:2] * skin_match\n\n        # lip match\n        lip_match = masked_his_match(image_s, image_r, mask_s[0:1], mask_r[0:1])\n        pgt = (1 - mask_s[0:1]) * pgt + mask_s[0:1] * lip_match\n\n        # eye match\n        mask_s_eye = expand_area(mask_s[2:4].sum(dim=0, keepdim=True), margins['eye']) * mask_s[1:2]\n        mask_r_eye = expand_area(mask_r[2:4].sum(dim=0, keepdim=True), margins['eye']) * mask_r[1:2]\n        eye_match = masked_his_match(image_s, image_r, mask_s_eye, mask_r_eye)\n        mask_s_eye_blur = mask_blur(mask_s_eye, blur_size=5, mode='valid')\n        pgt = (1 - mask_s_eye_blur) * pgt + mask_s_eye_blur * eye_match\n\n        # tps align\n        pgt = fine_align(img_size, lms_r, lms_s, image_r, pgt, mask_r, mask_s, margins, blend_alphas)\n        return pgt\n\n\nclass LinearAnnealingFn():\n    \"\"\"\n    define the linear annealing function with milestones\n    \"\"\"\n    def __init__(self, milestones, f_values):\n        assert len(milestones) == len(f_values)\n        self.milestones = milestones\n        self.f_values = f_values\n        \n    def __call__(self, t:int):\n        if t < self.milestones[0]:\n            return self.f_values[0]\n        elif t >= self.milestones[-1]:\n            return self.f_values[-1]\n        else:\n            for r in range(len(self.milestones) - 1):\n                if self.milestones[r] <= t < self.milestones[r+1]:\n                    return ((t - self.milestones[r]) * self.f_values[r+1] \\\n                            + (self.milestones[r+1] - t) * self.f_values[r]) \\\n                            / (self.milestones[r+1] - self.milestones[r])\n\n\nclass ComposePGT(nn.Module):\n    def __init__(self, margins, skin_alpha, eye_alpha, lip_alpha):\n        super(ComposePGT, self).__init__()\n        self.margins = margins\n        self.blend_alphas = {\n            'skin':skin_alpha,\n            'eye':eye_alpha,\n            'lip':lip_alpha\n        }\n\n    @torch.no_grad()\n    def forward(self, sources, targets, mask_srcs, mask_tars, lms_srcs, lms_tars):\n        pgts = []\n        for source, target, mask_src, mask_tar, lms_src, lms_tar in\\\n            zip(sources, targets, mask_srcs, mask_tars, lms_srcs, lms_tars):\n            pgt = generate_pgt(source, target, mask_src, mask_tar, lms_src, lms_tar, \n                               self.margins, self.blend_alphas)\n            pgts.append(pgt)\n        pgts = torch.stack(pgts, dim=0)\n        return pgts   \n\nclass AnnealingComposePGT(nn.Module):\n    def __init__(self, margins,\n            skin_alpha_milestones, skin_alpha_values,\n            eye_alpha_milestones, eye_alpha_values,\n            lip_alpha_milestones, lip_alpha_values\n        ):\n        super(AnnealingComposePGT, self).__init__()\n        self.margins = margins\n        self.skin_alpha_fn = LinearAnnealingFn(skin_alpha_milestones, skin_alpha_values)\n        self.eye_alpha_fn = LinearAnnealingFn(eye_alpha_milestones, eye_alpha_values)\n        self.lip_alpha_fn = LinearAnnealingFn(lip_alpha_milestones, lip_alpha_values)\n        \n        self.t = 0\n        self.blend_alphas = {}\n        self.step()\n\n    def step(self):\n        self.t += 1\n        self.blend_alphas['skin'] = self.skin_alpha_fn(self.t)\n        self.blend_alphas['eye'] = self.eye_alpha_fn(self.t)\n        self.blend_alphas['lip'] = self.lip_alpha_fn(self.t)\n\n    @torch.no_grad()\n    def forward(self, sources, targets, mask_srcs, mask_tars, lms_srcs, lms_tars):\n        pgts = []\n        for source, target, mask_src, mask_tar, lms_src, lms_tar in\\\n            zip(sources, targets, mask_srcs, mask_tars, lms_srcs, lms_tars):\n            pgt = generate_pgt(source, target, mask_src, mask_tar, lms_src, lms_tar,\n                               self.margins, self.blend_alphas)\n            pgts.append(pgt)\n        pgts = torch.stack(pgts, dim=0)\n        return pgts   \n\n\nclass MakeupLoss(nn.Module):\n    \"\"\"\n    Define the makeup loss w.r.t pseudo ground truth\n    \"\"\"\n    def __init__(self):\n        super(MakeupLoss, self).__init__()\n\n    def forward(self, x, target, mask=None):\n        if mask is None:\n            return F.l1_loss(x, target)\n        else:\n            return F.l1_loss(x * mask, target * mask)"
  },
  {
    "path": "models/model.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torchvision.models import VGG as TVGG\nfrom torchvision.models.vgg import load_state_dict_from_url, model_urls, cfgs\n\nfrom .modules.spectral_norm import spectral_norm as SpectralNorm\nfrom .elegant import Generator\n\n\ndef get_generator(config):\n    kwargs = {\n        'conv_dim':config.MODEL.G_CONV_DIM,\n        'image_size':config.DATA.IMG_SIZE,\n        'num_head':config.MODEL.NUM_HEAD,\n        'double_encoder':config.MODEL.DOUBLE_E,\n        'use_ff':config.MODEL.USE_FF,\n        'num_layer_e':config.MODEL.NUM_LAYER_E,\n        'num_layer_d':config.MODEL.NUM_LAYER_D,\n        'window_size':config.MODEL.WINDOW_SIZE,\n        'merge_mode':config.MODEL.MERGE_MODE\n    }\n    G = Generator(**kwargs)\n    return G\n\n\ndef get_discriminator(config):\n    kwargs = {\n        'input_channel': 3,\n        'conv_dim':config.MODEL.D_CONV_DIM,\n        'num_layers':config.MODEL.D_REPEAT_NUM,\n        'norm':config.MODEL.D_TYPE\n    }\n    D = Discriminator(**kwargs)\n    return D\n\n\nclass Discriminator(nn.Module):\n    \"\"\"Discriminator. PatchGAN.\"\"\"\n    def __init__(self, input_channel=3, conv_dim=64, num_layers=3, norm='SN', **unused):\n        super(Discriminator, self).__init__()\n\n        layers = []\n        if norm=='SN':\n            layers.append(SpectralNorm(nn.Conv2d(input_channel, conv_dim, kernel_size=4, stride=2, padding=1)))\n        else:\n            layers.append(nn.Conv2d(input_channel, conv_dim, kernel_size=4, stride=2, padding=1))\n        layers.append(nn.LeakyReLU(0.01, inplace=True))\n\n        curr_dim = conv_dim\n        for i in range(1, num_layers):\n            if norm=='SN':\n                layers.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1)))\n            else:\n                layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))\n            layers.append(nn.LeakyReLU(0.01, inplace=True))\n            curr_dim = curr_dim * 2\n\n        #k_size = int(image_size / np.power(2, repeat_num))\n        if norm=='SN':\n            layers.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=1, padding=1)))\n        else:\n            layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=1, padding=1))\n        layers.append(nn.LeakyReLU(0.01, inplace=True))\n        curr_dim = curr_dim * 2\n\n        self.main = nn.Sequential(*layers)\n        if norm=='SN':\n            self.conv1 = SpectralNorm(nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False))\n        else:\n            self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False)\n\n    def forward(self, x):\n        h = self.main(x)\n        out_makeup = self.conv1(h)\n        return out_makeup\n\n\nclass VGG(TVGG):\n    def forward(self, x):\n        x = self.features(x)\n        return x\n\n\ndef make_layers(cfg, batch_norm=False):\n    layers = []\n    in_channels = 3\n    for v in cfg:\n        if v == 'M':\n            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]\n        else:\n            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)\n            if batch_norm:\n                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]\n            else:\n                layers += [conv2d, nn.ReLU(inplace=True)]\n            in_channels = v\n    return nn.Sequential(*layers)\n\n\ndef _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):\n    if pretrained:\n        kwargs['init_weights'] = False\n    model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)\n    if pretrained:\n        state_dict = load_state_dict_from_url(model_urls[arch],\n                                              progress=progress)\n        model.load_state_dict(state_dict)\n    return model\n\n\ndef vgg16(pretrained=False, progress=True, **kwargs):\n    r\"\"\"VGG 16-layer model (configuration \"D\")\n    `\"Very Deep Convolutional Networks For Large-Scale Image Recognition\" <https://arxiv.org/pdf/1409.1556.pdf>`_\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)"
  },
  {
    "path": "models/modules/__init__.py",
    "content": ""
  },
  {
    "path": "models/modules/histogram_matching.py",
    "content": "import copy\nimport torch\n\ndef cal_hist(image):\n    \"\"\"\n        cal cumulative hist for channel list\n    \"\"\"\n    hists = []\n    for i in range(0, 3):\n        channel = image[i]\n        # channel = image[i, :, :]\n        channel = torch.from_numpy(channel)\n        # hist, _ = np.histogram(channel, bins=256, range=(0,255))\n        hist = torch.histc(channel, bins=256, min=0, max=256)\n        hist = hist.numpy()\n        # refHist=hist.view(256,1)\n        sum = hist.sum()\n        pdf = [v / (sum + 1e-10) for v in hist]\n        for i in range(1, 256):\n            pdf[i] = pdf[i - 1] + pdf[i]\n        hists.append(pdf)\n    return hists\n\n\ndef cal_trans(ref, adj):\n    \"\"\"\n        calculate transfer function\n        algorithm refering to wiki item: Histogram matching\n    \"\"\"\n    table = list(range(0, 256))\n    for i in list(range(1, 256)):\n        for j in list(range(1, 256)):\n            if ref[i] >= adj[j - 1] and ref[i] <= adj[j]:\n                table[i] = j\n                break\n    table[255] = 255\n    return table\n\ndef histogram_matching(dstImg, refImg, index):\n    \"\"\"\n        perform histogram matching\n        dstImg is transformed to have the same the histogram with refImg's\n        index[0], index[1]: the index of pixels that need to be transformed in dstImg\n        index[2], index[3]: the index of pixels that to compute histogram in refImg\n    \"\"\"\n    index = [x.cpu().numpy() for x in index]\n    dstImg = dstImg.detach().cpu().numpy()\n    refImg = refImg.detach().cpu().numpy()\n    dst_align = [dstImg[i, index[0], index[1]] for i in range(0, 3)]\n    ref_align = [refImg[i, index[2], index[3]] for i in range(0, 3)]\n    hist_ref = cal_hist(ref_align)\n    hist_dst = cal_hist(dst_align)\n    tables = [cal_trans(hist_dst[i], hist_ref[i]) for i in range(0, 3)]\n\n    mid = copy.deepcopy(dst_align)\n    for i in range(0, 3):\n        for k in range(0, len(index[0])):\n            dst_align[i][k] = tables[i][int(mid[i][k])]\n\n    for i in range(0, 3):\n        dstImg[i, index[0], index[1]] = dst_align[i]\n\n    dstImg = torch.FloatTensor(dstImg)\n    return dstImg"
  },
  {
    "path": "models/modules/module_attn.py",
    "content": "import math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass MultiheadAttention_weight(nn.Module):\n    def __init__(self, feature_dim, proj_dim, num_heads=1, dropout=0.0, bias=True):\n        super(MultiheadAttention_weight, self).__init__()\n        self.feature_dim = feature_dim\n        self.proj_dim = proj_dim\n        self.num_heads = num_heads\n        self.dropout = nn.Dropout(dropout)\n        self.head_dim = proj_dim // num_heads\n        assert self.head_dim * num_heads == self.proj_dim, \"embed_dim must be divisible by num_heads\"\n        self.scaling = self.head_dim ** -0.5\n\n        self.q_proj = nn.Linear(feature_dim, proj_dim, bias=bias)\n        self.k_proj = nn.Linear(feature_dim, proj_dim, bias=bias)\n\n    def forward(self, fea_c, fea_s, mask_c, mask_s):\n        '''\n        fea_c: (b, d, h, w)\n        mask_c: (b, c, h, w)\n        '''\n        bsz, dim, h, w = fea_c.shape; mask_channel = mask_c.shape[1]\n\n        fea_c = fea_c.view(bsz, dim, h*w).transpose(1, 2) # (b, HW, d)\n        fea_s = fea_s.view(bsz, dim, h*w).transpose(1, 2)\n        with torch.no_grad():\n            if mask_c.shape[2] != h:\n                mask_c = F.interpolate(mask_c, size=(h, w)) \n                mask_s = F.interpolate(mask_s, size=(h, w)) \n            mask_c = mask_c.view(bsz, mask_channel, -1, h*w) # (b, m_c, 1, HW)\n            mask_s = mask_s.view(bsz, mask_channel, -1, h*w)\n            mask_attn = torch.matmul(mask_c.transpose(-2, -1), mask_s) # (b, m_c, HW, HW)\n            mask_attn = torch.sum(mask_attn, dim=1, keepdim=True).clamp_(0, 1) # (b, 1, HW, HW)\n            mask_sum = torch.sum(mask_attn, dim=-1, keepdim=True)\n            mask_attn += (mask_sum == 0).float()\n            mask_attn = mask_attn.masked_fill_(mask_attn == 0, float('-inf')).masked_fill_(mask_attn == 1, float(0.0))\n\n        query = self.q_proj(fea_c) # (b, HW, D)\n        key = self.k_proj(fea_s) # (b, HW, D)\n        query = query.view(bsz, h*w, self.num_heads, self.head_dim).transpose(1, 2) # (b, h, HW, D)\n        key = key.view(bsz, h*w, self.num_heads, self.head_dim).transpose(1, 2)\n\n        weights = torch.matmul(query, key.transpose(-1, -2)) # (b, h, HW, HW)\n        weights = weights * self.scaling\n        weights = weights + mask_attn.detach()\n        weights = self.dropout(F.softmax(weights, dim=-1))\n        weights = weights * (1 - (mask_sum == 0).float().detach())\n        return weights \n\n\nclass MultiheadAttention_value(nn.Module):\n    def __init__(self, feature_dim, proj_dim, num_heads=1, bias=True):\n        super(MultiheadAttention_value, self).__init__()\n        self.feature_dim = feature_dim\n        self.proj_dim = proj_dim\n        self.num_heads = num_heads\n        self.head_dim = proj_dim // num_heads\n        assert self.head_dim * num_heads == self.proj_dim, \"embed_dim must be divisible by num_heads\"\n        \n        self.v_proj = nn.Linear(feature_dim, proj_dim, bias=bias)\n\n    def forward(self, weights, fea):\n        '''\n        weights: (b, h, HW. HW)\n        fea: (b, d, H, W)\n        '''\n        bsz, dim, h, w = fea.shape\n        fea = fea.view(bsz, dim, h*w).transpose(1, 2) #(b, HW, D)\n        value = self.v_proj(fea)\n        value = value.view(bsz, h*w, self.num_heads, self.head_dim).transpose(1, 2) #(b, h, HW, D)\n\n        out = torch.matmul(weights, value)\n        out = out.transpose(1, 2).contiguous().view(bsz, h*w, self.proj_dim) # (b, HW, D)\n        out = out.transpose(1, 2).view(bsz, self.proj_dim, h, w) #(b, d, H, W)\n        return out\n\n\nclass MultiheadAttention(nn.Module):\n    def __init__(self, in_channels, proj_channels, value_channels, out_channels, num_heads=1, dropout=0.0, bias=True):\n        super(MultiheadAttention, self).__init__()\n        self.weight = MultiheadAttention_weight(in_channels, proj_channels, num_heads, dropout, bias)\n        self.value = MultiheadAttention_value(value_channels, out_channels, num_heads, bias)\n\n    def forward(self, fea_q, fea_k, fea_v, mask_q, mask_k):\n        '''\n        fea: (b, d, h, w)\n        mask: (b, c, h, w)\n        '''\n        weights = self.weight(fea_q, fea_k, mask_q, mask_k)\n        return self.value(weights, fea_v)\n\n\nclass FeedForwardLayer(nn.Module):\n    def __init__(self, feature_dim, ff_dim, dropout=0.0):\n        super(FeedForwardLayer, self).__init__()\n        self.main = nn.Sequential(\n            nn.LeakyReLU(0.2, inplace=True),\n            nn.Dropout(p=dropout, inplace=True),\n            nn.Conv2d(feature_dim, ff_dim, kernel_size=1),\n            nn.LeakyReLU(0.2, inplace=True),\n            nn.Conv2d(ff_dim, feature_dim, kernel_size=1)\n        )\n\n    def forward(self, x):\n        return self.main(x)\n\n\nclass Attention_apply(nn.Module):\n    def __init__(self, feature_dim, normalize=True):\n        super(Attention_apply, self).__init__()\n        self.normalize = normalize\n        if normalize:\n            self.norm = nn.InstanceNorm2d(feature_dim, affine=False)\n        self.actv = nn.LeakyReLU(0.2, inplace=True)\n        self.conv = nn.Conv2d(feature_dim, feature_dim, kernel_size=3, stride=1, padding=1, bias=False)\n\n    def forward(self, x, attn_out):\n        if self.normalize:\n            x = self.norm(x) \n        x = x * (1 + attn_out)\n        return self.conv(self.actv(x))\n"
  },
  {
    "path": "models/modules/module_base.py",
    "content": "import math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass ResidualBlock(nn.Module):\n    \"\"\"Residual Block.\"\"\"\n    def __init__(self, dim_in, dim_out):            \n        super(ResidualBlock, self).__init__()\n        self.main = nn.Sequential(\n            nn.LeakyReLU(0.2, inplace=True),\n            nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),\n            nn.LeakyReLU(0.2, inplace=True),\n            nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False)\n        )\n        self.skip = nn.Identity() if dim_in == dim_out else nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=False)\n\n    def forward(self, x):\n        x = self.skip(x) + self.main(x)\n        return x / math.sqrt(2)\n\n\nclass ResidualBlock_IN(nn.Module):\n    \"\"\"Residual Block with InstanceNorm.\"\"\"\n    def __init__(self, dim_in, dim_out, affine=False):            \n        super(ResidualBlock_IN, self).__init__()\n        self.main = nn.Sequential(\n            nn.InstanceNorm2d(dim_in, affine=affine),\n            nn.LeakyReLU(0.2, inplace=True),\n            nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),\n            nn.InstanceNorm2d(dim_out, affine=affine),\n            nn.LeakyReLU(0.2, inplace=True),\n            nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),\n        )\n        self.skip = nn.Identity() if dim_in == dim_out else nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=False)\n\n    def forward(self, x):\n        x = self.skip(x) + self.main(x)\n        return x / math.sqrt(2)\n\n\nclass ResidualBlock_Downsample(nn.Module):\n    \"\"\"Residual Block with InstanceNorm.\"\"\"\n    def __init__(self, dim_in, dim_out, affine=False):            \n        super(ResidualBlock_Downsample, self).__init__()\n        self.main = nn.Sequential(\n            nn.InstanceNorm2d(dim_in, affine=affine),\n            nn.LeakyReLU(0.2, inplace=True),\n            nn.Conv2d(dim_in, dim_out, kernel_size=4, stride=2, padding=1, bias=False)    \n        )\n        if dim_in == dim_out:\n            self.skip = nn.Identity()\n        else:\n            self.skip = nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=1, bias=False)\n\n    def forward(self, x):\n        skip = F.interpolate(self.skip(x), scale_factor=0.5, mode='bilinear', align_corners=False, recompute_scale_factor=True)\n        res = self.main(x)\n        x = skip + res\n        return x / math.sqrt(2)\n\n\nclass Downsample(nn.Module):\n    \"\"\"Residual Block with InstanceNorm.\"\"\"\n    def __init__(self, dim_in, dim_out, affine=False):            \n        super(Downsample, self).__init__()\n        self.main = nn.Sequential(\n            nn.InstanceNorm2d(dim_in, affine=affine),\n            nn.LeakyReLU(0.2, inplace=True),\n            nn.Conv2d(dim_in, dim_out, kernel_size=4, stride=2, padding=1, bias=False)           \n        )\n\n    def forward(self, x):\n        return self.main(x)\n\n\nclass ResidualBlock_Upsample(nn.Module):\n    \"\"\"Residual Block with InstanceNorm.\"\"\"\n    def __init__(self, dim_in, dim_out, normalize=True, affine=False):            \n        super(ResidualBlock_Upsample, self).__init__()\n        if normalize:\n            self.main = nn.Sequential(\n                nn.InstanceNorm2d(dim_in, affine=affine),\n                nn.LeakyReLU(0.2, inplace=True),\n                nn.ConvTranspose2d(dim_in, dim_out, kernel_size=4, stride=2, padding=1, bias=False)\n            )\n        else:\n            self.main = nn.Sequential(\n                nn.LeakyReLU(0.2, inplace=True),\n                nn.ConvTranspose2d(dim_in, dim_out, kernel_size=4, stride=2, padding=1, bias=False)\n            )\n        if dim_in == dim_out:\n            self.skip = nn.Identity()\n        else:\n            self.skip = nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=1, bias=False)\n\n    def forward(self, x):\n        skip = F.interpolate(self.skip(x), scale_factor=2, mode='bilinear', align_corners=False)\n        res = self.main(x)\n        x = skip + res\n        return x / math.sqrt(2)\n\n\nclass Upsample(nn.Module):\n    \"\"\"Residual Block with InstanceNorm.\"\"\"\n    def __init__(self, dim_in, dim_out, normalize=True, affine=False):            \n        super(Upsample, self).__init__()\n        if normalize:\n            self.main = nn.Sequential(\n                nn.InstanceNorm2d(dim_in, affine=affine),\n                nn.LeakyReLU(0.2, inplace=True),\n                nn.ConvTranspose2d(dim_in, dim_out, kernel_size=4, stride=2, padding=1, bias=False)\n            )\n        else:\n            self.main = nn.Sequential(\n                nn.LeakyReLU(0.2, inplace=True),\n                nn.ConvTranspose2d(dim_in, dim_out, kernel_size=4, stride=2, padding=1, bias=False)\n            )\n\n    def forward(self, x):\n        return self.main(x)\n\n\nclass PositionalEmbedding(nn.Module):\n    def __init__(self, embedding_dim=136, feature_size=64, max_size=None, embedding_type='l2_norm'):\n        super(PositionalEmbedding, self).__init__()\n        self.embedding_dim = embedding_dim\n        self.feature_size = feature_size\n        self.max_size = max_size\n        assert embedding_type in ['l2_norm', 'uniform', 'sin']\n        self.embedding_type = embedding_type\n\n    @torch.no_grad()\n    def forward(self, diff, mask):\n        '''\n        diff: (b, d, h, w)\n        mask: (b, 3, h, w)\n        return: (b, d, h, w)\n        '''\n        bsz, init_dim, init_size, _ = diff.shape\n        assert self.embedding_dim >= init_dim\n        diff = F.interpolate(diff, self.feature_size) # (b, d, h, w)\n        mask = F.interpolate(mask, size=self.feature_size)\n        mask = torch.sum(mask, dim=1, keepdim=True) # (b, 1, h, w)\n        diff = diff * mask\n        \n        if self.embedding_type == 'l2_norm':\n            norm = torch.norm(diff, dim=1, keepdim=True)\n            norm = (norm == 0) + norm\n            diff = diff / norm\n        elif self.embedding_type == 'uniform':\n            diff = diff / self.max_size\n        elif self.embedding_type == 'sin':\n            diff = torch.sin(diff * math.pi / (2 * self.max_size))\n        \n        if self.embedding_dim > init_dim:\n            zero_shape = (bsz, self.embedding_dim - init_dim, self.feature_size, self.feature_size)\n            zero_padding = torch.zeros(zero_shape, device=diff.device)\n            diff = torch.cat((diff, zero_padding), dim=1)\n\n        diff = diff.detach(); diff.requires_grad = False\n        return diff\n\nclass MergeBlock(nn.Module):\n    def __init__(self, merge_mode, feature_dim, normalize=True):\n        super(MergeBlock, self).__init__()\n        assert merge_mode in ['conv', 'add', 'affine']\n        self.merge_mode = merge_mode\n        if merge_mode == 'affine':\n            self.norm = nn.LayerNorm(feature_dim, elementwise_affine=False) if normalize else nn.Identity()\n        else:\n            self.norm = nn.InstanceNorm2d(feature_dim, affine=False) if normalize else nn.Identity()\n        self.norm_r = nn.InstanceNorm2d(feature_dim, affine=False) if normalize else nn.Identity()\n        self.actv = nn.LeakyReLU(0.2, inplace=True)\n        if merge_mode == 'conv':\n            self.conv = nn.Conv2d(2 * feature_dim, feature_dim, kernel_size=3, stride=1, padding=1, bias=False)\n        else:\n            self.conv = nn.Conv2d(feature_dim, feature_dim, kernel_size=3, stride=1, padding=1, bias=False)\n\n    def forward(self, fea_s, fea_r):\n        if self.merge_mode == 'conv':\n            fea_s = self.norm(fea_s)\n            fea_r = self.norm_r(fea_r)\n            fea_s = torch.cat((fea_s, fea_r), dim=1)\n        elif self.merge_mode == 'add':\n            fea_s = self.norm(fea_s)\n            fea_r = self.norm_r(fea_r)\n            fea_s = (fea_s + fea_r) / math.sqrt(2)\n        elif self.merge_mode == 'affine':\n            fea_s = fea_s.permute(0, 2, 3, 1)\n            fea_s = self.norm(fea_s)\n            fea_s = fea_s.permute(0, 3, 1, 2)\n            fea_s = fea_s * (1 + fea_r)\n        return self.conv(self.actv(fea_s))"
  },
  {
    "path": "models/modules/pseudo_gt.py",
    "content": "import cv2\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torchvision.transforms import functional\n\nfrom models.modules.tps_transform import tps_sampler, tps_spatial_transform\n\n\ndef expand_area(mask:torch.Tensor, margin:int):\n    '''\n    mask: (C, H, W) or (N, C, H, W)\n    '''\n    kernel = np.zeros((margin * 2 + 1, margin * 2 + 1), dtype=np.uint8)\n    kernel = cv2.circle(kernel, (margin, margin), margin, (255, 0, 0), -1)\n    kernel = torch.FloatTensor((kernel > 0)).unsqueeze(0).unsqueeze(0).to(mask.device)\n    ndim = mask.ndimension()\n    if ndim == 3:\n        mask = mask.unsqueeze(0)\n    expanded_mask = torch.zeros_like(mask)\n    for i in range(mask.shape[1]):\n        expanded_mask[:,i:i+1,:,:] = F.conv2d(mask[:,i:i+1,:,:], kernel, padding=margin)\n    if ndim == 3:\n        expanded_mask = expanded_mask.squeeze(0)\n    return (expanded_mask > 0).float()\n\ndef mask_blur(mask:torch.Tensor, blur_size=3, mode='smooth'):\n    \"\"\"Blur the edge of mask so that the compose image have smooth transition\n    Args:\n        mask (torch.Tensor): [C, H, W]\n        blur_size (int): size of blur kernel. Defaults to 3.\n        mode (str) Defaults to 'smooth'.\n    Returns:\n        torch.Tensor: blurred mask\n    \"\"\"\n    #kernel = torch.ones((1, 1, blur_size * 2 + 1, blur_size * 2 + 1)).to(mask.device)\n    kernel = np.zeros((blur_size * 2 + 1, blur_size * 2 + 1), dtype=np.uint8)\n    kernel = cv2.circle(kernel, (blur_size, blur_size), blur_size, (255, 0, 0), -1)\n    kernel = torch.FloatTensor((kernel > 0)).unsqueeze(0).unsqueeze(0).to(mask.device)\n    kernel = kernel / torch.sum(kernel)\n    ndim = mask.ndimension()\n    if ndim == 3:\n        mask = mask.unsqueeze(0)\n    mask_blur = torch.zeros_like(mask)\n    for i in range(mask.shape[1]):\n        mask_blur[:,i:i+1,:,:] = F.conv2d(mask[:,i:i+1,:,:], kernel, padding=blur_size)\n    if mode == 'valid':\n        mask_blur = (mask_blur.clamp(0.5, 1) - 0.5) * 2 * mask\n    if ndim == 3:\n        mask_blur = mask_blur.squeeze(0)\n    return mask_blur.clamp(0, 1)\n\ndef mask_blend(mask, blend_alpha, mask_bound=None, blur_size=3, blend_mode='smooth'):\n    if blur_size > 0:\n        mask = mask_blur(mask, blur_size, blend_mode)\n    mask = mask * blend_alpha\n    if mask_bound is None:\n        return mask\n    else:\n        return mask * mask_bound\n\n\ndef tps_align(img_size, lms_r, lms_s, image_r, image_s=None, \n              mask_r = None, mask_s=None, sample_mode='bilinear'):\n    '''\n    image: (C, H, W), lms: (K, 2), mask:(1, H, W)\n    '''\n    lms_s = torch.flip(lms_s, dims=[1]) / (img_size - 1)\n    lms_r = (torch.flip(lms_r, dims=[1]) / (img_size - 1)).unsqueeze(0)\n    image_r = image_r.unsqueeze(0)\n    image_trans, _ = tps_spatial_transform(img_size, img_size, lms_s, image_r, lms_r, sample_mode)\n    if mask_r is not None:\n        mask_r_trans, _ = tps_spatial_transform(img_size, img_size, lms_s, mask_r.unsqueeze(0), \n                                                lms_r, 'nearest')\n    if image_s is not None:\n        mask_compose = torch.ones((1, img_size, img_size), device=lms_r.device)\n        if mask_s is not None:\n            mask_compose *= mask_s\n        if mask_r is not None:\n            mask_compose *= mask_r_trans.squeeze(0)\n        return image_s * (1 - mask_compose) + image_trans.squeeze(0) * mask_compose\n    else:\n        return image_trans.squeeze(0)\n\ndef tps_blend(blend_alpha, img_size, lms_r, lms_s, image_r, image_s=None, mask_r = None, mask_s=None, \n              mask_s_bound=None, blur_size=7, sample_mode='bilinear', blend_mode='smooth'):\n    '''\n    image: (C, H, W), lms: (K, 2), mask:(1, H, W)\n    '''\n    lms_s = torch.flip(lms_s, dims=[1]) / (img_size - 1)\n    lms_r = (torch.flip(lms_r, dims=[1]) / (img_size - 1)).unsqueeze(0)\n    image_r = image_r.unsqueeze(0)\n    image_trans, _ = tps_spatial_transform(img_size, img_size, lms_s, image_r, lms_r, sample_mode)\n    if mask_r is not None:\n        mask_r_trans, _ = tps_spatial_transform(img_size, img_size, lms_s, mask_r.unsqueeze(0), \n                                                lms_r, 'nearest')\n    if image_s is not None:\n        mask_compose = torch.ones((1, img_size, img_size), device=lms_r.device)\n        if mask_s is not None:\n            mask_compose *= mask_s\n        if mask_r is not None:\n            mask_compose *= mask_r_trans.squeeze(0)\n        mask_compose = mask_blend(mask_compose, blend_alpha, mask_s_bound, blur_size, blend_mode)\n        return image_s * (1 - mask_compose) + image_trans.squeeze(0) * mask_compose\n    else:\n        return image_trans.squeeze(0)\n\n\ndef fine_align(img_size, lms_r, lms_s, image_r, image_s, mask_r, mask_s, margins, blend_alphas):\n    '''\n    image: (C, H, W), lms: (K, 2)\n    mask: (C, H, W), lip, face, left eye, right eye\n    margins: dictionary, blend_alphas: dictionary\n    '''\n    # skin align\n    image_s = tps_blend(blend_alphas['skin'], img_size, lms_r[:60], lms_s[:60], image_r, image_s, \n                        mask_r[1:2], mask_s[1:2], mask_s[1:2], blur_size=8, blend_mode='valid')\n\n    # lip align\n    mask_s_lip = expand_area(mask_s[0:1], margins['lip'])\n    mask_r_lip = expand_area(mask_r[0:1], margins['lip'])\n    image_s = tps_blend(blend_alphas['lip'], img_size, lms_r[48:], lms_s[48:], image_r, image_s, \n                        mask_r_lip, mask_s_lip, mask_s[0:1], blur_size=3)\n\n    # left eye align\n    mask_s_eye = expand_area(mask_s[2:3], margins['eye'])\n    mask_r_eye = expand_area(mask_r[2:3], margins['eye']) * mask_r[1:2]\n    image_s = tps_blend(blend_alphas['eye'], img_size, \n                        torch.cat((lms_r[14:17], lms_r[22:27], lms_r[27:31], lms_r[42:48]), dim=0), \n                        torch.cat((lms_s[14:17], lms_s[22:27], lms_s[27:31], lms_s[42:48]), dim=0), \n                        image_r, image_s, mask_r_eye, mask_s_eye, mask_s[1:2], \n                        blur_size=5, sample_mode='nearest')\n\n    # right eye align\n    mask_s_eye = expand_area(mask_s[3:4], margins['eye'])\n    mask_r_eye = expand_area(mask_r[3:4], margins['eye']) * mask_r[1:2]\n    image_s = tps_blend(blend_alphas['eye'], img_size, \n                        torch.cat((lms_r[0:3], lms_r[17:22], lms_r[27:31], lms_r[36:42]), dim=0), \n                        torch.cat((lms_s[0:3], lms_s[17:22], lms_s[27:31], lms_s[36:42]), dim=0), \n                        image_r, image_s, mask_r_eye, mask_s_eye, mask_s[1:2], \n                        blur_size=5, sample_mode='nearest')\n\n    return image_s\n\n\nif __name__ == \"__main__\":\n    pass"
  },
  {
    "path": "models/modules/sow_attention.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass WindowAttention(nn.Module):\n    def __init__(self, window_size, in_channels, proj_channels, value_channels, out_channels, \n                 num_heads=1, dropout=0.0, bias=True, weighted_output=True):\n        super(WindowAttention, self).__init__()\n        assert window_size % 2 == 0\n        self.window_size = window_size\n        self.weighted_output = weighted_output\n        window_weight = self.generate_window_weight()\n        self.register_buffer('window_weight', window_weight)\n\n        self.num_heads = num_heads\n        self.dropout = nn.Dropout(dropout)\n        self.in_channels = in_channels\n        self.proj_channels = proj_channels\n        head_dim = proj_channels // num_heads\n        assert head_dim * num_heads == self.proj_channels, \"embed_dim must be divisible by num_heads\"\n        self.scaling = head_dim ** -0.5\n\n        self.q_proj = nn.Conv2d(in_channels, proj_channels, kernel_size=1, bias=bias)\n        self.k_proj = nn.Conv2d(in_channels, proj_channels, kernel_size=1, bias=bias)\n\n        self.value_channels = value_channels\n        self.out_channels = out_channels\n        assert out_channels // num_heads * num_heads == self.out_channels\n        self.v_proj = nn.Conv2d(value_channels, out_channels, kernel_size=1, bias=bias)\n\n    @torch.no_grad()\n    def generate_window_weight(self):\n        yc = torch.arange(self.window_size // 2).unsqueeze(1).repeat(1, self.window_size // 2)\n        xc = torch.arange(self.window_size // 2).unsqueeze(0).repeat(self.window_size // 2, 1)\n        window_weight = xc * yc / (self.window_size // 2 - 1) ** 2\n        window_weight = torch.cat((window_weight, torch.flip(window_weight, dims=[0])), dim=0)\n        window_weight = torch.cat((window_weight, torch.flip(window_weight, dims=[1])), dim=1)\n        return window_weight.view(-1)   \n\n    def make_window(self, x: torch.Tensor):\n        \"\"\"\n        input: (B, C, H, W)\n        output: (B, h, H/S, W/S, S*S, C/h)\n        \"\"\"\n        bsz, dim, h, w = x.shape\n        x = x.view(bsz, self.num_heads, dim // self.num_heads, h // self.window_size, self.window_size, \n                   w // self.window_size, self.window_size)\n        x = x.transpose(4, 5).contiguous().view(bsz, self.num_heads, dim // self.num_heads, \n                                                h // self.window_size, w // self.window_size, self.window_size**2)\n        x = x.permute(0, 1, 3, 4, 5, 2)\n        return x\n\n    def demake_window(self, x: torch.Tensor):\n        \"\"\"\n        input: (B, h, H/S, W/S, S*S, C/h)\n        output: (B, C, H, W)\n        \"\"\"\n        bsz, _, h_s, w_s, _, dim_h = x.shape\n        x = x.permute(0, 1, 5, 2, 3, 4).contiguous()\n        #print(x.shape)\n        x = x.view(bsz, dim_h * self.num_heads, h_s, w_s, self.window_size, self.window_size)\n        #print(x.shape)\n        x = x.transpose(3, 4).contiguous().view(bsz, dim_h * self.num_heads, \n                                                h_s * self.window_size, w_s * self.window_size)\n        #print(x.shape)\n        return x\n\n    @torch.no_grad()\n    def make_mask_window(self, mask: torch.Tensor):\n        \"\"\"\n        input: (B, C, H, W)\n        output: (B, 1, H/S, W/S, S*S, C)\n        \"\"\"\n        bsz, mask_channel, h, w = mask.shape\n        mask = mask.view(bsz, 1, mask_channel, h // self.window_size, self.window_size, \n                         w // self.window_size, self.window_size)\n        mask = mask.transpose(4, 5).contiguous().view(bsz, 1, mask_channel, \n                                                      h // self.window_size, w // self.window_size, self.window_size**2)\n        mask = mask.permute(0, 1, 3, 4, 5, 2)\n        return mask\n    \n    def forward(self, fea_q, fea_k, fea_v, mask_q=None, mask_k=None):\n        '''\n        fea: (b, d, h, w)\n        mask: (b, c, h, w)\n        '''\n        query = self.q_proj(fea_q) # (B, D, H, W)\n        key = self.k_proj(fea_k)\n        value = self.v_proj(fea_v)\n        query = self.make_window(query) # (B, h, H/S, W/S, S*S, D/h)\n        key = self.make_window(key)\n        value = self.make_window(value)\n        \n        weights = torch.matmul(query, key.transpose(-1, -2)) # (B, h, H/S, W/S, S*S, S*S)\n        weights = weights * self.scaling\n        if mask_q is not None and mask_k is not None:\n            mask_q = self.make_mask_window(mask_q) # (B, 1, H/S, W/S, S*S, C)\n            mask_k = self.make_mask_window(mask_k)\n            with torch.no_grad():\n                mask_attn = torch.matmul(mask_q, mask_k.transpose(-1, -2))\n                mask_sum = torch.sum(mask_attn, dim=-1, keepdim=True)\n                mask_attn += (mask_sum == 0).float()\n                mask_attn = mask_attn.masked_fill_(mask_attn == 0, float('-inf')).masked_fill_(mask_attn == 1, float(0.0))\n            weights += mask_attn        \n\n        weights = self.dropout(F.softmax(weights, dim=-1))\n        if mask_q is not None and mask_k is not None:\n            weights = weights * (1 - (mask_sum == 0).float().detach())\n\n        out = torch.matmul(weights, value) # (B, h, H/S, W/S, S*S, D/h)\n        if self.weighted_output:\n            window_weight = self.window_weight.view(1, 1, 1, 1, self.window_size ** 2, 1)\n            out = out * window_weight\n        out = self.demake_window(out) #(B, D, H, W)\n        return out\n\nclass SowAttention(nn.Module):\n    def __init__(self, window_size, in_channels, proj_channels, value_channels, out_channels, \n                 num_heads=1, dropout=0.0, bias=True):\n        super(SowAttention, self).__init__()\n        assert window_size % 2 == 0\n        self.window_size = window_size\n        self.pad = nn.ZeroPad2d(window_size // 2)\n        self.window_attention = WindowAttention(window_size, in_channels, proj_channels, value_channels,\n                                            out_channels, num_heads, dropout, bias)\n\n    def forward(self, fea_q, fea_k, fea_v, mask_q=None, mask_k=None):\n        '''\n        fea: (b, d, h, w)\n        mask: (b, c, h, w)\n        '''\n        out_0 = self.window_attention(fea_q, fea_k, fea_v, mask_q, mask_k)\n        \n        fea_q = self.pad(fea_q)\n        fea_k = self.pad(fea_k)\n        fea_v = self.pad(fea_v)\n        if mask_q is not None and mask_k is not None:\n            mask_q = self.pad(mask_q)\n            mask_k = self.pad(mask_k)\n        else:\n            mask_q = None; mask_k = None\n        \n        out_1 = self.window_attention(fea_q, fea_k, fea_v, mask_q, mask_k)\n        out_1 = out_1[:, :, self.window_size//2:-self.window_size//2, self.window_size//2:-self.window_size//2]\n        \n        if mask_q is not None and mask_k is not None:\n            out_2 = self.window_attention(\n                fea_q[:, :, :, self.window_size//2:-self.window_size//2],\n                fea_k[:, :, :, self.window_size//2:-self.window_size//2],\n                fea_v[:, :, :, self.window_size//2:-self.window_size//2],\n                mask_q[:, :, :, self.window_size//2:-self.window_size//2],\n                mask_k[:, :, :, self.window_size//2:-self.window_size//2]\n            )\n        else:\n            out_2 = self.window_attention(\n                fea_q[:, :, :, self.window_size//2:-self.window_size//2],\n                fea_k[:, :, :, self.window_size//2:-self.window_size//2],\n                fea_v[:, :, :, self.window_size//2:-self.window_size//2],\n            )\n        out_2 = out_2[:, :, self.window_size//2:-self.window_size//2, :]\n\n        if mask_q is not None and mask_k is not None:\n            out_3 = self.window_attention(\n                fea_q[:, :, self.window_size//2:-self.window_size//2, :],\n                fea_k[:, :, self.window_size//2:-self.window_size//2, :],\n                fea_v[:, :, self.window_size//2:-self.window_size//2, :],\n                mask_q[:, :, self.window_size//2:-self.window_size//2, :],\n                mask_k[:, :, self.window_size//2:-self.window_size//2, :]\n            )\n        else:\n            out_3 = self.window_attention(\n                fea_q[:, :, self.window_size//2:-self.window_size//2, :],\n                fea_k[:, :, self.window_size//2:-self.window_size//2, :],\n                fea_v[:, :, self.window_size//2:-self.window_size//2, :],\n            )\n        out_3 = out_3[:, :, :, self.window_size//2:-self.window_size//2]\n\n        out = out_0 + out_1 + out_2 + out_3\n        return out\n\nclass StridedwindowAttention(nn.Module):\n    def __init__(self, stride, in_channels, proj_channels, value_channels, out_channels, \n                 num_heads=1, dropout=0.0, bias=True):\n        super(StridedwindowAttention, self).__init__()\n        self.stride = stride\n        self.num_heads = num_heads\n        self.dropout = nn.Dropout(dropout)\n        self.in_channels = in_channels\n        self.proj_channels = proj_channels\n        head_dim = proj_channels // num_heads\n        assert head_dim * num_heads == self.proj_channels, \"embed_dim must be divisible by num_heads\"\n        self.scaling = head_dim ** -0.5\n\n        self.q_proj = nn.Conv2d(in_channels, proj_channels, kernel_size=1, bias=bias)\n        self.k_proj = nn.Conv2d(in_channels, proj_channels, kernel_size=1, bias=bias)\n\n        self.value_channels = value_channels\n        self.out_channels = out_channels\n        assert out_channels // num_heads * num_heads == self.out_channels\n        self.v_proj = nn.Conv2d(value_channels, out_channels, kernel_size=1, bias=bias)\n\n    def make_window(self, x: torch.Tensor):\n        \"\"\"\n        input: (B, C, H, W)\n        output: (B, h, S(h), S(w), H/S * W/S, C/h)\n        \"\"\"\n        bsz, dim, h, w = x.shape\n        assert h % self.stride == 0 and w % self.stride == 0\n        \n        x = x.view(bsz, self.num_heads, dim // self.num_heads, h // self.stride, self.stride, \n                   w // self.stride, self.stride) # (B, h, C/h, H/S, S(h), W/S, S(w))\n        x = x.permute(0, 1, 4, 6, 3, 5, 2).contiguous() # (B, h, S(h), S(w), H/S, W/S, C/h)\n        x = x.view(bsz, self.num_heads, self.stride, self.stride,  \n                   h // self.stride * w // self.stride, dim // self.num_heads)\n        return x\n\n    def demake_window(self, x: torch.Tensor, h, w):\n        \"\"\"\n        input: (B, h, S(h), S(w), H/S * W/S, C/h)\n        output: (B, C, H, W)\n        \"\"\"\n        bsz, _, _, _, _, dim_h = x.shape\n        x = x.view(bsz, self.num_heads, self.stride, self.stride,  \n                   h // self.stride, w // self.stride, dim_h) # (B, h, S(h), S(w), H/S, W/S, C/h)\n        x = x.permute(0, 1, 6, 4, 2, 5, 3).contiguous() # (B, h, C/h, H/S, S(h), W/S, S(w))\n        x = x.view(bsz, dim_h * self.num_heads, h, w)\n        return x\n\n    @torch.no_grad()\n    def make_mask_window(self, mask: torch.Tensor):\n        \"\"\"\n        input: (B, C, H, W)\n        output: (B, 1, S(h), S(w), H/S * W/S, C)\n        \"\"\"\n        bsz, mask_channel, h, w = mask.shape\n        assert h % self.stride == 0 and w % self.stride == 0\n\n        mask = mask.view(bsz, 1, mask_channel, h // self.stride, self.stride, w // self.stride, self.stride)\n        mask = mask.permute(0, 1, 4, 6, 3, 5, 2).contiguous()\n        mask = mask.view(bsz, 1, self.stride, self.stride, h // self.stride * w // self.stride, mask_channel)\n        return mask\n    \n    def forward(self, fea_q, fea_k, fea_v, mask_q=None, mask_k=None):\n        '''\n        fea: (b, d, h, w)\n        mask: (b, c, h, w)\n        '''\n        bsz, _, h, w = fea_q.shape\n        \n        query = self.q_proj(fea_q) # (B, D, H, W)\n        key = self.k_proj(fea_k)\n        value = self.v_proj(fea_v)\n        query = self.make_window(query) # (B, h, S(h), S(w), H/S * W/S, C/h)\n        key = self.make_window(key)\n        value = self.make_window(value)\n        \n        weights = torch.matmul(query, key.transpose(-1, -2)) # (B, h, S(h), S(w), H/S * W/S, H/S * W/S)\n        weights = weights * self.scaling\n        if mask_q is not None and mask_k is not None:\n            mask_q = self.make_mask_window(mask_q) # (B, 1, S(h), S(w), H/S * W/S, C)\n            mask_k = self.make_mask_window(mask_k)\n            with torch.no_grad():\n                mask_attn = torch.matmul(mask_q, mask_k.transpose(-1, -2))\n                mask_sum = torch.sum(mask_attn, dim=-1, keepdim=True)\n                mask_attn += (mask_sum == 0).float()\n                mask_attn = mask_attn.masked_fill_(mask_attn == 0, float('-inf')).masked_fill_(mask_attn == 1, float(0.0))\n            weights += mask_attn        \n\n        weights = self.dropout(F.softmax(weights, dim=-1))\n        if mask_q is not None and mask_k is not None:\n            weights = weights * (1 - (mask_sum == 0).float().detach())\n\n        out = torch.matmul(weights, value) # (B, h, S(h), S(w), H/S * W/S, D/h)\n        out = self.demake_window(out, h, w) #(B, D, H, W)\n        return out\n    "
  },
  {
    "path": "models/modules/spectral_norm.py",
    "content": "import torch\nfrom torch.nn import Parameter\n\ndef l2normalize(v, eps=1e-12):\n    return v / (v.norm() + eps)\n\nclass SpectralNorm(object):\n    def __init__(self):\n        self.name = \"weight\"\n        #print(self.name)\n        self.power_iterations = 1\n\n    def compute_weight(self, module):\n        u = getattr(module, self.name + \"_u\")\n        v = getattr(module, self.name + \"_v\")\n        w = getattr(module, self.name + \"_bar\")\n\n        height = w.data.shape[0]\n        for _ in range(self.power_iterations):\n            v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))\n            u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))\n        # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))\n        sigma = u.dot(w.view(height, -1).mv(v))\n        return w / sigma.expand_as(w)\n\n    @staticmethod\n    def apply(module):\n        name = \"weight\"\n        fn = SpectralNorm()\n\n        try:\n            u = getattr(module, name + \"_u\")\n            v = getattr(module, name + \"_v\")\n            w = getattr(module, name + \"_bar\")\n        except AttributeError:\n            w = getattr(module, name)\n            height = w.data.shape[0]\n            width = w.view(height, -1).data.shape[1]\n            u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)\n            v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)\n            w_bar = Parameter(w.data)\n\n            #del module._parameters[name]\n\n            module.register_parameter(name + \"_u\", u)\n            module.register_parameter(name + \"_v\", v)\n            module.register_parameter(name + \"_bar\", w_bar)\n\n        # remove w from parameter list\n        del module._parameters[name]\n\n        setattr(module, name, fn.compute_weight(module))\n\n        # recompute weight before every forward()\n        module.register_forward_pre_hook(fn)\n\n        return fn\n\n    def remove(self, module):\n        weight = self.compute_weight(module)\n        delattr(module, self.name)\n        del module._parameters[self.name + '_u']\n        del module._parameters[self.name + '_v']\n        del module._parameters[self.name + '_bar']\n        module.register_parameter(self.name, Parameter(weight.data))\n\n    def __call__(self, module, inputs):\n        setattr(module, self.name, self.compute_weight(module))\n\ndef spectral_norm(module):\n    SpectralNorm.apply(module)\n    return module\n\ndef remove_spectral_norm(module):\n    name = 'weight'\n    for k, hook in module._forward_pre_hooks.items():\n        if isinstance(hook, SpectralNorm) and hook.name == name:\n            hook.remove(module)\n            del module._forward_pre_hooks[k]\n            return module\n\n    raise ValueError(\"spectral_norm of '{}' not found in {}\"\n                     .format(name, module))"
  },
  {
    "path": "models/modules/tps_transform.py",
    "content": "from __future__ import absolute_import\n\nimport numpy as np\nimport itertools\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n# TF32 is not enough, require FP32\n# Disable automatic TF32 since Pytorch 1.7\ntorch.backends.cuda.matmul.allow_tf32 = False\ntorch.backends.cudnn.allow_tf32 = False\n\ndef grid_sample(input, grid, mode='bilinear', canvas=None):\n    output = F.grid_sample(input, grid, mode=mode, align_corners=True)\n    if canvas is None:\n        return output\n    else:\n        input_mask = input.data.new(input.size()).fill_(1)\n        output_mask = F.grid_sample(input_mask, grid, mode='nearest', align_corners=True)\n        padded_output = output * output_mask + canvas * (1 - output_mask)\n        return padded_output\n\n\n# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2\ndef compute_partial_repr(input_points, control_points):\n    N = input_points.size(0)\n    M = control_points.size(0)\n    pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2)\n    # original implementation, very slow\n    # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance\n    pairwise_diff_square = pairwise_diff * pairwise_diff\n    pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, 1]\n    repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)\n    #repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist + 1e-8)\n    # fix numerical error for 0 * log(0), substitute all nan with 0\n    mask = repr_matrix != repr_matrix\n    repr_matrix.masked_fill_(mask, 0)\n    return repr_matrix\n\n\n# compute \\Delta_c^-1\ndef bulid_delta_inverse(target_control_points):\n    '''\n    target_control_points: (N, 2)\n    '''\n    N = target_control_points.shape[0]\n    forward_kernel = torch.zeros(N + 3, N + 3).to(target_control_points.device)\n    target_control_partial_repr = compute_partial_repr(target_control_points, target_control_points)\n    forward_kernel[:N, :N].copy_(target_control_partial_repr)\n    forward_kernel[:N, -3].fill_(1)\n    forward_kernel[-3, :N].fill_(1)\n    forward_kernel[:N, -2:].copy_(target_control_points)\n    forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))\n    # compute inverse matrix\n    inverse_kernel = torch.inverse(forward_kernel)\n    return inverse_kernel\n\n\n# create target coordinate matrix\ndef build_target_coordinate_matrix(target_height, target_width, target_control_points):\n    '''\n    target_control_points: (N, 2)\n    '''\n    HW = target_height * target_width\n    target_coordinate = list(itertools.product(range(target_height), range(target_width)))\n    target_coordinate = torch.Tensor(target_coordinate).to(target_control_points.device) # HW x 2\n    Y, X = target_coordinate.split(1, dim = 1)\n    Y = Y / (target_height - 1)\n    X = X / (target_width - 1)\n    target_coordinate = torch.cat([X, Y], dim = 1) # convert from (y, x) to (x, y)\n    target_coordinate_partial_repr = compute_partial_repr(target_coordinate, target_control_points)\n    target_coordinate_repr = torch.cat([\n        target_coordinate_partial_repr, \n        torch.ones((HW, 1), device=target_control_points.device), \n        target_coordinate], dim = 1)\n    return target_coordinate_repr\n\n\ndef tps_sampler(target_height, target_width, inverse_kernel, target_coordinate_repr,\n                source, source_control_points, sample_mode='bilinear'):\n    '''\n    inverse_kernel: \\Delta_C^-1\n    target_coordinate_repr: \\hat{p}\n    source: (B, C, H, W)\n    source_control_points: (B, N, 2)\n    '''\n    batch_size = source.shape[0]\n    Y = torch.cat([source_control_points, torch.zeros((batch_size, 3, 2), device=source.device)], dim=1)\n    mapping_matrix = torch.matmul(inverse_kernel, Y)\n    source_coordinate = torch.matmul(target_coordinate_repr, mapping_matrix)\n\n    grid = source_coordinate.view(-1, target_height, target_width, 2)\n    grid = torch.clamp(grid, 0, 1) # the source_control_points may be out of [0, 1].\n    # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]\n    grid = 2.0 * grid - 1.0\n    output_maps = grid_sample(source, grid, mode=sample_mode, canvas=None)\n    return output_maps, source_coordinate\n\n\ndef tps_spatial_transform(target_height, target_width, target_control_points, \n                          source, source_control_points, sample_mode='bilinear'):\n    '''\n    target_control_points: (N, 2)\n    source: (B, C, H, W)\n    source_control_points: (B, N, 2)\n    '''\n    inverse_kernel = bulid_delta_inverse(target_control_points)\n    target_coordinate_repr = build_target_coordinate_matrix(target_height, target_width, target_control_points)\n    \n    return tps_sampler(target_height, target_width, inverse_kernel, target_coordinate_repr, \n                       source, source_control_points, sample_mode)\n\n\nclass TPSSpatialTransformer(nn.Module):\n\n    def __init__(self, target_height, target_width, target_control_points):\n        super(TPSSpatialTransformer, self).__init__()\n        self.target_height, self.target_width = target_height, target_width\n        self.num_control_points = target_control_points.shape[0]\n    \n        # create padded kernel matrix\n        inverse_kernel = bulid_delta_inverse(target_control_points)\n    \n        # create target coordinate matrix\n        target_coordinate_repr = build_target_coordinate_matrix(target_height, target_width, target_control_points)\n    \n        # register precomputed matrices\n        self.register_buffer('inverse_kernel', inverse_kernel)\n        #self.register_buffer('padding_matrix', torch.zeros(3, 2))\n        self.register_buffer('target_coordinate_repr', target_coordinate_repr)\n        self.register_buffer('target_control_points', target_control_points)\n    \n    def forward(self, source, source_control_points):\n        assert source_control_points.ndimension() == 3\n        assert source_control_points.size(1) == self.num_control_points\n        assert source_control_points.size(2) == 2\n        \n        return tps_sampler(self.target_height, self.target_width,\n                           self.inverse_kernel, self.target_coordinate_repr,\n                           source, source_control_points)\n "
  },
  {
    "path": "scripts/demo.py",
    "content": "import os\nimport sys\nimport argparse\nimport numpy as np\nimport cv2\nimport torch\nfrom PIL import Image\nsys.path.append('.')\n\nfrom training.config import get_config\nfrom training.inference import Inference\nfrom training.utils import create_logger, print_args\n\ndef main(config, args):\n    logger = create_logger(args.save_folder, args.name, 'info', console=True)\n    print_args(args, logger)\n    logger.info(config)\n\n    inference = Inference(config, args, args.load_path)\n\n    n_imgname = sorted(os.listdir(args.source_dir))\n    m_imgname = sorted(os.listdir(args.reference_dir))\n    \n    for i, (imga_name, imgb_name) in enumerate(zip(n_imgname, m_imgname)):\n        imgA = Image.open(os.path.join(args.source_dir, imga_name)).convert('RGB')\n        imgB = Image.open(os.path.join(args.reference_dir, imgb_name)).convert('RGB')\n\n        result = inference.transfer(imgA, imgB, postprocess=True) \n        if result is None:\n            continue\n        imgA = np.array(imgA); imgB = np.array(imgB)\n        h, w, _ = imgA.shape\n        result = result.resize((h, w)); result = np.array(result)\n        vis_image = np.hstack((imgA, imgB, result))\n        save_path = os.path.join(args.save_folder, f\"result_{i}.png\")\n        Image.fromarray(vis_image.astype(np.uint8)).save(save_path)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\"argument for training\")\n    parser.add_argument(\"--name\", type=str, default='demo')\n    parser.add_argument(\"--save_path\", type=str, default='result', help=\"path to save model\")\n    parser.add_argument(\"--load_path\", type=str, help=\"folder to load model\", \n                        default='ckpts/sow_pyramid_a5_e3d2_remapped.pth')\n\n    parser.add_argument(\"--source-dir\", type=str, default=\"assets/images/non-makeup\")\n    parser.add_argument(\"--reference-dir\", type=str, default=\"assets/images/makeup\")\n    parser.add_argument(\"--gpu\", default='0', type=str, help=\"GPU id to use.\")\n\n    args = parser.parse_args()\n    args.gpu = 'cuda:' + args.gpu\n    args.device = torch.device(args.gpu)\n\n    args.save_folder = os.path.join(args.save_path, args.name)\n    if not os.path.exists(args.save_folder):\n        os.makedirs(args.save_folder)\n    \n    config = get_config()\n    main(config, args)"
  },
  {
    "path": "scripts/train.py",
    "content": "import os\nimport sys\nimport argparse\nimport torch\nfrom torch.utils.data import DataLoader\nsys.path.append('.')\n\nfrom training.config import get_config\nfrom training.dataset import MakeupDataset\nfrom training.solver import Solver\nfrom training.utils import create_logger, print_args\n\n\ndef main(config, args):\n    logger = create_logger(args.save_folder, args.name, 'info', console=True)\n    print_args(args, logger)\n    logger.info(config)\n    \n    dataset = MakeupDataset(config)\n    data_loader = DataLoader(dataset, batch_size=config.DATA.BATCH_SIZE, num_workers=config.DATA.NUM_WORKERS, shuffle=True)\n    \n    solver = Solver(config, args, logger)\n    solver.train(data_loader)\n    \n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\"argument for training\")\n    parser.add_argument(\"--name\", type=str, default='elegant')\n    parser.add_argument(\"--save_path\", type=str, default='results', help=\"path to save model\")\n    parser.add_argument(\"--load_folder\", type=str, help=\"path to load model\", \n                        default=None)\n    parser.add_argument(\"--keepon\", default=False, action=\"store_true\", help='keep on training')\n\n    parser.add_argument(\"--gpu\", default='0', type=str, help=\"GPU id to use.\")\n\n    args = parser.parse_args()\n    config = get_config()\n    \n    #args.gpu = 'cuda:' + args.gpu\n    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu\n    #args.device = torch.device(args.gpu)\n    args.device = torch.device('cuda:0')\n\n    args.save_folder = os.path.join(args.save_path, args.name)\n    if not os.path.exists(args.save_folder):\n        os.makedirs(args.save_folder)    \n    \n    main(config, args)"
  },
  {
    "path": "training/__init__.py",
    "content": ""
  },
  {
    "path": "training/config.py",
    "content": "from fvcore.common.config import CfgNode\n\n\"\"\"\nThis file defines default options of configurations.\nIt will be further merged by yaml files and options from\nthe command-line.\nNote that *any* hyper-parameters should be firstly defined\nhere to enable yaml and command-line configuration.\n\"\"\"\n\n_C = CfgNode()\n\n# Logging and saving\n_C.LOG = CfgNode()\n_C.LOG.SAVE_FREQ = 10\n_C.LOG.VIS_FREQ = 1\n\n# Data settings\n_C.DATA = CfgNode()\n_C.DATA.PATH = './data/MT-Dataset'\n_C.DATA.NUM_WORKERS = 4\n_C.DATA.BATCH_SIZE = 1\n_C.DATA.IMG_SIZE = 256\n\n# Training hyper-parameters\n_C.TRAINING = CfgNode()\n_C.TRAINING.G_LR = 2e-4\n_C.TRAINING.D_LR = 2e-4\n_C.TRAINING.BETA1 = 0.5\n_C.TRAINING.BETA2 = 0.999\n_C.TRAINING.NUM_EPOCHS = 50\n_C.TRAINING.LR_DECAY_FACTOR = 5e-2\n_C.TRAINING.DOUBLE_D = False\n\n# Loss weights\n_C.LOSS = CfgNode()\n_C.LOSS.LAMBDA_A = 10.0\n_C.LOSS.LAMBDA_B = 10.0\n_C.LOSS.LAMBDA_IDT = 0.5\n_C.LOSS.LAMBDA_REC = 10\n_C.LOSS.LAMBDA_MAKEUP = 100\n_C.LOSS.LAMBDA_SKIN = 0.1\n_C.LOSS.LAMBDA_EYE = 1.5\n_C.LOSS.LAMBDA_LIP = 1\n_C.LOSS.LAMBDA_MAKEUP_LIP = _C.LOSS.LAMBDA_MAKEUP * _C.LOSS.LAMBDA_LIP\n_C.LOSS.LAMBDA_MAKEUP_SKIN = _C.LOSS.LAMBDA_MAKEUP * _C.LOSS.LAMBDA_SKIN\n_C.LOSS.LAMBDA_MAKEUP_EYE = _C.LOSS.LAMBDA_MAKEUP * _C.LOSS.LAMBDA_EYE\n_C.LOSS.LAMBDA_VGG = 5e-3\n\n# Model structure\n_C.MODEL = CfgNode()\n_C.MODEL.D_TYPE = 'SN'\n_C.MODEL.D_REPEAT_NUM = 3\n_C.MODEL.D_CONV_DIM = 64\n_C.MODEL.G_CONV_DIM = 64\n_C.MODEL.NUM_HEAD = 1\n_C.MODEL.DOUBLE_E = False\n_C.MODEL.USE_FF = False\n_C.MODEL.NUM_LAYER_E = 3\n_C.MODEL.NUM_LAYER_D = 2\n_C.MODEL.WINDOW_SIZE = 16\n_C.MODEL.MERGE_MODE = 'conv'\n\n# Preprocessing\n_C.PREPROCESS = CfgNode()\n_C.PREPROCESS.UP_RATIO = 0.6 / 0.85  # delta_size / face_size\n_C.PREPROCESS.DOWN_RATIO = 0.2 / 0.85  # delta_size / face_size\n_C.PREPROCESS.WIDTH_RATIO = 0.2 / 0.85  # delta_size / face_size\n_C.PREPROCESS.LIP_CLASS = [7, 9]\n_C.PREPROCESS.FACE_CLASS = [1, 6]\n_C.PREPROCESS.EYEBROW_CLASS = [2, 3]\n_C.PREPROCESS.EYE_CLASS = [4, 5]\n_C.PREPROCESS.LANDMARK_POINTS = 68\n\n# Pseudo ground truth\n_C.PGT = CfgNode()\n_C.PGT.EYE_MARGIN = 12\n_C.PGT.LIP_MARGIN = 4\n_C.PGT.ANNEALING = True\n_C.PGT.SKIN_ALPHA = 0.3\n_C.PGT.SKIN_ALPHA_MILESTONES = (0, 12, 24, 50)\n_C.PGT.SKIN_ALPHA_VALUES = (0.2, 0.4, 0.3, 0.2)\n_C.PGT.EYE_ALPHA = 0.8\n_C.PGT.EYE_ALPHA_MILESTONES = (0, 12, 24, 50)\n_C.PGT.EYE_ALPHA_VALUES = (0.6, 0.8, 0.6, 0.4)\n_C.PGT.LIP_ALPHA = 0.1\n_C.PGT.LIP_ALPHA_MILESTONES = (0, 12, 24, 50)\n_C.PGT.LIP_ALPHA_VALUES = (0.05, 0.2, 0.1, 0.0)\n\n# Postprocessing\n_C.POSTPROCESS = CfgNode()\n_C.POSTPROCESS.WILL_DENOISE = False\n\ndef get_config()->CfgNode:\n    return _C\n"
  },
  {
    "path": "training/dataset.py",
    "content": "import os\nfrom PIL import Image\nimport torch\nfrom torch.utils.data import Dataset, DataLoader\n\nfrom training.config import get_config\nfrom training.preprocess import PreProcess\n\nclass MakeupDataset(Dataset):\n    def __init__(self, config=None):\n        super(MakeupDataset, self).__init__()\n        if config is None:\n            config = get_config()\n        self.root = config.DATA.PATH\n        with open(os.path.join(config.DATA.PATH, 'makeup.txt'), 'r') as f:\n            self.makeup_names = [name.strip() for name in f.readlines()]\n        with open(os.path.join(config.DATA.PATH, 'non-makeup.txt'), 'r') as f:\n            self.non_makeup_names = [name.strip() for name in f.readlines()]\n        self.preprocessor = PreProcess(config, need_parser=False)\n        self.img_size = config.DATA.IMG_SIZE\n\n    def load_from_file(self, img_name):\n        image = Image.open(os.path.join(self.root, 'images', img_name)).convert('RGB')\n        mask = self.preprocessor.load_mask(os.path.join(self.root, 'segs', img_name))\n        base_name = os.path.splitext(img_name)[0]\n        lms = self.preprocessor.load_lms(os.path.join(self.root, 'lms', f'{base_name}.npy'))\n        return self.preprocessor.process(image, mask, lms)\n    \n    def __len__(self):\n        return max(len(self.makeup_names), len(self.non_makeup_names))\n\n    def __getitem__(self, index):\n        idx_s = torch.randint(0, len(self.non_makeup_names), (1, )).item()\n        idx_r = torch.randint(0, len(self.makeup_names), (1, )).item()\n        name_s = self.non_makeup_names[idx_s]\n        name_r = self.makeup_names[idx_r]\n        source = self.load_from_file(name_s)\n        reference = self.load_from_file(name_r)\n        return source, reference\n\ndef get_loader(config):\n    dataset = MakeupDataset(config)\n    dataloader = DataLoader(dataset=dataset,\n                            batch_size=config.DATA.BATCH_SIZE,\n                            num_workers=config.DATA.NUM_WORKERS)\n    return dataloader\n\n\nif __name__ == \"__main__\":\n    dataset = MakeupDataset()\n    dataloader = DataLoader(dataset, batch_size=1, num_workers=16)\n    for e in range(10):\n        for i, (point_s, point_r) in enumerate(dataloader):\n            pass"
  },
  {
    "path": "training/inference.py",
    "content": "from typing import List\nimport numpy as np\nimport cv2\nfrom PIL import Image\nimport torch\nimport torch.nn.functional as F\nfrom torchvision.transforms import ToPILImage\n\nfrom training.solver import Solver\nfrom training.preprocess import PreProcess\nfrom models.modules.pseudo_gt import expand_area, mask_blend\n\nclass InputSample:\n    def __init__(self, inputs, apply_mask=None):\n        self.inputs = inputs\n        self.transfer_input = None\n        self.attn_out_list = None\n        self.apply_mask = apply_mask\n\n    def clear(self):\n        self.transfer_input = None\n        self.attn_out_list = None\n\n\nclass Inference:\n    \"\"\"\n    An inference wrapper for makeup transfer.\n    It takes two image `source` and `reference` in,\n    and transfers the makeup of reference to source.\n    \"\"\"\n    def __init__(self, config, args, model_path=\"G.pth\"):\n\n        self.device = args.device\n        self.solver = Solver(config, args, inference=model_path)\n        self.preprocess = PreProcess(config, args.device)\n        self.denoise = config.POSTPROCESS.WILL_DENOISE\n        self.img_size = config.DATA.IMG_SIZE\n        # TODO: can be a hyper-parameter\n        self.eyeblur = {'margin': 12, 'blur_size':7}\n\n    def prepare_input(self, *data_inputs):\n        \"\"\"\n        data_inputs: List[image, mask, diff, lms]\n        \"\"\"\n        inputs = []\n        for i in range(len(data_inputs)):\n            inputs.append(data_inputs[i].to(self.device).unsqueeze(0))\n        # prepare mask\n        inputs[1] = torch.cat((inputs[1][:,0:1], inputs[1][:,1:].sum(dim=1, keepdim=True)), dim=1)\n        return inputs\n\n    def postprocess(self, source, crop_face, result):\n        if crop_face is not None:\n            source = source.crop(\n                (crop_face.left(), crop_face.top(), crop_face.right(), crop_face.bottom()))\n        source = np.array(source)\n        result = np.array(result)\n\n        height, width = source.shape[:2]\n        small_source = cv2.resize(source, (self.img_size, self.img_size))\n        laplacian_diff = source.astype(\n            np.float) - cv2.resize(small_source, (width, height)).astype(np.float)\n        result = (cv2.resize(result, (width, height)) +\n                  laplacian_diff).round().clip(0, 255)\n\n        result = result.astype(np.uint8)\n\n        if self.denoise:\n            result = cv2.fastNlMeansDenoisingColored(result)\n        result = Image.fromarray(result).convert('RGB')\n        return result\n\n    \n    def generate_source_sample(self, source_input):\n        \"\"\"\n        source_input: List[image, mask, diff, lms]\n        \"\"\"\n        source_input = self.prepare_input(*source_input)\n        return InputSample(source_input)\n\n    def generate_reference_sample(self, reference_input, apply_mask=None, \n                                  source_mask=None, mask_area=None, saturation=1.0):\n        \"\"\"\n        all the operations on the mask, e.g., partial mask, saturation, \n        should be finally defined in apply_mask\n        \"\"\"\n        if source_mask is not None and mask_area is not None:\n            apply_mask = self.generate_partial_mask(source_mask, mask_area, saturation)\n            apply_mask = apply_mask.unsqueeze(0).to(self.device)\n        reference_input = self.prepare_input(*reference_input)\n        \n        if apply_mask is None:\n            apply_mask = torch.ones(1, 1, self.img_size, self.img_size).to(self.device)\n        return InputSample(reference_input, apply_mask)\n\n\n    def generate_partial_mask(self, source_mask, mask_area='full', saturation=1.0):\n        \"\"\"\n        source_mask: (C, H, W), lip, face, left eye, right eye\n        return: apply_mask: (1, H, W)\n        \"\"\"\n        assert mask_area in ['full', 'skin', 'lip', 'eye']\n        if mask_area == 'full':\n            return torch.sum(source_mask[0:2], dim=0, keepdim=True) * saturation\n        elif mask_area == 'lip':\n            return source_mask[0:1] * saturation\n        elif mask_area == 'skin':\n            mask_l_eye = expand_area(source_mask[2:3], self.eyeblur['margin']) #* source_mask[1:2]\n            mask_r_eye = expand_area(source_mask[3:4], self.eyeblur['margin']) #* source_mask[1:2]\n            mask_eye = mask_l_eye + mask_r_eye\n            #mask_eye = mask_blend(mask_eye, 1.0, source_mask[1:2], blur_size=self.eyeblur['blur_size'])\n            mask_eye = mask_blend(mask_eye, 1.0, blur_size=self.eyeblur['blur_size'])\n            return source_mask[1:2] * (1 - mask_eye) * saturation\n        elif mask_area == 'eye':\n            mask_l_eye = expand_area(source_mask[2:3], self.eyeblur['margin']) #* source_mask[1:2]\n            mask_r_eye = expand_area(source_mask[3:4], self.eyeblur['margin']) #* source_mask[1:2]\n            mask_eye = mask_l_eye + mask_r_eye\n            #mask_eye = mask_blend(mask_eye, saturation, source_mask[1:2], blur_size=self.eyeblur['blur_size'])\n            mask_eye = mask_blend(mask_eye, saturation, blur_size=self.eyeblur['blur_size'])\n            return mask_eye\n  \n\n    @torch.no_grad()\n    def interface_transfer(self, source_sample: InputSample, reference_samples: List[InputSample]):\n        \"\"\"\n        Input: a source sample and multiple reference samples\n        Return: PIL.Image, the fused result\n        \"\"\"\n        # encode source\n        if source_sample.transfer_input is None:\n            source_sample.transfer_input = self.solver.G.get_transfer_input(*source_sample.inputs)\n        \n        # encode references\n        for r_sample in reference_samples:\n            if r_sample.transfer_input is None:\n                r_sample.transfer_input = self.solver.G.get_transfer_input(*r_sample.inputs, True)\n\n        # self attention\n        if source_sample.attn_out_list is None:\n            source_sample.attn_out_list = self.solver.G.get_transfer_output(\n                    *source_sample.transfer_input, *source_sample.transfer_input\n                )\n        \n        # full transfer for each reference\n        for r_sample in reference_samples:\n            if r_sample.attn_out_list is None:\n                r_sample.attn_out_list = self.solver.G.get_transfer_output(\n                    *source_sample.transfer_input, *r_sample.transfer_input\n                )\n\n        # fusion\n        # if the apply_mask is changed without changing source and references,\n        # only the following steps are required\n        fused_attn_out_list = []\n        for i in range(len(source_sample.attn_out_list)):\n            init_attn_out = torch.zeros_like(source_sample.attn_out_list[i], device=self.device)\n            fused_attn_out_list.append(init_attn_out)\n        apply_mask_sum = torch.zeros((1, 1, self.img_size, self.img_size), device=self.device)\n        \n        for r_sample in reference_samples:\n            if r_sample.apply_mask is not None:\n                apply_mask_sum += r_sample.apply_mask\n                for i in range(len(source_sample.attn_out_list)):\n                    feature_size = r_sample.attn_out_list[i].shape[2]\n                    apply_mask = F.interpolate(r_sample.apply_mask, feature_size, mode='nearest')\n                    fused_attn_out_list[i] += apply_mask * r_sample.attn_out_list[i]\n\n        # self as reference\n        source_apply_mask = 1 - apply_mask_sum.clamp(0, 1)\n        for i in range(len(source_sample.attn_out_list)):\n            feature_size = source_sample.attn_out_list[i].shape[2]\n            apply_mask = F.interpolate(source_apply_mask, feature_size, mode='nearest')\n            fused_attn_out_list[i] += apply_mask * source_sample.attn_out_list[i]\n\n        # decode\n        result = self.solver.G.decode(\n            source_sample.transfer_input[0], fused_attn_out_list\n        )\n        result = self.solver.de_norm(result).squeeze(0)\n        result = ToPILImage()(result.cpu())\n        return result\n\n    \n    def transfer(self, source: Image, reference: Image, postprocess=True):\n        \"\"\"\n        Args:\n            source (Image): The image where makeup will be transfered to.\n            reference (Image): Image containing targeted makeup.\n        Return:\n            Image: Transfered image.\n        \"\"\"\n        source_input, face, crop_face = self.preprocess(source)\n        reference_input, _, _ = self.preprocess(reference)\n        if not (source_input and reference_input):\n            return None\n\n        #source_sample = self.generate_source_sample(source_input)\n        #reference_samples = [self.generate_reference_sample(reference_input)]\n        #result = self.interface_transfer(source_sample, reference_samples)\n        source_input = self.prepare_input(*source_input)\n        reference_input = self.prepare_input(*reference_input)\n        result = self.solver.test(*source_input, *reference_input)\n        \n        if not postprocess:\n            return result\n        else:\n            return self.postprocess(source, crop_face, result)\n\n    def joint_transfer(self, source: Image, reference_lip: Image, reference_skin: Image,\n                       reference_eye: Image, postprocess=True):\n        source_input, face, crop_face = self.preprocess(source)\n        lip_input, _, _ = self.preprocess(reference_lip)\n        skin_input, _, _ = self.preprocess(reference_skin)\n        eye_input, _, _ = self.preprocess(reference_eye)\n        if not (source_input and lip_input and skin_input and eye_input):\n            return None\n\n        source_mask = source_input[1]\n        source_sample = self.generate_source_sample(source_input)\n        reference_samples = [\n            self.generate_reference_sample(lip_input, source_mask=source_mask, mask_area='lip'),\n            self.generate_reference_sample(skin_input, source_mask=source_mask, mask_area='skin'),\n            self.generate_reference_sample(eye_input, source_mask=source_mask, mask_area='eye')\n        ]\n        \n        result = self.interface_transfer(source_sample, reference_samples)\n        \n        if not postprocess:\n            return result\n        else:\n            return self.postprocess(source, crop_face, result)"
  },
  {
    "path": "training/preprocess.py",
    "content": "import os\nimport sys\nimport cv2\nfrom PIL import Image\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torchvision import transforms\nfrom torchvision.transforms import functional\nsys.path.append('.')\n\nimport faceutils as futils\nfrom training.config import get_config\n\nclass PreProcess:\n\n    def __init__(self, config, need_parser=True, device='cpu'):\n        self.img_size = config.DATA.IMG_SIZE   \n        self.device = device\n\n        xs, ys = np.meshgrid(\n            np.linspace(\n                0, self.img_size - 1,\n                self.img_size\n            ),\n            np.linspace(\n                0, self.img_size - 1,\n                self.img_size\n            )\n        )\n        xs = xs[None].repeat(config.PREPROCESS.LANDMARK_POINTS, axis=0)\n        ys = ys[None].repeat(config.PREPROCESS.LANDMARK_POINTS, axis=0)\n        fix = np.concatenate([ys, xs], axis=0) \n        self.fix = torch.Tensor(fix) #(136, h, w)\n        if need_parser:\n            self.face_parse = futils.mask.FaceParser(device=device)\n\n        self.up_ratio    = config.PREPROCESS.UP_RATIO\n        self.down_ratio  = config.PREPROCESS.DOWN_RATIO\n        self.width_ratio = config.PREPROCESS.WIDTH_RATIO\n        self.lip_class   = config.PREPROCESS.LIP_CLASS\n        self.face_class  = config.PREPROCESS.FACE_CLASS\n        self.eyebrow_class  = config.PREPROCESS.EYEBROW_CLASS\n        self.eye_class  = config.PREPROCESS.EYE_CLASS\n\n        self.transform = transforms.Compose([\n            transforms.Resize(config.DATA.IMG_SIZE),\n            transforms.ToTensor(),\n            transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])\n    \n    ############################## Mask Process ##############################\n    # mask attribute: 0:background 1:face 2:left-eyebrow 3:right-eyebrow 4:left-eye 5: right-eye 6: nose\n    # 7: upper-lip 8: teeth 9: under-lip 10:hair 11: left-ear 12: right-ear 13: neck\n    def mask_process(self, mask: torch.Tensor):\n        '''\n        mask: (1, h, w)\n        '''        \n        mask_lip = (mask == self.lip_class[0]).float() + (mask == self.lip_class[1]).float()\n        mask_face = (mask == self.face_class[0]).float() + (mask == self.face_class[1]).float()\n\n        #mask_eyebrow_left = (mask == self.eyebrow_class[0]).float()\n        #mask_eyebrow_right = (mask == self.eyebrow_class[1]).float()\n        mask_face += (mask == self.eyebrow_class[0]).float()\n        mask_face += (mask == self.eyebrow_class[1]).float()\n\n        mask_eye_left = (mask == self.eye_class[0]).float()\n        mask_eye_right = (mask == self.eye_class[1]).float()\n\n        #mask_list = [mask_lip, mask_face, mask_eyebrow_left, mask_eyebrow_right, mask_eye_left, mask_eye_right]\n        mask_list = [mask_lip, mask_face, mask_eye_left, mask_eye_right]\n        mask_aug = torch.cat(mask_list, 0) # (C, H, W)\n        return mask_aug      \n\n    def save_mask(self, mask: torch.Tensor, path):\n        assert mask.shape[0] == 1\n        mask = mask.squeeze(0).numpy().astype(np.uint8)\n        mask = Image.fromarray(mask)\n        mask.save(path)\n\n    def load_mask(self, path):\n        mask = np.array(Image.open(path).convert('L'))\n        mask = torch.FloatTensor(mask).unsqueeze(0)\n        mask = functional.resize(mask, self.img_size, transforms.InterpolationMode.NEAREST)\n        return mask\n    \n    ############################## Landmarks Process ##############################\n    def lms_process(self, image:Image):\n        face = futils.dlib.detect(image)\n        # face: rectangles, List of rectangles of face region: [(left, top), (right, bottom)]\n        if not face:\n            return None\n        face = face[0]\n        lms = futils.dlib.landmarks(image, face) * self.img_size / image.width # scale to fit self.img_size\n        # lms: narray, the position of 68 key points, (68 ,2)\n        lms = torch.IntTensor(lms.round()).clamp_max_(self.img_size - 1)\n        # distinguish upper and lower lips \n        lms[61:64,0] -= 1; lms[65:68,0] += 1\n        for i in range(3):\n            if torch.sum(torch.abs(lms[61+i] - lms[67-i])) == 0:\n                lms[61+i,0] -= 1;  lms[67-i,0] += 1\n        # double check\n        '''for i in range(48, 67):\n            for j in range(i+1, 68):\n                if torch.sum(torch.abs(lms[i] - lms[j])) == 0:\n                    lms[i,0] -= 1; lms[j,0] += 1'''\n        return lms       \n    \n    def diff_process(self, lms: torch.Tensor, normalize=False):\n        '''\n        lms:(68, 2)\n        '''\n        lms = lms.transpose(1, 0).reshape(-1, 1, 1) # (136, 1, 1)\n        diff = self.fix - lms # (136, h, w)\n\n        if normalize:\n            norm = torch.norm(diff, dim=0, keepdim=True).repeat(diff.shape[0], 1, 1)\n            norm = torch.where(norm == 0, torch.tensor(1e10), norm)\n            diff /= norm\n        return diff\n\n    def save_lms(self, lms: torch.Tensor, path):\n        lms = lms.numpy()\n        np.save(path, lms)\n    \n    def load_lms(self, path):\n        lms = np.load(path)\n        return torch.IntTensor(lms)\n\n    ############################## Compose Process ##############################\n    def preprocess(self, image: Image, is_crop=True):\n        '''\n        return: image: Image, (H, W), mask: tensor, (1, H, W)\n        '''\n        face = futils.dlib.detect(image)\n        # face: rectangles, List of rectangles of face region: [(left, top), (right, bottom)]\n        if not face:\n            return None, None, None\n\n        face_on_image = face[0]\n        if is_crop:\n            image, face, crop_face = futils.dlib.crop(\n                image, face_on_image, self.up_ratio, self.down_ratio, self.width_ratio)\n        else:\n            face = face[0]; crop_face = None\n        # image: Image, cropped face\n        # face: the same as above\n        # crop face: rectangle, face region in cropped face\n        np_image = np.array(image) # (h', w', 3)\n\n        mask = self.face_parse.parse(cv2.resize(np_image, (512, 512))).cpu()\n        # obtain face parsing result\n        # mask: Tensor, (512, 512)\n        mask = F.interpolate(\n            mask.view(1, 1, 512, 512),\n            (self.img_size, self.img_size),\n            mode=\"nearest\").squeeze(0).long() #(1, H, W)\n\n        lms = futils.dlib.landmarks(image, face) * self.img_size / image.width # scale to fit self.img_size\n        # lms: narray, the position of 68 key points, (68 ,2)\n        lms = torch.IntTensor(lms.round()).clamp_max_(self.img_size - 1)\n        # distinguish upper and lower lips \n        lms[61:64,0] -= 1; lms[65:68,0] += 1\n        for i in range(3):\n            if torch.sum(torch.abs(lms[61+i] - lms[67-i])) == 0:\n                lms[61+i,0] -= 1;  lms[67-i,0] += 1\n\n        image = image.resize((self.img_size, self.img_size), Image.ANTIALIAS)\n        return [image, mask, lms], face_on_image, crop_face\n    \n    def process(self, image: Image, mask: torch.Tensor, lms: torch.Tensor):\n        image = self.transform(image)\n        mask = self.mask_process(mask)\n        diff = self.diff_process(lms)\n        return [image, mask, diff, lms]\n    \n    def __call__(self, image:Image, is_crop=True):\n        source, face_on_image, crop_face = self.preprocess(image, is_crop)\n        if source is None:\n            return None, None, None\n        return self.process(*source), face_on_image, crop_face\n\n\nif __name__ == \"__main__\":\n    config = get_config()\n    preprocessor = PreProcess(config, device='cuda:0')\n    if not os.path.exists(os.path.join(config.DATA.PATH, 'lms')):\n        os.makedirs(os.path.join(config.DATA.PATH, 'lms', 'makeup'))\n        os.makedirs(os.path.join(config.DATA.PATH, 'lms', 'non-makeup'))\n    \n    # process makeup images\n    print(\"Processing makeup images...\")\n    with open(os.path.join(config.DATA.PATH, 'makeup.txt'), 'r') as f:\n        for line in f.readlines():\n            img_name = line.strip()\n            raw_image = Image.open(os.path.join(config.DATA.PATH, 'images', img_name)).convert('RGB')\n            lms = preprocessor.lms_process(raw_image)\n            if lms is not None:\n                base_name = os.path.splitext(img_name)[0]\n                preprocessor.save_lms(lms, os.path.join(config.DATA.PATH, 'lms', f'{base_name}.npy'))\n    print(\"Done.\")\n\n    # process non-makeup images\n    print(\"Processing non-makeup images...\")\n    with open(os.path.join(config.DATA.PATH, 'non-makeup.txt'), 'r') as f:\n        for line in f.readlines():\n            img_name = line.strip()\n            raw_image = Image.open(os.path.join(config.DATA.PATH, 'images', img_name)).convert('RGB')\n            lms = preprocessor.lms_process(raw_image)\n            if lms is not None:\n                base_name = os.path.splitext(img_name)[0]\n                preprocessor.save_lms(lms, os.path.join(config.DATA.PATH, 'lms', f'{base_name}.npy'))\n    print(\"Done.\")\n    "
  },
  {
    "path": "training/solver.py",
    "content": "import os\nimport time\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torchvision.transforms import ToPILImage\nfrom torchvision.utils import save_image, make_grid\nimport torch.nn.init as init\nfrom tqdm import tqdm\n\nfrom models.modules.pseudo_gt import expand_area\nfrom models.model import get_discriminator, get_generator, vgg16\nfrom models.loss import GANLoss, MakeupLoss, ComposePGT, AnnealingComposePGT\n\nfrom training.utils import plot_curves\n\nclass Solver():\n    def __init__(self, config, args, logger=None, inference=False):\n        self.G = get_generator(config)\n        if inference:\n            self.G.load_state_dict(torch.load(inference, map_location=args.device))\n            self.G = self.G.to(args.device).eval()\n            return\n        self.double_d = config.TRAINING.DOUBLE_D\n        self.D_A = get_discriminator(config)\n        if self.double_d:\n            self.D_B = get_discriminator(config)\n        \n        self.load_folder = args.load_folder\n        self.save_folder = args.save_folder\n        self.vis_folder = os.path.join(args.save_folder, 'visualization')\n        if not os.path.exists(self.vis_folder):\n            os.makedirs(self.vis_folder)\n        self.vis_freq = config.LOG.VIS_FREQ\n        self.save_freq = config.LOG.SAVE_FREQ\n\n        # Data & PGT\n        self.img_size = config.DATA.IMG_SIZE\n        self.margins = {'eye':config.PGT.EYE_MARGIN,\n                        'lip':config.PGT.LIP_MARGIN}\n        self.pgt_annealing = config.PGT.ANNEALING\n        if self.pgt_annealing:\n            self.pgt_maker = AnnealingComposePGT(self.margins, \n                config.PGT.SKIN_ALPHA_MILESTONES, config.PGT.SKIN_ALPHA_VALUES,\n                config.PGT.EYE_ALPHA_MILESTONES, config.PGT.EYE_ALPHA_VALUES,\n                config.PGT.LIP_ALPHA_MILESTONES, config.PGT.LIP_ALPHA_VALUES\n            )\n        else:\n            self.pgt_maker = ComposePGT(self.margins, \n                config.PGT.SKIN_ALPHA,\n                config.PGT.EYE_ALPHA,\n                config.PGT.LIP_ALPHA\n            )\n        self.pgt_maker.eval()\n\n        # Hyper-param\n        self.num_epochs = config.TRAINING.NUM_EPOCHS\n        self.g_lr = config.TRAINING.G_LR\n        self.d_lr = config.TRAINING.D_LR\n        self.beta1 = config.TRAINING.BETA1\n        self.beta2 = config.TRAINING.BETA2\n        self.lr_decay_factor = config.TRAINING.LR_DECAY_FACTOR\n\n        # Loss param\n        self.lambda_idt      = config.LOSS.LAMBDA_IDT\n        self.lambda_A        = config.LOSS.LAMBDA_A\n        self.lambda_B        = config.LOSS.LAMBDA_B\n        self.lambda_lip  = config.LOSS.LAMBDA_MAKEUP_LIP\n        self.lambda_skin = config.LOSS.LAMBDA_MAKEUP_SKIN\n        self.lambda_eye  = config.LOSS.LAMBDA_MAKEUP_EYE\n        self.lambda_vgg      = config.LOSS.LAMBDA_VGG\n\n        self.device = args.device\n        self.keepon = args.keepon\n        self.logger = logger\n        self.loss_logger = {\n            'D-A-loss_real':[],\n            'D-A-loss_fake':[],\n            'D-B-loss_real':[],\n            'D-B-loss_fake':[],\n            'G-A-loss-adv':[],\n            'G-B-loss-adv':[],\n            'G-loss-idt':[],\n            'G-loss-img-rec':[],\n            'G-loss-vgg-rec':[],\n            'G-loss-rec':[],\n            'G-loss-skin-pgt':[],\n            'G-loss-eye-pgt':[],\n            'G-loss-lip-pgt':[],\n            'G-loss-pgt':[],\n            'G-loss':[],\n            'D-A-loss':[],\n            'D-B-loss':[]\n        }\n\n        self.build_model()\n        super(Solver, self).__init__()\n\n    def print_network(self, model, name):\n        num_params = 0\n        for p in model.parameters():\n            num_params += p.numel()\n        if self.logger is not None:\n            self.logger.info('{:s}, the number of parameters: {:d}'.format(name, num_params))\n        else:\n            print('{:s}, the number of parameters: {:d}'.format(name, num_params))\n    \n    # For generator\n    def weights_init_xavier(self, m):\n        classname = m.__class__.__name__\n        if classname.find('Conv') != -1:\n            init.xavier_normal_(m.weight.data, gain=1.0)\n        elif classname.find('Linear') != -1:\n            init.xavier_normal_(m.weight.data, gain=1.0)\n\n    def build_model(self):\n        self.G.apply(self.weights_init_xavier)\n        self.D_A.apply(self.weights_init_xavier)\n        if self.double_d:\n            self.D_B.apply(self.weights_init_xavier)\n        if self.keepon:\n            self.load_checkpoint()\n        \n        self.criterionL1 = torch.nn.L1Loss()\n        self.criterionL2 = torch.nn.MSELoss()\n        self.criterionGAN = GANLoss(gan_mode='lsgan')\n        self.criterionPGT = MakeupLoss()\n        self.vgg = vgg16(pretrained=True)\n\n        # Optimizers\n        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])\n        self.d_A_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D_A.parameters()), self.d_lr, [self.beta1, self.beta2])\n        if self.double_d:\n            self.d_B_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D_B.parameters()), self.d_lr, [self.beta1, self.beta2])\n        self.g_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.g_optimizer, \n                    T_max=self.num_epochs, eta_min=self.g_lr * self.lr_decay_factor)\n        self.d_A_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.d_A_optimizer, \n                    T_max=self.num_epochs, eta_min=self.d_lr * self.lr_decay_factor)\n        if self.double_d:\n            self.d_B_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.d_B_optimizer, \n                    T_max=self.num_epochs, eta_min=self.d_lr * self.lr_decay_factor)\n\n        # Print networks\n        self.print_network(self.G, 'G')\n        self.print_network(self.D_A, 'D_A')\n        if self.double_d: self.print_network(self.D_B, 'D_B')\n\n        self.G.to(self.device)\n        self.vgg.to(self.device)\n        self.D_A.to(self.device)\n        if self.double_d: self.D_B.to(self.device)\n\n    def train(self, data_loader):\n        self.len_dataset = len(data_loader)\n        \n        for self.epoch in range(1, self.num_epochs + 1):\n            self.start_time = time.time()\n            loss_tmp = self.get_loss_tmp()\n            self.G.train(); self.D_A.train(); \n            if self.double_d: self.D_B.train()\n            losses_G = []; losses_D_A = []; losses_D_B = []\n            \n            with tqdm(data_loader, desc=\"training\") as pbar:\n                for step, (source, reference) in enumerate(pbar):\n                    # image, mask, diff, lms\n                    image_s, image_r = source[0].to(self.device), reference[0].to(self.device) # (b, c, h, w)\n                    mask_s_full, mask_r_full = source[1].to(self.device), reference[1].to(self.device) # (b, c', h, w) \n                    diff_s, diff_r = source[2].to(self.device), reference[2].to(self.device) # (b, 136, h, w)\n                    lms_s, lms_r = source[3].to(self.device), reference[3].to(self.device) # (b, K, 2)\n\n                    # process input mask\n                    mask_s = torch.cat((mask_s_full[:,0:1], mask_s_full[:,1:].sum(dim=1, keepdim=True)), dim=1)\n                    mask_r = torch.cat((mask_r_full[:,0:1], mask_r_full[:,1:].sum(dim=1, keepdim=True)), dim=1)\n                    #mask_s = mask_s_full[:,:2]; mask_r = mask_r_full[:,:2]\n\n                    # ================= Generate ================== #\n                    fake_A = self.G(image_s, image_r, mask_s, mask_r, diff_s, diff_r, lms_s, lms_r)\n                    fake_B = self.G(image_r, image_s, mask_r, mask_s, diff_r, diff_s, lms_r, lms_s)\n\n                    # generate pseudo ground truth\n                    pgt_A = self.pgt_maker(image_s, image_r, mask_s_full, mask_r_full, lms_s, lms_r)\n                    pgt_B = self.pgt_maker(image_r, image_s, mask_r_full, mask_s_full, lms_r, lms_s)\n                    \n                    # ================== Train D ================== #\n                    # training D_A, D_A aims to distinguish class B\n                    # Real\n                    out = self.D_A(image_r)\n                    d_loss_real = self.criterionGAN(out, True)\n                    # Fake\n                    out = self.D_A(fake_A.detach())\n                    d_loss_fake =  self.criterionGAN(out, False)\n\n                    # Backward + Optimize\n                    d_loss = (d_loss_real + d_loss_fake) * 0.5\n                    self.d_A_optimizer.zero_grad()\n                    d_loss.backward()\n                    self.d_A_optimizer.step()                   \n\n                    # Logging\n                    loss_tmp['D-A-loss_real'] += d_loss_real.item()\n                    loss_tmp['D-A-loss_fake'] += d_loss_fake.item()\n                    losses_D_A.append(d_loss.item())\n\n                    # training D_B, D_B aims to distinguish class A\n                    # Real\n                    if self.double_d:\n                        out = self.D_B(image_s)\n                    else:\n                        out = self.D_A(image_s)\n                    d_loss_real = self.criterionGAN(out, True)\n                    # Fake\n                    if self.double_d:\n                        out = self.D_B(fake_B.detach())\n                    else:\n                        out = self.D_A(fake_B.detach())\n                    d_loss_fake =  self.criterionGAN(out, False)\n\n                    # Backward + Optimize\n                    d_loss = (d_loss_real+ d_loss_fake) * 0.5\n                    if self.double_d:\n                        self.d_B_optimizer.zero_grad()\n                        d_loss.backward()\n                        self.d_B_optimizer.step()\n                    else:\n                        self.d_A_optimizer.zero_grad()\n                        d_loss.backward()\n                        self.d_A_optimizer.step()\n\n                    # Logging\n                    loss_tmp['D-B-loss_real'] += d_loss_real.item()\n                    loss_tmp['D-B-loss_fake'] += d_loss_fake.item()\n                    losses_D_B.append(d_loss.item())\n\n                    # ================== Train G ================== #\n                    \n                    # G should be identity if ref_B or org_A is fed\n                    idt_A = self.G(image_s, image_s, mask_s, mask_s, diff_s, diff_s, lms_s, lms_s)\n                    idt_B = self.G(image_r, image_r, mask_r, mask_r, diff_r, diff_r, lms_r, lms_r)\n                    loss_idt_A = self.criterionL1(idt_A, image_s) * self.lambda_A * self.lambda_idt\n                    loss_idt_B = self.criterionL1(idt_B, image_r) * self.lambda_B * self.lambda_idt\n                    # loss_idt\n                    loss_idt = (loss_idt_A + loss_idt_B) * 0.5\n\n                    # GAN loss D_A(G_A(A))\n                    pred_fake = self.D_A(fake_A)\n                    g_A_loss_adv = self.criterionGAN(pred_fake, True)\n\n                    # GAN loss D_B(G_B(B))\n                    if self.double_d:\n                        pred_fake = self.D_B(fake_B)\n                    else:\n                        pred_fake = self.D_A(fake_B)\n                    g_B_loss_adv = self.criterionGAN(pred_fake, True)\n                    \n                    # Makeup loss\n                    g_A_loss_pgt = 0; g_B_loss_pgt = 0\n                    \n                    g_A_lip_loss_pgt = self.criterionPGT(fake_A, pgt_A, mask_s_full[:,0:1]) * self.lambda_lip\n                    g_B_lip_loss_pgt = self.criterionPGT(fake_B, pgt_B, mask_r_full[:,0:1]) * self.lambda_lip\n                    g_A_loss_pgt += g_A_lip_loss_pgt\n                    g_B_loss_pgt += g_B_lip_loss_pgt\n\n                    mask_s_eye = expand_area(mask_s_full[:,2:4].sum(dim=1, keepdim=True), self.margins['eye'])\n                    mask_r_eye = expand_area(mask_r_full[:,2:4].sum(dim=1, keepdim=True), self.margins['eye'])\n                    mask_s_eye = mask_s_eye * mask_s_full[:,1:2]\n                    mask_r_eye = mask_r_eye * mask_r_full[:,1:2]\n                    g_A_eye_loss_pgt = self.criterionPGT(fake_A, pgt_A, mask_s_eye) * self.lambda_eye\n                    g_B_eye_loss_pgt = self.criterionPGT(fake_B, pgt_B, mask_r_eye) * self.lambda_eye\n                    g_A_loss_pgt += g_A_eye_loss_pgt\n                    g_B_loss_pgt += g_B_eye_loss_pgt\n                    \n                    mask_s_skin = mask_s_full[:,1:2] * (1 - mask_s_eye)\n                    mask_r_skin = mask_r_full[:,1:2] * (1 - mask_r_eye)\n                    g_A_skin_loss_pgt = self.criterionPGT(fake_A, pgt_A, mask_s_skin) * self.lambda_skin\n                    g_B_skin_loss_pgt = self.criterionPGT(fake_B, pgt_B, mask_r_skin) * self.lambda_skin\n                    g_A_loss_pgt += g_A_skin_loss_pgt\n                    g_B_loss_pgt += g_B_skin_loss_pgt\n                    \n                    # cycle loss\n                    rec_A = self.G(fake_A, image_s, mask_s, mask_s, diff_s, diff_s, lms_s, lms_s)\n                    rec_B = self.G(fake_B, image_r, mask_r, mask_r, diff_r, diff_r, lms_r, lms_r)\n\n                    # cycle loss v2\n                    # rec_A = self.G(fake_A, fake_B, mask_s, mask_r, diff_s, diff_r, lms_s, lms_r)\n                    # rec_B = self.G(fake_B, fake_A, mask_r, mask_s, diff_r, diff_s, lms_r, lms_s)\n\n                    g_loss_rec_A = self.criterionL1(rec_A, image_s) * self.lambda_A\n                    g_loss_rec_B = self.criterionL1(rec_B, image_r) * self.lambda_B\n\n                    # vgg loss\n                    vgg_s = self.vgg(image_s).detach()\n                    vgg_fake_A = self.vgg(fake_A)\n                    g_loss_A_vgg = self.criterionL2(vgg_fake_A, vgg_s) * self.lambda_A * self.lambda_vgg\n\n                    vgg_r = self.vgg(image_r).detach()\n                    vgg_fake_B = self.vgg(fake_B)\n                    g_loss_B_vgg = self.criterionL2(vgg_fake_B, vgg_r) * self.lambda_B * self.lambda_vgg\n\n                    loss_rec = (g_loss_rec_A + g_loss_rec_B + g_loss_A_vgg + g_loss_B_vgg) * 0.5\n\n                    # Combined loss\n                    g_loss = g_A_loss_adv + g_B_loss_adv + loss_rec + loss_idt + g_A_loss_pgt + g_B_loss_pgt\n\n                    self.g_optimizer.zero_grad()\n                    g_loss.backward()\n                    self.g_optimizer.step()\n\n                    # Logging\n                    loss_tmp['G-A-loss-adv'] += g_A_loss_adv.item()\n                    loss_tmp['G-B-loss-adv'] += g_B_loss_adv.item()\n                    loss_tmp['G-loss-idt'] += loss_idt.item()\n                    loss_tmp['G-loss-img-rec'] += (g_loss_rec_A + g_loss_rec_B).item() * 0.5\n                    loss_tmp['G-loss-vgg-rec'] += (g_loss_A_vgg + g_loss_B_vgg).item() * 0.5\n                    loss_tmp['G-loss-rec'] += loss_rec.item()\n                    loss_tmp['G-loss-skin-pgt'] += (g_A_skin_loss_pgt + g_B_skin_loss_pgt).item()\n                    loss_tmp['G-loss-eye-pgt'] += (g_A_eye_loss_pgt + g_B_eye_loss_pgt).item()\n                    loss_tmp['G-loss-lip-pgt'] += (g_A_lip_loss_pgt + g_B_lip_loss_pgt).item()\n                    loss_tmp['G-loss-pgt'] += (g_A_loss_pgt + g_B_loss_pgt).item()\n                    losses_G.append(g_loss.item())\n                    pbar.set_description(\"Epoch: %d, Step: %d, Loss_G: %0.4f, Loss_A: %0.4f, Loss_B: %0.4f\" % \\\n                                (self.epoch, step + 1, np.mean(losses_G), np.mean(losses_D_A), np.mean(losses_D_B)))\n\n            self.end_time = time.time()\n            for k, v in loss_tmp.items():\n                loss_tmp[k] = v / self.len_dataset  \n            loss_tmp['G-loss'] = np.mean(losses_G)\n            loss_tmp['D-A-loss'] = np.mean(losses_D_A)\n            loss_tmp['D-B-loss'] = np.mean(losses_D_B)\n            self.log_loss(loss_tmp)\n            self.plot_loss()\n\n            # Decay learning rate\n            self.g_scheduler.step()\n            self.d_A_scheduler.step()\n            if self.double_d:\n                self.d_B_scheduler.step()\n\n            if self.pgt_annealing:\n                self.pgt_maker.step()\n\n            #save the images\n            if (self.epoch) % self.vis_freq == 0:\n                self.vis_train([image_s.detach().cpu(), \n                                image_r.detach().cpu(), \n                                fake_A.detach().cpu(), \n                                pgt_A.detach().cpu()])\n            #                   rec_A.detach().cpu()])\n\n            # Save model checkpoints\n            if (self.epoch) % self.save_freq == 0:\n                self.save_models()\n   \n\n    def get_loss_tmp(self):\n        loss_tmp = {\n            'D-A-loss_real':0.0,\n            'D-A-loss_fake':0.0,\n            'D-B-loss_real':0.0,\n            'D-B-loss_fake':0.0,\n            'G-A-loss-adv':0.0,\n            'G-B-loss-adv':0.0,\n            'G-loss-idt':0.0,\n            'G-loss-img-rec':0.0,\n            'G-loss-vgg-rec':0.0,\n            'G-loss-rec':0.0,\n            'G-loss-skin-pgt':0.0,\n            'G-loss-eye-pgt':0.0,\n            'G-loss-lip-pgt':0.0,\n            'G-loss-pgt':0.0,\n        }\n        return loss_tmp\n\n    def log_loss(self, loss_tmp):\n        if self.logger is not None:\n            self.logger.info('\\n' + '='*40 + '\\nEpoch {:d}, time {:.2f} s'\n                            .format(self.epoch, self.end_time - self.start_time))\n        else:\n            print('\\n' + '='*40 + '\\nEpoch {:d}, time {:d} s'\n                    .format(self.epoch, self.end_time - self.start_time))\n        for k, v in loss_tmp.items():\n            self.loss_logger[k].append(v)\n            if self.logger is not None:\n                self.logger.info('{:s}\\t{:.6f}'.format(k, v))  \n            else:\n                print('{:s}\\t{:.6f}'.format(k, v))  \n        if self.logger is not None:\n            self.logger.info('='*40)  \n        else:\n            print('='*40)\n\n    def plot_loss(self):\n        G_losses = []; G_names = []\n        D_A_losses = []; D_A_names = []\n        D_B_losses = []; D_B_names = []\n        D_P_losses = []; D_P_names = []\n        for k, v in self.loss_logger.items():\n            if 'G' in k:\n                G_names.append(k); G_losses.append(v)\n            elif 'D-A' in k:\n                D_A_names.append(k); D_A_losses.append(v)\n            elif 'D-B' in k:\n                D_B_names.append(k); D_B_losses.append(v)\n            elif 'D-P' in k:\n                D_P_names.append(k); D_P_losses.append(v)\n        plot_curves(self.save_folder, 'G_loss', G_losses, G_names, ylabel='Loss')\n        plot_curves(self.save_folder, 'D-A_loss', D_A_losses, D_A_names, ylabel='Loss')\n        plot_curves(self.save_folder, 'D-B_loss', D_B_losses, D_B_names, ylabel='Loss')\n\n    def load_checkpoint(self):\n        G_path = os.path.join(self.load_folder, 'G.pth')\n        if os.path.exists(G_path):\n            self.G.load_state_dict(torch.load(G_path, map_location=self.device))\n            print('loaded trained generator {}..!'.format(G_path))\n        D_A_path = os.path.join(self.load_folder, 'D_A.pth')\n        if os.path.exists(D_A_path):\n            self.D_A.load_state_dict(torch.load(D_A_path, map_location=self.device))\n            print('loaded trained discriminator A {}..!'.format(D_A_path))\n\n        if self.double_d:\n            D_B_path = os.path.join(self.load_folder, 'D_B.pth')\n            if os.path.exists(D_B_path):\n                self.D_B.load_state_dict(torch.load(D_B_path, map_location=self.device))\n                print('loaded trained discriminator B {}..!'.format(D_B_path))\n    \n    def save_models(self):\n        save_dir = os.path.join(self.save_folder, 'epoch_{:d}'.format(self.epoch))\n        if not os.path.exists(save_dir):\n            os.makedirs(save_dir)\n        torch.save(self.G.state_dict(), os.path.join(save_dir, 'G.pth'))\n        torch.save(self.D_A.state_dict(), os.path.join(save_dir, 'D_A.pth'))\n        if self.double_d:\n            torch.save(self.D_B.state_dict(), os.path.join(save_dir, 'D_B.pth'))\n\n    def de_norm(self, x):\n        out = (x + 1) / 2\n        return out.clamp(0, 1)\n    \n    def vis_train(self, img_train_batch):\n        # saving training results\n        img_train_batch = torch.cat(img_train_batch, dim=3)\n        save_path = os.path.join(self.vis_folder, 'epoch_{:d}_fake.png'.format(self.epoch))\n        vis_image = make_grid(self.de_norm(img_train_batch), 1)\n        save_image(vis_image, save_path) #, normalize=True)\n\n    def generate(self, image_A, image_B, mask_A=None, mask_B=None, \n                 diff_A=None, diff_B=None, lms_A=None, lms_B=None):\n        \"\"\"image_A is content, image_B is style\"\"\"\n        with torch.no_grad():\n            res = self.G(image_A, image_B, mask_A, mask_B, diff_A, diff_B, lms_A, lms_B)\n        return res\n\n    def test(self, image_A, mask_A, diff_A, lms_A, image_B, mask_B, diff_B, lms_B):        \n        with torch.no_grad():\n            fake_A = self.generate(image_A, image_B, mask_A, mask_B, diff_A, diff_B, lms_A, lms_B)\n        fake_A = self.de_norm(fake_A)\n        fake_A = fake_A.squeeze(0)\n        return ToPILImage()(fake_A.cpu())"
  },
  {
    "path": "training/utils.py",
    "content": "import os\nimport logging\nimport numpy as np\nimport matplotlib.pyplot as plt\n\ndef create_logger(save_path='', file_type='', level='debug', console=True):\n    if level == 'debug':\n        _level = logging.DEBUG\n    elif level == 'info':\n        _level = logging.INFO\n\n    logger = logging.getLogger()\n    logger.setLevel(_level)\n\n    if console:\n        cs = logging.StreamHandler()\n        cs.setLevel(_level)\n        logger.addHandler(cs)\n\n    if save_path != '':\n        file_name = os.path.join(save_path, file_type + '_log.txt')\n        fh = logging.FileHandler(file_name, mode='w')\n        fh.setLevel(_level)\n\n        logger.addHandler(fh)\n\n    return logger\n\ndef print_args(args, logger=None):\n    for k, v in vars(args).items():\n        if logger is not None:\n            logger.info('{:<16} : {}'.format(k, v))\n        else:\n            print('{:<16} : {}'.format(k, v))\n\ndef plot_single_curve(path, name, point, freq=1, xlabel='Epoch',ylabel=None):\n    \n    x = (np.arange(len(point)) + 1) * freq\n    plt.plot(x, point, color='purple')\n    plt.xlabel(xlabel)\n    if ylabel is None:\n        ylabel = name\n    plt.ylabel(ylabel)\n    plt.savefig(os.path.join(path, name + '.png'))\n    plt.close()\n\ndef plot_curves(path, name, point_list, curve_names=None, freq=1, xlabel='Epoch',ylabel=None):\n    if curve_names is None:\n        curve_names = [''] * len(point_list)\n    else:\n        assert len(point_list) == len(curve_names)\n\n    x = (np.arange(len(point_list[0])) + 1) * freq\n    if len(point_list) <= 10:\n        cmap = plt.get_cmap('tab10')\n    else:\n        cmap = plt.get_cmap('tab20')\n    for i, (point, curve_name) in enumerate(zip(point_list, curve_names)):\n        assert len(point) == len(x)\n        plt.plot(x, point, color=cmap(i), label=curve_name)\n        \n    plt.xlabel(xlabel)\n    if ylabel is not None:\n        plt.ylabel(ylabel)\n    plt.legend()\n    plt.savefig(os.path.join(path, name + '.png'))\n    plt.close()"
  }
]