master 880a80b2f891 cached
42 files
282.7 KB
104.8k tokens
223 symbols
1 requests
Download .txt
Showing preview only (297K chars total). Download the full file or copy to clipboard to get everything.
Repository: bj80heyue/One_Shot_Face_Reenactment
Branch: master
Commit: 880a80b2f891
Files: 42
Total size: 282.7 KB

Directory structure:
gitextract_8hbwdasm/

├── .gitignore
├── LICENSE
├── README.md
├── data/
│   ├── poseGuide/
│   │   └── lms_poseGuide.out
│   └── reference/
│       └── lms_ref.out
├── fusion/
│   ├── README.md
│   ├── affineFace.py
│   ├── calcAffine.py
│   ├── parts2lms.py
│   ├── points2heatmap.py
│   ├── test.py
│   └── warper.py
├── loader/
│   ├── __init__.py
│   ├── dataset_basic.py
│   ├── dataset_loader_demo.py
│   └── dataset_loader_train.py
├── model/
│   ├── base_model.py
│   └── spade_model.py
├── net/
│   ├── ResNet.py
│   ├── appear_decoder_net.py
│   ├── appear_encoder_net.py
│   ├── base_net.py
│   ├── discriminator_net.py
│   ├── face_id_mlp_net.py
│   ├── face_id_net.py
│   ├── generaotr_net.py
│   ├── generator_net_concat_1Layer.py
│   └── vgg_net.py
├── opt/
│   ├── __init__.py
│   ├── config.py
│   └── configTrain.py
├── requirements.txt
├── test.py
└── utils/
    ├── __init__.py
    ├── affineFace.py
    ├── affine_util.py
    ├── calcAffine.py
    ├── lms.test
    ├── metric.py
    ├── points2heatmap.py
    ├── transforms.py
    └── warper.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/


================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2019 Stan

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# One-shot Face Reenactment

