Full Code of Chenyu-Yang-2000/EleGANt for AI

main d033d3398751 cached
37 files
160.8 KB
42.9k tokens
232 symbols
1 requests
Download .txt
Repository: Chenyu-Yang-2000/EleGANt
Branch: main
Commit: d033d3398751
Files: 37
Total size: 160.8 KB

Directory structure:
gitextract_z1ehcgp9/

├── .gitignore
├── LICENSE
├── README.md
├── assets/
│   └── docs/
│       ├── install.md
│       └── prepare.md
├── concern/
│   ├── __init__.py
│   ├── image.py
│   ├── track.py
│   └── visualize.py
├── faceutils/
│   ├── __init__.py
│   ├── dlibutils/
│   │   ├── __init__.py
│   │   └── main.py
│   └── mask/
│       ├── __init__.py
│       ├── main.py
│       ├── model.py
│       └── resnet.py
├── models/
│   ├── __init__.py
│   ├── elegant.py
│   ├── loss.py
│   ├── model.py
│   └── modules/
│       ├── __init__.py
│       ├── histogram_matching.py
│       ├── module_attn.py
│       ├── module_base.py
│       ├── pseudo_gt.py
│       ├── sow_attention.py
│       ├── spectral_norm.py
│       └── tps_transform.py
├── scripts/
│   ├── demo.py
│   └── train.py
└── training/
    ├── __init__.py
    ├── config.py
    ├── dataset.py
    ├── inference.py
    ├── preprocess.py
    ├── solver.py
    └── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
/data
/results
*.dat
*.pth
*.pt

# *.txt
# !/expe/*/*.txt
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so



# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
.idea/

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/
.vscode/settings.json


================================================
FILE: LICENSE
================================================
Attribution-NonCommercial-ShareAlike 4.0 International

=======================================================================

Creative Commons Corporation ("Creative Commons") is not a law firm and
does not provide legal services or legal advice. Distribution of
Creative Commons public licenses does not create a lawyer-client or
other relationship. Creative Commons makes its licenses and related
information available on an "as-is" basis. Creative Commons gives no
warranties regarding its licenses, any material licensed under their
terms and conditions, or any related information. Creative Commons
disclaims all liability for damages resulting from their use to the
fullest extent possible.

Using Creative Commons Public Licenses

Creative Commons public licenses provide a standard set of terms and
conditions that creators and other rights holders may use to share
original works of authorship and other material subject to copyright
and certain other rights specified in the public license below. The
following considerations are for informational purposes only, are not
exhaustive, and do not form part of our licenses.

     Considerations for licensors: Our public licenses are
     intended for use by those authorized to give the public
     permission to use material in ways otherwise restricted by
     copyright and certain other rights. Our licenses are
     irrevocable. Licensors should read and understand the terms
     and conditions of the license they choose before applying it.
     Licensors should also secure all rights necessary before
     applying our licenses so that the public can reuse the
     material as expected. Licensors should clearly mark any
     material not subject to the license. This includes other CC-
     licensed material, or material used under an exception or
     limitation to copyright. More considerations for licensors:
    wiki.creativecommons.org/Considerations_for_licensors

     Considerations for the public: By using one of our public
     licenses, a licensor grants the public permission to use the
     licensed material under specified terms and conditions. If
     the licensor's permission is not necessary for any reason--for
     example, because of any applicable exception or limitation to
     copyright--then that use is not regulated by the license. Our
     licenses grant only permissions under copyright and certain
     other rights that a licensor has authority to grant. Use of
     the licensed material may still be restricted for other
     reasons, including because others have copyright or other
     rights in the material. A licensor may make special requests,
     such as asking that all changes be marked or described.
     Although not required by our licenses, you are encouraged to
     respect those requests where reasonable. More considerations
     for the public:
    wiki.creativecommons.org/Considerations_for_licensees

=======================================================================

Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
Public License

By exercising the Licensed Rights (defined below), You accept and agree
to be bound by the terms and conditions of this Creative Commons
Attribution-NonCommercial-ShareAlike 4.0 International Public License
("Public License"). To the extent this Public License may be
interpreted as a contract, You are granted the Licensed Rights in
consideration of Your acceptance of these terms and conditions, and the
Licensor grants You such rights in consideration of benefits the
Licensor receives from making the Licensed Material available under
these terms and conditions.


Section 1 -- Definitions.

  a. Adapted Material means material subject to Copyright and Similar
     Rights that is derived from or based upon the Licensed Material
     and in which the Licensed Material is translated, altered,
     arranged, transformed, or otherwise modified in a manner requiring
     permission under the Copyright and Similar Rights held by the
     Licensor. For purposes of this Public License, where the Licensed
     Material is a musical work, performance, or sound recording,
     Adapted Material is always produced where the Licensed Material is
     synched in timed relation with a moving image.

  b. Adapter's License means the license You apply to Your Copyright
     and Similar Rights in Your contributions to Adapted Material in
     accordance with the terms and conditions of this Public License.

  c. BY-NC-SA Compatible License means a license listed at
     creativecommons.org/compatiblelicenses, approved by Creative
     Commons as essentially the equivalent of this Public License.

  d. Copyright and Similar Rights means copyright and/or similar rights
     closely related to copyright including, without limitation,
     performance, broadcast, sound recording, and Sui Generis Database
     Rights, without regard to how the rights are labeled or
     categorized. For purposes of this Public License, the rights
     specified in Section 2(b)(1)-(2) are not Copyright and Similar
     Rights.

  e. Effective Technological Measures means those measures that, in the
     absence of proper authority, may not be circumvented under laws
     fulfilling obligations under Article 11 of the WIPO Copyright
     Treaty adopted on December 20, 1996, and/or similar international
     agreements.

  f. Exceptions and Limitations means fair use, fair dealing, and/or
     any other exception or limitation to Copyright and Similar Rights
     that applies to Your use of the Licensed Material.

  g. License Elements means the license attributes listed in the name
     of a Creative Commons Public License. The License Elements of this
     Public License are Attribution, NonCommercial, and ShareAlike.

  h. Licensed Material means the artistic or literary work, database,
     or other material to which the Licensor applied this Public
     License.

  i. Licensed Rights means the rights granted to You subject to the
     terms and conditions of this Public License, which are limited to
     all Copyright and Similar Rights that apply to Your use of the
     Licensed Material and that the Licensor has authority to license.

  j. Licensor means the individual(s) or entity(ies) granting rights
     under this Public License.

  k. NonCommercial means not primarily intended for or directed towards
     commercial advantage or monetary compensation. For purposes of
     this Public License, the exchange of the Licensed Material for
     other material subject to Copyright and Similar Rights by digital
     file-sharing or similar means is NonCommercial provided there is
     no payment of monetary compensation in connection with the
     exchange.

  l. Share means to provide material to the public by any means or
     process that requires permission under the Licensed Rights, such
     as reproduction, public display, public performance, distribution,
     dissemination, communication, or importation, and to make material
     available to the public including in ways that members of the
     public may access the material from a place and at a time
     individually chosen by them.

  m. Sui Generis Database Rights means rights other than copyright
     resulting from Directive 96/9/EC of the European Parliament and of
     the Council of 11 March 1996 on the legal protection of databases,
     as amended and/or succeeded, as well as other essentially
     equivalent rights anywhere in the world.

  n. You means the individual or entity exercising the Licensed Rights
     under this Public License. Your has a corresponding meaning.


Section 2 -- Scope.

  a. License grant.

       1. Subject to the terms and conditions of this Public License,
          the Licensor hereby grants You a worldwide, royalty-free,
          non-sublicensable, non-exclusive, irrevocable license to
          exercise the Licensed Rights in the Licensed Material to:

            a. reproduce and Share the Licensed Material, in whole or
               in part, for NonCommercial purposes only; and

            b. produce, reproduce, and Share Adapted Material for
               NonCommercial purposes only.

       2. Exceptions and Limitations. For the avoidance of doubt, where
          Exceptions and Limitations apply to Your use, this Public
          License does not apply, and You do not need to comply with
          its terms and conditions.

       3. Term. The term of this Public License is specified in Section
          6(a).

       4. Media and formats; technical modifications allowed. The
          Licensor authorizes You to exercise the Licensed Rights in
          all media and formats whether now known or hereafter created,
          and to make technical modifications necessary to do so. The
          Licensor waives and/or agrees not to assert any right or
          authority to forbid You from making technical modifications
          necessary to exercise the Licensed Rights, including
          technical modifications necessary to circumvent Effective
          Technological Measures. For purposes of this Public License,
          simply making modifications authorized by this Section 2(a)
          (4) never produces Adapted Material.

       5. Downstream recipients.

            a. Offer from the Licensor -- Licensed Material. Every
               recipient of the Licensed Material automatically
               receives an offer from the Licensor to exercise the
               Licensed Rights under the terms and conditions of this
               Public License.

            b. Additional offer from the Licensor -- Adapted Material.
               Every recipient of Adapted Material from You
               automatically receives an offer from the Licensor to
               exercise the Licensed Rights in the Adapted Material
               under the conditions of the Adapter's License You apply.

            c. No downstream restrictions. You may not offer or impose
               any additional or different terms or conditions on, or
               apply any Effective Technological Measures to, the
               Licensed Material if doing so restricts exercise of the
               Licensed Rights by any recipient of the Licensed
               Material.

       6. No endorsement. Nothing in this Public License constitutes or
          may be construed as permission to assert or imply that You
          are, or that Your use of the Licensed Material is, connected
          with, or sponsored, endorsed, or granted official status by,
          the Licensor or others designated to receive attribution as
          provided in Section 3(a)(1)(A)(i).

  b. Other rights.

       1. Moral rights, such as the right of integrity, are not
          licensed under this Public License, nor are publicity,
          privacy, and/or other similar personality rights; however, to
          the extent possible, the Licensor waives and/or agrees not to
          assert any such rights held by the Licensor to the limited
          extent necessary to allow You to exercise the Licensed
          Rights, but not otherwise.

       2. Patent and trademark rights are not licensed under this
          Public License.

       3. To the extent possible, the Licensor waives any right to
          collect royalties from You for the exercise of the Licensed
          Rights, whether directly or through a collecting society
          under any voluntary or waivable statutory or compulsory
          licensing scheme. In all other cases the Licensor expressly
          reserves any right to collect such royalties, including when
          the Licensed Material is used other than for NonCommercial
          purposes.


Section 3 -- License Conditions.

Your exercise of the Licensed Rights is expressly made subject to the
following conditions.

  a. Attribution.

       1. If You Share the Licensed Material (including in modified
          form), You must:

            a. retain the following if it is supplied by the Licensor
               with the Licensed Material:

                 i. identification of the creator(s) of the Licensed
                    Material and any others designated to receive
                    attribution, in any reasonable manner requested by
                    the Licensor (including by pseudonym if
                    designated);

                ii. a copyright notice;

               iii. a notice that refers to this Public License;

                iv. a notice that refers to the disclaimer of
                    warranties;

                 v. a URI or hyperlink to the Licensed Material to the
                    extent reasonably practicable;

            b. indicate if You modified the Licensed Material and
               retain an indication of any previous modifications; and

            c. indicate the Licensed Material is licensed under this
               Public License, and include the text of, or the URI or
               hyperlink to, this Public License.

       2. You may satisfy the conditions in Section 3(a)(1) in any
          reasonable manner based on the medium, means, and context in
          which You Share the Licensed Material. For example, it may be
          reasonable to satisfy the conditions by providing a URI or
          hyperlink to a resource that includes the required
          information.
       3. If requested by the Licensor, You must remove any of the
          information required by Section 3(a)(1)(A) to the extent
          reasonably practicable.

  b. ShareAlike.

     In addition to the conditions in Section 3(a), if You Share
     Adapted Material You produce, the following conditions also apply.

       1. The Adapter's License You apply must be a Creative Commons
          license with the same License Elements, this version or
          later, or a BY-NC-SA Compatible License.

       2. You must include the text of, or the URI or hyperlink to, the
          Adapter's License You apply. You may satisfy this condition
          in any reasonable manner based on the medium, means, and
          context in which You Share Adapted Material.

       3. You may not offer or impose any additional or different terms
          or conditions on, or apply any Effective Technological
          Measures to, Adapted Material that restrict exercise of the
          rights granted under the Adapter's License You apply.


Section 4 -- Sui Generis Database Rights.

Where the Licensed Rights include Sui Generis Database Rights that
apply to Your use of the Licensed Material:

  a. for the avoidance of doubt, Section 2(a)(1) grants You the right
     to extract, reuse, reproduce, and Share all or a substantial
     portion of the contents of the database for NonCommercial purposes
     only;

  b. if You include all or a substantial portion of the database
     contents in a database in which You have Sui Generis Database
     Rights, then the database in which You have Sui Generis Database
     Rights (but not its individual contents) is Adapted Material,
     including for purposes of Section 3(b); and

  c. You must comply with the conditions in Section 3(a) if You Share
     all or a substantial portion of the contents of the database.

For the avoidance of doubt, this Section 4 supplements and does not
replace Your obligations under this Public License where the Licensed
Rights include other Copyright and Similar Rights.


Section 5 -- Disclaimer of Warranties and Limitation of Liability.

  a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
     EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
     AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
     ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
     IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
     WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
     PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
     ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
     KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
     ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.

  b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
     TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
     NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
     INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
     COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
     USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
     ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
     DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
     IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.

  c. The disclaimer of warranties and limitation of liability provided
     above shall be interpreted in a manner that, to the extent
     possible, most closely approximates an absolute disclaimer and
     waiver of all liability.


Section 6 -- Term and Termination.

  a. This Public License applies for the term of the Copyright and
     Similar Rights licensed here. However, if You fail to comply with
     this Public License, then Your rights under this Public License
     terminate automatically.

  b. Where Your right to use the Licensed Material has terminated under
     Section 6(a), it reinstates:

       1. automatically as of the date the violation is cured, provided
          it is cured within 30 days of Your discovery of the
          violation; or

       2. upon express reinstatement by the Licensor.

     For the avoidance of doubt, this Section 6(b) does not affect any
     right the Licensor may have to seek remedies for Your violations
     of this Public License.

  c. For the avoidance of doubt, the Licensor may also offer the
     Licensed Material under separate terms or conditions or stop
     distributing the Licensed Material at any time; however, doing so
     will not terminate this Public License.

  d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
     License.


Section 7 -- Other Terms and Conditions.

  a. The Licensor shall not be bound by any additional or different
     terms or conditions communicated by You unless expressly agreed.

  b. Any arrangements, understandings, or agreements regarding the
     Licensed Material not stated herein are separate from and
     independent of the terms and conditions of this Public License.


Section 8 -- Interpretation.

  a. For the avoidance of doubt, this Public License does not, and
     shall not be interpreted to, reduce, limit, restrict, or impose
     conditions on any use of the Licensed Material that could lawfully
     be made without permission under this Public License.

  b. To the extent possible, if any provision of this Public License is
     deemed unenforceable, it shall be automatically reformed to the
     minimum extent necessary to make it enforceable. If the provision
     cannot be reformed, it shall be severed from this Public License
     without affecting the enforceability of the remaining terms and
     conditions.

  c. No term or condition of this Public License will be waived and no
     failure to comply consented to unless expressly agreed to by the
     Licensor.

  d. Nothing in this Public License constitutes or may be interpreted
     as a limitation upon, or waiver of, any privileges and immunities
     that apply to the Licensor or You, including from the legal
     processes of any jurisdiction or authority.

=======================================================================

Creative Commons is not a party to its public
licenses. Notwithstanding, Creative Commons may elect to apply one of
its public licenses to material it publishes and in those instances
will be considered the “Licensor.” The text of the Creative Commons
public licenses is dedicated to the public domain under the CC0 Public
Domain Dedication. Except for the limited purpose of indicating that
material is shared under a Creative Commons public license or as
otherwise permitted by the Creative Commons policies published at
creativecommons.org/policies, Creative Commons does not authorize the
use of the trademark "Creative Commons" or any other trademark or logo
of Creative Commons without its prior written consent including,
without limitation, in connection with any unauthorized modifications
to any of its public licenses or any other arrangements,
understandings, or agreements concerning use of licensed material. For
the avoidance of doubt, this paragraph does not form part of the
public licenses.

Creative Commons may be contacted at creativecommons.org.

================================================
FILE: README.md
================================================
# EleGANt: Exquisite and Locally Editable GAN for Makeup Transfer

[![CC BY-NC-SA 4.0][cc-by-nc-sa-shield]][cc-by-nc-sa]

