Repository: Chenyu-Yang-2000/EleGANt
Branch: main
Commit: d033d3398751
Files: 37
Total size: 160.8 KB
Directory structure:
gitextract_z1ehcgp9/
├── .gitignore
├── LICENSE
├── README.md
├── assets/
│ └── docs/
│ ├── install.md
│ └── prepare.md
├── concern/
│ ├── __init__.py
│ ├── image.py
│ ├── track.py
│ └── visualize.py
├── faceutils/
│ ├── __init__.py
│ ├── dlibutils/
│ │ ├── __init__.py
│ │ └── main.py
│ └── mask/
│ ├── __init__.py
│ ├── main.py
│ ├── model.py
│ └── resnet.py
├── models/
│ ├── __init__.py
│ ├── elegant.py
│ ├── loss.py
│ ├── model.py
│ └── modules/
│ ├── __init__.py
│ ├── histogram_matching.py
│ ├── module_attn.py
│ ├── module_base.py
│ ├── pseudo_gt.py
│ ├── sow_attention.py
│ ├── spectral_norm.py
│ └── tps_transform.py
├── scripts/
│ ├── demo.py
│ └── train.py
└── training/
├── __init__.py
├── config.py
├── dataset.py
├── inference.py
├── preprocess.py
├── solver.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
/data
/results
*.dat
*.pth
*.pt
# *.txt
# !/expe/*/*.txt
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
.idea/
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
.vscode/settings.json
================================================
FILE: LICENSE
================================================
Attribution-NonCommercial-ShareAlike 4.0 International
=======================================================================
Creative Commons Corporation ("Creative Commons") is not a law firm and
does not provide legal services or legal advice. Distribution of
Creative Commons public licenses does not create a lawyer-client or
other relationship. Creative Commons makes its licenses and related
information available on an "as-is" basis. Creative Commons gives no
warranties regarding its licenses, any material licensed under their
terms and conditions, or any related information. Creative Commons
disclaims all liability for damages resulting from their use to the
fullest extent possible.
Using Creative Commons Public Licenses
Creative Commons public licenses provide a standard set of terms and
conditions that creators and other rights holders may use to share
original works of authorship and other material subject to copyright
and certain other rights specified in the public license below. The
following considerations are for informational purposes only, are not
exhaustive, and do not form part of our licenses.
Considerations for licensors: Our public licenses are
intended for use by those authorized to give the public
permission to use material in ways otherwise restricted by
copyright and certain other rights. Our licenses are
irrevocable. Licensors should read and understand the terms
and conditions of the license they choose before applying it.
Licensors should also secure all rights necessary before
applying our licenses so that the public can reuse the
material as expected. Licensors should clearly mark any
material not subject to the license. This includes other CC-
licensed material, or material used under an exception or
limitation to copyright. More considerations for licensors:
wiki.creativecommons.org/Considerations_for_licensors
Considerations for the public: By using one of our public
licenses, a licensor grants the public permission to use the
licensed material under specified terms and conditions. If
the licensor's permission is not necessary for any reason--for
example, because of any applicable exception or limitation to
copyright--then that use is not regulated by the license. Our
licenses grant only permissions under copyright and certain
other rights that a licensor has authority to grant. Use of
the licensed material may still be restricted for other
reasons, including because others have copyright or other
rights in the material. A licensor may make special requests,
such as asking that all changes be marked or described.
Although not required by our licenses, you are encouraged to
respect those requests where reasonable. More considerations
for the public:
wiki.creativecommons.org/Considerations_for_licensees
=======================================================================
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
Public License
By exercising the Licensed Rights (defined below), You accept and agree
to be bound by the terms and conditions of this Creative Commons
Attribution-NonCommercial-ShareAlike 4.0 International Public License
("Public License"). To the extent this Public License may be
interpreted as a contract, You are granted the Licensed Rights in
consideration of Your acceptance of these terms and conditions, and the
Licensor grants You such rights in consideration of benefits the
Licensor receives from making the Licensed Material available under
these terms and conditions.
Section 1 -- Definitions.
a. Adapted Material means material subject to Copyright and Similar
Rights that is derived from or based upon the Licensed Material
and in which the Licensed Material is translated, altered,
arranged, transformed, or otherwise modified in a manner requiring
permission under the Copyright and Similar Rights held by the
Licensor. For purposes of this Public License, where the Licensed
Material is a musical work, performance, or sound recording,
Adapted Material is always produced where the Licensed Material is
synched in timed relation with a moving image.
b. Adapter's License means the license You apply to Your Copyright
and Similar Rights in Your contributions to Adapted Material in
accordance with the terms and conditions of this Public License.
c. BY-NC-SA Compatible License means a license listed at
creativecommons.org/compatiblelicenses, approved by Creative
Commons as essentially the equivalent of this Public License.
d. Copyright and Similar Rights means copyright and/or similar rights
closely related to copyright including, without limitation,
performance, broadcast, sound recording, and Sui Generis Database
Rights, without regard to how the rights are labeled or
categorized. For purposes of this Public License, the rights
specified in Section 2(b)(1)-(2) are not Copyright and Similar
Rights.
e. Effective Technological Measures means those measures that, in the
absence of proper authority, may not be circumvented under laws
fulfilling obligations under Article 11 of the WIPO Copyright
Treaty adopted on December 20, 1996, and/or similar international
agreements.
f. Exceptions and Limitations means fair use, fair dealing, and/or
any other exception or limitation to Copyright and Similar Rights
that applies to Your use of the Licensed Material.
g. License Elements means the license attributes listed in the name
of a Creative Commons Public License. The License Elements of this
Public License are Attribution, NonCommercial, and ShareAlike.
h. Licensed Material means the artistic or literary work, database,
or other material to which the Licensor applied this Public
License.
i. Licensed Rights means the rights granted to You subject to the
terms and conditions of this Public License, which are limited to
all Copyright and Similar Rights that apply to Your use of the
Licensed Material and that the Licensor has authority to license.
j. Licensor means the individual(s) or entity(ies) granting rights
under this Public License.
k. NonCommercial means not primarily intended for or directed towards
commercial advantage or monetary compensation. For purposes of
this Public License, the exchange of the Licensed Material for
other material subject to Copyright and Similar Rights by digital
file-sharing or similar means is NonCommercial provided there is
no payment of monetary compensation in connection with the
exchange.
l. Share means to provide material to the public by any means or
process that requires permission under the Licensed Rights, such
as reproduction, public display, public performance, distribution,
dissemination, communication, or importation, and to make material
available to the public including in ways that members of the
public may access the material from a place and at a time
individually chosen by them.
m. Sui Generis Database Rights means rights other than copyright
resulting from Directive 96/9/EC of the European Parliament and of
the Council of 11 March 1996 on the legal protection of databases,
as amended and/or succeeded, as well as other essentially
equivalent rights anywhere in the world.
n. You means the individual or entity exercising the Licensed Rights
under this Public License. Your has a corresponding meaning.
Section 2 -- Scope.
a. License grant.
1. Subject to the terms and conditions of this Public License,
the Licensor hereby grants You a worldwide, royalty-free,
non-sublicensable, non-exclusive, irrevocable license to
exercise the Licensed Rights in the Licensed Material to:
a. reproduce and Share the Licensed Material, in whole or
in part, for NonCommercial purposes only; and
b. produce, reproduce, and Share Adapted Material for
NonCommercial purposes only.
2. Exceptions and Limitations. For the avoidance of doubt, where
Exceptions and Limitations apply to Your use, this Public
License does not apply, and You do not need to comply with
its terms and conditions.
3. Term. The term of this Public License is specified in Section
6(a).
4. Media and formats; technical modifications allowed. The
Licensor authorizes You to exercise the Licensed Rights in
all media and formats whether now known or hereafter created,
and to make technical modifications necessary to do so. The
Licensor waives and/or agrees not to assert any right or
authority to forbid You from making technical modifications
necessary to exercise the Licensed Rights, including
technical modifications necessary to circumvent Effective
Technological Measures. For purposes of this Public License,
simply making modifications authorized by this Section 2(a)
(4) never produces Adapted Material.
5. Downstream recipients.
a. Offer from the Licensor -- Licensed Material. Every
recipient of the Licensed Material automatically
receives an offer from the Licensor to exercise the
Licensed Rights under the terms and conditions of this
Public License.
b. Additional offer from the Licensor -- Adapted Material.
Every recipient of Adapted Material from You
automatically receives an offer from the Licensor to
exercise the Licensed Rights in the Adapted Material
under the conditions of the Adapter's License You apply.
c. No downstream restrictions. You may not offer or impose
any additional or different terms or conditions on, or
apply any Effective Technological Measures to, the
Licensed Material if doing so restricts exercise of the
Licensed Rights by any recipient of the Licensed
Material.
6. No endorsement. Nothing in this Public License constitutes or
may be construed as permission to assert or imply that You
are, or that Your use of the Licensed Material is, connected
with, or sponsored, endorsed, or granted official status by,
the Licensor or others designated to receive attribution as
provided in Section 3(a)(1)(A)(i).
b. Other rights.
1. Moral rights, such as the right of integrity, are not
licensed under this Public License, nor are publicity,
privacy, and/or other similar personality rights; however, to
the extent possible, the Licensor waives and/or agrees not to
assert any such rights held by the Licensor to the limited
extent necessary to allow You to exercise the Licensed
Rights, but not otherwise.
2. Patent and trademark rights are not licensed under this
Public License.
3. To the extent possible, the Licensor waives any right to
collect royalties from You for the exercise of the Licensed
Rights, whether directly or through a collecting society
under any voluntary or waivable statutory or compulsory
licensing scheme. In all other cases the Licensor expressly
reserves any right to collect such royalties, including when
the Licensed Material is used other than for NonCommercial
purposes.
Section 3 -- License Conditions.
Your exercise of the Licensed Rights is expressly made subject to the
following conditions.
a. Attribution.
1. If You Share the Licensed Material (including in modified
form), You must:
a. retain the following if it is supplied by the Licensor
with the Licensed Material:
i. identification of the creator(s) of the Licensed
Material and any others designated to receive
attribution, in any reasonable manner requested by
the Licensor (including by pseudonym if
designated);
ii. a copyright notice;
iii. a notice that refers to this Public License;
iv. a notice that refers to the disclaimer of
warranties;
v. a URI or hyperlink to the Licensed Material to the
extent reasonably practicable;
b. indicate if You modified the Licensed Material and
retain an indication of any previous modifications; and
c. indicate the Licensed Material is licensed under this
Public License, and include the text of, or the URI or
hyperlink to, this Public License.
2. You may satisfy the conditions in Section 3(a)(1) in any
reasonable manner based on the medium, means, and context in
which You Share the Licensed Material. For example, it may be
reasonable to satisfy the conditions by providing a URI or
hyperlink to a resource that includes the required
information.
3. If requested by the Licensor, You must remove any of the
information required by Section 3(a)(1)(A) to the extent
reasonably practicable.
b. ShareAlike.
In addition to the conditions in Section 3(a), if You Share
Adapted Material You produce, the following conditions also apply.
1. The Adapter's License You apply must be a Creative Commons
license with the same License Elements, this version or
later, or a BY-NC-SA Compatible License.
2. You must include the text of, or the URI or hyperlink to, the
Adapter's License You apply. You may satisfy this condition
in any reasonable manner based on the medium, means, and
context in which You Share Adapted Material.
3. You may not offer or impose any additional or different terms
or conditions on, or apply any Effective Technological
Measures to, Adapted Material that restrict exercise of the
rights granted under the Adapter's License You apply.
Section 4 -- Sui Generis Database Rights.
Where the Licensed Rights include Sui Generis Database Rights that
apply to Your use of the Licensed Material:
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
to extract, reuse, reproduce, and Share all or a substantial
portion of the contents of the database for NonCommercial purposes
only;
b. if You include all or a substantial portion of the database
contents in a database in which You have Sui Generis Database
Rights, then the database in which You have Sui Generis Database
Rights (but not its individual contents) is Adapted Material,
including for purposes of Section 3(b); and
c. You must comply with the conditions in Section 3(a) if You Share
all or a substantial portion of the contents of the database.
For the avoidance of doubt, this Section 4 supplements and does not
replace Your obligations under this Public License where the Licensed
Rights include other Copyright and Similar Rights.
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
c. The disclaimer of warranties and limitation of liability provided
above shall be interpreted in a manner that, to the extent
possible, most closely approximates an absolute disclaimer and
waiver of all liability.
Section 6 -- Term and Termination.
a. This Public License applies for the term of the Copyright and
Similar Rights licensed here. However, if You fail to comply with
this Public License, then Your rights under this Public License
terminate automatically.
b. Where Your right to use the Licensed Material has terminated under
Section 6(a), it reinstates:
1. automatically as of the date the violation is cured, provided
it is cured within 30 days of Your discovery of the
violation; or
2. upon express reinstatement by the Licensor.
For the avoidance of doubt, this Section 6(b) does not affect any
right the Licensor may have to seek remedies for Your violations
of this Public License.
c. For the avoidance of doubt, the Licensor may also offer the
Licensed Material under separate terms or conditions or stop
distributing the Licensed Material at any time; however, doing so
will not terminate this Public License.
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
License.
Section 7 -- Other Terms and Conditions.
a. The Licensor shall not be bound by any additional or different
terms or conditions communicated by You unless expressly agreed.
b. Any arrangements, understandings, or agreements regarding the
Licensed Material not stated herein are separate from and
independent of the terms and conditions of this Public License.
Section 8 -- Interpretation.
a. For the avoidance of doubt, this Public License does not, and
shall not be interpreted to, reduce, limit, restrict, or impose
conditions on any use of the Licensed Material that could lawfully
be made without permission under this Public License.
b. To the extent possible, if any provision of this Public License is
deemed unenforceable, it shall be automatically reformed to the
minimum extent necessary to make it enforceable. If the provision
cannot be reformed, it shall be severed from this Public License
without affecting the enforceability of the remaining terms and
conditions.
c. No term or condition of this Public License will be waived and no
failure to comply consented to unless expressly agreed to by the
Licensor.
d. Nothing in this Public License constitutes or may be interpreted
as a limitation upon, or waiver of, any privileges and immunities
that apply to the Licensor or You, including from the legal
processes of any jurisdiction or authority.
=======================================================================
Creative Commons is not a party to its public
licenses. Notwithstanding, Creative Commons may elect to apply one of
its public licenses to material it publishes and in those instances
will be considered the “Licensor.” The text of the Creative Commons
public licenses is dedicated to the public domain under the CC0 Public
Domain Dedication. Except for the limited purpose of indicating that
material is shared under a Creative Commons public license or as
otherwise permitted by the Creative Commons policies published at
creativecommons.org/policies, Creative Commons does not authorize the
use of the trademark "Creative Commons" or any other trademark or logo
of Creative Commons without its prior written consent including,
without limitation, in connection with any unauthorized modifications
to any of its public licenses or any other arrangements,
understandings, or agreements concerning use of licensed material. For
the avoidance of doubt, this paragraph does not form part of the
public licenses.
Creative Commons may be contacted at creativecommons.org.
================================================
FILE: README.md
================================================
# EleGANt: Exquisite and Locally Editable GAN for Makeup Transfer
[![CC BY-NC-SA 4.0][cc-by-nc-sa-shield]][cc-by-nc-sa]
Official [PyTorch](https://pytorch.org/) implementation of ECCV 2022 paper "[EleGANt: Exquisite and Locally Editable GAN for Makeup Transfer](https://arxiv.org/abs/2207.09840)"
*Chenyu Yang, Wanrong He, Yingqing Xu, and Yang Gao*.

## Getting Started
- [Installation](assets/docs/install.md)
- [Prepare Dataset & Checkpoints](assets/docs/prepare.md)
## Test
To test our model, download the [weights](https://drive.google.com/drive/folders/1xzIS3Dfmsssxkk9OhhAS4svrZSPfQYRe?usp=sharing) of the trained model and run
```bash
python scripts/demo.py
```
Examples of makeup transfer results can be seen [here](assets/images/examples/).
## Train
To train a model from scratch, run
```bash
python scripts/train.py
```
## Customized Transfer
https://user-images.githubusercontent.com/61506577/180593092-ccadddff-76be-4b7b-921e-0d3b4cc27d9b.mp4
This is our demo of customized makeup editing. The interactive system is built upon [Streamlit](https://github.com/streamlit/streamlit) and the interface in `./training/inference.py`.
**Controllable makeup transfer.**

**Local makeup editing.**

## Citation
If this work is helpful for your research, please consider citing the following BibTeX entry.
```text
@article{yang2022elegant,
title={EleGANt: Exquisite and Locally Editable GAN for Makeup Transfer},
author={Yang, Chenyu and He, Wanrong and Xu, Yingqing and Gao, Yang}
journal={arXiv preprint arXiv:2207.09840},
year={2022}
}
```
## Acknowledgement
Some of the codes are build upon [PSGAN](https://github.com/wtjiang98/PSGAN) and [aster.Pytorch](https://github.com/ayumiymk/aster.pytorch).
## License
This work is licensed under a
[Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License][cc-by-nc-sa].
[![CC BY-NC-SA 4.0][cc-by-nc-sa-image]][cc-by-nc-sa]
[cc-by-nc-sa]: http://creativecommons.org/licenses/by-nc-sa/4.0/
[cc-by-nc-sa-image]: https://licensebuttons.net/l/by-nc-sa/4.0/88x31.png
[cc-by-nc-sa-shield]: https://img.shields.io/badge/License-CC%20BY--NC--SA%204.0-lightgrey.svg
================================================
FILE: assets/docs/install.md
================================================
# Installation Instructions
This code was tested on Ubuntu 20.04 with CUDA 11.1.
**a. Create a conda virtual environment and activate it.**
```bash
conda create -n elegant python=3.8
conda activate elegant
```
**b. Install PyTorch and torchvision following the [official instructions](https://pytorch.org/).**
```bash
pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
```
**c. Install other required libaries.**
```bash
pip install opencv-python matplotlib dlib fvcore
```
================================================
FILE: assets/docs/prepare.md
================================================
# Preparation Instructions
Clone this repository and prepare the dataset and weights through the following steps:
**a. Prepare model weights for face detection.**
Download the weights of [dlib](https://github.com/davisking/dlib) face detector of 68 landmarks [here](http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2). Unzip it and move it to the directory `./faceutils/dlibutils`.
Download the weights of BiSeNet ([PyTorch implementation](https://github.com/zllrunning/face-parsing.PyTorch)) for face parsing [here](https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812). Rename it as `resnet.pth` and move it to the directory `./faceutils/mask`.
**b. Prepare Makeup Transfer (MT) dataset.**
Download raw data of the MT Dataset [here](https://github.com/wtjiang98/PSGAN) and unzip it into sub directory `./data`.
Run the following command to preprocess data:
```bash
python training/preprocess.py
```
Your data directory should look like:
```text
data
└── MT-Dataset
├── images
│ ├── makeup
│ └── non-makeup
├── segs
│ ├── makeup
│ └── non-makeup
├── lms
│ ├── makeup
│ └── non-makeup
├── makeup.txt
├── non-makeup.txt
└── ...
```
**c. Download weights of trained EleGANt.**
The weights of our trained model can be download [here](https://drive.google.com/drive/folders/1xzIS3Dfmsssxkk9OhhAS4svrZSPfQYRe?usp=sharing). Put it under the directory `./ckpts`.
================================================
FILE: concern/__init__.py
================================================
from .image import load_image
================================================
FILE: concern/image.py
================================================
import numpy as np
import cv2
from io import BytesIO
def load_image(path):
with path.open("rb") as reader:
data = np.fromstring(reader.read(), dtype=np.uint8)
img = cv2.imdecode(data, cv2.IMREAD_COLOR)
if img is None:
return
img = img[..., ::-1]
return img
def resize_by_max(image, max_side=512, force=False):
h, w = image.shape[:2]
if max(h, w) < max_side and not force:
return image
ratio = max(h, w) / max_side
w = int(w / ratio + 0.5)
h = int(h / ratio + 0.5)
return cv2.resize(image, (w, h))
def image2buffer(image):
is_success, buffer = cv2.imencode(".jpg", image)
if not is_success:
return None
return BytesIO(buffer)
================================================
FILE: concern/track.py
================================================
import time
import torch
class Track:
def __init__(self):
self.log_point = time.time()
self.enable_track = False
def track(self, mark):
if not self.enable_track:
return
if torch.cuda.is_available():
torch.cuda.synchronize()
print("{} memory:".format(mark), torch.cuda.memory_allocated() / 1024 / 1024, "M")
print("{} time cost:".format(mark), time.time() - self.log_point)
self.log_point = time.time()
================================================
FILE: concern/visualize.py
================================================
import numpy as np
import cv2
def channel_first(image, format):
return image.transpose(
format.index("C"), format.index("H"), format.index("W"))
def mask2image(mask:np.array, format="HWC"):
H, W = mask.shape
canvas = np.zeros((H, W, 3), dtype=np.uint8)
for i in range(int(mask.max())):
color = np.random.rand(1, 1, 3) * 255
canvas += (mask == i)[:, :, None] * color.astype(np.uint8)
return canvas
def draw_points(image, points, color=(255, 0, 0)):
for point in points:
print(int(point[1]), int(point[0]))
image = cv2.circle(image, (int(point[1]), int(point[0])), 3, color)
if hasattr(image, "get"):
return image.get()
return image
================================================
FILE: faceutils/__init__.py
================================================
#!/usr/bin/python
# -*- encoding: utf-8 -*-
#from . import faceplusplus as fpp
from . import dlibutils as dlib
from . import mask
================================================
FILE: faceutils/dlibutils/__init__.py
================================================
#!/usr/bin/python
# -*- encoding: utf-8 -*-
from .main import detect, crop, landmarks, crop_from_array
================================================
FILE: faceutils/dlibutils/main.py
================================================
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import os.path as osp
import numpy as np
from PIL import Image
import dlib
import cv2
from concern.image import resize_by_max
detector = dlib.get_frontal_face_detector()
predictor = dlib.shape_predictor(osp.split(osp.realpath(__file__))[0] + '/shape_predictor_68_face_landmarks.dat')
def detect(image: Image) -> 'faces':
image = np.asarray(image)
h, w = image.shape[:2]
image = resize_by_max(image, 361)
actual_h, actual_w = image.shape[:2]
faces_on_small = detector(image, 1)
faces = dlib.rectangles()
for face in faces_on_small:
faces.append(
dlib.rectangle(
int(face.left() / actual_w * w + 0.5),
int(face.top() / actual_h * h + 0.5),
int(face.right() / actual_w * w + 0.5),
int(face.bottom() / actual_h * h + 0.5)
)
)
return faces
def crop(image: Image, face, up_ratio, down_ratio, width_ratio) -> (Image, 'face'):
width, height = image.size
face_height = face.height()
face_width = face.width()
delta_up = up_ratio * face_height
delta_down = down_ratio * face_height
delta_width = width_ratio * width
img_left = int(max(0, face.left() - delta_width))
img_top = int(max(0, face.top() - delta_up))
img_right = int(min(width, face.right() + delta_width))
img_bottom = int(min(height, face.bottom() + delta_down))
image = image.crop((img_left, img_top, img_right, img_bottom))
face = dlib.rectangle(face.left() - img_left, face.top() - img_top,
face.right() - img_left, face.bottom() - img_top)
face_expand = dlib.rectangle(img_left, img_top, img_right, img_bottom)
center = face_expand.center()
width, height = image.size
# import ipdb; ipdb.set_trace()
crop_left = img_left
crop_top = img_top
crop_right = img_right
crop_bottom = img_bottom
if width > height:
left = int(center.x - height / 2)
right = int(center.x + height / 2)
if left < 0:
left, right = 0, height
elif right > width:
left, right = width - height, width
image = image.crop((left, 0, right, height))
face = dlib.rectangle(face.left() - left, face.top(),
face.right() - left, face.bottom())
crop_left += left
crop_right = crop_left + height
elif width < height:
top = int(center.y - width / 2)
bottom = int(center.y + width / 2)
if top < 0:
top, bottom = 0, width
elif bottom > height:
top, bottom = height - width, height
image = image.crop((0, top, width, bottom))
face = dlib.rectangle(face.left(), face.top() - top,
face.right(), face.bottom() - top)
crop_top += top
crop_bottom = crop_top + width
crop_face = dlib.rectangle(crop_left, crop_top, crop_right, crop_bottom)
return image, face, crop_face
def crop_by_image_size(image: Image, face) -> (Image, 'face'):
center = face.center()
width, height = image.size
if width > height:
left = int(center.x - height / 2)
right = int(center.x + height / 2)
if left < 0:
left, right = 0, height
elif right > width:
left, right = width - height, width
image = image.crop((left, 0, right, height))
face = dlib.rectangle(face.left() - left, face.top(),
face.right() - left, face.bottom())
elif width < height:
top = int(center.y - width / 2)
bottom = int(center.y + width / 2)
if top < 0:
top, bottom = 0, width
elif bottom > height:
top, bottom = height - width, height
image = image.crop((0, top, width, bottom))
face = dlib.rectangle(face.left(), face.top() - top,
face.right(), face.bottom() - top)
return image, face
def landmarks(image: Image, face):
shape = predictor(np.asarray(image), face).parts()
return np.array([[p.y, p.x] for p in shape])
def crop_from_array(image: np.array, face) -> (np.array, 'face'):
ratio = 0.20 / 0.85 # delta_size / face_size
height, width = image.shape[:2]
face_height = face.height()
face_width = face.width()
delta_height = ratio * face_height
delta_width = ratio * width
img_left = int(max(0, face.left() - delta_width))
img_top = int(max(0, face.top() - delta_height))
img_right = int(min(width, face.right() + delta_width))
img_bottom = int(min(height, face.bottom() + delta_height))
image = image[img_top:img_bottom, img_left:img_right]
face = dlib.rectangle(face.left() - img_left, face.top() - img_top,
face.right() - img_left, face.bottom() - img_top)
center = face.center()
height, width = image.shape[:2]
if width > height:
left = int(center.x - height / 2)
right = int(center.x + height / 2)
if left < 0:
left, right = 0, height
elif right > width:
left, right = width - height, width
image = image[0:height, left:right]
face = dlib.rectangle(face.left() - left, face.top(),
face.right() - left, face.bottom())
elif width < height:
top = int(center.y - width / 2)
bottom = int(center.y + width / 2)
if top < 0:
top, bottom = 0, width
elif bottom > height:
top, bottom = height - width, height
image = image[top:bottom, 0:width]
face = dlib.rectangle(face.left(), face.top() - top,
face.right(), face.bottom() - top)
return image, face
================================================
FILE: faceutils/mask/__init__.py
================================================
#!/usr/bin/python
# -*- encoding: utf-8 -*-
from .main import FaceParser
================================================
FILE: faceutils/mask/main.py
================================================
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import os.path as osp
import numpy as np
import cv2
from PIL import Image
import torch
import torchvision.transforms as transforms
from .model import BiSeNet
class FaceParser:
def __init__(self, device="cpu"):
mapper = [0, 1, 2, 3, 4, 5, 0, 11, 12, 0, 6, 8, 7, 9, 13, 0, 0, 10, 0]
self.device = device
self.dic = torch.tensor(mapper, device=device).unsqueeze(1)
save_pth = osp.split(osp.realpath(__file__))[0] + '/resnet.pth'
net = BiSeNet(n_classes=19)
net.load_state_dict(torch.load(save_pth, map_location=device))
self.net = net.to(device).eval()
self.to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
def parse(self, image: Image):
assert image.shape[:2] == (512, 512)
with torch.no_grad():
image = self.to_tensor(image).to(self.device)
image = torch.unsqueeze(image, 0)
out = self.net(image)[0]
parsing = out.squeeze(0).argmax(0)
parsing = torch.nn.functional.embedding(parsing, self.dic)
return parsing.float().squeeze(2)
================================================
FILE: faceutils/mask/model.py
================================================
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from .resnet import Resnet18
class ConvBNReLU(nn.Module):
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(in_chan,
out_chan,
kernel_size = ks,
stride = stride,
padding = padding,
bias = False)
self.bn = nn.BatchNorm2d(out_chan)
self.init_weight()
def forward(self, x):
x = self.conv(x)
x = F.relu(self.bn(x))
return x
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
class BiSeNetOutput(nn.Module):
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
super(BiSeNetOutput, self).__init__()
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
self.init_weight()
def forward(self, x):
x = self.conv(x)
x = self.conv_out(x)
return x
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.BatchNorm2d):
nowd_params += list(module.parameters())
return wd_params, nowd_params
class AttentionRefinementModule(nn.Module):
def __init__(self, in_chan, out_chan, *args, **kwargs):
super(AttentionRefinementModule, self).__init__()
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
self.bn_atten = nn.BatchNorm2d(out_chan)
self.sigmoid_atten = nn.Sigmoid()
self.init_weight()
def forward(self, x):
feat = self.conv(x)
atten = F.avg_pool2d(feat, feat.size()[2:])
atten = self.conv_atten(atten)
atten = self.bn_atten(atten)
atten = self.sigmoid_atten(atten)
out = torch.mul(feat, atten)
return out
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
class ContextPath(nn.Module):
def __init__(self, *args, **kwargs):
super(ContextPath, self).__init__()
self.resnet = Resnet18()
self.arm16 = AttentionRefinementModule(256, 128)
self.arm32 = AttentionRefinementModule(512, 128)
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
self.init_weight()
def forward(self, x):
H0, W0 = x.size()[2:]
feat8, feat16, feat32 = self.resnet(x)
H8, W8 = feat8.size()[2:]
H16, W16 = feat16.size()[2:]
H32, W32 = feat32.size()[2:]
avg = F.avg_pool2d(feat32, feat32.size()[2:])
avg = self.conv_avg(avg)
avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
feat32_arm = self.arm32(feat32)
feat32_sum = feat32_arm + avg_up
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
feat32_up = self.conv_head32(feat32_up)
feat16_arm = self.arm16(feat16)
feat16_sum = feat16_arm + feat32_up
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
feat16_up = self.conv_head16(feat16_up)
return feat8, feat16_up, feat32_up # x8, x8, x16
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.BatchNorm2d):
nowd_params += list(module.parameters())
return wd_params, nowd_params
### This is not used, since I replace this with the resnet feature with the same size
class SpatialPath(nn.Module):
def __init__(self, *args, **kwargs):
super(SpatialPath, self).__init__()
self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
self.init_weight()
def forward(self, x):
feat = self.conv1(x)
feat = self.conv2(feat)
feat = self.conv3(feat)
feat = self.conv_out(feat)
return feat
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.BatchNorm2d):
nowd_params += list(module.parameters())
return wd_params, nowd_params
class FeatureFusionModule(nn.Module):
def __init__(self, in_chan, out_chan, *args, **kwargs):
super(FeatureFusionModule, self).__init__()
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
self.conv1 = nn.Conv2d(out_chan,
out_chan//4,
kernel_size = 1,
stride = 1,
padding = 0,
bias = False)
self.conv2 = nn.Conv2d(out_chan//4,
out_chan,
kernel_size = 1,
stride = 1,
padding = 0,
bias = False)
self.relu = nn.ReLU(inplace=True)
self.sigmoid = nn.Sigmoid()
self.init_weight()
def forward(self, fsp, fcp):
fcat = torch.cat([fsp, fcp], dim=1)
feat = self.convblk(fcat)
atten = F.avg_pool2d(feat, feat.size()[2:])
atten = self.conv1(atten)
atten = self.relu(atten)
atten = self.conv2(atten)
atten = self.sigmoid(atten)
feat_atten = torch.mul(feat, atten)
feat_out = feat_atten + feat
return feat_out
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.BatchNorm2d):
nowd_params += list(module.parameters())
return wd_params, nowd_params
class BiSeNet(nn.Module):
def __init__(self, n_classes, *args, **kwargs):
super(BiSeNet, self).__init__()
self.cp = ContextPath()
## here self.sp is deleted
self.ffm = FeatureFusionModule(256, 256)
self.conv_out = BiSeNetOutput(256, 256, n_classes)
self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
# self.init_weight()
def forward(self, x):
H, W = x.size()[2:]
feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
feat_fuse = self.ffm(feat_sp, feat_cp8)
feat_out = self.conv_out(feat_fuse)
feat_out16 = self.conv_out16(feat_cp8)
feat_out32 = self.conv_out32(feat_cp16)
feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
return feat_out, feat_out16, feat_out32
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
for name, child in self.named_children():
child_wd_params, child_nowd_params = child.get_params()
if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
lr_mul_wd_params += child_wd_params
lr_mul_nowd_params += child_nowd_params
else:
wd_params += child_wd_params
nowd_params += child_nowd_params
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
if __name__ == "__main__":
net = BiSeNet(19)
net.cuda()
net.eval()
in_ten = torch.randn(16, 3, 640, 480).cuda()
out, out16, out32 = net(in_ten)
print(out.shape)
net.get_params()
================================================
FILE: faceutils/mask/resnet.py
================================================
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as modelzoo
resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(nn.Module):
def __init__(self, in_chan, out_chan, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(in_chan, out_chan, stride)
self.bn1 = nn.BatchNorm2d(out_chan)
self.conv2 = conv3x3(out_chan, out_chan)
self.bn2 = nn.BatchNorm2d(out_chan)
self.relu = nn.ReLU(inplace=True)
self.downsample = None
if in_chan != out_chan or stride != 1:
self.downsample = nn.Sequential(
nn.Conv2d(in_chan, out_chan,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_chan),
)
def forward(self, x):
residual = self.conv1(x)
residual = F.relu(self.bn1(residual))
residual = self.conv2(residual)
residual = self.bn2(residual)
shortcut = x
if self.downsample is not None:
shortcut = self.downsample(x)
out = shortcut + residual
out = self.relu(out)
return out
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
for i in range(bnum-1):
layers.append(BasicBlock(out_chan, out_chan, stride=1))
return nn.Sequential(*layers)
class Resnet18(nn.Module):
def __init__(self):
super(Resnet18, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
# self.init_weight()
def forward(self, x):
x = self.conv1(x)
x = F.relu(self.bn1(x))
x = self.maxpool(x)
x = self.layer1(x)
feat8 = self.layer2(x) # 1/8
feat16 = self.layer3(feat8) # 1/16
feat32 = self.layer4(feat16) # 1/32
return feat8, feat16, feat32
def init_weight(self):
state_dict = modelzoo.load_url(resnet18_url)
self_state_dict = self.state_dict()
for k, v in state_dict.items():
if 'fc' in k: continue
self_state_dict.update({k: v})
self.load_state_dict(self_state_dict)
def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.BatchNorm2d):
nowd_params += list(module.parameters())
return wd_params, nowd_params
if __name__ == "__main__":
net = Resnet18()
x = torch.randn(16, 3, 224, 224)
out = net(x)
print(out[0].size())
print(out[1].size())
print(out[2].size())
net.get_params()
================================================
FILE: models/__init__.py
================================================
================================================
FILE: models/elegant.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from .modules.module_base import ResidualBlock_IN, Downsample, Upsample, PositionalEmbedding, MergeBlock
from .modules.module_attn import Attention_apply, FeedForwardLayer, MultiheadAttention
from .modules.sow_attention import SowAttention
from .modules.tps_transform import tps_spatial_transform
class Generator(nn.ModuleDict):
"""Generator. Encoder-Decoder Architecture."""
def __init__(self, conv_dim=64, image_size=256, num_layer_e=2, num_layer_d=1, window_size=16, use_ff=False,
merge_mode='conv', num_head=1, double_encoder=False, **unused):
super(Generator, self).__init__()
# -------------------------- Encoder --------------------------
layers = nn.Conv2d(3, conv_dim, kernel_size=7, stride=1, padding=3, bias=False)
self.add_module('in_conv', layers)
# Down-Sampling & Bottleneck
curr_dim = conv_dim; feature_size = image_size
for i in range(2):
layers = Downsample(curr_dim, curr_dim * 2, affine=True)
self.add_module('down_{:d}'.format(i+1), layers)
curr_dim = curr_dim * 2; feature_size = feature_size // 2
self.add_module('e_bottleneck_{:d}'.format(i+1),
nn.Sequential(*[ResidualBlock_IN(curr_dim, curr_dim, affine=True) for j in range(num_layer_e)])
)
### second encoder
self.double_encoder = double_encoder
if self.double_encoder:
layers = nn.Conv2d(3, conv_dim, kernel_size=7, stride=1, padding=3, bias=False)
self.add_module('in_conv_s', layers)
# Down-Sampling & Bottleneck
curr_dim = conv_dim; feature_size = image_size
for i in range(2):
layers = Downsample(curr_dim, curr_dim * 2, affine=True)
self.add_module('down_{:d}_s'.format(i+1), layers)
curr_dim = curr_dim * 2; feature_size = feature_size // 2
self.add_module('e_bottleneck_{:d}_s'.format(i+1),
nn.Sequential(*[ResidualBlock_IN(curr_dim, curr_dim, affine=True) for j in range(num_layer_e)])
)
# --------------------------- Transfer ----------------------------
curr_dim = conv_dim; feature_size = image_size
self.use_ff = use_ff
for i in range(2):
curr_dim = curr_dim * 2; feature_size = feature_size // 2
self.add_module('embedding_{:d}'.format(i+1), PositionalEmbedding(
embedding_dim=136,
feature_size=feature_size,
max_size=image_size,
embedding_type='l2_norm'
))
if i < 1:
self.add_module('attention_extract_{:d}'.format(i+1), SowAttention(
window_size=window_size,
in_channels=curr_dim + 136,
proj_channels=curr_dim + 136,
value_channels=curr_dim,
out_channels=curr_dim,
num_heads=num_head
))
else:
self.add_module('attention_extract_{:d}'.format(i+1), MultiheadAttention(
in_channels=curr_dim + 136,
proj_channels=curr_dim + 136,
value_channels=curr_dim,
out_channels=curr_dim,
num_heads=num_head
))
if use_ff:
self.add_module('feedforward_{:d}'.format(i+1), FeedForwardLayer(curr_dim, curr_dim))
self.add_module('attention_apply_{:d}'.format(i+1), Attention_apply(curr_dim))
# --------------------------- Decoder ----------------------------
# Bottleneck & Up-Sampling & Merge
for i in range(2):
self.add_module('d_bottleneck_{:d}'.format(i+1),
nn.Sequential(*[ResidualBlock_IN(curr_dim, curr_dim, affine=True) for j in range(num_layer_d)])
)
layers = Upsample(curr_dim, curr_dim // 2, affine=True)
self.add_module('up_{:d}'.format(i+1), layers)
curr_dim = curr_dim // 2
if i < 1:
self.add_module('merge_{:d}'.format(i+1), MergeBlock(merge_mode, curr_dim))
layers = nn.Sequential(
nn.InstanceNorm2d(curr_dim, affine=True),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False),
)
self.add_module('out_conv', layers)
def get_transfer_input(self, image, mask, diff, lms, is_reference=False):
feature_size = image.shape[2]; scale_factor = 1.0
fea_list, mask_list, diff_list, lms_list = [], [], [], []
# input conv
if self.double_encoder and is_reference:
fea = self['in_conv_s'](image)
else:
fea = self['in_conv'](image)
# down-sampling & bottleneck
for i in range(2):
if self.double_encoder and is_reference:
fea = self['down_{:d}_s'.format(i+1)](fea)
fea_ = self['e_bottleneck_{:d}_s'.format(i+1)](fea)
else:
fea = self['down_{:d}'.format(i+1)](fea)
fea_ = self['e_bottleneck_{:d}'.format(i+1)](fea)
fea_list.append(fea_)
feature_size = feature_size // 2; scale_factor = scale_factor * 0.5
mask_ = F.interpolate(mask, feature_size, mode='nearest')
mask_list.append(mask_)
diff_ = self['embedding_{:d}'.format(i+1)](diff, mask)
diff_list.append(diff_)
lms_ = lms * scale_factor
lms_list.append(lms_)
return [fea_list, mask_list, diff_list, lms_list]
def get_transfer_output(self, fea_c_list, mask_c_list, diff_c_list, lms_c_list,
fea_s_list, mask_s_list, diff_s_list, lms_s_list):
attn_out_list = []
for i in range(2):
feature_size = fea_c_list[i].shape[2]
# align
if i == 0:
fea_s_ = self.tps_align(feature_size, lms_s_list[i], lms_c_list[i], fea_s_list[i])
mask_s_ = self.tps_align(feature_size, lms_s_list[i], lms_c_list[i], mask_s_list[i], 'nearest')
diff_s_ = self.tps_align(feature_size, lms_s_list[i], lms_c_list[i], diff_s_list[i], 'nearest')
else:
fea_s_ = fea_s_list[i]
mask_s_ = mask_s_list[i]
diff_s_ = diff_s_list[i]
# transfer
input_q = torch.cat((fea_c_list[i], diff_c_list[i]), dim=1)
input_k = torch.cat((fea_s_, diff_s_), dim=1)
attn_out = self['attention_extract_{:d}'.format(i+1)](input_q, input_k, fea_s_, mask_c_list[i], mask_s_)
if self.use_ff:
attn_out = self['feedforward_{:d}'.format(i+1)](attn_out)
attn_out_list.append(attn_out)
return attn_out_list
def decode(self, fea_c_list, attn_out_list):
# apply
for i in range(2):
fea_c_ = self['attention_apply_{:d}'.format(i+1)](fea_c_list[i], attn_out_list[i])
fea_c_ = self['d_bottleneck_{:d}'.format(2-i)](fea_c_)
fea_c_list[i] = fea_c_
# up-sampling & merge
fea_c = fea_c_list[1]
for i in range(2):
fea_c = self['up_{:d}'.format(i+1)](fea_c)
if i < 1:
fea_c = self['merge_{:d}'.format(i+1)](fea_c_list[0], fea_c)
fea_c = self['out_conv'](fea_c)
return fea_c
def forward(self, c, s, mask_c, mask_s, diff_c, diff_s, lms_c, lms_s):
"""
c: content, stands for source image. shape: (b, c, h, w)
s: style, stands for reference image. shape: (b, c, h, w)
mask_c: (b, c', h, w)
diff: (b, d, h, w)
lms: (b, K, 2)
"""
transfer_input_c = self.get_transfer_input(c, mask_c, diff_c, lms_c)
transfer_input_s = self.get_transfer_input(s, mask_s, diff_s, lms_s, True)
attn_out_list = self.get_transfer_output(*transfer_input_c, *transfer_input_s)
fea_c = self.decode(transfer_input_c[0], attn_out_list)
return fea_c
def tps_align(self, feature_size, lms_s, lms_c, fea_s, sample_mode='bilinear'):
'''
fea: (B, C, H, W), lms: (B, K, 2)
'''
fea_out = []
for l_s, l_c, f_s in zip(lms_s, lms_c, fea_s):
l_c = torch.flip(l_c, dims=[1]) / (feature_size - 1)
l_s = (torch.flip(l_s, dims=[1]) / (feature_size - 1)).unsqueeze(0)
f_s = f_s.unsqueeze(0) # (1, C, H, W)
fea_trans, _ = tps_spatial_transform(feature_size, feature_size, l_c, f_s, l_s, sample_mode)
fea_out.append(fea_trans)
return torch.cat(fea_out, dim=0)
================================================
FILE: models/loss.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from .modules.histogram_matching import histogram_matching
from .modules.pseudo_gt import fine_align, expand_area, mask_blur
class GANLoss(nn.Module):
"""Define different GAN objectives.
The GANLoss class abstracts away the need to create the target label tensor
that has the same size as the input.
"""
def __init__(self, gan_mode='lsgan', target_real_label=1.0, target_fake_label=0.0):
""" Initialize the GANLoss class.
Parameters:
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
target_real_label (bool) - - label for a real image
target_fake_label (bool) - - label of a fake image
Note: Do not use sigmoid as the last layer of Discriminator.
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
"""
super(GANLoss, self).__init__()
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))
self.gan_mode = gan_mode
if gan_mode == 'lsgan':
self.loss = nn.MSELoss()
elif gan_mode == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
else:
raise NotImplementedError('gan mode %s not implemented' % gan_mode)
def forward(self, prediction, target_is_real):
"""Calculate loss given Discriminator's output and grount truth labels.
Parameters:
prediction (tensor) - - tpyically the prediction output from a discriminator
target_is_real (bool) - - if the ground truth label is for real images or fake images
Returns:
the calculated loss.
"""
if target_is_real:
target_tensor = self.real_label
else:
target_tensor = self.fake_label
target_tensor = target_tensor.expand_as(prediction).to(prediction.device)
loss = self.loss(prediction, target_tensor)
return loss
def norm(x: torch.Tensor):
return x * 2 - 1
def de_norm(x: torch.Tensor):
out = (x + 1) / 2
return out.clamp(0, 1)
def masked_his_match(image_s, image_r, mask_s, mask_r):
'''
image: (3, h, w)
mask: (1, h, w)
'''
index_tmp = torch.nonzero(mask_s)
x_A_index = index_tmp[:, 1]
y_A_index = index_tmp[:, 2]
index_tmp = torch.nonzero(mask_r)
x_B_index = index_tmp[:, 1]
y_B_index = index_tmp[:, 2]
image_s = (de_norm(image_s) * 255) #[-1, 1] -> [0, 255]
image_r = (de_norm(image_r) * 255)
source_masked = image_s * mask_s
target_masked = image_r * mask_r
source_match = histogram_matching(
source_masked, target_masked,
[x_A_index, y_A_index, x_B_index, y_B_index])
source_match = source_match.to(image_s.device)
return norm(source_match / 255) #[0, 255] -> [-1, 1]
def generate_pgt(image_s, image_r, mask_s, mask_r, lms_s, lms_r, margins, blend_alphas, img_size=None):
"""
input_data: (3, h, w)
mask: (c, h, w), lip, skin, left eye, right eye
"""
if img_size is None:
img_size = image_s.shape[1]
pgt = image_s.detach().clone()
# skin match
skin_match = masked_his_match(image_s, image_r, mask_s[1:2], mask_r[1:2])
pgt = (1 - mask_s[1:2]) * pgt + mask_s[1:2] * skin_match
# lip match
lip_match = masked_his_match(image_s, image_r, mask_s[0:1], mask_r[0:1])
pgt = (1 - mask_s[0:1]) * pgt + mask_s[0:1] * lip_match
# eye match
mask_s_eye = expand_area(mask_s[2:4].sum(dim=0, keepdim=True), margins['eye']) * mask_s[1:2]
mask_r_eye = expand_area(mask_r[2:4].sum(dim=0, keepdim=True), margins['eye']) * mask_r[1:2]
eye_match = masked_his_match(image_s, image_r, mask_s_eye, mask_r_eye)
mask_s_eye_blur = mask_blur(mask_s_eye, blur_size=5, mode='valid')
pgt = (1 - mask_s_eye_blur) * pgt + mask_s_eye_blur * eye_match
# tps align
pgt = fine_align(img_size, lms_r, lms_s, image_r, pgt, mask_r, mask_s, margins, blend_alphas)
return pgt
class LinearAnnealingFn():
"""
define the linear annealing function with milestones
"""
def __init__(self, milestones, f_values):
assert len(milestones) == len(f_values)
self.milestones = milestones
self.f_values = f_values
def __call__(self, t:int):
if t < self.milestones[0]:
return self.f_values[0]
elif t >= self.milestones[-1]:
return self.f_values[-1]
else:
for r in range(len(self.milestones) - 1):
if self.milestones[r] <= t < self.milestones[r+1]:
return ((t - self.milestones[r]) * self.f_values[r+1] \
+ (self.milestones[r+1] - t) * self.f_values[r]) \
/ (self.milestones[r+1] - self.milestones[r])
class ComposePGT(nn.Module):
def __init__(self, margins, skin_alpha, eye_alpha, lip_alpha):
super(ComposePGT, self).__init__()
self.margins = margins
self.blend_alphas = {
'skin':skin_alpha,
'eye':eye_alpha,
'lip':lip_alpha
}
@torch.no_grad()
def forward(self, sources, targets, mask_srcs, mask_tars, lms_srcs, lms_tars):
pgts = []
for source, target, mask_src, mask_tar, lms_src, lms_tar in\
zip(sources, targets, mask_srcs, mask_tars, lms_srcs, lms_tars):
pgt = generate_pgt(source, target, mask_src, mask_tar, lms_src, lms_tar,
self.margins, self.blend_alphas)
pgts.append(pgt)
pgts = torch.stack(pgts, dim=0)
return pgts
class AnnealingComposePGT(nn.Module):
def __init__(self, margins,
skin_alpha_milestones, skin_alpha_values,
eye_alpha_milestones, eye_alpha_values,
lip_alpha_milestones, lip_alpha_values
):
super(AnnealingComposePGT, self).__init__()
self.margins = margins
self.skin_alpha_fn = LinearAnnealingFn(skin_alpha_milestones, skin_alpha_values)
self.eye_alpha_fn = LinearAnnealingFn(eye_alpha_milestones, eye_alpha_values)
self.lip_alpha_fn = LinearAnnealingFn(lip_alpha_milestones, lip_alpha_values)
self.t = 0
self.blend_alphas = {}
self.step()
def step(self):
self.t += 1
self.blend_alphas['skin'] = self.skin_alpha_fn(self.t)
self.blend_alphas['eye'] = self.eye_alpha_fn(self.t)
self.blend_alphas['lip'] = self.lip_alpha_fn(self.t)
@torch.no_grad()
def forward(self, sources, targets, mask_srcs, mask_tars, lms_srcs, lms_tars):
pgts = []
for source, target, mask_src, mask_tar, lms_src, lms_tar in\
zip(sources, targets, mask_srcs, mask_tars, lms_srcs, lms_tars):
pgt = generate_pgt(source, target, mask_src, mask_tar, lms_src, lms_tar,
self.margins, self.blend_alphas)
pgts.append(pgt)
pgts = torch.stack(pgts, dim=0)
return pgts
class MakeupLoss(nn.Module):
"""
Define the makeup loss w.r.t pseudo ground truth
"""
def __init__(self):
super(MakeupLoss, self).__init__()
def forward(self, x, target, mask=None):
if mask is None:
return F.l1_loss(x, target)
else:
return F.l1_loss(x * mask, target * mask)
================================================
FILE: models/model.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import VGG as TVGG
from torchvision.models.vgg import load_state_dict_from_url, model_urls, cfgs
from .modules.spectral_norm import spectral_norm as SpectralNorm
from .elegant import Generator
def get_generator(config):
kwargs = {
'conv_dim':config.MODEL.G_CONV_DIM,
'image_size':config.DATA.IMG_SIZE,
'num_head':config.MODEL.NUM_HEAD,
'double_encoder':config.MODEL.DOUBLE_E,
'use_ff':config.MODEL.USE_FF,
'num_layer_e':config.MODEL.NUM_LAYER_E,
'num_layer_d':config.MODEL.NUM_LAYER_D,
'window_size':config.MODEL.WINDOW_SIZE,
'merge_mode':config.MODEL.MERGE_MODE
}
G = Generator(**kwargs)
return G
def get_discriminator(config):
kwargs = {
'input_channel': 3,
'conv_dim':config.MODEL.D_CONV_DIM,
'num_layers':config.MODEL.D_REPEAT_NUM,
'norm':config.MODEL.D_TYPE
}
D = Discriminator(**kwargs)
return D
class Discriminator(nn.Module):
"""Discriminator. PatchGAN."""
def __init__(self, input_channel=3, conv_dim=64, num_layers=3, norm='SN', **unused):
super(Discriminator, self).__init__()
layers = []
if norm=='SN':
layers.append(SpectralNorm(nn.Conv2d(input_channel, conv_dim, kernel_size=4, stride=2, padding=1)))
else:
layers.append(nn.Conv2d(input_channel, conv_dim, kernel_size=4, stride=2, padding=1))
layers.append(nn.LeakyReLU(0.01, inplace=True))
curr_dim = conv_dim
for i in range(1, num_layers):
if norm=='SN':
layers.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1)))
else:
layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
layers.append(nn.LeakyReLU(0.01, inplace=True))
curr_dim = curr_dim * 2
#k_size = int(image_size / np.power(2, repeat_num))
if norm=='SN':
layers.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=1, padding=1)))
else:
layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=1, padding=1))
layers.append(nn.LeakyReLU(0.01, inplace=True))
curr_dim = curr_dim * 2
self.main = nn.Sequential(*layers)
if norm=='SN':
self.conv1 = SpectralNorm(nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False))
else:
self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False)
def forward(self, x):
h = self.main(x)
out_makeup = self.conv1(h)
return out_makeup
class VGG(TVGG):
def forward(self, x):
x = self.features(x)
return x
def make_layers(cfg, batch_norm=False):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
if pretrained:
kwargs['init_weights'] = False
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
def vgg16(pretrained=False, progress=True, **kwargs):
r"""VGG 16-layer model (configuration "D")
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)
================================================
FILE: models/modules/__init__.py
================================================
================================================
FILE: models/modules/histogram_matching.py
================================================
import copy
import torch
def cal_hist(image):
"""
cal cumulative hist for channel list
"""
hists = []
for i in range(0, 3):
channel = image[i]
# channel = image[i, :, :]
channel = torch.from_numpy(channel)
# hist, _ = np.histogram(channel, bins=256, range=(0,255))
hist = torch.histc(channel, bins=256, min=0, max=256)
hist = hist.numpy()
# refHist=hist.view(256,1)
sum = hist.sum()
pdf = [v / (sum + 1e-10) for v in hist]
for i in range(1, 256):
pdf[i] = pdf[i - 1] + pdf[i]
hists.append(pdf)
return hists
def cal_trans(ref, adj):
"""
calculate transfer function
algorithm refering to wiki item: Histogram matching
"""
table = list(range(0, 256))
for i in list(range(1, 256)):
for j in list(range(1, 256)):
if ref[i] >= adj[j - 1] and ref[i] <= adj[j]:
table[i] = j
break
table[255] = 255
return table
def histogram_matching(dstImg, refImg, index):
"""
perform histogram matching
dstImg is transformed to have the same the histogram with refImg's
index[0], index[1]: the index of pixels that need to be transformed in dstImg
index[2], index[3]: the index of pixels that to compute histogram in refImg
"""
index = [x.cpu().numpy() for x in index]
dstImg = dstImg.detach().cpu().numpy()
refImg = refImg.detach().cpu().numpy()
dst_align = [dstImg[i, index[0], index[1]] for i in range(0, 3)]
ref_align = [refImg[i, index[2], index[3]] for i in range(0, 3)]
hist_ref = cal_hist(ref_align)
hist_dst = cal_hist(dst_align)
tables = [cal_trans(hist_dst[i], hist_ref[i]) for i in range(0, 3)]
mid = copy.deepcopy(dst_align)
for i in range(0, 3):
for k in range(0, len(index[0])):
dst_align[i][k] = tables[i][int(mid[i][k])]
for i in range(0, 3):
dstImg[i, index[0], index[1]] = dst_align[i]
dstImg = torch.FloatTensor(dstImg)
return dstImg
================================================
FILE: models/modules/module_attn.py
================================================
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiheadAttention_weight(nn.Module):
def __init__(self, feature_dim, proj_dim, num_heads=1, dropout=0.0, bias=True):
super(MultiheadAttention_weight, self).__init__()
self.feature_dim = feature_dim
self.proj_dim = proj_dim
self.num_heads = num_heads
self.dropout = nn.Dropout(dropout)
self.head_dim = proj_dim // num_heads
assert self.head_dim * num_heads == self.proj_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.q_proj = nn.Linear(feature_dim, proj_dim, bias=bias)
self.k_proj = nn.Linear(feature_dim, proj_dim, bias=bias)
def forward(self, fea_c, fea_s, mask_c, mask_s):
'''
fea_c: (b, d, h, w)
mask_c: (b, c, h, w)
'''
bsz, dim, h, w = fea_c.shape; mask_channel = mask_c.shape[1]
fea_c = fea_c.view(bsz, dim, h*w).transpose(1, 2) # (b, HW, d)
fea_s = fea_s.view(bsz, dim, h*w).transpose(1, 2)
with torch.no_grad():
if mask_c.shape[2] != h:
mask_c = F.interpolate(mask_c, size=(h, w))
mask_s = F.interpolate(mask_s, size=(h, w))
mask_c = mask_c.view(bsz, mask_channel, -1, h*w) # (b, m_c, 1, HW)
mask_s = mask_s.view(bsz, mask_channel, -1, h*w)
mask_attn = torch.matmul(mask_c.transpose(-2, -1), mask_s) # (b, m_c, HW, HW)
mask_attn = torch.sum(mask_attn, dim=1, keepdim=True).clamp_(0, 1) # (b, 1, HW, HW)
mask_sum = torch.sum(mask_attn, dim=-1, keepdim=True)
mask_attn += (mask_sum == 0).float()
mask_attn = mask_attn.masked_fill_(mask_attn == 0, float('-inf')).masked_fill_(mask_attn == 1, float(0.0))
query = self.q_proj(fea_c) # (b, HW, D)
key = self.k_proj(fea_s) # (b, HW, D)
query = query.view(bsz, h*w, self.num_heads, self.head_dim).transpose(1, 2) # (b, h, HW, D)
key = key.view(bsz, h*w, self.num_heads, self.head_dim).transpose(1, 2)
weights = torch.matmul(query, key.transpose(-1, -2)) # (b, h, HW, HW)
weights = weights * self.scaling
weights = weights + mask_attn.detach()
weights = self.dropout(F.softmax(weights, dim=-1))
weights = weights * (1 - (mask_sum == 0).float().detach())
return weights
class MultiheadAttention_value(nn.Module):
def __init__(self, feature_dim, proj_dim, num_heads=1, bias=True):
super(MultiheadAttention_value, self).__init__()
self.feature_dim = feature_dim
self.proj_dim = proj_dim
self.num_heads = num_heads
self.head_dim = proj_dim // num_heads
assert self.head_dim * num_heads == self.proj_dim, "embed_dim must be divisible by num_heads"
self.v_proj = nn.Linear(feature_dim, proj_dim, bias=bias)
def forward(self, weights, fea):
'''
weights: (b, h, HW. HW)
fea: (b, d, H, W)
'''
bsz, dim, h, w = fea.shape
fea = fea.view(bsz, dim, h*w).transpose(1, 2) #(b, HW, D)
value = self.v_proj(fea)
value = value.view(bsz, h*w, self.num_heads, self.head_dim).transpose(1, 2) #(b, h, HW, D)
out = torch.matmul(weights, value)
out = out.transpose(1, 2).contiguous().view(bsz, h*w, self.proj_dim) # (b, HW, D)
out = out.transpose(1, 2).view(bsz, self.proj_dim, h, w) #(b, d, H, W)
return out
class MultiheadAttention(nn.Module):
def __init__(self, in_channels, proj_channels, value_channels, out_channels, num_heads=1, dropout=0.0, bias=True):
super(MultiheadAttention, self).__init__()
self.weight = MultiheadAttention_weight(in_channels, proj_channels, num_heads, dropout, bias)
self.value = MultiheadAttention_value(value_channels, out_channels, num_heads, bias)
def forward(self, fea_q, fea_k, fea_v, mask_q, mask_k):
'''
fea: (b, d, h, w)
mask: (b, c, h, w)
'''
weights = self.weight(fea_q, fea_k, mask_q, mask_k)
return self.value(weights, fea_v)
class FeedForwardLayer(nn.Module):
def __init__(self, feature_dim, ff_dim, dropout=0.0):
super(FeedForwardLayer, self).__init__()
self.main = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(p=dropout, inplace=True),
nn.Conv2d(feature_dim, ff_dim, kernel_size=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ff_dim, feature_dim, kernel_size=1)
)
def forward(self, x):
return self.main(x)
class Attention_apply(nn.Module):
def __init__(self, feature_dim, normalize=True):
super(Attention_apply, self).__init__()
self.normalize = normalize
if normalize:
self.norm = nn.InstanceNorm2d(feature_dim, affine=False)
self.actv = nn.LeakyReLU(0.2, inplace=True)
self.conv = nn.Conv2d(feature_dim, feature_dim, kernel_size=3, stride=1, padding=1, bias=False)
def forward(self, x, attn_out):
if self.normalize:
x = self.norm(x)
x = x * (1 + attn_out)
return self.conv(self.actv(x))
================================================
FILE: models/modules/module_base.py
================================================
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
"""Residual Block."""
def __init__(self, dim_in, dim_out):
super(ResidualBlock, self).__init__()
self.main = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False)
)
self.skip = nn.Identity() if dim_in == dim_out else nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=False)
def forward(self, x):
x = self.skip(x) + self.main(x)
return x / math.sqrt(2)
class ResidualBlock_IN(nn.Module):
"""Residual Block with InstanceNorm."""
def __init__(self, dim_in, dim_out, affine=False):
super(ResidualBlock_IN, self).__init__()
self.main = nn.Sequential(
nn.InstanceNorm2d(dim_in, affine=affine),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
nn.InstanceNorm2d(dim_out, affine=affine),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
)
self.skip = nn.Identity() if dim_in == dim_out else nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=False)
def forward(self, x):
x = self.skip(x) + self.main(x)
return x / math.sqrt(2)
class ResidualBlock_Downsample(nn.Module):
"""Residual Block with InstanceNorm."""
def __init__(self, dim_in, dim_out, affine=False):
super(ResidualBlock_Downsample, self).__init__()
self.main = nn.Sequential(
nn.InstanceNorm2d(dim_in, affine=affine),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(dim_in, dim_out, kernel_size=4, stride=2, padding=1, bias=False)
)
if dim_in == dim_out:
self.skip = nn.Identity()
else:
self.skip = nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=1, bias=False)
def forward(self, x):
skip = F.interpolate(self.skip(x), scale_factor=0.5, mode='bilinear', align_corners=False, recompute_scale_factor=True)
res = self.main(x)
x = skip + res
return x / math.sqrt(2)
class Downsample(nn.Module):
"""Residual Block with InstanceNorm."""
def __init__(self, dim_in, dim_out, affine=False):
super(Downsample, self).__init__()
self.main = nn.Sequential(
nn.InstanceNorm2d(dim_in, affine=affine),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(dim_in, dim_out, kernel_size=4, stride=2, padding=1, bias=False)
)
def forward(self, x):
return self.main(x)
class ResidualBlock_Upsample(nn.Module):
"""Residual Block with InstanceNorm."""
def __init__(self, dim_in, dim_out, normalize=True, affine=False):
super(ResidualBlock_Upsample, self).__init__()
if normalize:
self.main = nn.Sequential(
nn.InstanceNorm2d(dim_in, affine=affine),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose2d(dim_in, dim_out, kernel_size=4, stride=2, padding=1, bias=False)
)
else:
self.main = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose2d(dim_in, dim_out, kernel_size=4, stride=2, padding=1, bias=False)
)
if dim_in == dim_out:
self.skip = nn.Identity()
else:
self.skip = nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=1, bias=False)
def forward(self, x):
skip = F.interpolate(self.skip(x), scale_factor=2, mode='bilinear', align_corners=False)
res = self.main(x)
x = skip + res
return x / math.sqrt(2)
class Upsample(nn.Module):
"""Residual Block with InstanceNorm."""
def __init__(self, dim_in, dim_out, normalize=True, affine=False):
super(Upsample, self).__init__()
if normalize:
self.main = nn.Sequential(
nn.InstanceNorm2d(dim_in, affine=affine),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose2d(dim_in, dim_out, kernel_size=4, stride=2, padding=1, bias=False)
)
else:
self.main = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose2d(dim_in, dim_out, kernel_size=4, stride=2, padding=1, bias=False)
)
def forward(self, x):
return self.main(x)
class PositionalEmbedding(nn.Module):
def __init__(self, embedding_dim=136, feature_size=64, max_size=None, embedding_type='l2_norm'):
super(PositionalEmbedding, self).__init__()
self.embedding_dim = embedding_dim
self.feature_size = feature_size
self.max_size = max_size
assert embedding_type in ['l2_norm', 'uniform', 'sin']
self.embedding_type = embedding_type
@torch.no_grad()
def forward(self, diff, mask):
'''
diff: (b, d, h, w)
mask: (b, 3, h, w)
return: (b, d, h, w)
'''
bsz, init_dim, init_size, _ = diff.shape
assert self.embedding_dim >= init_dim
diff = F.interpolate(diff, self.feature_size) # (b, d, h, w)
mask = F.interpolate(mask, size=self.feature_size)
mask = torch.sum(mask, dim=1, keepdim=True) # (b, 1, h, w)
diff = diff * mask
if self.embedding_type == 'l2_norm':
norm = torch.norm(diff, dim=1, keepdim=True)
norm = (norm == 0) + norm
diff = diff / norm
elif self.embedding_type == 'uniform':
diff = diff / self.max_size
elif self.embedding_type == 'sin':
diff = torch.sin(diff * math.pi / (2 * self.max_size))
if self.embedding_dim > init_dim:
zero_shape = (bsz, self.embedding_dim - init_dim, self.feature_size, self.feature_size)
zero_padding = torch.zeros(zero_shape, device=diff.device)
diff = torch.cat((diff, zero_padding), dim=1)
diff = diff.detach(); diff.requires_grad = False
return diff
class MergeBlock(nn.Module):
def __init__(self, merge_mode, feature_dim, normalize=True):
super(MergeBlock, self).__init__()
assert merge_mode in ['conv', 'add', 'affine']
self.merge_mode = merge_mode
if merge_mode == 'affine':
self.norm = nn.LayerNorm(feature_dim, elementwise_affine=False) if normalize else nn.Identity()
else:
self.norm = nn.InstanceNorm2d(feature_dim, affine=False) if normalize else nn.Identity()
self.norm_r = nn.InstanceNorm2d(feature_dim, affine=False) if normalize else nn.Identity()
self.actv = nn.LeakyReLU(0.2, inplace=True)
if merge_mode == 'conv':
self.conv = nn.Conv2d(2 * feature_dim, feature_dim, kernel_size=3, stride=1, padding=1, bias=False)
else:
self.conv = nn.Conv2d(feature_dim, feature_dim, kernel_size=3, stride=1, padding=1, bias=False)
def forward(self, fea_s, fea_r):
if self.merge_mode == 'conv':
fea_s = self.norm(fea_s)
fea_r = self.norm_r(fea_r)
fea_s = torch.cat((fea_s, fea_r), dim=1)
elif self.merge_mode == 'add':
fea_s = self.norm(fea_s)
fea_r = self.norm_r(fea_r)
fea_s = (fea_s + fea_r) / math.sqrt(2)
elif self.merge_mode == 'affine':
fea_s = fea_s.permute(0, 2, 3, 1)
fea_s = self.norm(fea_s)
fea_s = fea_s.permute(0, 3, 1, 2)
fea_s = fea_s * (1 + fea_r)
return self.conv(self.actv(fea_s))
================================================
FILE: models/modules/pseudo_gt.py
================================================
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import functional
from models.modules.tps_transform import tps_sampler, tps_spatial_transform
def expand_area(mask:torch.Tensor, margin:int):
'''
mask: (C, H, W) or (N, C, H, W)
'''
kernel = np.zeros((margin * 2 + 1, margin * 2 + 1), dtype=np.uint8)
kernel = cv2.circle(kernel, (margin, margin), margin, (255, 0, 0), -1)
kernel = torch.FloatTensor((kernel > 0)).unsqueeze(0).unsqueeze(0).to(mask.device)
ndim = mask.ndimension()
if ndim == 3:
mask = mask.unsqueeze(0)
expanded_mask = torch.zeros_like(mask)
for i in range(mask.shape[1]):
expanded_mask[:,i:i+1,:,:] = F.conv2d(mask[:,i:i+1,:,:], kernel, padding=margin)
if ndim == 3:
expanded_mask = expanded_mask.squeeze(0)
return (expanded_mask > 0).float()
def mask_blur(mask:torch.Tensor, blur_size=3, mode='smooth'):
"""Blur the edge of mask so that the compose image have smooth transition
Args:
mask (torch.Tensor): [C, H, W]
blur_size (int): size of blur kernel. Defaults to 3.
mode (str) Defaults to 'smooth'.
Returns:
torch.Tensor: blurred mask
"""
#kernel = torch.ones((1, 1, blur_size * 2 + 1, blur_size * 2 + 1)).to(mask.device)
kernel = np.zeros((blur_size * 2 + 1, blur_size * 2 + 1), dtype=np.uint8)
kernel = cv2.circle(kernel, (blur_size, blur_size), blur_size, (255, 0, 0), -1)
kernel = torch.FloatTensor((kernel > 0)).unsqueeze(0).unsqueeze(0).to(mask.device)
kernel = kernel / torch.sum(kernel)
ndim = mask.ndimension()
if ndim == 3:
mask = mask.unsqueeze(0)
mask_blur = torch.zeros_like(mask)
for i in range(mask.shape[1]):
mask_blur[:,i:i+1,:,:] = F.conv2d(mask[:,i:i+1,:,:], kernel, padding=blur_size)
if mode == 'valid':
mask_blur = (mask_blur.clamp(0.5, 1) - 0.5) * 2 * mask
if ndim == 3:
mask_blur = mask_blur.squeeze(0)
return mask_blur.clamp(0, 1)
def mask_blend(mask, blend_alpha, mask_bound=None, blur_size=3, blend_mode='smooth'):
if blur_size > 0:
mask = mask_blur(mask, blur_size, blend_mode)
mask = mask * blend_alpha
if mask_bound is None:
return mask
else:
return mask * mask_bound
def tps_align(img_size, lms_r, lms_s, image_r, image_s=None,
mask_r = None, mask_s=None, sample_mode='bilinear'):
'''
image: (C, H, W), lms: (K, 2), mask:(1, H, W)
'''
lms_s = torch.flip(lms_s, dims=[1]) / (img_size - 1)
lms_r = (torch.flip(lms_r, dims=[1]) / (img_size - 1)).unsqueeze(0)
image_r = image_r.unsqueeze(0)
image_trans, _ = tps_spatial_transform(img_size, img_size, lms_s, image_r, lms_r, sample_mode)
if mask_r is not None:
mask_r_trans, _ = tps_spatial_transform(img_size, img_size, lms_s, mask_r.unsqueeze(0),
lms_r, 'nearest')
if image_s is not None:
mask_compose = torch.ones((1, img_size, img_size), device=lms_r.device)
if mask_s is not None:
mask_compose *= mask_s
if mask_r is not None:
mask_compose *= mask_r_trans.squeeze(0)
return image_s * (1 - mask_compose) + image_trans.squeeze(0) * mask_compose
else:
return image_trans.squeeze(0)
def tps_blend(blend_alpha, img_size, lms_r, lms_s, image_r, image_s=None, mask_r = None, mask_s=None,
mask_s_bound=None, blur_size=7, sample_mode='bilinear', blend_mode='smooth'):
'''
image: (C, H, W), lms: (K, 2), mask:(1, H, W)
'''
lms_s = torch.flip(lms_s, dims=[1]) / (img_size - 1)
lms_r = (torch.flip(lms_r, dims=[1]) / (img_size - 1)).unsqueeze(0)
image_r = image_r.unsqueeze(0)
image_trans, _ = tps_spatial_transform(img_size, img_size, lms_s, image_r, lms_r, sample_mode)
if mask_r is not None:
mask_r_trans, _ = tps_spatial_transform(img_size, img_size, lms_s, mask_r.unsqueeze(0),
lms_r, 'nearest')
if image_s is not None:
mask_compose = torch.ones((1, img_size, img_size), device=lms_r.device)
if mask_s is not None:
mask_compose *= mask_s
if mask_r is not None:
mask_compose *= mask_r_trans.squeeze(0)
mask_compose = mask_blend(mask_compose, blend_alpha, mask_s_bound, blur_size, blend_mode)
return image_s * (1 - mask_compose) + image_trans.squeeze(0) * mask_compose
else:
return image_trans.squeeze(0)
def fine_align(img_size, lms_r, lms_s, image_r, image_s, mask_r, mask_s, margins, blend_alphas):
'''
image: (C, H, W), lms: (K, 2)
mask: (C, H, W), lip, face, left eye, right eye
margins: dictionary, blend_alphas: dictionary
'''
# skin align
image_s = tps_blend(blend_alphas['skin'], img_size, lms_r[:60], lms_s[:60], image_r, image_s,
mask_r[1:2], mask_s[1:2], mask_s[1:2], blur_size=8, blend_mode='valid')
# lip align
mask_s_lip = expand_area(mask_s[0:1], margins['lip'])
mask_r_lip = expand_area(mask_r[0:1], margins['lip'])
image_s = tps_blend(blend_alphas['lip'], img_size, lms_r[48:], lms_s[48:], image_r, image_s,
mask_r_lip, mask_s_lip, mask_s[0:1], blur_size=3)
# left eye align
mask_s_eye = expand_area(mask_s[2:3], margins['eye'])
mask_r_eye = expand_area(mask_r[2:3], margins['eye']) * mask_r[1:2]
image_s = tps_blend(blend_alphas['eye'], img_size,
torch.cat((lms_r[14:17], lms_r[22:27], lms_r[27:31], lms_r[42:48]), dim=0),
torch.cat((lms_s[14:17], lms_s[22:27], lms_s[27:31], lms_s[42:48]), dim=0),
image_r, image_s, mask_r_eye, mask_s_eye, mask_s[1:2],
blur_size=5, sample_mode='nearest')
# right eye align
mask_s_eye = expand_area(mask_s[3:4], margins['eye'])
mask_r_eye = expand_area(mask_r[3:4], margins['eye']) * mask_r[1:2]
image_s = tps_blend(blend_alphas['eye'], img_size,
torch.cat((lms_r[0:3], lms_r[17:22], lms_r[27:31], lms_r[36:42]), dim=0),
torch.cat((lms_s[0:3], lms_s[17:22], lms_s[27:31], lms_s[36:42]), dim=0),
image_r, image_s, mask_r_eye, mask_s_eye, mask_s[1:2],
blur_size=5, sample_mode='nearest')
return image_s
if __name__ == "__main__":
pass
================================================
FILE: models/modules/sow_attention.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class WindowAttention(nn.Module):
def __init__(self, window_size, in_channels, proj_channels, value_channels, out_channels,
num_heads=1, dropout=0.0, bias=True, weighted_output=True):
super(WindowAttention, self).__init__()
assert window_size % 2 == 0
self.window_size = window_size
self.weighted_output = weighted_output
window_weight = self.generate_window_weight()
self.register_buffer('window_weight', window_weight)
self.num_heads = num_heads
self.dropout = nn.Dropout(dropout)
self.in_channels = in_channels
self.proj_channels = proj_channels
head_dim = proj_channels // num_heads
assert head_dim * num_heads == self.proj_channels, "embed_dim must be divisible by num_heads"
self.scaling = head_dim ** -0.5
self.q_proj = nn.Conv2d(in_channels, proj_channels, kernel_size=1, bias=bias)
self.k_proj = nn.Conv2d(in_channels, proj_channels, kernel_size=1, bias=bias)
self.value_channels = value_channels
self.out_channels = out_channels
assert out_channels // num_heads * num_heads == self.out_channels
self.v_proj = nn.Conv2d(value_channels, out_channels, kernel_size=1, bias=bias)
@torch.no_grad()
def generate_window_weight(self):
yc = torch.arange(self.window_size // 2).unsqueeze(1).repeat(1, self.window_size // 2)
xc = torch.arange(self.window_size // 2).unsqueeze(0).repeat(self.window_size // 2, 1)
window_weight = xc * yc / (self.window_size // 2 - 1) ** 2
window_weight = torch.cat((window_weight, torch.flip(window_weight, dims=[0])), dim=0)
window_weight = torch.cat((window_weight, torch.flip(window_weight, dims=[1])), dim=1)
return window_weight.view(-1)
def make_window(self, x: torch.Tensor):
"""
input: (B, C, H, W)
output: (B, h, H/S, W/S, S*S, C/h)
"""
bsz, dim, h, w = x.shape
x = x.view(bsz, self.num_heads, dim // self.num_heads, h // self.window_size, self.window_size,
w // self.window_size, self.window_size)
x = x.transpose(4, 5).contiguous().view(bsz, self.num_heads, dim // self.num_heads,
h // self.window_size, w // self.window_size, self.window_size**2)
x = x.permute(0, 1, 3, 4, 5, 2)
return x
def demake_window(self, x: torch.Tensor):
"""
input: (B, h, H/S, W/S, S*S, C/h)
output: (B, C, H, W)
"""
bsz, _, h_s, w_s, _, dim_h = x.shape
x = x.permute(0, 1, 5, 2, 3, 4).contiguous()
#print(x.shape)
x = x.view(bsz, dim_h * self.num_heads, h_s, w_s, self.window_size, self.window_size)
#print(x.shape)
x = x.transpose(3, 4).contiguous().view(bsz, dim_h * self.num_heads,
h_s * self.window_size, w_s * self.window_size)
#print(x.shape)
return x
@torch.no_grad()
def make_mask_window(self, mask: torch.Tensor):
"""
input: (B, C, H, W)
output: (B, 1, H/S, W/S, S*S, C)
"""
bsz, mask_channel, h, w = mask.shape
mask = mask.view(bsz, 1, mask_channel, h // self.window_size, self.window_size,
w // self.window_size, self.window_size)
mask = mask.transpose(4, 5).contiguous().view(bsz, 1, mask_channel,
h // self.window_size, w // self.window_size, self.window_size**2)
mask = mask.permute(0, 1, 3, 4, 5, 2)
return mask
def forward(self, fea_q, fea_k, fea_v, mask_q=None, mask_k=None):
'''
fea: (b, d, h, w)
mask: (b, c, h, w)
'''
query = self.q_proj(fea_q) # (B, D, H, W)
key = self.k_proj(fea_k)
value = self.v_proj(fea_v)
query = self.make_window(query) # (B, h, H/S, W/S, S*S, D/h)
key = self.make_window(key)
value = self.make_window(value)
weights = torch.matmul(query, key.transpose(-1, -2)) # (B, h, H/S, W/S, S*S, S*S)
weights = weights * self.scaling
if mask_q is not None and mask_k is not None:
mask_q = self.make_mask_window(mask_q) # (B, 1, H/S, W/S, S*S, C)
mask_k = self.make_mask_window(mask_k)
with torch.no_grad():
mask_attn = torch.matmul(mask_q, mask_k.transpose(-1, -2))
mask_sum = torch.sum(mask_attn, dim=-1, keepdim=True)
mask_attn += (mask_sum == 0).float()
mask_attn = mask_attn.masked_fill_(mask_attn == 0, float('-inf')).masked_fill_(mask_attn == 1, float(0.0))
weights += mask_attn
weights = self.dropout(F.softmax(weights, dim=-1))
if mask_q is not None and mask_k is not None:
weights = weights * (1 - (mask_sum == 0).float().detach())
out = torch.matmul(weights, value) # (B, h, H/S, W/S, S*S, D/h)
if self.weighted_output:
window_weight = self.window_weight.view(1, 1, 1, 1, self.window_size ** 2, 1)
out = out * window_weight
out = self.demake_window(out) #(B, D, H, W)
return out
class SowAttention(nn.Module):
def __init__(self, window_size, in_channels, proj_channels, value_channels, out_channels,
num_heads=1, dropout=0.0, bias=True):
super(SowAttention, self).__init__()
assert window_size % 2 == 0
self.window_size = window_size
self.pad = nn.ZeroPad2d(window_size // 2)
self.window_attention = WindowAttention(window_size, in_channels, proj_channels, value_channels,
out_channels, num_heads, dropout, bias)
def forward(self, fea_q, fea_k, fea_v, mask_q=None, mask_k=None):
'''
fea: (b, d, h, w)
mask: (b, c, h, w)
'''
out_0 = self.window_attention(fea_q, fea_k, fea_v, mask_q, mask_k)
fea_q = self.pad(fea_q)
fea_k = self.pad(fea_k)
fea_v = self.pad(fea_v)
if mask_q is not None and mask_k is not None:
mask_q = self.pad(mask_q)
mask_k = self.pad(mask_k)
else:
mask_q = None; mask_k = None
out_1 = self.window_attention(fea_q, fea_k, fea_v, mask_q, mask_k)
out_1 = out_1[:, :, self.window_size//2:-self.window_size//2, self.window_size//2:-self.window_size//2]
if mask_q is not None and mask_k is not None:
out_2 = self.window_attention(
fea_q[:, :, :, self.window_size//2:-self.window_size//2],
fea_k[:, :, :, self.window_size//2:-self.window_size//2],
fea_v[:, :, :, self.window_size//2:-self.window_size//2],
mask_q[:, :, :, self.window_size//2:-self.window_size//2],
mask_k[:, :, :, self.window_size//2:-self.window_size//2]
)
else:
out_2 = self.window_attention(
fea_q[:, :, :, self.window_size//2:-self.window_size//2],
fea_k[:, :, :, self.window_size//2:-self.window_size//2],
fea_v[:, :, :, self.window_size//2:-self.window_size//2],
)
out_2 = out_2[:, :, self.window_size//2:-self.window_size//2, :]
if mask_q is not None and mask_k is not None:
out_3 = self.window_attention(
fea_q[:, :, self.window_size//2:-self.window_size//2, :],
fea_k[:, :, self.window_size//2:-self.window_size//2, :],
fea_v[:, :, self.window_size//2:-self.window_size//2, :],
mask_q[:, :, self.window_size//2:-self.window_size//2, :],
mask_k[:, :, self.window_size//2:-self.window_size//2, :]
)
else:
out_3 = self.window_attention(
fea_q[:, :, self.window_size//2:-self.window_size//2, :],
fea_k[:, :, self.window_size//2:-self.window_size//2, :],
fea_v[:, :, self.window_size//2:-self.window_size//2, :],
)
out_3 = out_3[:, :, :, self.window_size//2:-self.window_size//2]
out = out_0 + out_1 + out_2 + out_3
return out
class StridedwindowAttention(nn.Module):
def __init__(self, stride, in_channels, proj_channels, value_channels, out_channels,
num_heads=1, dropout=0.0, bias=True):
super(StridedwindowAttention, self).__init__()
self.stride = stride
self.num_heads = num_heads
self.dropout = nn.Dropout(dropout)
self.in_channels = in_channels
self.proj_channels = proj_channels
head_dim = proj_channels // num_heads
assert head_dim * num_heads == self.proj_channels, "embed_dim must be divisible by num_heads"
self.scaling = head_dim ** -0.5
self.q_proj = nn.Conv2d(in_channels, proj_channels, kernel_size=1, bias=bias)
self.k_proj = nn.Conv2d(in_channels, proj_channels, kernel_size=1, bias=bias)
self.value_channels = value_channels
self.out_channels = out_channels
assert out_channels // num_heads * num_heads == self.out_channels
self.v_proj = nn.Conv2d(value_channels, out_channels, kernel_size=1, bias=bias)
def make_window(self, x: torch.Tensor):
"""
input: (B, C, H, W)
output: (B, h, S(h), S(w), H/S * W/S, C/h)
"""
bsz, dim, h, w = x.shape
assert h % self.stride == 0 and w % self.stride == 0
x = x.view(bsz, self.num_heads, dim // self.num_heads, h // self.stride, self.stride,
w // self.stride, self.stride) # (B, h, C/h, H/S, S(h), W/S, S(w))
x = x.permute(0, 1, 4, 6, 3, 5, 2).contiguous() # (B, h, S(h), S(w), H/S, W/S, C/h)
x = x.view(bsz, self.num_heads, self.stride, self.stride,
h // self.stride * w // self.stride, dim // self.num_heads)
return x
def demake_window(self, x: torch.Tensor, h, w):
"""
input: (B, h, S(h), S(w), H/S * W/S, C/h)
output: (B, C, H, W)
"""
bsz, _, _, _, _, dim_h = x.shape
x = x.view(bsz, self.num_heads, self.stride, self.stride,
h // self.stride, w // self.stride, dim_h) # (B, h, S(h), S(w), H/S, W/S, C/h)
x = x.permute(0, 1, 6, 4, 2, 5, 3).contiguous() # (B, h, C/h, H/S, S(h), W/S, S(w))
x = x.view(bsz, dim_h * self.num_heads, h, w)
return x
@torch.no_grad()
def make_mask_window(self, mask: torch.Tensor):
"""
input: (B, C, H, W)
output: (B, 1, S(h), S(w), H/S * W/S, C)
"""
bsz, mask_channel, h, w = mask.shape
assert h % self.stride == 0 and w % self.stride == 0
mask = mask.view(bsz, 1, mask_channel, h // self.stride, self.stride, w // self.stride, self.stride)
mask = mask.permute(0, 1, 4, 6, 3, 5, 2).contiguous()
mask = mask.view(bsz, 1, self.stride, self.stride, h // self.stride * w // self.stride, mask_channel)
return mask
def forward(self, fea_q, fea_k, fea_v, mask_q=None, mask_k=None):
'''
fea: (b, d, h, w)
mask: (b, c, h, w)
'''
bsz, _, h, w = fea_q.shape
query = self.q_proj(fea_q) # (B, D, H, W)
key = self.k_proj(fea_k)
value = self.v_proj(fea_v)
query = self.make_window(query) # (B, h, S(h), S(w), H/S * W/S, C/h)
key = self.make_window(key)
value = self.make_window(value)
weights = torch.matmul(query, key.transpose(-1, -2)) # (B, h, S(h), S(w), H/S * W/S, H/S * W/S)
weights = weights * self.scaling
if mask_q is not None and mask_k is not None:
mask_q = self.make_mask_window(mask_q) # (B, 1, S(h), S(w), H/S * W/S, C)
mask_k = self.make_mask_window(mask_k)
with torch.no_grad():
mask_attn = torch.matmul(mask_q, mask_k.transpose(-1, -2))
mask_sum = torch.sum(mask_attn, dim=-1, keepdim=True)
mask_attn += (mask_sum == 0).float()
mask_attn = mask_attn.masked_fill_(mask_attn == 0, float('-inf')).masked_fill_(mask_attn == 1, float(0.0))
weights += mask_attn
weights = self.dropout(F.softmax(weights, dim=-1))
if mask_q is not None and mask_k is not None:
weights = weights * (1 - (mask_sum == 0).float().detach())
out = torch.matmul(weights, value) # (B, h, S(h), S(w), H/S * W/S, D/h)
out = self.demake_window(out, h, w) #(B, D, H, W)
return out
================================================
FILE: models/modules/spectral_norm.py
================================================
import torch
from torch.nn import Parameter
def l2normalize(v, eps=1e-12):
return v / (v.norm() + eps)
class SpectralNorm(object):
def __init__(self):
self.name = "weight"
#print(self.name)
self.power_iterations = 1
def compute_weight(self, module):
u = getattr(module, self.name + "_u")
v = getattr(module, self.name + "_v")
w = getattr(module, self.name + "_bar")
height = w.data.shape[0]
for _ in range(self.power_iterations):
v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
sigma = u.dot(w.view(height, -1).mv(v))
return w / sigma.expand_as(w)
@staticmethod
def apply(module):
name = "weight"
fn = SpectralNorm()
try:
u = getattr(module, name + "_u")
v = getattr(module, name + "_v")
w = getattr(module, name + "_bar")
except AttributeError:
w = getattr(module, name)
height = w.data.shape[0]
width = w.view(height, -1).data.shape[1]
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
w_bar = Parameter(w.data)
#del module._parameters[name]
module.register_parameter(name + "_u", u)
module.register_parameter(name + "_v", v)
module.register_parameter(name + "_bar", w_bar)
# remove w from parameter list
del module._parameters[name]
setattr(module, name, fn.compute_weight(module))
# recompute weight before every forward()
module.register_forward_pre_hook(fn)
return fn
def remove(self, module):
weight = self.compute_weight(module)
delattr(module, self.name)
del module._parameters[self.name + '_u']
del module._parameters[self.name + '_v']
del module._parameters[self.name + '_bar']
module.register_parameter(self.name, Parameter(weight.data))
def __call__(self, module, inputs):
setattr(module, self.name, self.compute_weight(module))
def spectral_norm(module):
SpectralNorm.apply(module)
return module
def remove_spectral_norm(module):
name = 'weight'
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, SpectralNorm) and hook.name == name:
hook.remove(module)
del module._forward_pre_hooks[k]
return module
raise ValueError("spectral_norm of '{}' not found in {}"
.format(name, module))
================================================
FILE: models/modules/tps_transform.py
================================================
from __future__ import absolute_import
import numpy as np
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
# TF32 is not enough, require FP32
# Disable automatic TF32 since Pytorch 1.7
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
def grid_sample(input, grid, mode='bilinear', canvas=None):
output = F.grid_sample(input, grid, mode=mode, align_corners=True)
if canvas is None:
return output
else:
input_mask = input.data.new(input.size()).fill_(1)
output_mask = F.grid_sample(input_mask, grid, mode='nearest', align_corners=True)
padded_output = output * output_mask + canvas * (1 - output_mask)
return padded_output
# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
def compute_partial_repr(input_points, control_points):
N = input_points.size(0)
M = control_points.size(0)
pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2)
# original implementation, very slow
# pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
pairwise_diff_square = pairwise_diff * pairwise_diff
pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, 1]
repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
#repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist + 1e-8)
# fix numerical error for 0 * log(0), substitute all nan with 0
mask = repr_matrix != repr_matrix
repr_matrix.masked_fill_(mask, 0)
return repr_matrix
# compute \Delta_c^-1
def bulid_delta_inverse(target_control_points):
'''
target_control_points: (N, 2)
'''
N = target_control_points.shape[0]
forward_kernel = torch.zeros(N + 3, N + 3).to(target_control_points.device)
target_control_partial_repr = compute_partial_repr(target_control_points, target_control_points)
forward_kernel[:N, :N].copy_(target_control_partial_repr)
forward_kernel[:N, -3].fill_(1)
forward_kernel[-3, :N].fill_(1)
forward_kernel[:N, -2:].copy_(target_control_points)
forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))
# compute inverse matrix
inverse_kernel = torch.inverse(forward_kernel)
return inverse_kernel
# create target coordinate matrix
def build_target_coordinate_matrix(target_height, target_width, target_control_points):
'''
target_control_points: (N, 2)
'''
HW = target_height * target_width
target_coordinate = list(itertools.product(range(target_height), range(target_width)))
target_coordinate = torch.Tensor(target_coordinate).to(target_control_points.device) # HW x 2
Y, X = target_coordinate.split(1, dim = 1)
Y = Y / (target_height - 1)
X = X / (target_width - 1)
target_coordinate = torch.cat([X, Y], dim = 1) # convert from (y, x) to (x, y)
target_coordinate_partial_repr = compute_partial_repr(target_coordinate, target_control_points)
target_coordinate_repr = torch.cat([
target_coordinate_partial_repr,
torch.ones((HW, 1), device=target_control_points.device),
target_coordinate], dim = 1)
return target_coordinate_repr
def tps_sampler(target_height, target_width, inverse_kernel, target_coordinate_repr,
source, source_control_points, sample_mode='bilinear'):
'''
inverse_kernel: \Delta_C^-1
target_coordinate_repr: \hat{p}
source: (B, C, H, W)
source_control_points: (B, N, 2)
'''
batch_size = source.shape[0]
Y = torch.cat([source_control_points, torch.zeros((batch_size, 3, 2), device=source.device)], dim=1)
mapping_matrix = torch.matmul(inverse_kernel, Y)
source_coordinate = torch.matmul(target_coordinate_repr, mapping_matrix)
grid = source_coordinate.view(-1, target_height, target_width, 2)
grid = torch.clamp(grid, 0, 1) # the source_control_points may be out of [0, 1].
# the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
grid = 2.0 * grid - 1.0
output_maps = grid_sample(source, grid, mode=sample_mode, canvas=None)
return output_maps, source_coordinate
def tps_spatial_transform(target_height, target_width, target_control_points,
source, source_control_points, sample_mode='bilinear'):
'''
target_control_points: (N, 2)
source: (B, C, H, W)
source_control_points: (B, N, 2)
'''
inverse_kernel = bulid_delta_inverse(target_control_points)
target_coordinate_repr = build_target_coordinate_matrix(target_height, target_width, target_control_points)
return tps_sampler(target_height, target_width, inverse_kernel, target_coordinate_repr,
source, source_control_points, sample_mode)
class TPSSpatialTransformer(nn.Module):
def __init__(self, target_height, target_width, target_control_points):
super(TPSSpatialTransformer, self).__init__()
self.target_height, self.target_width = target_height, target_width
self.num_control_points = target_control_points.shape[0]
# create padded kernel matrix
inverse_kernel = bulid_delta_inverse(target_control_points)
# create target coordinate matrix
target_coordinate_repr = build_target_coordinate_matrix(target_height, target_width, target_control_points)
# register precomputed matrices
self.register_buffer('inverse_kernel', inverse_kernel)
#self.register_buffer('padding_matrix', torch.zeros(3, 2))
self.register_buffer('target_coordinate_repr', target_coordinate_repr)
self.register_buffer('target_control_points', target_control_points)
def forward(self, source, source_control_points):
assert source_control_points.ndimension() == 3
assert source_control_points.size(1) == self.num_control_points
assert source_control_points.size(2) == 2
return tps_sampler(self.target_height, self.target_width,
self.inverse_kernel, self.target_coordinate_repr,
source, source_control_points)
================================================
FILE: scripts/demo.py
================================================
import os
import sys
import argparse
import numpy as np
import cv2
import torch
from PIL import Image
sys.path.append('.')
from training.config import get_config
from training.inference import Inference
from training.utils import create_logger, print_args
def main(config, args):
logger = create_logger(args.save_folder, args.name, 'info', console=True)
print_args(args, logger)
logger.info(config)
inference = Inference(config, args, args.load_path)
n_imgname = sorted(os.listdir(args.source_dir))
m_imgname = sorted(os.listdir(args.reference_dir))
for i, (imga_name, imgb_name) in enumerate(zip(n_imgname, m_imgname)):
imgA = Image.open(os.path.join(args.source_dir, imga_name)).convert('RGB')
imgB = Image.open(os.path.join(args.reference_dir, imgb_name)).convert('RGB')
result = inference.transfer(imgA, imgB, postprocess=True)
if result is None:
continue
imgA = np.array(imgA); imgB = np.array(imgB)
h, w, _ = imgA.shape
result = result.resize((h, w)); result = np.array(result)
vis_image = np.hstack((imgA, imgB, result))
save_path = os.path.join(args.save_folder, f"result_{i}.png")
Image.fromarray(vis_image.astype(np.uint8)).save(save_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser("argument for training")
parser.add_argument("--name", type=str, default='demo')
parser.add_argument("--save_path", type=str, default='result', help="path to save model")
parser.add_argument("--load_path", type=str, help="folder to load model",
default='ckpts/sow_pyramid_a5_e3d2_remapped.pth')
parser.add_argument("--source-dir", type=str, default="assets/images/non-makeup")
parser.add_argument("--reference-dir", type=str, default="assets/images/makeup")
parser.add_argument("--gpu", default='0', type=str, help="GPU id to use.")
args = parser.parse_args()
args.gpu = 'cuda:' + args.gpu
args.device = torch.device(args.gpu)
args.save_folder = os.path.join(args.save_path, args.name)
if not os.path.exists(args.save_folder):
os.makedirs(args.save_folder)
config = get_config()
main(config, args)
================================================
FILE: scripts/train.py
================================================
import os
import sys
import argparse
import torch
from torch.utils.data import DataLoader
sys.path.append('.')
from training.config import get_config
from training.dataset import MakeupDataset
from training.solver import Solver
from training.utils import create_logger, print_args
def main(config, args):
logger = create_logger(args.save_folder, args.name, 'info', console=True)
print_args(args, logger)
logger.info(config)
dataset = MakeupDataset(config)
data_loader = DataLoader(dataset, batch_size=config.DATA.BATCH_SIZE, num_workers=config.DATA.NUM_WORKERS, shuffle=True)
solver = Solver(config, args, logger)
solver.train(data_loader)
if __name__ == "__main__":
parser = argparse.ArgumentParser("argument for training")
parser.add_argument("--name", type=str, default='elegant')
parser.add_argument("--save_path", type=str, default='results', help="path to save model")
parser.add_argument("--load_folder", type=str, help="path to load model",
default=None)
parser.add_argument("--keepon", default=False, action="store_true", help='keep on training')
parser.add_argument("--gpu", default='0', type=str, help="GPU id to use.")
args = parser.parse_args()
config = get_config()
#args.gpu = 'cuda:' + args.gpu
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
#args.device = torch.device(args.gpu)
args.device = torch.device('cuda:0')
args.save_folder = os.path.join(args.save_path, args.name)
if not os.path.exists(args.save_folder):
os.makedirs(args.save_folder)
main(config, args)
================================================
FILE: training/__init__.py
================================================
================================================
FILE: training/config.py
================================================
from fvcore.common.config import CfgNode
"""
This file defines default options of configurations.
It will be further merged by yaml files and options from
the command-line.
Note that *any* hyper-parameters should be firstly defined
here to enable yaml and command-line configuration.
"""
_C = CfgNode()
# Logging and saving
_C.LOG = CfgNode()
_C.LOG.SAVE_FREQ = 10
_C.LOG.VIS_FREQ = 1
# Data settings
_C.DATA = CfgNode()
_C.DATA.PATH = './data/MT-Dataset'
_C.DATA.NUM_WORKERS = 4
_C.DATA.BATCH_SIZE = 1
_C.DATA.IMG_SIZE = 256
# Training hyper-parameters
_C.TRAINING = CfgNode()
_C.TRAINING.G_LR = 2e-4
_C.TRAINING.D_LR = 2e-4
_C.TRAINING.BETA1 = 0.5
_C.TRAINING.BETA2 = 0.999
_C.TRAINING.NUM_EPOCHS = 50
_C.TRAINING.LR_DECAY_FACTOR = 5e-2
_C.TRAINING.DOUBLE_D = False
# Loss weights
_C.LOSS = CfgNode()
_C.LOSS.LAMBDA_A = 10.0
_C.LOSS.LAMBDA_B = 10.0
_C.LOSS.LAMBDA_IDT = 0.5
_C.LOSS.LAMBDA_REC = 10
_C.LOSS.LAMBDA_MAKEUP = 100
_C.LOSS.LAMBDA_SKIN = 0.1
_C.LOSS.LAMBDA_EYE = 1.5
_C.LOSS.LAMBDA_LIP = 1
_C.LOSS.LAMBDA_MAKEUP_LIP = _C.LOSS.LAMBDA_MAKEUP * _C.LOSS.LAMBDA_LIP
_C.LOSS.LAMBDA_MAKEUP_SKIN = _C.LOSS.LAMBDA_MAKEUP * _C.LOSS.LAMBDA_SKIN
_C.LOSS.LAMBDA_MAKEUP_EYE = _C.LOSS.LAMBDA_MAKEUP * _C.LOSS.LAMBDA_EYE
_C.LOSS.LAMBDA_VGG = 5e-3
# Model structure
_C.MODEL = CfgNode()
_C.MODEL.D_TYPE = 'SN'
_C.MODEL.D_REPEAT_NUM = 3
_C.MODEL.D_CONV_DIM = 64
_C.MODEL.G_CONV_DIM = 64
_C.MODEL.NUM_HEAD = 1
_C.MODEL.DOUBLE_E = False
_C.MODEL.USE_FF = False
_C.MODEL.NUM_LAYER_E = 3
_C.MODEL.NUM_LAYER_D = 2
_C.MODEL.WINDOW_SIZE = 16
_C.MODEL.MERGE_MODE = 'conv'
# Preprocessing
_C.PREPROCESS = CfgNode()
_C.PREPROCESS.UP_RATIO = 0.6 / 0.85 # delta_size / face_size
_C.PREPROCESS.DOWN_RATIO = 0.2 / 0.85 # delta_size / face_size
_C.PREPROCESS.WIDTH_RATIO = 0.2 / 0.85 # delta_size / face_size
_C.PREPROCESS.LIP_CLASS = [7, 9]
_C.PREPROCESS.FACE_CLASS = [1, 6]
_C.PREPROCESS.EYEBROW_CLASS = [2, 3]
_C.PREPROCESS.EYE_CLASS = [4, 5]
_C.PREPROCESS.LANDMARK_POINTS = 68
# Pseudo ground truth
_C.PGT = CfgNode()
_C.PGT.EYE_MARGIN = 12
_C.PGT.LIP_MARGIN = 4
_C.PGT.ANNEALING = True
_C.PGT.SKIN_ALPHA = 0.3
_C.PGT.SKIN_ALPHA_MILESTONES = (0, 12, 24, 50)
_C.PGT.SKIN_ALPHA_VALUES = (0.2, 0.4, 0.3, 0.2)
_C.PGT.EYE_ALPHA = 0.8
_C.PGT.EYE_ALPHA_MILESTONES = (0, 12, 24, 50)
_C.PGT.EYE_ALPHA_VALUES = (0.6, 0.8, 0.6, 0.4)
_C.PGT.LIP_ALPHA = 0.1
_C.PGT.LIP_ALPHA_MILESTONES = (0, 12, 24, 50)
_C.PGT.LIP_ALPHA_VALUES = (0.05, 0.2, 0.1, 0.0)
# Postprocessing
_C.POSTPROCESS = CfgNode()
_C.POSTPROCESS.WILL_DENOISE = False
def get_config()->CfgNode:
return _C
================================================
FILE: training/dataset.py
================================================
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from training.config import get_config
from training.preprocess import PreProcess
class MakeupDataset(Dataset):
def __init__(self, config=None):
super(MakeupDataset, self).__init__()
if config is None:
config = get_config()
self.root = config.DATA.PATH
with open(os.path.join(config.DATA.PATH, 'makeup.txt'), 'r') as f:
self.makeup_names = [name.strip() for name in f.readlines()]
with open(os.path.join(config.DATA.PATH, 'non-makeup.txt'), 'r') as f:
self.non_makeup_names = [name.strip() for name in f.readlines()]
self.preprocessor = PreProcess(config, need_parser=False)
self.img_size = config.DATA.IMG_SIZE
def load_from_file(self, img_name):
image = Image.open(os.path.join(self.root, 'images', img_name)).convert('RGB')
mask = self.preprocessor.load_mask(os.path.join(self.root, 'segs', img_name))
base_name = os.path.splitext(img_name)[0]
lms = self.preprocessor.load_lms(os.path.join(self.root, 'lms', f'{base_name}.npy'))
return self.preprocessor.process(image, mask, lms)
def __len__(self):
return max(len(self.makeup_names), len(self.non_makeup_names))
def __getitem__(self, index):
idx_s = torch.randint(0, len(self.non_makeup_names), (1, )).item()
idx_r = torch.randint(0, len(self.makeup_names), (1, )).item()
name_s = self.non_makeup_names[idx_s]
name_r = self.makeup_names[idx_r]
source = self.load_from_file(name_s)
reference = self.load_from_file(name_r)
return source, reference
def get_loader(config):
dataset = MakeupDataset(config)
dataloader = DataLoader(dataset=dataset,
batch_size=config.DATA.BATCH_SIZE,
num_workers=config.DATA.NUM_WORKERS)
return dataloader
if __name__ == "__main__":
dataset = MakeupDataset()
dataloader = DataLoader(dataset, batch_size=1, num_workers=16)
for e in range(10):
for i, (point_s, point_r) in enumerate(dataloader):
pass
================================================
FILE: training/inference.py
================================================
from typing import List
import numpy as np
import cv2
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision.transforms import ToPILImage
from training.solver import Solver
from training.preprocess import PreProcess
from models.modules.pseudo_gt import expand_area, mask_blend
class InputSample:
def __init__(self, inputs, apply_mask=None):
self.inputs = inputs
self.transfer_input = None
self.attn_out_list = None
self.apply_mask = apply_mask
def clear(self):
self.transfer_input = None
self.attn_out_list = None
class Inference:
"""
An inference wrapper for makeup transfer.
It takes two image `source` and `reference` in,
and transfers the makeup of reference to source.
"""
def __init__(self, config, args, model_path="G.pth"):
self.device = args.device
self.solver = Solver(config, args, inference=model_path)
self.preprocess = PreProcess(config, args.device)
self.denoise = config.POSTPROCESS.WILL_DENOISE
self.img_size = config.DATA.IMG_SIZE
# TODO: can be a hyper-parameter
self.eyeblur = {'margin': 12, 'blur_size':7}
def prepare_input(self, *data_inputs):
"""
data_inputs: List[image, mask, diff, lms]
"""
inputs = []
for i in range(len(data_inputs)):
inputs.append(data_inputs[i].to(self.device).unsqueeze(0))
# prepare mask
inputs[1] = torch.cat((inputs[1][:,0:1], inputs[1][:,1:].sum(dim=1, keepdim=True)), dim=1)
return inputs
def postprocess(self, source, crop_face, result):
if crop_face is not None:
source = source.crop(
(crop_face.left(), crop_face.top(), crop_face.right(), crop_face.bottom()))
source = np.array(source)
result = np.array(result)
height, width = source.shape[:2]
small_source = cv2.resize(source, (self.img_size, self.img_size))
laplacian_diff = source.astype(
np.float) - cv2.resize(small_source, (width, height)).astype(np.float)
result = (cv2.resize(result, (width, height)) +
laplacian_diff).round().clip(0, 255)
result = result.astype(np.uint8)
if self.denoise:
result = cv2.fastNlMeansDenoisingColored(result)
result = Image.fromarray(result).convert('RGB')
return result
def generate_source_sample(self, source_input):
"""
source_input: List[image, mask, diff, lms]
"""
source_input = self.prepare_input(*source_input)
return InputSample(source_input)
def generate_reference_sample(self, reference_input, apply_mask=None,
source_mask=None, mask_area=None, saturation=1.0):
"""
all the operations on the mask, e.g., partial mask, saturation,
should be finally defined in apply_mask
"""
if source_mask is not None and mask_area is not None:
apply_mask = self.generate_partial_mask(source_mask, mask_area, saturation)
apply_mask = apply_mask.unsqueeze(0).to(self.device)
reference_input = self.prepare_input(*reference_input)
if apply_mask is None:
apply_mask = torch.ones(1, 1, self.img_size, self.img_size).to(self.device)
return InputSample(reference_input, apply_mask)
def generate_partial_mask(self, source_mask, mask_area='full', saturation=1.0):
"""
source_mask: (C, H, W), lip, face, left eye, right eye
return: apply_mask: (1, H, W)
"""
assert mask_area in ['full', 'skin', 'lip', 'eye']
if mask_area == 'full':
return torch.sum(source_mask[0:2], dim=0, keepdim=True) * saturation
elif mask_area == 'lip':
return source_mask[0:1] * saturation
elif mask_area == 'skin':
mask_l_eye = expand_area(source_mask[2:3], self.eyeblur['margin']) #* source_mask[1:2]
mask_r_eye = expand_area(source_mask[3:4], self.eyeblur['margin']) #* source_mask[1:2]
mask_eye = mask_l_eye + mask_r_eye
#mask_eye = mask_blend(mask_eye, 1.0, source_mask[1:2], blur_size=self.eyeblur['blur_size'])
mask_eye = mask_blend(mask_eye, 1.0, blur_size=self.eyeblur['blur_size'])
return source_mask[1:2] * (1 - mask_eye) * saturation
elif mask_area == 'eye':
mask_l_eye = expand_area(source_mask[2:3], self.eyeblur['margin']) #* source_mask[1:2]
mask_r_eye = expand_area(source_mask[3:4], self.eyeblur['margin']) #* source_mask[1:2]
mask_eye = mask_l_eye + mask_r_eye
#mask_eye = mask_blend(mask_eye, saturation, source_mask[1:2], blur_size=self.eyeblur['blur_size'])
mask_eye = mask_blend(mask_eye, saturation, blur_size=self.eyeblur['blur_size'])
return mask_eye
@torch.no_grad()
def interface_transfer(self, source_sample: InputSample, reference_samples: List[InputSample]):
"""
Input: a source sample and multiple reference samples
Return: PIL.Image, the fused result
"""
# encode source
if source_sample.transfer_input is None:
source_sample.transfer_input = self.solver.G.get_transfer_input(*source_sample.inputs)
# encode references
for r_sample in reference_samples:
if r_sample.transfer_input is None:
r_sample.transfer_input = self.solver.G.get_transfer_input(*r_sample.inputs, True)
# self attention
if source_sample.attn_out_list is None:
source_sample.attn_out_list = self.solver.G.get_transfer_output(
*source_sample.transfer_input, *source_sample.transfer_input
)
# full transfer for each reference
for r_sample in reference_samples:
if r_sample.attn_out_list is None:
r_sample.attn_out_list = self.solver.G.get_transfer_output(
*source_sample.transfer_input, *r_sample.transfer_input
)
# fusion
# if the apply_mask is changed without changing source and references,
# only the following steps are required
fused_attn_out_list = []
for i in range(len(source_sample.attn_out_list)):
init_attn_out = torch.zeros_like(source_sample.attn_out_list[i], device=self.device)
fused_attn_out_list.append(init_attn_out)
apply_mask_sum = torch.zeros((1, 1, self.img_size, self.img_size), device=self.device)
for r_sample in reference_samples:
if r_sample.apply_mask is not None:
apply_mask_sum += r_sample.apply_mask
for i in range(len(source_sample.attn_out_list)):
feature_size = r_sample.attn_out_list[i].shape[2]
apply_mask = F.interpolate(r_sample.apply_mask, feature_size, mode='nearest')
fused_attn_out_list[i] += apply_mask * r_sample.attn_out_list[i]
# self as reference
source_apply_mask = 1 - apply_mask_sum.clamp(0, 1)
for i in range(len(source_sample.attn_out_list)):
feature_size = source_sample.attn_out_list[i].shape[2]
apply_mask = F.interpolate(source_apply_mask, feature_size, mode='nearest')
fused_attn_out_list[i] += apply_mask * source_sample.attn_out_list[i]
# decode
result = self.solver.G.decode(
source_sample.transfer_input[0], fused_attn_out_list
)
result = self.solver.de_norm(result).squeeze(0)
result = ToPILImage()(result.cpu())
return result
def transfer(self, source: Image, reference: Image, postprocess=True):
"""
Args:
source (Image): The image where makeup will be transfered to.
reference (Image): Image containing targeted makeup.
Return:
Image: Transfered image.
"""
source_input, face, crop_face = self.preprocess(source)
reference_input, _, _ = self.preprocess(reference)
if not (source_input and reference_input):
return None
#source_sample = self.generate_source_sample(source_input)
#reference_samples = [self.generate_reference_sample(reference_input)]
#result = self.interface_transfer(source_sample, reference_samples)
source_input = self.prepare_input(*source_input)
reference_input = self.prepare_input(*reference_input)
result = self.solver.test(*source_input, *reference_input)
if not postprocess:
return result
else:
return self.postprocess(source, crop_face, result)
def joint_transfer(self, source: Image, reference_lip: Image, reference_skin: Image,
reference_eye: Image, postprocess=True):
source_input, face, crop_face = self.preprocess(source)
lip_input, _, _ = self.preprocess(reference_lip)
skin_input, _, _ = self.preprocess(reference_skin)
eye_input, _, _ = self.preprocess(reference_eye)
if not (source_input and lip_input and skin_input and eye_input):
return None
source_mask = source_input[1]
source_sample = self.generate_source_sample(source_input)
reference_samples = [
self.generate_reference_sample(lip_input, source_mask=source_mask, mask_area='lip'),
self.generate_reference_sample(skin_input, source_mask=source_mask, mask_area='skin'),
self.generate_reference_sample(eye_input, source_mask=source_mask, mask_area='eye')
]
result = self.interface_transfer(source_sample, reference_samples)
if not postprocess:
return result
else:
return self.postprocess(source, crop_face, result)
================================================
FILE: training/preprocess.py
================================================
import os
import sys
import cv2
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from torchvision.transforms import functional
sys.path.append('.')
import faceutils as futils
from training.config import get_config
class PreProcess:
def __init__(self, config, need_parser=True, device='cpu'):
self.img_size = config.DATA.IMG_SIZE
self.device = device
xs, ys = np.meshgrid(
np.linspace(
0, self.img_size - 1,
self.img_size
),
np.linspace(
0, self.img_size - 1,
self.img_size
)
)
xs = xs[None].repeat(config.PREPROCESS.LANDMARK_POINTS, axis=0)
ys = ys[None].repeat(config.PREPROCESS.LANDMARK_POINTS, axis=0)
fix = np.concatenate([ys, xs], axis=0)
self.fix = torch.Tensor(fix) #(136, h, w)
if need_parser:
self.face_parse = futils.mask.FaceParser(device=device)
self.up_ratio = config.PREPROCESS.UP_RATIO
self.down_ratio = config.PREPROCESS.DOWN_RATIO
self.width_ratio = config.PREPROCESS.WIDTH_RATIO
self.lip_class = config.PREPROCESS.LIP_CLASS
self.face_class = config.PREPROCESS.FACE_CLASS
self.eyebrow_class = config.PREPROCESS.EYEBROW_CLASS
self.eye_class = config.PREPROCESS.EYE_CLASS
self.transform = transforms.Compose([
transforms.Resize(config.DATA.IMG_SIZE),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])
############################## Mask Process ##############################
# mask attribute: 0:background 1:face 2:left-eyebrow 3:right-eyebrow 4:left-eye 5: right-eye 6: nose
# 7: upper-lip 8: teeth 9: under-lip 10:hair 11: left-ear 12: right-ear 13: neck
def mask_process(self, mask: torch.Tensor):
'''
mask: (1, h, w)
'''
mask_lip = (mask == self.lip_class[0]).float() + (mask == self.lip_class[1]).float()
mask_face = (mask == self.face_class[0]).float() + (mask == self.face_class[1]).float()
#mask_eyebrow_left = (mask == self.eyebrow_class[0]).float()
#mask_eyebrow_right = (mask == self.eyebrow_class[1]).float()
mask_face += (mask == self.eyebrow_class[0]).float()
mask_face += (mask == self.eyebrow_class[1]).float()
mask_eye_left = (mask == self.eye_class[0]).float()
mask_eye_right = (mask == self.eye_class[1]).float()
#mask_list = [mask_lip, mask_face, mask_eyebrow_left, mask_eyebrow_right, mask_eye_left, mask_eye_right]
mask_list = [mask_lip, mask_face, mask_eye_left, mask_eye_right]
mask_aug = torch.cat(mask_list, 0) # (C, H, W)
return mask_aug
def save_mask(self, mask: torch.Tensor, path):
assert mask.shape[0] == 1
mask = mask.squeeze(0).numpy().astype(np.uint8)
mask = Image.fromarray(mask)
mask.save(path)
def load_mask(self, path):
mask = np.array(Image.open(path).convert('L'))
mask = torch.FloatTensor(mask).unsqueeze(0)
mask = functional.resize(mask, self.img_size, transforms.InterpolationMode.NEAREST)
return mask
############################## Landmarks Process ##############################
def lms_process(self, image:Image):
face = futils.dlib.detect(image)
# face: rectangles, List of rectangles of face region: [(left, top), (right, bottom)]
if not face:
return None
face = face[0]
lms = futils.dlib.landmarks(image, face) * self.img_size / image.width # scale to fit self.img_size
# lms: narray, the position of 68 key points, (68 ,2)
lms = torch.IntTensor(lms.round()).clamp_max_(self.img_size - 1)
# distinguish upper and lower lips
lms[61:64,0] -= 1; lms[65:68,0] += 1
for i in range(3):
if torch.sum(torch.abs(lms[61+i] - lms[67-i])) == 0:
lms[61+i,0] -= 1; lms[67-i,0] += 1
# double check
'''for i in range(48, 67):
for j in range(i+1, 68):
if torch.sum(torch.abs(lms[i] - lms[j])) == 0:
lms[i,0] -= 1; lms[j,0] += 1'''
return lms
def diff_process(self, lms: torch.Tensor, normalize=False):
'''
lms:(68, 2)
'''
lms = lms.transpose(1, 0).reshape(-1, 1, 1) # (136, 1, 1)
diff = self.fix - lms # (136, h, w)
if normalize:
norm = torch.norm(diff, dim=0, keepdim=True).repeat(diff.shape[0], 1, 1)
norm = torch.where(norm == 0, torch.tensor(1e10), norm)
diff /= norm
return diff
def save_lms(self, lms: torch.Tensor, path):
lms = lms.numpy()
np.save(path, lms)
def load_lms(self, path):
lms = np.load(path)
return torch.IntTensor(lms)
############################## Compose Process ##############################
def preprocess(self, image: Image, is_crop=True):
'''
return: image: Image, (H, W), mask: tensor, (1, H, W)
'''
face = futils.dlib.detect(image)
# face: rectangles, List of rectangles of face region: [(left, top), (right, bottom)]
if not face:
return None, None, None
face_on_image = face[0]
if is_crop:
image, face, crop_face = futils.dlib.crop(
image, face_on_image, self.up_ratio, self.down_ratio, self.width_ratio)
else:
face = face[0]; crop_face = None
# image: Image, cropped face
# face: the same as above
# crop face: rectangle, face region in cropped face
np_image = np.array(image) # (h', w', 3)
mask = self.face_parse.parse(cv2.resize(np_image, (512, 512))).cpu()
# obtain face parsing result
# mask: Tensor, (512, 512)
mask = F.interpolate(
mask.view(1, 1, 512, 512),
(self.img_size, self.img_size),
mode="nearest").squeeze(0).long() #(1, H, W)
lms = futils.dlib.landmarks(image, face) * self.img_size / image.width # scale to fit self.img_size
# lms: narray, the position of 68 key points, (68 ,2)
lms = torch.IntTensor(lms.round()).clamp_max_(self.img_size - 1)
# distinguish upper and lower lips
lms[61:64,0] -= 1; lms[65:68,0] += 1
for i in range(3):
if torch.sum(torch.abs(lms[61+i] - lms[67-i])) == 0:
lms[61+i,0] -= 1; lms[67-i,0] += 1
image = image.resize((self.img_size, self.img_size), Image.ANTIALIAS)
return [image, mask, lms], face_on_image, crop_face
def process(self, image: Image, mask: torch.Tensor, lms: torch.Tensor):
image = self.transform(image)
mask = self.mask_process(mask)
diff = self.diff_process(lms)
return [image, mask, diff, lms]
def __call__(self, image:Image, is_crop=True):
source, face_on_image, crop_face = self.preprocess(image, is_crop)
if source is None:
return None, None, None
return self.process(*source), face_on_image, crop_face
if __name__ == "__main__":
config = get_config()
preprocessor = PreProcess(config, device='cuda:0')
if not os.path.exists(os.path.join(config.DATA.PATH, 'lms')):
os.makedirs(os.path.join(config.DATA.PATH, 'lms', 'makeup'))
os.makedirs(os.path.join(config.DATA.PATH, 'lms', 'non-makeup'))
# process makeup images
print("Processing makeup images...")
with open(os.path.join(config.DATA.PATH, 'makeup.txt'), 'r') as f:
for line in f.readlines():
img_name = line.strip()
raw_image = Image.open(os.path.join(config.DATA.PATH, 'images', img_name)).convert('RGB')
lms = preprocessor.lms_process(raw_image)
if lms is not None:
base_name = os.path.splitext(img_name)[0]
preprocessor.save_lms(lms, os.path.join(config.DATA.PATH, 'lms', f'{base_name}.npy'))
print("Done.")
# process non-makeup images
print("Processing non-makeup images...")
with open(os.path.join(config.DATA.PATH, 'non-makeup.txt'), 'r') as f:
for line in f.readlines():
img_name = line.strip()
raw_image = Image.open(os.path.join(config.DATA.PATH, 'images', img_name)).convert('RGB')
lms = preprocessor.lms_process(raw_image)
if lms is not None:
base_name = os.path.splitext(img_name)[0]
preprocessor.save_lms(lms, os.path.join(config.DATA.PATH, 'lms', f'{base_name}.npy'))
print("Done.")
================================================
FILE: training/solver.py
================================================
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import ToPILImage
from torchvision.utils import save_image, make_grid
import torch.nn.init as init
from tqdm import tqdm
from models.modules.pseudo_gt import expand_area
from models.model import get_discriminator, get_generator, vgg16
from models.loss import GANLoss, MakeupLoss, ComposePGT, AnnealingComposePGT
from training.utils import plot_curves
class Solver():
def __init__(self, config, args, logger=None, inference=False):
self.G = get_generator(config)
if inference:
self.G.load_state_dict(torch.load(inference, map_location=args.device))
self.G = self.G.to(args.device).eval()
return
self.double_d = config.TRAINING.DOUBLE_D
self.D_A = get_discriminator(config)
if self.double_d:
self.D_B = get_discriminator(config)
self.load_folder = args.load_folder
self.save_folder = args.save_folder
self.vis_folder = os.path.join(args.save_folder, 'visualization')
if not os.path.exists(self.vis_folder):
os.makedirs(self.vis_folder)
self.vis_freq = config.LOG.VIS_FREQ
self.save_freq = config.LOG.SAVE_FREQ
# Data & PGT
self.img_size = config.DATA.IMG_SIZE
self.margins = {'eye':config.PGT.EYE_MARGIN,
'lip':config.PGT.LIP_MARGIN}
self.pgt_annealing = config.PGT.ANNEALING
if self.pgt_annealing:
self.pgt_maker = AnnealingComposePGT(self.margins,
config.PGT.SKIN_ALPHA_MILESTONES, config.PGT.SKIN_ALPHA_VALUES,
config.PGT.EYE_ALPHA_MILESTONES, config.PGT.EYE_ALPHA_VALUES,
config.PGT.LIP_ALPHA_MILESTONES, config.PGT.LIP_ALPHA_VALUES
)
else:
self.pgt_maker = ComposePGT(self.margins,
config.PGT.SKIN_ALPHA,
config.PGT.EYE_ALPHA,
config.PGT.LIP_ALPHA
)
self.pgt_maker.eval()
# Hyper-param
self.num_epochs = config.TRAINING.NUM_EPOCHS
self.g_lr = config.TRAINING.G_LR
self.d_lr = config.TRAINING.D_LR
self.beta1 = config.TRAINING.BETA1
self.beta2 = config.TRAINING.BETA2
self.lr_decay_factor = config.TRAINING.LR_DECAY_FACTOR
# Loss param
self.lambda_idt = config.LOSS.LAMBDA_IDT
self.lambda_A = config.LOSS.LAMBDA_A
self.lambda_B = config.LOSS.LAMBDA_B
self.lambda_lip = config.LOSS.LAMBDA_MAKEUP_LIP
self.lambda_skin = config.LOSS.LAMBDA_MAKEUP_SKIN
self.lambda_eye = config.LOSS.LAMBDA_MAKEUP_EYE
self.lambda_vgg = config.LOSS.LAMBDA_VGG
self.device = args.device
self.keepon = args.keepon
self.logger = logger
self.loss_logger = {
'D-A-loss_real':[],
'D-A-loss_fake':[],
'D-B-loss_real':[],
'D-B-loss_fake':[],
'G-A-loss-adv':[],
'G-B-loss-adv':[],
'G-loss-idt':[],
'G-loss-img-rec':[],
'G-loss-vgg-rec':[],
'G-loss-rec':[],
'G-loss-skin-pgt':[],
'G-loss-eye-pgt':[],
'G-loss-lip-pgt':[],
'G-loss-pgt':[],
'G-loss':[],
'D-A-loss':[],
'D-B-loss':[]
}
self.build_model()
super(Solver, self).__init__()
def print_network(self, model, name):
num_params = 0
for p in model.parameters():
num_params += p.numel()
if self.logger is not None:
self.logger.info('{:s}, the number of parameters: {:d}'.format(name, num_params))
else:
print('{:s}, the number of parameters: {:d}'.format(name, num_params))
# For generator
def weights_init_xavier(self, m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
init.xavier_normal_(m.weight.data, gain=1.0)
elif classname.find('Linear') != -1:
init.xavier_normal_(m.weight.data, gain=1.0)
def build_model(self):
self.G.apply(self.weights_init_xavier)
self.D_A.apply(self.weights_init_xavier)
if self.double_d:
self.D_B.apply(self.weights_init_xavier)
if self.keepon:
self.load_checkpoint()
self.criterionL1 = torch.nn.L1Loss()
self.criterionL2 = torch.nn.MSELoss()
self.criterionGAN = GANLoss(gan_mode='lsgan')
self.criterionPGT = MakeupLoss()
self.vgg = vgg16(pretrained=True)
# Optimizers
self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
self.d_A_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D_A.parameters()), self.d_lr, [self.beta1, self.beta2])
if self.double_d:
self.d_B_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D_B.parameters()), self.d_lr, [self.beta1, self.beta2])
self.g_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.g_optimizer,
T_max=self.num_epochs, eta_min=self.g_lr * self.lr_decay_factor)
self.d_A_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.d_A_optimizer,
T_max=self.num_epochs, eta_min=self.d_lr * self.lr_decay_factor)
if self.double_d:
self.d_B_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.d_B_optimizer,
T_max=self.num_epochs, eta_min=self.d_lr * self.lr_decay_factor)
# Print networks
self.print_network(self.G, 'G')
self.print_network(self.D_A, 'D_A')
if self.double_d: self.print_network(self.D_B, 'D_B')
self.G.to(self.device)
self.vgg.to(self.device)
self.D_A.to(self.device)
if self.double_d: self.D_B.to(self.device)
def train(self, data_loader):
self.len_dataset = len(data_loader)
for self.epoch in range(1, self.num_epochs + 1):
self.start_time = time.time()
loss_tmp = self.get_loss_tmp()
self.G.train(); self.D_A.train();
if self.double_d: self.D_B.train()
losses_G = []; losses_D_A = []; losses_D_B = []
with tqdm(data_loader, desc="training") as pbar:
for step, (source, reference) in enumerate(pbar):
# image, mask, diff, lms
image_s, image_r = source[0].to(self.device), reference[0].to(self.device) # (b, c, h, w)
mask_s_full, mask_r_full = source[1].to(self.device), reference[1].to(self.device) # (b, c', h, w)
diff_s, diff_r = source[2].to(self.device), reference[2].to(self.device) # (b, 136, h, w)
lms_s, lms_r = source[3].to(self.device), reference[3].to(self.device) # (b, K, 2)
# process input mask
mask_s = torch.cat((mask_s_full[:,0:1], mask_s_full[:,1:].sum(dim=1, keepdim=True)), dim=1)
mask_r = torch.cat((mask_r_full[:,0:1], mask_r_full[:,1:].sum(dim=1, keepdim=True)), dim=1)
#mask_s = mask_s_full[:,:2]; mask_r = mask_r_full[:,:2]
# ================= Generate ================== #
fake_A = self.G(image_s, image_r, mask_s, mask_r, diff_s, diff_r, lms_s, lms_r)
fake_B = self.G(image_r, image_s, mask_r, mask_s, diff_r, diff_s, lms_r, lms_s)
# generate pseudo ground truth
pgt_A = self.pgt_maker(image_s, image_r, mask_s_full, mask_r_full, lms_s, lms_r)
pgt_B = self.pgt_maker(image_r, image_s, mask_r_full, mask_s_full, lms_r, lms_s)
# ================== Train D ================== #
# training D_A, D_A aims to distinguish class B
# Real
out = self.D_A(image_r)
d_loss_real = self.criterionGAN(out, True)
# Fake
out = self.D_A(fake_A.detach())
d_loss_fake = self.criterionGAN(out, False)
# Backward + Optimize
d_loss = (d_loss_real + d_loss_fake) * 0.5
self.d_A_optimizer.zero_grad()
d_loss.backward()
self.d_A_optimizer.step()
# Logging
loss_tmp['D-A-loss_real'] += d_loss_real.item()
loss_tmp['D-A-loss_fake'] += d_loss_fake.item()
losses_D_A.append(d_loss.item())
# training D_B, D_B aims to distinguish class A
# Real
if self.double_d:
out = self.D_B(image_s)
else:
out = self.D_A(image_s)
d_loss_real = self.criterionGAN(out, True)
# Fake
if self.double_d:
out = self.D_B(fake_B.detach())
else:
out = self.D_A(fake_B.detach())
d_loss_fake = self.criterionGAN(out, False)
# Backward + Optimize
d_loss = (d_loss_real+ d_loss_fake) * 0.5
if self.double_d:
self.d_B_optimizer.zero_grad()
d_loss.backward()
self.d_B_optimizer.step()
else:
self.d_A_optimizer.zero_grad()
d_loss.backward()
self.d_A_optimizer.step()
# Logging
loss_tmp['D-B-loss_real'] += d_loss_real.item()
loss_tmp['D-B-loss_fake'] += d_loss_fake.item()
losses_D_B.append(d_loss.item())
# ================== Train G ================== #
# G should be identity if ref_B or org_A is fed
idt_A = self.G(image_s, image_s, mask_s, mask_s, diff_s, diff_s, lms_s, lms_s)
idt_B = self.G(image_r, image_r, mask_r, mask_r, diff_r, diff_r, lms_r, lms_r)
loss_idt_A = self.criterionL1(idt_A, image_s) * self.lambda_A * self.lambda_idt
loss_idt_B = self.criterionL1(idt_B, image_r) * self.lambda_B * self.lambda_idt
# loss_idt
loss_idt = (loss_idt_A + loss_idt_B) * 0.5
# GAN loss D_A(G_A(A))
pred_fake = self.D_A(fake_A)
g_A_loss_adv = self.criterionGAN(pred_fake, True)
# GAN loss D_B(G_B(B))
if self.double_d:
pred_fake = self.D_B(fake_B)
else:
pred_fake = self.D_A(fake_B)
g_B_loss_adv = self.criterionGAN(pred_fake, True)
# Makeup loss
g_A_loss_pgt = 0; g_B_loss_pgt = 0
g_A_lip_loss_pgt = self.criterionPGT(fake_A, pgt_A, mask_s_full[:,0:1]) * self.lambda_lip
g_B_lip_loss_pgt = self.criterionPGT(fake_B, pgt_B, mask_r_full[:,0:1]) * self.lambda_lip
g_A_loss_pgt += g_A_lip_loss_pgt
g_B_loss_pgt += g_B_lip_loss_pgt
mask_s_eye = expand_area(mask_s_full[:,2:4].sum(dim=1, keepdim=True), self.margins['eye'])
mask_r_eye = expand_area(mask_r_full[:,2:4].sum(dim=1, keepdim=True), self.margins['eye'])
mask_s_eye = mask_s_eye * mask_s_full[:,1:2]
mask_r_eye = mask_r_eye * mask_r_full[:,1:2]
g_A_eye_loss_pgt = self.criterionPGT(fake_A, pgt_A, mask_s_eye) * self.lambda_eye
g_B_eye_loss_pgt = self.criterionPGT(fake_B, pgt_B, mask_r_eye) * self.lambda_eye
g_A_loss_pgt += g_A_eye_loss_pgt
g_B_loss_pgt += g_B_eye_loss_pgt
mask_s_skin = mask_s_full[:,1:2] * (1 - mask_s_eye)
mask_r_skin = mask_r_full[:,1:2] * (1 - mask_r_eye)
g_A_skin_loss_pgt = self.criterionPGT(fake_A, pgt_A, mask_s_skin) * self.lambda_skin
g_B_skin_loss_pgt = self.criterionPGT(fake_B, pgt_B, mask_r_skin) * self.lambda_skin
g_A_loss_pgt += g_A_skin_loss_pgt
g_B_loss_pgt += g_B_skin_loss_pgt
# cycle loss
rec_A = self.G(fake_A, image_s, mask_s, mask_s, diff_s, diff_s, lms_s, lms_s)
rec_B = self.G(fake_B, image_r, mask_r, mask_r, diff_r, diff_r, lms_r, lms_r)
# cycle loss v2
# rec_A = self.G(fake_A, fake_B, mask_s, mask_r, diff_s, diff_r, lms_s, lms_r)
# rec_B = self.G(fake_B, fake_A, mask_r, mask_s, diff_r, diff_s, lms_r, lms_s)
g_loss_rec_A = self.criterionL1(rec_A, image_s) * self.lambda_A
g_loss_rec_B = self.criterionL1(rec_B, image_r) * self.lambda_B
# vgg loss
vgg_s = self.vgg(image_s).detach()
vgg_fake_A = self.vgg(fake_A)
g_loss_A_vgg = self.criterionL2(vgg_fake_A, vgg_s) * self.lambda_A * self.lambda_vgg
vgg_r = self.vgg(image_r).detach()
vgg_fake_B = self.vgg(fake_B)
g_loss_B_vgg = self.criterionL2(vgg_fake_B, vgg_r) * self.lambda_B * self.lambda_vgg
loss_rec = (g_loss_rec_A + g_loss_rec_B + g_loss_A_vgg + g_loss_B_vgg) * 0.5
# Combined loss
g_loss = g_A_loss_adv + g_B_loss_adv + loss_rec + loss_idt + g_A_loss_pgt + g_B_loss_pgt
self.g_optimizer.zero_grad()
g_loss.backward()
self.g_optimizer.step()
# Logging
loss_tmp['G-A-loss-adv'] += g_A_loss_adv.item()
loss_tmp['G-B-loss-adv'] += g_B_loss_adv.item()
loss_tmp['G-loss-idt'] += loss_idt.item()
loss_tmp['G-loss-img-rec'] += (g_loss_rec_A + g_loss_rec_B).item() * 0.5
loss_tmp['G-loss-vgg-rec'] += (g_loss_A_vgg + g_loss_B_vgg).item() * 0.5
loss_tmp['G-loss-rec'] += loss_rec.item()
loss_tmp['G-loss-skin-pgt'] += (g_A_skin_loss_pgt + g_B_skin_loss_pgt).item()
loss_tmp['G-loss-eye-pgt'] += (g_A_eye_loss_pgt + g_B_eye_loss_pgt).item()
loss_tmp['G-loss-lip-pgt'] += (g_A_lip_loss_pgt + g_B_lip_loss_pgt).item()
loss_tmp['G-loss-pgt'] += (g_A_loss_pgt + g_B_loss_pgt).item()
losses_G.append(g_loss.item())
pbar.set_description("Epoch: %d, Step: %d, Loss_G: %0.4f, Loss_A: %0.4f, Loss_B: %0.4f" % \
(self.epoch, step + 1, np.mean(losses_G), np.mean(losses_D_A), np.mean(losses_D_B)))
self.end_time = time.time()
for k, v in loss_tmp.items():
loss_tmp[k] = v / self.len_dataset
loss_tmp['G-loss'] = np.mean(losses_G)
loss_tmp['D-A-loss'] = np.mean(losses_D_A)
loss_tmp['D-B-loss'] = np.mean(losses_D_B)
self.log_loss(loss_tmp)
self.plot_loss()
# Decay learning rate
self.g_scheduler.step()
self.d_A_scheduler.step()
if self.double_d:
self.d_B_scheduler.step()
if self.pgt_annealing:
self.pgt_maker.step()
#save the images
if (self.epoch) % self.vis_freq == 0:
self.vis_train([image_s.detach().cpu(),
image_r.detach().cpu(),
fake_A.detach().cpu(),
pgt_A.detach().cpu()])
# rec_A.detach().cpu()])
# Save model checkpoints
if (self.epoch) % self.save_freq == 0:
self.save_models()
def get_loss_tmp(self):
loss_tmp = {
'D-A-loss_real':0.0,
'D-A-loss_fake':0.0,
'D-B-loss_real':0.0,
'D-B-loss_fake':0.0,
'G-A-loss-adv':0.0,
'G-B-loss-adv':0.0,
'G-loss-idt':0.0,
'G-loss-img-rec':0.0,
'G-loss-vgg-rec':0.0,
'G-loss-rec':0.0,
'G-loss-skin-pgt':0.0,
'G-loss-eye-pgt':0.0,
'G-loss-lip-pgt':0.0,
'G-loss-pgt':0.0,
}
return loss_tmp
def log_loss(self, loss_tmp):
if self.logger is not None:
self.logger.info('\n' + '='*40 + '\nEpoch {:d}, time {:.2f} s'
.format(self.epoch, self.end_time - self.start_time))
else:
print('\n' + '='*40 + '\nEpoch {:d}, time {:d} s'
.format(self.epoch, self.end_time - self.start_time))
for k, v in loss_tmp.items():
self.loss_logger[k].append(v)
if self.logger is not None:
self.logger.info('{:s}\t{:.6f}'.format(k, v))
else:
print('{:s}\t{:.6f}'.format(k, v))
if self.logger is not None:
self.logger.info('='*40)
else:
print('='*40)
def plot_loss(self):
G_losses = []; G_names = []
D_A_losses = []; D_A_names = []
D_B_losses = []; D_B_names = []
D_P_losses = []; D_P_names = []
for k, v in self.loss_logger.items():
if 'G' in k:
G_names.append(k); G_losses.append(v)
elif 'D-A' in k:
D_A_names.append(k); D_A_losses.append(v)
elif 'D-B' in k:
D_B_names.append(k); D_B_losses.append(v)
elif 'D-P' in k:
D_P_names.append(k); D_P_losses.append(v)
plot_curves(self.save_folder, 'G_loss', G_losses, G_names, ylabel='Loss')
plot_curves(self.save_folder, 'D-A_loss', D_A_losses, D_A_names, ylabel='Loss')
plot_curves(self.save_folder, 'D-B_loss', D_B_losses, D_B_names, ylabel='Loss')
def load_checkpoint(self):
G_path = os.path.join(self.load_folder, 'G.pth')
if os.path.exists(G_path):
self.G.load_state_dict(torch.load(G_path, map_location=self.device))
print('loaded trained generator {}..!'.format(G_path))
D_A_path = os.path.join(self.load_folder, 'D_A.pth')
if os.path.exists(D_A_path):
self.D_A.load_state_dict(torch.load(D_A_path, map_location=self.device))
print('loaded trained discriminator A {}..!'.format(D_A_path))
if self.double_d:
D_B_path = os.path.join(self.load_folder, 'D_B.pth')
if os.path.exists(D_B_path):
self.D_B.load_state_dict(torch.load(D_B_path, map_location=self.device))
print('loaded trained discriminator B {}..!'.format(D_B_path))
def save_models(self):
save_dir = os.path.join(self.save_folder, 'epoch_{:d}'.format(self.epoch))
if not os.path.exists(save_dir):
os.makedirs(save_dir)
torch.save(self.G.state_dict(), os.path.join(save_dir, 'G.pth'))
torch.save(self.D_A.state_dict(), os.path.join(save_dir, 'D_A.pth'))
if self.double_d:
torch.save(self.D_B.state_dict(), os.path.join(save_dir, 'D_B.pth'))
def de_norm(self, x):
out = (x + 1) / 2
return out.clamp(0, 1)
def vis_train(self, img_train_batch):
# saving training results
img_train_batch = torch.cat(img_train_batch, dim=3)
save_path = os.path.join(self.vis_folder, 'epoch_{:d}_fake.png'.format(self.epoch))
vis_image = make_grid(self.de_norm(img_train_batch), 1)
save_image(vis_image, save_path) #, normalize=True)
def generate(self, image_A, image_B, mask_A=None, mask_B=None,
diff_A=None, diff_B=None, lms_A=None, lms_B=None):
"""image_A is content, image_B is style"""
with torch.no_grad():
res = self.G(image_A, image_B, mask_A, mask_B, diff_A, diff_B, lms_A, lms_B)
return res
def test(self, image_A, mask_A, diff_A, lms_A, image_B, mask_B, diff_B, lms_B):
with torch.no_grad():
fake_A = self.generate(image_A, image_B, mask_A, mask_B, diff_A, diff_B, lms_A, lms_B)
fake_A = self.de_norm(fake_A)
fake_A = fake_A.squeeze(0)
return ToPILImage()(fake_A.cpu())
================================================
FILE: training/utils.py
================================================
import os
import logging
import numpy as np
import matplotlib.pyplot as plt
def create_logger(save_path='', file_type='', level='debug', console=True):
if level == 'debug':
_level = logging.DEBUG
elif level == 'info':
_level = logging.INFO
logger = logging.getLogger()
logger.setLevel(_level)
if console:
cs = logging.StreamHandler()
cs.setLevel(_level)
logger.addHandler(cs)
if save_path != '':
file_name = os.path.join(save_path, file_type + '_log.txt')
fh = logging.FileHandler(file_name, mode='w')
fh.setLevel(_level)
logger.addHandler(fh)
return logger
def print_args(args, logger=None):
for k, v in vars(args).items():
if logger is not None:
logger.info('{:<16} : {}'.format(k, v))
else:
print('{:<16} : {}'.format(k, v))
def plot_single_curve(path, name, point, freq=1, xlabel='Epoch',ylabel=None):
x = (np.arange(len(point)) + 1) * freq
plt.plot(x, point, color='purple')
plt.xlabel(xlabel)
if ylabel is None:
ylabel = name
plt.ylabel(ylabel)
plt.savefig(os.path.join(path, name + '.png'))
plt.close()
def plot_curves(path, name, point_list, curve_names=None, freq=1, xlabel='Epoch',ylabel=None):
if curve_names is None:
curve_names = [''] * len(point_list)
else:
assert len(point_list) == len(curve_names)
x = (np.arange(len(point_list[0])) + 1) * freq
if len(point_list) <= 10:
cmap = plt.get_cmap('tab10')
else:
cmap = plt.get_cmap('tab20')
for i, (point, curve_name) in enumerate(zip(point_list, curve_names)):
assert len(point) == len(x)
plt.plot(x, point, color=cmap(i), label=curve_name)
plt.xlabel(xlabel)
if ylabel is not None:
plt.ylabel(ylabel)
plt.legend()
plt.savefig(os.path.join(path, name + '.png'))
plt.close()
gitextract_z1ehcgp9/
├── .gitignore
├── LICENSE
├── README.md
├── assets/
│ └── docs/
│ ├── install.md
│ └── prepare.md
├── concern/
│ ├── __init__.py
│ ├── image.py
│ ├── track.py
│ └── visualize.py
├── faceutils/
│ ├── __init__.py
│ ├── dlibutils/
│ │ ├── __init__.py
│ │ └── main.py
│ └── mask/
│ ├── __init__.py
│ ├── main.py
│ ├── model.py
│ └── resnet.py
├── models/
│ ├── __init__.py
│ ├── elegant.py
│ ├── loss.py
│ ├── model.py
│ └── modules/
│ ├── __init__.py
│ ├── histogram_matching.py
│ ├── module_attn.py
│ ├── module_base.py
│ ├── pseudo_gt.py
│ ├── sow_attention.py
│ ├── spectral_norm.py
│ └── tps_transform.py
├── scripts/
│ ├── demo.py
│ └── train.py
└── training/
├── __init__.py
├── config.py
├── dataset.py
├── inference.py
├── preprocess.py
├── solver.py
└── utils.py
SYMBOL INDEX (232 symbols across 25 files)
FILE: concern/image.py
function load_image (line 6) | def load_image(path):
function resize_by_max (line 15) | def resize_by_max(image, max_side=512, force=False):
function image2buffer (line 25) | def image2buffer(image):
FILE: concern/track.py
class Track (line 6) | class Track:
method __init__ (line 7) | def __init__(self):
method track (line 11) | def track(self, mark):
FILE: concern/visualize.py
function channel_first (line 5) | def channel_first(image, format):
function mask2image (line 9) | def mask2image(mask:np.array, format="HWC"):
function draw_points (line 18) | def draw_points(image, points, color=(255, 0, 0)):
FILE: faceutils/dlibutils/main.py
function detect (line 15) | def detect(image: Image) -> 'faces':
function crop (line 33) | def crop(image: Image, face, up_ratio, down_ratio, width_ratio) -> (Imag...
function crop_by_image_size (line 84) | def crop_by_image_size(image: Image, face) -> (Image, 'face'):
function landmarks (line 110) | def landmarks(image: Image, face):
function crop_from_array (line 114) | def crop_from_array(image: np.array, face) -> (np.array, 'face'):
FILE: faceutils/mask/main.py
class FaceParser (line 14) | class FaceParser:
method __init__ (line 15) | def __init__(self, device="cpu"):
method parse (line 30) | def parse(self, image: Image):
FILE: faceutils/mask/model.py
class ConvBNReLU (line 11) | class ConvBNReLU(nn.Module):
method __init__ (line 12) | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args...
method forward (line 23) | def forward(self, x):
method init_weight (line 28) | def init_weight(self):
class BiSeNetOutput (line 34) | class BiSeNetOutput(nn.Module):
method __init__ (line 35) | def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
method forward (line 41) | def forward(self, x):
method init_weight (line 46) | def init_weight(self):
method get_params (line 52) | def get_params(self):
class AttentionRefinementModule (line 64) | class AttentionRefinementModule(nn.Module):
method __init__ (line 65) | def __init__(self, in_chan, out_chan, *args, **kwargs):
method forward (line 73) | def forward(self, x):
method init_weight (line 82) | def init_weight(self):
class ContextPath (line 89) | class ContextPath(nn.Module):
method __init__ (line 90) | def __init__(self, *args, **kwargs):
method forward (line 101) | def forward(self, x):
method init_weight (line 124) | def init_weight(self):
method get_params (line 130) | def get_params(self):
class SpatialPath (line 143) | class SpatialPath(nn.Module):
method __init__ (line 144) | def __init__(self, *args, **kwargs):
method forward (line 152) | def forward(self, x):
method init_weight (line 159) | def init_weight(self):
method get_params (line 165) | def get_params(self):
class FeatureFusionModule (line 177) | class FeatureFusionModule(nn.Module):
method __init__ (line 178) | def __init__(self, in_chan, out_chan, *args, **kwargs):
method forward (line 197) | def forward(self, fsp, fcp):
method init_weight (line 209) | def init_weight(self):
method get_params (line 215) | def get_params(self):
class BiSeNet (line 227) | class BiSeNet(nn.Module):
method __init__ (line 228) | def __init__(self, n_classes, *args, **kwargs):
method forward (line 238) | def forward(self, x):
method init_weight (line 253) | def init_weight(self):
method get_params (line 259) | def get_params(self):
FILE: faceutils/mask/resnet.py
function conv3x3 (line 11) | def conv3x3(in_planes, out_planes, stride=1):
class BasicBlock (line 17) | class BasicBlock(nn.Module):
method __init__ (line 18) | def __init__(self, in_chan, out_chan, stride=1):
method forward (line 33) | def forward(self, x):
function create_layer_basic (line 48) | def create_layer_basic(in_chan, out_chan, bnum, stride=1):
class Resnet18 (line 55) | class Resnet18(nn.Module):
method __init__ (line 56) | def __init__(self):
method forward (line 68) | def forward(self, x):
method init_weight (line 79) | def init_weight(self):
method get_params (line 87) | def get_params(self):
FILE: models/elegant.py
class Generator (line 11) | class Generator(nn.ModuleDict):
method __init__ (line 13) | def __init__(self, conv_dim=64, image_size=256, num_layer_e=2, num_lay...
method get_transfer_input (line 104) | def get_transfer_input(self, image, mask, diff, lms, is_reference=False):
method get_transfer_output (line 137) | def get_transfer_output(self, fea_c_list, mask_c_list, diff_c_list, lm...
method decode (line 164) | def decode(self, fea_c_list, attn_out_list):
method forward (line 182) | def forward(self, c, s, mask_c, mask_s, diff_c, diff_s, lms_c, lms_s):
method tps_align (line 197) | def tps_align(self, feature_size, lms_s, lms_c, fea_s, sample_mode='bi...
FILE: models/loss.py
class GANLoss (line 9) | class GANLoss(nn.Module):
method __init__ (line 15) | def __init__(self, gan_mode='lsgan', target_real_label=1.0, target_fak...
method forward (line 35) | def forward(self, prediction, target_is_real):
function norm (line 53) | def norm(x: torch.Tensor):
function de_norm (line 56) | def de_norm(x: torch.Tensor):
function masked_his_match (line 60) | def masked_his_match(image_s, image_r, mask_s, mask_r):
function generate_pgt (line 86) | def generate_pgt(image_s, image_r, mask_s, mask_r, lms_s, lms_r, margins...
class LinearAnnealingFn (line 115) | class LinearAnnealingFn():
method __init__ (line 119) | def __init__(self, milestones, f_values):
method __call__ (line 124) | def __call__(self, t:int):
class ComposePGT (line 137) | class ComposePGT(nn.Module):
method __init__ (line 138) | def __init__(self, margins, skin_alpha, eye_alpha, lip_alpha):
method forward (line 148) | def forward(self, sources, targets, mask_srcs, mask_tars, lms_srcs, lm...
class AnnealingComposePGT (line 158) | class AnnealingComposePGT(nn.Module):
method __init__ (line 159) | def __init__(self, margins,
method step (line 174) | def step(self):
method forward (line 181) | def forward(self, sources, targets, mask_srcs, mask_tars, lms_srcs, lm...
class MakeupLoss (line 192) | class MakeupLoss(nn.Module):
method __init__ (line 196) | def __init__(self):
method forward (line 199) | def forward(self, x, target, mask=None):
FILE: models/model.py
function get_generator (line 11) | def get_generator(config):
function get_discriminator (line 27) | def get_discriminator(config):
class Discriminator (line 38) | class Discriminator(nn.Module):
method __init__ (line 40) | def __init__(self, input_channel=3, conv_dim=64, num_layers=3, norm='S...
method forward (line 73) | def forward(self, x):
class VGG (line 79) | class VGG(TVGG):
method forward (line 80) | def forward(self, x):
function make_layers (line 85) | def make_layers(cfg, batch_norm=False):
function _vgg (line 101) | def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
function vgg16 (line 112) | def vgg16(pretrained=False, progress=True, **kwargs):
FILE: models/modules/histogram_matching.py
function cal_hist (line 4) | def cal_hist(image):
function cal_trans (line 25) | def cal_trans(ref, adj):
function histogram_matching (line 39) | def histogram_matching(dstImg, refImg, index):
FILE: models/modules/module_attn.py
class MultiheadAttention_weight (line 7) | class MultiheadAttention_weight(nn.Module):
method __init__ (line 8) | def __init__(self, feature_dim, proj_dim, num_heads=1, dropout=0.0, bi...
method forward (line 21) | def forward(self, fea_c, fea_s, mask_c, mask_s):
class MultiheadAttention_value (line 55) | class MultiheadAttention_value(nn.Module):
method __init__ (line 56) | def __init__(self, feature_dim, proj_dim, num_heads=1, bias=True):
method forward (line 66) | def forward(self, weights, fea):
class MultiheadAttention (line 82) | class MultiheadAttention(nn.Module):
method __init__ (line 83) | def __init__(self, in_channels, proj_channels, value_channels, out_cha...
method forward (line 88) | def forward(self, fea_q, fea_k, fea_v, mask_q, mask_k):
class FeedForwardLayer (line 97) | class FeedForwardLayer(nn.Module):
method __init__ (line 98) | def __init__(self, feature_dim, ff_dim, dropout=0.0):
method forward (line 108) | def forward(self, x):
class Attention_apply (line 112) | class Attention_apply(nn.Module):
method __init__ (line 113) | def __init__(self, feature_dim, normalize=True):
method forward (line 121) | def forward(self, x, attn_out):
FILE: models/modules/module_base.py
class ResidualBlock (line 7) | class ResidualBlock(nn.Module):
method __init__ (line 9) | def __init__(self, dim_in, dim_out):
method forward (line 19) | def forward(self, x):
class ResidualBlock_IN (line 24) | class ResidualBlock_IN(nn.Module):
method __init__ (line 26) | def __init__(self, dim_in, dim_out, affine=False):
method forward (line 38) | def forward(self, x):
class ResidualBlock_Downsample (line 43) | class ResidualBlock_Downsample(nn.Module):
method __init__ (line 45) | def __init__(self, dim_in, dim_out, affine=False):
method forward (line 57) | def forward(self, x):
class Downsample (line 64) | class Downsample(nn.Module):
method __init__ (line 66) | def __init__(self, dim_in, dim_out, affine=False):
method forward (line 74) | def forward(self, x):
class ResidualBlock_Upsample (line 78) | class ResidualBlock_Upsample(nn.Module):
method __init__ (line 80) | def __init__(self, dim_in, dim_out, normalize=True, affine=False):
method forward (line 98) | def forward(self, x):
class Upsample (line 105) | class Upsample(nn.Module):
method __init__ (line 107) | def __init__(self, dim_in, dim_out, normalize=True, affine=False):
method forward (line 121) | def forward(self, x):
class PositionalEmbedding (line 125) | class PositionalEmbedding(nn.Module):
method __init__ (line 126) | def __init__(self, embedding_dim=136, feature_size=64, max_size=None, ...
method forward (line 135) | def forward(self, diff, mask):
class MergeBlock (line 165) | class MergeBlock(nn.Module):
method __init__ (line 166) | def __init__(self, merge_mode, feature_dim, normalize=True):
method forward (line 181) | def forward(self, fea_s, fea_r):
FILE: models/modules/pseudo_gt.py
function expand_area (line 11) | def expand_area(mask:torch.Tensor, margin:int):
function mask_blur (line 28) | def mask_blur(mask:torch.Tensor, blur_size=3, mode='smooth'):
function mask_blend (line 54) | def mask_blend(mask, blend_alpha, mask_bound=None, blur_size=3, blend_mo...
function tps_align (line 64) | def tps_align(img_size, lms_r, lms_s, image_r, image_s=None,
function tps_blend (line 86) | def tps_blend(blend_alpha, img_size, lms_r, lms_s, image_r, image_s=None...
function fine_align (line 110) | def fine_align(img_size, lms_r, lms_s, image_r, image_s, mask_r, mask_s,...
FILE: models/modules/sow_attention.py
class WindowAttention (line 6) | class WindowAttention(nn.Module):
method __init__ (line 7) | def __init__(self, window_size, in_channels, proj_channels, value_chan...
method generate_window_weight (line 33) | def generate_window_weight(self):
method make_window (line 41) | def make_window(self, x: torch.Tensor):
method demake_window (line 54) | def demake_window(self, x: torch.Tensor):
method make_mask_window (line 70) | def make_mask_window(self, mask: torch.Tensor):
method forward (line 83) | def forward(self, fea_q, fea_k, fea_v, mask_q=None, mask_k=None):
class SowAttention (line 118) | class SowAttention(nn.Module):
method __init__ (line 119) | def __init__(self, window_size, in_channels, proj_channels, value_chan...
method forward (line 128) | def forward(self, fea_q, fea_k, fea_v, mask_q=None, mask_k=None):
class StridedwindowAttention (line 182) | class StridedwindowAttention(nn.Module):
method __init__ (line 183) | def __init__(self, stride, in_channels, proj_channels, value_channels,...
method make_window (line 203) | def make_window(self, x: torch.Tensor):
method demake_window (line 218) | def demake_window(self, x: torch.Tensor, h, w):
method make_mask_window (line 231) | def make_mask_window(self, mask: torch.Tensor):
method forward (line 244) | def forward(self, fea_q, fea_k, fea_v, mask_q=None, mask_k=None):
FILE: models/modules/spectral_norm.py
function l2normalize (line 4) | def l2normalize(v, eps=1e-12):
class SpectralNorm (line 7) | class SpectralNorm(object):
method __init__ (line 8) | def __init__(self):
method compute_weight (line 13) | def compute_weight(self, module):
method apply (line 27) | def apply(module):
method remove (line 59) | def remove(self, module):
method __call__ (line 67) | def __call__(self, module, inputs):
function spectral_norm (line 70) | def spectral_norm(module):
function remove_spectral_norm (line 74) | def remove_spectral_norm(module):
FILE: models/modules/tps_transform.py
function grid_sample (line 15) | def grid_sample(input, grid, mode='bilinear', canvas=None):
function compute_partial_repr (line 27) | def compute_partial_repr(input_points, control_points):
function bulid_delta_inverse (line 44) | def bulid_delta_inverse(target_control_points):
function build_target_coordinate_matrix (line 62) | def build_target_coordinate_matrix(target_height, target_width, target_c...
function tps_sampler (line 81) | def tps_sampler(target_height, target_width, inverse_kernel, target_coor...
function tps_spatial_transform (line 102) | def tps_spatial_transform(target_height, target_width, target_control_po...
class TPSSpatialTransformer (line 116) | class TPSSpatialTransformer(nn.Module):
method __init__ (line 118) | def __init__(self, target_height, target_width, target_control_points):
method forward (line 135) | def forward(self, source, source_control_points):
FILE: scripts/demo.py
function main (line 14) | def main(config, args):
FILE: scripts/train.py
function main (line 14) | def main(config, args):
FILE: training/config.py
function get_config (line 94) | def get_config()->CfgNode:
FILE: training/dataset.py
class MakeupDataset (line 9) | class MakeupDataset(Dataset):
method __init__ (line 10) | def __init__(self, config=None):
method load_from_file (line 22) | def load_from_file(self, img_name):
method __len__ (line 29) | def __len__(self):
method __getitem__ (line 32) | def __getitem__(self, index):
function get_loader (line 41) | def get_loader(config):
FILE: training/inference.py
class InputSample (line 13) | class InputSample:
method __init__ (line 14) | def __init__(self, inputs, apply_mask=None):
method clear (line 20) | def clear(self):
class Inference (line 25) | class Inference:
method __init__ (line 31) | def __init__(self, config, args, model_path="G.pth"):
method prepare_input (line 41) | def prepare_input(self, *data_inputs):
method postprocess (line 52) | def postprocess(self, source, crop_face, result):
method generate_source_sample (line 74) | def generate_source_sample(self, source_input):
method generate_reference_sample (line 81) | def generate_reference_sample(self, reference_input, apply_mask=None,
method generate_partial_mask (line 97) | def generate_partial_mask(self, source_mask, mask_area='full', saturat...
method interface_transfer (line 124) | def interface_transfer(self, source_sample: InputSample, reference_sam...
method transfer (line 184) | def transfer(self, source: Image, reference: Image, postprocess=True):
method joint_transfer (line 209) | def joint_transfer(self, source: Image, reference_lip: Image, referenc...
FILE: training/preprocess.py
class PreProcess (line 15) | class PreProcess:
method __init__ (line 17) | def __init__(self, config, need_parser=True, device='cpu'):
method mask_process (line 54) | def mask_process(self, mask: torch.Tensor):
method save_mask (line 74) | def save_mask(self, mask: torch.Tensor, path):
method load_mask (line 80) | def load_mask(self, path):
method lms_process (line 87) | def lms_process(self, image:Image):
method diff_process (line 108) | def diff_process(self, lms: torch.Tensor, normalize=False):
method save_lms (line 121) | def save_lms(self, lms: torch.Tensor, path):
method load_lms (line 125) | def load_lms(self, path):
method preprocess (line 130) | def preprocess(self, image: Image, is_crop=True):
method process (line 170) | def process(self, image: Image, mask: torch.Tensor, lms: torch.Tensor):
method __call__ (line 176) | def __call__(self, image:Image, is_crop=True):
FILE: training/solver.py
class Solver (line 18) | class Solver():
method __init__ (line 19) | def __init__(self, config, args, logger=None, inference=False):
method print_network (line 100) | def print_network(self, model, name):
method weights_init_xavier (line 110) | def weights_init_xavier(self, m):
method build_model (line 117) | def build_model(self):
method train (line 154) | def train(self, data_loader):
method get_loss_tmp (line 355) | def get_loss_tmp(self):
method log_loss (line 374) | def log_loss(self, loss_tmp):
method plot_loss (line 392) | def plot_loss(self):
method load_checkpoint (line 410) | def load_checkpoint(self):
method save_models (line 426) | def save_models(self):
method de_norm (line 435) | def de_norm(self, x):
method vis_train (line 439) | def vis_train(self, img_train_batch):
method generate (line 446) | def generate(self, image_A, image_B, mask_A=None, mask_B=None,
method test (line 453) | def test(self, image_A, mask_A, diff_A, lms_A, image_B, mask_B, diff_B...
FILE: training/utils.py
function create_logger (line 6) | def create_logger(save_path='', file_type='', level='debug', console=True):
function print_args (line 29) | def print_args(args, logger=None):
function plot_single_curve (line 36) | def plot_single_curve(path, name, point, freq=1, xlabel='Epoch',ylabel=N...
function plot_curves (line 47) | def plot_curves(path, name, point_list, curve_names=None, freq=1, xlabel...
Condensed preview — 37 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (171K chars).
[
{
"path": ".gitignore",
"chars": 1436,
"preview": "/data\n/results\n*.dat\n*.pth\n*.pt\n\n# *.txt\n# !/expe/*/*.txt\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]"
},
{
"path": "LICENSE",
"chars": 20844,
"preview": "Attribution-NonCommercial-ShareAlike 4.0 International\n\n================================================================"
},
{
"path": "README.md",
"chars": 2316,
"preview": "# EleGANt: Exquisite and Locally Editable GAN for Makeup Transfer\n\n[![CC BY-NC-SA 4.0][cc-by-nc-sa-shield]][cc-by-nc-sa]"
},
{
"path": "assets/docs/install.md",
"chars": 581,
"preview": "# Installation Instructions\r\n\r\nThis code was tested on Ubuntu 20.04 with CUDA 11.1.\r\n\r\n**a. Create a conda virtual envir"
},
{
"path": "assets/docs/prepare.md",
"chars": 1497,
"preview": "# Preparation Instructions\r\n\r\nClone this repository and prepare the dataset and weights through the following steps:\r\n\r\n"
},
{
"path": "concern/__init__.py",
"chars": 30,
"preview": "from .image import load_image\n"
},
{
"path": "concern/image.py",
"chars": 734,
"preview": "import numpy as np\nimport cv2\nfrom io import BytesIO\n\n\ndef load_image(path):\n with path.open(\"rb\") as reader:\n "
},
{
"path": "concern/track.py",
"chars": 499,
"preview": "import time\n\nimport torch\n\n\nclass Track:\n def __init__(self):\n self.log_point = time.time()\n self.enabl"
},
{
"path": "concern/visualize.py",
"chars": 717,
"preview": "import numpy as np\nimport cv2\n\n\ndef channel_first(image, format):\n return image.transpose(\n format.index(\"C\"),"
},
{
"path": "faceutils/__init__.py",
"chars": 130,
"preview": "#!/usr/bin/python\n# -*- encoding: utf-8 -*-\n#from . import faceplusplus as fpp\nfrom . import dlibutils as dlib\nfrom . im"
},
{
"path": "faceutils/dlibutils/__init__.py",
"chars": 103,
"preview": "#!/usr/bin/python\n# -*- encoding: utf-8 -*-\nfrom .main import detect, crop, landmarks, crop_from_array\n"
},
{
"path": "faceutils/dlibutils/main.py",
"chars": 5779,
"preview": "#!/usr/bin/python\n# -*- encoding: utf-8 -*-\nimport os.path as osp\n\nimport numpy as np\nfrom PIL import Image\nimport dlib\n"
},
{
"path": "faceutils/mask/__init__.py",
"chars": 73,
"preview": "#!/usr/bin/python\n# -*- encoding: utf-8 -*-\nfrom .main import FaceParser\n"
},
{
"path": "faceutils/mask/main.py",
"chars": 1241,
"preview": "#!/usr/bin/python\n# -*- encoding: utf-8 -*-\nimport os.path as osp\n\nimport numpy as np\nimport cv2\nfrom PIL import Image\ni"
},
{
"path": "faceutils/mask/model.py",
"chars": 10539,
"preview": "#!/usr/bin/python\n# -*- encoding: utf-8 -*-\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport to"
},
{
"path": "faceutils/mask/resnet.py",
"chars": 3593,
"preview": "#!/usr/bin/python\n# -*- encoding: utf-8 -*-\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport to"
},
{
"path": "models/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "models/elegant.py",
"chars": 8927,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .modules.module_base import ResidualBlock_IN, D"
},
{
"path": "models/loss.py",
"chars": 7618,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .modules.histogram_matching import histogram_ma"
},
{
"path": "models/model.py",
"chars": 4241,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torchvision.models import VGG as TVGG\nfrom torch"
},
{
"path": "models/modules/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "models/modules/histogram_matching.py",
"chars": 2078,
"preview": "import copy\nimport torch\n\ndef cal_hist(image):\n \"\"\"\n cal cumulative hist for channel list\n \"\"\"\n hists = "
},
{
"path": "models/modules/module_attn.py",
"chars": 5233,
"preview": "import math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass MultiheadAttention_weight(nn.Modu"
},
{
"path": "models/modules/module_base.py",
"chars": 7943,
"preview": "import math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass ResidualBlock(nn.Module):\n \"\"\""
},
{
"path": "models/modules/pseudo_gt.py",
"chars": 6500,
"preview": "import cv2\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torchvision.transf"
},
{
"path": "models/modules/sow_attention.py",
"chars": 12747,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass WindowAttention(nn.Module):\n def __init__("
},
{
"path": "models/modules/spectral_norm.py",
"chars": 2791,
"preview": "import torch\nfrom torch.nn import Parameter\n\ndef l2normalize(v, eps=1e-12):\n return v / (v.norm() + eps)\n\nclass Spect"
},
{
"path": "models/modules/tps_transform.py",
"chars": 6127,
"preview": "from __future__ import absolute_import\n\nimport numpy as np\nimport itertools\n\nimport torch\nimport torch.nn as nn\nimport t"
},
{
"path": "scripts/demo.py",
"chars": 2237,
"preview": "import os\nimport sys\nimport argparse\nimport numpy as np\nimport cv2\nimport torch\nfrom PIL import Image\nsys.path.append('."
},
{
"path": "scripts/train.py",
"chars": 1637,
"preview": "import os\nimport sys\nimport argparse\nimport torch\nfrom torch.utils.data import DataLoader\nsys.path.append('.')\n\nfrom tra"
},
{
"path": "training/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "training/config.py",
"chars": 2557,
"preview": "from fvcore.common.config import CfgNode\n\n\"\"\"\nThis file defines default options of configurations.\nIt will be further me"
},
{
"path": "training/dataset.py",
"chars": 2196,
"preview": "import os\nfrom PIL import Image\nimport torch\nfrom torch.utils.data import Dataset, DataLoader\n\nfrom training.config impo"
},
{
"path": "training/inference.py",
"chars": 9960,
"preview": "from typing import List\nimport numpy as np\nimport cv2\nfrom PIL import Image\nimport torch\nimport torch.nn.functional as F"
},
{
"path": "training/preprocess.py",
"chars": 8774,
"preview": "import os\nimport sys\nimport cv2\nfrom PIL import Image\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfr"
},
{
"path": "training/solver.py",
"chars": 21006,
"preview": "import os\nimport time\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torchvi"
},
{
"path": "training/utils.py",
"chars": 1939,
"preview": "import os\nimport logging\nimport numpy as np\nimport matplotlib.pyplot as plt\n\ndef create_logger(save_path='', file_type='"
}
]
About this extraction
This page contains the full source code of the Chenyu-Yang-2000/EleGANt GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 37 files (160.8 KB), approximately 42.9k tokens, and a symbol index with 232 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.