[[Project]](https://wywu.github.io/projects/ReenactGAN/OneShotReenact.html) [[Paper]](https://arxiv.org/abs/1908.03251) [[Demo]](https://www.youtube.com/watch?v=FE-D6wh11_A)  

Official test script for 2019 BMVC spotlight paper 'One-shot Face Reenactment' in PyTorch.

<img src="https://github.com/bj80heyue/Learning_One_Shot_Face_Reenactment/blob/master/pics/main.png" width = 900 align=middle>

## Installation

### Requirements
- Linux
- Python 3.6
- PyTorch 0.4+
- CUDA 9.0+
- GCC 4.9+

### Easy Install
```shell
pip install -r requirements.txt
```

## Getting Started

### Prepare Data
It is recommended to symlink the dataset root to `$PROJECT/data`.
```shell
Project
├── data
│   ├── poseGuide
│   │   ├── imgs
│   │   ├── lms
│   ├── reference
│   │   ├── imgs
│   │   ├── lms
```
- imgs : store images
- lms : store landmarks extracted from images
	- format : 106 common facial key points & 20+20 gaze key points
	
<div align="center"><img src="https://github.com/bj80heyue/Learning_One_Shot_Face_Reenactment/blob/master/pics/lms.png" width = 500></div>

Example input data is organized in folder 'data'. Please organize your data in the format the same as the example input data if you want to test with your own data. 

Output images are saved in folder 'output'.

Due to the protocol of company, the model to extract 106 + 40 facial landmarks cannot be released, however, if you want to get access to the following dataset, please fill in the license file in the repo (license/celebHQlms_license.pdf), then email the signed copy to siwei.1995@163.com to get access to the annotation dataset. 
- our preprocessed 106 + 40 facial landmark annotations of celebHQ dataset
- additional 80 images as pose guide with corresponding 106 + 40 facial landmark annotations


### Inference with pretrained model
```
python test.py --pose_path PATH/TO/POSE/GUIDE/IMG/DIR --ref_path PATH/TO/REF/IMG/DIR --pose_lms PATH/TO/POSE/LANDMARK/FILE --ref_lms PATH/TO/REF/LANDMARK/FILE
```

```
output sequence: 
		ref1-pose1, ref1-pose2,  ref1-pose3, ... &
		ref2-pose1, ref2-pose2,  ref2-pose3, ... &
		ref3-pose1, ref3-pose2,  ref3-pose3, ... &
		    .				
		    .				
		    .					
```

### Pretrained model
You can download models from [here](https://drive.google.com/open?id=1Wnc2TGwFQM4PdCdeSn-trI75UeGbuY_E) 
```shell
Project
├── pretrainModel
│   ├── id_200.pth
│   ├── vgg16-397923af.pth
├── trained_model
│   ├── latest_net_appEnc.pth
│   ├── latest_net_appDnc.pth
│   ├── latest_net_netG.pth
│   ├── latest_net_netD64.pth
│   ├── latest_net_netD128.pth
│   ├── latest_net_netD256.pth
```

### Visualization of results
You can download our sample data and corresponding results from [here](https://drive.google.com/open?id=1Ia8YJrtYTvNRwBfcKK7iBSAf5vb8gkqw)

## License and Citation
The use of this software follows **MIT License**.
```
@inproceedings{OneShotFace2019,
  title={One-shot Face Reenactment},
  author={Zhang, Yunxuan and Zhang, Siwei and He, Yue and Li, Cheng and Loy, Chen Change and Liu, Ziwei},
  booktitle={British Machine Vision Conference (BMVC)},
  year={2019}
}
```


================================================
FILE: data/poseGuide/lms_poseGuide.out
================================================
pose_1.jpg
90.44765716236793 148.92923857343578 91.57595251229804 162.8479934005091 92.9120000469461 176.46386363738992 94.4750702020267 189.95495589886139 97.02199692128647 203.29694912957 99.79529533135087 216.69020667075688 102.741526161524 230.20452567522284 107.24603891660149 243.02974731417868 112.59511909643379 255.8834214948385 119.47399642829623 267.9355943070909 127.98055826477298 278.4243267720386 137.41574048038206 287.8953381465611 147.8454651322761 296.04741539656624 159.34423801793673 301.4650047702521 172.04009073265115 306.14600984182664 184.8476145369723 309.1498658924654 198.1901765257437 309.2894505953745 210.300733963732 306.8638379074362 221.74102561555833 302.8266755169669 232.31119706456093 296.7210222250431 242.4531426761971 288.6321111047236 250.61525106310273 280.1429009636254 257.9335619693902 269.4977033895718 264.79963843573 258.62440630155766 269.747212813302 247.98834091017915 272.8998867302482 235.1301723854457 275.20565595853645 222.93601399814503 277.7970566608542 210.45285890859884 279.74133799998094 197.64873719251972 280.61281413259303 184.5151342861813 281.86739874645264 171.85992721532836 281.4792935965238 158.93559678704815 281.5092061164478 146.28295371195867 114.76937329328939 130.951598068198 128.08263322780795 117.98798787017822 142.75231443631833 115.1611843213403 157.86672034108483 116.6153814472695 173.90706075198796 120.19246576692521 206.74686782862727 121.02894499848259 220.99542433006582 116.48930599229095 236.1074216934221 114.23420102476143 250.73115505761848 115.86996634388765 263.3734861763337 127.31835338341557 191.33970313484292 140.2993282288778 192.45687057588907 153.99556019589238 193.40006374594554 167.7197145214593 194.32236469376062 181.56555264352028 174.32287399383358 199.9117998945228 184.1300182390903 200.7080648332646 194.13823396512498 200.9199239355632 203.11092044225438 199.86497334690608 212.42735582667092 198.39314190081785 131.27202642017005 145.25804889557963 139.60930390659155 138.39140341160612 159.2010831026048 137.80280976717103 166.42600991645304 143.9995468330668 158.77062924567076 148.28800882103064 140.55642625303972 148.86503435426226 213.4374142301142 143.2204520243767 220.39543471642332 136.74147974141397 239.92071927456345 137.12800959501914 247.93163833918334 142.84138090328952 240.00825505051557 147.15669896848198 221.6303427783007 146.4370558472341 128.0094815119112 127.30603943524707 143.35837320872042 124.87580588006517 158.72832845001858 126.4964054707917 173.69199365904 128.65707163001753 207.1552887017947 128.7516632692907 221.6274310617116 125.50257185064652 235.7439188520525 123.88641465163721 250.25624675797326 124.72534010103783 149.3436339426612 135.51948593642157 149.7487980940789 149.8507149705502 149.2291436246387 143.5389557296295 230.21048834458065 134.5325702413665 230.55831327048872 147.71896200290422 230.8006705913076 142.14327359606972 179.39126786362544 141.84649970810392 202.05051090427742 141.56137209773732 173.8039715509792 178.44659547107824 211.45557809523098 178.39411289071995 168.10655914660424 192.04339505424878 217.73949281404384 190.9817893520048 163.77671869691835 245.9297018326953 174.2606984559743 234.04905150188023 188.89939874358288 227.84672346345394 196.7329061228004 228.2799145924343 204.10161247457313 226.22829637685066 215.79634899438292 232.07471557651525 225.75460645391786 242.76781565926217 218.67711851772015 251.95207930554216 208.64262130213024 258.23235688939445 196.93652094532126 261.1568196873806 184.66714682070506 260.51896117531226 172.74227783176593 254.4916159172529 167.5079387479136 245.17813995913383 181.0178831606829 236.37151769034227 196.79519987565118 235.13433671419762 210.26497756865683 235.59731745760598 221.92064485126207 242.76948430139294 209.8973809584047 248.94593227955767 196.5173815999451 251.3505374509133 180.9974006859345 249.87787983113742 149.24129540873582 143.53587629472423 230.8157833355565 142.13912050783574 147.133 140.543 146.515 142.549 145.134 144.308 143.303 145.533 141.195 146.09 139.025 146.064 137.013 145.291 135.327 143.832 134.289 141.928 133.844 140.046 133.787 139.301 134.208 137.402 135.234 135.42 136.895 134.055 138.873 133.177 141.105 132.998 143.193 133.411 145.055 134.593 146.375 136.332 147.056 138.359 211.016 134.501 211.428 135.177 212.929 139.128 215.603 142.521 219.29 144.814 223.556 145.707 227.95 145.317 232.052 143.676 235.325 140.702 237.262 136.833 237.814 132.601 237.099 128.309 235.103 124.352 231.856 121.555 227.845 120.103 223.408 119.916 219.124 120.821 215.37 122.99 212.551 126.254 211.221 130.155 
pose_2.jpg
100.093194977701 161.4013559052053 100.53716781745582 174.1857192569044 101.90737024900535 186.330944275622 103.3735344503214 198.64378987292483 106.0354925127567 210.26739998238142 108.63544049277088 222.61087067996962 111.59644983652562 234.0313677912341 115.58285449158832 244.8659427127749 119.6903965821058 256.22302381253576 125.40724643070543 266.6858725541848 132.9721330817453 276.1873768745778 141.3874107390709 284.37879171946804 150.87961259137933 291.7796379709438 161.32208063157066 297.27698204120384 173.02991152939967 301.9868103222333 183.95909621077863 305.35365302459354 195.7679601753028 306.2800992849018 208.18786280319534 305.7388693777715 219.90713032948378 302.63420928143375 231.0674983570704 298.4270380602661 241.87467583557293 292.4041270658705 251.11430378350548 284.80162413595906 259.75921009425394 275.4774005034841 267.5936934573682 265.5298987339355 273.10328426726045 254.9936831616392 276.9043510885375 242.3450654350767 279.6842731168115 230.00664245652425 282.2207554656838 217.85061671549514 283.6271167176647 205.31749176197863 284.91158825286465 192.74645559564587 286.28405217837695 180.92652850754328 286.4006652981046 168.26830468338156 286.37802009263464 156.15744637477152 115.4900160996773 125.082193502199 125.97839064692833 111.80204842041258 140.26601488172514 106.87664397110419 155.14218390910088 106.70852137040197 169.09935687644764 110.27262179717772 206.40394643424622 109.51418879556448 220.59279429423606 105.65324388595778 235.85397794810945 105.46019037708922 250.0199630890238 110.31695931104736 261.8146174979548 122.08371614206064 188.27026113905077 141.11795251994283 188.94549470038856 155.79707940298127 189.08404129537774 170.40651295539388 189.0338830757779 185.08851954718264 171.16007470459334 206.46071102387398 180.93779563127492 207.0502404698997 190.87321681074877 206.35219052398807 200.97355808008274 204.79966402335427 211.08925122271853 204.3564037192203 129.99709531216052 146.82516613202017 138.45895319487153 138.65812639817906 158.9199691214858 139.59960169134345 165.7290131584893 147.94525838106998 156.7653382067972 151.544880685268 138.46813920180864 150.68082990304023 212.80803931969268 146.85949161785712 219.26231272640433 138.08401938986523 239.75751734527222 136.58540730340738 248.649051930705 143.87512069902633 240.21094885731281 148.74329648092575 221.15190422708247 149.42092601802844 126.86960645378379 120.33246236674327 140.8901007864652 116.54432321777693 154.93940323118028 117.14727923840363 168.52913532187097 118.78941690089661 207.75083328136253 118.21893080257749 222.02013148499196 115.72501499112207 236.30085682138304 115.43394613063205 249.963450327016 118.47854574588834 148.87490777257472 135.70718113818126 147.59434754904237 152.29876550495365 148.03748477546577 145.4896501933588 229.46599844979602 134.41317010619932 230.84721286967346 149.98417040083237 230.72517795734979 143.7379041389614 177.48241144005863 144.42981886262302 200.57731720506337 143.59579850675797 171.84385977785854 184.30467672499688 208.85265199799937 182.28690286702104 164.7157107525598 198.3152809435389 216.5810910941354 196.00943146426206 162.08994193115757 242.61385974228384 172.0858024620631 231.3517041461669 184.59666727117178 225.07210003155583 192.4492468801826 226.01037622068648 200.5588240919513 224.77790926122546 213.2634454736595 229.9549048592879 224.9381988568802 240.3493923009241 215.6243032223897 247.39660041613203 205.36281791541535 252.38760036802856 193.5081632119767 254.17419095783012 182.49965019869023 253.80076635628163 170.88804392816343 249.59330429304387 167.57043887732817 242.73629516138735 180.19135949493068 238.375810332754 192.85939239711735 237.72187925766596 206.8976616720181 236.7157494378059 219.61588976882643 239.68200286313336 205.99954107757281 239.46004693594395 192.87395458689463 239.71164424499196 180.12457142150345 240.76992809596996 148.06175692301161 145.49601398018515 230.75213110892832 143.74539610046523 151.81004827804105 143.84939898317916 150.63144534208192 145.6441302756924 148.86913851684324 146.98729337759272 146.8314499157171 147.6939404451874 144.67901038421724 147.69476928060448 142.64910917316251 147.06602696415425 141.02872336421495 145.72688544037987 139.94956991302985 143.81564142727416 139.60619050377977 141.69161161706296 139.7873028215118 139.78759526241075 139.95654438421695 139.21832295211956 140.85339649250864 137.5185476486653 142.29640969176268 135.85395234715398 144.20622126850992 134.9262398383987 146.28762452136618 134.58486734190467 148.44938709275277 134.98154598440036 150.30228524201237 135.933428538162 151.69219101976256 137.56935702042549 152.39817535520797 139.56902405102318 152.4303516793084 141.772208047627 220.83087992185017 138.27456575190863 221.03940931940252 140.4295799316757 221.93522532027163 142.46753600857878 223.44225732523256 144.06060385796238 225.3969587320836 145.0880821494304 227.5350677449888 145.46792778177556 229.65563945935935 145.05177339486923 231.56306625944353 143.89634876044013 232.94612166470642 142.1701057576689 233.7429995319805 140.35694118816687 233.86461164644606 139.79091564686422 233.88532984187515 137.84824845062215 233.4031521329124 135.67294387241222 232.13258946772032 133.9224424167932 230.40870849293913 132.67311152665408 228.2875682138373 132.0047132346189 226.16157424927616 131.9614449361577 224.10586896110965 132.72545572561967 222.45252030932784 134.1380106049866 221.29291179707224 136.12143267485757
pose_3.jpg
100.68002376329395 153.79703744153008 101.4051291762957 167.29352731036488 102.36608555212445 180.38026696181367 103.5115085866463 193.2796673806298 105.60208513182278 206.22525492792363 108.05407168715283 219.1911996360299 110.72916034828148 232.04327056505645 114.662128816159 244.31299676982297 119.22729693999361 256.96663174064025 124.91038003602449 268.85800394432505 132.06981400976326 280.20769626878973 140.08399028936486 290.3547602041902 149.26992550691995 299.5458046262653 159.67522608282252 305.8702960749566 171.6053687509525 310.79222087866896 183.76463958488898 314.0718881616237 196.40643310382438 314.9296751937251 209.08260222463787 312.7460533253202 221.08109456329373 309.46714244205816 232.5904138851381 304.34993137439795 243.38614067947708 296.72779844763835 251.95556374662897 287.92424367136994 259.56128384934743 277.37881002812264 266.6055223864016 266.02289533102015 271.8223444494196 255.0853207857532 275.6882976698365 241.69952614724372 278.1332311458493 229.07474592014216 280.865130942551 216.16749828113603 282.6715383437519 203.17098731003347 283.4855144423068 189.7047846995279 284.54365417249846 176.60752783316747 284.2440201782424 163.29309326444684 284.41190175601093 150.34890426294282 119.32299327945958 131.65871208074068 131.24354204447127 119.4271237085135 144.7233606197231 116.76436231641082 158.3771063414871 117.49771461977107 172.81182338036558 120.10444812210659 203.4335669303016 118.8589482271039 217.87278465955853 115.02702907304996 232.91907468182228 113.44929822161964 247.31294938041242 115.33190812950377 260.2315024010835 126.97841904273312 188.22273343612926 139.3076198752563 188.63686152102878 152.0386047432662 188.7611761515883 164.79384529821388 188.91978090528823 177.55227075180463 169.7111954193623 196.77551803916765 180.06821536557868 197.67351115261314 190.63790976752796 197.7888044966272 200.8804305939036 196.19666268162234 211.24318198311573 194.77596339340437 132.59280263568814 145.38064500335693 140.33753945021442 139.0346919634389 158.93063754126615 138.179268931794 166.69930149762808 143.11074009078462 159.10171084791972 147.3394655494072 141.46611287868313 148.17486895478004 211.87448685223558 141.32109648552972 218.99927225556576 135.38629032260937 238.14200872137235 135.77969758234636 246.3949750306624 141.10883871630253 238.28261153261315 144.81048398646308 219.98203674843285 144.61163538712265 131.35538709568743 128.44829961151106 145.3210348187722 126.39810197004851 159.1387735034233 127.31037555545663 172.68949902201916 128.8156827288658 203.93057601866838 127.37584575978363 218.4936246992536 124.46089694241138 232.73524799057964 123.07160759666456 247.05706726507663 124.18618894799454 149.47918521005658 136.43638963621842 150.41669906636992 148.76269602136756 149.83792696451576 143.299579440643 228.63413541056707 133.3277985159665 228.90817521054475 145.45268978778392 229.24034281706986 140.37335793597077 177.9319146074116 140.92273666581062 199.7174690978967 140.0987596833736 169.30041816138362 174.84747082055745 210.38993478839757 173.7937801486287 163.225771472235 188.3553133264286 217.35450961702531 186.66968050371196 160.1394551245839 242.89775811526653 168.54736414930602 225.85190441340148 184.2487138192256 217.81560362723332 192.7308484038665 217.93019388434595 200.94263381193832 216.52182806789176 217.13170151835243 223.50943332277305 228.81417226616497 239.95764581058992 219.63434860680593 251.1279338845803 207.24764191007648 257.99722671866533 193.60312998843597 260.80504085767797 180.81085488093902 259.80105624440637 168.6046248037361 253.04358964232298 164.01974631861418 241.9637881083339 175.805269270773 228.99728727707162 192.6064447100979 225.38489739724236 210.80045024409787 227.89017795717766 224.7944515525944 239.57564391391026 209.9331052046141 247.97277555588565 193.54048661491902 251.31926099718322 177.56903552895676 248.71010475718293 149.84582414347676 143.2984574691387 229.2632446360568 140.37010421860836 154.301461172647 141.28211941410973 153.98117099887304 143.8251636812882 152.89421223464961 146.2853838191459 151.1345283671787 148.25224364347358 148.80387764360444 149.42124922459743 146.23896934961363 149.66158067587995 143.84019972098628 148.7576380963099 141.85285060979106 146.94446913051868 140.55571670009954 144.5909333525596 139.88406480440915 142.26213159782014 139.72922406113173 141.18360134604575 139.81607205312503 138.8517732020959 140.515012731171 136.2065936517422 142.06859816590202 134.099020893086 144.20611444560933 132.65372132086355 146.81577006048656 132.16694641865024 149.30762030185718 132.6089132914571 151.4890601731642 134.0416290964237 153.06288105736593 136.13107669270323 154.0035887610266 138.58439595152163 215.59932696651867 137.60874451806953 216.85701835983605 139.9114879293186 219.56175336875458 143.1107864900333 223.08959800249056 145.3157822728301 227.1133598042387 146.31031579493748 231.25404792720292 146.05235790736765 235.10368937026135 144.48868952422154 238.22739565617837 141.6545095990617 240.1139335953012 137.84722886724492 240.63358006265992 133.6508493395479 239.9784498785409 129.64070899264615 238.04789058947233 125.85688882307153 234.99614525947288 122.83499049133499 231.17294812335706 121.25605822396645 227.08449592576687 120.91765881391294 222.96594299710102 121.87581135122085 219.39622629853466 123.95239154343281 216.78751784202973 127.16779559194907 215.28458818424303 131.02692301680122 214.9035510490429 133.3522078497409

================================================
FILE: data/reference/lms_ref.out
================================================
ref_1.png
91.5305110043 157.054464213 92.3770956017 170.115229042 93.7010275718 182.787981616 95.1548975719 195.332331205 97.3347102602 207.716766281 99.662236308 220.4665801 102.33052797 232.892260979 106.438173842 244.606825996 111.526303747 256.554821386 118.326501636 267.241535072 126.907554178 276.813854107 136.377258498 285.26307926 146.792371645 292.544235232 158.110919394 297.836709658 170.660884454 302.234441612 183.128704907 305.026855756 196.052484086 305.619976224 209.015038563 303.932347537 221.107902017 300.754207398 232.695360635 295.924145465 243.967578069 289.213744481 253.60576167 281.5663184 262.425012464 271.991909527 270.536534297 261.818957787 276.649322878 251.47003397 280.982866424 238.753417493 283.502200732 226.415968603 285.877343058 213.884350894 287.329861916 201.312966922 288.080086341 188.333151939 289.089321689 176.008753975 288.82966812 163.061934768 288.800723064 150.601424089 111.166514647 133.040716434 123.727850015 119.151674836 138.403570358 115.195718003 153.931016799 115.048281714 169.720489595 117.81474832 208.638735958 116.075327861 223.825603912 111.908791938 239.973167757 110.52262007 255.36881296 113.263118909 268.240874597 126.677018495 189.682041913 143.924696725 190.854041755 158.428013381 191.543603835 173.099139858 192.249936075 187.783461235 170.693937826 203.850342545 181.552014432 205.0961696 192.539693343 205.753910309 203.005895575 203.645338586 213.804201864 201.942667541 129.861998337 150.469447406 138.682325087 144.971438056 158.318952147 142.947134846 166.562767713 148.175911133 158.434543813 151.783553428 139.921782751 153.00548942 213.982344278 146.516031156 221.257153182 140.292642262 241.333860112 140.473593503 250.974868308 145.430719473 241.612688503 149.159617782 222.219029473 148.786122779 124.003302311 128.855280973 139.004461521 125.423794795 154.413255848 125.532565126 169.382919275 126.426350997 209.298274471 124.561489414 224.585108448 121.842109834 239.789645816 120.983759394 254.844007729 123.011316913 148.178048442 141.920938125 149.36627429 153.337310947 148.454744756 148.355641177 231.301475374 138.364326844 231.767179017 149.639055154 232.49717489 145.093280533 178.536600869 145.793460055 201.690026469 144.834209602 170.808068504 182.238448616 212.783621495 181.284539841 163.878859691 195.537120954 220.17502113 193.566153645 155.061694709 237.314899133 168.691774343 229.41669664 184.445533833 225.636771942 192.922252371 226.622102972 201.359987418 224.797551625 216.244562439 227.506028901 231.183770684 234.595217305 219.339705183 241.029675481 206.8282294 244.653452066 193.48598518 246.593223075 180.50678043 247.025890411 166.939154463 243.418623657 159.66420291 237.284550105 176.247472 234.802214372 193.01302325 235.541054382 209.885009946 233.825001783 226.476721585 234.730269342 209.673760794 234.29342297 193.075478992 235.183984642 176.303512976 235.572025677 148.470422359 148.355139093 232.508279859 145.09292489 155.279565975 145.477647446 154.609748868 147.675943207 153.171909168 149.577768712 151.204735999 150.877299851 148.911294491 151.503950821 146.567872056 151.488761482 144.388146126 150.70457047 142.543317208 149.188398723 141.370292655 147.169439163 140.856966418 145.147661715 140.800527419 144.549185855 141.191517902 142.546836043 142.207488061 140.417230767 143.951139828 138.896645721 146.047975359 137.929722343 148.444171137 137.632617459 150.732458337 137.968023504 152.812426749 139.128788161 154.338207325 140.912161647 155.171859036 143.155816221 226.565177903 142.22102609 227.396870455 144.280996543 228.961440761 145.98923668 230.960699023 147.080002841 233.215103849 147.500847529 235.465646862 147.31107009 237.492576996 146.386927482 239.138685693 144.781189341 240.084062174 142.754545992 240.398573086 140.793879684 240.370511298 140.223920684 239.804647089 138.317604988 238.641406911 136.359145098 236.836733786 135.055513762 234.751298886 134.295766975 232.428229204 134.201457416 230.262410384 134.688663192 228.365602855 135.959784698 227.066455763 137.781969604 226.454437003 139.988222896
ref_2.png
92.039137758 156.12329245 92.9050358285 169.485384495 94.4162795053 182.406301682 96.1104208304 195.161424701 98.7067967479 207.76323503 101.650204186 220.637302423 104.656856193 233.139233976 109.003742607 245.093591736 113.935069382 257.263007266 120.449074914 268.455902014 128.651936371 278.822349738 137.573958433 288.281461939 147.617540506 296.645662325 158.961544532 302.47668445 171.597037381 307.133502321 184.26512518 310.22100852 197.36392408 310.812898844 210.115403508 308.7057881 221.723349692 304.902468686 232.433400559 299.163784305 242.376146097 291.152868024 250.247482235 282.220521752 257.279485702 271.442518292 263.809502706 260.366941295 268.315940536 249.563609167 271.33120125 236.662909339 273.438509977 224.500099191 275.749154425 212.326589649 277.225648649 200.04842333 278.089551694 187.299923103 279.124329773 175.181193895 278.652381786 162.544083451 278.440504309 150.367990519 111.907895958 131.531235069 125.275455017 117.599022821 141.013648304 113.249077652 157.206272745 113.356300912 173.625129093 116.363822971 208.025288848 117.11531083 222.629721985 113.376275223 237.988667189 112.278082276 252.366591804 115.164869446 264.708164729 127.833223769 192.266751163 140.459419472 193.995037436 154.739435239 195.437956791 169.13898836 196.781258921 183.608614787 173.455824323 202.774889104 184.130273167 204.044877226 195.136226464 204.092632645 205.483673375 202.652360976 216.199353772 200.823558418 127.959007135 146.38452335 136.576082323 138.250970054 158.359299034 137.18681324 166.537550992 144.585496212 157.568374849 148.991600409 137.497804009 149.872248132 214.242507489 143.822636835 221.288462029 136.104347199 242.420508443 136.275974524 250.738034297 143.624391031 242.152287276 148.082170471 222.516329474 147.196935303 125.761681549 126.665056316 141.819047605 123.172883829 157.944320564 123.833332018 173.46129159 125.367303396 208.675931776 125.775806801 223.493937889 123.183216472 237.993139935 122.390416611 252.240579909 124.406580011 147.362642976 134.611399153 147.696202459 150.589866343 147.427284461 143.886617095 231.859031092 133.272454046 232.088616173 148.400661618 232.608514118 142.302373647 179.8362925 142.274270671 203.070888711 141.780055796 173.602126818 180.473167562 215.098561496 179.091802358 167.039173259 194.238218679 221.816089887 192.044322644 159.552088031 241.879093966 171.939655217 230.56629131 187.823790129 224.704613898 196.280822527 225.320683281 204.407476176 223.261294069 217.308645593 228.39691398 228.456682798 239.236012992 219.483389839 248.10694088 208.77547541 253.650756487 196.003122315 256.115729966 182.941888656 255.641470828 170.063697991 250.148592102 164.18724798 241.650019357 179.885120337 238.392818704 195.939883631 238.82889921 210.514669451 237.464503248 224.1035427 239.103085833 209.988551017 237.922111755 195.836065085 238.750814194 179.870307548 239.237023435 147.444582396 143.885883792 232.62927164 142.301493684 155.127028764 140.537331076 154.454429858 142.837332404 152.976286017 144.840845065 150.930815747 146.215450575 148.534658311 146.888194252 146.08006162 146.885503984 143.793948875 146.082695223 141.852325861 144.511123652 140.601332083 142.405653462 140.03895484 140.291092756 139.965669494 139.657877826 140.353049746 137.562666354 141.394324906 135.326056255 143.204773107 133.721580259 145.391970301 132.697944438 147.895657387 132.370688549 150.291740828 132.70996288 152.476167005 133.910104603 154.090852598 135.76517545 154.986592178 138.107517235 223.322353839 139.639890978 224.129231465 141.815482269 225.962528748 144.608836529 228.558435579 146.734926501 231.734648617 147.866683781 235.120995672 147.90542935 238.300858089 146.752373692 240.951424244 144.575776476 242.634433954 141.649299848 243.324806591 138.493184533 243.265992152 136.026638803 242.358140849 132.924412989 240.531109003 129.984536177 237.812378784 128.051499592 234.617845994 127.141637251 231.141121329 127.390918972 227.996289538 128.585465568 225.45099488 130.776315405 223.805921696 133.616618182 223.184736412 136.246464014
ref_3.png
89.7446818354 155.927528419 90.7091086051 169.45220918 92.1771452547 182.544666392 93.8190295425 195.494446186 96.315176578 208.265813727 99.1361861374 221.338389755 102.098395321 234.028489192 106.444773034 246.116655146 111.705677341 258.252967158 118.660060954 269.336292739 127.394685486 279.259712586 136.871520014 288.23173066 147.382824167 296.007851938 159.024171274 301.495403945 171.945588057 305.928041318 184.877657529 308.713101069 198.13098132 309.294782126 210.91666423 307.472922348 222.673864878 304.076525209 233.717266677 298.590029705 243.962595248 290.867063882 252.172612606 282.179417886 259.363784266 271.591163705 266.09216319 260.628978059 270.629296692 249.790642218 273.693222139 236.971069668 275.55531311 224.74186416 277.617786921 212.52923176 278.992960992 200.310134731 279.92342743 187.623430919 280.883729231 175.645397515 280.492032806 163.019975017 280.433642765 150.866565497 112.458881803 134.482923002 125.793396564 120.718675865 141.412204212 116.58009044 157.736266482 116.938735036 174.143618313 120.077149335 207.617215666 120.47122235 222.705830998 116.583712441 238.679772245 115.291092678 253.775657689 118.319344335 266.429632531 131.441166189 191.414226473 140.041489615 192.850822523 153.790601065 193.867209985 167.723469328 194.867778998 181.592537517 172.201717619 200.524667806 182.989602302 201.72866697 194.063698079 201.954069529 204.340615551 200.625265206 214.877793463 198.954489056 128.827440635 145.575422565 137.250661365 137.697763146 158.281694321 136.971861385 166.378688858 144.135400398 157.511216356 148.548000686 137.936630177 149.278622465 214.027648544 143.561100774 221.391715781 135.894704045 242.726200152 135.873806118 251.291436373 143.451470985 242.813075433 148.060481283 222.781634755 147.099606436 126.207495596 129.795013955 142.190211571 126.429900213 158.428127855 127.200662026 173.966468532 128.642187902 208.14434046 128.731465173 223.429448025 126.136964343 238.557853799 125.324946357 253.455571586 127.501648322 147.7060987 134.440665618 147.84086471 150.150124609 147.803482172 143.377713126 232.027555847 132.998389028 232.615159742 148.51752631 232.794367298 142.144380431 179.326938943 141.853703357 202.59662599 141.480726012 172.243260843 178.68304969 213.723770212 177.726980694 165.473446815 192.214971599 220.965017207 190.718945231 155.728922492 239.057606299 170.181175172 231.096228175 186.43091055 226.976940276 195.40539851 227.815860474 204.081826624 225.735051841 218.526224478 229.255215261 232.915680778 236.620503481 222.162507919 244.526047894 209.767832475 249.304907615 196.051622376 251.316624098 182.214783646 250.94673925 168.042340229 246.246573353 160.252157314 239.028301437 177.729957901 236.168927085 195.513964855 236.985212136 212.244918867 235.532484803 228.424807435 236.717161803 211.918022534 237.853759569 195.535605688 239.005363972 177.818326321 238.916362828 147.819888173 143.377165012 232.811456882 142.143809478 155.497042729 139.080294968 154.977936922 140.994646125 153.559323194 143.49473855 151.400137905 145.418032555 148.676044822 146.532825019 145.736284972 146.804864582 142.901565346 146.124206394 140.40355375 144.501587927 138.641411515 142.147454709 137.703539473 139.536527092 137.457627258 137.686058117 137.926338285 134.980233171 139.247572242 132.289724122 141.470811945 130.399191762 144.168431101 129.33181933 147.192121197 129.175367631 150.028614396 129.806693626 152.516618781 131.375583064 154.337109022 133.632642765 155.271760892 136.147185127 225.024125179 139.763086823 225.894796831 142.021768143 227.699222084 144.269099369 230.12902897 145.816665373 232.963430422 146.508283733 235.864918962 146.360705139 238.522957238 145.276177727 240.710690042 143.30607876 242.040227711 140.741212021 242.541665625 138.113677869 242.503864887 136.634015215 241.735070996 134.085873437 240.153703585 131.591568754 237.781859577 129.976368993 235.047504717 129.186283992 232.076311952 129.316345317 229.362318434 130.200859979 227.082765877 131.98309395 225.550964847 134.372682048 224.898900853 136.875222715
ref_4.png
92.6273695423 151.863222686 93.5634544076 165.442673083 94.8552471172 178.650325358 96.1474669783 191.749293298 98.0906383413 204.755684179 100.285175961 218.185819107 102.771752809 231.218781595 106.694038527 243.517218451 111.578585938 256.102201048 118.105511997 267.511376806 126.399356176 278.087505265 135.433629718 287.672366713 145.548504929 296.115766537 156.790602001 302.502752179 169.567545866 307.633094546 182.648571695 310.755146004 196.200805323 311.565606788 209.726313039 309.554460514 222.045251117 305.930018398 233.409945926 300.075564793 244.155133582 291.95486222 252.97666835 283.019263644 260.964792575 272.252452492 268.496273838 260.855603759 273.975458397 249.594996703 278.038081101 236.255777027 280.263208217 223.304503915 282.451046605 210.303387374 283.742767823 197.41335475 284.44342505 183.967466947 285.381033961 171.210624512 285.241893339 157.86537336 285.678359822 144.974773932 113.571695264 137.773282591 126.423070037 126.077276146 141.172084872 122.730020142 156.786673448 122.593110636 172.634015099 124.836543465 207.450243229 123.717710033 222.274178504 120.207772688 238.075031062 119.034652488 253.256669102 120.945992859 266.378505116 132.251581801 190.798224303 140.716603443 192.4233368 158.219556799 193.510106082 176.018613632 194.593526928 193.511670397 171.785536057 203.409870542 182.366082381 206.668729094 193.510218649 208.510021741 204.046347598 205.124224658 214.935880011 201.577604946 131.014695568 145.367011031 139.467183559 138.890210226 158.250907697 137.78525507 166.742471689 143.410065194 158.406075004 146.365828471 140.52095462 147.231028374 212.998014834 141.651829149 220.247853138 134.842563125 239.687697507 134.623601761 249.027437664 140.825322088 239.784804206 143.673153006 221.191753251 143.161764316 126.683321326 133.941381975 141.829788841 130.897089652 157.413555648 131.388575641 172.45413795 132.438648698 207.91546898 131.300051503 222.85468888 128.599313474 237.814246381 127.719136616 252.80098701 129.302383065 148.675890017 136.023315453 149.560589595 147.399107602 149.075840872 142.890987351 229.896886431 132.295537257 230.1985368 143.703648159 231.003209289 139.701959552 179.198639948 141.826319745 201.798114566 140.820250173 171.076364201 183.059199156 214.990800393 181.784780491 164.273786399 195.597389697 222.433535348 193.198757019 150.300478792 234.949007132 166.707377679 229.025010333 185.00149523 227.935187017 194.085939754 228.77216209 202.846508575 226.960413473 219.437476402 226.808102856 236.225997626 231.52430752 224.084792536 240.17210242 209.752520443 244.824966141 194.238170627 247.014452882 179.226720943 247.450978793 163.6234829 243.117198749 154.322548581 235.645769488 174.051832814 235.379514895 193.84864258 236.667679431 213.380385096 233.964877996 232.149237377 232.054578053 213.174011753 234.392647807 193.899421574 236.315921248 174.046701477 236.207236101 149.091263629 142.890591717 231.003868388 139.702585682 154.232787412 139.498496855 153.559249613 141.367723554 152.229615328 142.978560688 150.492771132 144.056476928 148.496682712 144.553950002 146.471422963 144.490799448 144.61066478 143.746310258 143.063676723 142.373883662 142.135785891 140.594608659 141.784370627 138.852631608 141.815688391 138.368906882 142.252695627 136.604132584 143.200795401 134.784449975 144.76419749 133.538005825 146.595971128 132.776601117 148.674070754 132.59339881 150.63604423 132.921818296 152.388039139 133.963608833 153.620955407 135.526071872 154.266246094 137.49393295 226.065859144 136.26361424 226.868195419 138.262170727 228.342357179 139.92904768 230.222562721 141.030566979 232.367304026 141.489978956 234.524264239 141.322758726 236.433610911 140.376741347 237.931692714 138.764798283 238.743687983 136.778202108 238.970360889 134.885714623 238.876300783 134.227084712 238.297546768 132.371334475 237.177941243 130.488304076 235.450083136 129.234334949 233.449811878 128.515930955 231.227253172 128.482277069 229.195261323 129.034043292 227.487947055 130.335622184 226.376902766 132.120259444 225.883529367 134.119259343
ref_5.png
95.8641666481 153.876216957 96.6446302127 167.22924384 97.9311325716 180.266378675 99.184959939 193.083168223 101.027597231 205.919003828 103.289807178 219.021574555 105.748259956 231.850852856 109.997457645 243.874168633 115.064778665 256.019911195 121.866812215 266.937287877 130.243278876 277.246220556 139.086598138 286.677196779 148.628167036 295.190484457 159.245096382 301.855525478 171.270807106 307.693979843 183.679301808 311.556503455 197.019283776 312.547509766 210.208056594 310.222936374 222.167477955 305.845679539 232.959303274 299.539884924 243.333168051 291.22890252 251.895625423 282.104000409 259.899547789 271.605086811 267.550411672 260.67534063 273.404815111 250.031511959 277.612279363 237.086154859 280.224974721 224.611605969 282.351300904 211.899376581 283.388452301 199.238960208 283.853415984 186.1044991 284.81094413 173.540800443 284.680510281 160.42545556 284.984174898 147.773904598 113.702919061 133.934425806 125.923380435 121.087837609 140.381271813 117.228598734 155.527007262 117.096417177 171.10423348 119.939581571 209.024718207 119.116092608 223.467736291 114.959349376 238.90147459 113.643068334 253.66274328 116.095616255 266.264419508 128.3682792 189.927230613 142.786284433 191.031431781 159.91561173 191.716484945 177.299107727 192.396402803 194.505825337 170.397055103 203.027801289 180.851436938 205.830811793 191.881432183 208.401207438 202.330991514 204.306038302 213.26702099 200.924225841 131.375682053 147.069434029 140.453815794 140.747815903 159.785033195 140.092646074 167.499734139 146.150652377 159.288103199 149.060921564 141.104433377 149.332161205 212.243008654 144.989128685 219.401922188 138.089053946 239.431172281 137.201448633 248.880813497 142.819445264 239.623307623 146.601572588 220.520097369 146.669603505 126.179256591 128.811103972 140.988925759 125.438628592 156.052577039 126.045256262 170.724849743 127.519761685 209.45863941 126.541451262 224.10024653 123.376039483 238.649646206 122.200545878 253.326614294 124.112638099 149.962146842 137.966419304 150.317830278 149.909939276 149.71119027 145.110796371 229.253804711 135.16935239 229.818628112 147.046677606 230.519357909 142.504417332 179.128902539 144.294270439 200.958214909 143.528436518 170.009334396 183.169208698 212.880929723 181.855934049 162.584727006 195.337426237 220.839542107 192.709053035 147.281824005 232.261709769 164.002657088 223.554253707 183.37014552 221.50797366 192.259443622 222.355737203 200.802915419 220.478495682 219.03648697 220.316533677 237.226191351 227.304014643 227.693805169 244.124502154 212.897176317 255.947635359 193.652382038 260.711780201 174.873795913 258.887172586 158.517661986 247.971968875 151.041529776 232.921207339 171.196964243 227.007403955 192.011174916 227.081699422 213.015536573 224.988189242 233.485042643 228.134080153 216.193361049 244.952965166 193.416850404 251.289293202 169.881928882 247.149151994 149.727135814 145.110213144 230.543964925 142.504182588 152.451951943 142.39464313 151.824151025 144.383527388 150.476923652 146.092693718 148.688914 147.264463788 146.621652379 147.822409065 144.506891575 147.809279979 142.551471213 147.077800954 140.910884152 145.677503436 139.903731134 143.842890243 139.482997763 142.034728686 139.470055027 141.498981138 139.860599833 139.662476694 140.810033368 137.736340433 142.40855364 136.385886219 144.306882875 135.520104602 146.474911233 135.279807056 148.526122102 135.592643636 150.374028544 136.658702515 151.697185583 138.274853377 152.412463839 140.296443166 227.934250089 139.199356545 228.741799454 141.124431296 230.251117618 142.71913679 232.144641023 143.740948837 234.257583507 144.122713168 236.372046431 143.927618466 238.271866802 143.029848555 239.799014676 141.49239821 240.651000461 139.575813795 240.911201815 137.740750384 240.855348809 137.195927746 240.289193514 135.377118292 239.162060095 133.536248939 237.43771659 132.337935585 235.464632569 131.65101154 233.278206253 131.606574092 231.256229646 132.09300793 229.501001958 133.306157299 228.323876427 135.027729975 227.783515958 137.097242098

================================================
FILE: fusion/README.md
================================================
简介:
	此项目的意义是针对Face2Face的生成结果,利用reference进行纹理的融合.

算法输入:
	生成图像img_gen
	生成图像106点+40点眼睛(optional)

	参考图像img_ref
	参考图像106点+40点眼睛(optional)

算法输出:
	融合后的图像


算法选项:
	1.alpha blending
	2.泊松融合
	2.NCC mask net


================================================
FILE: fusion/affineFace.py
================================================
from fusion.points2heatmap import *
from fusion.calcAffine import *
from fusion.warper import warping as warp
import matplotlib.pyplot as plt
from fusion.parts2lms import parts2lms
import time
from tqdm import *
import random
import multiprocessing
import sys


def gammaTrans(img, gamma):
	gamma_table = [np.power(x/255.0, gamma)*255.0 for x in range(256)]
	gamma_table = np.round(np.array(gamma_table)).astype(np.uint8)
	return cv2.LUT(img, gamma_table)

def erodeAndBlur(img,kernelSize=21,blurSize=21):
	#img : ndarray float32
	kernel = np.ones((int(kernelSize), int(kernelSize)), np.uint8)
	res = cv2.erode(img,kernel)
	res = cv2.GaussianBlur(res, (blurSize, blurSize), math.sqrt(blurSize))
	return res

def affineface(img,src_pt,dst_pt,heatmapSize=256,needImg=True):
	#src/dst_pt[ndarray] : [...,[x,y],...] in [0.0,1.0],with gaze
	#naive mode: align 5 parts 
	curves_src,_ = points2curves(src_pt.copy())
	pts_fivesense_src = np.vstack(curves_src[1:])
	curves_dst,_ = points2curves(dst_pt.copy())
	pts_fivesense_dst = np.vstack(curves_dst[1:])
	affine_mat = calAffine(pts_fivesense_src,pts_fivesense_dst)

	pt_aligned = affinePts(affine_mat,src_pt*255.0)/255.0
	if needImg:
		img_aligned = affineImg(img,affine_mat)
		return pt_aligned,img_aligned
	else:
		return pt_aligned

def affineface_parts(img,src_pt,dst_pt):
	curves_src,_ = points2curves(src_pt.copy())
	curves_dst,_ = points2curves(dst_pt.copy())#[0,255]

	parts_src = curves2parts(curves_src)
	parts_dst = curves2parts(curves_dst)	#[0,255]

	partsList = []
	for i in range(len(parts_src)-2):
		affine_mat = calAffine(parts_src[i],parts_dst[i])
		parts_aligned = affinePts(affine_mat,parts_src[i])	#[0,255]
		partsList.append(parts_aligned)
	partsList.append(parts_src[-2])
	partsList.append(parts_src[-1])
	
	'''
	A = []
	B = []
	for i in range(len(parts_src)):
		A.append(parts_src[i])
		B.append(partsList[i])
	A = np.vstack(A)
	B = np.vstack(B)
	res = warp(img,A,B)
	'''
	lms = parts2lms(partsList)
	#bound
	lms[:33] = dst_pt[:33]*256
	res = warp(img,src_pt[:106]*256,lms[:106])

	return lms/255.0,res

def lightEye(img_ref,lms_ref,img_gen,lms_gen,ratio=0.1):
	#get curves
	curves_ref,_ = points2curves(lms_ref.copy())
	curves_gen,_ = points2curves(lms_gen.copy())

	parts_ref = curves2parts(curves_ref)
	parts_gen = curves2parts(curves_gen)	#[0,255]

	#get rois
	gaze_ref = curves2gaze(curves_ref)
	gaze_gen = curves2gaze(curves_gen)

	#img_gazeL = np.dot(gaze_ref[0],  img_ref)
	img_gazeL = multi(img_ref,gaze_ref[0])
	#img_gazeR = np.dot(gaze_ref[1] , img_ref)
	img_gazeR = multi(img_ref,gaze_ref[1])

	affine_mat = calAffine(parts_ref[-2],parts_gen[-2])
	img_gazeL_affined = affineImg(img_gazeL,affine_mat)
	affine_mat = calAffine(parts_ref[-1],parts_gen[-1])
	img_gazeR_affined = affineImg(img_gazeR,affine_mat)

	img_ref = img_gazeL_affined + img_gazeR_affined
	
	mask = gaze_gen[0] + gaze_gen[1]
	mask = erodeAndBlur(mask,5,5)

	R = img_gen[:,:,0] * (1-mask) + mask* (img_gen[:,:,0]*ratio + img_ref[:,:,0]*(1-ratio))
	G = img_gen[:,:,1] * (1-mask) + mask* (img_gen[:,:,1]*ratio + img_ref[:,:,1]*(1-ratio))
	B = img_gen[:,:,2] * (1-mask) + mask* (img_gen[:,:,2]*ratio + img_ref[:,:,2]*(1-ratio))

	res = np.stack([R,G,B]).transpose((1,2,0))
	seg = mask
	seg = seg * 127
	return res,seg,img_ref

def multi(img,mask):
	R = img[:,:,0] * mask
	G = img[:,:,1] * mask
	B = img[:,:,2] * mask
	res = np.stack([R,G,B]).transpose((1,2,0))
	return res


def fusion(img_ref,lms_ref,img_gen,lms_gen,ratio=0.2):
	#img*: ndarray(np.uint8) [0,255]
	#lms*: ndarray , [...,[x,y],...] in [0,1]
	#ratio: weight of gen 
	#--------------------------------------------
	#get curves
	curves_ref,_ = points2curves(lms_ref.copy())
	curves_gen,_ = points2curves(lms_gen.copy())
	#get rois
	roi_ref = curves2segments(curves_ref)
	roi_gen = curves2segments(curves_gen)
	#get seg
	seg_ref = roi_ref.sum(0)
	seg_gen = roi_gen.sum(0)
	seg_ref = seg_ref / seg_ref.max() * 255
	seg_gen = seg_gen / seg_gen.max() * 255
	#get skin mask
	skin_src = roi_ref[0] - roi_ref[2:].max(0)
	skin_gen = roi_gen[0] - roi_gen[2:].max(0)
	#blur edge
	skin_src = erodeAndBlur(skin_src,7,7)
	skin_gen = erodeAndBlur(skin_gen,7,7)
	#fusion 
	skin = skin_src * skin_gen

	R = img_gen[:,:,0] * (1-skin) + skin * (img_gen[:,:,0]*ratio + img_ref[:,:,0]*(1-ratio))
	G = img_gen[:,:,1] * (1-skin) + skin * (img_gen[:,:,1]*ratio + img_ref[:,:,1]*(1-ratio))
	B = img_gen[:,:,2] * (1-skin) + skin * (img_gen[:,:,2]*ratio + img_ref[:,:,2]*(1-ratio))

	res = np.stack([R,G,B]).transpose((1,2,0))
	return res,seg_ref,seg_gen


def loaddata(head,path_lms,flag=256,num = 50000):
	#head: head of img
	#return res:[[path,lms[0,1]]]
	fin = open(path_lms,'r')
	data = fin.read().splitlines()
	res = []
	for i in tqdm(range(min(len(data)//2,num))):
		name = data[2*i]
		path = os.path.join(head,name)
		lms = list(map(float,data[2*i+1].split()))
		if flag==256:
			lms = np.array(lms).reshape(-1,2) / 255.0
		else:
			lms = (np.array(lms).reshape(-1,2)-64) / 255.0
		res.append((path,lms))
	return res

def gray2rgb(img):
	res = np.stack([img,img,img]).transpose((1,2,0))
	return res.astype(np.uint8)

def process(index, album_ref, album_gen, album_pose):
	# 30ms
	img_gen = cv2.imread(album_gen[index][0])
	lms_gen = album_gen[index][1]
	img_ref = cv2.imread(album_ref[index // 100][0])[64:64 + 256, 64:64 + 256, :]
	lms_ref = album_ref[index // 100][1]
	img_pose = cv2.imread(album_pose[index % 100][0])[64:64 + 256, 64:64 + 256, :]
	lms_pose = album_pose[index % 100][1]

	# affine
	# 4ms
	lms_ref_, img_ref_ = affineface(img_ref, lms_ref, lms_gen)
	# 200ms
	lms_ref_parts, img_ref_parts = affineface_parts(img_ref, lms_ref, lms_gen)

	# fusion
	# fuse_all,seg_ref_,seg_gen = fusion(img_ref_,lms_ref_,img_gen,lms_gen,0.1)
	fuse_parts, seg_ref_parts, seg_gen = fusion(img_ref_parts, lms_ref_parts, img_gen, lms_gen, 0.1)
	fuse_eye, mask_eye, img_eye = lightEye(img_ref, lms_ref, fuse_parts, lms_gen, 0.1)

	res = np.hstack([img_ref, img_pose, img_gen, fuse_eye])
	cv2.imwrite('proposed_wild/fuse/%d.jpg' % (index), fuse_eye)

	




================================================
FILE: fusion/calcAffine.py
================================================
# -*- coding: utf-8 -*-
"""
Created on Fri Dec 29 13:43:03 2017
"""
import numpy as np
import cv2


#affine points via least square method
#src_p[input] -- np.array([[x,y],...])
#dst_p[input] -- list[float]
#affine_mat[output] -- np.array() | matrix of affine 
#pt_align[output] -- np.array() | aligned points
def calAffine(src_p, dst_p):
	p_N = len(src_p)
	U = np.mat(list(dst_p[:,0]) + list(dst_p[:,1]))
	xx_src,yy_src = list(src_p[:,0]),list(src_p[:,1])

	X = np.mat(np.stack([xx_src + yy_src, yy_src + [-ii for ii in xx_src], \
	[1 for ii in range(p_N)] + [0 for ii in range(p_N)], \
	[0 for ii in range(p_N)] + [1 for ii in range(p_N)]], axis=1))

	result = np.linalg.pinv(X) * U.T

	affine_mat = np.zeros([2, 3])
	affine_mat[0][0] = result[0][0]
	affine_mat[0][1] = result[1][0]
	affine_mat[0][2] = result[2][0]
	affine_mat[1][0] = -result[1][0]
	affine_mat[1][1] = result[0][0]
	affine_mat[1][2] = result[3][0]
	return affine_mat

def affinePts(affine_mat,pt):
	src_align = pt.T
	new_align = np.mat(affine_mat[:2, :2]) * np.mat(src_align) + np.reshape(affine_mat[:, 2], (-1, 1))
	pt_align = np.array(np.reshape(new_align.T, -1))[0].reshape(-1,2)
	return pt_align
	
#affine Image from pt_src to pt_mean
#img[input] -- np.array()
#pt_src,pt_mean[input] -- list[float] format = x0,y0,x1,y1,...,xn,yn
#img_align -- np.array() | aligned image
def affineImg(img,TransMat,dsize = 256):
	img_align = cv2.warpAffine(img, TransMat, (dsize, dsize), borderValue=(155, 155, 155) )
	return img_align

	
if __name__ == '__main__':

	path_src = '/media/heyue/8d1c3fac-68d3-4428-af91-bc478fbdd541/Project/Face2Face/detectface/samples/common/output/landmarks.txt'
	output_pt = 'lms/lms.txt' 
	output_img = 'imgs'
	affineList(path_src, output_pt,output_img,'meanpose384.txt',k=2,head='/media/heyue/8d1c3fac-68d3-4428-af91-bc478fbdd541/Project/Face2Face/Data/test')
	
	'''
	path_src = 'alignedPoints_256.txt'
	output_pt = 'output/AU_points.txt'
	output_img = 'output/AU'
	head = '/media/heyue/8d1c3fac-68d3-4428-af91-bc478fbdd541/Project/Face2Face/net/GANimation/dataset_emo'
	affineList(path_src,output_pt,output_img,'meanpose384.txt',k=2,head = head)
	print('done')
	'''


================================================
FILE: fusion/parts2lms.py
================================================
import numpy as np

def parts2lms(parts):
	bound,browL,browR,eyeL,eyeR,nose,lipU,lipD,gazeL,gazeR = parts
	res = list()
	res.append(bound)	#0-32
	res.append(browL[:5]) #33- 37
	res.append(browR[:5])	#38-42
	res.append(nose[:4])	#43,44,45,46
	res.append(nose[6:6+5])	#47,48,49,50,51
	res.append(eyeL[:2])	#52,53
	res.append(eyeL[3:3+2])	#54,55
	res.append(eyeL[6])		#56
	res.append(eyeL[8])		#57
	res.append(eyeR[:2])	#58,59
	res.append(eyeR[3:3+2])	#60,61
	res.append(eyeR[6])		#62
	res.append(eyeR[8])		#63
	res.append(browL[6:6+4])#64,65,66,67
	res.append(browR[5:5+4])#68,69,70,71
	res.append(eyeL[2])	#72
	res.append(eyeL[7])	#73
	res.append((eyeL[2]+eyeL[7])/2)	#74 useless
	res.append(eyeR[2])	#75
	res.append(eyeR[7])	#76
	res.append((eyeR[2]+eyeR[7])/2)	#77 useless
	res.append((nose[0]+eyeL[4])/2)	#78
	res.append((nose[0]+eyeR[0])/2)	#79
	res.append(nose[4])	#80
	res.append(nose[12])	#81
	res.append(nose[5])	#82
	res.append(nose[11])	#83
	res.append(lipU[:7])	#84,85,86,87,88,89,90
	res.append(lipD[10])	#91
	res.append(lipD[9])	#92
	res.append(lipD[8])	#93
	res.append(lipD[7])	#94
	res.append(lipD[6])	#95
	res.append(lipU[7:7+5]) #96,97,98,99,100
	res.append(lipD[3])	#101
	res.append(lipD[2])	#102
	res.append(lipD[1])	#103
	res.append((eyeL[2]+eyeL[7])/2)	#104
	res.append((eyeR[2]+eyeR[7])/2)	#105
	res.append(gazeL)
	res.append(gazeR)
	res = np.vstack(res)
	return res





================================================
FILE: fusion/points2heatmap.py
================================================
import numpy as np
import cv2
import os
import math

def curve_interp(points, heatmapSize=256, sigma=3):
	sigma = max(1,(sigma // 2)*2 + 1)
	img = np.zeros((heatmapSize, heatmapSize), np.uint8)
	for ii in range(1, points.shape[0]):
		cv2.line(img, tuple(points[ii-1].astype(np.int32)),tuple(points[ii].astype(np.int32)), (255), sigma)
	img = cv2.GaussianBlur(img, (sigma, sigma), sigma)
	return img.astype(np.float64)/255.0

def curve_fill(points, heatmapSize=256, sigma=3, erode=False):
	sigma = max(1,(sigma // 2)*2 + 1)
	points = points.astype(np.int32)
	canvas = np.zeros([heatmapSize, heatmapSize])
	cv2.fillPoly(canvas,np.array([points]),255)
	'''
	kernel = np.ones((sigma, sigma), np.uint8)
	if erode:
		erode_kernel = np.ones((int(0.5*sigma), int(0.5*sigma)), np.uint8)
		canvas = cv2.erode(canvas, erode_kernel)
	else:
		canvas = cv2.dilate(canvas, kernel)
	'''
	canvas = cv2.GaussianBlur(canvas, (sigma,sigma), sigma)
	return canvas.astype(np.float64)/255.0

def curves2heatmap(curves,heatmapSize=256,sigma=3,flag='line'):
	#-----------------------input--------------------------
	# curves [list of ndarray] : points coordinate in [0,heatmapSize]
	# heatmapSize[int]: the size of the generated heatmap 
	# sigma[float]: Boundary vagueness
	# flag[string]: 'line' or 'segment'
	#-----------------------output----------------
	# heatmap[ndarray,float64]: [D,D,num of curves],range in (0.0,1.0)
	#=============================================
	heatmap = np.zeros((heatmapSize, heatmapSize, len(curves)),np.float64)
	for i in range(len(curves)):
		if flag == 'line':
			heatmap[:, :, i] = curve_interp(curves[i], heatmapSize, sigma)
		else:
			heatmap[:, :, i] = curve_fill(curves[i], heatmapSize, sigma)
	return heatmap

def curves2segments(curves,heatmapSize=256,sigma=3):
	#res[ndarray]: range in (0,1) [Channel,Size,Size]
	face = curve_fill(np.vstack([curves[0],curves[2][::-1],curves[1][::-1]]),heatmapSize,sigma)
	browL = curve_fill(np.vstack([curves[1],curves[13][::-1]]),heatmapSize,sigma)
	browR = curve_fill(np.vstack([curves[2],curves[14][::-1]]),heatmapSize,sigma)
	eyeL = curve_fill(np.vstack([curves[5],curves[6]]),heatmapSize,sigma)
	eyeR = curve_fill(np.vstack([curves[7],curves[8]]),heatmapSize,sigma)
	gazeL = curve_fill(curves[15],heatmapSize,sigma)
	gazeR = curve_fill(curves[16],heatmapSize,sigma)

	#intersect eye and gaze
	gazeL = gazeL * eyeL
	gazeR = gazeR * eyeR
	#2 to 1
	eye = np.max([eyeL,eyeR],axis=0)
	gaze = np.max([gazeL,gazeR],axis=0)
	brow = np.max([browL,browR],axis=0)

	nose = curve_fill(np.vstack([curves[3][0:1],curves[4]]),heatmapSize,sigma)
	lipU= curve_fill(np.vstack([curves[9],curves[10][::-1]]),heatmapSize,sigma)
	lipD= curve_fill(np.vstack([curves[11],curves[12][::-1]]),heatmapSize,sigma)
	tooth = curve_fill(np.vstack([curves[10],curves[11][::-1]]),heatmapSize,sigma)
	return np.stack([face,brow,eye,gaze,nose,lipU,lipD,tooth])

def curves2gaze(curves,heatmapSize=256,sigma=3):
	eyeL = curve_fill(np.vstack([curves[5],curves[6]]),heatmapSize,sigma)
	eyeR = curve_fill(np.vstack([curves[7],curves[8]]),heatmapSize,sigma)
	gazeL = curve_fill(curves[15],heatmapSize,sigma)
	gazeR = curve_fill(curves[16],heatmapSize,sigma)
	#intersect eye and gaze
	gazeL = gazeL * eyeL
	gazeR = gazeR * eyeR
	return np.stack([gazeL,gazeR])
	
def curves2parts(curves):
	bound = curves[0]
	browL = np.vstack([curves[1],curves[13]])
	browR = np.vstack([curves[2],curves[14]])
	eyeL = np.vstack([curves[5],curves[6]])
	eyeR = np.vstack([curves[7],curves[8]])
	gazeL = curves[15]
	gazeR = curves[16]
	nose = np.vstack([curves[3],curves[4]])
	lipU= np.vstack([curves[9],curves[10]])
	lipD= np.vstack([curves[11],curves[12]])
	return [bound,browL,browR,eyeL,eyeR,nose,lipU,lipD,gazeL,gazeR]
	


def points2curves(points, heatmapSize=256,  sigma=1, heatmap_num=17):
	#-----------------------input--------------------------
	# points[ndarray]: [...,[x,y],...],range in (0.0,1.0)
	# heatmapSize[int]: the size of the generated heatmap 
	# heatmapNum: number of heatmap channels
	#-----------------------output----------------
	# curves [list of ndarray] : points coordinate in [0,heatmapSize]
	# =====================================================
	# resize points (0-1) to heatmapSize(0-D)
	for i in range(points.shape[0]):
		points[i] *= (float(heatmapSize))
	# curve define
	curves = [0]*heatmap_num
	curves[0] = np.zeros((33, 2))  # contour
	curves[1] = np.zeros((5, 2))  # left top eyebrow
	curves[2] = np.zeros((5, 2))  # right top eyebrow
	curves[3] = np.zeros((4, 2))  # nose bridge
	curves[4] = np.zeros((9, 2))  # nose tip
	curves[5] = np.zeros((5, 2))  # left top eye
	curves[6] = np.zeros((5, 2))  # left bottom eye
	curves[7] = np.zeros((5, 2))  # right top eye
	curves[8] = np.zeros((5, 2))  # right bottom eye
	curves[9] = np.zeros((7, 2))  # up up lip
	curves[10] = np.zeros((5, 2))  # up bottom lip
	curves[11] = np.zeros((5, 2))  # bottom up lip
	curves[12] = np.zeros((7, 2))  # bottom bottom lip
	curves[13] = np.zeros((5, 2))  # left bottom eyebrow
	curves[14] = np.zeros((5, 2))  # left bottom eyebrow
	if heatmap_num == 17:
		curves[15] = np.zeros((20, 2))  # left gaze
		curves[16] = np.zeros((20, 2))  # right gaze
	# assignment proccess
	# countour
	for i in range(33):
		curves[0][i] = points[i]
	for i in range(5):
		# left top eyebrow
		curves[1][i] = points[i+33]
		# right top eyebrow
		curves[2][i] = points[i+38]
	# nose bridge
	for i in range(4):
		curves[3][i] = points[i+43]
	# nose tip
	curves[4][0] = points[80]
	curves[4][1] = points[82]
	for i in range(5):
		curves[4][i+2] = points[i+47]
	curves[4][7] = points[83]
	curves[4][8] = points[81]
	# left top eye
	curves[5][0] = points[52]
	curves[5][1] = points[53]
	curves[5][2] = points[72]
	curves[5][3] = points[54]
	curves[5][4] = points[55]
	# left bottom eye
	curves[6][0] = points[55]
	curves[6][1] = points[56]
	curves[6][2] = points[73]
	curves[6][3] = points[57]
	curves[6][4] = points[52]
	# right top eye
	curves[7][0] = points[58]
	curves[7][1] = points[59]
	curves[7][2] = points[75]
	curves[7][3] = points[60]
	curves[7][4] = points[61]
	# right bottom eye
	curves[8][0] = points[61]
	curves[8][1] = points[62]
	curves[8][2] = points[76]
	curves[8][3] = points[63]
	curves[8][4] = points[58]
	# up up lip
	for i in range(7):
		curves[9][i] = points[i+84]
	# up bottom lip
	for i in range(5):
		curves[10][i] = points[i+96]
	# bottom up lip
	curves[11][0] = points[96]
	curves[11][1] = points[103]
	curves[11][2] = points[102]
	curves[11][3] = points[101]
	curves[11][4] = points[100]
	# bottom bottom lip
	curves[12][0] = points[84]
	curves[12][1] = points[95]
	curves[12][2] = points[94]
	curves[12][3] = points[93]
	curves[12][4] = points[92]
	curves[12][5] = points[91]
	curves[12][6] = points[90]
	# left bottom eyebrow
	curves[13][0] = points[33]
	curves[13][1] = points[64]
	curves[13][2] = points[65]
	curves[13][3] = points[66]
	curves[13][4] = points[67]
	# right bottom eyebrow
	curves[14][0] = points[68]
	curves[14][1] = points[69]
	curves[14][2] = points[70]
	curves[14][3] = points[71]
	curves[14][4] = points[42]
	if heatmap_num == 17:
		# left gaze
		for i in range(20):
			curves[15][i] = points[106+i]
		# right gaze
		for i in range(20):
			curves[16][i] = points[106+20+i]

	return curves,None

def distance(p1, p2):
	return math.sqrt((p1[0]-p2[0])*(p1[0]-p2[0])+(p1[1]-p2[1])*(p1[1]-p2[1]))

def curve_fitting(points, heatmap_size, sigma):
	curve_tmp = curve_interp(points, heatmap_size, sigma)
	return curve_tmp


if __name__ == '__main__':
	import matplotlib.pyplot as plt
	res = list()
	path = '../2019CVPR_reconstruct/data/celebHQ/lms.txt'
	head = '../2019CVPR_reconstruct/data/celebHQ/align_384'
	with open(path, 'r') as fin:
		data = fin.read().splitlines()
		N = len(data)//2
		for i in range(N):
			imgPath = os.path.join(head, data[2*i+0])
			landmarks = list(map(float, data[2*i+1].split()))
			res.append((imgPath, landmarks))
	for path,landmark in res:
		points = (np.array(landmark).reshape(-1,2).astype(np.float32)-64)/256.0
		curves = points2curves(points)
		segments = curves2segments(curves)
		img= np.sum(segments,axis=0)

		plt.figure()
		plt.imshow(img)
		for i in range(len(points)):
			plt.plot(points[i][0],(points[i][1]),'.',255,1)
			if i<=106:
				plt.text(points[i][0], (points[i][1]), str(i), fontsize=5)
			else:
				plt.text(points[i][0], (points[i][1]), str(i), fontsize=3)
		plt.show()
	


================================================
FILE: fusion/test.py
================================================
import multiprocessing
import time
from tqdm import *

# list = [1, 2, 3, 4]
def func(i):
    msg = "hello %d" % (list[i])
    print ("msg:", msg)
    time.sleep(3)
    print("end")


list = [1, 2, 3, 4]
data = []
if __name__ == "__main__":
    pool = multiprocessing.Pool(processes = 4)
    # list = [1, 2, 3, 4]
    for i in tqdm(range(4)):
        data.append(i)
        pool.map(func, data)
        # pool.apply_async(func, (i, list, ))   #维持执行的进程总数为processes,当一个进程执行完毕后会添加新的进程进去

    print("Mark~ Mark~ Mark~~~~~~~~~~~~~~~~~~~~~~")
    pool.close()
    pool.join()
    print("Sub-process(es) done.")

================================================
FILE: fusion/warper.py
================================================
import numpy as np
import scipy.spatial as spatial
from builtins import range
import cv2
from matplotlib import pyplot as plt


def warping(img, src_bound, dst_bound, size=(256, 256)):
    d = 254
    bound = np.array([0, 0, 0, d, d, 0, d, d]).reshape(-1, 2)
    src_bound = np.vstack([src_bound, bound]).astype(np.int32)
    dst_bound = np.vstack([dst_bound, bound]).astype(np.int32)
    src_bound[src_bound > d] = d
    dst_bound[dst_bound > d] = d
    src_bound = src_bound.astype(np.int32)
    dst_bound = dst_bound.astype(np.int32)
    res = warp_image(img, src_bound, dst_bound, size)
    return res


def bilinear_interpolate(img, coords):
    """ Interpolates over every image channel
    http://en.wikipedia.org/wiki/Bilinear_interpolation

    :param img: max 3 channel image
    :param coords: 2 x _m_ array. 1st row = xcoords, 2nd row = ycoords
    :returns: array of interpolated pixels with same shape as coords
    """
    int_coords = np.int32(coords)
    x0, y0 = int_coords
    dx, dy = coords - int_coords

    # 4 Neighour pixels
    q11 = img[y0, x0]
    q21 = img[y0, x0+1]
    q12 = img[y0+1, x0]
    q22 = img[y0+1, x0+1]

    btm = q21.T * dx + q11.T * (1 - dx)
    top = q22.T * dx + q12.T * (1 - dx)
    inter_pixel = top * dy + btm * (1 - dy)

    return inter_pixel.T


def grid_coordinates(points):
    """ x,y grid coordinates within the ROI of supplied points

    :param points: points to generate grid coordinates
    :returns: array of (x, y) coordinates
    """
    xmin = np.min(points[:, 0])
    xmax = np.max(points[:, 0]) + 1
    ymin = np.min(points[:, 1])
    ymax = np.max(points[:, 1]) + 1
    return np.asarray([(x, y) for y in range(ymin, ymax)
                       for x in range(xmin, xmax)], np.uint32)


def process_warp(src_img, result_img, tri_affines, dst_points, delaunay):
    """
    Warp each triangle from the src_image only within the
    ROI of the destination image (points in dst_points).
    """
    roi_coords = grid_coordinates(dst_points)
    # indices to vertices. -1 if pixel is not in any triangle
    roi_tri_indices = delaunay.find_simplex(roi_coords)

    for simplex_index in range(len(delaunay.simplices)):
        coords = roi_coords[roi_tri_indices == simplex_index]
        num_coords = len(coords)
        out_coords = np.dot(tri_affines[simplex_index],
                            np.vstack((coords.T, np.ones(num_coords))))
        x, y = coords.T
        result_img[y, x] = bilinear_interpolate(src_img, out_coords)

    return None


def triangular_affine_matrices(vertices, src_points, dest_points):
    """
    Calculate the affine transformation matrix for each
    triangle (x,y) vertex from dest_points to src_points

    :param vertices: array of triplet indices to corners of triangle
    :param src_points: array of [x, y] points to landmarks for source image
    :param dest_points: array of [x, y] points to landmarks for destination image
    :returns: 2 x 3 affine matrix transformation for a triangle
    """
    ones = [1, 1, 1]
    for tri_indices in vertices:
        src_tri = np.vstack((src_points[tri_indices, :].T, ones))
        dst_tri = np.vstack((dest_points[tri_indices, :].T, ones))
        mat = np.dot(src_tri, np.linalg.inv(dst_tri))[:2, :]
        yield mat


def warp_image(src_img, src_points, dest_points, dest_shape, dtype=np.uint8):
    # Resultant image will not have an alpha channel
    num_chans = 3
    src_img = src_img[:, :, :3]

    rows, cols = dest_shape[:2]
    result_img = np.zeros((rows, cols, num_chans), dtype)

    delaunay = spatial.Delaunay(dest_points)
    tri_affines = np.asarray(list(triangular_affine_matrices(
        delaunay.simplices, src_points, dest_points)))

    process_warp(src_img, result_img, tri_affines, dest_points, delaunay)

    return result_img


if __name__ == "__main__":
    pass


================================================
FILE: loader/__init__.py
================================================


================================================
FILE: loader/dataset_basic.py
================================================
# coding:utf-8
import sys
from utils.points2heatmap import curves2segments,points2curves
from utils import warper
import os
import numpy as np
import torch
import torch.utils.data
import cv2
from tqdm import *


class DatasetBasic(torch.utils.data.Dataset):
	def __init__(self, imgSize=256):
		# imgSize[int]
		self.boundList = None
		self.appearList = None
		self.imgSize = imgSize
		self.sigma = 3

	def __len__(self):
		return -1

	def shape(self):
		return self.__len__()

	def loadtxt(self, path, head=''):
		# path[string] : path to lms.txt
		#		format:	subpath of img
		#				landmarks [106*2 + 20*2] or 40*2
		# head[string] : head of subpath
		res = list()
		with open(path, 'r') as fin:
			data = fin.read().splitlines()
			N = len(data)//2
			for i in tqdm(range(N)):
				imgPath = os.path.join(head, data[2*i+0])
				landmarks = list(map(float, data[2*i+1].split()))
				res.append((imgPath, landmarks))
		return res

	def loadtxtList(self, pathList, head):
		res = list()
		for path in pathList:
			res += self.loadtxt(path, head)
		return res

	def warp(self, img, srcPt, dstPt):
		# img[ndarray]: shape = (3,D,D)
		# srcPt[ndarray]: shape = (K,2) ,K key points
		# dstPt[ndarray]: shape = (K,2) ,K key points
		return warper.warping(img, srcPt.reshape(-1, 2), dstPt.reshape(-1, 2), (self.imgSize, self.imgSize))

	def np2tensor(self, img, scale=1/255.0):
		#========input=======
		#img[ndarray][H,W,C] (0.0,1/scale)
		#========output======
		#img[ndarray][C,H,W] (-1.0,1.0)
		img = img.transpose((2, 0, 1))
		img = torch.from_numpy(img).float() * scale
		img = (img - 0.5) / 0.5
		return img

	def points2heatmap(self, landmarks, mapSize, sigma, landmarkSize=255.0, heatmap_num=17):
		# landmarks[ndarray]	: shape = (K,2), K = 106 + 20*2
		# landmarkSize [float]
		# mapSize[int] 		: output size
		# sigma[float] 		: gaussian sigma
		# _________Return__________
		# heatmap[tensor]: [C,H,W]
		# curve[list] :	list[list]
		if landmarks.max() > 1:
			landmarks /= landmarkSize
			landmarks[landmarks > 1] = 1
		curves, boundary = points2curves(landmarks, mapSize, sigma, heatmap_num)
		# [C,H,W] (0.0,1.0)
		heatmap = curves2segments(curves)
		# np 2 tensor
		heatmap = torch.from_numpy(heatmap).float()

		# boundary heatmap
		boundary = boundary.transpose([2, 0, 1])
		boundary = torch.from_numpy(boundary).float()

		return heatmap, curves, boundary



	'''
	def getRois(self, curve, sigma, onlyMask=False):
		# curves[list[list]] :
		# sigma[float]		: gaussian sigma
		# onlyMask[bool] 	: for train ,just need mask of face
		bound = np.vstack([curve[1], curve[2], curve[0]])

		mask_bound = genROI(bound, D=5, sigma=5)
		if onlyMask:
			return None, mask_bound
		browL = np.vstack([curve[1], curve[13]])
		browR = np.vstack([curve[2], curve[14]])
		eyeL = np.vstack([curve[5], curve[6]])
		eyeR = np.vstack([curve[7], curve[8]])
		nose = np.vstack([curve[3], curve[4]])
		teeth = np.vstack([curve[10], curve[11]])
		mouth = np.vstack([curve[9], curve[12]])

		mask_browL = genROI(browL)
		mask_browR = genROI(browR)
		mask_eyeL = genROI(eyeL, )
		mask_eyeR = genROI(eyeR)
		mask_nose = genROI(nose)
		mask_mouth = genROI(mouth)
		mask_teeth = genROI(teeth)
		mask_skin = (1-mask_eyeL)*(1-mask_eyeR)*(1-mask_nose)*(1-mask_browL)*(1-mask_browR)\
			* (1-mask_teeth)*(1-mask_mouth)

		if len(curve) == 17:
			# gaze
			gazeL = curve[15]
			gazeR = curve[16]
			mask_gazeL = genROI(gazeL,erode=True)
			mask_gazeR = genROI(gazeR,erode=True)
			return {'browL': mask_browL, 'browR': mask_browR, 'eyeL': mask_eyeL, 'eyeR': mask_eyeR, 'nose': mask_nose,
					'mouth': mask_mouth, 'teeth': mask_teeth, 'skin': mask_skin, 'gazeL': mask_gazeL, 'gazeR': mask_gazeR}, mask_bound
		else:
			return {'browL': mask_browL, 'browR': mask_browR, 'eyeL': mask_eyeL, 'eyeR': mask_eyeR, 'nose': mask_nose,
					'mouth': mask_mouth, 'teeth': mask_teeth, 'skin': mask_skin, }, mask_bound

	def fix_gaze(self, eye_roi, gaze_roi):
		intersect = eye_roi * gaze_roi
		return intersect
	'''

	def gammaTrans(self, img, gamma):
		gamma_table = [np.power(x/255.0, gamma)*255.0 for x in range(256)]
		gamma_table = np.round(np.array(gamma_table)).astype(np.uint8)
		return cv2.LUT(img, gamma_table)

	def __getitem__(self, index):
		pass


================================================
FILE: loader/dataset_loader_demo.py
================================================
from loader.dataset_basic import *
import random
import numpy as np
import copy
import torch as th
from utils.affineFace import affineface

class DatasetLoaderDemo(DatasetBasic):
	def __init__(self, imgSize=256, gaze=True):
		super(DatasetLoaderDemo, self).__init__(imgSize)
		self.boundList = None
		self.appearList = None
		self.rule = 'sequence'
		self.indexAppear = 0

	def loadBounds(self, pathList, head):
		self.boundList = self.loadtxtList(pathList, head)

	def loadAppears(self, pathList, head):
		self.appearList = self.loadtxtList(pathList, head)

	def setAppearRule(self, flag='random'):
		# flag[string]: random / similar / sequence
		# call this function after load data
		if self.appearList == None:
			print('please call setAppearRule after load data!')
		if flag != 'random' and flag != 'similar' and flag != 'sequence':
			print('rule: ', 'random / similar / sequence')
		else:
			self.rule = flag
			if flag == 'random':
				self.indexAppear = random.randint(0, len(self.appearList)-1)
			else:
				pass

	def findSimilar(self, pt_dst):
		minVal = 1e5
		res = 0
		for index in range(len(self.appearList)):
			_, pt = self.appearList[index]
			pt = (np.array(pt) - 64).reshape(-1, 2)
			diff = np.linalg.norm(pt[:106] - pt_dst[:106])
			if diff < minVal:
				res = index
				minVal = diff
		return res

	def adjustPose(self, img_src, pt_src, pt_dst):
		img_align, pt_align = affineface(img_src, pt_src, pt_dst)
		return img_align, pt_align

	def add_nose_bridge(self, boundary, heatmap):
		# add nose bridge boundary and dilate
		nose_bridge = copy.copy(boundary[3:4])
		kernel = np.ones((4, 4), np.uint8)
		nose_bridge = 255 * torch.from_numpy(cv2.dilate(nose_bridge.squeeze(0).numpy(), kernel)).unsqueeze(0).float()
		heatmap = torch.cat((heatmap, nose_bridge), 0)
		return heatmap

	def __getitem__(self, index):
		# load dst
		path, pt = self.boundList[index]
		img_dst = cv2.imread(path, 1)[64:64+256, 64:64+256]
		pt_dst = (np.array(pt) - 64).reshape(-1, 2)
		# dst
		heatmap_dst, curves_dst, boundary_dst = self.points2heatmap(pt_dst, self.imgSize, sigma=self.sigma)
		heatmap_dst = self.add_nose_bridge(boundary_dst, heatmap_dst)  # add nose bridge boundary and dilate
		weighted_mask_dst = heatmap_dst[0:1] + 2 * heatmap_dst[1:2] + 3 * heatmap_dst[2:3] + 4 * heatmap_dst[3:4] + 2 * heatmap_dst[4:5] + \
							3 * heatmap_dst[5:6] + 3 * heatmap_dst[6:7] + 2 * heatmap_dst[7:8]  +  heatmap_dst[8:]
		#select reference
		if self.rule == 'random':
			index = self.indexAppear
		elif self.rule == 'similar':
			index = self.findSimilar(pt_dst)
		elif self.rule == 'sequence':
			index = min(index, len(self.appearList)-1)

		# load src
		path, pt = self.appearList[index]
		img_src = cv2.imread(path, 1)[64:64+256, 64:64+256]
		img_src_np = img_src
		img_src = self.gammaTrans(img_src, 0.5)
		pt_src = (np.array(pt) - 64).reshape(-1, 2)
		pt_src_np = pt_src

		# align pose src 2 dst
		img_src,pt_src = self.adjustPose(img_src,pt_src/256.0,pt_dst/256.0)
		img_src = self.warp(img_src, pt_src, np.vstack([pt_dst[:33], pt_src[33:]]))

		# src
		heatmap_src, curves_src, boundary_src = self.points2heatmap(pt_src, self.imgSize, sigma=self.sigma)

		#np 2 tensor scale = [-1,1]
		img_src = self.np2tensor(img_src)
		img_dst = self.np2tensor(img_dst)

		return {'img_src': img_src, 'face_mask_src': heatmap_src[0:1],
				'img_dst': img_dst, 'face_mask_dst': heatmap_dst[0:1], 'seg_dst': heatmap_dst, 'weighted_mask_dst': weighted_mask_dst,
				'pt_src': pt_src_np, 'pt_dst': pt_dst, 'img_src_np': img_src_np}

	def __len__(self):
		return len(self.boundList)


================================================
FILE: loader/dataset_loader_train.py
================================================
from loader.dataset_basic import *
from utils.transforms import initAlignTransfer
#from utils.transforms import shakeCurve
import random
import numpy as np
import torch as th
import copy

class DatasetLoaderTrain(DatasetBasic):
	def __init__(self, imgSize=256, gaze=True):
		super(DatasetLoaderTrain, self).__init__(imgSize)
		self.transformAlign = initAlignTransfer(self.imgSize,mirror=False, gaze=gaze)
		self.dataList = None
		self.isTransform = True
		self.SampleCurveType = 'Bound'
		if gaze:
			self.heatmap_num = 17
		else:
			self.heatmap_num = 15
	
	def setSampleCurve(self, flag='Bound'):
		self.SampleCurveType = flag

	def transform(self, img, pt):
		pack = self.transformAlign(pt, img)
		pt = pack[0]
		img = pack[1]
		return img, pt

	def loaddata(self, pathList, head):
		# pathList[list] : [path1,path2,...]
		# head[string]:	head of subpath
		self.dataList = self.loadtxtList(pathList, head)

	def sampleCurve(self, img, pt, flag='Bound'):
		# if flag == 'None':
		# 	return img_src.copy(), pt_src.copy()
		if flag == 'Bound':
			_, pt_sample = random.sample(self.dataList, 1)[0]
			pt_sample = (np.array(pt_sample) - 64).reshape(-1, 2)
			pt_ = pt.copy()
			pt_[:33] = pt_sample[:33]
			img_warped = self.warp(img, pt, pt_)
			return img_warped, pt_
		# if flag == 'Shake':
		# 	pt_ = shakeCurve(pt)
		# 	img_warped = self.warp(img, pt, pt_)
		# 	return img_warped, pt_
		return None

	def add_nose_bridge(self, boundary, heatmap):
		# add nose bridge boundary and dilate
		nose_bridge = copy.copy(boundary[3:4])
		kernel = np.ones((4, 4), np.uint8)
		nose_bridge = 255 * torch.from_numpy(cv2.dilate(nose_bridge.squeeze(0).numpy(), kernel)).unsqueeze(0).float()  # todo
		heatmap = torch.cat((heatmap, nose_bridge), 0)
		return heatmap

	def __getitem__(self, index):
		path, pt = self.dataList[index]
		img_src = cv2.imread(path, 1)[64:64+256, 64:64+256] # [256, 256, 3]
		pt_src = (np.array(pt) - 64).reshape(-1, 2) # [146,2]
		img_src = self.gammaTrans(img_src, 0.5)
		# sample strategy
		img_dst, pt_dst = self.sampleCurve(img_src, pt_src, flag=self.SampleCurveType)
		# img_dst = copy.copy(img_src)
		# pt_dst = copy.copy(pt_src)
		# data augmentation
		if self.isTransform:
			img_dst, pt_dst = self.transform(img_dst, pt_dst)
		# src
		heatmap_src, curves_src, boundary_src = self.points2heatmap(pt_src, self.imgSize, sigma=self.sigma) # heatmap_src: tensor [8, 256, 256], curve_src: list of 17, each eliemnt: an array
		heatmap_src = self.add_nose_bridge(boundary_src, heatmap_src)
		weighted_mask_src = heatmap_src[0:1] + 2 * heatmap_src[1:2] + 3 * heatmap_src[2:3] + 4 * heatmap_src[3:4] + 2 * heatmap_src[4:5] + \
							3 * heatmap_src[5:6] + 3 * heatmap_src[6:7] + 2 * heatmap_src[7:8] + heatmap_src[8:]
		# dst
		heatmap_dst, curves_dst, boundary_dst = self.points2heatmap(pt_dst, self.imgSize, sigma=self.sigma)
		heatmap_dst = self.add_nose_bridge(boundary_dst, heatmap_dst)  # add nose bridge boundary and dilate
		weighted_mask_dst = heatmap_dst[0:1] + 2 * heatmap_dst[1:2] + 3 * heatmap_dst[2:3] + 4 * heatmap_dst[3:4] + 2 * heatmap_dst[4:5] + \
							3 * heatmap_dst[5:6] + 3 * heatmap_dst[6:7] + 2 * heatmap_dst[7:8] + heatmap_dst[8:]

		img_src = self.np2tensor(img_src)
		img_dst = self.np2tensor(img_dst)
		return {'img_src':img_src,'seg_src':heatmap_src,'face_mask_src':heatmap_src[0:1], 'boundary_dst': boundary_dst,
				'img_dst':img_dst,'seg_dst':heatmap_dst,'face_mask_dst':heatmap_dst[0:1], 'weighted_mask_dst': weighted_mask_dst}

	def __len__(self):
		return len(self.dataList)


================================================
FILE: model/base_model.py
================================================
import os
import torch
from collections import OrderedDict
import net.base_net as base_net
import shutil
from tensorboardX import SummaryWriter
import json
import random
import logging
import datetime


class BaseModel():

    # modify parser to add command line options,
    # and also change the default values if needed
    @staticmethod
    def modify_commandline_options(parser, is_train):
        return parser

    def name(self):
        return 'BaseModel'

    def initialize(self, opt):
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.isTrain = ('train' == opt.phase)
        self.device = torch.device('cuda:{}'.format(
            self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
        # self.save_dir = opt.save_dir
        self.loss_names = []
        self.visual_names = []
        self.image_paths = []
        self.train_model_name = []

        # if os.path.exists(self.save_dir):
        #     shutil.rmtree(self.save_dir)
        # os.makedirs(self.save_dir)

    def set_input(self, input):
        pass

    def forward(self):
        pass

    def set_logger(self, opt):
        if self.isTrain:
            run_id = random.randint(1,100000)
            self.logdir = os.path.join(opt.save_dir,str(run_id))
            self.writer = SummaryWriter(self.logdir)
            self.logger = self.get_logger(self.logdir)
            self.logger.info('Let the games begin')
            self.logger.info('save dir: runs/{}'.format(run_id))
            print('log dir : ', self.logdir)
        else:
            self.logdir = os.path.join(opt.save_dir,'test_res')
            self.writer = SummaryWriter(self.logdir)

    def get_logger(self, logdir):
        logger = logging.getLogger('myLogger')
        ts = str(datetime.datetime.now()).split('.')[0].replace(" ", "_")
        ts = ts.replace(":", "_").replace("-", "_")
        file_path = os.path.join(logdir, 'run_{}.log'.format(ts))
        hdlr = logging.FileHandler(file_path)
        formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
        hdlr.setFormatter(formatter)
        logger.addHandler(hdlr)
        logger.setLevel(logging.INFO)
        return logger

    def save_config(self, config):
        param_path = os.path.join(self.logdir, "params.json")
        print("[*] PARAM path: %s" % param_path)

        with open(param_path, 'w') as fp:
            json.dump(config.__dict__, fp, indent=4, sort_keys=True)


    # load and print networks; create schedulers
    def setup(self, opt, parser=None):
        if self.isTrain:
            self.schedulers = [base_net.get_scheduler(
                optimizer, opt) for optimizer in self.optimizers]
        if not self.isTrain or opt.load_path:
            # load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
            load_suffix = '{}/{}_net'.format(opt.load_path, opt.load_model_iter)
            self.load_networks_all(load_suffix)
            print('load {} successful!'.format(load_suffix))
        self.print_networks(opt.verbose)

    def load_networks_all(self, prefix):
        for name in self.train_model_name:
            if 'netD' in name:
                continue
            net = getattr(self, name)
            load_filename = '{}_{}.pth'.format(prefix, name)

            self.load_networks(net, load_filename)

    # load model
    def load_networks(self, model, path):
        if isinstance(model, torch.nn.DataParallel):
            model = model.module

        pretrainDict = torch.load(path, map_location=self.device)
        modelDict = model.state_dict()
        for kk, vv in pretrainDict.items():
            kk = kk.replace('module.', '')
            if kk in modelDict:
                modelDict[kk].copy_(vv)
            else:
                print('{} not in modelDict'.format(kk))
        # model.load_state_dict(pretrainDict)
        # print(modelDict.keys())

    # make models eval mode during test time
    def eval(self):
        for name in self.train_model_name:
            if isinstance(name, str):
                net = getattr(self, name)
                net.eval()

    # used in test time, wrapping `forward` in no_grad() so we don't save
    # intermediate steps for backprop
    def test(self):
        with torch.no_grad():
            self.forward()

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def optimize_parameters(self):
        pass

    # update learning rate (called once every epoch)
    def update_learning_rate(self):
        for scheduler in self.schedulers:
            scheduler.step()
        lr = self.optimizers[0].param_groups[0]['lr']
        print('learning rate = %.7f' % lr)

    # # return visualization images. train.py will display these images, and save the images to a html
    # def get_current_visuals(self):
    #     visual_ret = OrderedDict()
    #     for name in self.visual_names:
    #         if isinstance(name, str):
    #             visual_ret[name] = getattr(self, name)
    #     return visual_ret
    #
    # # return traning losses/errors. train.py will print out these errors as debugging information
    # def get_current_losses(self):
    #     errors_ret = OrderedDict()
    #     for name in self.loss_names:
    #         if isinstance(name, str):
    #             # float(...) works for both scalar tensor and float number
    #             errors_ret[name] = float(getattr(self, 'loss_' + name))
    #     return errors_ret

    # save models to the disk
    def save_networks(self, epoch):
        for name in self.train_model_name:
            if isinstance(name, str):
                save_filename = '%s_net_%s.pth' % (epoch, name)
                save_path = os.path.join(self.logdir, save_filename)
                net = getattr(self, name)

                if len(self.gpu_ids) > 0 and torch.cuda.is_available():
                    torch.save(net.module.cpu().state_dict(), save_path)
                    net.cuda(self.gpu_ids[0])
                    if len(self.gpu_ids) > 1:
                        net = torch.nn.DataParallel(net, self.opt.gpu_ids)
                else:
                    torch.save(net.cpu().state_dict(), save_path)

    # print network information

    def print_networks(self, verbose):
        print('---------- Networks initialized -------------')
        for name in self.train_model_name:
            if isinstance(name, str):
                net = getattr(self, name)
                num_params = 0
                for param in net.parameters():
                    num_params += param.numel()
                if verbose:
                    print(net)
                print('[Network %s] Total number of parameters : %.3f M' %
                      (name, num_params / 1e6))
        print('-----------------------------------------------')

    # set requies_grad=Fasle to avoid computation

    def set_requires_grad(self, nets, requires_grad=False):
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad


================================================
FILE: model/spade_model.py
================================================
import sys
from model.base_model import BaseModel
import net.vgg_net as vgg_net
import net.generaotr_net as generator_net
import net.discriminator_net as discriminator_net
import net.appear_decoder_net as appDec
import net.appear_encoder_net as appEnc
import net.face_id_net as face_id_net
import torch
import torch.nn.functional as F
import itertools
from utils import metric

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SpadeModel(BaseModel):
    def __init__(self, opt):
        super(SpadeModel, self).initialize(opt)
        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        # self.loss_names = ['vgg', 'id', 'reconstruct', 'gan']
        self.train_model_name = ['appEnc', 'appDnc', 'netG']
        # define appearance encoder, decoder
        self.appEnc = appEnc.defineAppEnc(
            3, norm='instance', init_type='normal', init_gain=0.02, gpu_ids=self.opt.gpu_ids, conv_k=3)
        self.appDnc = appDec.defineAppDec(
            3, norm='instance', init_type='normal', init_gain=0.02, gpu_ids=self.opt.gpu_ids)
        self.netG = generator_net.defineSPADEGenerator(opt.input_nc, opt.output_nc, 64, norm='instance',
                                                       init_type='normal', init_gain=0.02, gpu_ids=self.opt.gpu_ids,
                                                       latent_chl=1024, up_mode='convT')
        # -----pass-----
        if self.isTrain:
            # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
            self.pretrain_model_name = []
            if self.opt.loss_percept:
                self.pretrain_model_name.append('vgg')
            if self.opt.loss_faceID:
                self.pretrain_model_name.append('faceId')

            self.train_model_name += ['netD256', 'netD128', 'netD64']
            # load vgg and faceID networks
            if self.opt.loss_percept:
                self.vgg = vgg_net.defineVGG(
                    init_type='no', gpu_ids=self.opt.gpu_ids).eval()
            if self.opt.loss_faceID:
                self.faceId = face_id_net.defineFaceID(
                    input_nc=opt.output_nc, gpu_ids=self.opt.gpu_ids).eval()
                faceId_path = 'pretrainModel/id_200.pth'
                self.load_networks(self.faceId, faceId_path)

            use_sigmoid = opt.no_lsgan
            self.netD256 = discriminator_net.define_D(opt.output_nc, opt.ndf, opt.netD,
                                                      opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type,
                                                      opt.init_gain,
                                                      self.gpu_ids)
            self.netD128 = discriminator_net.define_D(opt.output_nc, opt.ndf, opt.netD,
                                                      2, opt.norm, use_sigmoid, opt.init_type, opt.init_gain,
                                                      self.gpu_ids)
            self.netD64 = discriminator_net.define_D(opt.output_nc, opt.ndf, opt.netD,
                                                     2, opt.norm, use_sigmoid, opt.init_type, opt.init_gain,
                                                     self.gpu_ids)

            # initialize optimizers
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                       itertools.chain(self.netG.parameters(),
                                                                       self.appEnc.parameters(),
                                                                       self.appDnc.parameters())),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(
                itertools.chain(self.netD256.parameters(), self.netD128.parameters(), self.netD64.parameters()),
                lr=0.5 * opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            self.loss_D = torch.tensor(0).float().to(
                device)  # initialize D_loss and gan_loss for G to 0 (first several epochs may not use gan)
            self.loss_gan = torch.tensor(0).float().to(device)

            # define loss functions
            self.criterionVGG = torch.nn.L1Loss().to(self.device)
            self.criterionId = torch.nn.L1Loss().to(self.device)
            self.criterionReconstruct = torch.nn.L1Loss().to(self.device)
            self.criterionPix = torch.nn.L1Loss().to(self.device)
            self.criterionGAN = discriminator_net.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)


    def set_input(self, input):
        self.seg_dst = input['seg_dst'].to(self.device)
        self.img_src = input['img_src'].to(self.device)
        self.srcMask = input['face_mask_src'].to(self.device)
        self.dstMask = input['face_mask_dst'].to(self.device)
        # apply ref & mask
        if 'img_dst' in input and self.isTrain:
            self.groundtruth = input['img_dst'].to(self.device)
            self.groundtruth = self.groundtruth * self.srcMask
        if 'weighted_mask_dst' in input:
            self.weightMask = input['weighted_mask_dst'].to(self.device)


    def forward(self):
        sample_z, kl_loss, _ = self.appEnc(self.img_src)  # [batch_size,1024,1,1]
        out16, out32, out64, out128, self.out256 = self.appDnc(sample_z)  # [1024, 16, 16,] [512, 32, 32], [256, 64, 64], [128, 128, 128], [3, 256, 256]
        self.fake_B = self.netG(self.seg_dst, sample_z, [out16, out32, out64, out128]) # [batch_size, 3, 256, 256]

        if self.isTrain:
            self.gt128 = F.max_pool2d(self.groundtruth, 3, stride=2)
            self.gt64 = F.max_pool2d(self.gt128, 3, stride=2)
            self.fake128 = F.max_pool2d(self.fake_B, 3, stride=2)
            self.fake64 = F.max_pool2d(self.fake128, 3, stride=2)
        return self.fake_B

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        self.loss_D = (loss_D_real + loss_D_fake) * 0.5
        return self.loss_D

    def backward_D(self):
        self.lossD256 = self.backward_D_basic(self.netD256, self.groundtruth, self.fake_B)
        self.lossD128 = self.backward_D_basic(self.netD128, self.gt128, self.fake128)
        self.lossD64 = self.backward_D_basic(self.netD64, self.gt64, self.fake64)
        self.loss_D = self.lossD256 + self.lossD128 + self.lossD64
        self.loss_D.backward()

    def backward_G(self, epoch):
        # perceptual loss
        if self.opt.loss_percept:
            self.loss_vgg = self.opt.lambda_vgg * (
                self.vgg.module.perceptual_loss(self.fake_B, self.groundtruth, self.criterionVGG) if hasattr(
                    self.vgg, 'module') else self.vgg.perceptual_loss(self.fake_B, self.groundtruth, self.criterionVGG))

        # Identity loss
        if self.opt.loss_faceID:
            fake_B_id = self.fake_B
            gt_id = self.groundtruth
            fake_B_id = fake_B_id[:,:,28:228, 28:228]
            gt_id = gt_id[:,:,28:228, 28:228]
            self.loss_id = self.opt.lambda_id * (
                self.faceId.module.face_id_loss(fake_B_id, gt_id, self.criterionId) if hasattr(
                    self.faceId, 'module') else self.faceId.face_id_loss(fake_B_id, gt_id, self.criterionId))

        # GAN loss
        if epoch >= self.opt.gan_start_epoch:
                self.loss_gan = self.opt.lambda_gan * (
                    self.criterionGAN(self.netD256(self.fake_B), True) + self.criterionGAN(self.netD128(self.fake128), True) + self.criterionGAN(self.netD64(self.fake64), True))

        # AE reconstruction loss
        self.loss_reconstruct = self.opt.lambda_reconstruct * self.criterionReconstruct(self.out256, self.img_src)

        # pixel loss between gt and generated image
        fake_B_pix = self.fake_B * (0.5 + self.weightMask)
        gt_pix = self.groundtruth * (0.5 + self.weightMask)
        self.loss_pix = self.opt.lambda_pix * self.criterionPix(fake_B_pix, gt_pix)

        # combined loss
        self.loss_G = torch.tensor(0).float().to(device)
        self.loss_G += self.loss_reconstruct
        self.loss_G += self.loss_pix
        if self.opt.loss_percept:
            self.loss_G += self.loss_vgg
        if self.opt.loss_faceID:
            self.loss_G += self.loss_id
        self.loss_G += self.loss_gan

        self.loss_G.backward()


    def func_require_grad(self, model_, flag_):
        for mm in model_:
            self.set_requires_grad(mm, flag_)

    def func_zero_grad(self, model_):
        for mm in model_:
            mm.zero_grad()

    def optimize_parameters(self, epoch):
        self.forward()
        # D
        if epoch >= self.opt.gan_start_epoch:  # start to include D after xxx epochs
            self.func_require_grad([self.netD256, self.netD128, self.netD64], True)
            self.func_zero_grad([self.netD256, self.netD128, self.netD64])
            self.backward_D()
            self.optimizer_D.step()
        # G
        self.func_require_grad([self.netD256, self.netD128, self.netD64], False)
        self.optimizer_G.zero_grad()
        self.backward_G(epoch)  # start to include gan loss for G after xxx epochs
        self.optimizer_G.step()


================================================
FILE: net/ResNet.py
================================================
# -*- coding: utf-8 -*-
"""
Created on 18-5-21 下午5:26
@author: ronghuaiyang
"""
import torch
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
import torch.nn.utils.weight_norm as weight_norm
import torch.nn.functional as F


model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class IRBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
        super(IRBlock, self).__init__()
        self.bn0 = nn.BatchNorm2d(inplanes)
        self.conv1 = conv3x3(inplanes, inplanes)
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.prelu = nn.PReLU()
        self.conv2 = conv3x3(inplanes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride
        self.use_se = use_se
        if self.use_se:
            self.se = SEBlock(planes)

    def forward(self, x):
        residual = x
        out = self.bn0(x)
        out = self.conv1(out)
        out = self.bn1(out)
        out = self.prelu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        if self.use_se:
            out = self.se(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.prelu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(
            planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction),
            nn.PReLU(),
            nn.Linear(channel // reduction, channel),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y


class ResNetFace(nn.Module):
    def __init__(self, block, layers, input_nc, use_se=True):
        self.inplanes = 64
        self.use_se = use_se
        super(ResNetFace, self).__init__()
        self.conv1 = nn.Conv2d(input_nc, 64, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.prelu = nn.PReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.bn4 = nn.BatchNorm2d(512)
        self.dropout = nn.Dropout()
        #self.fc5 = nn.Linear(512 * 8 * 8, 512) # 128
        self.fc5 = nn.Linear(512 * 13 * 13, 512) # 200
        self.bn5 = nn.BatchNorm1d(512)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride,
                            downsample, use_se=self.use_se))
        self.inplanes = planes
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, use_se=self.use_se))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.prelu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.bn4(x)
        #print(x.size())
        x = self.dropout(x)
        x = x.view(x.size(0), -1)
        x = self.fc5(x)
        x = self.bn5(x)
        return x


class ResNet(nn.Module):

    def __init__(self, block, layers):
        self.inplanes = 64
        super(ResNet, self).__init__()
        # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
        #                        bias=False)
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        # self.avgpool = nn.AvgPool2d(8, stride=1)
        # self.fc = nn.Linear(512 * block.expansion, num_classes)
        self.fc5 = nn.Linear(512 * 8 * 8, 512)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        # x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        # x = nn.AvgPool2d(kernel_size=x.size()[2:])(x)
        # x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc5(x)

        return x


def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model


def resnet34(pretrained=False, **kwargs):
    """Constructs a ResNet-34 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
    return model


def resnet50(pretrained=False, **kwargs):
    """Constructs a ResNet-50 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
    return model


def resnet101(pretrained=False, **kwargs):
    """Constructs a ResNet-101 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
    return model


def resnet152(pretrained=False, **kwargs):
    """Constructs a ResNet-152 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
    return model


def resnet_face18(input_nc, use_se=True, **kwargs):
    model = ResNetFace(IRBlock, [2, 2, 2, 2], input_nc, use_se=use_se, **kwargs)
    return model


================================================
FILE: net/appear_decoder_net.py
================================================
import torch as th
from torch import nn
import net.base_net as base_net
###############################################################################
# define
###############################################################################


def defineAppDec(input_nc, size_=256, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
    net = None
    norm_layer = base_net.get_norm_layer(norm_type=norm)
    if 128 == size_:
        net = appearDec128(input_nc, norm_layer=norm_layer, size_=size_)
    elif 256 == size_:
        net = appearDec(input_nc, norm_layer=norm_layer, size_=size_)
    return base_net.init_net(net, init_type, init_gain, gpu_ids)


class appearDec(nn.Module):
    def __init__(self, input_c, norm_layer, size_=256):
        super(appearDec, self).__init__()
        # input 3x256x256
        # encoder
        layers = []

        channel_list = [1024, 1024, 1024, 1024]  
        c0 = 1024
        for cc in channel_list:
            layers.append(nn.ConvTranspose2d(c0, cc, 4, 2, 1))
            layers.append(norm_layer(cc))
            layers.append(nn.ReLU(True))
            c0 = cc
        self.decoder16 = nn.Sequential(*layers)

        self.decoder32 = nn.Sequential(nn.ConvTranspose2d(1024, 512, 4, 2, 1), norm_layer(512), nn.ReLU(True))
        self.decoder64 = nn.Sequential(nn.ConvTranspose2d(512, 256, 4, 2, 1), norm_layer(256), nn.ReLU(True))
        self.decoder128 = nn.Sequential(nn.ConvTranspose2d(256, 128, 4, 2, 1), norm_layer(128), nn.ReLU(True))
        layers = []
        layers.append(nn.ConvTranspose2d(128, 3, 4, 2, 1))
        layers.append(nn.Tanh())
        self.decoder256 = nn.Sequential(*layers)


    def forward(self, input):
        out16 = self.decoder16(input)
        out32 = self.decoder32(out16)
        out64 = self.decoder64(out32)
        out128 = self.decoder128(out64)
        out256 = self.decoder256(out128)
        return out16, out32, out64, out128, out256

class appearDec128(nn.Module):
    def __init__(self, input_c, norm_layer, size_=256):
        super(appearDec128, self).__init__()
        # input 3x256x256
        # encoder
        layers = []

        channel_list = [1024, 1024, 1024]  
        c0 = 1024
        for cc in channel_list:
            layers.append(nn.ConvTranspose2d(c0, cc, 4, 2, 1))
            layers.append(norm_layer(cc))
            layers.append(nn.ReLU(True))
            c0 = cc
        self.decoder8 = nn.Sequential(*layers)

        self.decoder16 = nn.Sequential(nn.ConvTranspose2d(1024, 512, 4, 2, 1), norm_layer(512), nn.ReLU(True))
        self.decoder32 = nn.Sequential(nn.ConvTranspose2d(512, 256, 4, 2, 1), norm_layer(256), nn.ReLU(True))
        self.decoder64 = nn.Sequential(nn.ConvTranspose2d(256, 128, 4, 2, 1), norm_layer(128), nn.ReLU(True))
        layers = []
        layers.append(nn.ConvTranspose2d(128, 3, 4, 2, 1))
        layers.append(nn.Tanh())
        self.decoder128 = nn.Sequential(*layers)


    def forward(self, input):
        out8 = self.decoder8(input)
        out16 = self.decoder16(out8)
        out32 = self.decoder32(out16)
        out64 = self.decoder64(out32)
        out128 = self.decoder128(out64)
        return out8, out16, out32, out64, out128

================================================
FILE: net/appear_encoder_net.py
================================================
import torch as th
from torch import nn
import net.base_net as base_net
###############################################################################
# define
###############################################################################


def defineAppEnc(input_nc, size_=256, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[], conv_k=4):
    net = None
    norm_layer = base_net.get_norm_layer(norm_type=norm)

    net = appearEnc(input_nc, norm_layer=norm_layer, size_=size_, conv_k=conv_k)
    return base_net.init_net(net, init_type, init_gain, gpu_ids)


# class appearEnc(nn.Module):
#     def __init__(self, input_c, norm_layer, size_=256, conv_k=4):
#         super(appearEnc, self).__init__()
#         # input 3x256x256
#         # encoder
#         channel_list = [128, 256, 512, 1024, 1024, 1024]
#         c0 = 64
#         self.layer1 = nn.Sequential(
#             nn.Conv2d(input_c, c0, conv_k, 2, 1),
#             nn.LeakyReLU(0.2)
#         )
#         self.layer2 = nn.Sequential(
#             nn.Conv2d(c0, channel_list[0], conv_k, 2, 1),
#             norm_layer(channel_list[0]),
#             nn.LeakyReLU(0.2)
#         )
#         self.layer3 = nn.Sequential(
#             nn.Conv2d(channel_list[0], channel_list[1], conv_k, 2, 1),
#             norm_layer(channel_list[1]),
#             nn.LeakyReLU(0.2)
#         )
#         self.layer4 = nn.Sequential(
#             nn.Conv2d(channel_list[1], channel_list[2], conv_k, 2, 1),
#             norm_layer(channel_list[2]),
#             nn.LeakyReLU(0.2)
#         )
#         self.layer5 = nn.Sequential(
#             nn.Conv2d(channel_list[2], channel_list[3], conv_k, 2, 1),
#             norm_layer(channel_list[3]),
#             nn.LeakyReLU(0.2)
#         )
#         self.layer6 = nn.Sequential(
#             nn.Conv2d(channel_list[3], channel_list[4], conv_k, 2, 1),
#             norm_layer(channel_list[4]),
#             nn.LeakyReLU(0.2)
#         )
#         self.layer7 = nn.Sequential(
#             nn.Conv2d(channel_list[4], channel_list[5], conv_k, 2, 1),
#             norm_layer(channel_list[5]),
#             nn.LeakyReLU(0.2)
#         )
#         self.mean = nn.Conv2d(1024, 1024, conv_k, 2, 1)

class appearEnc(nn.Module):
    def __init__(self, input_c, norm_layer, size_=256, conv_k=4):
        super(appearEnc, self).__init__()
        # input 3x256x256
        # encoder
        layers = []
        channel_list = [128, 256, 512, 1024, 1024, 1024]

        c0 = 64
        layers.append(nn.Conv2d(input_c, c0, conv_k, 2, 1))
        layers.append(nn.LeakyReLU(0.2))
        for cc in channel_list:
            layers.append(nn.Conv2d(c0, cc, conv_k, 2, 1))
            layers.append(norm_layer(cc))
            layers.append(nn.LeakyReLU(0.2))
            c0 = cc
        self.encoder = nn.Sequential(*layers)
        # mean
        layers = []
        layers.append(nn.Conv2d(1024, 1024, conv_k, 2, 1))
        # layers.append(nn.ReLU())
        self.mean = nn.Sequential(*layers)
        # self.logvar = nn.Sequential(*layers)



    def sample_z(self, z_mu):
        z_std = 1.0
        eps = th.randn(z_mu.size()).type_as(z_mu) # random number in [0,1]
        return z_mu + z_std * eps

    # def sample_z(self, z_mu, z_logvar):
    #     z_std = th.exp(0.5 * z_logvar)
    #     eps = th.randn_like(z_std)
    #     return z_mu + z_std * eps

    def kl_loss(self, z_mu):
        #kl_loss = torch.mean(0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1. - z_var, 1))
        z_var = th.ones(z_mu.size()).type_as(z_mu) # [batch_size, 1024, 1, 1]
        kl_loss_ = th.mean(0.5 * th.sum(th.exp(z_var) + z_mu**2 - 1. - z_var, 1))
        return kl_loss_ # scalar loss

    # def kl_loss(self, z_mu, z_logvar):
    #     kl_loss = -0.5 * th.mean(1 + z_logvar - z_mu.pow(2) - z_logvar.exp())
    #     return kl_loss # scalar loss

    def freeze(self):
        for module_ in self.encoder:
            for p in module_.parameters():
                p.requires_grad = False

        for module_ in self.mean:
            for p in module_.parameters():
                p.requires_grad = False

    # def forward(self, input):
    #     encoder = self.encoder(input)  # input: [batch_size,3,256,256], encoder: [1, 1024, 2, 2]
    #     z_mu = self.mean(encoder)   # [batch_size,1024,1,1]
    #     z_logvar = self.logvar(encoder)  # [batch_size,1024,1,1]
    #
    #     sample_z = self.sample_z(z_mu, z_logvar) # [batch_size,1024,1,1]
    #     kl_loss = self.kl_loss(z_mu, z_logvar) # scalar KL loss
    #     return sample_z, kl_loss, z_mu


    def forward(self, input):
        encoder = self.encoder(input)  # input: [1,3,200,200]
        z_mu = self.mean(encoder)  # [1,1024,1,1]
        sample_z = self.sample_z(z_mu) # [1,1024,1,1]
        kl_loss = self.kl_loss(z_mu) # scalar KL loss
        return sample_z, kl_loss, z_mu

    # def forward(self, input):
    #     encode128 = self.layer1(input)
    #     encode64 = self.layer2(encode128)
    #     encode32 = self.layer3(encode64)
    #     encode16 = self.layer4(encode32)
    #     encode8 = self.layer5(encode16)
    #     encode4 = self.layer6(encode8)
    #     encode2 = self.layer7(encode4)
    #     z_mu = self.mean(encode2)  # [1,1024,1,1]
    #     sample_z = self.sample_z(z_mu)  # [1,1024,1,1]
    #     return sample_z, z_mu, encode2, encode4, encode8, encode16, encode32, encode64, encode128


================================================
FILE: net/base_net.py
================================================
import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.optim import lr_scheduler

###############################################################################
# base module set
###############################################################################


def get_norm_layer(norm_type='instance'):
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(
            nn.InstanceNorm2d, affine=False, track_running_stats=False)
    elif norm_type == 'none':
        norm_layer = None
    else:
        raise NotImplementedError(
            'normalization layer [%s] is not found' % norm_type)
    return norm_layer


def get_scheduler(optimizer, opt):
    if opt.lr_policy == 'lambda':
        def lambda_rule(epoch):
            lr_l = 1.0 - max(0, epoch-
                             opt.niter) / float(opt.niter_decay + 1)
            return lr_l
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
    elif opt.lr_policy == 'step':
        scheduler = lr_scheduler.StepLR(
            optimizer, step_size=opt.lr_decay_iters, gamma=0.5)
    elif opt.lr_policy == 'plateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
    elif opt.lr_policy == 'cosine':
        scheduler = lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=opt.niter, eta_min=0)
    else:
        return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
    return scheduler


def init_weights(net, init_type='normal', gain=0.02):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=gain)
            else:
                raise NotImplementedError(
                    'initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, gain)
            init.constant_(m.bias.data, 0.0)

    if init_type == 'no':
        print('not init')
    else:
        print('initialize network with %s' % init_type)
        net.apply(init_func)


def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
    if len(gpu_ids) > 0:
        assert(torch.cuda.is_available())
        net.to(gpu_ids[0])
        net = torch.nn.DataParallel(net, gpu_ids)

    init_weights(net, init_type, gain=init_gain)
    return net


================================================
FILE: net/discriminator_net.py
================================================
import torch
import torch.nn as nn
import functools
import net.base_net as base_net
###############################################################################
# discriminator define
###############################################################################


def define_D(input_nc, ndf, netD,
             n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
    net = None
    norm_layer = base_net.get_norm_layer(norm_type=norm)

    if netD == 'basic':
        net = NLayerDiscriminator(
            input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
    elif netD == 'n_layers':
        net = NLayerDiscriminator(
            input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
    elif netD == 'pixel':
        net = PixelDiscriminator(
            input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
    else:
        raise NotImplementedError(
            'Discriminator model name [%s] is not recognized' % net)
    return base_net.init_net(net, init_type, init_gain, gpu_ids)


##############################################################################
# Classes
##############################################################################


# Defines the GAN loss which uses either LSGAN or the regular GAN.
# When LSGAN is used, it is basically same as MSELoss,
# but it abstracts away the need to create the target label tensor
# that has the same size as the input
class GANLoss(nn.Module):
    def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        if use_lsgan:
            self.loss = nn.MSELoss()
        else:
            self.loss = nn.BCELoss()

    def get_target_tensor(self, input, target_is_real):
        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(input)

    def __call__(self, input, target_is_real):
        target_tensor = self.get_target_tensor(input, target_is_real)
        return self.loss(input, target_tensor)


# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
        super(NLayerDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = 1
        sequence = [
            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2**n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                          kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2**n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                      kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1,
                               kernel_size=kw, stride=1, padding=padw)]

        if use_sigmoid:
            sequence += [nn.Sigmoid()]

        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        return self.model(input)


class PixelDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
        super(PixelDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        self.net = [
            nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf, ndf * 2, kernel_size=1,
                      stride=1, padding=0, bias=use_bias),
            norm_layer(ndf * 2),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]

        if use_sigmoid:
            self.net.append(nn.Sigmoid())

        self.net = nn.Sequential(*self.net)

    def forward(self, input):
        return self.net(input)


================================================
FILE: net/face_id_mlp_net.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import math


class MLP(nn.Module):
    def __init__(self, input_nc, output_nc):
        super(MLP, self).__init__()
        self.fc = nn.Linear(input_nc, output_nc)

    def forward(self, input):
        return self.fc(input)


class ArcMarginProduct(nn.Module):
    r"""Implement of large margin arc distance: :
        Args:
            in_features: size of each input sample
            out_features: size of each output sample
            s: norm of input feature
            m: margin
            cos(theta + m)
        """

    def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.weight = Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------------
        # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
        one_hot = torch.zeros(cosine.size(), device='cuda')
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
        # you can use torch.where if your torch.__version__ is 0.4
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        # print(output)

        return output


class AddMarginProduct(nn.Module):
    r"""Implement of large margin cosine distance: :
    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        s: norm of input feature
        m: margin
        cos(theta) - m
    """

    def __init__(self, in_features, out_features, s=30.0, m=0.40):
        super(AddMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.weight = Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        phi = cosine - self.m
        # --------------------------- convert label to one-hot ---------------------------
        one_hot = torch.zeros(cosine.size(), device='cuda')
        # one_hot = one_hot.cuda() if cosine.is_cuda else one_hot
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
        # you can use torch.where if your torch.__version__ is 0.4
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        # print(output)

        return output

    def __repr__(self):
        return self.__class__.__name__ + '(' \
            + 'in_features=' + str(self.in_features) \
            + ', out_features=' + str(self.out_features) \
            + ', s=' + str(self.s) \
            + ', m=' + str(self.m) + ')'


class SphereProduct(nn.Module):
    r"""Implement of large margin cosine distance: :
    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        m: margin
        cos(m*theta)
    """

    def __init__(self, in_features, out_features, m=4):
        super(SphereProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.m = m
        self.base = 1000.0
        self.gamma = 0.12
        self.power = 1
        self.LambdaMin = 5.0
        self.iter = 0
        self.weight = Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform(self.weight)

        # duplication formula
        self.mlambda = [
            lambda x: x ** 0,
            lambda x: x ** 1,
            lambda x: 2 * x ** 2 - 1,
            lambda x: 4 * x ** 3 - 3 * x,
            lambda x: 8 * x ** 4 - 8 * x ** 2 + 1,
            lambda x: 16 * x ** 5 - 20 * x ** 3 + 5 * x
        ]

    def forward(self, input, label):
        # lambda = max(lambda_min,base*(1+gamma*iteration)^(-power))
        self.iter += 1
        self.lamb = max(self.LambdaMin, self.base *
                        (1 + self.gamma * self.iter) ** (-1 * self.power))

        # --------------------------- cos(theta) & phi(theta) ---------------------------
        cos_theta = F.linear(F.normalize(input), F.normalize(self.weight))
        cos_theta = cos_theta.clamp(-1, 1)
        cos_m_theta = self.mlambda[self.m](cos_theta)
        theta = cos_theta.data.acos()
        k = (self.m * theta / 3.14159265).floor()
        phi_theta = ((-1.0) ** k) * cos_m_theta - 2 * k
        NormOfFeature = torch.norm(input, 2, 1)

        # --------------------------- convert label to one-hot ---------------------------
        one_hot = torch.zeros(cos_theta.size())
        one_hot = one_hot.cuda() if cos_theta.is_cuda else one_hot
        one_hot.scatter_(1, label.view(-1, 1), 1)

        # --------------------------- Calculate output ---------------------------
        output = (one_hot * (phi_theta - cos_theta) /
                  (1 + self.lamb)) + cos_theta
        output *= NormOfFeature.view(-1, 1)

        return output

    def __repr__(self):
        return self.__class__.__name__ + '(' \
            + 'in_features=' + str(self.in_features) \
            + ', out_features=' + str(self.out_features) \
            + ', m=' + str(self.m) + ')'


================================================
FILE: net/face_id_net.py
================================================
import torch as th
from torch import nn
from net.ResNet import resnet_face18 as resnet18
from net.face_id_mlp_net import MLP
import net.base_net as base_net
###############################################################################
# define
###############################################################################


def defineFaceID(input_nc=3, class_num=10173, init_type='normal', init_gain=0.02, gpu_ids=[]):
    net = None
    net = faceIDNet(input_nc, class_num)
    return base_net.init_net(net, init_type, init_gain, gpu_ids)


class faceIDNet(nn.Module):
    def __init__(self, input_nc, class_num):
        super(faceIDNet, self).__init__()
        # input 3x256x256
        self.feat = resnet18(input_nc, use_se=False)
        self.mlp = MLP(512, class_num)

    def forward(self, input):
        feat = self.feat(input)
        pred = self.mlp(feat)
        return pred

    def face_id_loss(self, x, target, loss_func):
        targetIdFeat256 = self.feat(target).detach()
        faceIDFeat = self.feat(x)
        id_loss = loss_func(faceIDFeat, targetIdFeat256)
        return id_loss


================================================
FILE: net/generaotr_net.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import functools
import net.base_net as base_net


###############################################################################
# define spatially adaptive normalized generator
# input: boundary and apperance latent vector
###############################################################################


def defineSPADEGenerator(input_nc, output_nc, ngf, norm='instance', use_dropout=False, init_type='normal',
               init_gain=0.02, gpu_ids=[], latent_chl=1024, up_mode='NF'):
    norm_layer = base_net.get_norm_layer(norm_type=norm)
    net = SPADEGenerator(input_nc, output_nc, ngf,
                            norm_layer=norm_layer, latent_chl=latent_chl, up_mode=up_mode)
    return base_net.init_net(net, init_type, init_gain, gpu_ids)


##############################################################################
# Classes
##############################################################################

# class BasicSPADE(nn.Module):
#     def __init__(self, norm_layer, input_nc, planes):
#         super(BasicSPADE, self).__init__()
#         self.norm = norm_layer(planes, affine=False)
#
#         self.conv_weight1=nn.Conv2d(input_nc, input_nc, kernel_size=3, stride=1, padding=1)
#         self.conv_bias1=nn.Conv2d(input_nc, input_nc, kernel_size=3, stride=1, padding=1)
#         self.conv_weight2=nn.Conv2d(input_nc, input_nc, kernel_size=3, stride=1, padding=1)
#         self.conv_bias2=nn.Conv2d(input_nc, input_nc, kernel_size=3, stride=1, padding=1)
#         self.conv_weight3=nn.Conv2d(input_nc, input_nc, kernel_size=3, stride=1, padding=1)
#         self.conv_bias3=nn.Conv2d(input_nc, input_nc, kernel_size=3, stride=1, padding=1)
#         self.conv_weight4=nn.Conv2d(input_nc, input_nc, kernel_size=3, stride=1, padding=1)
#         self.conv_bias4=nn.Conv2d(input_nc, input_nc, kernel_size=3, stride=1, padding=1)
#
#         self.conv_weight=nn.Conv2d(input_nc, planes, kernel_size=3, stride=1, padding=1)
#         self.conv_bias=nn.Conv2d(input_nc, planes, kernel_size=3, stride=1, padding=1)
#
#     def forward(self, x, bound):
#         out = self.norm(x)
#
#         weight_norm1 = self.conv_weight1(bound)
#         bias_norm1 = self.conv_bias1(bound)
#         weight_norm2 = self.conv_weight2(weight_norm1)
#         bias_norm2 = self.conv_bias2(bias_norm1)
#         weight_norm3 = self.conv_weight3(weight_norm2)
#         bias_norm3 = self.conv_bias3(bias_norm2)
#         weight_norm4 = self.conv_weight4(weight_norm3)
#         bias_norm4 = self.conv_bias4(bias_norm3)
#
#         weight_norm = self.conv_weight(weight_norm4)
#         bias_norm = self.conv_bias(bias_norm4)
#
#         out = out * weight_norm + bias_norm
#         return out

class BasicSPADE(nn.Module):
    def __init__(self, norm_layer, input_nc, planes):
        super(BasicSPADE, self).__init__()
        self.conv_weight = nn.Conv2d(input_nc, planes, kernel_size=3, stride=1, padding=1)
        self.conv_bias = nn.Conv2d(input_nc, planes, kernel_size=3, stride=1, padding=1)
        self.norm = norm_layer(planes, affine=False)

    def forward(self, x, bound):
        out = self.norm(x)
        weight_norm = self.conv_weight(bound)
        bias_norm = self.conv_bias(bound)
        out = out * weight_norm + bias_norm
        return out


class ResBlkSPADE(nn.Module):
    def __init__(self, norm_layer, input_nc, planes, conv_kernel_size=1, padding=0):   # todo: change conv kernel size, kernel=3, padding=1 or kernel=1, padding=0
        super(ResBlkSPADE, self).__init__()
        self.spade1 = BasicSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(planes, planes, kernel_size=conv_kernel_size, stride=1, padding=padding)
        self.spade2 = BasicSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=conv_kernel_size, stride=1, padding=padding)
        self.spade_res = BasicSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=planes)
        self.conv_res = nn.Conv2d(planes, planes, kernel_size=conv_kernel_size, stride=1, padding=padding)

    def forward(self, x, bound):
        out = self.spade1(x, bound)
        out = self.relu(out)
        out = self.conv1(out)
        out = self.spade2(out, bound)
        out = self.relu(out)
        out = self.conv2(out)

        residual = x
        residual = self.spade_res(residual, bound)
        residual = self.relu(residual)
        residual = self.conv_res(residual)

        out = out + residual

        return out

# Defines the generator.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
# at the bottleneck
class SPADEGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64,
                 norm_layer=nn.InstanceNorm2d, latent_chl=1024, up_mode='NF'):
        super(SPADEGenerator, self).__init__()

        layers = []
        self.up_mode = up_mode

        self.up1 = nn.ConvTranspose2d(in_channels=latent_chl, out_channels=512, kernel_size=4, stride=2, padding=1)
        self.up2 = nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1)

        if self.up_mode == 'convT':
            self.up3 = nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1)
            self.up4 = nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1)
            self.up5 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=4, stride=2, padding=1)
            self.up6 = nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1)
            self.up7 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1)
            self.up8 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1)
            # self.up3 = nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1)
            # self.up4 = nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1)
            # self.up5 = nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1)
            # self.up6 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=4, stride=2, padding=1)
            # self.up7 = nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1)
            # self.up8 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1)
        elif self.up_mode == 'NF':
            self.up3 = nn.Upsample(scale_factor=2, mode='nearest')
            self.up4 = nn.Upsample(scale_factor=2, mode='nearest')
            self.up5 = nn.Upsample(scale_factor=2, mode='nearest')
            self.up6 = nn.Upsample(scale_factor=2, mode='nearest')
            self.up7 = nn.Upsample(scale_factor=2, mode='nearest')
            self.up8 = nn.Upsample(scale_factor=2, mode='nearest')

        self.spade_blc3 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=512)
        self.spade_blc4 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=512)
        self.spade_blc5 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=1024+512)
        self.spade_blc6 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=512+256)
        self.spade_blc7 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=256+128)
        self.spade_blc8 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=128+64)

        self.conv5 = nn.Conv2d(in_channels=1024+512, out_channels=256, kernel_size=1, stride=1, padding=0)
        self.conv6 = nn.Conv2d(in_channels=512+256, out_channels=128, kernel_size=1, stride=1, padding=0)
        self.conv7 = nn.Conv2d(in_channels=256+128, out_channels=64, kernel_size=1, stride=1, padding=0)
        self.conv8 = nn.Conv2d(in_channels=128+64, out_channels=64, kernel_size=1, stride=1, padding=0)

        # self.spade_blc3 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=512)
        # self.spade_blc4 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=512)
        # self.spade_blc5 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=512)
        # self.spade_blc6 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=256)
        # self.spade_blc7 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=128)
        # self.spade_blc8 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=64)
        #
        # self.conv5 = nn.Conv2d(in_channels=1024 + 512, out_channels=512, kernel_size=1, stride=1, padding=0)
        # self.conv6 = nn.Conv2d(in_channels=512 + 512, out_channels=256, kernel_size=1, stride=1, padding=0)
        # self.conv7 = nn.Conv2d(in_channels=256 + 256, out_channels=128, kernel_size=1, stride=1, padding=0)
        # self.conv8 = nn.Conv2d(in_channels=128 + 128, out_channels=64, kernel_size=1, stride=1, padding=0)


        self.same = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1)
        self.tanh = nn.Tanh()


    def forward(self, input, latent_z, decoder_result): # input: bound, batch_size*17*256*256
        bound128 = F.interpolate(input, scale_factor=0.5)
        bound64 = F.interpolate(bound128, scale_factor=0.5)
        bound32 = F.interpolate(bound64, scale_factor=0.5)
        bound16 = F.interpolate(bound32, scale_factor=0.5)
        bound8 = F.interpolate(bound16, scale_factor=0.5)
        bound4 = F.interpolate(bound8, scale_factor=0.5)

        x_up1 = self.up1(latent_z)
        x_up2 = self.up2(x_up1)

        x_up3 = self.spade_blc3(x_up2, bound4) # 4*4 bound
        x_up3 = self.up3(x_up3)

        x_up4 = self.spade_blc4(x_up3, bound8) # 8*8 bound
        x_up4 = self.up4(x_up4)

        x_up5 = self.spade_blc5(torch.cat([x_up4, decoder_result[0]], 1), bound16) # 16*16 bound
        x_up5 = self.conv5(x_up5)
        x_up5 = self.up5(x_up5)

        x_up6 = self.spade_blc6(torch.cat([x_up5, decoder_result[1]], 1), bound32) # 32*32 bound
        x_up6 = self.conv6(x_up6)
        x_up6 = self.up6(x_up6)

        x_up7 = self.spade_blc7(torch.cat([x_up6, decoder_result[2]], 1), bound64) # 64*64 bound
        x_up7 = self.conv7(x_up7)
        x_up7 = self.up7(x_up7)

        x_up8 = self.spade_blc8(torch.cat([x_up7, decoder_result[3]], 1), bound128) # 128*128 bound
        x_up8 = self.conv8(x_up8)
        x_up8 = self.up8(x_up8)


        # x_up5 = self.conv5(torch.cat([x_up4, decoder_result[0]], 1))
        # x_up5 = self.spade_blc5(x_up5, bound16)  # 16*16 bound
        # x_up5 = self.up5(x_up5)
        #
        # x_up6 = self.conv6(torch.cat([x_up5, decoder_result[1]], 1))
        # x_up6 = self.spade_blc6(x_up6, bound32)  # 16*16 bound
        # x_up6 = self.up6(x_up6)
        #
        # x_up7 = self.conv7(torch.cat([x_up6, decoder_result[2]], 1))
        # x_up7 = self.spade_blc7(x_up7, bound64)  # 16*16 bound
        # x_up7 = self.up7(x_up7)
        #
        # x_up8 = self.conv8(torch.cat([x_up7, decoder_result[3]], 1))
        # x_up8 = self.spade_blc8(x_up8, bound128)  # 16*16 bound
        # x_up8 = self.up8(x_up8)


        x_out = self.same(x_up8)
        x_out = self.tanh(x_out)

        return x_out



# # define upSample Module
# class UpSampleBlock(nn.Module):
#     def __init__(self, input_nc, output_nc,
#                  outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
#         super(UpSampleBlock, self).__init__()
#
#         if type(norm_layer) == functools.partial:
#             use_bias = norm_layer.func == nn.InstanceNorm2d
#         else:
#             use_bias = norm_layer == nn.InstanceNorm2d
#
#         uprelu = nn.ReLU(True)
#         upnorm = norm_layer(output_nc)
#
#         if outermost:
#             upconv = nn.ConvTranspose2d(input_nc, output_nc,
#                                         kernel_size=4, stride=2,
#                                         padding=1)
#             up = [uprelu, upconv, nn.Tanh()]
#
#         elif innermost:
#             upconv = nn.ConvTranspose2d(input_nc, output_nc,
#                                         kernel_size=4, stride=2,
#                                         padding=1, bias=use_bias)
#             up = [uprelu, upconv, upnorm]
#
#         else:
#             upconv = nn.ConvTranspose2d(input_nc, output_nc,
#                                         kernel_size=4, stride=2,
#                                         padding=1, bias=use_bias)
#             up = [uprelu, upconv, upnorm]
#             if use_dropout:
#                 up = up + [nn.Dropout(0.5)]
#
#         self.up = nn.Sequential(*up)
#
#     def forward(self, x):
#         return self.up(x)




================================================
FILE: net/generator_net_concat_1Layer.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import functools
import net.base_net as base_net


###############################################################################
# define spatially adaptive normalized generator
# input: boundary and apperance latent vector
###############################################################################


def defineSPADEGenerator(input_nc, output_nc, ngf, norm='instance', use_dropout=False, init_type='normal',
               init_gain=0.02, gpu_ids=[], latent_chl=1024, up_mode='NF'):
    norm_layer = base_net.get_norm_layer(norm_type=norm)
    net = SPADEGenerator(input_nc, output_nc, ngf,
                            norm_layer=norm_layer, latent_chl=latent_chl, up_mode=up_mode)
    return base_net.init_net(net, init_type, init_gain, gpu_ids)


##############################################################################
# Classes
##############################################################################
class BasicSPADE(nn.Module):
    def __init__(self, norm_layer, input_nc, planes):
        super(BasicSPADE, self).__init__()
        self.conv_weight = nn.Conv2d(input_nc, planes, kernel_size=3, stride=1, padding=1)
        self.conv_bias = nn.Conv2d(input_nc, planes, kernel_size=3, stride=1, padding=1)
        self.norm = norm_layer(planes, affine=False)

    def forward(self, x, bound):
        out = self.norm(x)
        weight_norm = self.conv_weight(bound)
        bias_norm = self.conv_bias(bound)
        out = out * weight_norm + bias_norm
        return out


class ResBlkSPADE(nn.Module):
    def __init__(self, norm_layer, input_nc, planes, conv_kernel_size=1, padding=0):   # todo: change conv kernel size, kernel=3, padding=1 or kernel=1, padding=0
        super(ResBlkSPADE, self).__init__()
        self.spade1 = BasicSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(planes, planes, kernel_size=conv_kernel_size, stride=1, padding=padding)
        self.spade2 = BasicSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=conv_kernel_size, stride=1, padding=padding)
        self.spade_res = BasicSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=planes)
        self.conv_res = nn.Conv2d(planes, planes, kernel_size=conv_kernel_size, stride=1, padding=padding)

    def forward(self, x, bound):
        out = self.spade1(x, bound)
        out = self.relu(out)
        out = self.conv1(out)
        out = self.spade2(out, bound)
        out = self.relu(out)
        out = self.conv2(out)

        residual = x
        residual = self.spade_res(residual, bound)
        residual = self.relu(residual)
        residual = self.conv_res(residual)

        out = out + residual

        return out

# Defines the generator.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
# at the bottleneck
class SPADEGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64,
                 norm_layer=nn.InstanceNorm2d, latent_chl=1024, up_mode='NF'):
        super(SPADEGenerator, self).__init__()

        layers = []
        self.up_mode = up_mode

        self.up1 = nn.ConvTranspose2d(in_channels=latent_chl, out_channels=512, kernel_size=4, stride=2, padding=1)
        self.up2 = nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1)

        self.up3 = nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1)
        self.up4 = nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1)
        self.up5 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=4, stride=2, padding=1)
        self.up6 = nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1)
        self.up7 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1)
        self.up8 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1)

        # without concat
        self.spade_blc3 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=512)
        self.spade_blc4 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=512)
        self.spade_blc5 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=512)
        self.spade_blc6 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=256)
        self.spade_blc7 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=128)
        self.spade_blc8 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=64)

        self.conv5 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0)
        self.conv6 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1, stride=1, padding=0)
        self.conv7 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=1, stride=1, padding=0)
        self.conv8 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1, stride=1, padding=0)

        # # only concat out16
        # self.spade_blc3 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=512)
        # self.spade_blc4 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=512)
        # self.spade_blc5 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=1024+512)
        # self.spade_blc6 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=256)
        # self.spade_blc7 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=128)
        # self.spade_blc8 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=64)
        #
        # self.conv5 = nn.Conv2d(in_channels=1024+512, out_channels=256, kernel_size=1, stride=1, padding=0)
        # self.conv6 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1, stride=1, padding=0)
        # self.conv7 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=1, stride=1, padding=0)
        # self.conv8 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1, stride=1, padding=0)

        # # only concat out32
        # self.spade_blc3 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=512)
        # self.spade_blc4 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=512)
        # self.spade_blc5 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=512)
        # self.spade_blc6 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=256+512)
        # self.spade_blc7 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=128)
        # self.spade_blc8 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=64)
        #
        # self.conv5 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0)
        # self.conv6 = nn.Conv2d(in_channels=256+512, out_channels=128, kernel_size=1, stride=1, padding=0)
        # self.conv7 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=1, stride=1, padding=0)
        # self.conv8 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1, stride=1, padding=0)
        #
        # # only concat out64
        # self.spade_blc3 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=512)
        # self.spade_blc4 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=512)
        # self.spade_blc5 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=512)
        # self.spade_blc6 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=256)
        # self.spade_blc7 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=128+256)
        # self.spade_blc8 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=64)
        #
        # self.conv5 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0)
        # self.conv6 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1, stride=1, padding=0)
        # self.conv7 = nn.Conv2d(in_channels=128+256, out_channels=64, kernel_size=1, stride=1, padding=0)
        # self.conv8 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1, stride=1, padding=0)
        #
        # # only concat out128
        # self.spade_blc3 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=512)
        # self.spade_blc4 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=512)
        # self.spade_blc5 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=512)
        # self.spade_blc6 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=256)
        # self.spade_blc7 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=128)
        # self.spade_blc8 = ResBlkSPADE(norm_layer=norm_layer, input_nc=input_nc, planes=64+128)
        #
        # self.conv5 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0)
        # self.conv6 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1, stride=1, padding=0)
        # self.conv7 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=1, stride=1, padding=0)
        # self.conv8 = nn.Conv2d(in_channels=64+128, out_channels=64, kernel_size=1, stride=1, padding=0)


        self.same = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1)
        self.tanh = nn.Tanh()


    # def forward(self, input, latent_z, decoder_result): # input: bound, batch_size*17*256*256
    def forward(self, input, latent_z):
        bound128 = F.interpolate(input, scale_factor=0.5)
        bound64 = F.interpolate(bound128, scale_factor=0.5)
        bound32 = F.interpolate(bound64, scale_factor=0.5)
        bound16 = F.interpolate(bound32, scale_factor=0.5)
        bound8 = F.interpolate(bound16, scale_factor=0.5)
        bound4 = F.interpolate(bound8, scale_factor=0.5)

        x_up1 = self.up1(latent_z)
        x_up2 = self.up2(x_up1)

        x_up3 = self.spade_blc3(x_up2, bound4) # 4*4 bound
        x_up3 = self.up3(x_up3)

        x_up4 = self.spade_blc4(x_up3, bound8) # 8*8 bound
        x_up4 = self.up4(x_up4)

        # x_up5 = self.spade_blc5(torch.cat([x_up4, decoder_result[0]], 1), bound16) # 16*16 bound
        x_up5 = self.spade_blc5(x_up4, bound16)  # 16*16 bound
        x_up5 = self.conv5(x_up5)
        x_up5 = self.up5(x_up5)

        # x_up6 = self.spade_blc6(torch.cat([x_up5, decoder_result[1]], 1), bound32) # 32*32 bound
        x_up6 = self.spade_blc6(x_up5, bound32)
        x_up6 = self.conv6(x_up6)
        x_up6 = self.up6(x_up6)

        # x_up7 = self.spade_blc7(torch.cat([x_up6, decoder_result[2]], 1), bound64) # 64*64 bound
        x_up7 = self.spade_blc7(x_up6, bound64)  # 64*64 bound
        x_up7 = self.conv7(x_up7)
        x_up7 = self.up7(x_up7)

        # x_up8 = self.spade_blc8(torch.cat([x_up7, decoder_result[3]], 1), bound128) # 128*128 bound
        x_up8 = self.spade_blc8(x_up7, bound128)
        x_up8 = self.conv8(x_up8)
        x_up8 = self.up8(x_up8)

        x_out = self.same(x_up8)
        x_out = self.tanh(x_out)

        return x_out