Official [PyTorch](https://pytorch.org/) implementation of ECCV 2022 paper "[EleGANt: Exquisite and Locally Editable GAN for Makeup Transfer](https://arxiv.org/abs/2207.09840)"

*Chenyu Yang, Wanrong He, Yingqing Xu, and Yang Gao*.

![teaser](assets/figs/teaser.png)

## Getting Started

- [Installation](assets/docs/install.md)
- [Prepare Dataset & Checkpoints](assets/docs/prepare.md)

## Test

To test our model, download the [weights](https://drive.google.com/drive/folders/1xzIS3Dfmsssxkk9OhhAS4svrZSPfQYRe?usp=sharing) of the trained model and run

```bash
python scripts/demo.py
```

Examples of makeup transfer results can be seen [here](assets/images/examples/).

## Train

To train a model from scratch, run

```bash
python scripts/train.py
```

## Customized Transfer

https://user-images.githubusercontent.com/61506577/180593092-ccadddff-76be-4b7b-921e-0d3b4cc27d9b.mp4

This is our demo of customized makeup editing. The interactive system is built upon [Streamlit](https://github.com/streamlit/streamlit) and the interface in `./training/inference.py`.

**Controllable makeup transfer.**

![control](assets/figs/control.png 'controllable makeup transfer')

**Local makeup editing.**

![edit](assets/figs/edit.png 'local makeup editing')

## Citation

If this work is helpful for your research, please consider citing the following BibTeX entry.

```text
@article{yang2022elegant,
  title={EleGANt: Exquisite and Locally Editable GAN for Makeup Transfer},
  author={Yang, Chenyu and He, Wanrong and Xu, Yingqing and Gao, Yang}
  journal={arXiv preprint arXiv:2207.09840},
  year={2022}
}
```

## Acknowledgement

Some of the codes are build upon [PSGAN](https://github.com/wtjiang98/PSGAN) and [aster.Pytorch](https://github.com/ayumiymk/aster.pytorch).

## License

This work is licensed under a
[Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License][cc-by-nc-sa].

[![CC BY-NC-SA 4.0][cc-by-nc-sa-image]][cc-by-nc-sa]

[cc-by-nc-sa]: http://creativecommons.org/licenses/by-nc-sa/4.0/
[cc-by-nc-sa-image]: https://licensebuttons.net/l/by-nc-sa/4.0/88x31.png
[cc-by-nc-sa-shield]: https://img.shields.io/badge/License-CC%20BY--NC--SA%204.0-lightgrey.svg


================================================
FILE: assets/docs/install.md
================================================
# Installation Instructions

This code was tested on Ubuntu 20.04 with CUDA 11.1.

**a. Create a conda virtual environment and activate it.**

```bash
conda create -n elegant python=3.8
conda activate elegant
```

**b. Install PyTorch and torchvision following the [official instructions](https://pytorch.org/).**

```bash
pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
```

**c. Install other required libaries.**

```bash
pip install opencv-python matplotlib dlib fvcore
```


================================================
FILE: assets/docs/prepare.md
================================================
# Preparation Instructions

Clone this repository and prepare the dataset and weights through the following steps:

**a. Prepare model weights for face detection.**

Download the weights of [dlib](https://github.com/davisking/dlib) face detector of 68 landmarks [here](http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2). Unzip it and move it to the directory `./faceutils/dlibutils`.

Download the weights of BiSeNet ([PyTorch implementation](https://github.com/zllrunning/face-parsing.PyTorch)) for face parsing [here](https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812). Rename it as `resnet.pth` and move it to the directory `./faceutils/mask`.

**b. Prepare Makeup Transfer (MT) dataset.**

Download raw data of the MT Dataset [here](https://github.com/wtjiang98/PSGAN) and unzip it into sub directory `./data`.

Run the following command to preprocess data:

```bash
python training/preprocess.py
```

Your data directory should look like:

```text
data
└── MT-Dataset
    ├── images
    │   ├── makeup
    │   └── non-makeup
    ├── segs
    │   ├── makeup
    │   └── non-makeup
    ├── lms
    │   ├── makeup
    │   └── non-makeup
    ├── makeup.txt
    ├── non-makeup.txt
    └── ...
```

**c. Download weights of trained EleGANt.**

The weights of our trained model can be download [here](https://drive.google.com/drive/folders/1xzIS3Dfmsssxkk9OhhAS4svrZSPfQYRe?usp=sharing). Put it under the directory `./ckpts`.


================================================
FILE: concern/__init__.py
================================================
from .image import load_image


================================================
FILE: concern/image.py
================================================
import numpy as np
import cv2
from io import BytesIO


def load_image(path):
    with path.open("rb") as reader:
        data = np.fromstring(reader.read(), dtype=np.uint8)
        img = cv2.imdecode(data, cv2.IMREAD_COLOR)
        if img is None:
            return
        img = img[..., ::-1]
    return img

def resize_by_max(image, max_side=512, force=False):
    h, w = image.shape[:2]
    if max(h, w) < max_side and not force:
        return image
    ratio = max(h, w) / max_side

    w = int(w / ratio + 0.5)
    h = int(h / ratio + 0.5)
    return cv2.resize(image, (w, h))

def image2buffer(image):
    is_success, buffer = cv2.imencode(".jpg", image)
    if not is_success:
        return None
    return BytesIO(buffer)


================================================
FILE: concern/track.py
================================================
import time

import torch


class Track:
    def __init__(self):
        self.log_point = time.time()
        self.enable_track = False

    def track(self, mark):
        if not self.enable_track:
            return

        if torch.cuda.is_available():
            torch.cuda.synchronize()
            print("{} memory:".format(mark), torch.cuda.memory_allocated() / 1024 / 1024, "M")
        print("{} time cost:".format(mark), time.time() - self.log_point)
        self.log_point = time.time()


================================================
FILE: concern/visualize.py
================================================
import numpy as np
import cv2


def channel_first(image, format):
    return image.transpose(
        format.index("C"), format.index("H"), format.index("W"))

def mask2image(mask:np.array, format="HWC"):
    H, W = mask.shape

    canvas = np.zeros((H, W, 3), dtype=np.uint8)
    for i in range(int(mask.max())):
        color = np.random.rand(1, 1, 3) * 255
        canvas += (mask == i)[:, :, None] * color.astype(np.uint8)
    return canvas

def draw_points(image, points, color=(255, 0, 0)):
    for point in points:
        print(int(point[1]), int(point[0]))
        image = cv2.circle(image, (int(point[1]), int(point[0])), 3, color)

    if hasattr(image, "get"):
        return image.get()
    return image


================================================
FILE: faceutils/__init__.py
================================================
#!/usr/bin/python
# -*- encoding: utf-8 -*-
#from . import faceplusplus as fpp
from . import dlibutils as dlib
from . import mask


================================================
FILE: faceutils/dlibutils/__init__.py
================================================
#!/usr/bin/python
# -*- encoding: utf-8 -*-
from .main import detect, crop, landmarks, crop_from_array


================================================
FILE: faceutils/dlibutils/main.py
================================================
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import os.path as osp

import numpy as np
from PIL import Image
import dlib
import cv2
from concern.image import resize_by_max

detector = dlib.get_frontal_face_detector()
predictor = dlib.shape_predictor(osp.split(osp.realpath(__file__))[0] + '/shape_predictor_68_face_landmarks.dat')


def detect(image: Image) -> 'faces':
    image = np.asarray(image)
    h, w = image.shape[:2]
    image = resize_by_max(image, 361)
    actual_h, actual_w = image.shape[:2]
    faces_on_small = detector(image, 1)
    faces = dlib.rectangles()
    for face in faces_on_small:
        faces.append(
            dlib.rectangle(
                int(face.left() / actual_w * w + 0.5),
                int(face.top() / actual_h * h + 0.5),
                int(face.right() / actual_w * w + 0.5),
                int(face.bottom() / actual_h * h  + 0.5)
            )
        )
    return faces

def crop(image: Image, face, up_ratio, down_ratio, width_ratio) -> (Image, 'face'):
    width, height = image.size
    face_height = face.height()
    face_width = face.width()
    delta_up = up_ratio * face_height
    delta_down = down_ratio * face_height
    delta_width = width_ratio * width

    img_left = int(max(0, face.left() - delta_width))
    img_top = int(max(0, face.top() - delta_up))
    img_right = int(min(width, face.right() + delta_width))
    img_bottom = int(min(height, face.bottom() + delta_down))
    image = image.crop((img_left, img_top, img_right, img_bottom))
    face = dlib.rectangle(face.left() - img_left, face.top() - img_top,
                        face.right() - img_left, face.bottom() - img_top)
    face_expand = dlib.rectangle(img_left, img_top, img_right, img_bottom)
    center = face_expand.center()
    width, height = image.size
    # import ipdb; ipdb.set_trace()
    crop_left = img_left
    crop_top = img_top
    crop_right = img_right
    crop_bottom = img_bottom
    if width > height:
        left = int(center.x - height / 2)
        right = int(center.x + height / 2)
        if left < 0:
            left, right = 0, height
        elif right > width:
            left, right = width - height, width
        image = image.crop((left, 0, right, height))
        face = dlib.rectangle(face.left() - left, face.top(),
                              face.right() - left, face.bottom())
        crop_left += left
        crop_right = crop_left + height
    elif width < height:
        top = int(center.y - width / 2)
        bottom = int(center.y + width / 2)
        if top < 0:
            top, bottom = 0, width
        elif bottom > height:
            top, bottom = height - width, height
        image = image.crop((0, top, width, bottom))
        face = dlib.rectangle(face.left(), face.top() - top,
                              face.right(), face.bottom() - top)
        crop_top += top
        crop_bottom = crop_top + width
    crop_face = dlib.rectangle(crop_left, crop_top, crop_right, crop_bottom)
    return image, face, crop_face


def crop_by_image_size(image: Image, face) -> (Image, 'face'):
    center = face.center()
    width, height = image.size
    if width > height:
        left = int(center.x - height / 2)
        right = int(center.x + height / 2)
        if left < 0:
            left, right = 0, height
        elif right > width:
            left, right = width - height, width
        image = image.crop((left, 0, right, height))
        face = dlib.rectangle(face.left() - left, face.top(),
                              face.right() - left, face.bottom())
    elif width < height:
        top = int(center.y - width / 2)
        bottom = int(center.y + width / 2)
        if top < 0:
            top, bottom = 0, width
        elif bottom > height:
            top, bottom = height - width, height
        image = image.crop((0, top, width, bottom))
        face = dlib.rectangle(face.left(), face.top() - top, 
                              face.right(), face.bottom() - top)
    return image, face


def landmarks(image: Image, face):
    shape = predictor(np.asarray(image), face).parts()
    return np.array([[p.y, p.x] for p in shape])

def crop_from_array(image: np.array, face) -> (np.array, 'face'):
    ratio = 0.20 / 0.85 # delta_size / face_size
    height, width = image.shape[:2]
    face_height = face.height()
    face_width = face.width()
    delta_height = ratio * face_height
    delta_width = ratio * width

    img_left = int(max(0, face.left() - delta_width))
    img_top = int(max(0, face.top() - delta_height))
    img_right = int(min(width, face.right() + delta_width))
    img_bottom = int(min(height, face.bottom() + delta_height))
    image = image[img_top:img_bottom, img_left:img_right]
    face = dlib.rectangle(face.left() - img_left, face.top() - img_top,
                        face.right() - img_left, face.bottom() - img_top)
    center = face.center()
    height, width = image.shape[:2]
    if width > height:
        left = int(center.x - height / 2)
        right = int(center.x + height / 2)
        if left < 0:
            left, right = 0, height
        elif right > width:
            left, right = width - height, width
        image = image[0:height, left:right]
        face = dlib.rectangle(face.left() - left, face.top(),
                              face.right() - left, face.bottom())
    elif width < height:
        top = int(center.y - width / 2)
        bottom = int(center.y + width / 2)
        if top < 0:
            top, bottom = 0, width
        elif bottom > height:
            top, bottom = height - width, height
        image = image[top:bottom, 0:width]
        face = dlib.rectangle(face.left(), face.top() - top,
                              face.right(), face.bottom() - top)
    return image, face



================================================
FILE: faceutils/mask/__init__.py
================================================
#!/usr/bin/python
# -*- encoding: utf-8 -*-
from .main import FaceParser


================================================
FILE: faceutils/mask/main.py
================================================
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import os.path as osp

import numpy as np
import cv2
from PIL import Image
import torch
import torchvision.transforms as transforms

from .model import BiSeNet


class FaceParser:
    def __init__(self, device="cpu"):
        mapper = [0, 1, 2, 3, 4, 5, 0, 11, 12, 0, 6, 8, 7, 9, 13, 0, 0, 10, 0]
        self.device = device
        self.dic = torch.tensor(mapper, device=device).unsqueeze(1)
        save_pth = osp.split(osp.realpath(__file__))[0] + '/resnet.pth'

        net = BiSeNet(n_classes=19)
        net.load_state_dict(torch.load(save_pth, map_location=device))
        self.net = net.to(device).eval()
        self.to_tensor = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])


    def parse(self, image: Image):
        assert image.shape[:2] == (512, 512)
        with torch.no_grad():
            image = self.to_tensor(image).to(self.device)
            image = torch.unsqueeze(image, 0)
            out = self.net(image)[0]
            parsing = out.squeeze(0).argmax(0)
        parsing = torch.nn.functional.embedding(parsing, self.dic)
        return parsing.float().squeeze(2)



================================================
FILE: faceutils/mask/model.py
================================================
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

from .resnet import Resnet18


class ConvBNReLU(nn.Module):
    def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
        super(ConvBNReLU, self).__init__()
        self.conv = nn.Conv2d(in_chan,
                out_chan,
                kernel_size = ks,
                stride = stride,
                padding = padding,
                bias = False)
        self.bn = nn.BatchNorm2d(out_chan)
        self.init_weight()

    def forward(self, x):
        x = self.conv(x)
        x = F.relu(self.bn(x))
        return x

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

class BiSeNetOutput(nn.Module):
    def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
        super(BiSeNetOutput, self).__init__()
        self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
        self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
        self.init_weight()

    def forward(self, x):
        x = self.conv(x)
        x = self.conv_out(x)
        return x

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
            if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
                wd_params.append(module.weight)
                if not module.bias is None:
                    nowd_params.append(module.bias)
            elif isinstance(module, nn.BatchNorm2d):
                nowd_params += list(module.parameters())
        return wd_params, nowd_params


class AttentionRefinementModule(nn.Module):
    def __init__(self, in_chan, out_chan, *args, **kwargs):
        super(AttentionRefinementModule, self).__init__()
        self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
        self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
        self.bn_atten = nn.BatchNorm2d(out_chan)
        self.sigmoid_atten = nn.Sigmoid()
        self.init_weight()

    def forward(self, x):
        feat = self.conv(x)
        atten = F.avg_pool2d(feat, feat.size()[2:])
        atten = self.conv_atten(atten)
        atten = self.bn_atten(atten)
        atten = self.sigmoid_atten(atten)
        out = torch.mul(feat, atten)
        return out

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)


class ContextPath(nn.Module):
    def __init__(self, *args, **kwargs):
        super(ContextPath, self).__init__()
        self.resnet = Resnet18()
        self.arm16 = AttentionRefinementModule(256, 128)
        self.arm32 = AttentionRefinementModule(512, 128)
        self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
        self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
        self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)

        self.init_weight()

    def forward(self, x):
        H0, W0 = x.size()[2:]
        feat8, feat16, feat32 = self.resnet(x)
        H8, W8 = feat8.size()[2:]
        H16, W16 = feat16.size()[2:]
        H32, W32 = feat32.size()[2:]

        avg = F.avg_pool2d(feat32, feat32.size()[2:])
        avg = self.conv_avg(avg)
        avg_up = F.interpolate(avg, (H32, W32), mode='nearest')

        feat32_arm = self.arm32(feat32)
        feat32_sum = feat32_arm + avg_up
        feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
        feat32_up = self.conv_head32(feat32_up)

        feat16_arm = self.arm16(feat16)
        feat16_sum = feat16_arm + feat32_up
        feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
        feat16_up = self.conv_head16(feat16_up)

        return feat8, feat16_up, feat32_up  # x8, x8, x16

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                wd_params.append(module.weight)
                if not module.bias is None:
                    nowd_params.append(module.bias)
            elif isinstance(module, nn.BatchNorm2d):
                nowd_params += list(module.parameters())
        return wd_params, nowd_params


### This is not used, since I replace this with the resnet feature with the same size
class SpatialPath(nn.Module):
    def __init__(self, *args, **kwargs):
        super(SpatialPath, self).__init__()
        self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
        self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
        self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
        self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
        self.init_weight()

    def forward(self, x):
        feat = self.conv1(x)
        feat = self.conv2(feat)
        feat = self.conv3(feat)
        feat = self.conv_out(feat)
        return feat

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
            if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
                wd_params.append(module.weight)
                if not module.bias is None:
                    nowd_params.append(module.bias)
            elif isinstance(module, nn.BatchNorm2d):
                nowd_params += list(module.parameters())
        return wd_params, nowd_params


class FeatureFusionModule(nn.Module):
    def __init__(self, in_chan, out_chan, *args, **kwargs):
        super(FeatureFusionModule, self).__init__()
        self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
        self.conv1 = nn.Conv2d(out_chan,
                out_chan//4,
                kernel_size = 1,
                stride = 1,
                padding = 0,
                bias = False)
        self.conv2 = nn.Conv2d(out_chan//4,
                out_chan,
                kernel_size = 1,
                stride = 1,
                padding = 0,
                bias = False)
        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()
        self.init_weight()

    def forward(self, fsp, fcp):
        fcat = torch.cat([fsp, fcp], dim=1)
        feat = self.convblk(fcat)
        atten = F.avg_pool2d(feat, feat.size()[2:])
        atten = self.conv1(atten)
        atten = self.relu(atten)
        atten = self.conv2(atten)
        atten = self.sigmoid(atten)
        feat_atten = torch.mul(feat, atten)
        feat_out = feat_atten + feat
        return feat_out

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
            if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
                wd_params.append(module.weight)
                if not module.bias is None:
                    nowd_params.append(module.bias)
            elif isinstance(module, nn.BatchNorm2d):
                nowd_params += list(module.parameters())
        return wd_params, nowd_params


class BiSeNet(nn.Module):
    def __init__(self, n_classes, *args, **kwargs):
        super(BiSeNet, self).__init__()
        self.cp = ContextPath()
        ## here self.sp is deleted
        self.ffm = FeatureFusionModule(256, 256)
        self.conv_out = BiSeNetOutput(256, 256, n_classes)
        self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
        self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
        # self.init_weight()

    def forward(self, x):
        H, W = x.size()[2:]
        feat_res8, feat_cp8, feat_cp16 = self.cp(x)  # here return res3b1 feature
        feat_sp = feat_res8  # use res3b1 feature to replace spatial path feature
        feat_fuse = self.ffm(feat_sp, feat_cp8)

        feat_out = self.conv_out(feat_fuse)
        feat_out16 = self.conv_out16(feat_cp8)
        feat_out32 = self.conv_out32(feat_cp16)

        feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
        feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
        feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
        return feat_out, feat_out16, feat_out32

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
        for name, child in self.named_children():
            child_wd_params, child_nowd_params = child.get_params()
            if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
                lr_mul_wd_params += child_wd_params
                lr_mul_nowd_params += child_nowd_params
            else:
                wd_params += child_wd_params
                nowd_params += child_nowd_params
        return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params


if __name__ == "__main__":
    net = BiSeNet(19)
    net.cuda()
    net.eval()
    in_ten = torch.randn(16, 3, 640, 480).cuda()
    out, out16, out32 = net(in_ten)
    print(out.shape)

    net.get_params()


================================================
FILE: faceutils/mask/resnet.py
================================================
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as modelzoo

resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    def __init__(self, in_chan, out_chan, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(in_chan, out_chan, stride)
        self.bn1 = nn.BatchNorm2d(out_chan)
        self.conv2 = conv3x3(out_chan, out_chan)
        self.bn2 = nn.BatchNorm2d(out_chan)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = None
        if in_chan != out_chan or stride != 1:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_chan, out_chan,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_chan),
                )

    def forward(self, x):
        residual = self.conv1(x)
        residual = F.relu(self.bn1(residual))
        residual = self.conv2(residual)
        residual = self.bn2(residual)

        shortcut = x
        if self.downsample is not None:
            shortcut = self.downsample(x)

        out = shortcut + residual
        out = self.relu(out)
        return out


def create_layer_basic(in_chan, out_chan, bnum, stride=1):
    layers = [BasicBlock(in_chan, out_chan, stride=stride)]
    for i in range(bnum-1):
        layers.append(BasicBlock(out_chan, out_chan, stride=1))
    return nn.Sequential(*layers)


class Resnet18(nn.Module):
    def __init__(self):
        super(Resnet18, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
        self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
        self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
        self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
        # self.init_weight()

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(self.bn1(x))
        x = self.maxpool(x)

        x = self.layer1(x)
        feat8 = self.layer2(x) # 1/8
        feat16 = self.layer3(feat8) # 1/16
        feat32 = self.layer4(feat16) # 1/32
        return feat8, feat16, feat32

    def init_weight(self):
        state_dict = modelzoo.load_url(resnet18_url)
        self_state_dict = self.state_dict()
        for k, v in state_dict.items():
            if 'fc' in k: continue
            self_state_dict.update({k: v})
        self.load_state_dict(self_state_dict)

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                wd_params.append(module.weight)
                if not module.bias is None:
                    nowd_params.append(module.bias)
            elif isinstance(module,  nn.BatchNorm2d):
                nowd_params += list(module.parameters())
        return wd_params, nowd_params


if __name__ == "__main__":
    net = Resnet18()
    x = torch.randn(16, 3, 224, 224)
    out = net(x)
    print(out[0].size())
    print(out[1].size())
    print(out[2].size())
    net.get_params()


================================================
FILE: models/__init__.py
================================================


================================================
FILE: models/elegant.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F

from .modules.module_base import ResidualBlock_IN, Downsample, Upsample, PositionalEmbedding, MergeBlock
from .modules.module_attn import Attention_apply, FeedForwardLayer, MultiheadAttention 
from .modules.sow_attention import SowAttention
from .modules.tps_transform import tps_spatial_transform


class Generator(nn.ModuleDict):
    """Generator. Encoder-Decoder Architecture."""
    def __init__(self, conv_dim=64, image_size=256, num_layer_e=2, num_layer_d=1, window_size=16, use_ff=False,
                 merge_mode='conv', num_head=1, double_encoder=False, **unused):
        super(Generator, self).__init__()

        # -------------------------- Encoder --------------------------

        layers = nn.Conv2d(3, conv_dim, kernel_size=7, stride=1, padding=3, bias=False)
        self.add_module('in_conv', layers)

        # Down-Sampling & Bottleneck
        curr_dim = conv_dim; feature_size = image_size
        for i in range(2):
            layers = Downsample(curr_dim, curr_dim * 2, affine=True)
            self.add_module('down_{:d}'.format(i+1), layers)
            curr_dim = curr_dim * 2; feature_size = feature_size // 2

            self.add_module('e_bottleneck_{:d}'.format(i+1), 
                nn.Sequential(*[ResidualBlock_IN(curr_dim, curr_dim, affine=True) for j in range(num_layer_e)])
            )

        ### second encoder
        self.double_encoder = double_encoder
        if self.double_encoder:
            layers = nn.Conv2d(3, conv_dim, kernel_size=7, stride=1, padding=3, bias=False)
            self.add_module('in_conv_s', layers)

            # Down-Sampling & Bottleneck
            curr_dim = conv_dim; feature_size = image_size
            for i in range(2):
                layers = Downsample(curr_dim, curr_dim * 2, affine=True)
                self.add_module('down_{:d}_s'.format(i+1), layers)
                curr_dim = curr_dim * 2; feature_size = feature_size // 2

                self.add_module('e_bottleneck_{:d}_s'.format(i+1), 
                    nn.Sequential(*[ResidualBlock_IN(curr_dim, curr_dim, affine=True) for j in range(num_layer_e)])
                )

        # --------------------------- Transfer ----------------------------
        curr_dim = conv_dim; feature_size = image_size
        self.use_ff = use_ff
        for i in range(2):
            curr_dim = curr_dim * 2; feature_size = feature_size // 2
            self.add_module('embedding_{:d}'.format(i+1), PositionalEmbedding(
                embedding_dim=136,
                feature_size=feature_size,
                max_size=image_size,
                embedding_type='l2_norm'
            ))
            if i < 1:
                self.add_module('attention_extract_{:d}'.format(i+1), SowAttention(
                    window_size=window_size,
                    in_channels=curr_dim + 136,
                    proj_channels=curr_dim + 136,
                    value_channels=curr_dim,
                    out_channels=curr_dim,
                    num_heads=num_head
                ))
            else:
                self.add_module('attention_extract_{:d}'.format(i+1), MultiheadAttention(
                    in_channels=curr_dim + 136,
                    proj_channels=curr_dim + 136,
                    value_channels=curr_dim,
                    out_channels=curr_dim,
                    num_heads=num_head
                ))
                
            if use_ff:
                self.add_module('feedforward_{:d}'.format(i+1), FeedForwardLayer(curr_dim, curr_dim))
            self.add_module('attention_apply_{:d}'.format(i+1), Attention_apply(curr_dim))           

        # --------------------------- Decoder ----------------------------

        # Bottleneck & Up-Sampling & Merge
        for i in range(2):
            self.add_module('d_bottleneck_{:d}'.format(i+1), 
                nn.Sequential(*[ResidualBlock_IN(curr_dim, curr_dim, affine=True) for j in range(num_layer_d)])
            )            
            layers = Upsample(curr_dim, curr_dim // 2, affine=True)
            self.add_module('up_{:d}'.format(i+1), layers)
            curr_dim = curr_dim // 2
            if i < 1:
                self.add_module('merge_{:d}'.format(i+1), MergeBlock(merge_mode, curr_dim))

        layers = nn.Sequential(
            nn.InstanceNorm2d(curr_dim, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False),
        )
        self.add_module('out_conv', layers)


    def get_transfer_input(self, image, mask, diff, lms, is_reference=False):
        feature_size = image.shape[2]; scale_factor = 1.0
        fea_list, mask_list, diff_list, lms_list = [], [], [], []

        # input conv
        if self.double_encoder and is_reference:
            fea = self['in_conv_s'](image)
        else:
            fea = self['in_conv'](image)

        # down-sampling & bottleneck
        for i in range(2):
            if self.double_encoder and is_reference:
                fea = self['down_{:d}_s'.format(i+1)](fea)
                fea_ = self['e_bottleneck_{:d}_s'.format(i+1)](fea)
            else:
                fea = self['down_{:d}'.format(i+1)](fea)
                fea_ = self['e_bottleneck_{:d}'.format(i+1)](fea)
            fea_list.append(fea_)
            
            feature_size = feature_size // 2; scale_factor = scale_factor * 0.5
            mask_ = F.interpolate(mask, feature_size, mode='nearest')
            mask_list.append(mask_)

            diff_ = self['embedding_{:d}'.format(i+1)](diff, mask)
            diff_list.append(diff_)
            
            lms_ = lms * scale_factor
            lms_list.append(lms_)
            
        return [fea_list, mask_list, diff_list, lms_list]


    def get_transfer_output(self, fea_c_list, mask_c_list, diff_c_list, lms_c_list,
                            fea_s_list, mask_s_list, diff_s_list, lms_s_list):
        attn_out_list = []
        for i in range(2):
            feature_size = fea_c_list[i].shape[2]

            # align
            if i == 0:
                fea_s_ = self.tps_align(feature_size, lms_s_list[i], lms_c_list[i], fea_s_list[i])
                mask_s_ = self.tps_align(feature_size, lms_s_list[i], lms_c_list[i], mask_s_list[i], 'nearest')
                diff_s_ = self.tps_align(feature_size, lms_s_list[i], lms_c_list[i], diff_s_list[i], 'nearest')
            else:
                fea_s_ = fea_s_list[i]
                mask_s_ = mask_s_list[i]
                diff_s_ = diff_s_list[i]

            # transfer
            input_q = torch.cat((fea_c_list[i], diff_c_list[i]), dim=1)
            input_k = torch.cat((fea_s_, diff_s_), dim=1)
            attn_out = self['attention_extract_{:d}'.format(i+1)](input_q, input_k, fea_s_, mask_c_list[i], mask_s_)
            if self.use_ff:
                attn_out = self['feedforward_{:d}'.format(i+1)](attn_out)
            attn_out_list.append(attn_out)
        
        return attn_out_list

    
    def decode(self, fea_c_list, attn_out_list):
        # apply
        for i in range(2): 
            fea_c_ = self['attention_apply_{:d}'.format(i+1)](fea_c_list[i], attn_out_list[i])
            fea_c_ = self['d_bottleneck_{:d}'.format(2-i)](fea_c_)
            fea_c_list[i] = fea_c_

        # up-sampling & merge
        fea_c = fea_c_list[1]
        for i in range(2):
            fea_c = self['up_{:d}'.format(i+1)](fea_c)
            if i < 1:  
                fea_c = self['merge_{:d}'.format(i+1)](fea_c_list[0], fea_c)

        fea_c = self['out_conv'](fea_c)
        return fea_c

    
    def forward(self, c, s, mask_c, mask_s, diff_c, diff_s, lms_c, lms_s):
        """
        c: content, stands for source image. shape: (b, c, h, w)
        s: style, stands for reference image. shape: (b, c, h, w)
        mask_c: (b, c', h, w)
        diff: (b, d, h, w)
        lms: (b, K, 2)
        """
        transfer_input_c = self.get_transfer_input(c, mask_c, diff_c, lms_c)
        transfer_input_s = self.get_transfer_input(s, mask_s, diff_s, lms_s, True)
        attn_out_list = self.get_transfer_output(*transfer_input_c, *transfer_input_s)
        fea_c = self.decode(transfer_input_c[0], attn_out_list)
        return fea_c


    def tps_align(self, feature_size, lms_s, lms_c, fea_s, sample_mode='bilinear'):
        '''
        fea: (B, C, H, W), lms: (B, K, 2)
        '''
        fea_out = []
        for l_s, l_c, f_s in zip(lms_s, lms_c, fea_s):
            l_c = torch.flip(l_c, dims=[1]) / (feature_size - 1)
            l_s = (torch.flip(l_s, dims=[1]) / (feature_size - 1)).unsqueeze(0)
            f_s = f_s.unsqueeze(0) # (1, C, H, W)
            fea_trans, _ = tps_spatial_transform(feature_size, feature_size, l_c, f_s, l_s, sample_mode)
            fea_out.append(fea_trans)
        return torch.cat(fea_out, dim=0)


================================================
FILE: models/loss.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F

from .modules.histogram_matching import histogram_matching
from .modules.pseudo_gt import fine_align, expand_area, mask_blur


class GANLoss(nn.Module):
    """Define different GAN objectives.
    The GANLoss class abstracts away the need to create the target label tensor
    that has the same size as the input.
    """

    def __init__(self, gan_mode='lsgan', target_real_label=1.0, target_fake_label=0.0):
        """ Initialize the GANLoss class.
        Parameters:
            gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
            target_real_label (bool) - - label for a real image
            target_fake_label (bool) - - label of a fake image
        Note: Do not use sigmoid as the last layer of Discriminator.
        LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
        """
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        self.gan_mode = gan_mode
        if gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
        elif gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        else:
            raise NotImplementedError('gan mode %s not implemented' % gan_mode)

    def forward(self, prediction, target_is_real):
        """Calculate loss given Discriminator's output and grount truth labels.
        Parameters:
            prediction (tensor) - - tpyically the prediction output from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images
        Returns:
            the calculated loss.
        """
        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        target_tensor = target_tensor.expand_as(prediction).to(prediction.device)
        
        loss = self.loss(prediction, target_tensor)
        return loss


def norm(x: torch.Tensor):
    return x * 2 - 1

def de_norm(x: torch.Tensor):
    out = (x + 1) / 2
    return out.clamp(0, 1)

def masked_his_match(image_s, image_r, mask_s, mask_r):
    '''
    image: (3, h, w)
    mask: (1, h, w)
    '''
    index_tmp = torch.nonzero(mask_s)
    x_A_index = index_tmp[:, 1]
    y_A_index = index_tmp[:, 2]
    index_tmp = torch.nonzero(mask_r)
    x_B_index = index_tmp[:, 1]
    y_B_index = index_tmp[:, 2]

    image_s = (de_norm(image_s) * 255) #[-1, 1] -> [0, 255]
    image_r = (de_norm(image_r) * 255)
    
    source_masked = image_s * mask_s
    target_masked = image_r * mask_r
    
    source_match = histogram_matching(
                source_masked, target_masked,
                [x_A_index, y_A_index, x_B_index, y_B_index])
    source_match = source_match.to(image_s.device)
    
    return norm(source_match / 255) #[0, 255] -> [-1, 1]


def generate_pgt(image_s, image_r, mask_s, mask_r, lms_s, lms_r, margins, blend_alphas, img_size=None):
        """
        input_data: (3, h, w)
        mask: (c, h, w), lip, skin, left eye, right eye
        """
        if img_size is None:
            img_size = image_s.shape[1]
        pgt = image_s.detach().clone()

        # skin match
        skin_match = masked_his_match(image_s, image_r, mask_s[1:2], mask_r[1:2])
        pgt = (1 - mask_s[1:2]) * pgt + mask_s[1:2] * skin_match

        # lip match
        lip_match = masked_his_match(image_s, image_r, mask_s[0:1], mask_r[0:1])
        pgt = (1 - mask_s[0:1]) * pgt + mask_s[0:1] * lip_match

        # eye match
        mask_s_eye = expand_area(mask_s[2:4].sum(dim=0, keepdim=True), margins['eye']) * mask_s[1:2]
        mask_r_eye = expand_area(mask_r[2:4].sum(dim=0, keepdim=True), margins['eye']) * mask_r[1:2]
        eye_match = masked_his_match(image_s, image_r, mask_s_eye, mask_r_eye)
        mask_s_eye_blur = mask_blur(mask_s_eye, blur_size=5, mode='valid')
        pgt = (1 - mask_s_eye_blur) * pgt + mask_s_eye_blur * eye_match

        # tps align
        pgt = fine_align(img_size, lms_r, lms_s, image_r, pgt, mask_r, mask_s, margins, blend_alphas)
        return pgt


class LinearAnnealingFn():
    """
    define the linear annealing function with milestones
    """
    def __init__(self, milestones, f_values):
        assert len(milestones) == len(f_values)
        self.milestones = milestones
        self.f_values = f_values
        
    def __call__(self, t:int):
        if t < self.milestones[0]:
            return self.f_values[0]
        elif t >= self.milestones[-1]:
            return self.f_values[-1]
        else:
            for r in range(len(self.milestones) - 1):
                if self.milestones[r] <= t < self.milestones[r+1]:
                    return ((t - self.milestones[r]) * self.f_values[r+1] \
                            + (self.milestones[r+1] - t) * self.f_values[r]) \
                            / (self.milestones[r+1] - self.milestones[r])


class ComposePGT(nn.Module):
    def __init__(self, margins, skin_alpha, eye_alpha, lip_alpha):
        super(ComposePGT, self).__init__()
        self.margins = margins
        self.blend_alphas = {
            'skin':skin_alpha,
            'eye':eye_alpha,
            'lip':lip_alpha
        }

    @torch.no_grad()
    def forward(self, sources, targets, mask_srcs, mask_tars, lms_srcs, lms_tars):
        pgts = []
        for source, target, mask_src, mask_tar, lms_src, lms_tar in\
            zip(sources, targets, mask_srcs, mask_tars, lms_srcs, lms_tars):
            pgt = generate_pgt(source, target, mask_src, mask_tar, lms_src, lms_tar, 
                               self.margins, self.blend_alphas)
            pgts.append(pgt)
        pgts = torch.stack(pgts, dim=0)
        return pgts   

class AnnealingComposePGT(nn.Module):
    def __init__(self, margins,
            skin_alpha_milestones, skin_alpha_values,
            eye_alpha_milestones, eye_alpha_values,
            lip_alpha_milestones, lip_alpha_values
        ):
        super(AnnealingComposePGT, self).__init__()
        self.margins = margins
        self.skin_alpha_fn = LinearAnnealingFn(skin_alpha_milestones, skin_alpha_values)
        self.eye_alpha_fn = LinearAnnealingFn(eye_alpha_milestones, eye_alpha_values)
        self.lip_alpha_fn = LinearAnnealingFn(lip_alpha_milestones, lip_alpha_values)
        
        self.t = 0
        self.blend_alphas = {}
        self.step()

    def step(self):
        self.t += 1
        self.blend_alphas['skin'] = self.skin_alpha_fn(self.t)
        self.blend_alphas['eye'] = self.eye_alpha_fn(self.t)
        self.blend_alphas['lip'] = self.lip_alpha_fn(self.t)

    @torch.no_grad()
    def forward(self, sources, targets, mask_srcs, mask_tars, lms_srcs, lms_tars):
        pgts = []
        for source, target, mask_src, mask_tar, lms_src, lms_tar in\
            zip(sources, targets, mask_srcs, mask_tars, lms_srcs, lms_tars):
            pgt = generate_pgt(source, target, mask_src, mask_tar, lms_src, lms_tar,
                               self.margins, self.blend_alphas)
            pgts.append(pgt)
        pgts = torch.stack(pgts, dim=0)
        return pgts   


class MakeupLoss(nn.Module):
    """
    Define the makeup loss w.r.t pseudo ground truth
    """
    def __init__(self):
        super(MakeupLoss, self).__init__()

    def forward(self, x, target, mask=None):
        if mask is None:
            return F.l1_loss(x, target)
        else:
            return F.l1_loss(x * mask, target * mask)

================================================
FILE: models/model.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import VGG as TVGG
from torchvision.models.vgg import load_state_dict_from_url, model_urls, cfgs

from .modules.spectral_norm import spectral_norm as SpectralNorm
from .elegant import Generator


def get_generator(config):
    kwargs = {
        'conv_dim':config.MODEL.G_CONV_DIM,
        'image_size':config.DATA.IMG_SIZE,
        'num_head':config.MODEL.NUM_HEAD,
        'double_encoder':config.MODEL.DOUBLE_E,
        'use_ff':config.MODEL.USE_FF,
        'num_layer_e':config.MODEL.NUM_LAYER_E,
        'num_layer_d':config.MODEL.NUM_LAYER_D,
        'window_size':config.MODEL.WINDOW_SIZE,
        'merge_mode':config.MODEL.MERGE_MODE
    }
    G = Generator(**kwargs)
    return G


def get_discriminator(config):
    kwargs = {
        'input_channel': 3,
        'conv_dim':config.MODEL.D_CONV_DIM,
        'num_layers':config.MODEL.D_REPEAT_NUM,
        'norm':config.MODEL.D_TYPE
    }
    D = Discriminator(**kwargs)
    return D


class Discriminator(nn.Module):
    """Discriminator. PatchGAN."""
    def __init__(self, input_channel=3, conv_dim=64, num_layers=3, norm='SN', **unused):
        super(Discriminator, self).__init__()

        layers = []
        if norm=='SN':
            layers.append(SpectralNorm(nn.Conv2d(input_channel, conv_dim, kernel_size=4, stride=2, padding=1)))
        else:
            layers.append(nn.Conv2d(input_channel, conv_dim, kernel_size=4, stride=2, padding=1))
        layers.append(nn.LeakyReLU(0.01, inplace=True))

        curr_dim = conv_dim
        for i in range(1, num_layers):
            if norm=='SN':
                layers.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1)))
            else:
                layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
            layers.append(nn.LeakyReLU(0.01, inplace=True))
            curr_dim = curr_dim * 2

        #k_size = int(image_size / np.power(2, repeat_num))
        if norm=='SN':
            layers.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=1, padding=1)))
        else:
            layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=1, padding=1))
        layers.append(nn.LeakyReLU(0.01, inplace=True))
        curr_dim = curr_dim * 2

        self.main = nn.Sequential(*layers)
        if norm=='SN':
            self.conv1 = SpectralNorm(nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False))
        else:
            self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False)

    def forward(self, x):
        h = self.main(x)
        out_makeup = self.conv1(h)
        return out_makeup


