Repository: Chenyu-Yang-2000/EleGANt Branch: main Commit: d033d3398751 Files: 37 Total size: 160.8 KB Directory structure: gitextract_z1ehcgp9/ ├── .gitignore ├── LICENSE ├── README.md ├── assets/ │ └── docs/ │ ├── install.md │ └── prepare.md ├── concern/ │ ├── __init__.py │ ├── image.py │ ├── track.py │ └── visualize.py ├── faceutils/ │ ├── __init__.py │ ├── dlibutils/ │ │ ├── __init__.py │ │ └── main.py │ └── mask/ │ ├── __init__.py │ ├── main.py │ ├── model.py │ └── resnet.py ├── models/ │ ├── __init__.py │ ├── elegant.py │ ├── loss.py │ ├── model.py │ └── modules/ │ ├── __init__.py │ ├── histogram_matching.py │ ├── module_attn.py │ ├── module_base.py │ ├── pseudo_gt.py │ ├── sow_attention.py │ ├── spectral_norm.py │ └── tps_transform.py ├── scripts/ │ ├── demo.py │ └── train.py └── training/ ├── __init__.py ├── config.py ├── dataset.py ├── inference.py ├── preprocess.py ├── solver.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ /data /results *.dat *.pth *.pt # *.txt # !/expe/*/*.txt # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec .idea/ # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # celery beat schedule file celerybeat-schedule # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ .vscode/settings.json ================================================ FILE: LICENSE ================================================ Attribution-NonCommercial-ShareAlike 4.0 International ======================================================================= Creative Commons Corporation ("Creative Commons") is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an "as-is" basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible. Using Creative Commons Public Licenses Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses. Considerations for licensors: Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC- licensed material, or material used under an exception or limitation to copyright. More considerations for licensors: wiki.creativecommons.org/Considerations_for_licensors Considerations for the public: By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor's permission is not necessary for any reason--for example, because of any applicable exception or limitation to copyright--then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. More considerations for the public: wiki.creativecommons.org/Considerations_for_licensees ======================================================================= Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. Section 1 -- Definitions. a. Adapted Material means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. b. Adapter's License means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License. c. BY-NC-SA Compatible License means a license listed at creativecommons.org/compatiblelicenses, approved by Creative Commons as essentially the equivalent of this Public License. d. Copyright and Similar Rights means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. e. Effective Technological Measures means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. f. Exceptions and Limitations means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. g. License Elements means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution, NonCommercial, and ShareAlike. h. Licensed Material means the artistic or literary work, database, or other material to which the Licensor applied this Public License. i. Licensed Rights means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. j. Licensor means the individual(s) or entity(ies) granting rights under this Public License. k. NonCommercial means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange. l. Share means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. m. Sui Generis Database Rights means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. n. You means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning. Section 2 -- Scope. a. License grant. 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: a. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and b. produce, reproduce, and Share Adapted Material for NonCommercial purposes only. 2. Exceptions and Limitations. For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. 3. Term. The term of this Public License is specified in Section 6(a). 4. Media and formats; technical modifications allowed. The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a) (4) never produces Adapted Material. 5. Downstream recipients. a. Offer from the Licensor -- Licensed Material. Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. b. Additional offer from the Licensor -- Adapted Material. Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter's License You apply. c. No downstream restrictions. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. 6. No endorsement. Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). b. Other rights. 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. 2. Patent and trademark rights are not licensed under this Public License. 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes. Section 3 -- License Conditions. Your exercise of the Licensed Rights is expressly made subject to the following conditions. a. Attribution. 1. If You Share the Licensed Material (including in modified form), You must: a. retain the following if it is supplied by the Licensor with the Licensed Material: i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); ii. a copyright notice; iii. a notice that refers to this Public License; iv. a notice that refers to the disclaimer of warranties; v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable; b. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and c. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. b. ShareAlike. In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply. 1. The Adapter's License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-NC-SA Compatible License. 2. You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material. 3. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply. Section 4 -- Sui Generis Database Rights. Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only; b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. Section 5 -- Disclaimer of Warranties and Limitation of Liability. a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. Section 6 -- Term and Termination. a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 2. upon express reinstatement by the Licensor. For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License. Section 7 -- Other Terms and Conditions. a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. Section 8 -- Interpretation. a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. ======================================================================= Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” The text of the Creative Commons public licenses is dedicated to the public domain under the CC0 Public Domain Dedication. Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at creativecommons.org/policies, Creative Commons does not authorize the use of the trademark "Creative Commons" or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses. Creative Commons may be contacted at creativecommons.org. ================================================ FILE: README.md ================================================ # EleGANt: Exquisite and Locally Editable GAN for Makeup Transfer [![CC BY-NC-SA 4.0][cc-by-nc-sa-shield]][cc-by-nc-sa] Official [PyTorch](https://pytorch.org/) implementation of ECCV 2022 paper "[EleGANt: Exquisite and Locally Editable GAN for Makeup Transfer](https://arxiv.org/abs/2207.09840)" *Chenyu Yang, Wanrong He, Yingqing Xu, and Yang Gao*. ![teaser](assets/figs/teaser.png) ## Getting Started - [Installation](assets/docs/install.md) - [Prepare Dataset & Checkpoints](assets/docs/prepare.md) ## Test To test our model, download the [weights](https://drive.google.com/drive/folders/1xzIS3Dfmsssxkk9OhhAS4svrZSPfQYRe?usp=sharing) of the trained model and run ```bash python scripts/demo.py ``` Examples of makeup transfer results can be seen [here](assets/images/examples/). ## Train To train a model from scratch, run ```bash python scripts/train.py ``` ## Customized Transfer https://user-images.githubusercontent.com/61506577/180593092-ccadddff-76be-4b7b-921e-0d3b4cc27d9b.mp4 This is our demo of customized makeup editing. The interactive system is built upon [Streamlit](https://github.com/streamlit/streamlit) and the interface in `./training/inference.py`. **Controllable makeup transfer.** ![control](assets/figs/control.png 'controllable makeup transfer') **Local makeup editing.** ![edit](assets/figs/edit.png 'local makeup editing') ## Citation If this work is helpful for your research, please consider citing the following BibTeX entry. ```text @article{yang2022elegant, title={EleGANt: Exquisite and Locally Editable GAN for Makeup Transfer}, author={Yang, Chenyu and He, Wanrong and Xu, Yingqing and Gao, Yang} journal={arXiv preprint arXiv:2207.09840}, year={2022} } ``` ## Acknowledgement Some of the codes are build upon [PSGAN](https://github.com/wtjiang98/PSGAN) and [aster.Pytorch](https://github.com/ayumiymk/aster.pytorch). ## License This work is licensed under a [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License][cc-by-nc-sa]. [![CC BY-NC-SA 4.0][cc-by-nc-sa-image]][cc-by-nc-sa] [cc-by-nc-sa]: http://creativecommons.org/licenses/by-nc-sa/4.0/ [cc-by-nc-sa-image]: https://licensebuttons.net/l/by-nc-sa/4.0/88x31.png [cc-by-nc-sa-shield]: https://img.shields.io/badge/License-CC%20BY--NC--SA%204.0-lightgrey.svg ================================================ FILE: assets/docs/install.md ================================================ # Installation Instructions This code was tested on Ubuntu 20.04 with CUDA 11.1. **a. Create a conda virtual environment and activate it.** ```bash conda create -n elegant python=3.8 conda activate elegant ``` **b. Install PyTorch and torchvision following the [official instructions](https://pytorch.org/).** ```bash pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html ``` **c. Install other required libaries.** ```bash pip install opencv-python matplotlib dlib fvcore ``` ================================================ FILE: assets/docs/prepare.md ================================================ # Preparation Instructions Clone this repository and prepare the dataset and weights through the following steps: **a. Prepare model weights for face detection.** Download the weights of [dlib](https://github.com/davisking/dlib) face detector of 68 landmarks [here](http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2). Unzip it and move it to the directory `./faceutils/dlibutils`. Download the weights of BiSeNet ([PyTorch implementation](https://github.com/zllrunning/face-parsing.PyTorch)) for face parsing [here](https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812). Rename it as `resnet.pth` and move it to the directory `./faceutils/mask`. **b. Prepare Makeup Transfer (MT) dataset.** Download raw data of the MT Dataset [here](https://github.com/wtjiang98/PSGAN) and unzip it into sub directory `./data`. Run the following command to preprocess data: ```bash python training/preprocess.py ``` Your data directory should look like: ```text data └── MT-Dataset ├── images │   ├── makeup │   └── non-makeup ├── segs │   ├── makeup │   └── non-makeup ├── lms    │ ├── makeup    │ └── non-makeup ├── makeup.txt ├── non-makeup.txt └── ... ``` **c. Download weights of trained EleGANt.** The weights of our trained model can be download [here](https://drive.google.com/drive/folders/1xzIS3Dfmsssxkk9OhhAS4svrZSPfQYRe?usp=sharing). Put it under the directory `./ckpts`. ================================================ FILE: concern/__init__.py ================================================ from .image import load_image ================================================ FILE: concern/image.py ================================================ import numpy as np import cv2 from io import BytesIO def load_image(path): with path.open("rb") as reader: data = np.fromstring(reader.read(), dtype=np.uint8) img = cv2.imdecode(data, cv2.IMREAD_COLOR) if img is None: return img = img[..., ::-1] return img def resize_by_max(image, max_side=512, force=False): h, w = image.shape[:2] if max(h, w) < max_side and not force: return image ratio = max(h, w) / max_side w = int(w / ratio + 0.5) h = int(h / ratio + 0.5) return cv2.resize(image, (w, h)) def image2buffer(image): is_success, buffer = cv2.imencode(".jpg", image) if not is_success: return None return BytesIO(buffer) ================================================ FILE: concern/track.py ================================================ import time import torch class Track: def __init__(self): self.log_point = time.time() self.enable_track = False def track(self, mark): if not self.enable_track: return if torch.cuda.is_available(): torch.cuda.synchronize() print("{} memory:".format(mark), torch.cuda.memory_allocated() / 1024 / 1024, "M") print("{} time cost:".format(mark), time.time() - self.log_point) self.log_point = time.time() ================================================ FILE: concern/visualize.py ================================================ import numpy as np import cv2 def channel_first(image, format): return image.transpose( format.index("C"), format.index("H"), format.index("W")) def mask2image(mask:np.array, format="HWC"): H, W = mask.shape canvas = np.zeros((H, W, 3), dtype=np.uint8) for i in range(int(mask.max())): color = np.random.rand(1, 1, 3) * 255 canvas += (mask == i)[:, :, None] * color.astype(np.uint8) return canvas def draw_points(image, points, color=(255, 0, 0)): for point in points: print(int(point[1]), int(point[0])) image = cv2.circle(image, (int(point[1]), int(point[0])), 3, color) if hasattr(image, "get"): return image.get() return image ================================================ FILE: faceutils/__init__.py ================================================ #!/usr/bin/python # -*- encoding: utf-8 -*- #from . import faceplusplus as fpp from . import dlibutils as dlib from . import mask ================================================ FILE: faceutils/dlibutils/__init__.py ================================================ #!/usr/bin/python # -*- encoding: utf-8 -*- from .main import detect, crop, landmarks, crop_from_array ================================================ FILE: faceutils/dlibutils/main.py ================================================ #!/usr/bin/python # -*- encoding: utf-8 -*- import os.path as osp import numpy as np from PIL import Image import dlib import cv2 from concern.image import resize_by_max detector = dlib.get_frontal_face_detector() predictor = dlib.shape_predictor(osp.split(osp.realpath(__file__))[0] + '/shape_predictor_68_face_landmarks.dat') def detect(image: Image) -> 'faces': image = np.asarray(image) h, w = image.shape[:2] image = resize_by_max(image, 361) actual_h, actual_w = image.shape[:2] faces_on_small = detector(image, 1) faces = dlib.rectangles() for face in faces_on_small: faces.append( dlib.rectangle( int(face.left() / actual_w * w + 0.5), int(face.top() / actual_h * h + 0.5), int(face.right() / actual_w * w + 0.5), int(face.bottom() / actual_h * h + 0.5) ) ) return faces def crop(image: Image, face, up_ratio, down_ratio, width_ratio) -> (Image, 'face'): width, height = image.size face_height = face.height() face_width = face.width() delta_up = up_ratio * face_height delta_down = down_ratio * face_height delta_width = width_ratio * width img_left = int(max(0, face.left() - delta_width)) img_top = int(max(0, face.top() - delta_up)) img_right = int(min(width, face.right() + delta_width)) img_bottom = int(min(height, face.bottom() + delta_down)) image = image.crop((img_left, img_top, img_right, img_bottom)) face = dlib.rectangle(face.left() - img_left, face.top() - img_top, face.right() - img_left, face.bottom() - img_top) face_expand = dlib.rectangle(img_left, img_top, img_right, img_bottom) center = face_expand.center() width, height = image.size # import ipdb; ipdb.set_trace() crop_left = img_left crop_top = img_top crop_right = img_right crop_bottom = img_bottom if width > height: left = int(center.x - height / 2) right = int(center.x + height / 2) if left < 0: left, right = 0, height elif right > width: left, right = width - height, width image = image.crop((left, 0, right, height)) face = dlib.rectangle(face.left() - left, face.top(), face.right() - left, face.bottom()) crop_left += left crop_right = crop_left + height elif width < height: top = int(center.y - width / 2) bottom = int(center.y + width / 2) if top < 0: top, bottom = 0, width elif bottom > height: top, bottom = height - width, height image = image.crop((0, top, width, bottom)) face = dlib.rectangle(face.left(), face.top() - top, face.right(), face.bottom() - top) crop_top += top crop_bottom = crop_top + width crop_face = dlib.rectangle(crop_left, crop_top, crop_right, crop_bottom) return image, face, crop_face def crop_by_image_size(image: Image, face) -> (Image, 'face'): center = face.center() width, height = image.size if width > height: left = int(center.x - height / 2) right = int(center.x + height / 2) if left < 0: left, right = 0, height elif right > width: left, right = width - height, width image = image.crop((left, 0, right, height)) face = dlib.rectangle(face.left() - left, face.top(), face.right() - left, face.bottom()) elif width < height: top = int(center.y - width / 2) bottom = int(center.y + width / 2) if top < 0: top, bottom = 0, width elif bottom > height: top, bottom = height - width, height image = image.crop((0, top, width, bottom)) face = dlib.rectangle(face.left(), face.top() - top, face.right(), face.bottom() - top) return image, face def landmarks(image: Image, face): shape = predictor(np.asarray(image), face).parts() return np.array([[p.y, p.x] for p in shape]) def crop_from_array(image: np.array, face) -> (np.array, 'face'): ratio = 0.20 / 0.85 # delta_size / face_size height, width = image.shape[:2] face_height = face.height() face_width = face.width() delta_height = ratio * face_height delta_width = ratio * width img_left = int(max(0, face.left() - delta_width)) img_top = int(max(0, face.top() - delta_height)) img_right = int(min(width, face.right() + delta_width)) img_bottom = int(min(height, face.bottom() + delta_height)) image = image[img_top:img_bottom, img_left:img_right] face = dlib.rectangle(face.left() - img_left, face.top() - img_top, face.right() - img_left, face.bottom() - img_top) center = face.center() height, width = image.shape[:2] if width > height: left = int(center.x - height / 2) right = int(center.x + height / 2) if left < 0: left, right = 0, height elif right > width: left, right = width - height, width image = image[0:height, left:right] face = dlib.rectangle(face.left() - left, face.top(), face.right() - left, face.bottom()) elif width < height: top = int(center.y - width / 2) bottom = int(center.y + width / 2) if top < 0: top, bottom = 0, width elif bottom > height: top, bottom = height - width, height image = image[top:bottom, 0:width] face = dlib.rectangle(face.left(), face.top() - top, face.right(), face.bottom() - top) return image, face ================================================ FILE: faceutils/mask/__init__.py ================================================ #!/usr/bin/python # -*- encoding: utf-8 -*- from .main import FaceParser ================================================ FILE: faceutils/mask/main.py ================================================ #!/usr/bin/python # -*- encoding: utf-8 -*- import os.path as osp import numpy as np import cv2 from PIL import Image import torch import torchvision.transforms as transforms from .model import BiSeNet class FaceParser: def __init__(self, device="cpu"): mapper = [0, 1, 2, 3, 4, 5, 0, 11, 12, 0, 6, 8, 7, 9, 13, 0, 0, 10, 0] self.device = device self.dic = torch.tensor(mapper, device=device).unsqueeze(1) save_pth = osp.split(osp.realpath(__file__))[0] + '/resnet.pth' net = BiSeNet(n_classes=19) net.load_state_dict(torch.load(save_pth, map_location=device)) self.net = net.to(device).eval() self.to_tensor = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) def parse(self, image: Image): assert image.shape[:2] == (512, 512) with torch.no_grad(): image = self.to_tensor(image).to(self.device) image = torch.unsqueeze(image, 0) out = self.net(image)[0] parsing = out.squeeze(0).argmax(0) parsing = torch.nn.functional.embedding(parsing, self.dic) return parsing.float().squeeze(2) ================================================ FILE: faceutils/mask/model.py ================================================ #!/usr/bin/python # -*- encoding: utf-8 -*- import torch import torch.nn as nn import torch.nn.functional as F import torchvision from .resnet import Resnet18 class ConvBNReLU(nn.Module): def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): super(ConvBNReLU, self).__init__() self.conv = nn.Conv2d(in_chan, out_chan, kernel_size = ks, stride = stride, padding = padding, bias = False) self.bn = nn.BatchNorm2d(out_chan) self.init_weight() def forward(self, x): x = self.conv(x) x = F.relu(self.bn(x)) return x def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) if not ly.bias is None: nn.init.constant_(ly.bias, 0) class BiSeNetOutput(nn.Module): def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): super(BiSeNetOutput, self).__init__() self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) self.init_weight() def forward(self, x): x = self.conv(x) x = self.conv_out(x) return x def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) if not ly.bias is None: nn.init.constant_(ly.bias, 0) def get_params(self): wd_params, nowd_params = [], [] for name, module in self.named_modules(): if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): wd_params.append(module.weight) if not module.bias is None: nowd_params.append(module.bias) elif isinstance(module, nn.BatchNorm2d): nowd_params += list(module.parameters()) return wd_params, nowd_params class AttentionRefinementModule(nn.Module): def __init__(self, in_chan, out_chan, *args, **kwargs): super(AttentionRefinementModule, self).__init__() self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False) self.bn_atten = nn.BatchNorm2d(out_chan) self.sigmoid_atten = nn.Sigmoid() self.init_weight() def forward(self, x): feat = self.conv(x) atten = F.avg_pool2d(feat, feat.size()[2:]) atten = self.conv_atten(atten) atten = self.bn_atten(atten) atten = self.sigmoid_atten(atten) out = torch.mul(feat, atten) return out def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) if not ly.bias is None: nn.init.constant_(ly.bias, 0) class ContextPath(nn.Module): def __init__(self, *args, **kwargs): super(ContextPath, self).__init__() self.resnet = Resnet18() self.arm16 = AttentionRefinementModule(256, 128) self.arm32 = AttentionRefinementModule(512, 128) self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) self.init_weight() def forward(self, x): H0, W0 = x.size()[2:] feat8, feat16, feat32 = self.resnet(x) H8, W8 = feat8.size()[2:] H16, W16 = feat16.size()[2:] H32, W32 = feat32.size()[2:] avg = F.avg_pool2d(feat32, feat32.size()[2:]) avg = self.conv_avg(avg) avg_up = F.interpolate(avg, (H32, W32), mode='nearest') feat32_arm = self.arm32(feat32) feat32_sum = feat32_arm + avg_up feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') feat32_up = self.conv_head32(feat32_up) feat16_arm = self.arm16(feat16) feat16_sum = feat16_arm + feat32_up feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') feat16_up = self.conv_head16(feat16_up) return feat8, feat16_up, feat32_up # x8, x8, x16 def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) if not ly.bias is None: nn.init.constant_(ly.bias, 0) def get_params(self): wd_params, nowd_params = [], [] for name, module in self.named_modules(): if isinstance(module, (nn.Linear, nn.Conv2d)): wd_params.append(module.weight) if not module.bias is None: nowd_params.append(module.bias) elif isinstance(module, nn.BatchNorm2d): nowd_params += list(module.parameters()) return wd_params, nowd_params ### This is not used, since I replace this with the resnet feature with the same size class SpatialPath(nn.Module): def __init__(self, *args, **kwargs): super(SpatialPath, self).__init__() self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) self.init_weight() def forward(self, x): feat = self.conv1(x) feat = self.conv2(feat) feat = self.conv3(feat) feat = self.conv_out(feat) return feat def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) if not ly.bias is None: nn.init.constant_(ly.bias, 0) def get_params(self): wd_params, nowd_params = [], [] for name, module in self.named_modules(): if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): wd_params.append(module.weight) if not module.bias is None: nowd_params.append(module.bias) elif isinstance(module, nn.BatchNorm2d): nowd_params += list(module.parameters()) return wd_params, nowd_params class FeatureFusionModule(nn.Module): def __init__(self, in_chan, out_chan, *args, **kwargs): super(FeatureFusionModule, self).__init__() self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) self.conv1 = nn.Conv2d(out_chan, out_chan//4, kernel_size = 1, stride = 1, padding = 0, bias = False) self.conv2 = nn.Conv2d(out_chan//4, out_chan, kernel_size = 1, stride = 1, padding = 0, bias = False) self.relu = nn.ReLU(inplace=True) self.sigmoid = nn.Sigmoid() self.init_weight() def forward(self, fsp, fcp): fcat = torch.cat([fsp, fcp], dim=1) feat = self.convblk(fcat) atten = F.avg_pool2d(feat, feat.size()[2:]) atten = self.conv1(atten) atten = self.relu(atten) atten = self.conv2(atten) atten = self.sigmoid(atten) feat_atten = torch.mul(feat, atten) feat_out = feat_atten + feat return feat_out def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) if not ly.bias is None: nn.init.constant_(ly.bias, 0) def get_params(self): wd_params, nowd_params = [], [] for name, module in self.named_modules(): if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): wd_params.append(module.weight) if not module.bias is None: nowd_params.append(module.bias) elif isinstance(module, nn.BatchNorm2d): nowd_params += list(module.parameters()) return wd_params, nowd_params class BiSeNet(nn.Module): def __init__(self, n_classes, *args, **kwargs): super(BiSeNet, self).__init__() self.cp = ContextPath() ## here self.sp is deleted self.ffm = FeatureFusionModule(256, 256) self.conv_out = BiSeNetOutput(256, 256, n_classes) self.conv_out16 = BiSeNetOutput(128, 64, n_classes) self.conv_out32 = BiSeNetOutput(128, 64, n_classes) # self.init_weight() def forward(self, x): H, W = x.size()[2:] feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature feat_fuse = self.ffm(feat_sp, feat_cp8) feat_out = self.conv_out(feat_fuse) feat_out16 = self.conv_out16(feat_cp8) feat_out32 = self.conv_out32(feat_cp16) feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) return feat_out, feat_out16, feat_out32 def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) if not ly.bias is None: nn.init.constant_(ly.bias, 0) def get_params(self): wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] for name, child in self.named_children(): child_wd_params, child_nowd_params = child.get_params() if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput): lr_mul_wd_params += child_wd_params lr_mul_nowd_params += child_nowd_params else: wd_params += child_wd_params nowd_params += child_nowd_params return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params if __name__ == "__main__": net = BiSeNet(19) net.cuda() net.eval() in_ten = torch.randn(16, 3, 640, 480).cuda() out, out16, out32 = net(in_ten) print(out.shape) net.get_params() ================================================ FILE: faceutils/mask/resnet.py ================================================ #!/usr/bin/python # -*- encoding: utf-8 -*- import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.model_zoo as modelzoo resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class BasicBlock(nn.Module): def __init__(self, in_chan, out_chan, stride=1): super(BasicBlock, self).__init__() self.conv1 = conv3x3(in_chan, out_chan, stride) self.bn1 = nn.BatchNorm2d(out_chan) self.conv2 = conv3x3(out_chan, out_chan) self.bn2 = nn.BatchNorm2d(out_chan) self.relu = nn.ReLU(inplace=True) self.downsample = None if in_chan != out_chan or stride != 1: self.downsample = nn.Sequential( nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_chan), ) def forward(self, x): residual = self.conv1(x) residual = F.relu(self.bn1(residual)) residual = self.conv2(residual) residual = self.bn2(residual) shortcut = x if self.downsample is not None: shortcut = self.downsample(x) out = shortcut + residual out = self.relu(out) return out def create_layer_basic(in_chan, out_chan, bnum, stride=1): layers = [BasicBlock(in_chan, out_chan, stride=stride)] for i in range(bnum-1): layers.append(BasicBlock(out_chan, out_chan, stride=1)) return nn.Sequential(*layers) class Resnet18(nn.Module): def __init__(self): super(Resnet18, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) # self.init_weight() def forward(self, x): x = self.conv1(x) x = F.relu(self.bn1(x)) x = self.maxpool(x) x = self.layer1(x) feat8 = self.layer2(x) # 1/8 feat16 = self.layer3(feat8) # 1/16 feat32 = self.layer4(feat16) # 1/32 return feat8, feat16, feat32 def init_weight(self): state_dict = modelzoo.load_url(resnet18_url) self_state_dict = self.state_dict() for k, v in state_dict.items(): if 'fc' in k: continue self_state_dict.update({k: v}) self.load_state_dict(self_state_dict) def get_params(self): wd_params, nowd_params = [], [] for name, module in self.named_modules(): if isinstance(module, (nn.Linear, nn.Conv2d)): wd_params.append(module.weight) if not module.bias is None: nowd_params.append(module.bias) elif isinstance(module, nn.BatchNorm2d): nowd_params += list(module.parameters()) return wd_params, nowd_params if __name__ == "__main__": net = Resnet18() x = torch.randn(16, 3, 224, 224) out = net(x) print(out[0].size()) print(out[1].size()) print(out[2].size()) net.get_params() ================================================ FILE: models/__init__.py ================================================ ================================================ FILE: models/elegant.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from .modules.module_base import ResidualBlock_IN, Downsample, Upsample, PositionalEmbedding, MergeBlock from .modules.module_attn import Attention_apply, FeedForwardLayer, MultiheadAttention from .modules.sow_attention import SowAttention from .modules.tps_transform import tps_spatial_transform class Generator(nn.ModuleDict): """Generator. Encoder-Decoder Architecture.""" def __init__(self, conv_dim=64, image_size=256, num_layer_e=2, num_layer_d=1, window_size=16, use_ff=False, merge_mode='conv', num_head=1, double_encoder=False, **unused): super(Generator, self).__init__() # -------------------------- Encoder -------------------------- layers = nn.Conv2d(3, conv_dim, kernel_size=7, stride=1, padding=3, bias=False) self.add_module('in_conv', layers) # Down-Sampling & Bottleneck curr_dim = conv_dim; feature_size = image_size for i in range(2): layers = Downsample(curr_dim, curr_dim * 2, affine=True) self.add_module('down_{:d}'.format(i+1), layers) curr_dim = curr_dim * 2; feature_size = feature_size // 2 self.add_module('e_bottleneck_{:d}'.format(i+1), nn.Sequential(*[ResidualBlock_IN(curr_dim, curr_dim, affine=True) for j in range(num_layer_e)]) ) ### second encoder self.double_encoder = double_encoder if self.double_encoder: layers = nn.Conv2d(3, conv_dim, kernel_size=7, stride=1, padding=3, bias=False) self.add_module('in_conv_s', layers) # Down-Sampling & Bottleneck curr_dim = conv_dim; feature_size = image_size for i in range(2): layers = Downsample(curr_dim, curr_dim * 2, affine=True) self.add_module('down_{:d}_s'.format(i+1), layers) curr_dim = curr_dim * 2; feature_size = feature_size // 2 self.add_module('e_bottleneck_{:d}_s'.format(i+1), nn.Sequential(*[ResidualBlock_IN(curr_dim, curr_dim, affine=True) for j in range(num_layer_e)]) ) # --------------------------- Transfer ---------------------------- curr_dim = conv_dim; feature_size = image_size self.use_ff = use_ff for i in range(2): curr_dim = curr_dim * 2; feature_size = feature_size // 2 self.add_module('embedding_{:d}'.format(i+1), PositionalEmbedding( embedding_dim=136, feature_size=feature_size, max_size=image_size, embedding_type='l2_norm' )) if i < 1: self.add_module('attention_extract_{:d}'.format(i+1), SowAttention( window_size=window_size, in_channels=curr_dim + 136, proj_channels=curr_dim + 136, value_channels=curr_dim, out_channels=curr_dim, num_heads=num_head )) else: self.add_module('attention_extract_{:d}'.format(i+1), MultiheadAttention( in_channels=curr_dim + 136, proj_channels=curr_dim + 136, value_channels=curr_dim, out_channels=curr_dim, num_heads=num_head )) if use_ff: self.add_module('feedforward_{:d}'.format(i+1), FeedForwardLayer(curr_dim, curr_dim)) self.add_module('attention_apply_{:d}'.format(i+1), Attention_apply(curr_dim)) # --------------------------- Decoder ---------------------------- # Bottleneck & Up-Sampling & Merge for i in range(2): self.add_module('d_bottleneck_{:d}'.format(i+1), nn.Sequential(*[ResidualBlock_IN(curr_dim, curr_dim, affine=True) for j in range(num_layer_d)]) ) layers = Upsample(curr_dim, curr_dim // 2, affine=True) self.add_module('up_{:d}'.format(i+1), layers) curr_dim = curr_dim // 2 if i < 1: self.add_module('merge_{:d}'.format(i+1), MergeBlock(merge_mode, curr_dim)) layers = nn.Sequential( nn.InstanceNorm2d(curr_dim, affine=True), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False), ) self.add_module('out_conv', layers) def get_transfer_input(self, image, mask, diff, lms, is_reference=False): feature_size = image.shape[2]; scale_factor = 1.0 fea_list, mask_list, diff_list, lms_list = [], [], [], [] # input conv if self.double_encoder and is_reference: fea = self['in_conv_s'](image) else: fea = self['in_conv'](image) # down-sampling & bottleneck for i in range(2): if self.double_encoder and is_reference: fea = self['down_{:d}_s'.format(i+1)](fea) fea_ = self['e_bottleneck_{:d}_s'.format(i+1)](fea) else: fea = self['down_{:d}'.format(i+1)](fea) fea_ = self['e_bottleneck_{:d}'.format(i+1)](fea) fea_list.append(fea_) feature_size = feature_size // 2; scale_factor = scale_factor * 0.5 mask_ = F.interpolate(mask, feature_size, mode='nearest') mask_list.append(mask_) diff_ = self['embedding_{:d}'.format(i+1)](diff, mask) diff_list.append(diff_) lms_ = lms * scale_factor lms_list.append(lms_) return [fea_list, mask_list, diff_list, lms_list] def get_transfer_output(self, fea_c_list, mask_c_list, diff_c_list, lms_c_list, fea_s_list, mask_s_list, diff_s_list, lms_s_list): attn_out_list = [] for i in range(2): feature_size = fea_c_list[i].shape[2] # align if i == 0: fea_s_ = self.tps_align(feature_size, lms_s_list[i], lms_c_list[i], fea_s_list[i]) mask_s_ = self.tps_align(feature_size, lms_s_list[i], lms_c_list[i], mask_s_list[i], 'nearest') diff_s_ = self.tps_align(feature_size, lms_s_list[i], lms_c_list[i], diff_s_list[i], 'nearest') else: fea_s_ = fea_s_list[i] mask_s_ = mask_s_list[i] diff_s_ = diff_s_list[i] # transfer input_q = torch.cat((fea_c_list[i], diff_c_list[i]), dim=1) input_k = torch.cat((fea_s_, diff_s_), dim=1) attn_out = self['attention_extract_{:d}'.format(i+1)](input_q, input_k, fea_s_, mask_c_list[i], mask_s_) if self.use_ff: attn_out = self['feedforward_{:d}'.format(i+1)](attn_out) attn_out_list.append(attn_out) return attn_out_list def decode(self, fea_c_list, attn_out_list): # apply for i in range(2): fea_c_ = self['attention_apply_{:d}'.format(i+1)](fea_c_list[i], attn_out_list[i]) fea_c_ = self['d_bottleneck_{:d}'.format(2-i)](fea_c_) fea_c_list[i] = fea_c_ # up-sampling & merge fea_c = fea_c_list[1] for i in range(2): fea_c = self['up_{:d}'.format(i+1)](fea_c) if i < 1: fea_c = self['merge_{:d}'.format(i+1)](fea_c_list[0], fea_c) fea_c = self['out_conv'](fea_c) return fea_c def forward(self, c, s, mask_c, mask_s, diff_c, diff_s, lms_c, lms_s): """ c: content, stands for source image. shape: (b, c, h, w) s: style, stands for reference image. shape: (b, c, h, w) mask_c: (b, c', h, w) diff: (b, d, h, w) lms: (b, K, 2) """ transfer_input_c = self.get_transfer_input(c, mask_c, diff_c, lms_c) transfer_input_s = self.get_transfer_input(s, mask_s, diff_s, lms_s, True) attn_out_list = self.get_transfer_output(*transfer_input_c, *transfer_input_s) fea_c = self.decode(transfer_input_c[0], attn_out_list) return fea_c def tps_align(self, feature_size, lms_s, lms_c, fea_s, sample_mode='bilinear'): ''' fea: (B, C, H, W), lms: (B, K, 2) ''' fea_out = [] for l_s, l_c, f_s in zip(lms_s, lms_c, fea_s): l_c = torch.flip(l_c, dims=[1]) / (feature_size - 1) l_s = (torch.flip(l_s, dims=[1]) / (feature_size - 1)).unsqueeze(0) f_s = f_s.unsqueeze(0) # (1, C, H, W) fea_trans, _ = tps_spatial_transform(feature_size, feature_size, l_c, f_s, l_s, sample_mode) fea_out.append(fea_trans) return torch.cat(fea_out, dim=0) ================================================ FILE: models/loss.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from .modules.histogram_matching import histogram_matching from .modules.pseudo_gt import fine_align, expand_area, mask_blur class GANLoss(nn.Module): """Define different GAN objectives. The GANLoss class abstracts away the need to create the target label tensor that has the same size as the input. """ def __init__(self, gan_mode='lsgan', target_real_label=1.0, target_fake_label=0.0): """ Initialize the GANLoss class. Parameters: gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. target_real_label (bool) - - label for a real image target_fake_label (bool) - - label of a fake image Note: Do not use sigmoid as the last layer of Discriminator. LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. """ super(GANLoss, self).__init__() self.register_buffer('real_label', torch.tensor(target_real_label)) self.register_buffer('fake_label', torch.tensor(target_fake_label)) self.gan_mode = gan_mode if gan_mode == 'lsgan': self.loss = nn.MSELoss() elif gan_mode == 'vanilla': self.loss = nn.BCEWithLogitsLoss() else: raise NotImplementedError('gan mode %s not implemented' % gan_mode) def forward(self, prediction, target_is_real): """Calculate loss given Discriminator's output and grount truth labels. Parameters: prediction (tensor) - - tpyically the prediction output from a discriminator target_is_real (bool) - - if the ground truth label is for real images or fake images Returns: the calculated loss. """ if target_is_real: target_tensor = self.real_label else: target_tensor = self.fake_label target_tensor = target_tensor.expand_as(prediction).to(prediction.device) loss = self.loss(prediction, target_tensor) return loss def norm(x: torch.Tensor): return x * 2 - 1 def de_norm(x: torch.Tensor): out = (x + 1) / 2 return out.clamp(0, 1) def masked_his_match(image_s, image_r, mask_s, mask_r): ''' image: (3, h, w) mask: (1, h, w) ''' index_tmp = torch.nonzero(mask_s) x_A_index = index_tmp[:, 1] y_A_index = index_tmp[:, 2] index_tmp = torch.nonzero(mask_r) x_B_index = index_tmp[:, 1] y_B_index = index_tmp[:, 2] image_s = (de_norm(image_s) * 255) #[-1, 1] -> [0, 255] image_r = (de_norm(image_r) * 255) source_masked = image_s * mask_s target_masked = image_r * mask_r source_match = histogram_matching( source_masked, target_masked, [x_A_index, y_A_index, x_B_index, y_B_index]) source_match = source_match.to(image_s.device) return norm(source_match / 255) #[0, 255] -> [-1, 1] def generate_pgt(image_s, image_r, mask_s, mask_r, lms_s, lms_r, margins, blend_alphas, img_size=None): """ input_data: (3, h, w) mask: (c, h, w), lip, skin, left eye, right eye """ if img_size is None: img_size = image_s.shape[1] pgt = image_s.detach().clone() # skin match skin_match = masked_his_match(image_s, image_r, mask_s[1:2], mask_r[1:2]) pgt = (1 - mask_s[1:2]) * pgt + mask_s[1:2] * skin_match # lip match lip_match = masked_his_match(image_s, image_r, mask_s[0:1], mask_r[0:1]) pgt = (1 - mask_s[0:1]) * pgt + mask_s[0:1] * lip_match # eye match mask_s_eye = expand_area(mask_s[2:4].sum(dim=0, keepdim=True), margins['eye']) * mask_s[1:2] mask_r_eye = expand_area(mask_r[2:4].sum(dim=0, keepdim=True), margins['eye']) * mask_r[1:2] eye_match = masked_his_match(image_s, image_r, mask_s_eye, mask_r_eye) mask_s_eye_blur = mask_blur(mask_s_eye, blur_size=5, mode='valid') pgt = (1 - mask_s_eye_blur) * pgt + mask_s_eye_blur * eye_match # tps align pgt = fine_align(img_size, lms_r, lms_s, image_r, pgt, mask_r, mask_s, margins, blend_alphas) return pgt class LinearAnnealingFn(): """ define the linear annealing function with milestones """ def __init__(self, milestones, f_values): assert len(milestones) == len(f_values) self.milestones = milestones self.f_values = f_values def __call__(self, t:int): if t < self.milestones[0]: return self.f_values[0] elif t >= self.milestones[-1]: return self.f_values[-1] else: for r in range(len(self.milestones) - 1): if self.milestones[r] <= t < self.milestones[r+1]: return ((t - self.milestones[r]) * self.f_values[r+1] \ + (self.milestones[r+1] - t) * self.f_values[r]) \ / (self.milestones[r+1] - self.milestones[r]) class ComposePGT(nn.Module): def __init__(self, margins, skin_alpha, eye_alpha, lip_alpha): super(ComposePGT, self).__init__() self.margins = margins self.blend_alphas = { 'skin':skin_alpha, 'eye':eye_alpha, 'lip':lip_alpha } @torch.no_grad() def forward(self, sources, targets, mask_srcs, mask_tars, lms_srcs, lms_tars): pgts = [] for source, target, mask_src, mask_tar, lms_src, lms_tar in\ zip(sources, targets, mask_srcs, mask_tars, lms_srcs, lms_tars): pgt = generate_pgt(source, target, mask_src, mask_tar, lms_src, lms_tar, self.margins, self.blend_alphas) pgts.append(pgt) pgts = torch.stack(pgts, dim=0) return pgts class AnnealingComposePGT(nn.Module): def __init__(self, margins, skin_alpha_milestones, skin_alpha_values, eye_alpha_milestones, eye_alpha_values, lip_alpha_milestones, lip_alpha_values ): super(AnnealingComposePGT, self).__init__() self.margins = margins self.skin_alpha_fn = LinearAnnealingFn(skin_alpha_milestones, skin_alpha_values) self.eye_alpha_fn = LinearAnnealingFn(eye_alpha_milestones, eye_alpha_values) self.lip_alpha_fn = LinearAnnealingFn(lip_alpha_milestones, lip_alpha_values) self.t = 0 self.blend_alphas = {} self.step() def step(self): self.t += 1 self.blend_alphas['skin'] = self.skin_alpha_fn(self.t) self.blend_alphas['eye'] = self.eye_alpha_fn(self.t) self.blend_alphas['lip'] = self.lip_alpha_fn(self.t) @torch.no_grad() def forward(self, sources, targets, mask_srcs, mask_tars, lms_srcs, lms_tars): pgts = [] for source, target, mask_src, mask_tar, lms_src, lms_tar in\ zip(sources, targets, mask_srcs, mask_tars, lms_srcs, lms_tars): pgt = generate_pgt(source, target, mask_src, mask_tar, lms_src, lms_tar, self.margins, self.blend_alphas) pgts.append(pgt) pgts = torch.stack(pgts, dim=0) return pgts class MakeupLoss(nn.Module): """ Define the makeup loss w.r.t pseudo ground truth """ def __init__(self): super(MakeupLoss, self).__init__() def forward(self, x, target, mask=None): if mask is None: return F.l1_loss(x, target) else: return F.l1_loss(x * mask, target * mask) ================================================ FILE: models/model.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from torchvision.models import VGG as TVGG from torchvision.models.vgg import load_state_dict_from_url, model_urls, cfgs from .modules.spectral_norm import spectral_norm as SpectralNorm from .elegant import Generator def get_generator(config): kwargs = { 'conv_dim':config.MODEL.G_CONV_DIM, 'image_size':config.DATA.IMG_SIZE, 'num_head':config.MODEL.NUM_HEAD, 'double_encoder':config.MODEL.DOUBLE_E, 'use_ff':config.MODEL.USE_FF, 'num_layer_e':config.MODEL.NUM_LAYER_E, 'num_layer_d':config.MODEL.NUM_LAYER_D, 'window_size':config.MODEL.WINDOW_SIZE, 'merge_mode':config.MODEL.MERGE_MODE } G = Generator(**kwargs) return G def get_discriminator(config): kwargs = { 'input_channel': 3, 'conv_dim':config.MODEL.D_CONV_DIM, 'num_layers':config.MODEL.D_REPEAT_NUM, 'norm':config.MODEL.D_TYPE } D = Discriminator(**kwargs) return D class Discriminator(nn.Module): """Discriminator. PatchGAN.""" def __init__(self, input_channel=3, conv_dim=64, num_layers=3, norm='SN', **unused): super(Discriminator, self).__init__() layers = [] if norm=='SN': layers.append(SpectralNorm(nn.Conv2d(input_channel, conv_dim, kernel_size=4, stride=2, padding=1))) else: layers.append(nn.Conv2d(input_channel, conv_dim, kernel_size=4, stride=2, padding=1)) layers.append(nn.LeakyReLU(0.01, inplace=True)) curr_dim = conv_dim for i in range(1, num_layers): if norm=='SN': layers.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))) else: layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1)) layers.append(nn.LeakyReLU(0.01, inplace=True)) curr_dim = curr_dim * 2 #k_size = int(image_size / np.power(2, repeat_num)) if norm=='SN': layers.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=1, padding=1))) else: layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=1, padding=1)) layers.append(nn.LeakyReLU(0.01, inplace=True)) curr_dim = curr_dim * 2 self.main = nn.Sequential(*layers) if norm=='SN': self.conv1 = SpectralNorm(nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False)) else: self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False) def forward(self, x): h = self.main(x) out_makeup = self.conv1(h) return out_makeup class VGG(TVGG): def forward(self, x): x = self.features(x) return x def make_layers(cfg, batch_norm=False): layers = [] in_channels = 3 for v in cfg: if v == 'M': layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) if batch_norm: layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] else: layers += [conv2d, nn.ReLU(inplace=True)] in_channels = v return nn.Sequential(*layers) def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs): if pretrained: kwargs['init_weights'] = False model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) return model def vgg16(pretrained=False, progress=True, **kwargs): r"""VGG 16-layer model (configuration "D") `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) ================================================ FILE: models/modules/__init__.py ================================================ ================================================ FILE: models/modules/histogram_matching.py ================================================ import copy import torch def cal_hist(image): """ cal cumulative hist for channel list """ hists = [] for i in range(0, 3): channel = image[i] # channel = image[i, :, :] channel = torch.from_numpy(channel) # hist, _ = np.histogram(channel, bins=256, range=(0,255)) hist = torch.histc(channel, bins=256, min=0, max=256) hist = hist.numpy() # refHist=hist.view(256,1) sum = hist.sum() pdf = [v / (sum + 1e-10) for v in hist] for i in range(1, 256): pdf[i] = pdf[i - 1] + pdf[i] hists.append(pdf) return hists def cal_trans(ref, adj): """ calculate transfer function algorithm refering to wiki item: Histogram matching """ table = list(range(0, 256)) for i in list(range(1, 256)): for j in list(range(1, 256)): if ref[i] >= adj[j - 1] and ref[i] <= adj[j]: table[i] = j break table[255] = 255 return table def histogram_matching(dstImg, refImg, index): """ perform histogram matching dstImg is transformed to have the same the histogram with refImg's index[0], index[1]: the index of pixels that need to be transformed in dstImg index[2], index[3]: the index of pixels that to compute histogram in refImg """ index = [x.cpu().numpy() for x in index] dstImg = dstImg.detach().cpu().numpy() refImg = refImg.detach().cpu().numpy() dst_align = [dstImg[i, index[0], index[1]] for i in range(0, 3)] ref_align = [refImg[i, index[2], index[3]] for i in range(0, 3)] hist_ref = cal_hist(ref_align) hist_dst = cal_hist(dst_align) tables = [cal_trans(hist_dst[i], hist_ref[i]) for i in range(0, 3)] mid = copy.deepcopy(dst_align) for i in range(0, 3): for k in range(0, len(index[0])): dst_align[i][k] = tables[i][int(mid[i][k])] for i in range(0, 3): dstImg[i, index[0], index[1]] = dst_align[i] dstImg = torch.FloatTensor(dstImg) return dstImg ================================================ FILE: models/modules/module_attn.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F class MultiheadAttention_weight(nn.Module): def __init__(self, feature_dim, proj_dim, num_heads=1, dropout=0.0, bias=True): super(MultiheadAttention_weight, self).__init__() self.feature_dim = feature_dim self.proj_dim = proj_dim self.num_heads = num_heads self.dropout = nn.Dropout(dropout) self.head_dim = proj_dim // num_heads assert self.head_dim * num_heads == self.proj_dim, "embed_dim must be divisible by num_heads" self.scaling = self.head_dim ** -0.5 self.q_proj = nn.Linear(feature_dim, proj_dim, bias=bias) self.k_proj = nn.Linear(feature_dim, proj_dim, bias=bias) def forward(self, fea_c, fea_s, mask_c, mask_s): ''' fea_c: (b, d, h, w) mask_c: (b, c, h, w) ''' bsz, dim, h, w = fea_c.shape; mask_channel = mask_c.shape[1] fea_c = fea_c.view(bsz, dim, h*w).transpose(1, 2) # (b, HW, d) fea_s = fea_s.view(bsz, dim, h*w).transpose(1, 2) with torch.no_grad(): if mask_c.shape[2] != h: mask_c = F.interpolate(mask_c, size=(h, w)) mask_s = F.interpolate(mask_s, size=(h, w)) mask_c = mask_c.view(bsz, mask_channel, -1, h*w) # (b, m_c, 1, HW) mask_s = mask_s.view(bsz, mask_channel, -1, h*w) mask_attn = torch.matmul(mask_c.transpose(-2, -1), mask_s) # (b, m_c, HW, HW) mask_attn = torch.sum(mask_attn, dim=1, keepdim=True).clamp_(0, 1) # (b, 1, HW, HW) mask_sum = torch.sum(mask_attn, dim=-1, keepdim=True) mask_attn += (mask_sum == 0).float() mask_attn = mask_attn.masked_fill_(mask_attn == 0, float('-inf')).masked_fill_(mask_attn == 1, float(0.0)) query = self.q_proj(fea_c) # (b, HW, D) key = self.k_proj(fea_s) # (b, HW, D) query = query.view(bsz, h*w, self.num_heads, self.head_dim).transpose(1, 2) # (b, h, HW, D) key = key.view(bsz, h*w, self.num_heads, self.head_dim).transpose(1, 2) weights = torch.matmul(query, key.transpose(-1, -2)) # (b, h, HW, HW) weights = weights * self.scaling weights = weights + mask_attn.detach() weights = self.dropout(F.softmax(weights, dim=-1)) weights = weights * (1 - (mask_sum == 0).float().detach()) return weights class MultiheadAttention_value(nn.Module): def __init__(self, feature_dim, proj_dim, num_heads=1, bias=True): super(MultiheadAttention_value, self).__init__() self.feature_dim = feature_dim self.proj_dim = proj_dim self.num_heads = num_heads self.head_dim = proj_dim // num_heads assert self.head_dim * num_heads == self.proj_dim, "embed_dim must be divisible by num_heads" self.v_proj = nn.Linear(feature_dim, proj_dim, bias=bias) def forward(self, weights, fea): ''' weights: (b, h, HW. HW) fea: (b, d, H, W) ''' bsz, dim, h, w = fea.shape fea = fea.view(bsz, dim, h*w).transpose(1, 2) #(b, HW, D) value = self.v_proj(fea) value = value.view(bsz, h*w, self.num_heads, self.head_dim).transpose(1, 2) #(b, h, HW, D) out = torch.matmul(weights, value) out = out.transpose(1, 2).contiguous().view(bsz, h*w, self.proj_dim) # (b, HW, D) out = out.transpose(1, 2).view(bsz, self.proj_dim, h, w) #(b, d, H, W) return out class MultiheadAttention(nn.Module): def __init__(self, in_channels, proj_channels, value_channels, out_channels, num_heads=1, dropout=0.0, bias=True): super(MultiheadAttention, self).__init__() self.weight = MultiheadAttention_weight(in_channels, proj_channels, num_heads, dropout, bias) self.value = MultiheadAttention_value(value_channels, out_channels, num_heads, bias) def forward(self, fea_q, fea_k, fea_v, mask_q, mask_k): ''' fea: (b, d, h, w) mask: (b, c, h, w) ''' weights = self.weight(fea_q, fea_k, mask_q, mask_k) return self.value(weights, fea_v) class FeedForwardLayer(nn.Module): def __init__(self, feature_dim, ff_dim, dropout=0.0): super(FeedForwardLayer, self).__init__() self.main = nn.Sequential( nn.LeakyReLU(0.2, inplace=True), nn.Dropout(p=dropout, inplace=True), nn.Conv2d(feature_dim, ff_dim, kernel_size=1), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ff_dim, feature_dim, kernel_size=1) ) def forward(self, x): return self.main(x) class Attention_apply(nn.Module): def __init__(self, feature_dim, normalize=True): super(Attention_apply, self).__init__() self.normalize = normalize if normalize: self.norm = nn.InstanceNorm2d(feature_dim, affine=False) self.actv = nn.LeakyReLU(0.2, inplace=True) self.conv = nn.Conv2d(feature_dim, feature_dim, kernel_size=3, stride=1, padding=1, bias=False) def forward(self, x, attn_out): if self.normalize: x = self.norm(x) x = x * (1 + attn_out) return self.conv(self.actv(x)) ================================================ FILE: models/modules/module_base.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F class ResidualBlock(nn.Module): """Residual Block.""" def __init__(self, dim_in, dim_out): super(ResidualBlock, self).__init__() self.main = nn.Sequential( nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False) ) self.skip = nn.Identity() if dim_in == dim_out else nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=False) def forward(self, x): x = self.skip(x) + self.main(x) return x / math.sqrt(2) class ResidualBlock_IN(nn.Module): """Residual Block with InstanceNorm.""" def __init__(self, dim_in, dim_out, affine=False): super(ResidualBlock_IN, self).__init__() self.main = nn.Sequential( nn.InstanceNorm2d(dim_in, affine=affine), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False), nn.InstanceNorm2d(dim_out, affine=affine), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False), ) self.skip = nn.Identity() if dim_in == dim_out else nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=False) def forward(self, x): x = self.skip(x) + self.main(x) return x / math.sqrt(2) class ResidualBlock_Downsample(nn.Module): """Residual Block with InstanceNorm.""" def __init__(self, dim_in, dim_out, affine=False): super(ResidualBlock_Downsample, self).__init__() self.main = nn.Sequential( nn.InstanceNorm2d(dim_in, affine=affine), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(dim_in, dim_out, kernel_size=4, stride=2, padding=1, bias=False) ) if dim_in == dim_out: self.skip = nn.Identity() else: self.skip = nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=1, bias=False) def forward(self, x): skip = F.interpolate(self.skip(x), scale_factor=0.5, mode='bilinear', align_corners=False, recompute_scale_factor=True) res = self.main(x) x = skip + res return x / math.sqrt(2) class Downsample(nn.Module): """Residual Block with InstanceNorm.""" def __init__(self, dim_in, dim_out, affine=False): super(Downsample, self).__init__() self.main = nn.Sequential( nn.InstanceNorm2d(dim_in, affine=affine), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(dim_in, dim_out, kernel_size=4, stride=2, padding=1, bias=False) ) def forward(self, x): return self.main(x) class ResidualBlock_Upsample(nn.Module): """Residual Block with InstanceNorm.""" def __init__(self, dim_in, dim_out, normalize=True, affine=False): super(ResidualBlock_Upsample, self).__init__() if normalize: self.main = nn.Sequential( nn.InstanceNorm2d(dim_in, affine=affine), nn.LeakyReLU(0.2, inplace=True), nn.ConvTranspose2d(dim_in, dim_out, kernel_size=4, stride=2, padding=1, bias=False) ) else: self.main = nn.Sequential( nn.LeakyReLU(0.2, inplace=True), nn.ConvTranspose2d(dim_in, dim_out, kernel_size=4, stride=2, padding=1, bias=False) ) if dim_in == dim_out: self.skip = nn.Identity() else: self.skip = nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=1, bias=False) def forward(self, x): skip = F.interpolate(self.skip(x), scale_factor=2, mode='bilinear', align_corners=False) res = self.main(x) x = skip + res return x / math.sqrt(2) class Upsample(nn.Module): """Residual Block with InstanceNorm.""" def __init__(self, dim_in, dim_out, normalize=True, affine=False): super(Upsample, self).__init__() if normalize: self.main = nn.Sequential( nn.InstanceNorm2d(dim_in, affine=affine), nn.LeakyReLU(0.2, inplace=True), nn.ConvTranspose2d(dim_in, dim_out, kernel_size=4, stride=2, padding=1, bias=False) ) else: self.main = nn.Sequential( nn.LeakyReLU(0.2, inplace=True), nn.ConvTranspose2d(dim_in, dim_out, kernel_size=4, stride=2, padding=1, bias=False) ) def forward(self, x): return self.main(x) class PositionalEmbedding(nn.Module): def __init__(self, embedding_dim=136, feature_size=64, max_size=None, embedding_type='l2_norm'): super(PositionalEmbedding, self).__init__() self.embedding_dim = embedding_dim self.feature_size = feature_size self.max_size = max_size assert embedding_type in ['l2_norm', 'uniform', 'sin'] self.embedding_type = embedding_type @torch.no_grad() def forward(self, diff, mask): ''' diff: (b, d, h, w) mask: (b, 3, h, w) return: (b, d, h, w) ''' bsz, init_dim, init_size, _ = diff.shape assert self.embedding_dim >= init_dim diff = F.interpolate(diff, self.feature_size) # (b, d, h, w) mask = F.interpolate(mask, size=self.feature_size) mask = torch.sum(mask, dim=1, keepdim=True) # (b, 1, h, w) diff = diff * mask if self.embedding_type == 'l2_norm': norm = torch.norm(diff, dim=1, keepdim=True) norm = (norm == 0) + norm diff = diff / norm elif self.embedding_type == 'uniform': diff = diff / self.max_size elif self.embedding_type == 'sin': diff = torch.sin(diff * math.pi / (2 * self.max_size)) if self.embedding_dim > init_dim: zero_shape = (bsz, self.embedding_dim - init_dim, self.feature_size, self.feature_size) zero_padding = torch.zeros(zero_shape, device=diff.device) diff = torch.cat((diff, zero_padding), dim=1) diff = diff.detach(); diff.requires_grad = False return diff class MergeBlock(nn.Module): def __init__(self, merge_mode, feature_dim, normalize=True): super(MergeBlock, self).__init__() assert merge_mode in ['conv', 'add', 'affine'] self.merge_mode = merge_mode if merge_mode == 'affine': self.norm = nn.LayerNorm(feature_dim, elementwise_affine=False) if normalize else nn.Identity() else: self.norm = nn.InstanceNorm2d(feature_dim, affine=False) if normalize else nn.Identity() self.norm_r = nn.InstanceNorm2d(feature_dim, affine=False) if normalize else nn.Identity() self.actv = nn.LeakyReLU(0.2, inplace=True) if merge_mode == 'conv': self.conv = nn.Conv2d(2 * feature_dim, feature_dim, kernel_size=3, stride=1, padding=1, bias=False) else: self.conv = nn.Conv2d(feature_dim, feature_dim, kernel_size=3, stride=1, padding=1, bias=False) def forward(self, fea_s, fea_r): if self.merge_mode == 'conv': fea_s = self.norm(fea_s) fea_r = self.norm_r(fea_r) fea_s = torch.cat((fea_s, fea_r), dim=1) elif self.merge_mode == 'add': fea_s = self.norm(fea_s) fea_r = self.norm_r(fea_r) fea_s = (fea_s + fea_r) / math.sqrt(2) elif self.merge_mode == 'affine': fea_s = fea_s.permute(0, 2, 3, 1) fea_s = self.norm(fea_s) fea_s = fea_s.permute(0, 3, 1, 2) fea_s = fea_s * (1 + fea_r) return self.conv(self.actv(fea_s)) ================================================ FILE: models/modules/pseudo_gt.py ================================================ import cv2 import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torchvision.transforms import functional from models.modules.tps_transform import tps_sampler, tps_spatial_transform def expand_area(mask:torch.Tensor, margin:int): ''' mask: (C, H, W) or (N, C, H, W) ''' kernel = np.zeros((margin * 2 + 1, margin * 2 + 1), dtype=np.uint8) kernel = cv2.circle(kernel, (margin, margin), margin, (255, 0, 0), -1) kernel = torch.FloatTensor((kernel > 0)).unsqueeze(0).unsqueeze(0).to(mask.device) ndim = mask.ndimension() if ndim == 3: mask = mask.unsqueeze(0) expanded_mask = torch.zeros_like(mask) for i in range(mask.shape[1]): expanded_mask[:,i:i+1,:,:] = F.conv2d(mask[:,i:i+1,:,:], kernel, padding=margin) if ndim == 3: expanded_mask = expanded_mask.squeeze(0) return (expanded_mask > 0).float() def mask_blur(mask:torch.Tensor, blur_size=3, mode='smooth'): """Blur the edge of mask so that the compose image have smooth transition Args: mask (torch.Tensor): [C, H, W] blur_size (int): size of blur kernel. Defaults to 3. mode (str) Defaults to 'smooth'. Returns: torch.Tensor: blurred mask """ #kernel = torch.ones((1, 1, blur_size * 2 + 1, blur_size * 2 + 1)).to(mask.device) kernel = np.zeros((blur_size * 2 + 1, blur_size * 2 + 1), dtype=np.uint8) kernel = cv2.circle(kernel, (blur_size, blur_size), blur_size, (255, 0, 0), -1) kernel = torch.FloatTensor((kernel > 0)).unsqueeze(0).unsqueeze(0).to(mask.device) kernel = kernel / torch.sum(kernel) ndim = mask.ndimension() if ndim == 3: mask = mask.unsqueeze(0) mask_blur = torch.zeros_like(mask) for i in range(mask.shape[1]): mask_blur[:,i:i+1,:,:] = F.conv2d(mask[:,i:i+1,:,:], kernel, padding=blur_size) if mode == 'valid': mask_blur = (mask_blur.clamp(0.5, 1) - 0.5) * 2 * mask if ndim == 3: mask_blur = mask_blur.squeeze(0) return mask_blur.clamp(0, 1) def mask_blend(mask, blend_alpha, mask_bound=None, blur_size=3, blend_mode='smooth'): if blur_size > 0: mask = mask_blur(mask, blur_size, blend_mode) mask = mask * blend_alpha if mask_bound is None: return mask else: return mask * mask_bound def tps_align(img_size, lms_r, lms_s, image_r, image_s=None, mask_r = None, mask_s=None, sample_mode='bilinear'): ''' image: (C, H, W), lms: (K, 2), mask:(1, H, W) ''' lms_s = torch.flip(lms_s, dims=[1]) / (img_size - 1) lms_r = (torch.flip(lms_r, dims=[1]) / (img_size - 1)).unsqueeze(0) image_r = image_r.unsqueeze(0) image_trans, _ = tps_spatial_transform(img_size, img_size, lms_s, image_r, lms_r, sample_mode) if mask_r is not None: mask_r_trans, _ = tps_spatial_transform(img_size, img_size, lms_s, mask_r.unsqueeze(0), lms_r, 'nearest') if image_s is not None: mask_compose = torch.ones((1, img_size, img_size), device=lms_r.device) if mask_s is not None: mask_compose *= mask_s if mask_r is not None: mask_compose *= mask_r_trans.squeeze(0) return image_s * (1 - mask_compose) + image_trans.squeeze(0) * mask_compose else: return image_trans.squeeze(0) def tps_blend(blend_alpha, img_size, lms_r, lms_s, image_r, image_s=None, mask_r = None, mask_s=None, mask_s_bound=None, blur_size=7, sample_mode='bilinear', blend_mode='smooth'): ''' image: (C, H, W), lms: (K, 2), mask:(1, H, W) ''' lms_s = torch.flip(lms_s, dims=[1]) / (img_size - 1) lms_r = (torch.flip(lms_r, dims=[1]) / (img_size - 1)).unsqueeze(0) image_r = image_r.unsqueeze(0) image_trans, _ = tps_spatial_transform(img_size, img_size, lms_s, image_r, lms_r, sample_mode) if mask_r is not None: mask_r_trans, _ = tps_spatial_transform(img_size, img_size, lms_s, mask_r.unsqueeze(0), lms_r, 'nearest') if image_s is not None: mask_compose = torch.ones((1, img_size, img_size), device=lms_r.device) if mask_s is not None: mask_compose *= mask_s if mask_r is not None: mask_compose *= mask_r_trans.squeeze(0) mask_compose = mask_blend(mask_compose, blend_alpha, mask_s_bound, blur_size, blend_mode) return image_s * (1 - mask_compose) + image_trans.squeeze(0) * mask_compose else: return image_trans.squeeze(0) def fine_align(img_size, lms_r, lms_s, image_r, image_s, mask_r, mask_s, margins, blend_alphas): ''' image: (C, H, W), lms: (K, 2) mask: (C, H, W), lip, face, left eye, right eye margins: dictionary, blend_alphas: dictionary ''' # skin align image_s = tps_blend(blend_alphas['skin'], img_size, lms_r[:60], lms_s[:60], image_r, image_s, mask_r[1:2], mask_s[1:2], mask_s[1:2], blur_size=8, blend_mode='valid') # lip align mask_s_lip = expand_area(mask_s[0:1], margins['lip']) mask_r_lip = expand_area(mask_r[0:1], margins['lip']) image_s = tps_blend(blend_alphas['lip'], img_size, lms_r[48:], lms_s[48:], image_r, image_s, mask_r_lip, mask_s_lip, mask_s[0:1], blur_size=3) # left eye align mask_s_eye = expand_area(mask_s[2:3], margins['eye']) mask_r_eye = expand_area(mask_r[2:3], margins['eye']) * mask_r[1:2] image_s = tps_blend(blend_alphas['eye'], img_size, torch.cat((lms_r[14:17], lms_r[22:27], lms_r[27:31], lms_r[42:48]), dim=0), torch.cat((lms_s[14:17], lms_s[22:27], lms_s[27:31], lms_s[42:48]), dim=0), image_r, image_s, mask_r_eye, mask_s_eye, mask_s[1:2], blur_size=5, sample_mode='nearest') # right eye align mask_s_eye = expand_area(mask_s[3:4], margins['eye']) mask_r_eye = expand_area(mask_r[3:4], margins['eye']) * mask_r[1:2] image_s = tps_blend(blend_alphas['eye'], img_size, torch.cat((lms_r[0:3], lms_r[17:22], lms_r[27:31], lms_r[36:42]), dim=0), torch.cat((lms_s[0:3], lms_s[17:22], lms_s[27:31], lms_s[36:42]), dim=0), image_r, image_s, mask_r_eye, mask_s_eye, mask_s[1:2], blur_size=5, sample_mode='nearest') return image_s if __name__ == "__main__": pass ================================================ FILE: models/modules/sow_attention.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F class WindowAttention(nn.Module): def __init__(self, window_size, in_channels, proj_channels, value_channels, out_channels, num_heads=1, dropout=0.0, bias=True, weighted_output=True): super(WindowAttention, self).__init__() assert window_size % 2 == 0 self.window_size = window_size self.weighted_output = weighted_output window_weight = self.generate_window_weight() self.register_buffer('window_weight', window_weight) self.num_heads = num_heads self.dropout = nn.Dropout(dropout) self.in_channels = in_channels self.proj_channels = proj_channels head_dim = proj_channels // num_heads assert head_dim * num_heads == self.proj_channels, "embed_dim must be divisible by num_heads" self.scaling = head_dim ** -0.5 self.q_proj = nn.Conv2d(in_channels, proj_channels, kernel_size=1, bias=bias) self.k_proj = nn.Conv2d(in_channels, proj_channels, kernel_size=1, bias=bias) self.value_channels = value_channels self.out_channels = out_channels assert out_channels // num_heads * num_heads == self.out_channels self.v_proj = nn.Conv2d(value_channels, out_channels, kernel_size=1, bias=bias) @torch.no_grad() def generate_window_weight(self): yc = torch.arange(self.window_size // 2).unsqueeze(1).repeat(1, self.window_size // 2) xc = torch.arange(self.window_size // 2).unsqueeze(0).repeat(self.window_size // 2, 1) window_weight = xc * yc / (self.window_size // 2 - 1) ** 2 window_weight = torch.cat((window_weight, torch.flip(window_weight, dims=[0])), dim=0) window_weight = torch.cat((window_weight, torch.flip(window_weight, dims=[1])), dim=1) return window_weight.view(-1) def make_window(self, x: torch.Tensor): """ input: (B, C, H, W) output: (B, h, H/S, W/S, S*S, C/h) """ bsz, dim, h, w = x.shape x = x.view(bsz, self.num_heads, dim // self.num_heads, h // self.window_size, self.window_size, w // self.window_size, self.window_size) x = x.transpose(4, 5).contiguous().view(bsz, self.num_heads, dim // self.num_heads, h // self.window_size, w // self.window_size, self.window_size**2) x = x.permute(0, 1, 3, 4, 5, 2) return x def demake_window(self, x: torch.Tensor): """ input: (B, h, H/S, W/S, S*S, C/h) output: (B, C, H, W) """ bsz, _, h_s, w_s, _, dim_h = x.shape x = x.permute(0, 1, 5, 2, 3, 4).contiguous() #print(x.shape) x = x.view(bsz, dim_h * self.num_heads, h_s, w_s, self.window_size, self.window_size) #print(x.shape) x = x.transpose(3, 4).contiguous().view(bsz, dim_h * self.num_heads, h_s * self.window_size, w_s * self.window_size) #print(x.shape) return x @torch.no_grad() def make_mask_window(self, mask: torch.Tensor): """ input: (B, C, H, W) output: (B, 1, H/S, W/S, S*S, C) """ bsz, mask_channel, h, w = mask.shape mask = mask.view(bsz, 1, mask_channel, h // self.window_size, self.window_size, w // self.window_size, self.window_size) mask = mask.transpose(4, 5).contiguous().view(bsz, 1, mask_channel, h // self.window_size, w // self.window_size, self.window_size**2) mask = mask.permute(0, 1, 3, 4, 5, 2) return mask def forward(self, fea_q, fea_k, fea_v, mask_q=None, mask_k=None): ''' fea: (b, d, h, w) mask: (b, c, h, w) ''' query = self.q_proj(fea_q) # (B, D, H, W) key = self.k_proj(fea_k) value = self.v_proj(fea_v) query = self.make_window(query) # (B, h, H/S, W/S, S*S, D/h) key = self.make_window(key) value = self.make_window(value) weights = torch.matmul(query, key.transpose(-1, -2)) # (B, h, H/S, W/S, S*S, S*S) weights = weights * self.scaling if mask_q is not None and mask_k is not None: mask_q = self.make_mask_window(mask_q) # (B, 1, H/S, W/S, S*S, C) mask_k = self.make_mask_window(mask_k) with torch.no_grad(): mask_attn = torch.matmul(mask_q, mask_k.transpose(-1, -2)) mask_sum = torch.sum(mask_attn, dim=-1, keepdim=True) mask_attn += (mask_sum == 0).float() mask_attn = mask_attn.masked_fill_(mask_attn == 0, float('-inf')).masked_fill_(mask_attn == 1, float(0.0)) weights += mask_attn weights = self.dropout(F.softmax(weights, dim=-1)) if mask_q is not None and mask_k is not None: weights = weights * (1 - (mask_sum == 0).float().detach()) out = torch.matmul(weights, value) # (B, h, H/S, W/S, S*S, D/h) if self.weighted_output: window_weight = self.window_weight.view(1, 1, 1, 1, self.window_size ** 2, 1) out = out * window_weight out = self.demake_window(out) #(B, D, H, W) return out class SowAttention(nn.Module): def __init__(self, window_size, in_channels, proj_channels, value_channels, out_channels, num_heads=1, dropout=0.0, bias=True): super(SowAttention, self).__init__() assert window_size % 2 == 0 self.window_size = window_size self.pad = nn.ZeroPad2d(window_size // 2) self.window_attention = WindowAttention(window_size, in_channels, proj_channels, value_channels, out_channels, num_heads, dropout, bias) def forward(self, fea_q, fea_k, fea_v, mask_q=None, mask_k=None): ''' fea: (b, d, h, w) mask: (b, c, h, w) ''' out_0 = self.window_attention(fea_q, fea_k, fea_v, mask_q, mask_k) fea_q = self.pad(fea_q) fea_k = self.pad(fea_k) fea_v = self.pad(fea_v) if mask_q is not None and mask_k is not None: mask_q = self.pad(mask_q) mask_k = self.pad(mask_k) else: mask_q = None; mask_k = None out_1 = self.window_attention(fea_q, fea_k, fea_v, mask_q, mask_k) out_1 = out_1[:, :, self.window_size//2:-self.window_size//2, self.window_size//2:-self.window_size//2] if mask_q is not None and mask_k is not None: out_2 = self.window_attention( fea_q[:, :, :, self.window_size//2:-self.window_size//2], fea_k[:, :, :, self.window_size//2:-self.window_size//2], fea_v[:, :, :, self.window_size//2:-self.window_size//2], mask_q[:, :, :, self.window_size//2:-self.window_size//2], mask_k[:, :, :, self.window_size//2:-self.window_size//2] ) else: out_2 = self.window_attention( fea_q[:, :, :, self.window_size//2:-self.window_size//2], fea_k[:, :, :, self.window_size//2:-self.window_size//2], fea_v[:, :, :, self.window_size//2:-self.window_size//2], ) out_2 = out_2[:, :, self.window_size//2:-self.window_size//2, :] if mask_q is not None and mask_k is not None: out_3 = self.window_attention( fea_q[:, :, self.window_size//2:-self.window_size//2, :], fea_k[:, :, self.window_size//2:-self.window_size//2, :], fea_v[:, :, self.window_size//2:-self.window_size//2, :], mask_q[:, :, self.window_size//2:-self.window_size//2, :], mask_k[:, :, self.window_size//2:-self.window_size//2, :] ) else: out_3 = self.window_attention( fea_q[:, :, self.window_size//2:-self.window_size//2, :], fea_k[:, :, self.window_size//2:-self.window_size//2, :], fea_v[:, :, self.window_size//2:-self.window_size//2, :], ) out_3 = out_3[:, :, :, self.window_size//2:-self.window_size//2] out = out_0 + out_1 + out_2 + out_3 return out class StridedwindowAttention(nn.Module): def __init__(self, stride, in_channels, proj_channels, value_channels, out_channels, num_heads=1, dropout=0.0, bias=True): super(StridedwindowAttention, self).__init__() self.stride = stride self.num_heads = num_heads self.dropout = nn.Dropout(dropout) self.in_channels = in_channels self.proj_channels = proj_channels head_dim = proj_channels // num_heads assert head_dim * num_heads == self.proj_channels, "embed_dim must be divisible by num_heads" self.scaling = head_dim ** -0.5 self.q_proj = nn.Conv2d(in_channels, proj_channels, kernel_size=1, bias=bias) self.k_proj = nn.Conv2d(in_channels, proj_channels, kernel_size=1, bias=bias) self.value_channels = value_channels self.out_channels = out_channels assert out_channels // num_heads * num_heads == self.out_channels self.v_proj = nn.Conv2d(value_channels, out_channels, kernel_size=1, bias=bias) def make_window(self, x: torch.Tensor): """ input: (B, C, H, W) output: (B, h, S(h), S(w), H/S * W/S, C/h) """ bsz, dim, h, w = x.shape assert h % self.stride == 0 and w % self.stride == 0 x = x.view(bsz, self.num_heads, dim // self.num_heads, h // self.stride, self.stride, w // self.stride, self.stride) # (B, h, C/h, H/S, S(h), W/S, S(w)) x = x.permute(0, 1, 4, 6, 3, 5, 2).contiguous() # (B, h, S(h), S(w), H/S, W/S, C/h) x = x.view(bsz, self.num_heads, self.stride, self.stride, h // self.stride * w // self.stride, dim // self.num_heads) return x def demake_window(self, x: torch.Tensor, h, w): """ input: (B, h, S(h), S(w), H/S * W/S, C/h) output: (B, C, H, W) """ bsz, _, _, _, _, dim_h = x.shape x = x.view(bsz, self.num_heads, self.stride, self.stride, h // self.stride, w // self.stride, dim_h) # (B, h, S(h), S(w), H/S, W/S, C/h) x = x.permute(0, 1, 6, 4, 2, 5, 3).contiguous() # (B, h, C/h, H/S, S(h), W/S, S(w)) x = x.view(bsz, dim_h * self.num_heads, h, w) return x @torch.no_grad() def make_mask_window(self, mask: torch.Tensor): """ input: (B, C, H, W) output: (B, 1, S(h), S(w), H/S * W/S, C) """ bsz, mask_channel, h, w = mask.shape assert h % self.stride == 0 and w % self.stride == 0 mask = mask.view(bsz, 1, mask_channel, h // self.stride, self.stride, w // self.stride, self.stride) mask = mask.permute(0, 1, 4, 6, 3, 5, 2).contiguous() mask = mask.view(bsz, 1, self.stride, self.stride, h // self.stride * w // self.stride, mask_channel) return mask def forward(self, fea_q, fea_k, fea_v, mask_q=None, mask_k=None): ''' fea: (b, d, h, w) mask: (b, c, h, w) ''' bsz, _, h, w = fea_q.shape query = self.q_proj(fea_q) # (B, D, H, W) key = self.k_proj(fea_k) value = self.v_proj(fea_v) query = self.make_window(query) # (B, h, S(h), S(w), H/S * W/S, C/h) key = self.make_window(key) value = self.make_window(value) weights = torch.matmul(query, key.transpose(-1, -2)) # (B, h, S(h), S(w), H/S * W/S, H/S * W/S) weights = weights * self.scaling if mask_q is not None and mask_k is not None: mask_q = self.make_mask_window(mask_q) # (B, 1, S(h), S(w), H/S * W/S, C) mask_k = self.make_mask_window(mask_k) with torch.no_grad(): mask_attn = torch.matmul(mask_q, mask_k.transpose(-1, -2)) mask_sum = torch.sum(mask_attn, dim=-1, keepdim=True) mask_attn += (mask_sum == 0).float() mask_attn = mask_attn.masked_fill_(mask_attn == 0, float('-inf')).masked_fill_(mask_attn == 1, float(0.0)) weights += mask_attn weights = self.dropout(F.softmax(weights, dim=-1)) if mask_q is not None and mask_k is not None: weights = weights * (1 - (mask_sum == 0).float().detach()) out = torch.matmul(weights, value) # (B, h, S(h), S(w), H/S * W/S, D/h) out = self.demake_window(out, h, w) #(B, D, H, W) return out ================================================ FILE: models/modules/spectral_norm.py ================================================ import torch from torch.nn import Parameter def l2normalize(v, eps=1e-12): return v / (v.norm() + eps) class SpectralNorm(object): def __init__(self): self.name = "weight" #print(self.name) self.power_iterations = 1 def compute_weight(self, module): u = getattr(module, self.name + "_u") v = getattr(module, self.name + "_v") w = getattr(module, self.name + "_bar") height = w.data.shape[0] for _ in range(self.power_iterations): v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data)) u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data)) # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) sigma = u.dot(w.view(height, -1).mv(v)) return w / sigma.expand_as(w) @staticmethod def apply(module): name = "weight" fn = SpectralNorm() try: u = getattr(module, name + "_u") v = getattr(module, name + "_v") w = getattr(module, name + "_bar") except AttributeError: w = getattr(module, name) height = w.data.shape[0] width = w.view(height, -1).data.shape[1] u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) w_bar = Parameter(w.data) #del module._parameters[name] module.register_parameter(name + "_u", u) module.register_parameter(name + "_v", v) module.register_parameter(name + "_bar", w_bar) # remove w from parameter list del module._parameters[name] setattr(module, name, fn.compute_weight(module)) # recompute weight before every forward() module.register_forward_pre_hook(fn) return fn def remove(self, module): weight = self.compute_weight(module) delattr(module, self.name) del module._parameters[self.name + '_u'] del module._parameters[self.name + '_v'] del module._parameters[self.name + '_bar'] module.register_parameter(self.name, Parameter(weight.data)) def __call__(self, module, inputs): setattr(module, self.name, self.compute_weight(module)) def spectral_norm(module): SpectralNorm.apply(module) return module def remove_spectral_norm(module): name = 'weight' for k, hook in module._forward_pre_hooks.items(): if isinstance(hook, SpectralNorm) and hook.name == name: hook.remove(module) del module._forward_pre_hooks[k] return module raise ValueError("spectral_norm of '{}' not found in {}" .format(name, module)) ================================================ FILE: models/modules/tps_transform.py ================================================ from __future__ import absolute_import import numpy as np import itertools import torch import torch.nn as nn import torch.nn.functional as F # TF32 is not enough, require FP32 # Disable automatic TF32 since Pytorch 1.7 torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False def grid_sample(input, grid, mode='bilinear', canvas=None): output = F.grid_sample(input, grid, mode=mode, align_corners=True) if canvas is None: return output else: input_mask = input.data.new(input.size()).fill_(1) output_mask = F.grid_sample(input_mask, grid, mode='nearest', align_corners=True) padded_output = output * output_mask + canvas * (1 - output_mask) return padded_output # phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2 def compute_partial_repr(input_points, control_points): N = input_points.size(0) M = control_points.size(0) pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2) # original implementation, very slow # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance pairwise_diff_square = pairwise_diff * pairwise_diff pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, 1] repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist) #repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist + 1e-8) # fix numerical error for 0 * log(0), substitute all nan with 0 mask = repr_matrix != repr_matrix repr_matrix.masked_fill_(mask, 0) return repr_matrix # compute \Delta_c^-1 def bulid_delta_inverse(target_control_points): ''' target_control_points: (N, 2) ''' N = target_control_points.shape[0] forward_kernel = torch.zeros(N + 3, N + 3).to(target_control_points.device) target_control_partial_repr = compute_partial_repr(target_control_points, target_control_points) forward_kernel[:N, :N].copy_(target_control_partial_repr) forward_kernel[:N, -3].fill_(1) forward_kernel[-3, :N].fill_(1) forward_kernel[:N, -2:].copy_(target_control_points) forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1)) # compute inverse matrix inverse_kernel = torch.inverse(forward_kernel) return inverse_kernel # create target coordinate matrix def build_target_coordinate_matrix(target_height, target_width, target_control_points): ''' target_control_points: (N, 2) ''' HW = target_height * target_width target_coordinate = list(itertools.product(range(target_height), range(target_width))) target_coordinate = torch.Tensor(target_coordinate).to(target_control_points.device) # HW x 2 Y, X = target_coordinate.split(1, dim = 1) Y = Y / (target_height - 1) X = X / (target_width - 1) target_coordinate = torch.cat([X, Y], dim = 1) # convert from (y, x) to (x, y) target_coordinate_partial_repr = compute_partial_repr(target_coordinate, target_control_points) target_coordinate_repr = torch.cat([ target_coordinate_partial_repr, torch.ones((HW, 1), device=target_control_points.device), target_coordinate], dim = 1) return target_coordinate_repr def tps_sampler(target_height, target_width, inverse_kernel, target_coordinate_repr, source, source_control_points, sample_mode='bilinear'): ''' inverse_kernel: \Delta_C^-1 target_coordinate_repr: \hat{p} source: (B, C, H, W) source_control_points: (B, N, 2) ''' batch_size = source.shape[0] Y = torch.cat([source_control_points, torch.zeros((batch_size, 3, 2), device=source.device)], dim=1) mapping_matrix = torch.matmul(inverse_kernel, Y) source_coordinate = torch.matmul(target_coordinate_repr, mapping_matrix) grid = source_coordinate.view(-1, target_height, target_width, 2) grid = torch.clamp(grid, 0, 1) # the source_control_points may be out of [0, 1]. # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1] grid = 2.0 * grid - 1.0 output_maps = grid_sample(source, grid, mode=sample_mode, canvas=None) return output_maps, source_coordinate def tps_spatial_transform(target_height, target_width, target_control_points, source, source_control_points, sample_mode='bilinear'): ''' target_control_points: (N, 2) source: (B, C, H, W) source_control_points: (B, N, 2) ''' inverse_kernel = bulid_delta_inverse(target_control_points) target_coordinate_repr = build_target_coordinate_matrix(target_height, target_width, target_control_points) return tps_sampler(target_height, target_width, inverse_kernel, target_coordinate_repr, source, source_control_points, sample_mode) class TPSSpatialTransformer(nn.Module): def __init__(self, target_height, target_width, target_control_points): super(TPSSpatialTransformer, self).__init__() self.target_height, self.target_width = target_height, target_width self.num_control_points = target_control_points.shape[0] # create padded kernel matrix inverse_kernel = bulid_delta_inverse(target_control_points) # create target coordinate matrix target_coordinate_repr = build_target_coordinate_matrix(target_height, target_width, target_control_points) # register precomputed matrices self.register_buffer('inverse_kernel', inverse_kernel) #self.register_buffer('padding_matrix', torch.zeros(3, 2)) self.register_buffer('target_coordinate_repr', target_coordinate_repr) self.register_buffer('target_control_points', target_control_points) def forward(self, source, source_control_points): assert source_control_points.ndimension() == 3 assert source_control_points.size(1) == self.num_control_points assert source_control_points.size(2) == 2 return tps_sampler(self.target_height, self.target_width, self.inverse_kernel, self.target_coordinate_repr, source, source_control_points) ================================================ FILE: scripts/demo.py ================================================ import os import sys import argparse import numpy as np import cv2 import torch from PIL import Image sys.path.append('.') from training.config import get_config from training.inference import Inference from training.utils import create_logger, print_args def main(config, args): logger = create_logger(args.save_folder, args.name, 'info', console=True) print_args(args, logger) logger.info(config) inference = Inference(config, args, args.load_path) n_imgname = sorted(os.listdir(args.source_dir)) m_imgname = sorted(os.listdir(args.reference_dir)) for i, (imga_name, imgb_name) in enumerate(zip(n_imgname, m_imgname)): imgA = Image.open(os.path.join(args.source_dir, imga_name)).convert('RGB') imgB = Image.open(os.path.join(args.reference_dir, imgb_name)).convert('RGB') result = inference.transfer(imgA, imgB, postprocess=True) if result is None: continue imgA = np.array(imgA); imgB = np.array(imgB) h, w, _ = imgA.shape result = result.resize((h, w)); result = np.array(result) vis_image = np.hstack((imgA, imgB, result)) save_path = os.path.join(args.save_folder, f"result_{i}.png") Image.fromarray(vis_image.astype(np.uint8)).save(save_path) if __name__ == "__main__": parser = argparse.ArgumentParser("argument for training") parser.add_argument("--name", type=str, default='demo') parser.add_argument("--save_path", type=str, default='result', help="path to save model") parser.add_argument("--load_path", type=str, help="folder to load model", default='ckpts/sow_pyramid_a5_e3d2_remapped.pth') parser.add_argument("--source-dir", type=str, default="assets/images/non-makeup") parser.add_argument("--reference-dir", type=str, default="assets/images/makeup") parser.add_argument("--gpu", default='0', type=str, help="GPU id to use.") args = parser.parse_args() args.gpu = 'cuda:' + args.gpu args.device = torch.device(args.gpu) args.save_folder = os.path.join(args.save_path, args.name) if not os.path.exists(args.save_folder): os.makedirs(args.save_folder) config = get_config() main(config, args) ================================================ FILE: scripts/train.py ================================================ import os import sys import argparse import torch from torch.utils.data import DataLoader sys.path.append('.') from training.config import get_config from training.dataset import MakeupDataset from training.solver import Solver from training.utils import create_logger, print_args def main(config, args): logger = create_logger(args.save_folder, args.name, 'info', console=True) print_args(args, logger) logger.info(config) dataset = MakeupDataset(config) data_loader = DataLoader(dataset, batch_size=config.DATA.BATCH_SIZE, num_workers=config.DATA.NUM_WORKERS, shuffle=True) solver = Solver(config, args, logger) solver.train(data_loader) if __name__ == "__main__": parser = argparse.ArgumentParser("argument for training") parser.add_argument("--name", type=str, default='elegant') parser.add_argument("--save_path", type=str, default='results', help="path to save model") parser.add_argument("--load_folder", type=str, help="path to load model", default=None) parser.add_argument("--keepon", default=False, action="store_true", help='keep on training') parser.add_argument("--gpu", default='0', type=str, help="GPU id to use.") args = parser.parse_args() config = get_config() #args.gpu = 'cuda:' + args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu #args.device = torch.device(args.gpu) args.device = torch.device('cuda:0') args.save_folder = os.path.join(args.save_path, args.name) if not os.path.exists(args.save_folder): os.makedirs(args.save_folder) main(config, args) ================================================ FILE: training/__init__.py ================================================ ================================================ FILE: training/config.py ================================================ from fvcore.common.config import CfgNode """ This file defines default options of configurations. It will be further merged by yaml files and options from the command-line. Note that *any* hyper-parameters should be firstly defined here to enable yaml and command-line configuration. """ _C = CfgNode() # Logging and saving _C.LOG = CfgNode() _C.LOG.SAVE_FREQ = 10 _C.LOG.VIS_FREQ = 1 # Data settings _C.DATA = CfgNode() _C.DATA.PATH = './data/MT-Dataset' _C.DATA.NUM_WORKERS = 4 _C.DATA.BATCH_SIZE = 1 _C.DATA.IMG_SIZE = 256 # Training hyper-parameters _C.TRAINING = CfgNode() _C.TRAINING.G_LR = 2e-4 _C.TRAINING.D_LR = 2e-4 _C.TRAINING.BETA1 = 0.5 _C.TRAINING.BETA2 = 0.999 _C.TRAINING.NUM_EPOCHS = 50 _C.TRAINING.LR_DECAY_FACTOR = 5e-2 _C.TRAINING.DOUBLE_D = False # Loss weights _C.LOSS = CfgNode() _C.LOSS.LAMBDA_A = 10.0 _C.LOSS.LAMBDA_B = 10.0 _C.LOSS.LAMBDA_IDT = 0.5 _C.LOSS.LAMBDA_REC = 10 _C.LOSS.LAMBDA_MAKEUP = 100 _C.LOSS.LAMBDA_SKIN = 0.1 _C.LOSS.LAMBDA_EYE = 1.5 _C.LOSS.LAMBDA_LIP = 1 _C.LOSS.LAMBDA_MAKEUP_LIP = _C.LOSS.LAMBDA_MAKEUP * _C.LOSS.LAMBDA_LIP _C.LOSS.LAMBDA_MAKEUP_SKIN = _C.LOSS.LAMBDA_MAKEUP * _C.LOSS.LAMBDA_SKIN _C.LOSS.LAMBDA_MAKEUP_EYE = _C.LOSS.LAMBDA_MAKEUP * _C.LOSS.LAMBDA_EYE _C.LOSS.LAMBDA_VGG = 5e-3 # Model structure _C.MODEL = CfgNode() _C.MODEL.D_TYPE = 'SN' _C.MODEL.D_REPEAT_NUM = 3 _C.MODEL.D_CONV_DIM = 64 _C.MODEL.G_CONV_DIM = 64 _C.MODEL.NUM_HEAD = 1 _C.MODEL.DOUBLE_E = False _C.MODEL.USE_FF = False _C.MODEL.NUM_LAYER_E = 3 _C.MODEL.NUM_LAYER_D = 2 _C.MODEL.WINDOW_SIZE = 16 _C.MODEL.MERGE_MODE = 'conv' # Preprocessing _C.PREPROCESS = CfgNode() _C.PREPROCESS.UP_RATIO = 0.6 / 0.85 # delta_size / face_size _C.PREPROCESS.DOWN_RATIO = 0.2 / 0.85 # delta_size / face_size _C.PREPROCESS.WIDTH_RATIO = 0.2 / 0.85 # delta_size / face_size _C.PREPROCESS.LIP_CLASS = [7, 9] _C.PREPROCESS.FACE_CLASS = [1, 6] _C.PREPROCESS.EYEBROW_CLASS = [2, 3] _C.PREPROCESS.EYE_CLASS = [4, 5] _C.PREPROCESS.LANDMARK_POINTS = 68 # Pseudo ground truth _C.PGT = CfgNode() _C.PGT.EYE_MARGIN = 12 _C.PGT.LIP_MARGIN = 4 _C.PGT.ANNEALING = True _C.PGT.SKIN_ALPHA = 0.3 _C.PGT.SKIN_ALPHA_MILESTONES = (0, 12, 24, 50) _C.PGT.SKIN_ALPHA_VALUES = (0.2, 0.4, 0.3, 0.2) _C.PGT.EYE_ALPHA = 0.8 _C.PGT.EYE_ALPHA_MILESTONES = (0, 12, 24, 50) _C.PGT.EYE_ALPHA_VALUES = (0.6, 0.8, 0.6, 0.4) _C.PGT.LIP_ALPHA = 0.1 _C.PGT.LIP_ALPHA_MILESTONES = (0, 12, 24, 50) _C.PGT.LIP_ALPHA_VALUES = (0.05, 0.2, 0.1, 0.0) # Postprocessing _C.POSTPROCESS = CfgNode() _C.POSTPROCESS.WILL_DENOISE = False def get_config()->CfgNode: return _C ================================================ FILE: training/dataset.py ================================================ import os from PIL import Image import torch from torch.utils.data import Dataset, DataLoader from training.config import get_config from training.preprocess import PreProcess class MakeupDataset(Dataset): def __init__(self, config=None): super(MakeupDataset, self).__init__() if config is None: config = get_config() self.root = config.DATA.PATH with open(os.path.join(config.DATA.PATH, 'makeup.txt'), 'r') as f: self.makeup_names = [name.strip() for name in f.readlines()] with open(os.path.join(config.DATA.PATH, 'non-makeup.txt'), 'r') as f: self.non_makeup_names = [name.strip() for name in f.readlines()] self.preprocessor = PreProcess(config, need_parser=False) self.img_size = config.DATA.IMG_SIZE def load_from_file(self, img_name): image = Image.open(os.path.join(self.root, 'images', img_name)).convert('RGB') mask = self.preprocessor.load_mask(os.path.join(self.root, 'segs', img_name)) base_name = os.path.splitext(img_name)[0] lms = self.preprocessor.load_lms(os.path.join(self.root, 'lms', f'{base_name}.npy')) return self.preprocessor.process(image, mask, lms) def __len__(self): return max(len(self.makeup_names), len(self.non_makeup_names)) def __getitem__(self, index): idx_s = torch.randint(0, len(self.non_makeup_names), (1, )).item() idx_r = torch.randint(0, len(self.makeup_names), (1, )).item() name_s = self.non_makeup_names[idx_s] name_r = self.makeup_names[idx_r] source = self.load_from_file(name_s) reference = self.load_from_file(name_r) return source, reference def get_loader(config): dataset = MakeupDataset(config) dataloader = DataLoader(dataset=dataset, batch_size=config.DATA.BATCH_SIZE, num_workers=config.DATA.NUM_WORKERS) return dataloader if __name__ == "__main__": dataset = MakeupDataset() dataloader = DataLoader(dataset, batch_size=1, num_workers=16) for e in range(10): for i, (point_s, point_r) in enumerate(dataloader): pass ================================================ FILE: training/inference.py ================================================ from typing import List import numpy as np import cv2 from PIL import Image import torch import torch.nn.functional as F from torchvision.transforms import ToPILImage from training.solver import Solver from training.preprocess import PreProcess from models.modules.pseudo_gt import expand_area, mask_blend class InputSample: def __init__(self, inputs, apply_mask=None): self.inputs = inputs self.transfer_input = None self.attn_out_list = None self.apply_mask = apply_mask def clear(self): self.transfer_input = None self.attn_out_list = None class Inference: """ An inference wrapper for makeup transfer. It takes two image `source` and `reference` in, and transfers the makeup of reference to source. """ def __init__(self, config, args, model_path="G.pth"): self.device = args.device self.solver = Solver(config, args, inference=model_path) self.preprocess = PreProcess(config, args.device) self.denoise = config.POSTPROCESS.WILL_DENOISE self.img_size = config.DATA.IMG_SIZE # TODO: can be a hyper-parameter self.eyeblur = {'margin': 12, 'blur_size':7} def prepare_input(self, *data_inputs): """ data_inputs: List[image, mask, diff, lms] """ inputs = [] for i in range(len(data_inputs)): inputs.append(data_inputs[i].to(self.device).unsqueeze(0)) # prepare mask inputs[1] = torch.cat((inputs[1][:,0:1], inputs[1][:,1:].sum(dim=1, keepdim=True)), dim=1) return inputs def postprocess(self, source, crop_face, result): if crop_face is not None: source = source.crop( (crop_face.left(), crop_face.top(), crop_face.right(), crop_face.bottom())) source = np.array(source) result = np.array(result) height, width = source.shape[:2] small_source = cv2.resize(source, (self.img_size, self.img_size)) laplacian_diff = source.astype( np.float) - cv2.resize(small_source, (width, height)).astype(np.float) result = (cv2.resize(result, (width, height)) + laplacian_diff).round().clip(0, 255) result = result.astype(np.uint8) if self.denoise: result = cv2.fastNlMeansDenoisingColored(result) result = Image.fromarray(result).convert('RGB') return result def generate_source_sample(self, source_input): """ source_input: List[image, mask, diff, lms] """ source_input = self.prepare_input(*source_input) return InputSample(source_input) def generate_reference_sample(self, reference_input, apply_mask=None, source_mask=None, mask_area=None, saturation=1.0): """ all the operations on the mask, e.g., partial mask, saturation, should be finally defined in apply_mask """ if source_mask is not None and mask_area is not None: apply_mask = self.generate_partial_mask(source_mask, mask_area, saturation) apply_mask = apply_mask.unsqueeze(0).to(self.device) reference_input = self.prepare_input(*reference_input) if apply_mask is None: apply_mask = torch.ones(1, 1, self.img_size, self.img_size).to(self.device) return InputSample(reference_input, apply_mask) def generate_partial_mask(self, source_mask, mask_area='full', saturation=1.0): """ source_mask: (C, H, W), lip, face, left eye, right eye return: apply_mask: (1, H, W) """ assert mask_area in ['full', 'skin', 'lip', 'eye'] if mask_area == 'full': return torch.sum(source_mask[0:2], dim=0, keepdim=True) * saturation elif mask_area == 'lip': return source_mask[0:1] * saturation elif mask_area == 'skin': mask_l_eye = expand_area(source_mask[2:3], self.eyeblur['margin']) #* source_mask[1:2] mask_r_eye = expand_area(source_mask[3:4], self.eyeblur['margin']) #* source_mask[1:2] mask_eye = mask_l_eye + mask_r_eye #mask_eye = mask_blend(mask_eye, 1.0, source_mask[1:2], blur_size=self.eyeblur['blur_size']) mask_eye = mask_blend(mask_eye, 1.0, blur_size=self.eyeblur['blur_size']) return source_mask[1:2] * (1 - mask_eye) * saturation elif mask_area == 'eye': mask_l_eye = expand_area(source_mask[2:3], self.eyeblur['margin']) #* source_mask[1:2] mask_r_eye = expand_area(source_mask[3:4], self.eyeblur['margin']) #* source_mask[1:2] mask_eye = mask_l_eye + mask_r_eye #mask_eye = mask_blend(mask_eye, saturation, source_mask[1:2], blur_size=self.eyeblur['blur_size']) mask_eye = mask_blend(mask_eye, saturation, blur_size=self.eyeblur['blur_size']) return mask_eye @torch.no_grad() def interface_transfer(self, source_sample: InputSample, reference_samples: List[InputSample]): """ Input: a source sample and multiple reference samples Return: PIL.Image, the fused result """ # encode source if source_sample.transfer_input is None: source_sample.transfer_input = self.solver.G.get_transfer_input(*source_sample.inputs) # encode references for r_sample in reference_samples: if r_sample.transfer_input is None: r_sample.transfer_input = self.solver.G.get_transfer_input(*r_sample.inputs, True) # self attention if source_sample.attn_out_list is None: source_sample.attn_out_list = self.solver.G.get_transfer_output( *source_sample.transfer_input, *source_sample.transfer_input ) # full transfer for each reference for r_sample in reference_samples: if r_sample.attn_out_list is None: r_sample.attn_out_list = self.solver.G.get_transfer_output( *source_sample.transfer_input, *r_sample.transfer_input ) # fusion # if the apply_mask is changed without changing source and references, # only the following steps are required fused_attn_out_list = [] for i in range(len(source_sample.attn_out_list)): init_attn_out = torch.zeros_like(source_sample.attn_out_list[i], device=self.device) fused_attn_out_list.append(init_attn_out) apply_mask_sum = torch.zeros((1, 1, self.img_size, self.img_size), device=self.device) for r_sample in reference_samples: if r_sample.apply_mask is not None: apply_mask_sum += r_sample.apply_mask for i in range(len(source_sample.attn_out_list)): feature_size = r_sample.attn_out_list[i].shape[2] apply_mask = F.interpolate(r_sample.apply_mask, feature_size, mode='nearest') fused_attn_out_list[i] += apply_mask * r_sample.attn_out_list[i] # self as reference source_apply_mask = 1 - apply_mask_sum.clamp(0, 1) for i in range(len(source_sample.attn_out_list)): feature_size = source_sample.attn_out_list[i].shape[2] apply_mask = F.interpolate(source_apply_mask, feature_size, mode='nearest') fused_attn_out_list[i] += apply_mask * source_sample.attn_out_list[i] # decode result = self.solver.G.decode( source_sample.transfer_input[0], fused_attn_out_list ) result = self.solver.de_norm(result).squeeze(0) result = ToPILImage()(result.cpu()) return result def transfer(self, source: Image, reference: Image, postprocess=True): """ Args: source (Image): The image where makeup will be transfered to. reference (Image): Image containing targeted makeup. Return: Image: Transfered image. """ source_input, face, crop_face = self.preprocess(source) reference_input, _, _ = self.preprocess(reference) if not (source_input and reference_input): return None #source_sample = self.generate_source_sample(source_input) #reference_samples = [self.generate_reference_sample(reference_input)] #result = self.interface_transfer(source_sample, reference_samples) source_input = self.prepare_input(*source_input) reference_input = self.prepare_input(*reference_input) result = self.solver.test(*source_input, *reference_input) if not postprocess: return result else: return self.postprocess(source, crop_face, result) def joint_transfer(self, source: Image, reference_lip: Image, reference_skin: Image, reference_eye: Image, postprocess=True): source_input, face, crop_face = self.preprocess(source) lip_input, _, _ = self.preprocess(reference_lip) skin_input, _, _ = self.preprocess(reference_skin) eye_input, _, _ = self.preprocess(reference_eye) if not (source_input and lip_input and skin_input and eye_input): return None source_mask = source_input[1] source_sample = self.generate_source_sample(source_input) reference_samples = [ self.generate_reference_sample(lip_input, source_mask=source_mask, mask_area='lip'), self.generate_reference_sample(skin_input, source_mask=source_mask, mask_area='skin'), self.generate_reference_sample(eye_input, source_mask=source_mask, mask_area='eye') ] result = self.interface_transfer(source_sample, reference_samples) if not postprocess: return result else: return self.postprocess(source, crop_face, result) ================================================ FILE: training/preprocess.py ================================================ import os import sys import cv2 from PIL import Image import numpy as np import torch import torch.nn.functional as F from torchvision import transforms from torchvision.transforms import functional sys.path.append('.') import faceutils as futils from training.config import get_config class PreProcess: def __init__(self, config, need_parser=True, device='cpu'): self.img_size = config.DATA.IMG_SIZE self.device = device xs, ys = np.meshgrid( np.linspace( 0, self.img_size - 1, self.img_size ), np.linspace( 0, self.img_size - 1, self.img_size ) ) xs = xs[None].repeat(config.PREPROCESS.LANDMARK_POINTS, axis=0) ys = ys[None].repeat(config.PREPROCESS.LANDMARK_POINTS, axis=0) fix = np.concatenate([ys, xs], axis=0) self.fix = torch.Tensor(fix) #(136, h, w) if need_parser: self.face_parse = futils.mask.FaceParser(device=device) self.up_ratio = config.PREPROCESS.UP_RATIO self.down_ratio = config.PREPROCESS.DOWN_RATIO self.width_ratio = config.PREPROCESS.WIDTH_RATIO self.lip_class = config.PREPROCESS.LIP_CLASS self.face_class = config.PREPROCESS.FACE_CLASS self.eyebrow_class = config.PREPROCESS.EYEBROW_CLASS self.eye_class = config.PREPROCESS.EYE_CLASS self.transform = transforms.Compose([ transforms.Resize(config.DATA.IMG_SIZE), transforms.ToTensor(), transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])]) ############################## Mask Process ############################## # mask attribute: 0:background 1:face 2:left-eyebrow 3:right-eyebrow 4:left-eye 5: right-eye 6: nose # 7: upper-lip 8: teeth 9: under-lip 10:hair 11: left-ear 12: right-ear 13: neck def mask_process(self, mask: torch.Tensor): ''' mask: (1, h, w) ''' mask_lip = (mask == self.lip_class[0]).float() + (mask == self.lip_class[1]).float() mask_face = (mask == self.face_class[0]).float() + (mask == self.face_class[1]).float() #mask_eyebrow_left = (mask == self.eyebrow_class[0]).float() #mask_eyebrow_right = (mask == self.eyebrow_class[1]).float() mask_face += (mask == self.eyebrow_class[0]).float() mask_face += (mask == self.eyebrow_class[1]).float() mask_eye_left = (mask == self.eye_class[0]).float() mask_eye_right = (mask == self.eye_class[1]).float() #mask_list = [mask_lip, mask_face, mask_eyebrow_left, mask_eyebrow_right, mask_eye_left, mask_eye_right] mask_list = [mask_lip, mask_face, mask_eye_left, mask_eye_right] mask_aug = torch.cat(mask_list, 0) # (C, H, W) return mask_aug def save_mask(self, mask: torch.Tensor, path): assert mask.shape[0] == 1 mask = mask.squeeze(0).numpy().astype(np.uint8) mask = Image.fromarray(mask) mask.save(path) def load_mask(self, path): mask = np.array(Image.open(path).convert('L')) mask = torch.FloatTensor(mask).unsqueeze(0) mask = functional.resize(mask, self.img_size, transforms.InterpolationMode.NEAREST) return mask ############################## Landmarks Process ############################## def lms_process(self, image:Image): face = futils.dlib.detect(image) # face: rectangles, List of rectangles of face region: [(left, top), (right, bottom)] if not face: return None face = face[0] lms = futils.dlib.landmarks(image, face) * self.img_size / image.width # scale to fit self.img_size # lms: narray, the position of 68 key points, (68 ,2) lms = torch.IntTensor(lms.round()).clamp_max_(self.img_size - 1) # distinguish upper and lower lips lms[61:64,0] -= 1; lms[65:68,0] += 1 for i in range(3): if torch.sum(torch.abs(lms[61+i] - lms[67-i])) == 0: lms[61+i,0] -= 1; lms[67-i,0] += 1 # double check '''for i in range(48, 67): for j in range(i+1, 68): if torch.sum(torch.abs(lms[i] - lms[j])) == 0: lms[i,0] -= 1; lms[j,0] += 1''' return lms def diff_process(self, lms: torch.Tensor, normalize=False): ''' lms:(68, 2) ''' lms = lms.transpose(1, 0).reshape(-1, 1, 1) # (136, 1, 1) diff = self.fix - lms # (136, h, w) if normalize: norm = torch.norm(diff, dim=0, keepdim=True).repeat(diff.shape[0], 1, 1) norm = torch.where(norm == 0, torch.tensor(1e10), norm) diff /= norm return diff def save_lms(self, lms: torch.Tensor, path): lms = lms.numpy() np.save(path, lms) def load_lms(self, path): lms = np.load(path) return torch.IntTensor(lms) ############################## Compose Process ############################## def preprocess(self, image: Image, is_crop=True): ''' return: image: Image, (H, W), mask: tensor, (1, H, W) ''' face = futils.dlib.detect(image) # face: rectangles, List of rectangles of face region: [(left, top), (right, bottom)] if not face: return None, None, None face_on_image = face[0] if is_crop: image, face, crop_face = futils.dlib.crop( image, face_on_image, self.up_ratio, self.down_ratio, self.width_ratio) else: face = face[0]; crop_face = None # image: Image, cropped face # face: the same as above # crop face: rectangle, face region in cropped face np_image = np.array(image) # (h', w', 3) mask = self.face_parse.parse(cv2.resize(np_image, (512, 512))).cpu() # obtain face parsing result # mask: Tensor, (512, 512) mask = F.interpolate( mask.view(1, 1, 512, 512), (self.img_size, self.img_size), mode="nearest").squeeze(0).long() #(1, H, W) lms = futils.dlib.landmarks(image, face) * self.img_size / image.width # scale to fit self.img_size # lms: narray, the position of 68 key points, (68 ,2) lms = torch.IntTensor(lms.round()).clamp_max_(self.img_size - 1) # distinguish upper and lower lips lms[61:64,0] -= 1; lms[65:68,0] += 1 for i in range(3): if torch.sum(torch.abs(lms[61+i] - lms[67-i])) == 0: lms[61+i,0] -= 1; lms[67-i,0] += 1 image = image.resize((self.img_size, self.img_size), Image.ANTIALIAS) return [image, mask, lms], face_on_image, crop_face def process(self, image: Image, mask: torch.Tensor, lms: torch.Tensor): image = self.transform(image) mask = self.mask_process(mask) diff = self.diff_process(lms) return [image, mask, diff, lms] def __call__(self, image:Image, is_crop=True): source, face_on_image, crop_face = self.preprocess(image, is_crop) if source is None: return None, None, None return self.process(*source), face_on_image, crop_face if __name__ == "__main__": config = get_config() preprocessor = PreProcess(config, device='cuda:0') if not os.path.exists(os.path.join(config.DATA.PATH, 'lms')): os.makedirs(os.path.join(config.DATA.PATH, 'lms', 'makeup')) os.makedirs(os.path.join(config.DATA.PATH, 'lms', 'non-makeup')) # process makeup images print("Processing makeup images...") with open(os.path.join(config.DATA.PATH, 'makeup.txt'), 'r') as f: for line in f.readlines(): img_name = line.strip() raw_image = Image.open(os.path.join(config.DATA.PATH, 'images', img_name)).convert('RGB') lms = preprocessor.lms_process(raw_image) if lms is not None: base_name = os.path.splitext(img_name)[0] preprocessor.save_lms(lms, os.path.join(config.DATA.PATH, 'lms', f'{base_name}.npy')) print("Done.") # process non-makeup images print("Processing non-makeup images...") with open(os.path.join(config.DATA.PATH, 'non-makeup.txt'), 'r') as f: for line in f.readlines(): img_name = line.strip() raw_image = Image.open(os.path.join(config.DATA.PATH, 'images', img_name)).convert('RGB') lms = preprocessor.lms_process(raw_image) if lms is not None: base_name = os.path.splitext(img_name)[0] preprocessor.save_lms(lms, os.path.join(config.DATA.PATH, 'lms', f'{base_name}.npy')) print("Done.") ================================================ FILE: training/solver.py ================================================ import os import time import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torchvision.transforms import ToPILImage from torchvision.utils import save_image, make_grid import torch.nn.init as init from tqdm import tqdm from models.modules.pseudo_gt import expand_area from models.model import get_discriminator, get_generator, vgg16 from models.loss import GANLoss, MakeupLoss, ComposePGT, AnnealingComposePGT from training.utils import plot_curves class Solver(): def __init__(self, config, args, logger=None, inference=False): self.G = get_generator(config) if inference: self.G.load_state_dict(torch.load(inference, map_location=args.device)) self.G = self.G.to(args.device).eval() return self.double_d = config.TRAINING.DOUBLE_D self.D_A = get_discriminator(config) if self.double_d: self.D_B = get_discriminator(config) self.load_folder = args.load_folder self.save_folder = args.save_folder self.vis_folder = os.path.join(args.save_folder, 'visualization') if not os.path.exists(self.vis_folder): os.makedirs(self.vis_folder) self.vis_freq = config.LOG.VIS_FREQ self.save_freq = config.LOG.SAVE_FREQ # Data & PGT self.img_size = config.DATA.IMG_SIZE self.margins = {'eye':config.PGT.EYE_MARGIN, 'lip':config.PGT.LIP_MARGIN} self.pgt_annealing = config.PGT.ANNEALING if self.pgt_annealing: self.pgt_maker = AnnealingComposePGT(self.margins, config.PGT.SKIN_ALPHA_MILESTONES, config.PGT.SKIN_ALPHA_VALUES, config.PGT.EYE_ALPHA_MILESTONES, config.PGT.EYE_ALPHA_VALUES, config.PGT.LIP_ALPHA_MILESTONES, config.PGT.LIP_ALPHA_VALUES ) else: self.pgt_maker = ComposePGT(self.margins, config.PGT.SKIN_ALPHA, config.PGT.EYE_ALPHA, config.PGT.LIP_ALPHA ) self.pgt_maker.eval() # Hyper-param self.num_epochs = config.TRAINING.NUM_EPOCHS self.g_lr = config.TRAINING.G_LR self.d_lr = config.TRAINING.D_LR self.beta1 = config.TRAINING.BETA1 self.beta2 = config.TRAINING.BETA2 self.lr_decay_factor = config.TRAINING.LR_DECAY_FACTOR # Loss param self.lambda_idt = config.LOSS.LAMBDA_IDT self.lambda_A = config.LOSS.LAMBDA_A self.lambda_B = config.LOSS.LAMBDA_B self.lambda_lip = config.LOSS.LAMBDA_MAKEUP_LIP self.lambda_skin = config.LOSS.LAMBDA_MAKEUP_SKIN self.lambda_eye = config.LOSS.LAMBDA_MAKEUP_EYE self.lambda_vgg = config.LOSS.LAMBDA_VGG self.device = args.device self.keepon = args.keepon self.logger = logger self.loss_logger = { 'D-A-loss_real':[], 'D-A-loss_fake':[], 'D-B-loss_real':[], 'D-B-loss_fake':[], 'G-A-loss-adv':[], 'G-B-loss-adv':[], 'G-loss-idt':[], 'G-loss-img-rec':[], 'G-loss-vgg-rec':[], 'G-loss-rec':[], 'G-loss-skin-pgt':[], 'G-loss-eye-pgt':[], 'G-loss-lip-pgt':[], 'G-loss-pgt':[], 'G-loss':[], 'D-A-loss':[], 'D-B-loss':[] } self.build_model() super(Solver, self).__init__() def print_network(self, model, name): num_params = 0 for p in model.parameters(): num_params += p.numel() if self.logger is not None: self.logger.info('{:s}, the number of parameters: {:d}'.format(name, num_params)) else: print('{:s}, the number of parameters: {:d}'.format(name, num_params)) # For generator def weights_init_xavier(self, m): classname = m.__class__.__name__ if classname.find('Conv') != -1: init.xavier_normal_(m.weight.data, gain=1.0) elif classname.find('Linear') != -1: init.xavier_normal_(m.weight.data, gain=1.0) def build_model(self): self.G.apply(self.weights_init_xavier) self.D_A.apply(self.weights_init_xavier) if self.double_d: self.D_B.apply(self.weights_init_xavier) if self.keepon: self.load_checkpoint() self.criterionL1 = torch.nn.L1Loss() self.criterionL2 = torch.nn.MSELoss() self.criterionGAN = GANLoss(gan_mode='lsgan') self.criterionPGT = MakeupLoss() self.vgg = vgg16(pretrained=True) # Optimizers self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.d_A_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D_A.parameters()), self.d_lr, [self.beta1, self.beta2]) if self.double_d: self.d_B_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D_B.parameters()), self.d_lr, [self.beta1, self.beta2]) self.g_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.g_optimizer, T_max=self.num_epochs, eta_min=self.g_lr * self.lr_decay_factor) self.d_A_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.d_A_optimizer, T_max=self.num_epochs, eta_min=self.d_lr * self.lr_decay_factor) if self.double_d: self.d_B_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.d_B_optimizer, T_max=self.num_epochs, eta_min=self.d_lr * self.lr_decay_factor) # Print networks self.print_network(self.G, 'G') self.print_network(self.D_A, 'D_A') if self.double_d: self.print_network(self.D_B, 'D_B') self.G.to(self.device) self.vgg.to(self.device) self.D_A.to(self.device) if self.double_d: self.D_B.to(self.device) def train(self, data_loader): self.len_dataset = len(data_loader) for self.epoch in range(1, self.num_epochs + 1): self.start_time = time.time() loss_tmp = self.get_loss_tmp() self.G.train(); self.D_A.train(); if self.double_d: self.D_B.train() losses_G = []; losses_D_A = []; losses_D_B = [] with tqdm(data_loader, desc="training") as pbar: for step, (source, reference) in enumerate(pbar): # image, mask, diff, lms image_s, image_r = source[0].to(self.device), reference[0].to(self.device) # (b, c, h, w) mask_s_full, mask_r_full = source[1].to(self.device), reference[1].to(self.device) # (b, c', h, w) diff_s, diff_r = source[2].to(self.device), reference[2].to(self.device) # (b, 136, h, w) lms_s, lms_r = source[3].to(self.device), reference[3].to(self.device) # (b, K, 2) # process input mask mask_s = torch.cat((mask_s_full[:,0:1], mask_s_full[:,1:].sum(dim=1, keepdim=True)), dim=1) mask_r = torch.cat((mask_r_full[:,0:1], mask_r_full[:,1:].sum(dim=1, keepdim=True)), dim=1) #mask_s = mask_s_full[:,:2]; mask_r = mask_r_full[:,:2] # ================= Generate ================== # fake_A = self.G(image_s, image_r, mask_s, mask_r, diff_s, diff_r, lms_s, lms_r) fake_B = self.G(image_r, image_s, mask_r, mask_s, diff_r, diff_s, lms_r, lms_s) # generate pseudo ground truth pgt_A = self.pgt_maker(image_s, image_r, mask_s_full, mask_r_full, lms_s, lms_r) pgt_B = self.pgt_maker(image_r, image_s, mask_r_full, mask_s_full, lms_r, lms_s) # ================== Train D ================== # # training D_A, D_A aims to distinguish class B # Real out = self.D_A(image_r) d_loss_real = self.criterionGAN(out, True) # Fake out = self.D_A(fake_A.detach()) d_loss_fake = self.criterionGAN(out, False) # Backward + Optimize d_loss = (d_loss_real + d_loss_fake) * 0.5 self.d_A_optimizer.zero_grad() d_loss.backward() self.d_A_optimizer.step() # Logging loss_tmp['D-A-loss_real'] += d_loss_real.item() loss_tmp['D-A-loss_fake'] += d_loss_fake.item() losses_D_A.append(d_loss.item()) # training D_B, D_B aims to distinguish class A # Real if self.double_d: out = self.D_B(image_s) else: out = self.D_A(image_s) d_loss_real = self.criterionGAN(out, True) # Fake if self.double_d: out = self.D_B(fake_B.detach()) else: out = self.D_A(fake_B.detach()) d_loss_fake = self.criterionGAN(out, False) # Backward + Optimize d_loss = (d_loss_real+ d_loss_fake) * 0.5 if self.double_d: self.d_B_optimizer.zero_grad() d_loss.backward() self.d_B_optimizer.step() else: self.d_A_optimizer.zero_grad() d_loss.backward() self.d_A_optimizer.step() # Logging loss_tmp['D-B-loss_real'] += d_loss_real.item() loss_tmp['D-B-loss_fake'] += d_loss_fake.item() losses_D_B.append(d_loss.item()) # ================== Train G ================== # # G should be identity if ref_B or org_A is fed idt_A = self.G(image_s, image_s, mask_s, mask_s, diff_s, diff_s, lms_s, lms_s) idt_B = self.G(image_r, image_r, mask_r, mask_r, diff_r, diff_r, lms_r, lms_r) loss_idt_A = self.criterionL1(idt_A, image_s) * self.lambda_A * self.lambda_idt loss_idt_B = self.criterionL1(idt_B, image_r) * self.lambda_B * self.lambda_idt # loss_idt loss_idt = (loss_idt_A + loss_idt_B) * 0.5 # GAN loss D_A(G_A(A)) pred_fake = self.D_A(fake_A) g_A_loss_adv = self.criterionGAN(pred_fake, True) # GAN loss D_B(G_B(B)) if self.double_d: pred_fake = self.D_B(fake_B) else: pred_fake = self.D_A(fake_B) g_B_loss_adv = self.criterionGAN(pred_fake, True) # Makeup loss g_A_loss_pgt = 0; g_B_loss_pgt = 0 g_A_lip_loss_pgt = self.criterionPGT(fake_A, pgt_A, mask_s_full[:,0:1]) * self.lambda_lip g_B_lip_loss_pgt = self.criterionPGT(fake_B, pgt_B, mask_r_full[:,0:1]) * self.lambda_lip g_A_loss_pgt += g_A_lip_loss_pgt g_B_loss_pgt += g_B_lip_loss_pgt mask_s_eye = expand_area(mask_s_full[:,2:4].sum(dim=1, keepdim=True), self.margins['eye']) mask_r_eye = expand_area(mask_r_full[:,2:4].sum(dim=1, keepdim=True), self.margins['eye']) mask_s_eye = mask_s_eye * mask_s_full[:,1:2] mask_r_eye = mask_r_eye * mask_r_full[:,1:2] g_A_eye_loss_pgt = self.criterionPGT(fake_A, pgt_A, mask_s_eye) * self.lambda_eye g_B_eye_loss_pgt = self.criterionPGT(fake_B, pgt_B, mask_r_eye) * self.lambda_eye g_A_loss_pgt += g_A_eye_loss_pgt g_B_loss_pgt += g_B_eye_loss_pgt mask_s_skin = mask_s_full[:,1:2] * (1 - mask_s_eye) mask_r_skin = mask_r_full[:,1:2] * (1 - mask_r_eye) g_A_skin_loss_pgt = self.criterionPGT(fake_A, pgt_A, mask_s_skin) * self.lambda_skin g_B_skin_loss_pgt = self.criterionPGT(fake_B, pgt_B, mask_r_skin) * self.lambda_skin g_A_loss_pgt += g_A_skin_loss_pgt g_B_loss_pgt += g_B_skin_loss_pgt # cycle loss rec_A = self.G(fake_A, image_s, mask_s, mask_s, diff_s, diff_s, lms_s, lms_s) rec_B = self.G(fake_B, image_r, mask_r, mask_r, diff_r, diff_r, lms_r, lms_r) # cycle loss v2 # rec_A = self.G(fake_A, fake_B, mask_s, mask_r, diff_s, diff_r, lms_s, lms_r) # rec_B = self.G(fake_B, fake_A, mask_r, mask_s, diff_r, diff_s, lms_r, lms_s) g_loss_rec_A = self.criterionL1(rec_A, image_s) * self.lambda_A g_loss_rec_B = self.criterionL1(rec_B, image_r) * self.lambda_B # vgg loss vgg_s = self.vgg(image_s).detach() vgg_fake_A = self.vgg(fake_A) g_loss_A_vgg = self.criterionL2(vgg_fake_A, vgg_s) * self.lambda_A * self.lambda_vgg vgg_r = self.vgg(image_r).detach() vgg_fake_B = self.vgg(fake_B) g_loss_B_vgg = self.criterionL2(vgg_fake_B, vgg_r) * self.lambda_B * self.lambda_vgg loss_rec = (g_loss_rec_A + g_loss_rec_B + g_loss_A_vgg + g_loss_B_vgg) * 0.5 # Combined loss g_loss = g_A_loss_adv + g_B_loss_adv + loss_rec + loss_idt + g_A_loss_pgt + g_B_loss_pgt self.g_optimizer.zero_grad() g_loss.backward() self.g_optimizer.step() # Logging loss_tmp['G-A-loss-adv'] += g_A_loss_adv.item() loss_tmp['G-B-loss-adv'] += g_B_loss_adv.item() loss_tmp['G-loss-idt'] += loss_idt.item() loss_tmp['G-loss-img-rec'] += (g_loss_rec_A + g_loss_rec_B).item() * 0.5 loss_tmp['G-loss-vgg-rec'] += (g_loss_A_vgg + g_loss_B_vgg).item() * 0.5 loss_tmp['G-loss-rec'] += loss_rec.item() loss_tmp['G-loss-skin-pgt'] += (g_A_skin_loss_pgt + g_B_skin_loss_pgt).item() loss_tmp['G-loss-eye-pgt'] += (g_A_eye_loss_pgt + g_B_eye_loss_pgt).item() loss_tmp['G-loss-lip-pgt'] += (g_A_lip_loss_pgt + g_B_lip_loss_pgt).item() loss_tmp['G-loss-pgt'] += (g_A_loss_pgt + g_B_loss_pgt).item() losses_G.append(g_loss.item()) pbar.set_description("Epoch: %d, Step: %d, Loss_G: %0.4f, Loss_A: %0.4f, Loss_B: %0.4f" % \ (self.epoch, step + 1, np.mean(losses_G), np.mean(losses_D_A), np.mean(losses_D_B))) self.end_time = time.time() for k, v in loss_tmp.items(): loss_tmp[k] = v / self.len_dataset loss_tmp['G-loss'] = np.mean(losses_G) loss_tmp['D-A-loss'] = np.mean(losses_D_A) loss_tmp['D-B-loss'] = np.mean(losses_D_B) self.log_loss(loss_tmp) self.plot_loss() # Decay learning rate self.g_scheduler.step() self.d_A_scheduler.step() if self.double_d: self.d_B_scheduler.step() if self.pgt_annealing: self.pgt_maker.step() #save the images if (self.epoch) % self.vis_freq == 0: self.vis_train([image_s.detach().cpu(), image_r.detach().cpu(), fake_A.detach().cpu(), pgt_A.detach().cpu()]) # rec_A.detach().cpu()]) # Save model checkpoints if (self.epoch) % self.save_freq == 0: self.save_models() def get_loss_tmp(self): loss_tmp = { 'D-A-loss_real':0.0, 'D-A-loss_fake':0.0, 'D-B-loss_real':0.0, 'D-B-loss_fake':0.0, 'G-A-loss-adv':0.0, 'G-B-loss-adv':0.0, 'G-loss-idt':0.0, 'G-loss-img-rec':0.0, 'G-loss-vgg-rec':0.0, 'G-loss-rec':0.0, 'G-loss-skin-pgt':0.0, 'G-loss-eye-pgt':0.0, 'G-loss-lip-pgt':0.0, 'G-loss-pgt':0.0, } return loss_tmp def log_loss(self, loss_tmp): if self.logger is not None: self.logger.info('\n' + '='*40 + '\nEpoch {:d}, time {:.2f} s' .format(self.epoch, self.end_time - self.start_time)) else: print('\n' + '='*40 + '\nEpoch {:d}, time {:d} s' .format(self.epoch, self.end_time - self.start_time)) for k, v in loss_tmp.items(): self.loss_logger[k].append(v) if self.logger is not None: self.logger.info('{:s}\t{:.6f}'.format(k, v)) else: print('{:s}\t{:.6f}'.format(k, v)) if self.logger is not None: self.logger.info('='*40) else: print('='*40) def plot_loss(self): G_losses = []; G_names = [] D_A_losses = []; D_A_names = [] D_B_losses = []; D_B_names = [] D_P_losses = []; D_P_names = [] for k, v in self.loss_logger.items(): if 'G' in k: G_names.append(k); G_losses.append(v) elif 'D-A' in k: D_A_names.append(k); D_A_losses.append(v) elif 'D-B' in k: D_B_names.append(k); D_B_losses.append(v) elif 'D-P' in k: D_P_names.append(k); D_P_losses.append(v) plot_curves(self.save_folder, 'G_loss', G_losses, G_names, ylabel='Loss') plot_curves(self.save_folder, 'D-A_loss', D_A_losses, D_A_names, ylabel='Loss') plot_curves(self.save_folder, 'D-B_loss', D_B_losses, D_B_names, ylabel='Loss') def load_checkpoint(self): G_path = os.path.join(self.load_folder, 'G.pth') if os.path.exists(G_path): self.G.load_state_dict(torch.load(G_path, map_location=self.device)) print('loaded trained generator {}..!'.format(G_path)) D_A_path = os.path.join(self.load_folder, 'D_A.pth') if os.path.exists(D_A_path): self.D_A.load_state_dict(torch.load(D_A_path, map_location=self.device)) print('loaded trained discriminator A {}..!'.format(D_A_path)) if self.double_d: D_B_path = os.path.join(self.load_folder, 'D_B.pth') if os.path.exists(D_B_path): self.D_B.load_state_dict(torch.load(D_B_path, map_location=self.device)) print('loaded trained discriminator B {}..!'.format(D_B_path)) def save_models(self): save_dir = os.path.join(self.save_folder, 'epoch_{:d}'.format(self.epoch)) if not os.path.exists(save_dir): os.makedirs(save_dir) torch.save(self.G.state_dict(), os.path.join(save_dir, 'G.pth')) torch.save(self.D_A.state_dict(), os.path.join(save_dir, 'D_A.pth')) if self.double_d: torch.save(self.D_B.state_dict(), os.path.join(save_dir, 'D_B.pth')) def de_norm(self, x): out = (x + 1) / 2 return out.clamp(0, 1) def vis_train(self, img_train_batch): # saving training results img_train_batch = torch.cat(img_train_batch, dim=3) save_path = os.path.join(self.vis_folder, 'epoch_{:d}_fake.png'.format(self.epoch)) vis_image = make_grid(self.de_norm(img_train_batch), 1) save_image(vis_image, save_path) #, normalize=True) def generate(self, image_A, image_B, mask_A=None, mask_B=None, diff_A=None, diff_B=None, lms_A=None, lms_B=None): """image_A is content, image_B is style""" with torch.no_grad(): res = self.G(image_A, image_B, mask_A, mask_B, diff_A, diff_B, lms_A, lms_B) return res def test(self, image_A, mask_A, diff_A, lms_A, image_B, mask_B, diff_B, lms_B): with torch.no_grad(): fake_A = self.generate(image_A, image_B, mask_A, mask_B, diff_A, diff_B, lms_A, lms_B) fake_A = self.de_norm(fake_A) fake_A = fake_A.squeeze(0) return ToPILImage()(fake_A.cpu()) ================================================ FILE: training/utils.py ================================================ import os import logging import numpy as np import matplotlib.pyplot as plt def create_logger(save_path='', file_type='', level='debug', console=True): if level == 'debug': _level = logging.DEBUG elif level == 'info': _level = logging.INFO logger = logging.getLogger() logger.setLevel(_level) if console: cs = logging.StreamHandler() cs.setLevel(_level) logger.addHandler(cs) if save_path != '': file_name = os.path.join(save_path, file_type + '_log.txt') fh = logging.FileHandler(file_name, mode='w') fh.setLevel(_level) logger.addHandler(fh) return logger def print_args(args, logger=None): for k, v in vars(args).items(): if logger is not None: logger.info('{:<16} : {}'.format(k, v)) else: print('{:<16} : {}'.format(k, v)) def plot_single_curve(path, name, point, freq=1, xlabel='Epoch',ylabel=None): x = (np.arange(len(point)) + 1) * freq plt.plot(x, point, color='purple') plt.xlabel(xlabel) if ylabel is None: ylabel = name plt.ylabel(ylabel) plt.savefig(os.path.join(path, name + '.png')) plt.close() def plot_curves(path, name, point_list, curve_names=None, freq=1, xlabel='Epoch',ylabel=None): if curve_names is None: curve_names = [''] * len(point_list) else: assert len(point_list) == len(curve_names) x = (np.arange(len(point_list[0])) + 1) * freq if len(point_list) <= 10: cmap = plt.get_cmap('tab10') else: cmap = plt.get_cmap('tab20') for i, (point, curve_name) in enumerate(zip(point_list, curve_names)): assert len(point) == len(x) plt.plot(x, point, color=cmap(i), label=curve_name) plt.xlabel(xlabel) if ylabel is not None: plt.ylabel(ylabel) plt.legend() plt.savefig(os.path.join(path, name + '.png')) plt.close()