================================================
FILE: net/vgg_net.py
================================================
import torch as th
from torch import nn
from torchvision.models import vgg16

import net.base_net as base_net
from utils.metric import gram_matrix
###############################################################################
# define
###############################################################################


def defineVGG(init_type='normal', init_gain=0.02, gpu_ids=[]):
    net = VGGNet()
    return base_net.init_net(net, init_type, init_gain, gpu_ids)


class VGGNet(nn.Module):
    def __init__(self):
        super(VGGNet, self).__init__()
        self.net = vgg16()
        vgg_path = 'pretrainModel/vgg16-397923af.pth'
        self.net.load_state_dict(th.load(vgg_path))

    def forward(self, x):
        map_ = ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
        vgg_layers = self.net.features
        layer_name_mapping = {
            '3': "relu1_2",
            '8': "relu2_2",
            '15': "relu3_3",
            '22': "relu4_3"
        }
        output = []
        for name, module in vgg_layers._modules.items():
            x = module(x)
            #x = nn.parallel.data_parallel(module, x, range(num_gpu))
            if name in layer_name_mapping:
                output.append(x)
        return output

    def perceptual_loss(self, x, target, loss_func):
        self.x_result = self.forward(x)
        self.target_result = self.forward(target)
        loss_ = 0
        for xx, yy in zip(self.x_result, self.target_result):
            loss_ += loss_func(xx, yy.detach())
        return loss_

    def style_loss(self, x, target, loss_func):
        #x_result = self.forward(x)
        #target_result = self.forward(target)
        loss_ = 0
        for xx, yy in zip(self.x_result, self.target_result):
            loss_ += loss_func(gram_matrix(xx), gram_matrix(yy.detach()))
        return loss_