class VGG(TVGG):
    def forward(self, x):
        x = self.features(x)
        return x


def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)


def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
    if pretrained:
        kwargs['init_weights'] = False
    model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model


def vgg16(pretrained=False, progress=True, **kwargs):
    r"""VGG 16-layer model (configuration "D")
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)

================================================
FILE: models/modules/__init__.py
================================================


================================================
FILE: models/modules/histogram_matching.py
================================================
import copy
import torch

def cal_hist(image):
    """
        cal cumulative hist for channel list
    """
    hists = []
    for i in range(0, 3):
        channel = image[i]
        # channel = image[i, :, :]
        channel = torch.from_numpy(channel)
        # hist, _ = np.histogram(channel, bins=256, range=(0,255))
        hist = torch.histc(channel, bins=256, min=0, max=256)
        hist = hist.numpy()
        # refHist=hist.view(256,1)
        sum = hist.sum()
        pdf = [v / (sum + 1e-10) for v in hist]
        for i in range(1, 256):
            pdf[i] = pdf[i - 1] + pdf[i]
        hists.append(pdf)
    return hists


def cal_trans(ref, adj):
    """
        calculate transfer function
        algorithm refering to wiki item: Histogram matching
    """
    table = list(range(0, 256))
    for i in list(range(1, 256)):
        for j in list(range(1, 256)):
            if ref[i] >= adj[j - 1] and ref[i] <= adj[j]:
                table[i] = j
                break
    table[255] = 255
    return table

def histogram_matching(dstImg, refImg, index):
    """
        perform histogram matching
        dstImg is transformed to have the same the histogram with refImg's
        index[0], index[1]: the index of pixels that need to be transformed in dstImg
        index[2], index[3]: the index of pixels that to compute histogram in refImg
    """
    index = [x.cpu().numpy() for x in index]
    dstImg = dstImg.detach().cpu().numpy()
    refImg = refImg.detach().cpu().numpy()
    dst_align = [dstImg[i, index[0], index[1]] for i in range(0, 3)]
    ref_align = [refImg[i, index[2], index[3]] for i in range(0, 3)]
    hist_ref = cal_hist(ref_align)
    hist_dst = cal_hist(dst_align)
    tables = [cal_trans(hist_dst[i], hist_ref[i]) for i in range(0, 3)]

    mid = copy.deepcopy(dst_align)
    for i in range(0, 3):
        for k in range(0, len(index[0])):
            dst_align[i][k] = tables[i][int(mid[i][k])]

    for i in range(0, 3):
        dstImg[i, index[0], index[1]] = dst_align[i]

    dstImg = torch.FloatTensor(dstImg)
    return dstImg

================================================
FILE: models/modules/module_attn.py
================================================
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class MultiheadAttention_weight(nn.Module):
    def __init__(self, feature_dim, proj_dim, num_heads=1, dropout=0.0, bias=True):
        super(MultiheadAttention_weight, self).__init__()
        self.feature_dim = feature_dim
        self.proj_dim = proj_dim
        self.num_heads = num_heads
        self.dropout = nn.Dropout(dropout)
        self.head_dim = proj_dim // num_heads
        assert self.head_dim * num_heads == self.proj_dim, "embed_dim must be divisible by num_heads"
        self.scaling = self.head_dim ** -0.5

        self.q_proj = nn.Linear(feature_dim, proj_dim, bias=bias)
        self.k_proj = nn.Linear(feature_dim, proj_dim, bias=bias)

    def forward(self, fea_c, fea_s, mask_c, mask_s):
        '''
        fea_c: (b, d, h, w)
        mask_c: (b, c, h, w)
        '''
        bsz, dim, h, w = fea_c.shape; mask_channel = mask_c.shape[1]

        fea_c = fea_c.view(bsz, dim, h*w).transpose(1, 2) # (b, HW, d)
        fea_s = fea_s.view(bsz, dim, h*w).transpose(1, 2)
        with torch.no_grad():
            if mask_c.shape[2] != h:
                mask_c = F.interpolate(mask_c, size=(h, w)) 
                mask_s = F.interpolate(mask_s, size=(h, w)) 
            mask_c = mask_c.view(bsz, mask_channel, -1, h*w) # (b, m_c, 1, HW)
            mask_s = mask_s.view(bsz, mask_channel, -1, h*w)
            mask_attn = torch.matmul(mask_c.transpose(-2, -1), mask_s) # (b, m_c, HW, HW)
            mask_attn = torch.sum(mask_attn, dim=1, keepdim=True).clamp_(0, 1) # (b, 1, HW, HW)
            mask_sum = torch.sum(mask_attn, dim=-1, keepdim=True)
            mask_attn += (mask_sum == 0).float()
            mask_attn = mask_attn.masked_fill_(mask_attn == 0, float('-inf')).masked_fill_(mask_attn == 1, float(0.0))

        query = self.q_proj(fea_c) # (b, HW, D)
        key = self.k_proj(fea_s) # (b, HW, D)
        query = query.view(bsz, h*w, self.num_heads, self.head_dim).transpose(1, 2) # (b, h, HW, D)
        key = key.view(bsz, h*w, self.num_heads, self.head_dim).transpose(1, 2)

        weights = torch.matmul(query, key.transpose(-1, -2)) # (b, h, HW, HW)
        weights = weights * self.scaling
        weights = weights + mask_attn.detach()
        weights = self.dropout(F.softmax(weights, dim=-1))
        weights = weights * (1 - (mask_sum == 0).float().detach())
        return weights 


class MultiheadAttention_value(nn.Module):
    def __init__(self, feature_dim, proj_dim, num_heads=1, bias=True):
        super(MultiheadAttention_value, self).__init__()
        self.feature_dim = feature_dim
        self.proj_dim = proj_dim
        self.num_heads = num_heads
        self.head_dim = proj_dim // num_heads
        assert self.head_dim * num_heads == self.proj_dim, "embed_dim must be divisible by num_heads"
        
        self.v_proj = nn.Linear(feature_dim, proj_dim, bias=bias)

    def forward(self, weights, fea):
        '''
        weights: (b, h, HW. HW)
        fea: (b, d, H, W)
        '''
        bsz, dim, h, w = fea.shape
        fea = fea.view(bsz, dim, h*w).transpose(1, 2) #(b, HW, D)
        value = self.v_proj(fea)
        value = value.view(bsz, h*w, self.num_heads, self.head_dim).transpose(1, 2) #(b, h, HW, D)

        out = torch.matmul(weights, value)
        out = out.transpose(1, 2).contiguous().view(bsz, h*w, self.proj_dim) # (b, HW, D)
        out = out.transpose(1, 2).view(bsz, self.proj_dim, h, w) #(b, d, H, W)
        return out


class MultiheadAttention(nn.Module):
    def __init__(self, in_channels, proj_channels, value_channels, out_channels, num_heads=1, dropout=0.0, bias=True):
        super(MultiheadAttention, self).__init__()
        self.weight = MultiheadAttention_weight(in_channels, proj_channels, num_heads, dropout, bias)
        self.value = MultiheadAttention_value(value_channels, out_channels, num_heads, bias)

    def forward(self, fea_q, fea_k, fea_v, mask_q, mask_k):
        '''
        fea: (b, d, h, w)
        mask: (b, c, h, w)
        '''
        weights = self.weight(fea_q, fea_k, mask_q, mask_k)
        return self.value(weights, fea_v)


class FeedForwardLayer(nn.Module):
    def __init__(self, feature_dim, ff_dim, dropout=0.0):
        super(FeedForwardLayer, self).__init__()
        self.main = nn.Sequential(
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(p=dropout, inplace=True),
            nn.Conv2d(feature_dim, ff_dim, kernel_size=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ff_dim, feature_dim, kernel_size=1)
        )

    def forward(self, x):
        return self.main(x)


class Attention_apply(nn.Module):
    def __init__(self, feature_dim, normalize=True):
        super(Attention_apply, self).__init__()
        self.normalize = normalize
        if normalize:
            self.norm = nn.InstanceNorm2d(feature_dim, affine=False)
        self.actv = nn.LeakyReLU(0.2, inplace=True)
        self.conv = nn.Conv2d(feature_dim, feature_dim, kernel_size=3, stride=1, padding=1, bias=False)

    def forward(self, x, attn_out):
        if self.normalize:
            x = self.norm(x) 
        x = x * (1 + attn_out)
        return self.conv(self.actv(x))


================================================
FILE: models/modules/module_base.py
================================================
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class ResidualBlock(nn.Module):
    """Residual Block."""
    def __init__(self, dim_in, dim_out):            
        super(ResidualBlock, self).__init__()
        self.main = nn.Sequential(
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False)
        )
        self.skip = nn.Identity() if dim_in == dim_out else nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=False)

    def forward(self, x):
        x = self.skip(x) + self.main(x)
        return x / math.sqrt(2)


class ResidualBlock_IN(nn.Module):
    """Residual Block with InstanceNorm."""
    def __init__(self, dim_in, dim_out, affine=False):            
        super(ResidualBlock_IN, self).__init__()
        self.main = nn.Sequential(
            nn.InstanceNorm2d(dim_in, affine=affine),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=affine),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
        )
        self.skip = nn.Identity() if dim_in == dim_out else nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=False)

    def forward(self, x):
        x = self.skip(x) + self.main(x)
        return x / math.sqrt(2)


class ResidualBlock_Downsample(nn.Module):
    """Residual Block with InstanceNorm."""
    def __init__(self, dim_in, dim_out, affine=False):            
        super(ResidualBlock_Downsample, self).__init__()
        self.main = nn.Sequential(
            nn.InstanceNorm2d(dim_in, affine=affine),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(dim_in, dim_out, kernel_size=4, stride=2, padding=1, bias=False)    
        )
        if dim_in == dim_out:
            self.skip = nn.Identity()
        else:
            self.skip = nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=1, bias=False)

    def forward(self, x):
        skip = F.interpolate(self.skip(x), scale_factor=0.5, mode='bilinear', align_corners=False, recompute_scale_factor=True)
        res = self.main(x)
        x = skip + res
        return x / math.sqrt(2)


class Downsample(nn.Module):
    """Residual Block with InstanceNorm."""
    def __init__(self, dim_in, dim_out, affine=False):            
        super(Downsample, self).__init__()
        self.main = nn.Sequential(
            nn.InstanceNorm2d(dim_in, affine=affine),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(dim_in, dim_out, kernel_size=4, stride=2, padding=1, bias=False)           
        )

    def forward(self, x):
        return self.main(x)


class ResidualBlock_Upsample(nn.Module):
    """Residual Block with InstanceNorm."""
    def __init__(self, dim_in, dim_out, normalize=True, affine=False):            
        super(ResidualBlock_Upsample, self).__init__()
        if normalize:
            self.main = nn.Sequential(
                nn.InstanceNorm2d(dim_in, affine=affine),
                nn.LeakyReLU(0.2, inplace=True),
                nn.ConvTranspose2d(dim_in, dim_out, kernel_size=4, stride=2, padding=1, bias=False)
            )
        else:
            self.main = nn.Sequential(
                nn.LeakyReLU(0.2, inplace=True),
                nn.ConvTranspose2d(dim_in, dim_out, kernel_size=4, stride=2, padding=1, bias=False)
            )
        if dim_in == dim_out:
            self.skip = nn.Identity()
        else:
            self.skip = nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=1, bias=False)

    def forward(self, x):
        skip = F.interpolate(self.skip(x), scale_factor=2, mode='bilinear', align_corners=False)
        res = self.main(x)
        x = skip + res
        return x / math.sqrt(2)


class Upsample(nn.Module):
    """Residual Block with InstanceNorm."""
    def __init__(self, dim_in, dim_out, normalize=True, affine=False):            
        super(Upsample, self).__init__()
        if normalize:
            self.main = nn.Sequential(
                nn.InstanceNorm2d(dim_in, affine=affine),
                nn.LeakyReLU(0.2, inplace=True),
                nn.ConvTranspose2d(dim_in, dim_out, kernel_size=4, stride=2, padding=1, bias=False)
            )
        else:
            self.main = nn.Sequential(
                nn.LeakyReLU(0.2, inplace=True),
                nn.ConvTranspose2d(dim_in, dim_out, kernel_size=4, stride=2, padding=1, bias=False)
            )

    def forward(self, x):
        return self.main(x)


class PositionalEmbedding(nn.Module):
    def __init__(self, embedding_dim=136, feature_size=64, max_size=None, embedding_type='l2_norm'):
        super(PositionalEmbedding, self).__init__()
        self.embedding_dim = embedding_dim
        self.feature_size = feature_size
        self.max_size = max_size
        assert embedding_type in ['l2_norm', 'uniform', 'sin']
        self.embedding_type = embedding_type

    @torch.no_grad()
    def forward(self, diff, mask):
        '''
        diff: (b, d, h, w)
        mask: (b, 3, h, w)
        return: (b, d, h, w)
        '''
        bsz, init_dim, init_size, _ = diff.shape
        assert self.embedding_dim >= init_dim
        diff = F.interpolate(diff, self.feature_size) # (b, d, h, w)
        mask = F.interpolate(mask, size=self.feature_size)
        mask = torch.sum(mask, dim=1, keepdim=True) # (b, 1, h, w)
        diff = diff * mask
        
        if self.embedding_type == 'l2_norm':
            norm = torch.norm(diff, dim=1, keepdim=True)
            norm = (norm == 0) + norm
            diff = diff / norm
        elif self.embedding_type == 'uniform':
            diff = diff / self.max_size
        elif self.embedding_type == 'sin':
            diff = torch.sin(diff * math.pi / (2 * self.max_size))
        
        if self.embedding_dim > init_dim:
            zero_shape = (bsz, self.embedding_dim - init_dim, self.feature_size, self.feature_size)
            zero_padding = torch.zeros(zero_shape, device=diff.device)
            diff = torch.cat((diff, zero_padding), dim=1)

        diff = diff.detach(); diff.requires_grad = False
        return diff

class MergeBlock(nn.Module):
    def __init__(self, merge_mode, feature_dim, normalize=True):
        super(MergeBlock, self).__init__()
        assert merge_mode in ['conv', 'add', 'affine']
        self.merge_mode = merge_mode
        if merge_mode == 'affine':
            self.norm = nn.LayerNorm(feature_dim, elementwise_affine=False) if normalize else nn.Identity()
        else:
            self.norm = nn.InstanceNorm2d(feature_dim, affine=False) if normalize else nn.Identity()
        self.norm_r = nn.InstanceNorm2d(feature_dim, affine=False) if normalize else nn.Identity()
        self.actv = nn.LeakyReLU(0.2, inplace=True)
        if merge_mode == 'conv':
            self.conv = nn.Conv2d(2 * feature_dim, feature_dim, kernel_size=3, stride=1, padding=1, bias=False)
        else:
            self.conv = nn.Conv2d(feature_dim, feature_dim, kernel_size=3, stride=1, padding=1, bias=False)

    def forward(self, fea_s, fea_r):
        if self.merge_mode == 'conv':
            fea_s = self.norm(fea_s)
            fea_r = self.norm_r(fea_r)
            fea_s = torch.cat((fea_s, fea_r), dim=1)
        elif self.merge_mode == 'add':
            fea_s = self.norm(fea_s)
            fea_r = self.norm_r(fea_r)
            fea_s = (fea_s + fea_r) / math.sqrt(2)
        elif self.merge_mode == 'affine':
            fea_s = fea_s.permute(0, 2, 3, 1)
            fea_s = self.norm(fea_s)
            fea_s = fea_s.permute(0, 3, 1, 2)
            fea_s = fea_s * (1 + fea_r)
        return self.conv(self.actv(fea_s))

================================================
FILE: models/modules/pseudo_gt.py
================================================
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import functional

from models.modules.tps_transform import tps_sampler, tps_spatial_transform


def expand_area(mask:torch.Tensor, margin:int):
    '''
    mask: (C, H, W) or (N, C, H, W)
    '''
    kernel = np.zeros((margin * 2 + 1, margin * 2 + 1), dtype=np.uint8)
    kernel = cv2.circle(kernel, (margin, margin), margin, (255, 0, 0), -1)
    kernel = torch.FloatTensor((kernel > 0)).unsqueeze(0).unsqueeze(0).to(mask.device)
    ndim = mask.ndimension()
    if ndim == 3:
        mask = mask.unsqueeze(0)
    expanded_mask = torch.zeros_like(mask)
    for i in range(mask.shape[1]):
        expanded_mask[:,i:i+1,:,:] = F.conv2d(mask[:,i:i+1,:,:], kernel, padding=margin)
    if ndim == 3:
        expanded_mask = expanded_mask.squeeze(0)
    return (expanded_mask > 0).float()

def mask_blur(mask:torch.Tensor, blur_size=3, mode='smooth'):
    """Blur the edge of mask so that the compose image have smooth transition
    Args:
        mask (torch.Tensor): [C, H, W]
        blur_size (int): size of blur kernel. Defaults to 3.
        mode (str) Defaults to 'smooth'.
    Returns:
        torch.Tensor: blurred mask
    """
    #kernel = torch.ones((1, 1, blur_size * 2 + 1, blur_size * 2 + 1)).to(mask.device)
    kernel = np.zeros((blur_size * 2 + 1, blur_size * 2 + 1), dtype=np.uint8)
    kernel = cv2.circle(kernel, (blur_size, blur_size), blur_size, (255, 0, 0), -1)
    kernel = torch.FloatTensor((kernel > 0)).unsqueeze(0).unsqueeze(0).to(mask.device)
    kernel = kernel / torch.sum(kernel)
    ndim = mask.ndimension()
    if ndim == 3:
        mask = mask.unsqueeze(0)
    mask_blur = torch.zeros_like(mask)
    for i in range(mask.shape[1]):
        mask_blur[:,i:i+1,:,:] = F.conv2d(mask[:,i:i+1,:,:], kernel, padding=blur_size)
    if mode == 'valid':
        mask_blur = (mask_blur.clamp(0.5, 1) - 0.5) * 2 * mask
    if ndim == 3:
        mask_blur = mask_blur.squeeze(0)
    return mask_blur.clamp(0, 1)

def mask_blend(mask, blend_alpha, mask_bound=None, blur_size=3, blend_mode='smooth'):
    if blur_size > 0:
        mask = mask_blur(mask, blur_size, blend_mode)
    mask = mask * blend_alpha
    if mask_bound is None:
        return mask
    else:
        return mask * mask_bound


def tps_align(img_size, lms_r, lms_s, image_r, image_s=None, 
              mask_r = None, mask_s=None, sample_mode='bilinear'):
    '''
    image: (C, H, W), lms: (K, 2), mask:(1, H, W)
    '''
    lms_s = torch.flip(lms_s, dims=[1]) / (img_size - 1)
    lms_r = (torch.flip(lms_r, dims=[1]) / (img_size - 1)).unsqueeze(0)
    image_r = image_r.unsqueeze(0)
    image_trans, _ = tps_spatial_transform(img_size, img_size, lms_s, image_r, lms_r, sample_mode)
    if mask_r is not None:
        mask_r_trans, _ = tps_spatial_transform(img_size, img_size, lms_s, mask_r.unsqueeze(0), 
                                                lms_r, 'nearest')
    if image_s is not None:
        mask_compose = torch.ones((1, img_size, img_size), device=lms_r.device)
        if mask_s is not None:
            mask_compose *= mask_s
        if mask_r is not None:
            mask_compose *= mask_r_trans.squeeze(0)
        return image_s * (1 - mask_compose) + image_trans.squeeze(0) * mask_compose
    else:
        return image_trans.squeeze(0)

def tps_blend(blend_alpha, img_size, lms_r, lms_s, image_r, image_s=None, mask_r = None, mask_s=None, 
              mask_s_bound=None, blur_size=7, sample_mode='bilinear', blend_mode='smooth'):
    '''
    image: (C, H, W), lms: (K, 2), mask:(1, H, W)
    '''
    lms_s = torch.flip(lms_s, dims=[1]) / (img_size - 1)
    lms_r = (torch.flip(lms_r, dims=[1]) / (img_size - 1)).unsqueeze(0)
    image_r = image_r.unsqueeze(0)
    image_trans, _ = tps_spatial_transform(img_size, img_size, lms_s, image_r, lms_r, sample_mode)
    if mask_r is not None:
        mask_r_trans, _ = tps_spatial_transform(img_size, img_size, lms_s, mask_r.unsqueeze(0), 
                                                lms_r, 'nearest')
    if image_s is not None:
        mask_compose = torch.ones((1, img_size, img_size), device=lms_r.device)
        if mask_s is not None:
            mask_compose *= mask_s
        if mask_r is not None:
            mask_compose *= mask_r_trans.squeeze(0)
        mask_compose = mask_blend(mask_compose, blend_alpha, mask_s_bound, blur_size, blend_mode)
        return image_s * (1 - mask_compose) + image_trans.squeeze(0) * mask_compose
    else:
        return image_trans.squeeze(0)


def fine_align(img_size, lms_r, lms_s, image_r, image_s, mask_r, mask_s, margins, blend_alphas):
    '''
    image: (C, H, W), lms: (K, 2)
    mask: (C, H, W), lip, face, left eye, right eye
    margins: dictionary, blend_alphas: dictionary
    '''
    # skin align
    image_s = tps_blend(blend_alphas['skin'], img_size, lms_r[:60], lms_s[:60], image_r, image_s, 
                        mask_r[1:2], mask_s[1:2], mask_s[1:2], blur_size=8, blend_mode='valid')

    # lip align
    mask_s_lip = expand_area(mask_s[0:1], margins['lip'])
    mask_r_lip = expand_area(mask_r[0:1], margins['lip'])
    image_s = tps_blend(blend_alphas['lip'], img_size, lms_r[48:], lms_s[48:], image_r, image_s, 
                        mask_r_lip, mask_s_lip, mask_s[0:1], blur_size=3)

    # left eye align
    mask_s_eye = expand_area(mask_s[2:3], margins['eye'])
    mask_r_eye = expand_area(mask_r[2:3], margins['eye']) * mask_r[1:2]
    image_s = tps_blend(blend_alphas['eye'], img_size, 
                        torch.cat((lms_r[14:17], lms_r[22:27], lms_r[27:31], lms_r[42:48]), dim=0), 
                        torch.cat((lms_s[14:17], lms_s[22:27], lms_s[27:31], lms_s[42:48]), dim=0), 
                        image_r, image_s, mask_r_eye, mask_s_eye, mask_s[1:2], 
                        blur_size=5, sample_mode='nearest')

    # right eye align
    mask_s_eye = expand_area(mask_s[3:4], margins['eye'])
    mask_r_eye = expand_area(mask_r[3:4], margins['eye']) * mask_r[1:2]
    image_s = tps_blend(blend_alphas['eye'], img_size, 
                        torch.cat((lms_r[0:3], lms_r[17:22], lms_r[27:31], lms_r[36:42]), dim=0), 
                        torch.cat((lms_s[0:3], lms_s[17:22], lms_s[27:31], lms_s[36:42]), dim=0), 
                        image_r, image_s, mask_r_eye, mask_s_eye, mask_s[1:2], 
                        blur_size=5, sample_mode='nearest')

    return image_s


if __name__ == "__main__":
    pass

================================================
FILE: models/modules/sow_attention.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F


class WindowAttention(nn.Module):
    def __init__(self, window_size, in_channels, proj_channels, value_channels, out_channels, 
                 num_heads=1, dropout=0.0, bias=True, weighted_output=True):
        super(WindowAttention, self).__init__()
        assert window_size % 2 == 0
        self.window_size = window_size
        self.weighted_output = weighted_output
        window_weight = self.generate_window_weight()
        self.register_buffer('window_weight', window_weight)

        self.num_heads = num_heads
        self.dropout = nn.Dropout(dropout)
        self.in_channels = in_channels
        self.proj_channels = proj_channels
        head_dim = proj_channels // num_heads
        assert head_dim * num_heads == self.proj_channels, "embed_dim must be divisible by num_heads"
        self.scaling = head_dim ** -0.5

        self.q_proj = nn.Conv2d(in_channels, proj_channels, kernel_size=1, bias=bias)
        self.k_proj = nn.Conv2d(in_channels, proj_channels, kernel_size=1, bias=bias)

        self.value_channels = value_channels
        self.out_channels = out_channels
        assert out_channels // num_heads * num_heads == self.out_channels
        self.v_proj = nn.Conv2d(value_channels, out_channels, kernel_size=1, bias=bias)

    @torch.no_grad()
    def generate_window_weight(self):
        yc = torch.arange(self.window_size // 2).unsqueeze(1).repeat(1, self.window_size // 2)
        xc = torch.arange(self.window_size // 2).unsqueeze(0).repeat(self.window_size // 2, 1)
        window_weight = xc * yc / (self.window_size // 2 - 1) ** 2
        window_weight = torch.cat((window_weight, torch.flip(window_weight, dims=[0])), dim=0)
        window_weight = torch.cat((window_weight, torch.flip(window_weight, dims=[1])), dim=1)
        return window_weight.view(-1)   

    def make_window(self, x: torch.Tensor):
        """
        input: (B, C, H, W)
        output: (B, h, H/S, W/S, S*S, C/h)
        """
        bsz, dim, h, w = x.shape
        x = x.view(bsz, self.num_heads, dim // self.num_heads, h // self.window_size, self.window_size, 
                   w // self.window_size, self.window_size)
        x = x.transpose(4, 5).contiguous().view(bsz, self.num_heads, dim // self.num_heads, 
                                                h // self.window_size, w // self.window_size, self.window_size**2)
        x = x.permute(0, 1, 3, 4, 5, 2)
        return x

    def demake_window(self, x: torch.Tensor):
        """
        input: (B, h, H/S, W/S, S*S, C/h)
        output: (B, C, H, W)
        """
        bsz, _, h_s, w_s, _, dim_h = x.shape
        x = x.permute(0, 1, 5, 2, 3, 4).contiguous()
        #print(x.shape)
        x = x.view(bsz, dim_h * self.num_heads, h_s, w_s, self.window_size, self.window_size)
        #print(x.shape)
        x = x.transpose(3, 4).contiguous().view(bsz, dim_h * self.num_heads, 
                                                h_s * self.window_size, w_s * self.window_size)
        #print(x.shape)
        return x

    @torch.no_grad()
    def make_mask_window(self, mask: torch.Tensor):
        """
        input: (B, C, H, W)
        output: (B, 1, H/S, W/S, S*S, C)
        """
        bsz, mask_channel, h, w = mask.shape
        mask = mask.view(bsz, 1, mask_channel, h // self.window_size, self.window_size, 
                         w // self.window_size, self.window_size)
        mask = mask.transpose(4, 5).contiguous().view(bsz, 1, mask_channel, 
                                                      h // self.window_size, w // self.window_size, self.window_size**2)
        mask = mask.permute(0, 1, 3, 4, 5, 2)
        return mask
    
    def forward(self, fea_q, fea_k, fea_v, mask_q=None, mask_k=None):
        '''
        fea: (b, d, h, w)
        mask: (b, c, h, w)
        '''
        query = self.q_proj(fea_q) # (B, D, H, W)
        key = self.k_proj(fea_k)
        value = self.v_proj(fea_v)
        query = self.make_window(query) # (B, h, H/S, W/S, S*S, D/h)
        key = self.make_window(key)
        value = self.make_window(value)
        
        weights = torch.matmul(query, key.transpose(-1, -2)) # (B, h, H/S, W/S, S*S, S*S)
        weights = weights * self.scaling
        if mask_q is not None and mask_k is not None:
            mask_q = self.make_mask_window(mask_q) # (B, 1, H/S, W/S, S*S, C)
            mask_k = self.make_mask_window(mask_k)
            with torch.no_grad():
                mask_attn = torch.matmul(mask_q, mask_k.transpose(-1, -2))
                mask_sum = torch.sum(mask_attn, dim=-1, keepdim=True)
                mask_attn += (mask_sum == 0).float()
                mask_attn = mask_attn.masked_fill_(mask_attn == 0, float('-inf')).masked_fill_(mask_attn == 1, float(0.0))
            weights += mask_attn        

        weights = self.dropout(F.softmax(weights, dim=-1))
        if mask_q is not None and mask_k is not None:
            weights = weights * (1 - (mask_sum == 0).float().detach())

        out = torch.matmul(weights, value) # (B, h, H/S, W/S, S*S, D/h)
        if self.weighted_output:
            window_weight = self.window_weight.view(1, 1, 1, 1, self.window_size ** 2, 1)
            out = out * window_weight
        out = self.demake_window(out) #(B, D, H, W)
        return out

class SowAttention(nn.Module):
    def __init__(self, window_size, in_channels, proj_channels, value_channels, out_channels, 
                 num_heads=1, dropout=0.0, bias=True):
        super(SowAttention, self).__init__()
        assert window_size % 2 == 0
        self.window_size = window_size
        self.pad = nn.ZeroPad2d(window_size // 2)
        self.window_attention = WindowAttention(window_size, in_channels, proj_channels, value_channels,
                                            out_channels, num_heads, dropout, bias)

    def forward(self, fea_q, fea_k, fea_v, mask_q=None, mask_k=None):
        '''
        fea: (b, d, h, w)
        mask: (b, c, h, w)
        '''
        out_0 = self.window_attention(fea_q, fea_k, fea_v, mask_q, mask_k)
        
        fea_q = self.pad(fea_q)
        fea_k = self.pad(fea_k)
        fea_v = self.pad(fea_v)
        if mask_q is not None and mask_k is not None:
            mask_q = self.pad(mask_q)
            mask_k = self.pad(mask_k)
        else:
            mask_q = None; mask_k = None
        
        out_1 = self.window_attention(fea_q, fea_k, fea_v, mask_q, mask_k)
        out_1 = out_1[:, :, self.window_size//2:-self.window_size//2, self.window_size//2:-self.window_size//2]
        
        if mask_q is not None and mask_k is not None:
            out_2 = self.window_attention(
                fea_q[:, :, :, self.window_size//2:-self.window_size//2],
                fea_k[:, :, :, self.window_size//2:-self.window_size//2],
                fea_v[:, :, :, self.window_size//2:-self.window_size//2],
                mask_q[:, :, :, self.window_size//2:-self.window_size//2],
                mask_k[:, :, :, self.window_size//2:-self.window_size//2]
            )
        else:
            out_2 = self.window_attention(
                fea_q[:, :, :, self.window_size//2:-self.window_size//2],
                fea_k[:, :, :, self.window_size//2:-self.window_size//2],
                fea_v[:, :, :, self.window_size//2:-self.window_size//2],
            )
        out_2 = out_2[:, :, self.window_size//2:-self.window_size//2, :]

        if mask_q is not None and mask_k is not None:
            out_3 = self.window_attention(
                fea_q[:, :, self.window_size//2:-self.window_size//2, :],
                fea_k[:, :, self.window_size//2:-self.window_size//2, :],
                fea_v[:, :, self.window_size//2:-self.window_size//2, :],
                mask_q[:, :, self.window_size//2:-self.window_size//2, :],
                mask_k[:, :, self.window_size//2:-self.window_size//2, :]
            )
        else:
            out_3 = self.window_attention(
                fea_q[:, :, self.window_size//2:-self.window_size//2, :],
                fea_k[:, :, self.window_size//2:-self.window_size//2, :],
                fea_v[:, :, self.window_size//2:-self.window_size//2, :],
            )
        out_3 = out_3[:, :, :, self.window_size//2:-self.window_size//2]

        out = out_0 + out_1 + out_2 + out_3
        return out

class StridedwindowAttention(nn.Module):
    def __init__(self, stride, in_channels, proj_channels, value_channels, out_channels, 
                 num_heads=1, dropout=0.0, bias=True):
        super(StridedwindowAttention, self).__init__()
        self.stride = stride
        self.num_heads = num_heads
        self.dropout = nn.Dropout(dropout)
        self.in_channels = in_channels
        self.proj_channels = proj_channels
        head_dim = proj_channels // num_heads
        assert head_dim * num_heads == self.proj_channels, "embed_dim must be divisible by num_heads"
        self.scaling = head_dim ** -0.5

        self.q_proj = nn.Conv2d(in_channels, proj_channels, kernel_size=1, bias=bias)
        self.k_proj = nn.Conv2d(in_channels, proj_channels, kernel_size=1, bias=bias)

        self.value_channels = value_channels
        self.out_channels = out_channels
        assert out_channels // num_heads * num_heads == self.out_channels
        self.v_proj = nn.Conv2d(value_channels, out_channels, kernel_size=1, bias=bias)

    def make_window(self, x: torch.Tensor):
        """
        input: (B, C, H, W)
        output: (B, h, S(h), S(w), H/S * W/S, C/h)
        """
        bsz, dim, h, w = x.shape
        assert h % self.stride == 0 and w % self.stride == 0
        
        x = x.view(bsz, self.num_heads, dim // self.num_heads, h // self.stride, self.stride, 
                   w // self.stride, self.stride) # (B, h, C/h, H/S, S(h), W/S, S(w))
        x = x.permute(0, 1, 4, 6, 3, 5, 2).contiguous() # (B, h, S(h), S(w), H/S, W/S, C/h)
        x = x.view(bsz, self.num_heads, self.stride, self.stride,  
                   h // self.stride * w // self.stride, dim // self.num_heads)
        return x

    def demake_window(self, x: torch.Tensor, h, w):
        """
        input: (B, h, S(h), S(w), H/S * W/S, C/h)
        output: (B, C, H, W)
        """
        bsz, _, _, _, _, dim_h = x.shape
        x = x.view(bsz, self.num_heads, self.stride, self.stride,  
                   h // self.stride, w // self.stride, dim_h) # (B, h, S(h), S(w), H/S, W/S, C/h)
        x = x.permute(0, 1, 6, 4, 2, 5, 3).contiguous() # (B, h, C/h, H/S, S(h), W/S, S(w))
        x = x.view(bsz, dim_h * self.num_heads, h, w)
        return x

    @torch.no_grad()
    def make_mask_window(self, mask: torch.Tensor):
        """
        input: (B, C, H, W)
        output: (B, 1, S(h), S(w), H/S * W/S, C)
        """
        bsz, mask_channel, h, w = mask.shape
        assert h % self.stride == 0 and w % self.stride == 0

        mask = mask.view(bsz, 1, mask_channel, h // self.stride, self.stride, w // self.stride, self.stride)
        mask = mask.permute(0, 1, 4, 6, 3, 5, 2).contiguous()
        mask = mask.view(bsz, 1, self.stride, self.stride, h // self.stride * w // self.stride, mask_channel)
        return mask
    
    def forward(self, fea_q, fea_k, fea_v, mask_q=None, mask_k=None):
        '''
        fea: (b, d, h, w)
        mask: (b, c, h, w)
        '''
        bsz, _, h, w = fea_q.shape
        
        query = self.q_proj(fea_q) # (B, D, H, W)
        key = self.k_proj(fea_k)
        value = self.v_proj(fea_v)
        query = self.make_window(query) # (B, h, S(h), S(w), H/S * W/S, C/h)
        key = self.make_window(key)
        value = self.make_window(value)
        
        weights = torch.matmul(query, key.transpose(-1, -2)) # (B, h, S(h), S(w), H/S * W/S, H/S * W/S)
        weights = weights * self.scaling
        if mask_q is not None and mask_k is not None:
            mask_q = self.make_mask_window(mask_q) # (B, 1, S(h), S(w), H/S * W/S, C)
            mask_k = self.make_mask_window(mask_k)
            with torch.no_grad():
                mask_attn = torch.matmul(mask_q, mask_k.transpose(-1, -2))
                mask_sum = torch.sum(mask_attn, dim=-1, keepdim=True)
                mask_attn += (mask_sum == 0).float()
                mask_attn = mask_attn.masked_fill_(mask_attn == 0, float('-inf')).masked_fill_(mask_attn == 1, float(0.0))
            weights += mask_attn        

        weights = self.dropout(F.softmax(weights, dim=-1))
        if mask_q is not None and mask_k is not None:
            weights = weights * (1 - (mask_sum == 0).float().detach())

        out = torch.matmul(weights, value) # (B, h, S(h), S(w), H/S * W/S, D/h)
        out = self.demake_window(out, h, w) #(B, D, H, W)
        return out
    

================================================
FILE: models/modules/spectral_norm.py
================================================
import torch
from torch.nn import Parameter

def l2normalize(v, eps=1e-12):
    return v / (v.norm() + eps)

class SpectralNorm(object):
    def __init__(self):
        self.name = "weight"
        #print(self.name)
        self.power_iterations = 1

    def compute_weight(self, module):
        u = getattr(module, self.name + "_u")
        v = getattr(module, self.name + "_v")
        w = getattr(module, self.name + "_bar")

        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
            u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
        # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
        sigma = u.dot(w.view(height, -1).mv(v))
        return w / sigma.expand_as(w)

    @staticmethod
    def apply(module):
        name = "weight"
        fn = SpectralNorm()

        try:
            u = getattr(module, name + "_u")
            v = getattr(module, name + "_v")
            w = getattr(module, name + "_bar")
        except AttributeError:
            w = getattr(module, name)
            height = w.data.shape[0]
            width = w.view(height, -1).data.shape[1]
            u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
            v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
            w_bar = Parameter(w.data)

            #del module._parameters[name]

            module.register_parameter(name + "_u", u)
            module.register_parameter(name + "_v", v)
            module.register_parameter(name + "_bar", w_bar)

        # remove w from parameter list
        del module._parameters[name]

        setattr(module, name, fn.compute_weight(module))

        # recompute weight before every forward()
        module.register_forward_pre_hook(fn)

        return fn

    def remove(self, module):
        weight = self.compute_weight(module)
        delattr(module, self.name)
        del module._parameters[self.name + '_u']
        del module._parameters[self.name + '_v']
        del module._parameters[self.name + '_bar']
        module.register_parameter(self.name, Parameter(weight.data))

    def __call__(self, module, inputs):
        setattr(module, self.name, self.compute_weight(module))

def spectral_norm(module):
    SpectralNorm.apply(module)
    return module

def remove_spectral_norm(module):
    name = 'weight'
    for k, hook in module._forward_pre_hooks.items():
        if isinstance(hook, SpectralNorm) and hook.name == name:
            hook.remove(module)
            del module._forward_pre_hooks[k]
            return module

    raise ValueError("spectral_norm of '{}' not found in {}"
                     .format(name, module))

================================================
FILE: models/modules/tps_transform.py
================================================
from __future__ import absolute_import

import numpy as np
import itertools

import torch
import torch.nn as nn
import torch.nn.functional as F

# TF32 is not enough, require FP32
# Disable automatic TF32 since Pytorch 1.7
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

def grid_sample(input, grid, mode='bilinear', canvas=None):
    output = F.grid_sample(input, grid, mode=mode, align_corners=True)
    if canvas is None:
        return output
    else:
        input_mask = input.data.new(input.size()).fill_(1)
        output_mask = F.grid_sample(input_mask, grid, mode='nearest', align_corners=True)
        padded_output = output * output_mask + canvas * (1 - output_mask)
        return padded_output


# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
def compute_partial_repr(input_points, control_points):
    N = input_points.size(0)
    M = control_points.size(0)
    pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2)
    # original implementation, very slow
    # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
    pairwise_diff_square = pairwise_diff * pairwise_diff
    pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, 1]
    repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
    #repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist + 1e-8)
    # fix numerical error for 0 * log(0), substitute all nan with 0
    mask = repr_matrix != repr_matrix
    repr_matrix.masked_fill_(mask, 0)
    return repr_matrix


# compute \Delta_c^-1
def bulid_delta_inverse(target_control_points):
    '''
    target_control_points: (N, 2)
    '''
    N = target_control_points.shape[0]
    forward_kernel = torch.zeros(N + 3, N + 3).to(target_control_points.device)
    target_control_partial_repr = compute_partial_repr(target_control_points, target_control_points)
    forward_kernel[:N, :N].copy_(target_control_partial_repr)
    forward_kernel[:N, -3].fill_(1)
    forward_kernel[-3, :N].fill_(1)
    forward_kernel[:N, -2:].copy_(target_control_points)
    forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))
    # compute inverse matrix
    inverse_kernel = torch.inverse(forward_kernel)
    return inverse_kernel


# create target coordinate matrix
def build_target_coordinate_matrix(target_height, target_width, target_control_points):
    '''
    target_control_points: (N, 2)
    '''
    HW = target_height * target_width
    target_coordinate = list(itertools.product(range(target_height), range(target_width)))
    target_coordinate = torch.Tensor(target_coordinate).to(target_control_points.device) # HW x 2
    Y, X = target_coordinate.split(1, dim = 1)
    Y = Y / (target_height - 1)
    X = X / (target_width - 1)
    target_coordinate = torch.cat([X, Y], dim = 1) # convert from (y, x) to (x, y)
    target_coordinate_partial_repr = compute_partial_repr(target_coordinate, target_control_points)
    target_coordinate_repr = torch.cat([
        target_coordinate_partial_repr, 
        torch.ones((HW, 1), device=target_control_points.device), 
        target_coordinate], dim = 1)
    return target_coordinate_repr


def tps_sampler(target_height, target_width, inverse_kernel, target_coordinate_repr,
                source, source_control_points, sample_mode='bilinear'):
    '''
    inverse_kernel: \Delta_C^-1
    target_coordinate_repr: \hat{p}
    source: (B, C, H, W)
    source_control_points: (B, N, 2)
    '''
    batch_size = source.shape[0]
    Y = torch.cat([source_control_points, torch.zeros((batch_size, 3, 2), device=source.device)], dim=1)
    mapping_matrix = torch.matmul(inverse_kernel, Y)
    source_coordinate = torch.matmul(target_coordinate_repr, mapping_matrix)

    grid = source_coordinate.view(-1, target_height, target_width, 2)
    grid = torch.clamp(grid, 0, 1) # the source_control_points may be out of [0, 1].
    # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
    grid = 2.0 * grid - 1.0
    output_maps = grid_sample(source, grid, mode=sample_mode, canvas=None)
    return output_maps, source_coordinate


def tps_spatial_transform(target_height, target_width, target_control_points, 
                          source, source_control_points, sample_mode='bilinear'):
    '''
    target_control_points: (N, 2)
    source: (B, C, H, W)
    source_control_points: (B, N, 2)
    '''
    inverse_kernel = bulid_delta_inverse(target_control_points)
    target_coordinate_repr = build_target_coordinate_matrix(target_height, target_width, target_control_points)
    
    return tps_sampler(target_height, target_width, inverse_kernel, target_coordinate_repr, 
                       source, source_control_points, sample_mode)


class TPSSpatialTransformer(nn.Module):

    def __init__(self, target_height, target_width, target_control_points):
        super(TPSSpatialTransformer, self).__init__()
        self.target_height, self.target_width = target_height, target_width
        self.num_control_points = target_control_points.shape[0]
    
        # create padded kernel matrix
        inverse_kernel = bulid_delta_inverse(target_control_points)
    
        # create target coordinate matrix
        target_coordinate_repr = build_target_coordinate_matrix(target_height, target_width, target_control_points)
    
        # register precomputed matrices
        self.register_buffer('inverse_kernel', inverse_kernel)
        #self.register_buffer('padding_matrix', torch.zeros(3, 2))
        self.register_buffer('target_coordinate_repr', target_coordinate_repr)
        self.register_buffer('target_control_points', target_control_points)
    
    def forward(self, source, source_control_points):
        assert source_control_points.ndimension() == 3
        assert source_control_points.size(1) == self.num_control_points
        assert source_control_points.size(2) == 2
        
        return tps_sampler(self.target_height, self.target_width,
                           self.inverse_kernel, self.target_coordinate_repr,
                           source, source_control_points)
 

================================================
FILE: scripts/demo.py
================================================
import os
import sys
import argparse
import numpy as np
import cv2
import torch
from PIL import Image
sys.path.append('.')

from training.config import get_config
from training.inference import Inference
from training.utils import create_logger, print_args

def main(config, args):
    logger = create_logger(args.save_folder, args.name, 'info', console=True)
    print_args(args, logger)
    logger.info(config)

    inference = Inference(config, args, args.load_path)

    n_imgname = sorted(os.listdir(args.source_dir))
    m_imgname = sorted(os.listdir(args.reference_dir))
    
    for i, (imga_name, imgb_name) in enumerate(zip(n_imgname, m_imgname)):
        imgA = Image.open(os.path.join(args.source_dir, imga_name)).convert('RGB')
        imgB = Image.open(os.path.join(args.reference_dir, imgb_name)).convert('RGB')

        result = inference.transfer(imgA, imgB, postprocess=True) 
        if result is None:
            continue
        imgA = np.array(imgA); imgB = np.array(imgB)
        h, w, _ = imgA.shape
        result = result.resize((h, w)); result = np.array(result)
        vis_image = np.hstack((imgA, imgB, result))
        save_path = os.path.join(args.save_folder, f"result_{i}.png")
        Image.fromarray(vis_image.astype(np.uint8)).save(save_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser("argument for training")
    parser.add_argument("--name", type=str, default='demo')
    parser.add_argument("--save_path", type=str, default='result', help="path to save model")
    parser.add_argument("--load_path", type=str, help="folder to load model", 
                        default='ckpts/sow_pyramid_a5_e3d2_remapped.pth')

    parser.add_argument("--source-dir", type=str, default="assets/images/non-makeup")
    parser.add_argument("--reference-dir", type=str, default="assets/images/makeup")
    parser.add_argument("--gpu", default='0', type=str, help="GPU id to use.")

    args = parser.parse_args()
    args.gpu = 'cuda:' + args.gpu
    args.device = torch.device(args.gpu)

    args.save_folder = os.path.join(args.save_path, args.name)
    if not os.path.exists(args.save_folder):
        os.makedirs(args.save_folder)
    
    config = get_config()
    main(config, args)

================================================
FILE: scripts/train.py
================================================
import os
import sys
import argparse
import torch
from torch.utils.data import DataLoader
sys.path.append('.')

from training.config import get_config
from training.dataset import MakeupDataset
from training.solver import Solver
from training.utils import create_logger, print_args


def main(config, args):
    logger = create_logger(args.save_folder, args.name, 'info', console=True)
    print_args(args, logger)
    logger.info(config)
    
    dataset = MakeupDataset(config)
    data_loader = DataLoader(dataset, batch_size=config.DATA.BATCH_SIZE, num_workers=config.DATA.NUM_WORKERS, shuffle=True)
    
    solver = Solver(config, args, logger)
    solver.train(data_loader)
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser("argument for training")
    parser.add_argument("--name", type=str, default='elegant')
    parser.add_argument("--save_path", type=str, default='results', help="path to save model")
    parser.add_argument("--load_folder", type=str, help="path to load model", 
                        default=None)
    parser.add_argument("--keepon", default=False, action="store_true", help='keep on training')

    parser.add_argument("--gpu", default='0', type=str, help="GPU id to use.")

    args = parser.parse_args()
    config = get_config()
    
    #args.gpu = 'cuda:' + args.gpu
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    #args.device = torch.device(args.gpu)
    args.device = torch.device('cuda:0')

    args.save_folder = os.path.join(args.save_path, args.name)
    if not os.path.exists(args.save_folder):
        os.makedirs(args.save_folder)    
    
    main(config, args)

================================================
FILE: training/__init__.py
================================================


================================================
FILE: training/config.py
================================================
from fvcore.common.config import CfgNode

"""
This file defines default options of configurations.
It will be further merged by yaml files and options from
the command-line.
Note that *any* hyper-parameters should be firstly defined
here to enable yaml and command-line configuration.
"""

_C = CfgNode()

# Logging and saving
_C.LOG = CfgNode()
_C.LOG.SAVE_FREQ = 10
_C.LOG.VIS_FREQ = 1

# Data settings
_C.DATA = CfgNode()
_C.DATA.PATH = './data/MT-Dataset'
_C.DATA.NUM_WORKERS = 4
_C.DATA.BATCH_SIZE = 1
_C.DATA.IMG_SIZE = 256

# Training hyper-parameters
_C.TRAINING = CfgNode()
_C.TRAINING.G_LR = 2e-4
_C.TRAINING.D_LR = 2e-4
_C.TRAINING.BETA1 = 0.5
_C.TRAINING.BETA2 = 0.999
_C.TRAINING.NUM_EPOCHS = 50
_C.TRAINING.LR_DECAY_FACTOR = 5e-2
_C.TRAINING.DOUBLE_D = False

# Loss weights
_C.LOSS = CfgNode()
_C.LOSS.LAMBDA_A = 10.0
_C.LOSS.LAMBDA_B = 10.0
_C.LOSS.LAMBDA_IDT = 0.5
_C.LOSS.LAMBDA_REC = 10
_C.LOSS.LAMBDA_MAKEUP = 100
_C.LOSS.LAMBDA_SKIN = 0.1
_C.LOSS.LAMBDA_EYE = 1.5
_C.LOSS.LAMBDA_LIP = 1
_C.LOSS.LAMBDA_MAKEUP_LIP = _C.LOSS.LAMBDA_MAKEUP * _C.LOSS.LAMBDA_LIP
_C.LOSS.LAMBDA_MAKEUP_SKIN = _C.LOSS.LAMBDA_MAKEUP * _C.LOSS.LAMBDA_SKIN
_C.LOSS.LAMBDA_MAKEUP_EYE = _C.LOSS.LAMBDA_MAKEUP * _C.LOSS.LAMBDA_EYE
_C.LOSS.LAMBDA_VGG = 5e-3

# Model structure
_C.MODEL = CfgNode()
_C.MODEL.D_TYPE = 'SN'
_C.MODEL.D_REPEAT_NUM = 3
_C.MODEL.D_CONV_DIM = 64
_C.MODEL.G_CONV_DIM = 64
_C.MODEL.NUM_HEAD = 1
_C.MODEL.DOUBLE_E = False
_C.MODEL.USE_FF = False
_C.MODEL.NUM_LAYER_E = 3
_C.MODEL.NUM_LAYER_D = 2
_C.MODEL.WINDOW_SIZE = 16
_C.MODEL.MERGE_MODE = 'conv'

# Preprocessing
_C.PREPROCESS = CfgNode()
_C.PREPROCESS.UP_RATIO = 0.6 / 0.85  # delta_size / face_size
_C.PREPROCESS.DOWN_RATIO = 0.2 / 0.85  # delta_size / face_size
_C.PREPROCESS.WIDTH_RATIO = 0.2 / 0.85  # delta_size / face_size
_C.PREPROCESS.LIP_CLASS = [7, 9]
_C.PREPROCESS.FACE_CLASS = [1, 6]
_C.PREPROCESS.EYEBROW_CLASS = [2, 3]
_C.PREPROCESS.EYE_CLASS = [4, 5]
_C.PREPROCESS.LANDMARK_POINTS = 68

# Pseudo ground truth
_C.PGT = CfgNode()
_C.PGT.EYE_MARGIN = 12
_C.PGT.LIP_MARGIN = 4
_C.PGT.ANNEALING = True
_C.PGT.SKIN_ALPHA = 0.3
_C.PGT.SKIN_ALPHA_MILESTONES = (0, 12, 24, 50)
_C.PGT.SKIN_ALPHA_VALUES = (0.2, 0.4, 0.3, 0.2)
_C.PGT.EYE_ALPHA = 0.8
_C.PGT.EYE_ALPHA_MILESTONES = (0, 12, 24, 50)
_C.PGT.EYE_ALPHA_VALUES = (0.6, 0.8, 0.6, 0.4)
_C.PGT.LIP_ALPHA = 0.1
_C.PGT.LIP_ALPHA_MILESTONES = (0, 12, 24, 50)
_C.PGT.LIP_ALPHA_VALUES = (0.05, 0.2, 0.1, 0.0)

# Postprocessing
_C.POSTPROCESS = CfgNode()
_C.POSTPROCESS.WILL_DENOISE = False

def get_config()->CfgNode:
    return _C


================================================
FILE: training/dataset.py
================================================
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader

from training.config import get_config
from training.preprocess import PreProcess

class MakeupDataset(Dataset):
    def __init__(self, config=None):
        super(MakeupDataset, self).__init__()
        if config is None:
            config = get_config()
        self.root = config.DATA.PATH
        with open(os.path.join(config.DATA.PATH, 'makeup.txt'), 'r') as f:
            self.makeup_names = [name.strip() for name in f.readlines()]
        with open(os.path.join(config.DATA.PATH, 'non-makeup.txt'), 'r') as f:
            self.non_makeup_names = [name.strip() for name in f.readlines()]
        self.preprocessor = PreProcess(config, need_parser=False)
        self.img_size = config.DATA.IMG_SIZE

    def load_from_file(self, img_name):
        image = Image.open(os.path.join(self.root, 'images', img_name)).convert('RGB')
        mask = self.preprocessor.load_mask(os.path.join(self.root, 'segs', img_name))
        base_name = os.path.splitext(img_name)[0]
        lms = self.preprocessor.load_lms(os.path.join(self.root, 'lms', f'{base_name}.npy'))
        return self.preprocessor.process(image, mask, lms)
    
    def __len__(self):
        return max(len(self.makeup_names), len(self.non_makeup_names))

    def __getitem__(self, index):
        idx_s = torch.randint(0, len(self.non_makeup_names), (1, )).item()
        idx_r = torch.randint(0, len(self.makeup_names), (1, )).item()
        name_s = self.non_makeup_names[idx_s]
        name_r = self.makeup_names[idx_r]
        source = self.load_from_file(name_s)
        reference = self.load_from_file(name_r)
        return source, reference

def get_loader(config):
    dataset = MakeupDataset(config)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=config.DATA.BATCH_SIZE,
                            num_workers=config.DATA.NUM_WORKERS)
    return dataloader


if __name__ == "__main__":
    dataset = MakeupDataset()
    dataloader = DataLoader(dataset, batch_size=1, num_workers=16)
    for e in range(10):
        for i, (point_s, point_r) in enumerate(dataloader):
            pass

================================================
FILE: training/inference.py
================================================
from typing import List
import numpy as np
import cv2
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision.transforms import ToPILImage

from training.solver import Solver
from training.preprocess import PreProcess
from models.modules.pseudo_gt import expand_area, mask_blend

class InputSample:
    def __init__(self, inputs, apply_mask=None):
        self.inputs = inputs
        self.transfer_input = None
        self.attn_out_list = None
        self.apply_mask = apply_mask

    def clear(self):
        self.transfer_input = None
        self.attn_out_list = None


class Inference:
    """
    An inference wrapper for makeup transfer.
    It takes two image `source` and `reference` in,
    and transfers the makeup of reference to source.
    """
    def __init__(self, config, args, model_path="G.pth"):

        self.device = args.device
        self.solver = Solver(config, args, inference=model_path)
        self.preprocess = PreProcess(config, args.device)
        self.denoise = config.POSTPROCESS.WILL_DENOISE
        self.img_size = config.DATA.IMG_SIZE
        # TODO: can be a hyper-parameter
        self.eyeblur = {'margin': 12, 'blur_size':7}

    def prepare_input(self, *data_inputs):
        """
        data_inputs: List[image, mask, diff, lms]
        """
        inputs = []
        for i in range(len(data_inputs)):
            inputs.append(data_inputs[i].to(self.device).unsqueeze(0))
        # prepare mask
        inputs[1] = torch.cat((inputs[1][:,0:1], inputs[1][:,1:].sum(dim=1, keepdim=True)), dim=1)
        return inputs

    def postprocess(self, source, crop_face, result):
        if crop_face is not None:
            source = source.crop(
                (crop_face.left(), crop_face.top(), crop_face.right(), crop_face.bottom()))
        source = np.array(source)
        result = np.array(result)

        height, width = source.shape[:2]
        small_source = cv2.resize(source, (self.img_size, self.img_size))
        laplacian_diff = source.astype(
            np.float) - cv2.resize(small_source, (width, height)).astype(np.float)
        result = (cv2.resize(result, (width, height)) +
                  laplacian_diff).round().clip(0, 255)

        result = result.astype(np.uint8)

        if self.denoise:
            result = cv2.fastNlMeansDenoisingColored(result)
        result = Image.fromarray(result).convert('RGB')
        return result

    
    def generate_source_sample(self, source_input):
        """
        source_input: List[image, mask, diff, lms]
        """
        source_input = self.prepare_input(*source_input)
        return InputSample(source_input)

    def generate_reference_sample(self, reference_input, apply_mask=None, 
                                  source_mask=None, mask_area=None, saturation=1.0):
        """
        all the operations on the mask, e.g., partial mask, saturation, 
        should be finally defined in apply_mask
        """
        if source_mask is not None and mask_area is not None:
            apply_mask = self.generate_partial_mask(source_mask, mask_area, saturation)
            apply_mask = apply_mask.unsqueeze(0).to(self.device)
        reference_input = self.prepare_input(*reference_input)
        
        if apply_mask is None:
            apply_mask = torch.ones(1, 1, self.img_size, self.img_size).to(self.device)
        return InputSample(reference_input, apply_mask)


    def generate_partial_mask(self, source_mask, mask_area='full', saturation=1.0):
        """
        source_mask: (C, H, W), lip, face, left eye, right eye
        return: apply_mask: (1, H, W)
        """
        assert mask_area in ['full', 'skin', 'lip', 'eye']
        if mask_area == 'full':
            return torch.sum(source_mask[0:2], dim=0, keepdim=True) * saturation
        elif mask_area == 'lip':
            return source_mask[0:1] * saturation
        elif mask_area == 'skin':
            mask_l_eye = expand_area(source_mask[2:3], self.eyeblur['margin']) #* source_mask[1:2]
            mask_r_eye = expand_area(source_mask[3:4], self.eyeblur['margin']) #* source_mask[1:2]
            mask_eye = mask_l_eye + mask_r_eye
            #mask_eye = mask_blend(mask_eye, 1.0, source_mask[1:2], blur_size=self.eyeblur['blur_size'])
            mask_eye = mask_blend(mask_eye, 1.0, blur_size=self.eyeblur['blur_size'])
            return source_mask[1:2] * (1 - mask_eye) * saturation
        elif mask_area == 'eye':
            mask_l_eye = expand_area(source_mask[2:3], self.eyeblur['margin']) #* source_mask[1:2]
            mask_r_eye = expand_area(source_mask[3:4], self.eyeblur['margin']) #* source_mask[1:2]
            mask_eye = mask_l_eye + mask_r_eye
            #mask_eye = mask_blend(mask_eye, saturation, source_mask[1:2], blur_size=self.eyeblur['blur_size'])
            mask_eye = mask_blend(mask_eye, saturation, blur_size=self.eyeblur['blur_size'])
            return mask_eye
  

    @torch.no_grad()
    def interface_transfer(self, source_sample: InputSample, reference_samples: List[InputSample]):
        """
        Input: a source sample and multiple reference samples
        Return: PIL.Image, the fused result
        """
        # encode source
        if source_sample.transfer_input is None:
            source_sample.transfer_input = self.solver.G.get_transfer_input(*source_sample.inputs)
        
        # encode references
        for r_sample in reference_samples:
            if r_sample.transfer_input is None:
                r_sample.transfer_input = self.solver.G.get_transfer_input(*r_sample.inputs, True)

        # self attention
        if source_sample.attn_out_list is None:
            source_sample.attn_out_list = self.solver.G.get_transfer_output(
                    *source_sample.transfer_input, *source_sample.transfer_input
                )
        
        # full transfer for each reference
        for r_sample in reference_samples:
            if r_sample.attn_out_list is None:
                r_sample.attn_out_list = self.solver.G.get_transfer_output(
                    *source_sample.transfer_input, *r_sample.transfer_input
                )

        # fusion
        # if the apply_mask is changed without changing source and references,
        # only the following steps are required
        fused_attn_out_list = []
        for i in range(len(source_sample.attn_out_list)):
            init_attn_out = torch.zeros_like(source_sample.attn_out_list[i], device=self.device)
            fused_attn_out_list.append(init_attn_out)
        apply_mask_sum = torch.zeros((1, 1, self.img_size, self.img_size), device=self.device)
        
        for r_sample in reference_samples:
            if r_sample.apply_mask is not None:
                apply_mask_sum += r_sample.apply_mask
                for i in range(len(source_sample.attn_out_list)):
                    feature_size = r_sample.attn_out_list[i].shape[2]
                    apply_mask = F.interpolate(r_sample.apply_mask, feature_size, mode='nearest')
                    fused_attn_out_list[i] += apply_mask * r_sample.attn_out_list[i]

        # self as reference
        source_apply_mask = 1 - apply_mask_sum.clamp(0, 1)
        for i in range(len(source_sample.attn_out_list)):
            feature_size = source_sample.attn_out_list[i].shape[2]
            apply_mask = F.interpolate(source_apply_mask, feature_size, mode='nearest')
            fused_attn_out_list[i] += apply_mask * source_sample.attn_out_list[i]

        # decode
        result = self.solver.G.decode(
            source_sample.transfer_input[0], fused_attn_out_list
        )
        result = self.solver.de_norm(result).squeeze(0)
        result = ToPILImage()(result.cpu())
        return result

    
    def transfer(self, source: Image, reference: Image, postprocess=True):
        """
        Args:
            source (Image): The image where makeup will be transfered to.
            reference (Image): Image containing targeted makeup.
        Return:
            Image: Transfered image.
        """
        source_input, face, crop_face = self.preprocess(source)
        reference_input, _, _ = self.preprocess(reference)
        if not (source_input and reference_input):
            return None

        #source_sample = self.generate_source_sample(source_input)
        #reference_samples = [self.generate_reference_sample(reference_input)]
        #result = self.interface_transfer(source_sample, reference_samples)
        source_input = self.prepare_input(*source_input)
        reference_input = self.prepare_input(*reference_input)
        result = self.solver.test(*source_input, *reference_input)
        
        if not postprocess:
            return result
        else:
            return self.postprocess(source, crop_face, result)

    def joint_transfer(self, source: Image, reference_lip: Image, reference_skin: Image,
                       reference_eye: Image, postprocess=True):
        source_input, face, crop_face = self.preprocess(source)
        lip_input, _, _ = self.preprocess(reference_lip)
        skin_input, _, _ = self.preprocess(reference_skin)
        eye_input, _, _ = self.preprocess(reference_eye)
        if not (source_input and lip_input and skin_input and eye_input):
            return None

        source_mask = source_input[1]
        source_sample = self.generate_source_sample(source_input)
        reference_samples = [
            self.generate_reference_sample(lip_input, source_mask=source_mask, mask_area='lip'),
            self.generate_reference_sample(skin_input, source_mask=source_mask, mask_area='skin'),
            self.generate_reference_sample(eye_input, source_mask=source_mask, mask_area='eye')
        ]
        
        result = self.interface_transfer(source_sample, reference_samples)
        
        if not postprocess:
            return result
        else:
            return self.postprocess(source, crop_face, result)

================================================
FILE: training/preprocess.py
================================================
import os
import sys
import cv2
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from torchvision.transforms import functional
sys.path.append('.')

import faceutils as futils
from training.config import get_config

class PreProcess:

    def __init__(self, config, need_parser=True, device='cpu'):
        self.img_size = config.DATA.IMG_SIZE   
        self.device = device

        xs, ys = np.meshgrid(
            np.linspace(
                0, self.img_size - 1,
                self.img_size
            ),
            np.linspace(
                0, self.img_size - 1,
                self.img_size
            )
        )
        xs = xs[None].repeat(config.PREPROCESS.LANDMARK_POINTS, axis=0)
        ys = ys[None].repeat(config.PREPROCESS.LANDMARK_POINTS, axis=0)
        fix = np.concatenate([ys, xs], axis=0) 
        self.fix = torch.Tensor(fix) #(136, h, w)
        if need_parser:
            self.face_parse = futils.mask.FaceParser(device=device)

        self.up_ratio    = config.PREPROCESS.UP_RATIO
        self.down_ratio  = config.PREPROCESS.DOWN_RATIO
        self.width_ratio = config.PREPROCESS.WIDTH_RATIO
        self.lip_class   = config.PREPROCESS.LIP_CLASS
        self.face_class  = config.PREPROCESS.FACE_CLASS
        self.eyebrow_class  = config.PREPROCESS.EYEBROW_CLASS
        self.eye_class  = config.PREPROCESS.EYE_CLASS

        self.transform = transforms.Compose([
            transforms.Resize(config.DATA.IMG_SIZE),
            transforms.ToTensor(),
            transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])
    
    ############################## Mask Process ##############################
    # mask attribute: 0:background 1:face 2:left-eyebrow 3:right-eyebrow 4:left-eye 5: right-eye 6: nose
    # 7: upper-lip 8: teeth 9: under-lip 10:hair 11: left-ear 12: right-ear 13: neck
    def mask_process(self, mask: torch.Tensor):
        '''
        mask: (1, h, w)
        '''        
        mask_lip = (mask == self.lip_class[0]).float() + (mask == self.lip_class[1]).float()
        mask_face = (mask == self.face_class[0]).float() + (mask == self.face_class[1]).float()

        #mask_eyebrow_left = (mask == self.eyebrow_class[0]).float()
        #mask_eyebrow_right = (mask == self.eyebrow_class[1]).float()
        mask_face += (mask == self.eyebrow_class[0]).float()
        mask_face += (mask == self.eyebrow_class[1]).float()

        mask_eye_left = (mask == self.eye_class[0]).float()
        mask_eye_right = (mask == self.eye_class[1]).float()

        #mask_list = [mask_lip, mask_face, mask_eyebrow_left, mask_eyebrow_right, mask_eye_left, mask_eye_right]
        mask_list = [mask_lip, mask_face, mask_eye_left, mask_eye_right]
        mask_aug = torch.cat(mask_list, 0) # (C, H, W)
        return mask_aug      

    def save_mask(self, mask: torch.Tensor, path):
        assert mask.shape[0] == 1
        mask = mask.squeeze(0).numpy().astype(np.uint8)
        mask = Image.fromarray(mask)
        mask.save(path)

    def load_mask(self, path):
        mask = np.array(Image.open(path).convert('L'))
        mask = torch.FloatTensor(mask).unsqueeze(0)
        mask = functional.resize(mask, self.img_size, transforms.InterpolationMode.NEAREST)
        return mask
    
    ############################## Landmarks Process ##############################
    def lms_process(self, image:Image):
        face = futils.dlib.detect(image)
        # face: rectangles, List of rectangles of face region: [(left, top), (right, bottom)]
        if not face:
            return None
        face = face[0]
        lms = futils.dlib.landmarks(image, face) * self.img_size / image.width # scale to fit self.img_size
        # lms: narray, the position of 68 key points, (68 ,2)
        lms = torch.IntTensor(lms.round()).clamp_max_(self.img_size - 1)
        # distinguish upper and lower lips 
        lms[61:64,0] -= 1; lms[65:68,0] += 1
        for i in range(3):
            if torch.sum(torch.abs(lms[61+i] - lms[67-i])) == 0:
                lms[61+i,0] -= 1;  lms[67-i,0] += 1
        # double check
        '''for i in range(48, 67):
            for j in range(i+1, 68):
                if torch.sum(torch.abs(lms[i] - lms[j])) == 0:
                    lms[i,0] -= 1; lms[j,0] += 1'''
        return lms       
    
    def diff_process(self, lms: torch.Tensor, normalize=False):
        '''
        lms:(68, 2)
        '''
        lms = lms.transpose(1, 0).reshape(-1, 1, 1) # (136, 1, 1)
        diff = self.fix - lms # (136, h, w)

        if normalize:
            norm = torch.norm(diff, dim=0, keepdim=True).repeat(diff.shape[0], 1, 1)
            norm = torch.where(norm == 0, torch.tensor(1e10), norm)
            diff /= norm
        return diff

    def save_lms(self, lms: torch.Tensor, path):
        lms = lms.numpy()
        np.save(path, lms)
    
    def load_lms(self, path):
        lms = np.load(path)
        return torch.IntTensor(lms)

    ############################## Compose Process ##############################
    def preprocess(self, image: Image, is_crop=True):
        '''
        return: image: Image, (H, W), mask: tensor, (1, H, W)
        '''
        face = futils.dlib.detect(image)
        # face: rectangles, List of rectangles of face region: [(left, top), (right, bottom)]
        if not face:
            return None, None, None

        face_on_image = face[0]
        if is_crop:
            image, face, crop_face = futils.dlib.crop(
                image, face_on_image, self.up_ratio, self.down_ratio, self.width_ratio)
        else:
            face = face[0]; crop_face = None
        # image: Image, cropped face
        # face: the same as above
        # crop face: rectangle, face region in cropped face
        np_image = np.array(image) # (h', w', 3)

        mask = self.face_parse.parse(cv2.resize(np_image, (512, 512))).cpu()
        # obtain face parsing result
        # mask: Tensor, (512, 512)
        mask = F.interpolate(
            mask.view(1, 1, 512, 512),
            (self.img_size, self.img_size),
            mode="nearest").squeeze(0).long() #(1, H, W)

        lms = futils.dlib.landmarks(image, face) * self.img_size / image.width # scale to fit self.img_size
        # lms: narray, the position of 68 key points, (68 ,2)
        lms = torch.IntTensor(lms.round()).clamp_max_(self.img_size - 1)
        # distinguish upper and lower lips 
        lms[61:64,0] -= 1; lms[65:68,0] += 1
        for i in range(3):
            if torch.sum(torch.abs(lms[61+i] - lms[67-i])) == 0:
                lms[61+i,0] -= 1;  lms[67-i,0] += 1

        image = image.resize((self.img_size, self.img_size), Image.ANTIALIAS)
        return [image, mask, lms], face_on_image, crop_face
    
    def process(self, image: Image, mask: torch.Tensor, lms: torch.Tensor):
        image = self.transform(image)
        mask = self.mask_process(mask)
        diff = self.diff_process(lms)
        return [image, mask, diff, lms]
    
    def __call__(self, image:Image, is_crop=True):
        source, face_on_image, crop_face = self.preprocess(image, is_crop)
        if source is None:
            return None, None, None
        return self.process(*source), face_on_image, crop_face


if __name__ == "__main__":
    config = get_config()
    preprocessor = PreProcess(config, device='cuda:0')
    if not os.path.exists(os.path.join(config.DATA.PATH, 'lms')):
        os.makedirs(os.path.join(config.DATA.PATH, 'lms', 'makeup'))
        os.makedirs(os.path.join(config.DATA.PATH, 'lms', 'non-makeup'))
    
    # process makeup images
    print("Processing makeup images...")
    with open(os.path.join(config.DATA.PATH, 'makeup.txt'), 'r') as f:
        for line in f.readlines():
            img_name = line.strip()
            raw_image = Image.open(os.path.join(config.DATA.PATH, 'images', img_name)).convert('RGB')
            lms = preprocessor.lms_process(raw_image)
            if lms is not None:
                base_name = os.path.splitext(img_name)[0]
                preprocessor.save_lms(lms, os.path.join(config.DATA.PATH, 'lms', f'{base_name}.npy'))
    print("Done.")

    # process non-makeup images
    print("Processing non-makeup images...")
    with open(os.path.join(config.DATA.PATH, 'non-makeup.txt'), 'r') as f:
        for line in f.readlines():
            img_name = line.strip()
            raw_image = Image.open(os.path.join(config.DATA.PATH, 'images', img_name)).convert('RGB')
            lms = preprocessor.lms_process(raw_image)
            if lms is not None:
                base_name = os.path.splitext(img_name)[0]
                preprocessor.save_lms(lms, os.path.join(config.DATA.PATH, 'lms', f'{base_name}.npy'))
    print("Done.")
    

================================================
FILE: training/solver.py
================================================
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import ToPILImage
from torchvision.utils import save_image, make_grid
import torch.nn.init as init
from tqdm import tqdm

from models.modules.pseudo_gt import expand_area
from models.model import get_discriminator, get_generator, vgg16
from models.loss import GANLoss, MakeupLoss, ComposePGT, AnnealingComposePGT

from training.utils import plot_curves

class Solver():
    def __init__(self, config, args, logger=None, inference=False):
        self.G = get_generator(config)
        if inference:
            self.G.load_state_dict(torch.load(inference, map_location=args.device))
            self.G = self.G.to(args.device).eval()
            return
        self.double_d = config.TRAINING.DOUBLE_D
        self.D_A = get_discriminator(config)
        if self.double_d:
            self.D_B = get_discriminator(config)
        
        self.load_folder = args.load_folder
        self.save_folder = args.save_folder
        self.vis_folder = os.path.join(args.save_folder, 'visualization')
        if not os.path.exists(self.vis_folder):
            os.makedirs(self.vis_folder)
        self.vis_freq = config.LOG.VIS_FREQ
        self.save_freq = config.LOG.SAVE_FREQ

        # Data & PGT
        self.img_size = config.DATA.IMG_SIZE
        self.margins = {'eye':config.PGT.EYE_MARGIN,
                        'lip':config.PGT.LIP_MARGIN}
        self.pgt_annealing = config.PGT.ANNEALING
        if self.pgt_annealing:
            self.pgt_maker = AnnealingComposePGT(self.margins, 
                config.PGT.SKIN_ALPHA_MILESTONES, config.PGT.SKIN_ALPHA_VALUES,
                config.PGT.EYE_ALPHA_MILESTONES, config.PGT.EYE_ALPHA_VALUES,
                config.PGT.LIP_ALPHA_MILESTONES, config.PGT.LIP_ALPHA_VALUES
            )
        else:
            self.pgt_maker = ComposePGT(self.margins, 
                config.PGT.SKIN_ALPHA,
                config.PGT.EYE_ALPHA,
                config.PGT.LIP_ALPHA
            )
        self.pgt_maker.eval()

        # Hyper-param
        self.num_epochs = config.TRAINING.NUM_EPOCHS
        self.g_lr = config.TRAINING.G_LR
        self.d_lr = config.TRAINING.D_LR
        self.beta1 = config.TRAINING.BETA1
        self.beta2 = config.TRAINING.BETA2
        self.lr_decay_factor = config.TRAINING.LR_DECAY_FACTOR

        # Loss param
        self.lambda_idt      = config.LOSS.LAMBDA_IDT
        self.lambda_A        = config.LOSS.LAMBDA_A
        self.lambda_B        = config.LOSS.LAMBDA_B
        self.lambda_lip  = config.LOSS.LAMBDA_MAKEUP_LIP
        self.lambda_skin = config.LOSS.LAMBDA_MAKEUP_SKIN
        self.lambda_eye  = config.LOSS.LAMBDA_MAKEUP_EYE
        self.lambda_vgg      = config.LOSS.LAMBDA_VGG

        self.device = args.device
        self.keepon = args.keepon
        self.logger = logger
        self.loss_logger = {
            'D-A-loss_real':[],
            'D-A-loss_fake':[],
            'D-B-loss_real':[],
            'D-B-loss_fake':[],
            'G-A-loss-adv':[],
            'G-B-loss-adv':[],
            'G-loss-idt':[],
            'G-loss-img-rec':[],
            'G-loss-vgg-rec':[],
            'G-loss-rec':[],
            'G-loss-skin-pgt':[],
            'G-loss-eye-pgt':[],
            'G-loss-lip-pgt':[],
            'G-loss-pgt':[],
            'G-loss':[],
            'D-A-loss':[],
            'D-B-loss':[]
        }

        self.build_model()
        super(Solver, self).__init__()

    def print_network(self, model, name):
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        if self.logger is not None:
            self.logger.info('{:s}, the number of parameters: {:d}'.format(name, num_params))
        else:
            print('{:s}, the number of parameters: {:d}'.format(name, num_params))
    
    # For generator
    def weights_init_xavier(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            init.xavier_normal_(m.weight.data, gain=1.0)
        elif classname.find('Linear') != -1:
            init.xavier_normal_(m.weight.data, gain=1.0)

    def build_model(self):
        self.G.apply(self.weights_init_xavier)
        self.D_A.apply(self.weights_init_xavier)
        if self.double_d:
            self.D_B.apply(self.weights_init_xavier)
        if self.keepon:
            self.load_checkpoint()
        
        self.criterionL1 = torch.nn.L1Loss()
        self.criterionL2 = torch.nn.MSELoss()
        self.criterionGAN = GANLoss(gan_mode='lsgan')
        self.criterionPGT = MakeupLoss()
        self.vgg = vgg16(pretrained=True)

        # Optimizers
        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.d_A_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D_A.parameters()), self.d_lr, [self.beta1, self.beta2])
        if self.double_d:
            self.d_B_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D_B.parameters()), self.d_lr, [self.beta1, self.beta2])
        self.g_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.g_optimizer, 
                    T_max=self.num_epochs, eta_min=self.g_lr * self.lr_decay_factor)
        self.d_A_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.d_A_optimizer, 
                    T_max=self.num_epochs, eta_min=self.d_lr * self.lr_decay_factor)
        if self.double_d:
            self.d_B_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.d_B_optimizer, 
                    T_max=self.num_epochs, eta_min=self.d_lr * self.lr_decay_factor)

        # Print networks
        self.print_network(self.G, 'G')
        self.print_network(self.D_A, 'D_A')
        if self.double_d: self.print_network(self.D_B, 'D_B')

        self.G.to(self.device)
        self.vgg.to(self.device)
        self.D_A.to(self.device)
        if self.double_d: self.D_B.to(self.device)

    def train(self, data_loader):
        self.len_dataset = len(data_loader)
        
        for self.epoch in range(1, self.num_epochs + 1):
            self.start_time = time.time()
            loss_tmp = self.get_loss_tmp()
            self.G.train(); self.D_A.train(); 
            if self.double_d: self.D_B.train()
            losses_G = []; losses_D_A = []; losses_D_B = []
            
            with tqdm(data_loader, desc="training") as pbar:
                for step, (source, reference) in enumerate(pbar):
                    # image, mask, diff, lms
                    image_s, image_r = source[0].to(self.device), reference[0].to(self.device) # (b, c, h, w)
                    mask_s_full, mask_r_full = source[1].to(self.device), reference[1].to(self.device) # (b, c', h, w) 
                    diff_s, diff_r = source[2].to(self.device), reference[2].to(self.device) # (b, 136, h, w)
                    lms_s, lms_r = source[3].to(self.device), reference[3].to(self.device) # (b, K, 2)

                    # process input mask
                    mask_s = torch.cat((mask_s_full[:,0:1], mask_s_full[:,1:].sum(dim=1, keepdim=True)), dim=1)
                    mask_r = torch.cat((mask_r_full[:,0:1], mask_r_full[:,1:].sum(dim=1, keepdim=True)), dim=1)
                    #mask_s = mask_s_full[:,:2]; mask_r = mask_r_full[:,:2]

                    # ================= Generate ================== #
                    fake_A = self.G(image_s, image_r, mask_s, mask_r, diff_s, diff_r, lms_s, lms_r)
                    fake_B = self.G(image_r, image_s, mask_r, mask_s, diff_r, diff_s, lms_r, lms_s)

                    # generate pseudo ground truth
                    pgt_A = self.pgt_maker(image_s, image_r, mask_s_full, mask_r_full, lms_s, lms_r)
                    pgt_B = self.pgt_maker(image_r, image_s, mask_r_full, mask_s_full, lms_r, lms_s)
                    
                    # ================== Train D ================== #
                    # training D_A, D_A aims to distinguish class B
                    # Real
                    out = self.D_A(image_r)
                    d_loss_real = self.criterionGAN(out, True)
                    # Fake
                    out = self.D_A(fake_A.detach())
                    d_loss_fake =  self.criterionGAN(out, False)

                    # Backward + Optimize
                    d_loss = (d_loss_real + d_loss_fake) * 0.5
                    self.d_A_optimizer.zero_grad()
                    d_loss.backward()
                    self.d_A_optimizer.step()                   

                    # Logging
                    loss_tmp['D-A-loss_real'] += d_loss_real.item()
                    loss_tmp['D-A-loss_fake'] += d_loss_fake.item()
                    losses_D_A.append(d_loss.item())

                    # training D_B, D_B aims to distinguish class A
                    # Real
                    if self.double_d:
                        out = self.D_B(image_s)
                    else:
                        out = self.D_A(image_s)
                    d_loss_real = self.criterionGAN(out, True)
                    # Fake
                    if self.double_d:
                        out = self.D_B(fake_B.detach())
                    else:
                        out = self.D_A(fake_B.detach())
                    d_loss_fake =  self.criterionGAN(out, False)

                    # Backward + Optimize
                    d_loss = (d_loss_real+ d_loss_fake) * 0.5
                    if self.double_d:
                        self.d_B_optimizer.zero_grad()
                        d_loss.backward()
                        self.d_B_optimizer.step()
                    else:
                        self.d_A_optimizer.zero_grad()
                        d_loss.backward()
                        self.d_A_optimizer.step()

                    # Logging
                    loss_tmp['D-B-loss_real'] += d_loss_real.item()
                    loss_tmp['D-B-loss_fake'] += d_loss_fake.item()
                    losses_D_B.append(d_loss.item())

                    # ================== Train G ================== #
                    
                    # G should be identity if ref_B or org_A is fed
                    idt_A = self.G(image_s, image_s, mask_s, mask_s, diff_s, diff_s, lms_s, lms_s)
                    idt_B = self.G(image_r, image_r, mask_r, mask_r, diff_r, diff_r, lms_r, lms_r)
                    loss_idt_A = self.criterionL1(idt_A, image_s) * self.lambda_A * self.lambda_idt
                    loss_idt_B = self.criterionL1(idt_B, image_r) * self.lambda_B * self.lambda_idt
                    # loss_idt
                    loss_idt = (loss_idt_A + loss_idt_B) * 0.5

                    # GAN loss D_A(G_A(A))
                    pred_fake = self.D_A(fake_A)
                    g_A_loss_adv = self.criterionGAN(pred_fake, True)

                    # GAN loss D_B(G_B(B))
                    if self.double_d:
                        pred_fake = self.D_B(fake_B)
                    else:
                        pred_fake = self.D_A(fake_B)
                    g_B_loss_adv = self.criterionGAN(pred_fake, True)
                    
                    # Makeup loss
                    g_A_loss_pgt = 0; g_B_loss_pgt = 0
                    
                    g_A_lip_loss_pgt = self.criterionPGT(fake_A, pgt_A, mask_s_full[:,0:1]) * self.lambda_lip
                    g_B_lip_loss_pgt = self.criterionPGT(fake_B, pgt_B, mask_r_full[:,0:1]) * self.lambda_lip
                    g_A_loss_pgt += g_A_lip_loss_pgt
                    g_B_loss_pgt += g_B_lip_loss_pgt

                    mask_s_eye = expand_area(mask_s_full[:,2:4].sum(dim=1, keepdim=True), self.margins['eye'])
                    mask_r_eye = expand_area(mask_r_full[:,2:4].sum(dim=1, keepdim=True), self.margins['eye'])
                    mask_s_eye = mask_s_eye * mask_s_full[:,1:2]
                    mask_r_eye = mask_r_eye * mask_r_full[:,1:2]
                    g_A_eye_loss_pgt = self.criterionPGT(fake_A, pgt_A, mask_s_eye) * self.lambda_eye
                    g_B_eye_loss_pgt = self.criterionPGT(fake_B, pgt_B, mask_r_eye) * self.lambda_eye
                    g_A_loss_pgt += g_A_eye_loss_pgt
                    g_B_loss_pgt += g_B_eye_loss_pgt
                    
                    mask_s_skin = mask_s_full[:,1:2] * (1 - mask_s_eye)
                    mask_r_skin = mask_r_full[:,1:2] * (1 - mask_r_eye)
                    g_A_skin_loss_pgt = self.criterionPGT(fake_A, pgt_A, mask_s_skin) * self.lambda_skin
                    g_B_skin_loss_pgt = self.criterionPGT(fake_B, pgt_B, mask_r_skin) * self.lambda_skin
                    g_A_loss_pgt += g_A_skin_loss_pgt
                    g_B_loss_pgt += g_B_skin_loss_pgt
                    
                    # cycle loss
                    rec_A = self.G(fake_A, image_s, mask_s, mask_s, diff_s, diff_s, lms_s, lms_s)
                    rec_B = self.G(fake_B, image_r, mask_r, mask_r, diff_r, diff_r, lms_r, lms_r)

                    # cycle loss v2
                    # rec_A = self.G(fake_A, fake_B, mask_s, mask_r, diff_s, diff_r, lms_s, lms_r)
                    # rec_B = self.G(fake_B, fake_A, mask_r, mask_s, diff_r, diff_s, lms_r, lms_s)

                    g_loss_rec_A = self.criterionL1(rec_A, image_s) * self.lambda_A
                    g_loss_rec_B = self.criterionL1(rec_B, image_r) * self.lambda_B

                    # vgg loss
                    vgg_s = self.vgg(image_s).detach()
                    vgg_fake_A = self.vgg(fake_A)
                    g_loss_A_vgg = self.criterionL2(vgg_fake_A, vgg_s) * self.lambda_A * self.lambda_vgg

                    vgg_r = self.vgg(image_r).detach()
                    vgg_fake_B = self.vgg(fake_B)
                    g_loss_B_vgg = self.criterionL2(vgg_fake_B, vgg_r) * self.lambda_B * self.lambda_vgg

                    loss_rec = (g_loss_rec_A + g_loss_rec_B + g_loss_A_vgg + g_loss_B_vgg) * 0.5

                    # Combined loss
                    g_loss = g_A_loss_adv + g_B_loss_adv + loss_rec + loss_idt + g_A_loss_pgt + g_B_loss_pgt

                    self.g_optimizer.zero_grad()
                    g_loss.backward()
                    self.g_optimizer.step()

                    # Logging
                    loss_tmp['G-A-loss-adv'] += g_A_loss_adv.item()
                    loss_tmp['G-B-loss-adv'] += g_B_loss_adv.item()
                    loss_tmp['G-loss-idt'] += loss_idt.item()
                    loss_tmp['G-loss-img-rec'] += (g_loss_rec_A + g_loss_rec_B).item() * 0.5
                    loss_tmp['G-loss-vgg-rec'] += (g_loss_A_vgg + g_loss_B_vgg).item() * 0.5
                    loss_tmp['G-loss-rec'] += loss_rec.item()
                    loss_tmp['G-loss-skin-pgt'] += (g_A_skin_loss_pgt + g_B_skin_loss_pgt).item()
                    loss_tmp['G-loss-eye-pgt'] += (g_A_eye_loss_pgt + g_B_eye_loss_pgt).item()
                    loss_tmp['G-loss-lip-pgt'] += (g_A_lip_loss_pgt + g_B_lip_loss_pgt).item()
                    loss_tmp['G-loss-pgt'] += (g_A_loss_pgt + g_B_loss_pgt).item()
                    losses_G.append(g_loss.item())
                    pbar.set_description("Epoch: %d, Step: %d, Loss_G: %0.4f, Loss_A: %0.4f, Loss_B: %0.4f" % \
                                (self.epoch, step + 1, np.mean(losses_G), np.mean(losses_D_A), np.mean(losses_D_B)))

            self.end_time = time.time()
            for k, v in loss_tmp.items():
                loss_tmp[k] = v / self.len_dataset  
            loss_tmp['G-loss'] = np.mean(losses_G)
            loss_tmp['D-A-loss'] = np.mean(losses_D_A)
            loss_tmp['D-B-loss'] = np.mean(losses_D_B)
            self.log_loss(loss_tmp)
            self.plot_loss()

            # Decay learning rate
            self.g_scheduler.step()
            self.d_A_scheduler.step()
            if self.double_d:
                self.d_B_scheduler.step()

            if self.pgt_annealing:
                self.pgt_maker.step()

            #save the images
            if (self.epoch) % self.vis_freq == 0:
                self.vis_train([image_s.detach().cpu(), 
                                image_r.detach().cpu(), 
                                fake_A.detach().cpu(), 
                                pgt_A.detach().cpu()])
            #                   rec_A.detach().cpu()])

            # Save model checkpoints
            if (self.epoch) % self.save_freq == 0:
                self.save_models()
   

    def get_loss_tmp(self):
        loss_tmp = {
            'D-A-loss_real':0.0,
            'D-A-loss_fake':0.0,
            'D-B-loss_real':0.0,
            'D-B-loss_fake':0.0,
            'G-A-loss-adv':0.0,
            'G-B-loss-adv':0.0,
            'G-loss-idt':0.0,
            'G-loss-img-rec':0.0,
            'G-loss-vgg-rec':0.0,
            'G-loss-rec':0.0,
            'G-loss-skin-pgt':0.0,
            'G-loss-eye-pgt':0.0,
            'G-loss-lip-pgt':0.0,
            'G-loss-pgt':0.0,
        }
        return loss_tmp

    def log_loss(self, loss_tmp):
        if self.logger is not None:
            self.logger.info('\n' + '='*40 + '\nEpoch {:d}, time {:.2f} s'
                            .format(self.epoch, self.end_time - self.start_time))
        else:
            print('\n' + '='*40 + '\nEpoch {:d}, time {:d} s'
                    .format(self.epoch, self.end_time - self.start_time))
        for k, v in loss_tmp.items():
            self.loss_logger[k].append(v)
            if self.logger is not None:
                self.logger.info('{:s}\t{:.6f}'.format(k, v))  
            else:
                print('{:s}\t{:.6f}'.format(k, v))  
        if self.logger is not None:
            self.logger.info('='*40)  
        else:
            print('='*40)

    def plot_loss(self):
        G_losses = []; G_names = []
        D_A_losses = []; D_A_names = []
        D_B_losses = []; D_B_names = []
        D_P_losses = []; D_P_names = []
        for k, v in self.loss_logger.items():
            if 'G' in k:
                G_names.append(k); G_losses.append(v)
            elif 'D-A' in k:
                D_A_names.append(k); D_A_losses.append(v)
            elif 'D-B' in k:
                D_B_names.append(k); D_B_losses.append(v)
            elif 'D-P' in k:
                D_P_names.append(k); D_P_losses.append(v)
        plot_curves(self.save_folder, 'G_loss', G_losses, G_names, ylabel='Loss')
        plot_curves(self.save_folder, 'D-A_loss', D_A_losses, D_A_names, ylabel='Loss')
        plot_curves(self.save_folder, 'D-B_loss', D_B_losses, D_B_names, ylabel='Loss')

    def load_checkpoint(self):
        G_path = os.path.join(self.load_folder, 'G.pth')
        if os.path.exists(G_path):
            self.G.load_state_dict(torch.load(G_path, map_location=self.device))
            print('loaded trained generator {}..!'.format(G_path))
        D_A_path = os.path.join(self.load_folder, 'D_A.pth')
        if os.path.exists(D_A_path):
            self.D_A.load_state_dict(torch.load(D_A_path, map_location=self.device))
            print('loaded trained discriminator A {}..!'.format(D_A_path))

        if self.double_d:
            D_B_path = os.path.join(self.load_folder, 'D_B.pth')
            if os.path.exists(D_B_path):
                self.D_B.load_state_dict(torch.load(D_B_path, map_location=self.device))
                print('loaded trained discriminator B {}..!'.format(D_B_path))
    
    def save_models(self):
        save_dir = os.path.join(self.save_folder, 'epoch_{:d}'.format(self.epoch))
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        torch.save(self.G.state_dict(), os.path.join(save_dir, 'G.pth'))
        torch.save(self.D_A.state_dict(), os.path.join(save_dir, 'D_A.pth'))
        if self.double_d:
            torch.save(self.D_B.state_dict(), os.path.join(save_dir, 'D_B.pth'))

    def de_norm(self, x):
        out = (x + 1) / 2
        return out.clamp(0, 1)
    
    def vis_train(self, img_train_batch):
        # saving training results
        img_train_batch = torch.cat(img_train_batch, dim=3)
        save_path = os.path.join(self.vis_folder, 'epoch_{:d}_fake.png'.format(self.epoch))
        vis_image = make_grid(self.de_norm(img_train_batch), 1)
        save_image(vis_image, save_path) #, normalize=True)

    def generate(self, image_A, image_B, mask_A=None, mask_B=None, 
                 diff_A=None, diff_B=None, lms_A=None, lms_B=None):
        """image_A is content, image_B is style"""
        with torch.no_grad():
            res = self.G(image_A, image_B, mask_A, mask_B, diff_A, diff_B, lms_A, lms_B)
        return res

    def test(self, image_A, mask_A, diff_A, lms_A, image_B, mask_B, diff_B, lms_B):        
        with torch.no_grad():
            fake_A = self.generate(image_A, image_B, mask_A, mask_B, diff_A, diff_B, lms_A, lms_B)
        fake_A = self.de_norm(fake_A)
        fake_A = fake_A.squeeze(0)
        return ToPILImage()(fake_A.cpu())

================================================
FILE: training/utils.py
================================================
import os
import logging
import numpy as np
import matplotlib.pyplot as plt

def create_logger(save_path='', file_type='', level='debug', console=True):
    if level == 'debug':
        _level = logging.DEBUG
    elif level == 'info':
        _level = logging.INFO

    logger = logging.getLogger()
    logger.setLevel(_level)

    if console:
        cs = logging.StreamHandler()
        cs.setLevel(_level)
        logger.addHandler(cs)

    if save_path != '':
        file_name = os.path.join(save_path, file_type + '_log.txt')
        fh = logging.FileHandler(file_name, mode='w')
        fh.setLevel(_level)

        logger.addHandler(fh)

    return logger

def print_args(args, logger=None):
    for k, v in vars(args).items():
        if logger is not None:
            logger.info('{:<16} : {}'.format(k, v))
        else:
            print('{:<16} : {}'.format(k, v))

def plot_single_curve(path, name, point, freq=1, xlabel='Epoch',ylabel=None):
    
    x = (np.arange(len(point)) + 1) * freq
    plt.plot(x, point, color='purple')
    plt.xlabel(xlabel)
    if ylabel is None:
        ylabel = name
    plt.ylabel(ylabel)
    plt.savefig(os.path.join(path, name + '.png'))
    plt.close()

def plot_curves(path, name, point_list, curve_names=None, freq=1, xlabel='Epoch',ylabel=None):
    if curve_names is None:
        curve_names = [''] * len(point_list)
    else:
        assert len(point_list) == len(curve_names)

    x = (np.arange(len(point_list[0])) + 1) * freq
    if len(point_list) <= 10:
        cmap = plt.get_cmap('tab10')
    else:
        cmap = plt.get_cmap('tab20')
    for i, (point, curve_name) in enumerate(zip(point_list, curve_names)):
        assert len(point) == len(x)
        plt.plot(x, point, color=cmap(i), label=curve_name)
        
    plt.xlabel(xlabel)
    if ylabel is not None:
        plt.ylabel(ylabel)
    plt.legend()
    plt.savefig(os.path.join(path, name + '.png'))
    plt.close()
Download .txt
gitextract_z1ehcgp9/

├── .gitignore
├── LICENSE
├── README.md
├── assets/
│   └── docs/
│       ├── install.md
│       └── prepare.md
├── concern/
│   ├── __init__.py
│   ├── image.py
│   ├── track.py
│   └── visualize.py
├── faceutils/
│   ├── __init__.py
│   ├── dlibutils/
│   │   ├── __init__.py
│   │   └── main.py
│   └── mask/
│       ├── __init__.py
│       ├── main.py
│       ├── model.py
│       └── resnet.py
├── models/
│   ├── __init__.py
│   ├── elegant.py
│   ├── loss.py
│   ├── model.py
│   └── modules/
│       ├── __init__.py
│       ├── histogram_matching.py
│       ├── module_attn.py
│       ├── module_base.py
│       ├── pseudo_gt.py
│       ├── sow_attention.py
│       ├── spectral_norm.py
│       └── tps_transform.py
├── scripts/
│   ├── demo.py
│   └── train.py
└── training/
    ├── __init__.py
    ├── config.py
    ├── dataset.py
    ├── inference.py
    ├── preprocess.py
    ├── solver.py
    └── utils.py
Download .txt
SYMBOL INDEX (232 symbols across 25 files)

FILE: concern/image.py
  function load_image (line 6) | def load_image(path):
  function resize_by_max (line 15) | def resize_by_max(image, max_side=512, force=False):
  function image2buffer (line 25) | def image2buffer(image):

FILE: concern/track.py
  class Track (line 6) | class Track:
    method __init__ (line 7) | def __init__(self):
    method track (line 11) | def track(self, mark):

FILE: concern/visualize.py
  function channel_first (line 5) | def channel_first(image, format):
  function mask2image (line 9) | def mask2image(mask:np.array, format="HWC"):
  function draw_points (line 18) | def draw_points(image, points, color=(255, 0, 0)):

FILE: faceutils/dlibutils/main.py
  function detect (line 15) | def detect(image: Image) -> 'faces':
  function crop (line 33) | def crop(image: Image, face, up_ratio, down_ratio, width_ratio) -> (Imag...
  function crop_by_image_size (line 84) | def crop_by_image_size(image: Image, face) -> (Image, 'face'):
  function landmarks (line 110) | def landmarks(image: Image, face):
  function crop_from_array (line 114) | def crop_from_array(image: np.array, face) -> (np.array, 'face'):

FILE: faceutils/mask/main.py
  class FaceParser (line 14) | class FaceParser:
    method __init__ (line 15) | def __init__(self, device="cpu"):
    method parse (line 30) | def parse(self, image: Image):

FILE: faceutils/mask/model.py
  class ConvBNReLU (line 11) | class ConvBNReLU(nn.Module):
    method __init__ (line 12) | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args...
    method forward (line 23) | def forward(self, x):
    method init_weight (line 28) | def init_weight(self):
  class BiSeNetOutput (line 34) | class BiSeNetOutput(nn.Module):
    method __init__ (line 35) | def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
    method forward (line 41) | def forward(self, x):
    method init_weight (line 46) | def init_weight(self):
    method get_params (line 52) | def get_params(self):
  class AttentionRefinementModule (line 64) | class AttentionRefinementModule(nn.Module):
    method __init__ (line 65) | def __init__(self, in_chan, out_chan, *args, **kwargs):
    method forward (line 73) | def forward(self, x):
    method init_weight (line 82) | def init_weight(self):
  class ContextPath (line 89) | class ContextPath(nn.Module):
    method __init__ (line 90) | def __init__(self, *args, **kwargs):
    method forward (line 101) | def forward(self, x):
    method init_weight (line 124) | def init_weight(self):
    method get_params (line 130) | def get_params(self):
  class SpatialPath (line 143) | class SpatialPath(nn.Module):
    method __init__ (line 144) | def __init__(self, *args, **kwargs):
    method forward (line 152) | def forward(self, x):
    method init_weight (line 159) | def init_weight(self):
    method get_params (line 165) | def get_params(self):
  class FeatureFusionModule (line 177) | class FeatureFusionModule(nn.Module):
    method __init__ (line 178) | def __init__(self, in_chan, out_chan, *args, **kwargs):
    method forward (line 197) | def forward(self, fsp, fcp):
    method init_weight (line 209) | def init_weight(self):
    method get_params (line 215) | def get_params(self):
  class BiSeNet (line 227) | class BiSeNet(nn.Module):
    method __init__ (line 228) | def __init__(self, n_classes, *args, **kwargs):
    method forward (line 238) | def forward(self, x):
    method init_weight (line 253) | def init_weight(self):
    method get_params (line 259) | def get_params(self):

FILE: faceutils/mask/resnet.py
  function conv3x3 (line 11) | def conv3x3(in_planes, out_planes, stride=1):
  class BasicBlock (line 17) | class BasicBlock(nn.Module):
    method __init__ (line 18) | def __init__(self, in_chan, out_chan, stride=1):
    method forward (line 33) | def forward(self, x):
  function create_layer_basic (line 48) | def create_layer_basic(in_chan, out_chan, bnum, stride=1):
  class Resnet18 (line 55) | class Resnet18(nn.Module):
    method __init__ (line 56) | def __init__(self):
    method forward (line 68) | def forward(self, x):
    method init_weight (line 79) | def init_weight(self):
    method get_params (line 87) | def get_params(self):

FILE: models/elegant.py
  class Generator (line 11) | class Generator(nn.ModuleDict):
    method __init__ (line 13) | def __init__(self, conv_dim=64, image_size=256, num_layer_e=2, num_lay...
    method get_transfer_input (line 104) | def get_transfer_input(self, image, mask, diff, lms, is_reference=False):
    method get_transfer_output (line 137) | def get_transfer_output(self, fea_c_list, mask_c_list, diff_c_list, lm...
    method decode (line 164) | def decode(self, fea_c_list, attn_out_list):
    method forward (line 182) | def forward(self, c, s, mask_c, mask_s, diff_c, diff_s, lms_c, lms_s):
    method tps_align (line 197) | def tps_align(self, feature_size, lms_s, lms_c, fea_s, sample_mode='bi...

FILE: models/loss.py
  class GANLoss (line 9) | class GANLoss(nn.Module):
    method __init__ (line 15) | def __init__(self, gan_mode='lsgan', target_real_label=1.0, target_fak...
    method forward (line 35) | def forward(self, prediction, target_is_real):
  function norm (line 53) | def norm(x: torch.Tensor):
  function de_norm (line 56) | def de_norm(x: torch.Tensor):
  function masked_his_match (line 60) | def masked_his_match(image_s, image_r, mask_s, mask_r):
  function generate_pgt (line 86) | def generate_pgt(image_s, image_r, mask_s, mask_r, lms_s, lms_r, margins...
  class LinearAnnealingFn (line 115) | class LinearAnnealingFn():
    method __init__ (line 119) | def __init__(self, milestones, f_values):
    method __call__ (line 124) | def __call__(self, t:int):
  class ComposePGT (line 137) | class ComposePGT(nn.Module):
    method __init__ (line 138) | def __init__(self, margins, skin_alpha, eye_alpha, lip_alpha):
    method forward (line 148) | def forward(self, sources, targets, mask_srcs, mask_tars, lms_srcs, lm...
  class AnnealingComposePGT (line 158) | class AnnealingComposePGT(nn.Module):
    method __init__ (line 159) | def __init__(self, margins,
    method step (line 174) | def step(self):
    method forward (line 181) | def forward(self, sources, targets, mask_srcs, mask_tars, lms_srcs, lm...
  class MakeupLoss (line 192) | class MakeupLoss(nn.Module):
    method __init__ (line 196) | def __init__(self):
    method forward (line 199) | def forward(self, x, target, mask=None):

FILE: models/model.py
  function get_generator (line 11) | def get_generator(config):
  function get_discriminator (line 27) | def get_discriminator(config):
  class Discriminator (line 38) | class Discriminator(nn.Module):
    method __init__ (line 40) | def __init__(self, input_channel=3, conv_dim=64, num_layers=3, norm='S...
    method forward (line 73) | def forward(self, x):
  class VGG (line 79) | class VGG(TVGG):
    method forward (line 80) | def forward(self, x):
  function make_layers (line 85) | def make_layers(cfg, batch_norm=False):
  function _vgg (line 101) | def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
  function vgg16 (line 112) | def vgg16(pretrained=False, progress=True, **kwargs):

FILE: models/modules/histogram_matching.py
  function cal_hist (line 4) | def cal_hist(image):
  function cal_trans (line 25) | def cal_trans(ref, adj):
  function histogram_matching (line 39) | def histogram_matching(dstImg, refImg, index):

FILE: models/modules/module_attn.py
  class MultiheadAttention_weight (line 7) | class MultiheadAttention_weight(nn.Module):
    method __init__ (line 8) | def __init__(self, feature_dim, proj_dim, num_heads=1, dropout=0.0, bi...
    method forward (line 21) | def forward(self, fea_c, fea_s, mask_c, mask_s):
  class MultiheadAttention_value (line 55) | class MultiheadAttention_value(nn.Module):
    method __init__ (line 56) | def __init__(self, feature_dim, proj_dim, num_heads=1, bias=True):
    method forward (line 66) | def forward(self, weights, fea):
  class MultiheadAttention (line 82) | class MultiheadAttention(nn.Module):
    method __init__ (line 83) | def __init__(self, in_channels, proj_channels, value_channels, out_cha...
    method forward (line 88) | def forward(self, fea_q, fea_k, fea_v, mask_q, mask_k):
  class FeedForwardLayer (line 97) | class FeedForwardLayer(nn.Module):
    method __init__ (line 98) | def __init__(self, feature_dim, ff_dim, dropout=0.0):
    method forward (line 108) | def forward(self, x):
  class Attention_apply (line 112) | class Attention_apply(nn.Module):
    method __init__ (line 113) | def __init__(self, feature_dim, normalize=True):
    method forward (line 121) | def forward(self, x, attn_out):

FILE: models/modules/module_base.py
  class ResidualBlock (line 7) | class ResidualBlock(nn.Module):
    method __init__ (line 9) | def __init__(self, dim_in, dim_out):
    method forward (line 19) | def forward(self, x):
  class ResidualBlock_IN (line 24) | class ResidualBlock_IN(nn.Module):
    method __init__ (line 26) | def __init__(self, dim_in, dim_out, affine=False):
    method forward (line 38) | def forward(self, x):
  class ResidualBlock_Downsample (line 43) | class ResidualBlock_Downsample(nn.Module):
    method __init__ (line 45) | def __init__(self, dim_in, dim_out, affine=False):
    method forward (line 57) | def forward(self, x):
  class Downsample (line 64) | class Downsample(nn.Module):
    method __init__ (line 66) | def __init__(self, dim_in, dim_out, affine=False):
    method forward (line 74) | def forward(self, x):
  class ResidualBlock_Upsample (line 78) | class ResidualBlock_Upsample(nn.Module):
    method __init__ (line 80) | def __init__(self, dim_in, dim_out, normalize=True, affine=False):
    method forward (line 98) | def forward(self, x):
  class Upsample (line 105) | class Upsample(nn.Module):
    method __init__ (line 107) | def __init__(self, dim_in, dim_out, normalize=True, affine=False):
    method forward (line 121) | def forward(self, x):
  class PositionalEmbedding (line 125) | class PositionalEmbedding(nn.Module):
    method __init__ (line 126) | def __init__(self, embedding_dim=136, feature_size=64, max_size=None, ...
    method forward (line 135) | def forward(self, diff, mask):
  class MergeBlock (line 165) | class MergeBlock(nn.Module):
    method __init__ (line 166) | def __init__(self, merge_mode, feature_dim, normalize=True):
    method forward (line 181) | def forward(self, fea_s, fea_r):

FILE: models/modules/pseudo_gt.py
  function expand_area (line 11) | def expand_area(mask:torch.Tensor, margin:int):
  function mask_blur (line 28) | def mask_blur(mask:torch.Tensor, blur_size=3, mode='smooth'):
  function mask_blend (line 54) | def mask_blend(mask, blend_alpha, mask_bound=None, blur_size=3, blend_mo...
  function tps_align (line 64) | def tps_align(img_size, lms_r, lms_s, image_r, image_s=None,
  function tps_blend (line 86) | def tps_blend(blend_alpha, img_size, lms_r, lms_s, image_r, image_s=None...
  function fine_align (line 110) | def fine_align(img_size, lms_r, lms_s, image_r, image_s, mask_r, mask_s,...

FILE: models/modules/sow_attention.py
  class WindowAttention (line 6) | class WindowAttention(nn.Module):
    method __init__ (line 7) | def __init__(self, window_size, in_channels, proj_channels, value_chan...
    method generate_window_weight (line 33) | def generate_window_weight(self):
    method make_window (line 41) | def make_window(self, x: torch.Tensor):
    method demake_window (line 54) | def demake_window(self, x: torch.Tensor):
    method make_mask_window (line 70) | def make_mask_window(self, mask: torch.Tensor):
    method forward (line 83) | def forward(self, fea_q, fea_k, fea_v, mask_q=None, mask_k=None):
  class SowAttention (line 118) | class SowAttention(nn.Module):
    method __init__ (line 119) | def __init__(self, window_size, in_channels, proj_channels, value_chan...
    method forward (line 128) | def forward(self, fea_q, fea_k, fea_v, mask_q=None, mask_k=None):
  class StridedwindowAttention (line 182) | class StridedwindowAttention(nn.Module):
    method __init__ (line 183) | def __init__(self, stride, in_channels, proj_channels, value_channels,...
    method make_window (line 203) | def make_window(self, x: torch.Tensor):
    method demake_window (line 218) | def demake_window(self, x: torch.Tensor, h, w):
    method make_mask_window (line 231) | def make_mask_window(self, mask: torch.Tensor):
    method forward (line 244) | def forward(self, fea_q, fea_k, fea_v, mask_q=None, mask_k=None):

FILE: models/modules/spectral_norm.py
  function l2normalize (line 4) | def l2normalize(v, eps=1e-12):
  class SpectralNorm (line 7) | class SpectralNorm(object):
    method __init__ (line 8) | def __init__(self):
    method compute_weight (line 13) | def compute_weight(self, module):
    method apply (line 27) | def apply(module):
    method remove (line 59) | def remove(self, module):
    method __call__ (line 67) | def __call__(self, module, inputs):
  function spectral_norm (line 70) | def spectral_norm(module):
  function remove_spectral_norm (line 74) | def remove_spectral_norm(module):

FILE: models/modules/tps_transform.py
  function grid_sample (line 15) | def grid_sample(input, grid, mode='bilinear', canvas=None):
  function compute_partial_repr (line 27) | def compute_partial_repr(input_points, control_points):
  function bulid_delta_inverse (line 44) | def bulid_delta_inverse(target_control_points):
  function build_target_coordinate_matrix (line 62) | def build_target_coordinate_matrix(target_height, target_width, target_c...
  function tps_sampler (line 81) | def tps_sampler(target_height, target_width, inverse_kernel, target_coor...
  function tps_spatial_transform (line 102) | def tps_spatial_transform(target_height, target_width, target_control_po...
  class TPSSpatialTransformer (line 116) | class TPSSpatialTransformer(nn.Module):
    method __init__ (line 118) | def __init__(self, target_height, target_width, target_control_points):
    method forward (line 135) | def forward(self, source, source_control_points):

FILE: scripts/demo.py
  function main (line 14) | def main(config, args):

FILE: scripts/train.py
  function main (line 14) | def main(config, args):

FILE: training/config.py
  function get_config (line 94) | def get_config()->CfgNode:

FILE: training/dataset.py
  class MakeupDataset (line 9) | class MakeupDataset(Dataset):
    method __init__ (line 10) | def __init__(self, config=None):
    method load_from_file (line 22) | def load_from_file(self, img_name):
    method __len__ (line 29) | def __len__(self):
    method __getitem__ (line 32) | def __getitem__(self, index):
  function get_loader (line 41) | def get_loader(config):

FILE: training/inference.py
  class InputSample (line 13) | class InputSample:
    method __init__ (line 14) | def __init__(self, inputs, apply_mask=None):
    method clear (line 20) | def clear(self):
  class Inference (line 25) | class Inference:
    method __init__ (line 31) | def __init__(self, config, args, model_path="G.pth"):
    method prepare_input (line 41) | def prepare_input(self, *data_inputs):
    method postprocess (line 52) | def postprocess(self, source, crop_face, result):
    method generate_source_sample (line 74) | def generate_source_sample(self, source_input):
    method generate_reference_sample (line 81) | def generate_reference_sample(self, reference_input, apply_mask=None,
    method generate_partial_mask (line 97) | def generate_partial_mask(self, source_mask, mask_area='full', saturat...
    method interface_transfer (line 124) | def interface_transfer(self, source_sample: InputSample, reference_sam...
    method transfer (line 184) | def transfer(self, source: Image, reference: Image, postprocess=True):
    method joint_transfer (line 209) | def joint_transfer(self, source: Image, reference_lip: Image, referenc...

FILE: training/preprocess.py
  class PreProcess (line 15) | class PreProcess:
    method __init__ (line 17) | def __init__(self, config, need_parser=True, device='cpu'):
    method mask_process (line 54) | def mask_process(self, mask: torch.Tensor):
    method save_mask (line 74) | def save_mask(self, mask: torch.Tensor, path):
    method load_mask (line 80) | def load_mask(self, path):
    method lms_process (line 87) | def lms_process(self, image:Image):
    method diff_process (line 108) | def diff_process(self, lms: torch.Tensor, normalize=False):
    method save_lms (line 121) | def save_lms(self, lms: torch.Tensor, path):
    method load_lms (line 125) | def load_lms(self, path):
    method preprocess (line 130) | def preprocess(self, image: Image, is_crop=True):
    method process (line 170) | def process(self, image: Image, mask: torch.Tensor, lms: torch.Tensor):
    method __call__ (line 176) | def __call__(self, image:Image, is_crop=True):

FILE: training/solver.py
  class Solver (line 18) | class Solver():
    method __init__ (line 19) | def __init__(self, config, args, logger=None, inference=False):
    method print_network (line 100) | def print_network(self, model, name):
    method weights_init_xavier (line 110) | def weights_init_xavier(self, m):
    method build_model (line 117) | def build_model(self):
    method train (line 154) | def train(self, data_loader):
    method get_loss_tmp (line 355) | def get_loss_tmp(self):
    method log_loss (line 374) | def log_loss(self, loss_tmp):
    method plot_loss (line 392) | def plot_loss(self):
    method load_checkpoint (line 410) | def load_checkpoint(self):
    method save_models (line 426) | def save_models(self):
    method de_norm (line 435) | def de_norm(self, x):
    method vis_train (line 439) | def vis_train(self, img_train_batch):
    method generate (line 446) | def generate(self, image_A, image_B, mask_A=None, mask_B=None,
    method test (line 453) | def test(self, image_A, mask_A, diff_A, lms_A, image_B, mask_B, diff_B...

FILE: training/utils.py
  function create_logger (line 6) | def create_logger(save_path='', file_type='', level='debug', console=True):
  function print_args (line 29) | def print_args(args, logger=None):
  function plot_single_curve (line 36) | def plot_single_curve(path, name, point, freq=1, xlabel='Epoch',ylabel=N...
  function plot_curves (line 47) | def plot_curves(path, name, point_list, curve_names=None, freq=1, xlabel...
Condensed preview — 37 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (171K chars).
[
  {
    "path": ".gitignore",
    "chars": 1436,
    "preview": "/data\n/results\n*.dat\n*.pth\n*.pt\n\n# *.txt\n# !/expe/*/*.txt\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]"
  },
  {
    "path": "LICENSE",
    "chars": 20844,
    "preview": "Attribution-NonCommercial-ShareAlike 4.0 International\n\n================================================================"
  },
  {
    "path": "README.md",
    "chars": 2316,
    "preview": "# EleGANt: Exquisite and Locally Editable GAN for Makeup Transfer\n\n[![CC BY-NC-SA 4.0][cc-by-nc-sa-shield]][cc-by-nc-sa]"
  },
  {
    "path": "assets/docs/install.md",
    "chars": 581,
    "preview": "# Installation Instructions\r\n\r\nThis code was tested on Ubuntu 20.04 with CUDA 11.1.\r\n\r\n**a. Create a conda virtual envir"
  },
  {
    "path": "assets/docs/prepare.md",
    "chars": 1497,
    "preview": "# Preparation Instructions\r\n\r\nClone this repository and prepare the dataset and weights through the following steps:\r\n\r\n"
  },
  {
    "path": "concern/__init__.py",
    "chars": 30,
    "preview": "from .image import load_image\n"
  },
  {
    "path": "concern/image.py",
    "chars": 734,
    "preview": "import numpy as np\nimport cv2\nfrom io import BytesIO\n\n\ndef load_image(path):\n    with path.open(\"rb\") as reader:\n       "
  },
  {
    "path": "concern/track.py",
    "chars": 499,
    "preview": "import time\n\nimport torch\n\n\nclass Track:\n    def __init__(self):\n        self.log_point = time.time()\n        self.enabl"
  },
  {
    "path": "concern/visualize.py",
    "chars": 717,
    "preview": "import numpy as np\nimport cv2\n\n\ndef channel_first(image, format):\n    return image.transpose(\n        format.index(\"C\"),"
  },
  {
    "path": "faceutils/__init__.py",
    "chars": 130,
    "preview": "#!/usr/bin/python\n# -*- encoding: utf-8 -*-\n#from . import faceplusplus as fpp\nfrom . import dlibutils as dlib\nfrom . im"
  },
  {
    "path": "faceutils/dlibutils/__init__.py",
    "chars": 103,
    "preview": "#!/usr/bin/python\n# -*- encoding: utf-8 -*-\nfrom .main import detect, crop, landmarks, crop_from_array\n"
  },
  {
    "path": "faceutils/dlibutils/main.py",
    "chars": 5779,
    "preview": "#!/usr/bin/python\n# -*- encoding: utf-8 -*-\nimport os.path as osp\n\nimport numpy as np\nfrom PIL import Image\nimport dlib\n"
  },
  {
    "path": "faceutils/mask/__init__.py",
    "chars": 73,
    "preview": "#!/usr/bin/python\n# -*- encoding: utf-8 -*-\nfrom .main import FaceParser\n"
  },
  {
    "path": "faceutils/mask/main.py",
    "chars": 1241,
    "preview": "#!/usr/bin/python\n# -*- encoding: utf-8 -*-\nimport os.path as osp\n\nimport numpy as np\nimport cv2\nfrom PIL import Image\ni"
  },
  {
    "path": "faceutils/mask/model.py",
    "chars": 10539,
    "preview": "#!/usr/bin/python\n# -*- encoding: utf-8 -*-\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport to"
  },
  {
    "path": "faceutils/mask/resnet.py",
    "chars": 3593,
    "preview": "#!/usr/bin/python\n# -*- encoding: utf-8 -*-\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport to"
  },
  {
    "path": "models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "models/elegant.py",
    "chars": 8927,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .modules.module_base import ResidualBlock_IN, D"
  },
  {
    "path": "models/loss.py",
    "chars": 7618,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .modules.histogram_matching import histogram_ma"
  },
  {
    "path": "models/model.py",
    "chars": 4241,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torchvision.models import VGG as TVGG\nfrom torch"
  },
  {
    "path": "models/modules/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "models/modules/histogram_matching.py",
    "chars": 2078,
    "preview": "import copy\nimport torch\n\ndef cal_hist(image):\n    \"\"\"\n        cal cumulative hist for channel list\n    \"\"\"\n    hists = "
  },
  {
    "path": "models/modules/module_attn.py",
    "chars": 5233,
    "preview": "import math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass MultiheadAttention_weight(nn.Modu"
  },
  {
    "path": "models/modules/module_base.py",
    "chars": 7943,
    "preview": "import math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass ResidualBlock(nn.Module):\n    \"\"\""
  },
  {
    "path": "models/modules/pseudo_gt.py",
    "chars": 6500,
    "preview": "import cv2\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torchvision.transf"
  },
  {
    "path": "models/modules/sow_attention.py",
    "chars": 12747,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass WindowAttention(nn.Module):\n    def __init__("
  },
  {
    "path": "models/modules/spectral_norm.py",
    "chars": 2791,
    "preview": "import torch\nfrom torch.nn import Parameter\n\ndef l2normalize(v, eps=1e-12):\n    return v / (v.norm() + eps)\n\nclass Spect"
  },
  {
    "path": "models/modules/tps_transform.py",
    "chars": 6127,
    "preview": "from __future__ import absolute_import\n\nimport numpy as np\nimport itertools\n\nimport torch\nimport torch.nn as nn\nimport t"
  },
  {
    "path": "scripts/demo.py",
    "chars": 2237,
    "preview": "import os\nimport sys\nimport argparse\nimport numpy as np\nimport cv2\nimport torch\nfrom PIL import Image\nsys.path.append('."
  },
  {
    "path": "scripts/train.py",
    "chars": 1637,
    "preview": "import os\nimport sys\nimport argparse\nimport torch\nfrom torch.utils.data import DataLoader\nsys.path.append('.')\n\nfrom tra"
  },
  {
    "path": "training/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "training/config.py",
    "chars": 2557,
    "preview": "from fvcore.common.config import CfgNode\n\n\"\"\"\nThis file defines default options of configurations.\nIt will be further me"
  },
  {
    "path": "training/dataset.py",
    "chars": 2196,
    "preview": "import os\nfrom PIL import Image\nimport torch\nfrom torch.utils.data import Dataset, DataLoader\n\nfrom training.config impo"
  },
  {
    "path": "training/inference.py",
    "chars": 9960,
    "preview": "from typing import List\nimport numpy as np\nimport cv2\nfrom PIL import Image\nimport torch\nimport torch.nn.functional as F"
  },
  {
    "path": "training/preprocess.py",
    "chars": 8774,
    "preview": "import os\nimport sys\nimport cv2\nfrom PIL import Image\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfr"
  },
  {
    "path": "training/solver.py",
    "chars": 21006,
    "preview": "import os\nimport time\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torchvi"
  },
  {
    "path": "training/utils.py",
    "chars": 1939,
    "preview": "import os\nimport logging\nimport numpy as np\nimport matplotlib.pyplot as plt\n\ndef create_logger(save_path='', file_type='"
  }
]

About this extraction

This page contains the full source code of the Chenyu-Yang-2000/EleGANt GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 37 files (160.8 KB), approximately 42.9k tokens, and a symbol index with 232 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!