================================================
FILE: opt/__init__.py
================================================


================================================
FILE: opt/config.py
================================================
# -*- coding: utf-8 -*-
import argparse
import torch


class BaseOptions():
    def __init__(self):
        """Reset the class; indicates the class hasn't been initailized"""
        self.initialized = False

    def initialize(self, misc_arg):
        # data set
        misc_arg.add_argument('--batch_size', type=int,
                            default=6, help='input batch size')
        misc_arg.add_argument('--no_flip', action='store_true',
                            help='if specified, do not flip the images for data augmentation')

        # net set
        misc_arg.add_argument('--input_nc', type=int, default=9,
                            help='# of input image channels')
        misc_arg.add_argument('--output_nc', type=int, default=3,
                            help='# of output image channels')
        misc_arg.add_argument('--ngf', type=int, default=64,
                            help='# of gen filters in first conv layer')
        misc_arg.add_argument('--ndf', type=int, default=64,
                            help='# of discrim filters in first conv layer')

        misc_arg.add_argument('--netD', type=str, default='basic',
                            help='selects model to use for netD')

        misc_arg.add_argument('--n_layers_D', type=int, default=3,
                            help='only used if netD==n_layers')
        misc_arg.add_argument('--gpu_ids', type=str, default='0',
                            help='gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU')

        # loss set
        misc_arg.add_argument('--loss_percept', action='store_true',
                              help='include perceptual loss')
        misc_arg.add_argument('--loss_faceID', action='store_true',
                              help='include face identity loss')
        # misc_arg.add_argument('--loss_percept', type=bool, default=True,
        #                       help='include perceptual loss')
        # misc_arg.add_argument('--loss_faceID', type=bool, default=True,
        #                       help='include face identity loss')

        misc_arg.add_argument('--gan_start_epoch', type=int, default=0,
                              help='start to include GAN loss from which epoch')

        # path and name
        misc_arg.add_argument('--name', type=str, default='experiment_name',
                            help='name of the experiment. It decides where to store samples and models')
        misc_arg.add_argument('--load_path', type=str, default='trained_model')
        misc_arg.add_argument('--load_model_iter', type=str, default='latest',
                            help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
        
        misc_arg.add_argument('--num_threads', default=4, type=int,
                            help='# threads for loading data')
        misc_arg.add_argument('--save_dir', type=str,
                            default='runs', help='output path')

        # norm and dropout
        misc_arg.add_argument('--norm', type=str, default='instance',
                            help='instance normalization or batch normalization for discriminator')
        # misc_arg.add_argument('--no_dropout', action='store_true',
        #                     help='no dropout for the generator')

        # init
        misc_arg.add_argument('--init_type', type=str, default='normal',
                            help='network initialization [normal|xavier|kaiming|orthogonal]')
        misc_arg.add_argument('--init_gain', type=float, default=0.02,
                            help='scaling factor for normal, xavier and orthogonal.')
        misc_arg.add_argument('--verbose', action='store_true',
                            help='if specified, print more debugging information')
        
        # display
        misc_arg.add_argument('--log_step', type=int, default=200,
                            help='log after n iters')
        misc_arg.add_argument('--save_step', type=int,
                            default=200, help='log after n iters')
                            
        misc_arg.add_argument('--save_by_iter', action='store_true',
                            help='whether saves model by iteration')
        misc_arg.add_argument('--phase', type=str, default='test',
                            help='train, val, test, etc')
        # optimizer                
        misc_arg.add_argument('--niter', type=int, default=100,
                            help='# of iter at starting learning rate')
        misc_arg.add_argument('--niter_decay', type=int, default=100,
                            help='# of iter to linearly decay learning rate to zero')
        misc_arg.add_argument('--beta1', type=float, default=0.5,
                            help='momentum term of adam')
        misc_arg.add_argument('--lr', type=float, default=0.0001,
                            help='initial learning rate for adam')
        misc_arg.add_argument('--no_lsgan', action='store_true',
                            help='do *not* use least square GAN, if false, use vanilla GAN')
        misc_arg.add_argument('--pool_size', type=int, default=50,
                            help='the size of image buffer that stores previously generated images')
        
        misc_arg.add_argument('--lr_policy', type=str, default='lambda',
                            help='learning rate policy: lambda|step|plateau|cosine')
        misc_arg.add_argument('--lr_decay_iters', type=int, default=50,
                            help='multiply by a gamma every lr_decay_iters iterations')
        self.initialized = True
        return misc_arg

    def get_config(self):
        """Initialize our parser with basic options(only once).
        Add additional model-specific and dataset-specific options.
        These options are defined in the <modify_commandline_options> function
        in model and dataset classes.
        """
        if not self.initialized:  # check if it has been initialized
            parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
            parser = self.initialize(parser)

        # get the basic options
        config, _ = parser.parse_known_args()

        # set gpu ids, transfrom string to int number
        str_ids = config.gpu_ids.split(',')
        config.gpu_ids = []
        for str_id in str_ids:
            id = int(str_id)
            if id >= 0:
                config.gpu_ids.append(id)
        if len(config.gpu_ids) > 0:
            torch.cuda.set_device(config.gpu_ids[0])

        return config


    


================================================
FILE: opt/configTrain.py
================================================
# -*- coding: utf-8 -*-
from opt.config import BaseOptions

class TrainOptions(BaseOptions):
    """This class includes training options.
    It also includes shared options defined in BaseOptions.
    """

    def initialize(self, misc_arg):
        misc_arg = BaseOptions.initialize(self, misc_arg)
        misc_arg.add_argument('--lambda_vgg', type=int, default=1)
        misc_arg.add_argument('--lambda_reconstruct', type=int, default=25)
        misc_arg.add_argument('--lambda_pix', type=int, default=25)
        misc_arg.add_argument('--lambda_id', type=int, default=1)
        misc_arg.add_argument('--lambda_gan', type=int, default=1)
        self.initialized = True
        return misc_arg



================================================
FILE: requirements.txt
================================================
tqdm==4.32.2
opencv_python==3.4.1.15
torch==0.4.1
torchvision==0.2.1
scipy==1.0.1
matplotlib==2.2.2
numpy==1.15.0
tensorboardX==1.8


================================================
FILE: test.py
================================================
import time
import scipy.misc as m
import numpy as np
import cv2
import torch
import torchvision.utils as vutils
import argparse
from tqdm import *
from model.spade_model import SpadeModel
from opt.configTrain import TrainOptions
from loader.dataset_loader_demo import DatasetLoaderDemo
from fusion.affineFace import *


parser = argparse.ArgumentParser()
parser.add_argument('--pose_path', type=str, default='data/poseGuide/imgs', help='path to pose guide images')
parser.add_argument('--ref_path', type=str, default='data/reference/imgs', help='path to appearance/reference images')
parser.add_argument('--pose_lms', type=str, default='data/poseGuide/lms_poseGuide.out', help='path to pose guide landmark file')
parser.add_argument('--ref_lms', type=str, default='data/reference/lms_ref.out', help='path to reference landmark file')
args = parser.parse_args()


if __name__ == '__main__':
    trainConfig = TrainOptions()
    opt = trainConfig.get_config()  # namespace of arguments
    # init test dataset
    dataset = DatasetLoaderDemo(gaze=(opt.input_nc == 9), imgSize=256)

    root = args.pose_path  # root to pose guide img
    path_Appears = args.pose_lms.format(root)  # root to pose guide dir&landmark
    dataset.loadBounds([path_Appears], head='{}/'.format(root))

    root = args.ref_path  # root to reference img
    path_Appears = args.ref_lms.format(root)   # root to reference dir&landmark
    dataset.loadAppears([path_Appears], '{}/'.format(root))
    dataset.setAppearRule('sequence')

    # dataloader
    data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=12, drop_last=False)
    print('dataset size: {}\n'.format(dataset.shape()))

    # output sequence: ref1-pose1, ref1-pose2,  ref1-pose3, ... ref2-pose1, ref2-pose2, ref2-pose3, ...
    boundNew = []
    appNew = []
    for aa in dataset.appearList:
        for bb in dataset.boundList:
            boundNew.append(bb)
            appNew.append(aa)
    dataset.boundList = boundNew
    dataset.appearList = appNew

    model = SpadeModel(opt)  # define model
    model.setup(opt)  # initilize schedules (if isTrain), load pretrained models
    model.set_logger(opt) # set writer to runs/test_res
    model.eval()

    iter_start_time = time.time()
    cnt = 1
    with torch.no_grad():
        for step, data in tqdm(enumerate(data_loader)):
            model.set_input(data)  # set device for data
            model.forward()
            # fusionNet
            for i in range(data['img_src'].shape[0]):
                img_gen = model.fake_B.cpu().numpy()[i].transpose(1, 2, 0)
                img_gen = (img_gen * 0.5 + 0.5) * 255.0
                img_gen = img_gen.astype(np.uint8)
                img_gen = dataset.gammaTrans(img_gen, 2.0) # model output image, 256*256*3
                # cv2.imwrite('output_noFusion/{}.jpg'.format(cnt), img_gen)

                lms_gen = data['pt_dst'].cpu().numpy()[i] / 255.0 # [146, 2]
                img_ref = data['img_src_np'].cpu().numpy()[i]
                lms_ref = data['pt_src'].cpu().numpy()[i] / 255.0
                lms_ref_parts, img_ref_parts = affineface_parts(img_ref, lms_ref, lms_gen)

                # fusion
                fuse_parts, seg_ref_parts, seg_gen = fusion(img_ref_parts, lms_ref_parts, img_gen, lms_gen, 0.1)
                fuse_eye, mask_eye, img_eye = lightEye(img_ref, lms_ref, fuse_parts, lms_gen, 0.1)
                # res = np.hstack([img_ref, img_pose, img_gen, fuse_eye])
                cv2.imwrite('output/{}.jpg'.format(cnt), fuse_eye)
                cnt += 1
    iter_end_time = time.time()

    print('length of dataset:', len(dataset))
    print('time per img: ', (iter_end_time - iter_start_time) / len(dataset))






================================================
FILE: utils/__init__.py
================================================
#coding:utf-8


================================================
FILE: utils/affineFace.py
================================================
from utils.points2heatmap import *
from utils.calcAffine import *


def affineface(img, src_pt, dst_pt, heatmapSize=256):
    # naive mode
    curves_src, _ = points2curves(src_pt)
    pts_fivesense_src = np.vstack(curves_src[1:])
    curves_dst, _ = points2curves(dst_pt)
    pts_fivesense_dst = np.vstack(curves_dst[1:])
    affine_mat = calAffine(pts_fivesense_src, pts_fivesense_dst)

    pt_aligned = affinePts(affine_mat, src_pt)
    img_aligned = affineImg(img, affine_mat)
    return img_aligned, pt_aligned


if __name__ == '__main__':
    pass



================================================
FILE: utils/affine_util.py
================================================
from __future__ import print_function
import torch
import numpy as np
import inspect
import re
import numpy as np
import os
import collections

import cv2

# Converts a Tensor into a Numpy array
# |imtype|: the desired type of the converted numpy array


def th_affine2d(x, matrix, output_img_width, output_img_height, center=True, is_landmarks=False):
    """
    2D Affine image transform on torch.Tensor

    """
    assert(matrix.ndim == 2)
    matrix = matrix[:2, :]
    transform_matrix = matrix
    src = x
    if is_landmarks:
        dst = np.empty((x.shape[0], 2), dtype=np.float32)
        for i in range(src.shape[0]):
            dst[i, :] = AffinePoint(np.expand_dims(
                src[i, :], axis=0), transform_matrix)

    else:
        # cols, rows, channels = src.shape
        dst = cv2.warpAffine(src, transform_matrix, (output_img_width, output_img_height),
                             cv2.INTER_AREA, cv2.BORDER_CONSTANT, borderValue=(0, 0, 0))
        # for gray image
        if dst.ndim == 2:
            dst = np.expand_dims(np.asarray(dst), axis=2)

    return dst


def AffinePoint(point, affine_mat):
    """
    Affine 2d point
    """
    assert(affine_mat.shape[0] == 2)
    assert(affine_mat.shape[1] == 3)
    assert(point.shape[1] == 2)

    point_x = point[0, 0]
    point_y = point[0, 1]
    result = np.empty((1, 2), dtype=np.float32)
    result[0, 0] = affine_mat[0, 0] * point_x + \
        affine_mat[0, 1] * point_y + \
        affine_mat[0, 2]
    result[0, 1] = affine_mat[1, 0] * point_x + \
        affine_mat[1, 1] * point_y + \
        affine_mat[1, 2]

    return result


def exchange_landmarks(input_tf, corr_list):
    """
    Exchange value of pair of landmarks
    """
    #print(corr_list.shape)
    for i in range(corr_list.shape[0]):
        temp = input_tf[corr_list[i][0], :].copy()
        input_tf[corr_list[i][0], :] = input_tf[corr_list[i][1], :]
        input_tf[corr_list[i][1], :] = temp

    return input_tf


================================================
FILE: utils/calcAffine.py
================================================
# -*- coding: utf-8 -*-
"""
Created on Fri Dec 29 13:43:03 2017
"""
import numpy as np
import os, sys, shutil
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm


# affine points via least square method
# src_p[input] -- np.array([[x,y],...])
# dst_p[input] -- list[float]
# affine_mat[output] -- np.array() | matrix of affine
# pt_align[output] -- np.array() | aligned points
def calAffine(src_p, dst_p):
    p_N = len(src_p)
    U = np.mat(list(dst_p[:, 0]) + list(dst_p[:, 1]))
    xx_src, yy_src = list(src_p[:, 0]), list(src_p[:, 1])

    X = np.mat(np.stack([xx_src + yy_src, yy_src + [-ii for ii in xx_src], [1 for ii in range(p_N)] + [0 for ii in range(p_N)], [0 for ii in range(p_N)] + [1 for ii in range(p_N)]], axis=1))

    result = np.linalg.pinv(X) * U.T

    affine_mat = np.zeros([2, 3])
    affine_mat[0][0] = result[0][0]
    affine_mat[0][1] = result[1][0]
    affine_mat[0][2] = result[2][0]
    affine_mat[1][0] = -result[1][0]
    affine_mat[1][1] = result[0][0]
    affine_mat[1][2] = result[3][0]
    return affine_mat


def affinePts(affine_mat, pt):
    src_align = pt.T
    new_align = np.mat(affine_mat[:2, :2]) * np.mat(src_align) + np.reshape(affine_mat[:, 2], (-1, 1))
    pt_align = np.array(np.reshape(new_align.T, -1))[0].reshape(-1, 2)
    return pt_align


# affine Image from pt_src to pt_mean
# img[input] -- np.array()
# pt_src,pt_mean[input] -- list[float] format = x0,y0,x1,y1,...,xn,yn
# img_align -- np.array() | aligned image
def affineImg(img, TransMat, dsize=256):
    img_align = cv2.warpAffine(img, TransMat, (dsize, dsize), borderValue=(155, 155, 155))
    return img_align


# if __name__ == '__main__':
#     path_src = '/media/heyue/8d1c3fac-68d3-4428-af91-bc478fbdd541/Project/Face2Face/detectface/samples/common/output/landmarks.txt'
#     output_pt = 'lms/lms.txt'
#     output_img = 'imgs'
#     affineList(path_src, output_pt, output_img, 'meanpose384.txt', k=2,
#                head='/media/heyue/8d1c3fac-68d3-4428-af91-bc478fbdd541/Project/Face2Face/Data/test')
#
#     '''
#     path_src = 'alignedPoints_256.txt'
#     output_pt = 'output/AU_points.txt'
#     output_img = 'output/AU'
#     head = '/media/heyue/8d1c3fac-68d3-4428-af91-bc478fbdd541/Project/Face2Face/net/GANimation/dataset_emo'
#     affineList(path_src,output_pt,output_img,'meanpose384.txt',k=2,head = head)
#     print('done')
#     '''

================================================
FILE: utils/lms.test
================================================
9/19256.png
59.2707332221 140.50139565 60.7953501505 155.840424728 63.0589911357 170.835862842 65.4653911442 185.631580966 68.932073428 200.297533329 73.1550212248 214.977364271 78.3881163686 228.991811878 85.782502989 242.139749636 94.4189019168 254.626618041 104.924653651 265.730908768 117.005647331 275.139883829 129.815078926 283.476031021 143.37127461 290.826952689 157.705309344 295.664839878 172.967870215 299.278662395 188.142331646 301.003628508 203.391112954 299.106629196 214.045289167 294.2596111 222.110408022 286.9237645 227.824291319 277.536297759 232.693593582 267.177742414 236.99960223 257.480352842 241.334591417 246.280666171 245.480590158 235.443775065 247.893018636 225.0246153 250.039055401 213.309511006 252.46659434 202.406234293 254.773921111 191.514755894 256.225473022 180.407283758 256.372423739 168.78606124 256.828096529 158.1029349 256.357407963 146.970661227 256.985755887 136.084872502 114.586808868 135.321729885 132.74553607 124.70268833 152.040049834 122.438303213 170.91265272 124.446420684 189.598270691 128.753271858 218.845519791 130.545330844 228.999942361 125.463643143 239.492576365 121.450320408 249.411474877 118.880680476 258.63657448 126.027651995 205.708653694 147.024191699 209.353867509 164.673407501 213.07951614 182.28544999 216.65254958 199.749653618 185.393391977 212.936340413 196.392694475 214.773223325 208.047812188 215.409272233 216.617837427 213.903127509 224.635758048 208.808564363 131.858618758 149.400812206 141.803898017 141.837268135 165.392099346 141.877548881 173.496075357 150.424830919 163.490303701 155.657311596 141.529345276 154.72979133 218.371201941 148.56854991 223.046388058 140.281171009 241.050909426 136.381149039 248.061006408 142.790943892 242.89240324 149.838912486 225.871684848 151.684002643 133.124127304 132.692006387 152.894224179 131.235446775 171.782386626 133.906632401 189.462860167 136.799074227 219.621570411 137.486735175 230.071839672 132.992997019 239.766994382 129.429999215 249.825871997 127.146855576 153.647803621 138.716963931 152.532061597 156.828066976 153.08423455 148.835235305 231.661808615 135.504713658 234.498881293 151.922103669 233.452793735 144.753889108 189.998753616 148.65085426 211.829934016 147.614811963 183.50083607 192.236196887 224.472572542 190.084188921 177.735356983 205.116053997 229.367980432 201.176585545 165.681562295 246.508253322 182.299302633 238.905006917 200.693193025 234.064595888 208.664367869 234.20392116 215.735333293 232.49493141 222.794397869 233.948770072 228.546120971 240.699878289 223.70320508 247.896904973 216.815692657 252.697373346 207.576602906 255.753809513 193.485770695 256.258864646 178.869826197 252.538349197 169.738702535 246.227459702 188.733957672 243.310082701 207.91329263 243.195118021 217.199482234 241.363744993 225.538937849 241.024455906 216.417664974 241.68074608 207.352268122 242.826354474 188.391096125 244.116216393 153.102035611 148.834301512 233.454847703 144.753781362 153.725930258 145.114166383 153.084286376 147.501284993 151.563270626 149.641380309 149.466570187 151.172505503 146.994850192 151.941333886 144.421277117 151.998757233 142.008750244 151.173766886 139.965573287 149.524089837 138.660797851 147.318713561 138.066185917 145.105605337 137.981505574 144.276587636 138.369191154 142.03460668 139.460180447 139.654124445 141.345427985 137.9398006 143.643632066 136.821013852 146.278905487 136.500155145 148.780330685 136.899632999 151.03364833 138.197076184 152.669510423 140.170197508 153.563862259 142.548142291 221.73503399 142.544143402 222.324681326 144.296157045 223.747033131 146.727416182 225.871748815 148.622716569 228.565630813 149.654700302 231.454922255 149.674084725 234.115583168 148.620042416 236.265651641 146.666641629 237.611757594 144.110753208 238.165864921 141.363673677 238.111969948 139.014414155 237.364507682 136.34796902 235.853386482 133.808087344 233.553001096 132.136889175 230.82140856 131.419945958 227.864313086 131.743955739 225.252816986 132.923802015 223.246766375 134.939675132 222.010432246 137.428703155 221.574681609 139.64437078
9/18313.png
106.48914616 132.672242138 107.415715889 145.711047527 108.089733661 158.217376145 109.097645413 170.943788737 111.49622388 183.266693209 114.497992701 195.61203933 117.4158239 207.928001363 121.647678133 219.699144661 125.653509329 231.889662553 130.676008838 243.633838019 136.936671504 254.620283763 143.778754614 265.558231432 150.979034341 276.034626287 158.980762336 285.070427474 168.636700268 293.83382138 179.453387051 300.075893109 192.162578599 302.277568977 206.841866015 301.3882526 220.759331914 297.429317905 233.685173048 290.916029164 246.050769856 282.338019796 256.605411476 272.805536915 266.830653498 261.502662583 276.33434463 250.115474513 284.151631088 238.342934186 289.489777878 224.304358509 293.74386166 210.381942125 297.746770812 196.27342125 300.729988231 181.733988216 302.378821257 167.174265776 304.249079801 152.938836638 304.658511111 138.276047208 305.071288243 123.954987635 112.847168039 131.899352655 125.202342376 123.503510827 138.703473483 124.821782401 152.072583536 127.758609354 165.981191284 131.636220127 203.018679197 126.642964252 219.322462372 119.756744827 236.917121769 114.683829451 254.534167436 112.995767178 270.818333052 122.148798935 182.686748366 149.842311592 181.890113083 168.336652534 180.591242808 186.82105254 179.082507225 205.369823936 164.673462438 213.922283489 173.873411976 217.958740668 183.403257567 219.123765839 194.323037043 216.449363124 205.249015343 213.573475192 128.488183749 151.31398383 136.87693746 147.744556713 155.183765167 148.024973775 163.032068502 152.258383687 154.814616274 154.89594445 137.053182665 154.737244799 213.122829346 150.005593506 220.596053083 142.584600921 242.236942333 140.518517549 252.98372581 143.783167642 243.973767857 150.058514219 222.740135367 151.989105048 124.883888821 133.189669134 138.785197683 134.127358206 152.084180916 136.918700167 165.205495626 139.182662383 203.899336715 134.347678637 220.240167696 129.001630836 236.932956357 124.692478267 254.349869402 122.707408406 145.810942286 146.553356026 145.984092888 155.815542063 145.753005675 151.444313936 231.248896524 139.688929543 233.28793687 152.067511801 232.949303358 146.533350971 173.385300061 150.939051443 197.712775443 149.666354053 164.406491946 193.347326232 205.614876812 192.808966763 159.113212686 205.590542538 211.774756737 204.963119795 159.334247959 246.17033926 166.361691723 240.000440723 175.891471738 237.031314863 183.560816212 238.085511132 191.568983323 236.614927606 206.898009026 239.205955284 223.0520094 245.333374235 211.611766532 253.858648396 199.353988616 259.518893372 185.877344433 261.787193023 175.724784804 260.756856459 166.151721142 254.625118803 163.100185194 246.472029038 173.701483102 246.022122151 184.69960444 247.829447571 201.401870894 245.822144351 218.333681184 245.686391746 201.318101332 246.444242567 185.088446002 247.508323029 173.863314132 246.861554284 145.771761246 151.443138211 232.961306923 146.532598507 152.459077151 150.360388735 151.880913147 152.034778676 150.687841843 153.463769705 149.114967358 154.396881419 147.323728686 154.826253251 145.522226445 154.779511285 143.873800922 154.134946958 142.489117828 152.927800571 141.65883591 151.36128393 141.311614635 149.823233132 141.344734109 149.402393953 141.743326134 147.807046173 142.598382835 146.188356315 144.025360235 145.126560133 145.65737248 144.500047682 147.49907096 144.358989834 149.212416454 144.633443477 150.760126283 145.545671647 151.860918363 146.88584938 152.480365294 148.597387485 227.924359649 143.328780536 228.776704934 145.535615715 230.567435638 147.669316343 232.937918279 149.103130649 235.667134951 149.687475432 238.451677638 149.496351824 240.994527509 148.41732722 243.08501528 146.477918036 244.344520333 143.994089069 244.817729155 141.474442122 244.82921745 140.215926534 244.110093791 137.803409348 242.600089712 135.41410888 240.322193509 133.871306219 237.694357661 133.065952494 234.826894059 133.147792164 232.21192646 133.96546826 229.994982439 135.677815247 228.497249141 138.001842309 227.831235297 140.545626305
9/19869.png
110.245464177 148.829301092 109.958189408 161.084203848 109.730843772 173.034262362 109.703148885 184.953274131 110.916384523 196.671940639 113.143185036 208.720416614 115.752609491 220.390328909 119.946544729 231.346135802 124.347461589 242.609061802 129.981212104 253.288024557 136.91716085 263.271156305 144.070928045 273.036081186 151.492454317 282.448806427 159.563820795 290.668665906 169.184807464 298.297906136 179.867381002 303.752905258 192.143799331 305.526843759 206.24083174 304.108790136 219.470034352 300.325274358 232.277552849 295.280395413 245.211004896 288.568311956 256.672340579 281.003872168 267.540954022 271.907935343 277.918025824 262.282467807 286.591136251 252.298826524 292.950527309 239.504488134 297.274798901 226.666050766 300.799673183 213.35461447 302.804299458 199.763470165 304.156516395 185.843153816 305.621658405 172.364673163 305.978605954 158.397528001 306.344036754 144.847427091 115.158408801 130.593369953 125.425441538 119.748865704 137.4032323 118.079374989 149.884781878 119.656907731 162.719434597 123.50112244 201.00370918 120.63808576 216.996071381 115.407738672 234.637367238 113.135471059 251.870550227 115.413017023 267.731011378 126.321561511 180.770483915 147.683802573 179.679613463 163.421609287 178.059323268 179.209566407 176.610655732 195.201430783 162.33964871 207.993385361 171.860432863 211.177135312 181.665372 211.519677529 193.687845484 207.812373844 205.711453541 205.052011661 127.2937918 151.246436155 135.399786431 143.656834038 156.388568435 144.762517163 163.54204846 152.166695511 154.903900711 155.65213948 135.779192622 155.485417804 210.936260237 150.107521268 218.781286879 141.369186129 242.343032089 139.727266587 252.365699653 145.971512562 243.113231495 151.393710048 221.003683915 152.36716455 125.458452893 127.680096213 137.771680259 126.615953447 150.045906506 128.786541335 162.193643974 131.462699002 201.756171426 128.96726159 218.000932428 124.592997073 234.776869397 122.025030265 251.804959575 123.101921725 145.938906608 141.262334928 145.436329759 156.675274923 145.53128924 150.086237403 230.399095311 137.532877223 232.004505762 152.798354303 231.712248934 146.590700504 172.575974279 149.75751461 195.776529787 148.601153999 162.457116329 186.948189109 204.796838533 184.870006129 156.005636829 199.960778122 213.203829315 196.137553083 151.395921113 233.946685725 162.837917004 228.078855349 177.219584857 225.339022655 184.891009804 225.126288288 192.518319946 224.257546259 215.058344801 224.357921954 237.677210995 227.443201326 224.046897741 243.328721249 207.601315789 255.323932135 187.126877986 260.766652313 171.94667495 258.166824733 159.785582063 247.588910705 155.872394208 235.230526391 170.417260005 232.66075966 185.720946114 232.710684618 209.613279863 229.613787876 232.91840261 229.14295881 210.711825919 243.101625664 186.32442677 247.400546622 169.584992156 244.641951312 145.545987519 150.085773429 231.731645027 146.590757006 154.172787887 146.897621223 153.402829241 149.263956198 151.804805735 151.350802737 149.604541104 152.769833872 147.037637941 153.421322286 144.426923571 153.320490677 142.024766445 152.369088795 140.030218457 150.623028808 138.784199077 148.331700328 138.268655644 146.046070342 138.23486678 145.229900865 138.734284085 143.000509643 139.933199842 140.664103987 141.918387227 139.020467918 144.287662217 138.061565141 146.952257375 137.877806511 149.463285489 138.401103214 151.696301445 139.799232027 153.302375298 141.844456817 154.131819591 144.307433446 224.682316483 143.905470924 225.62914445 146.252388917 227.410255457 148.207154693 229.696735598 149.455797814 232.273203237 149.933558815 234.846022687 149.702539529 237.165950716 148.642203992 239.048083799 146.805871587 240.133154552 144.485766239 240.487980555 142.230844643 240.449351438 141.557944585 239.807654626 139.385974986 238.475737586 137.148853232 236.415543577 135.653647638 234.027570233 134.791413109 231.37351349 134.694624745 228.895926096 135.273072995 226.731815861 136.735770404 225.246307919 138.827759218 224.554754566 141.347442455
9/18824.png
69.4807351723 151.801376805 70.2700786681 166.669154882 71.8143754474 181.241441228 73.4269978983 195.55873371 76.2527942641 209.799132528 80.0043269856 223.968911671 84.8441464888 237.466449121 91.8781325344 250.004923335 100.188678082 261.797310386 110.160991578 272.024613754 121.769699443 280.811142146 133.89600159 288.788219684 146.713776859 295.961609825 160.525106331 300.493339418 175.307859426 304.226228537 189.91308948 306.171458044 204.221558442 303.956575289 213.323620951 298.455847863 219.22790111 290.601383296 223.374967269 281.460247035 227.643970704 271.693185861 231.593874444 262.524965878 235.968566423 252.060185303 240.35214401 242.119502116 243.165813416 232.687239338 245.348476302 221.666267774 247.349904237 211.386183925 249.018476926 201.06432608 249.498083565 190.297317376 247.903824721 179.234798687 246.743679457 169.20659787 244.739155353 158.487321574 244.238579347 148.053784737 125.096213735 116.720124836 140.081668454 105.896495598 155.067981074 103.540847004 169.68936401 104.447227641 184.639339075 108.401194107 214.707007946 113.958461967 221.008128327 112.470160527 227.603550136 112.180608215 233.442581862 113.653703605 239.327521598 122.446098298 204.111672709 147.838110575 209.606371136 161.160861426 215.466825351 174.494745106 221.161057828 187.997954986 186.465877545 212.474209367 199.336571231 212.283758816 212.428688328 211.637724313 221.738915823 210.458519637 229.935808671 205.838826967 137.223800922 144.126809682 146.829807002 138.094431327 168.772421818 140.051422374 175.774119466 149.162405184 165.696483193 152.449315441 145.496437668 149.459778139 217.567376811 152.587047705 221.540459568 146.267503436 236.241405671 144.799717822 240.913939855 150.633364888 236.385024141 155.688155581 223.052157322 156.054708474 140.112800115 114.549578725 155.80052298 113.380681874 170.537003696 115.706630278 184.613403176 118.915844355 215.310954826 122.387534635 221.804408493 121.319734116 227.354584946 121.143225458 233.528637123 121.744218044 157.929718299 136.172247856 155.425268244 152.612596014 156.91544926 145.444983919 228.826987893 143.334772928 229.701924243 157.00342149 229.595553687 150.854403745 190.246896155 148.296107106 210.582852603 149.946050881 185.962823549 186.981646229 227.617763787 184.383302305 179.027541323 202.195566143 233.168787548 196.85038802 167.145586134 246.648869672 186.555274252 239.135192231 206.694950861 232.161149718 213.87626209 232.684044523 219.856638543 232.038303182 225.499612324 236.286903068 228.035337953 243.14564591 224.196178584 250.672063621 218.978127782 255.970407105 211.339631337 259.487898436 196.188639049 258.621431544 180.748136138 253.921478848 171.60830814 246.888249998 191.93566201 244.827638362 212.028395822 245.149003446 219.127299427 244.12759024 225.165272106 243.994396347 218.490940622 244.586956248 211.694394091 244.782245347 191.750990766 245.612113479 156.93274299 145.44396392 229.596218831 150.854364514 172.474556654 140.488510493 171.820439665 142.9779942 170.040697724 147.2392217 167.0493125 150.908460333 163.018419099 153.35081736 158.331348966 154.128525109 153.685226026 153.042946702 149.706897066 150.356131081 146.958625721 146.42337564 145.544828545 141.917325918 145.224489794 137.357462427 146.117691929 132.765403269 148.299332118 128.449100982 151.792609744 125.390351905 156.169352453 123.7645879 161.063809693 123.843677741 165.569437351 125.47818312 169.11164603 128.556488098 171.448275582 132.537550978 172.309353356 135.673205226 214.77651454 149.516489364 215.404518082 151.1674488 216.878394689 153.806414101 219.112808663 155.929964853 221.952639081 157.137431413 225.047389323 157.167151907 227.936031163 156.010942924 230.236804261 153.895562952 231.644607371 151.113999539 232.164982597 148.077079527 231.976165067 145.124705563 231.029750895 142.228362052 229.293492133 139.588185835 226.781661447 137.898825644 223.817967893 137.261998628 220.661911894 137.765855125 217.901300241 139.2255479 215.911027798 141.546879951 214.770808519 144.315998334 214.48742804 146.391818237
9/19595.png
74.4244224358 133.664743499 75.7215088889 148.601916915 77.5045565718 163.196653964 79.372922785 177.746965257 82.0947246067 192.140287097 85.1699660466 206.698737736 88.6776595779 221.007755836 94.1089962315 234.738807963 100.841536719 248.141829955 109.626610062 260.195864335 120.212759837 270.885950946 131.437117639 280.639295749 143.538643841 289.344346602 156.635847135 295.887845075 170.876182508 301.3601746 185.332310332 304.391321099 200.486917729 304.356448861 213.185886687 300.626640372 224.063752659 294.420254316 232.957510443 285.2431258 240.855873541 274.566582542 247.298216757 263.979042936 253.443482862 251.953029683 259.442473639 240.368286023 263.532753707 228.861978904 266.332035021 215.791926375 268.418502645 203.429471171 270.485203653 190.973954974 272.023174831 178.485633389 272.534792985 165.514580256 273.071758094 153.191967367 272.331694216 140.52304056 272.151693699 128.190352571 110.303971361 137.062857659 126.409594615 126.71107324 143.870491914 125.333281857 161.272057564 127.566410872 178.541846833 131.663706289 214.532898112 131.261295824 228.119166808 125.812189923 242.347597053 121.660183258 256.199047753 119.927602512 268.12735852 128.136923585 198.646036779 146.974272846 201.163338592 166.444477235 203.403248911 186.132748885 205.394274456 205.564629019 178.846316127 213.16832158 190.005644421 216.508481686 201.593162862 218.767305983 210.562510683 215.610155382 219.750087388 210.43823679 130.382465566 148.780554258 140.671786435 143.040899546 161.959339805 142.842848778 169.837128152 149.631126133 160.611529936 152.593054854 140.500967755 152.198327072 216.285222655 147.351955144 222.526275362 140.047176789 241.920904494 137.473142137 251.133152725 142.102756151 242.984825721 147.31136935 224.272999256 148.560671699 126.614115503 135.246504733 144.40996987 134.085376593 161.760119924 136.556068166 178.108201583 139.09450361 215.24301508 138.022410922 229.031723194 133.594276038 242.345269617 130.225272566 256.028206416 128.458423855 151.346459205 140.686512334 150.565711107 153.476446626 150.574161738 148.042763817 231.972476668 136.613570114 233.484430776 148.655362906 233.628068292 143.788081302 184.701490912 148.204111106 207.32031875 146.925008419 177.405145437 192.669867461 219.294310999 191.203743729 171.303442528 204.762176519 225.387123161 202.379446638 155.321585787 240.472759857 173.410240067 239.224676358 192.392173918 238.744526493 200.378743862 238.942749996 208.05991076 236.743701944 219.250296194 236.788541255 231.401795941 236.95067661 222.963494325 246.363338758 212.79694842 253.722414247 200.062823497 257.332821485 184.178010836 256.469342495 168.672586426 249.706529102 159.94236505 241.828740835 179.914510128 243.970996493 200.120400264 246.304086892 213.956847516 242.720078847 227.168608703 238.686498525 213.556908547 243.142432657 199.828198637 245.971925079 179.588351937 244.872634313 150.591700421 148.041968548 233.629448009 143.788694692 157.052391021 144.000865004 156.725924638 145.193851973 155.537185169 148.04686971 153.535098909 150.433859528 150.784096514 151.986885995 147.641126014 152.587695073 144.480942824 152.182792582 141.550740224 150.79747126 139.311047603 148.540631736 137.974005847 145.767668669 137.51970161 143.157464544 137.892553227 140.080524813 139.207082419 137.094724392 141.521431953 134.960576324 144.416661598 133.741013759 147.6975064 133.527356833 150.798184153 134.125575894 153.556505935 135.700521927 155.619484955 138.040404932 156.704023739 140.834956686 224.447863221 141.639108912 225.19253752 143.676458598 226.759139195 145.668427655 228.889131066 147.026494014 231.377366851 147.638964412 233.92721994 147.538215636 236.275745596 146.619906878 238.22486648 144.898427306 239.432089843 142.646844392 239.911467719 140.339039045 239.942716151 139.194587831 239.334721309 136.981376723 238.022392403 134.768332983 235.982589614 133.297348897 233.603989522 132.508891438 230.983595582 132.499954865 228.565585496 133.182803358 226.493655688 134.700306517 225.065470119 136.789540991 224.407701278 139.098675723
9/19022.png
105.853911786 150.688842749 106.304755687 163.478940126 107.058193774 175.858800554 108.093608467 188.254071331 110.229936105 200.396513414 112.766738097 212.860480821 115.210241137 225.029971473 118.716371193 236.735315629 122.300263461 248.923973988 127.100442957 260.508494302 133.61804031 271.318109138 140.868919038 281.410912745 149.169107994 290.565297806 158.540567204 297.753250206 169.432213007 304.124802445 180.946535516 308.6148877 193.542154585 309.948068083 207.549631581 308.505301114 220.76968111 305.208204526 233.451038713 300.090467454 245.975225286 292.97329979 256.672108338 284.777100812 266.381493607 274.504398947 275.362423646 263.545693043 282.061802271 252.456912497 286.854969624 238.833684208 289.928046451 225.49834043 292.790172123 212.034993583 294.366884196 198.369996927 295.474931021 184.364889576 296.831564829 170.878509905 296.903260249 156.938634422 297.266115108 143.399566056 116.304739438 132.882383554 127.912101677 120.314053712 141.71663141 117.955114229 155.815006556 119.31210216 170.258223912 123.024928223 199.684637675 121.269258893 215.400914864 115.767697653 232.299514693 113.01896136 248.852728629 114.843558631 264.054600974 126.807430409 185.04317789 141.528634167 185.183521016 157.61319909 184.738888808 173.74555895 184.290973012 189.749690102 168.237920197 204.288286214 177.145813987 206.714088635 186.403867424 207.14750767 196.980990698 204.61604892 207.730063248 202.267617824 132.39440308 145.597696654 140.744441652 137.702594423 160.522211282 138.494458912 167.145615638 145.593700621 159.357529903 149.313758229 141.099701459 149.300662591 209.918812446 143.450660882 216.583980293 134.931619486 237.958407537 133.411665142 247.353427929 140.202166925 238.825297548 145.378180406 218.771264736 146.113453826 128.244438385 128.968131144 142.40245084 127.061883512 156.396350369 128.854706759 170.03519427 131.101479494 200.311398144 129.538880662 216.186949721 125.32534772 232.331881862 122.835928446 248.788032596 123.68794603 150.760209549 135.090558168 150.280482337 150.424354998 150.065194766 143.988975242 227.144313534 131.12358597 228.71299771 146.634009031 228.667597337 140.35964247 176.638400006 143.360093585 197.331703788 142.19173124 168.345583532 183.783124179 207.410676297 182.164130291 162.476447813 196.749745738 213.989893988 193.982610498 158.243354322 242.16225384 166.597125534 231.693527273 179.101008482 226.682022275 187.224280453 227.232709673 195.451885443 225.917992458 210.86120123 229.524986825 225.09774603 239.637971883 214.81208181 249.816069769 202.589610148 256.6727466 188.389553595 259.804956041 176.405821916 258.802618562 165.482536053 252.01117807 161.314320622 241.739365482 173.479273129 235.07489689 187.42068506 234.461425635 204.747234017 234.212221293 221.29974001 239.398270593 205.17836789 245.854164848 188.109749171 248.6157591 173.864728577 247.179831606 150.082426937 143.988441253 228.682761647 140.35917256 161.280066397 138.860105662 161.07098251 139.415783897 160.109143918 142.4199581 158.282766052 145.044542752 155.599418544 146.843596422 152.426789823 147.570452718 149.177360876 147.26614573 146.139145028 146.007304175 143.726931864 143.820401682 142.245033322 140.974057598 141.717994561 137.837323827 142.07024103 134.620906772 143.365123516 131.565705272 145.677808647 129.360718675 148.628181238 128.163305121 151.962385722 127.997199873 155.123444948 128.67819745 157.887808842 130.317407447 159.944511568 132.720931168 160.98639282 135.639556093 220.549012597 137.867294167 221.362117279 140.127207891 223.248967977 142.482178967 225.797227766 144.094278085 228.778872635 144.885707534 231.874281312 144.895349546 234.71252579 143.909085163 237.047996569 141.882779151 238.496647026 139.180896086 239.074888831 136.22172627 238.876118418 133.503278306 237.700859522 130.755754057 235.615071044 128.460073692 232.815525094 127.177277495 229.788044255 126.710164399 226.687095102 127.211946065 224.048886248 128.552854942 222.092566418 130.813723559 220.85843946 133.491710279 220.414261706 134.810167405
9/19916.png
95.2767910751 126.809887821 96.1610041532 140.668978114 97.1712750455 154.117666679 98.4616574186 167.729488962 100.946784689 181.011968182 103.82395298 194.399018442 106.804852556 207.638992554 111.326188645 220.272837509 116.288335188 233.153280591 122.68213832 245.247516364 130.50029673 256.423697433 138.752939865 267.349674917 147.384651295 277.874381193 156.911089211 286.907533908 168.162598012 295.261146595 180.504350865 300.860278984 194.363685923 302.529834898 209.025394305 300.909261294 222.570543927 296.277509068 234.922790029 288.92073999 246.79259185 279.454876017 256.627390857 269.182875567 265.980324209 257.221358307 274.684298266 245.28327563 281.807384886 233.053646155 286.57573427 218.81150233 289.923920006 204.836247919 292.906251951 190.753547444 295.102180642 176.380123942 296.321788426 161.903134424 297.759033857 147.813987976 297.945669692 133.238594198 298.320027286 119.04763484 108.261737636 134.004490065 122.09696225 125.843724022 137.160102923 127.177999277 152.081053372 130.451708983 167.373611453 134.708808776 207.540427842 133.687352451 223.633872672 128.039538618 240.641058617 123.614570224 257.345820876 121.735653215 272.721029723 129.621207649 186.223892518 151.843003149 186.594737554 171.41424043 186.656708261 191.009175528 186.546199842 210.555770115 168.448512052 215.80480723 178.302937662 220.021394777 188.381443098 222.257169315 198.995154047 218.026183293 209.769180464 213.865183
Download .txt
gitextract_8hbwdasm/

├── .gitignore
├── LICENSE
├── README.md
├── data/
│   ├── poseGuide/
│   │   └── lms_poseGuide.out
│   └── reference/
│       └── lms_ref.out
├── fusion/
│   ├── README.md
│   ├── affineFace.py
│   ├── calcAffine.py
│   ├── parts2lms.py
│   ├── points2heatmap.py
│   ├── test.py
│   └── warper.py
├── loader/
│   ├── __init__.py
│   ├── dataset_basic.py
│   ├── dataset_loader_demo.py
│   └── dataset_loader_train.py
├── model/
│   ├── base_model.py
│   └── spade_model.py
├── net/
│   ├── ResNet.py
│   ├── appear_decoder_net.py
│   ├── appear_encoder_net.py
│   ├── base_net.py
│   ├── discriminator_net.py
│   ├── face_id_mlp_net.py
│   ├── face_id_net.py
│   ├── generaotr_net.py
│   ├── generator_net_concat_1Layer.py
│   └── vgg_net.py
├── opt/
│   ├── __init__.py
│   ├── config.py
│   └── configTrain.py
├── requirements.txt
├── test.py
└── utils/
    ├── __init__.py
    ├── affineFace.py
    ├── affine_util.py
    ├── calcAffine.py
    ├── lms.test
    ├── metric.py
    ├── points2heatmap.py
    ├── transforms.py
    └── warper.py
Download .txt
SYMBOL INDEX (223 symbols across 30 files)

FILE: fusion/affineFace.py
  function gammaTrans (line 13) | def gammaTrans(img, gamma):
  function erodeAndBlur (line 18) | def erodeAndBlur(img,kernelSize=21,blurSize=21):
  function affineface (line 25) | def affineface(img,src_pt,dst_pt,heatmapSize=256,needImg=True):
  function affineface_parts (line 41) | def affineface_parts(img,src_pt,dst_pt):
  function lightEye (line 73) | def lightEye(img_ref,lms_ref,img_gen,lms_gen,ratio=0.1):
  function multi (line 109) | def multi(img,mask):
  function fusion (line 117) | def fusion(img_ref,lms_ref,img_gen,lms_gen,ratio=0.2):
  function loaddata (line 150) | def loaddata(head,path_lms,flag=256,num = 50000):
  function gray2rgb (line 167) | def gray2rgb(img):
  function process (line 171) | def process(index, album_ref, album_gen, album_pose):

FILE: fusion/calcAffine.py
  function calAffine (line 14) | def calAffine(src_p, dst_p):
  function affinePts (line 34) | def affinePts(affine_mat,pt):
  function affineImg (line 44) | def affineImg(img,TransMat,dsize = 256):

FILE: fusion/parts2lms.py
  function parts2lms (line 3) | def parts2lms(parts):

FILE: fusion/points2heatmap.py
  function curve_interp (line 6) | def curve_interp(points, heatmapSize=256, sigma=3):
  function curve_fill (line 14) | def curve_fill(points, heatmapSize=256, sigma=3, erode=False):
  function curves2heatmap (line 30) | def curves2heatmap(curves,heatmapSize=256,sigma=3,flag='line'):
  function curves2segments (line 47) | def curves2segments(curves,heatmapSize=256,sigma=3):
  function curves2gaze (line 71) | def curves2gaze(curves,heatmapSize=256,sigma=3):
  function curves2parts (line 81) | def curves2parts(curves):
  function points2curves (line 96) | def points2curves(points, heatmapSize=256,  sigma=1, heatmap_num=17):
  function distance (line 212) | def distance(p1, p2):
  function curve_fitting (line 215) | def curve_fitting(points, heatmap_size, sigma):

FILE: fusion/test.py
  function func (line 6) | def func(i):

FILE: fusion/warper.py
  function warping (line 8) | def warping(img, src_bound, dst_bound, size=(256, 256)):
  function bilinear_interpolate (line 21) | def bilinear_interpolate(img, coords):
  function grid_coordinates (line 46) | def grid_coordinates(points):
  function process_warp (line 60) | def process_warp(src_img, result_img, tri_affines, dst_points, delaunay):
  function triangular_affine_matrices (line 80) | def triangular_affine_matrices(vertices, src_points, dest_points):
  function warp_image (line 98) | def warp_image(src_img, src_points, dest_points, dest_shape, dtype=np.ui...

FILE: loader/dataset_basic.py
  class DatasetBasic (line 13) | class DatasetBasic(torch.utils.data.Dataset):
    method __init__ (line 14) | def __init__(self, imgSize=256):
    method __len__ (line 21) | def __len__(self):
    method shape (line 24) | def shape(self):
    method loadtxt (line 27) | def loadtxt(self, path, head=''):
    method loadtxtList (line 42) | def loadtxtList(self, pathList, head):
    method warp (line 48) | def warp(self, img, srcPt, dstPt):
    method np2tensor (line 54) | def np2tensor(self, img, scale=1/255.0):
    method points2heatmap (line 64) | def points2heatmap(self, landmarks, mapSize, sigma, landmarkSize=255.0...
    method gammaTrans (line 134) | def gammaTrans(self, img, gamma):
    method __getitem__ (line 139) | def __getitem__(self, index):

FILE: loader/dataset_loader_demo.py
  class DatasetLoaderDemo (line 8) | class DatasetLoaderDemo(DatasetBasic):
    method __init__ (line 9) | def __init__(self, imgSize=256, gaze=True):
    method loadBounds (line 16) | def loadBounds(self, pathList, head):
    method loadAppears (line 19) | def loadAppears(self, pathList, head):
    method setAppearRule (line 22) | def setAppearRule(self, flag='random'):
    method findSimilar (line 36) | def findSimilar(self, pt_dst):
    method adjustPose (line 48) | def adjustPose(self, img_src, pt_src, pt_dst):
    method add_nose_bridge (line 52) | def add_nose_bridge(self, boundary, heatmap):
    method __getitem__ (line 60) | def __getitem__(self, index):
    method __len__ (line 101) | def __len__(self):

FILE: loader/dataset_loader_train.py
  class DatasetLoaderTrain (line 9) | class DatasetLoaderTrain(DatasetBasic):
    method __init__ (line 10) | def __init__(self, imgSize=256, gaze=True):
    method setSampleCurve (line 21) | def setSampleCurve(self, flag='Bound'):
    method transform (line 24) | def transform(self, img, pt):
    method loaddata (line 30) | def loaddata(self, pathList, head):
    method sampleCurve (line 35) | def sampleCurve(self, img, pt, flag='Bound'):
    method add_nose_bridge (line 51) | def add_nose_bridge(self, boundary, heatmap):
    method __getitem__ (line 59) | def __getitem__(self, index):
    method __len__ (line 87) | def __len__(self):

FILE: model/base_model.py
  class BaseModel (line 13) | class BaseModel():
    method modify_commandline_options (line 18) | def modify_commandline_options(parser, is_train):
    method name (line 21) | def name(self):
    method initialize (line 24) | def initialize(self, opt):
    method set_input (line 40) | def set_input(self, input):
    method forward (line 43) | def forward(self):
    method set_logger (line 46) | def set_logger(self, opt):
    method get_logger (line 59) | def get_logger(self, logdir):
    method save_config (line 71) | def save_config(self, config):
    method setup (line 80) | def setup(self, opt, parser=None):
    method load_networks_all (line 91) | def load_networks_all(self, prefix):
    method load_networks (line 101) | def load_networks(self, model, path):
    method eval (line 117) | def eval(self):
    method test (line 125) | def test(self):
    method get_image_paths (line 130) | def get_image_paths(self):
    method optimize_parameters (line 133) | def optimize_parameters(self):
    method update_learning_rate (line 137) | def update_learning_rate(self):
    method save_networks (line 161) | def save_networks(self, epoch):
    method print_networks (line 178) | def print_networks(self, verbose):
    method set_requires_grad (line 194) | def set_requires_grad(self, nets, requires_grad=False):

FILE: model/spade_model.py
  class SpadeModel (line 16) | class SpadeModel(BaseModel):
    method __init__ (line 17) | def __init__(self, opt):
    method set_input (line 86) | def set_input(self, input):
    method forward (line 99) | def forward(self):
    method backward_D_basic (line 111) | def backward_D_basic(self, netD, real, fake):
    method backward_D (line 122) | def backward_D(self):
    method backward_G (line 129) | def backward_G(self, epoch):
    method func_require_grad (line 172) | def func_require_grad(self, model_, flag_):
    method func_zero_grad (line 176) | def func_zero_grad(self, model_):
    method optimize_parameters (line 180) | def optimize_parameters(self, epoch):

FILE: net/ResNet.py
  function conv3x3 (line 23) | def conv3x3(in_planes, out_planes, stride=1):
  class BasicBlock (line 29) | class BasicBlock(nn.Module):
    method __init__ (line 32) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 42) | def forward(self, x):
  class IRBlock (line 61) | class IRBlock(nn.Module):
    method __init__ (line 64) | def __init__(self, inplanes, planes, stride=1, downsample=None, use_se...
    method forward (line 78) | def forward(self, x):
  class Bottleneck (line 99) | class Bottleneck(nn.Module):
    method __init__ (line 102) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 116) | def forward(self, x):
  class SEBlock (line 139) | class SEBlock(nn.Module):
    method __init__ (line 140) | def __init__(self, channel, reduction=16):
    method forward (line 150) | def forward(self, x):
  class ResNetFace (line 157) | class ResNetFace(nn.Module):
    method __init__ (line 158) | def __init__(self, block, layers, input_nc, use_se=True):
    method _make_layer (line 186) | def _make_layer(self, block, planes, blocks, stride=1):
    method forward (line 203) | def forward(self, x):
  class ResNet (line 222) | class ResNet(nn.Module):
    method __init__ (line 224) | def __init__(self, block, layers):
    method _make_layer (line 250) | def _make_layer(self, block, planes, blocks, stride=1):
    method forward (line 267) | def forward(self, x):
  function resnet18 (line 285) | def resnet18(pretrained=False, **kwargs):
  function resnet34 (line 296) | def resnet34(pretrained=False, **kwargs):
  function resnet50 (line 307) | def resnet50(pretrained=False, **kwargs):
  function resnet101 (line 318) | def resnet101(pretrained=False, **kwargs):
  function resnet152 (line 329) | def resnet152(pretrained=False, **kwargs):
  function resnet_face18 (line 340) | def resnet_face18(input_nc, use_se=True, **kwargs):

FILE: net/appear_decoder_net.py
  function defineAppDec (line 9) | def defineAppDec(input_nc, size_=256, norm='batch', init_type='normal', ...
  class appearDec (line 19) | class appearDec(nn.Module):
    method __init__ (line 20) | def __init__(self, input_c, norm_layer, size_=256):
    method forward (line 44) | def forward(self, input):
  class appearDec128 (line 52) | class appearDec128(nn.Module):
    method __init__ (line 53) | def __init__(self, input_c, norm_layer, size_=256):
    method forward (line 77) | def forward(self, input):

FILE: net/appear_encoder_net.py
  function defineAppEnc (line 9) | def defineAppEnc(input_nc, size_=256, norm='batch', init_type='normal', ...
  class appearEnc (line 60) | class appearEnc(nn.Module):
    method __init__ (line 61) | def __init__(self, input_c, norm_layer, size_=256, conv_k=4):
    method sample_z (line 86) | def sample_z(self, z_mu):
    method kl_loss (line 96) | def kl_loss(self, z_mu):
    method freeze (line 106) | def freeze(self):
    method forward (line 125) | def forward(self, input):

FILE: net/base_net.py
  function get_norm_layer (line 12) | def get_norm_layer(norm_type='instance'):
  function get_scheduler (line 26) | def get_scheduler(optimizer, opt):
  function init_weights (line 47) | def init_weights(net, init_type='normal', gain=0.02):
  function init_net (line 75) | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):

FILE: net/discriminator_net.py
  function define_D (line 10) | def define_D(input_nc, ndf, netD,
  class GANLoss (line 39) | class GANLoss(nn.Module):
    method __init__ (line 40) | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_...
    method get_target_tensor (line 49) | def get_target_tensor(self, input, target_is_real):
    method __call__ (line 56) | def __call__(self, input, target_is_real):
  class NLayerDiscriminator (line 62) | class NLayerDiscriminator(nn.Module):
    method __init__ (line 63) | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNo...
    method forward (line 106) | def forward(self, input):
  class PixelDiscriminator (line 110) | class PixelDiscriminator(nn.Module):
    method __init__ (line 111) | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_si...
    method forward (line 132) | def forward(self, input):

FILE: net/face_id_mlp_net.py
  class MLP (line 8) | class MLP(nn.Module):
    method __init__ (line 9) | def __init__(self, input_nc, output_nc):
    method forward (line 13) | def forward(self, input):
  class ArcMarginProduct (line 17) | class ArcMarginProduct(nn.Module):
    method __init__ (line 27) | def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_mar...
    method forward (line 42) | def forward(self, input, label):
  class AddMarginProduct (line 64) | class AddMarginProduct(nn.Module):
    method __init__ (line 74) | def __init__(self, in_features, out_features, s=30.0, m=0.40):
    method forward (line 83) | def forward(self, input, label):
    method __repr__ (line 99) | def __repr__(self):
  class SphereProduct (line 107) | class SphereProduct(nn.Module):
    method __init__ (line 116) | def __init__(self, in_features, out_features, m=4):
    method forward (line 139) | def forward(self, input, label):
    method __repr__ (line 166) | def __repr__(self):

FILE: net/face_id_net.py
  function defineFaceID (line 11) | def defineFaceID(input_nc=3, class_num=10173, init_type='normal', init_g...
  class faceIDNet (line 17) | class faceIDNet(nn.Module):
    method __init__ (line 18) | def __init__(self, input_nc, class_num):
    method forward (line 24) | def forward(self, input):
    method face_id_loss (line 29) | def face_id_loss(self, x, target, loss_func):

FILE: net/generaotr_net.py
  function defineSPADEGenerator (line 14) | def defineSPADEGenerator(input_nc, output_nc, ngf, norm='instance', use_...
  class BasicSPADE (line 61) | class BasicSPADE(nn.Module):
    method __init__ (line 62) | def __init__(self, norm_layer, input_nc, planes):
    method forward (line 68) | def forward(self, x, bound):
  class ResBlkSPADE (line 76) | class ResBlkSPADE(nn.Module):
    method __init__ (line 77) | def __init__(self, norm_layer, input_nc, planes, conv_kernel_size=1, p...
    method forward (line 87) | def forward(self, x, bound):
  class SPADEGenerator (line 108) | class SPADEGenerator(nn.Module):
    method __init__ (line 109) | def __init__(self, input_nc, output_nc, ngf=64,
    method forward (line 169) | def forward(self, input, latent_z, decoder_result): # input: bound, ba...

FILE: net/generator_net_concat_1Layer.py
  function defineSPADEGenerator (line 14) | def defineSPADEGenerator(input_nc, output_nc, ngf, norm='instance', use_...
  class BasicSPADE (line 25) | class BasicSPADE(nn.Module):
    method __init__ (line 26) | def __init__(self, norm_layer, input_nc, planes):
    method forward (line 32) | def forward(self, x, bound):
  class ResBlkSPADE (line 40) | class ResBlkSPADE(nn.Module):
    method __init__ (line 41) | def __init__(self, norm_layer, input_nc, planes, conv_kernel_size=1, p...
    method forward (line 51) | def forward(self, x, bound):
  class SPADEGenerator (line 72) | class SPADEGenerator(nn.Module):
    method __init__ (line 73) | def __init__(self, input_nc, output_nc, ngf=64,
    method forward (line 161) | def forward(self, input, latent_z):

FILE: net/vgg_net.py
  function defineVGG (line 12) | def defineVGG(init_type='normal', init_gain=0.02, gpu_ids=[]):
  class VGGNet (line 17) | class VGGNet(nn.Module):
    method __init__ (line 18) | def __init__(self):
    method forward (line 24) | def forward(self, x):
    method perceptual_loss (line 41) | def perceptual_loss(self, x, target, loss_func):
    method style_loss (line 49) | def style_loss(self, x, target, loss_func):

FILE: opt/config.py
  class BaseOptions (line 6) | class BaseOptions():
    method __init__ (line 7) | def __init__(self):
    method initialize (line 11) | def initialize(self, misc_arg):
    method get_config (line 106) | def get_config(self):

FILE: opt/configTrain.py
  class TrainOptions (line 4) | class TrainOptions(BaseOptions):
    method initialize (line 9) | def initialize(self, misc_arg):

FILE: utils/affineFace.py
  function affineface (line 5) | def affineface(img, src_pt, dst_pt, heatmapSize=256):

FILE: utils/affine_util.py
  function th_affine2d (line 16) | def th_affine2d(x, matrix, output_img_width, output_img_height, center=T...
  function AffinePoint (line 42) | def AffinePoint(point, affine_mat):
  function exchange_landmarks (line 63) | def exchange_landmarks(input_tf, corr_list):

FILE: utils/calcAffine.py
  function calAffine (line 17) | def calAffine(src_p, dst_p):
  function affinePts (line 36) | def affinePts(affine_mat, pt):
  function affineImg (line 47) | def affineImg(img, TransMat, dsize=256):

FILE: utils/metric.py
  function gram_matrix (line 5) | def gram_matrix(feat):

FILE: utils/points2heatmap.py
  function curve_interp (line 6) | def curve_interp(points, heatmapSize=256, sigma=3):
  function curve_fill (line 14) | def curve_fill(points, heatmapSize=256, sigma=3, erode=False):
  function curves2heatmap (line 30) | def curves2heatmap(curves,heatmapSize=256,sigma=3,flag='line'):
  function curves2segments (line 47) | def curves2segments(curves,heatmapSize=256,sigma=3):
  function points2curves (line 72) | def points2curves(points, heatmapSize=256,  sigma=1, heatmap_num=17):
  function distance (line 201) | def distance(p1, p2):
  function curve_fitting (line 204) | def curve_fitting(points, heatmap_size, sigma):

FILE: utils/transforms.py
  function initAlignTransfer (line 16) | def initAlignTransfer(size, mirror=False, gaze=True):
  class AffineCompose (line 39) | class AffineCompose(object):
    method __init__ (line 41) | def __init__(self,
    method __call__ (line 57) | def __call__(self, *inputs):
  function dealcurve (line 122) | def dealcurve(curve):

FILE: utils/warper.py
  function warping (line 8) | def warping(img, src_bound, dst_bound, size=(256, 256)):
  function bilinear_interpolate (line 21) | def bilinear_interpolate(img, coords):
  function grid_coordinates (line 46) | def grid_coordinates(points):
  function process_warp (line 60) | def process_warp(src_img, result_img, tri_affines, dst_points, delaunay):
  function triangular_affine_matrices (line 80) | def triangular_affine_matrices(vertices, src_points, dest_points):
  function warp_image (line 98) | def warp_image(src_img, src_points, dest_points, dest_shape, dtype=np.ui...
Condensed preview — 42 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (298K chars).
[
  {
    "path": ".gitignore",
    "chars": 1203,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "LICENSE",
    "chars": 1061,
    "preview": "MIT License\n\nCopyright (c) 2019 Stan\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof th"
  },
  {
    "path": "README.md",
    "chars": 3115,
    "preview": "# One-shot Face Reenactment\n\n[[Project]](https://wywu.github.io/projects/ReenactGAN/OneShotReenact.html) [[Paper]](https"
  },
  {
    "path": "data/poseGuide/lms_poseGuide.out",
    "chars": 15508,
    "preview": "pose_1.jpg\n90.44765716236793 148.92923857343578 91.57595251229804 162.8479934005091 92.9120000469461 176.46386363738992 "
  },
  {
    "path": "data/reference/lms_ref.out",
    "chars": 20331,
    "preview": "ref_1.png\n91.5305110043 157.054464213 92.3770956017 170.115229042 93.7010275718 182.787981616 95.1548975719 195.33233120"
  },
  {
    "path": "fusion/README.md",
    "chars": 200,
    "preview": "简介:\n\t此项目的意义是针对Face2Face的生成结果,利用reference进行纹理的融合.\n\n算法输入:\n\t生成图像img_gen\n\t生成图像106点+40点眼睛(optional)\n\n\t参考图像img_ref\n\t参考图像106点+4"
  },
  {
    "path": "fusion/affineFace.py",
    "chars": 6034,
    "preview": "from fusion.points2heatmap import *\nfrom fusion.calcAffine import *\nfrom fusion.warper import warping as warp\nimport mat"
  },
  {
    "path": "fusion/calcAffine.py",
    "chars": 2160,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"\nCreated on Fri Dec 29 13:43:03 2017\n\"\"\"\nimport numpy as np\nimport cv2\n\n\n#affine points via l"
  },
  {
    "path": "fusion/parts2lms.py",
    "chars": 1391,
    "preview": "import numpy as np\n\ndef parts2lms(parts):\n\tbound,browL,browR,eyeL,eyeR,nose,lipU,lipD,gazeL,gazeR = parts\n\tres = list()\n"
  },
  {
    "path": "fusion/points2heatmap.py",
    "chars": 8373,
    "preview": "import numpy as np\nimport cv2\nimport os\nimport math\n\ndef curve_interp(points, heatmapSize=256, sigma=3):\n\tsigma = max(1,"
  },
  {
    "path": "fusion/test.py",
    "chars": 604,
    "preview": "import multiprocessing\nimport time\nfrom tqdm import *\n\n# list = [1, 2, 3, 4]\ndef func(i):\n    msg = \"hello %d\" % (list[i"
  },
  {
    "path": "fusion/warper.py",
    "chars": 3845,
    "preview": "import numpy as np\nimport scipy.spatial as spatial\nfrom builtins import range\nimport cv2\nfrom matplotlib import pyplot a"
  },
  {
    "path": "loader/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "loader/dataset_basic.py",
    "chars": 4223,
    "preview": "# coding:utf-8\nimport sys\nfrom utils.points2heatmap import curves2segments,points2curves\nfrom utils import warper\nimport"
  },
  {
    "path": "loader/dataset_loader_demo.py",
    "chars": 3577,
    "preview": "from loader.dataset_basic import *\nimport random\nimport numpy as np\nimport copy\nimport torch as th\nfrom utils.affineFace"
  },
  {
    "path": "loader/dataset_loader_train.py",
    "chars": 3529,
    "preview": "from loader.dataset_basic import *\nfrom utils.transforms import initAlignTransfer\n#from utils.transforms import shakeCur"
  },
  {
    "path": "model/base_model.py",
    "chars": 7186,
    "preview": "import os\nimport torch\nfrom collections import OrderedDict\nimport net.base_net as base_net\nimport shutil\nfrom tensorboar"
  },
  {
    "path": "model/spade_model.py",
    "chars": 9631,
    "preview": "import sys\nfrom model.base_model import BaseModel\nimport net.vgg_net as vgg_net\nimport net.generaotr_net as generator_ne"
  },
  {
    "path": "net/ResNet.py",
    "chars": 11095,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"\nCreated on 18-5-21 下午5:26\n@author: ronghuaiyang\n\"\"\"\nimport torch\nimport torch.nn as nn\nimpor"
  },
  {
    "path": "net/appear_decoder_net.py",
    "chars": 3214,
    "preview": "import torch as th\nfrom torch import nn\nimport net.base_net as base_net\n################################################"
  },
  {
    "path": "net/appear_encoder_net.py",
    "chars": 5392,
    "preview": "import torch as th\nfrom torch import nn\nimport net.base_net as base_net\n################################################"
  },
  {
    "path": "net/base_net.py",
    "chars": 3104,
    "preview": "import torch\nimport torch.nn as nn\nfrom torch.nn import init\nimport functools\nfrom torch.optim import lr_scheduler\n\n####"
  },
  {
    "path": "net/discriminator_net.py",
    "chars": 4880,
    "preview": "import torch\nimport torch.nn as nn\nimport functools\nimport net.base_net as base_net\n####################################"
  },
  {
    "path": "net/face_id_mlp_net.py",
    "chars": 6394,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import Parameter\nimport math\n\n\nclass ML"
  },
  {
    "path": "net/face_id_net.py",
    "chars": 1110,
    "preview": "import torch as th\nfrom torch import nn\nfrom net.ResNet import resnet_face18 as resnet18\nfrom net.face_id_mlp_net import"
  },
  {
    "path": "net/generaotr_net.py",
    "chars": 12926,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport functools\nimport net.base_net as base_net\n\n\n##"
  },
  {
    "path": "net/generator_net_concat_1Layer.py",
    "chars": 11131,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport functools\nimport net.base_net as base_net\n\n\n##"
  },
  {
    "path": "net/vgg_net.py",
    "chars": 1848,
    "preview": "import torch as th\nfrom torch import nn\nfrom torchvision.models import vgg16\n\nimport net.base_net as base_net\nfrom utils"
  },
  {
    "path": "opt/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "opt/config.py",
    "chars": 6649,
    "preview": "# -*- coding: utf-8 -*-\nimport argparse\nimport torch\n\n\nclass BaseOptions():\n    def __init__(self):\n        \"\"\"Reset the"
  },
  {
    "path": "opt/configTrain.py",
    "chars": 702,
    "preview": "# -*- coding: utf-8 -*-\nfrom opt.config import BaseOptions\n\nclass TrainOptions(BaseOptions):\n    \"\"\"This class includes "
  },
  {
    "path": "requirements.txt",
    "chars": 132,
    "preview": "tqdm==4.32.2\nopencv_python==3.4.1.15\ntorch==0.4.1\ntorchvision==0.2.1\nscipy==1.0.1\nmatplotlib==2.2.2\nnumpy==1.15.0\ntensor"
  },
  {
    "path": "test.py",
    "chars": 3897,
    "preview": "import time\nimport scipy.misc as m\nimport numpy as np\nimport cv2\nimport torch\nimport torchvision.utils as vutils\nimport "
  },
  {
    "path": "utils/__init__.py",
    "chars": 14,
    "preview": "#coding:utf-8\n"
  },
  {
    "path": "utils/affineFace.py",
    "chars": 555,
    "preview": "from utils.points2heatmap import *\nfrom utils.calcAffine import *\n\n\ndef affineface(img, src_pt, dst_pt, heatmapSize=256)"
  },
  {
    "path": "utils/affine_util.py",
    "chars": 1979,
    "preview": "from __future__ import print_function\nimport torch\nimport numpy as np\nimport inspect\nimport re\nimport numpy as np\nimport"
  },
  {
    "path": "utils/calcAffine.py",
    "chars": 2383,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"\nCreated on Fri Dec 29 13:43:03 2017\n\"\"\"\nimport numpy as np\nimport os, sys, shutil\nimport cv2"
  },
  {
    "path": "utils/lms.test",
    "chars": 101679,
    "preview": "9/19256.png\n59.2707332221 140.50139565 60.7953501505 155.840424728 63.0589911357 170.835862842 65.4653911442 185.6315809"
  },
  {
    "path": "utils/metric.py",
    "chars": 364,
    "preview": "import torch as th\r\nimport torch.nn as nn\r\n\r\n\r\ndef gram_matrix(feat):\r\n    # https://github.com/pytorch/examples/blob/ma"
  },
  {
    "path": "utils/points2heatmap.py",
    "chars": 8148,
    "preview": "import numpy as np\nimport cv2\nimport os\nimport math\n\ndef curve_interp(points, heatmapSize=256, sigma=3):\n\tsigma = max(1,"
  },
  {
    "path": "utils/transforms.py",
    "chars": 6040,
    "preview": "\"\"\"\nAffine transforms implemented on torch tensors, and\nrequiring only one interpolation\n\"\"\"\n\nimport math\nimport random\n"
  },
  {
    "path": "utils/warper.py",
    "chars": 3845,
    "preview": "import numpy as np\nimport scipy.spatial as spatial\nfrom builtins import range\nimport cv2\nfrom matplotlib import pyplot a"
  }
]

About this extraction

This page contains the full source code of the bj80heyue/One_Shot_Face_Reenactment GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 42 files (282.7 KB), approximately 104.8k tokens, and a symbol index with 223 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!