Full Code of RenYurui/PIRender for AI

main 9e59f194f1a0 cached
56 files
241.6 KB
60.7k tokens
389 symbols
1 requests
Download .txt
Showing preview only (256K chars total). Download the full file or copy to clipboard to get everything.
Repository: RenYurui/PIRender
Branch: main
Commit: 9e59f194f1a0
Files: 56
Total size: 241.6 KB

Directory structure:
gitextract_uz8rj1lb/

├── .gitmodules
├── DatasetHelper.md
├── LICENSE.md
├── README.md
├── config/
│   ├── face.yaml
│   └── face_demo.yaml
├── config.py
├── data/
│   ├── __init__.py
│   ├── image_dataset.py
│   ├── vox_dataset.py
│   └── vox_video_dataset.py
├── demo_images/
│   └── expression.mat
├── generators/
│   ├── base_function.py
│   └── face_model.py
├── inference.py
├── intuitive_control.py
├── loss/
│   └── perceptual.py
├── requirements.txt
├── scripts/
│   ├── coeff_detector.py
│   ├── download_demo_dataset.sh
│   ├── download_weights.sh
│   ├── extract_kp_videos.py
│   ├── face_recon_images.py
│   ├── face_recon_videos.py
│   ├── inference_options.py
│   └── prepare_vox_lmdb.py
├── third_part/
│   └── PerceptualSimilarity/
│       ├── models/
│       │   ├── __init__.py
│       │   ├── base_model.py
│       │   ├── dist_model.py
│       │   ├── models.py
│       │   ├── networks_basic.py
│       │   └── pretrained_networks.py
│       ├── util/
│       │   ├── __init__.py
│       │   ├── html.py
│       │   ├── util.py
│       │   └── visualizer.py
│       └── weights/
│           ├── v0.0/
│           │   ├── alex.pth
│           │   ├── squeeze.pth
│           │   └── vgg.pth
│           └── v0.1/
│               ├── alex.pth
│               ├── squeeze.pth
│               └── vgg.pth
├── train.py
├── trainers/
│   ├── __init__.py
│   ├── base.py
│   └── face_trainer.py
└── util/
    ├── cudnn.py
    ├── distributed.py
    ├── flow_util.py
    ├── init_weight.py
    ├── io.py
    ├── logging.py
    ├── lpips.py
    ├── meters.py
    ├── misc.py
    └── trainer.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitmodules
================================================
[submodule "Deep3DFaceRecon_pytorch"]
	path = Deep3DFaceRecon_pytorch
	url = https://github.com/sicxu/Deep3DFaceRecon_pytorch


================================================
FILE: DatasetHelper.md
================================================
### Extract 3DMM Coefficients for Videos

We provide scripts for extracting 3dmm coefficients for videos by using [DeepFaceRecon_pytorch](https://github.com/sicxu/Deep3DFaceRecon_pytorch/tree/73d491102af6731bded9ae6b3cc7466c3b2e9e48).

1. Follow the instructions of their repo to build the environment of DeepFaceRecon.

2. Copy the provided scrips to the folder `Deep3DFaceRecon_pytorch`.

   ```bash
   cp scripts/face_recon_videos.py ./Deep3DFaceRecon_pytorch
   cp scripts/extract_kp_videos.py ./Deep3DFaceRecon_pytorch
   cp scripts/coeff_detector.py ./Deep3DFaceRecon_pytorch
   cp scripts/inference_options.py ./Deep3DFaceRecon_pytorch/options

   cd Deep3DFaceRecon_pytorch
   ```

3. Extract facial landmarks from videos.

   ```bash
   python extract_kp_videos.py \
   --input_dir path_to_viodes \
   --output_dir path_to_keypoint \
   --device_ids 0,1,2,3 \
   --workers 12
   ```

4. Extract coefficients for videos

   ```bash
   python face_recon_videos.py \
   --input_dir path_to_videos \
   --keypoint_dir path_to_keypoint \
   --output_dir output_dir \
   --inference_batch_size 100 \
   --name=model_name \
   --epoch=20 \
   --model facerecon
   ```

   





================================================
FILE: LICENSE.md
================================================
## creative commons

# Attribution-NonCommercial 4.0 International

Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.

### Using Creative Commons Public Licenses

Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.

* __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).

* __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).

## Creative Commons Attribution-NonCommercial 4.0 International Public License

By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.

### Section 1 – Definitions.

a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.

b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.

c. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.

d. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.

e. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.

f. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.

g. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.

h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.

i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.

j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.

k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.

l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.

### Section 2 – Scope.

a. ___License grant.___

   1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:

       A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and

       B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.

   2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.

   3. __Term.__ The term of this Public License is specified in Section 6(a).

   4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.

   5. __Downstream recipients.__

        A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.

        B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.

   6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).

b. ___Other rights.___

   1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.

   2. Patent and trademark rights are not licensed under this Public License.

   3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.

### Section 3 – License Conditions.

Your exercise of the Licensed Rights is expressly made subject to the following conditions.

a. ___Attribution.___

   1. If You Share the Licensed Material (including in modified form), You must:

       A. retain the following if it is supplied by the Licensor with the Licensed Material:

         i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);

         ii. a copyright notice;

         iii. a notice that refers to this Public License;

         iv. a notice that refers to the disclaimer of warranties;

         v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;

       B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and

       C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.

   2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.

   3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.

   4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License.

### Section 4 – Sui Generis Database Rights.

Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:

a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;

b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and

c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.

For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.

### Section 5 – Disclaimer of Warranties and Limitation of Liability.

a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__

b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__

c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.

### Section 6 – Term and Termination.

a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.

b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:

   1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or

   2. upon express reinstatement by the Licensor.

   For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.

c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.

d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.

### Section 7 – Other Terms and Conditions.

a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.

b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.

### Section 8 – Interpretation.

a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.

b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.

c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.

d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.

> Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
>
> Creative Commons may be contacted at creativecommons.org


================================================
FILE: README.md
================================================
<p align='center'>
  <b>
    <a href="https://renyurui.github.io/PIRender_web/"> Website</a>
    | 
    <a href="https://arxiv.org/abs/2109.08379">ArXiv</a>
    | 
    <a href="#Get-Start">Get Start</a>
    | 
    <a href="https://youtu.be/gDhcRcPI1JU">Video</a>
  </b>
</p> 


# PIRenderer

The source code of the ICCV2021 paper "[PIRenderer: Controllable Portrait Image Generation via Semantic Neural Rendering](https://arxiv.org/abs/2109.08379)" (ICCV2021)

The proposed **PIRenderer** can synthesis portrait images by intuitively controlling the face motions with fully disentangled 3DMM parameters. This model can be applied to tasks such as:

* **Intuitive Portrait Image Editing**

  <p align='center'>  
    <img src='https://renyurui.github.io/PIRender_web/intuitive_fast.gif' width='700'/>
  </p>
  <p align='center'>  
    <b>Intuitive Portrait Image Control</b> 
  </p>
  <p align='center'>  
    <img src='https://renyurui.github.io/PIRender_web/intuitive_editing_fast.gif' width='700'/>
  </p>
  <p align='center'>  
    <b>Pose & Expression Alignment</b> 
  </p>
  
  
* **Motion Imitation**
  <p align='center'> 
    <img src='https://user-images.githubusercontent.com/30292465/133969233-d7ce0c02-ce6a-4cef-bc5e-d8f55b709f81.gif' width='700'/>
  </p>
  <p align='center'>  
    <b>Same & Corss-identity Reenactment</b> 
  </p>
  
* **Audio-Driven Facial Reenactment**

  <p align='center'>  
    <img src='https://renyurui.github.io/PIRender_web/audio.gif' width='700'/>
  </p>
  <p align='center'>  
    <b>Audio-Driven Reenactment</b> 
  </p>

## News

* 2021.9.20 Code for PyTorch is available!



## Colab Demo

Coming soon


## Get Start

### 1). Installation

#### Requirements

* Python 3
* PyTorch 1.7.1
* CUDA 10.2

#### Conda Installation

```bash
# 1. Create a conda virtual environment.
conda create -n PIRenderer python=3.6
conda activate PIRenderer
conda install -c pytorch pytorch=1.7.1 torchvision cudatoolkit=10.2

# 2. Install other dependencies
pip install -r requirements.txt
```

### 2). Dataset

We train our model using the [VoxCeleb](https://arxiv.org/abs/1706.08612). You can download the demo dataset for inference or prepare the dataset for training and testing.

#### Download the demo dataset

The demo dataset contains all 514 test videos. You can download the dataset with the following code:

```bash
./scripts/download_demo_dataset.sh
```

Or you can choose to download the resources with these links: 

​	[Google Driven](https://drive.google.com/drive/folders/16Yn2r46b4cV6ZozOH6a8SdFz_iG7BQk1?usp=sharing) & [BaiDu Driven](https://pan.baidu.com/s/1e615bBHvM4Wz-2snk-86Xw) with extraction passwords ”p9ab“

Then unzip and save the files to `./dataset`

#### Prepare the dataset

1. The dataset is preprocessed follow the method used in [First-Order](https://github.com/AliaksandrSiarohin/video-preprocessing). You can follow the instructions in their repo to download and crop videos for training and testing.

2. After obtaining the VoxCeleb videos, we extract 3DMM parameters using [Deep3DFaceReconstruction](https://github.com/microsoft/Deep3DFaceReconstruction). 

   The folder are with format as:

   ```
   ${DATASET_ROOT_FOLDER}
   └───path_to_videos
       └───train
           └───xxx.mp4
           └───xxx.mp4
           ...
       └───test
           └───xxx.mp4
           └───xxx.mp4
           ...
   └───path_to_3dmm_coeff
       └───train
           └───xxx.mat
           └───xxx.mat
           ...
       └───test
           └───xxx.mat
           └───xxx.mat
           ...
   ```
   
   **News**: We provide Scripts for extracting 3dmm coeffs from videos. Please check the [DatasetHelper](./DatasetHelper.md) for more details.
   
3. We save the video and 3DMM parameters in a lmdb file. Please run the following code to do this 

   ```bash
   python scripts/prepare_vox_lmdb.py \
   --path path_to_videos \
   --coeff_3dmm_path path_to_3dmm_coeff \
   --out path_to_output_dir
   ```

### 3). Training and Inference

#### Inference

The trained weights can be downloaded by running the following code:

```bash
./scripts/download_weights.sh
```

Or you can choose to download the resources with these links: 

[Google Driven](https://drive.google.com/file/d/1-0xOf6g58OmtKtEWJlU3VlnfRqPN9Uq7/view?usp=sharing) & [Baidu Driven](https://pan.baidu.com/s/18B3xfKMXnm4tOqlFSB8ntg) with extraction passwards "4sy1".

Then unzip and save the files to `./result/face`.

**Reenactment**

Run the demo for face reenactment:

```bash
# same identity
python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 inference.py \
--config ./config/face_demo.yaml \
--name face \
--no_resume \
--output_dir ./vox_result/face_reenactment

# cross identity
python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 inference.py \
--config ./config/face_demo.yaml \
--name face \
--no_resume \
--output_dir ./vox_result/face_reenactment_cross \
--cross_id
```

The output results are saved at `./vox_result/face_reenactment` and `./vox_result/face_reenactment_cross`

**Intuitive Control**

Our model can generate results by providing intuitive controlling coefficients. 
We provide the following code for this task. Please note that you need to build the environment of [DeepFaceRecon](https://github.com/sicxu/Deep3DFaceRecon_pytorch/tree/73d491102af6731bded9ae6b3cc7466c3b2e9e48) first.

```bash
# 1. Copy the provided scrips to the folder `Deep3DFaceRecon_pytorch`.
cp scripts/face_recon_videos.py ./Deep3DFaceRecon_pytorch
cp scripts/extract_kp_videos.py ./Deep3DFaceRecon_pytorch
cp scripts/coeff_detector.py ./Deep3DFaceRecon_pytorch
cp scripts/inference_options.py ./Deep3DFaceRecon_pytorch/options

cd Deep3DFaceRecon_pytorch

# 2. Extracte the 3dmm coefficients of the demo images.
python coeff_detector.py \
--input_dir ../demo_images \
--keypoint_dir ../demo_images \
--output_dir ../demo_images \
--name=model_name \
--epoch=20 \
--model facerecon   

# 3. control the source image with our model
cd ..
python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 intuitive_control.py \
--config ./config/face_demo.yaml \
--name face \
--no_resume \
--output_dir ./vox_result/face_intuitive \
--input_name ./demo_images
```


#### Train

Our model can be trained with the following code

```bash
python -m torch.distributed.launch --nproc_per_node=4 --master_port 12345 train.py \
--config ./config/face.yaml \
--name face
```


## Citation

If you find this code is helpful, please cite our paper

```tex
@misc{ren2021pirenderer,
      title={PIRenderer: Controllable Portrait Image Generation via Semantic Neural Rendering}, 
      author={Yurui Ren and Ge Li and Yuanqi Chen and Thomas H. Li and Shan Liu},
      year={2021},
      eprint={2109.08379},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
```

## Acknowledgement 

We build our project base on [imaginaire](https://github.com/NVlabs/imaginaire). Some dataset preprocessing methods are derived from [video-preprocessing](https://github.com/AliaksandrSiarohin/video-preprocessing).



================================================
FILE: config/face.yaml
================================================
# How often do you want to log the training stats.
# network_list: 
#     gen: gen_optimizer
#     dis: dis_optimizer

distributed: True
image_to_tensorboard: True
snapshot_save_iter: 40000
snapshot_save_epoch: 20
snapshot_save_start_iter: 20000
snapshot_save_start_epoch: 10
image_save_iter: 1000
max_epoch: 200
logging_iter: 100
results_dir: ./eval_results

gen_optimizer:
    type: adam
    lr: 0.0001
    adam_beta1: 0.5
    adam_beta2: 0.999
    lr_policy:
        iteration_mode: True
        type: step
        step_size: 300000
        gamma: 0.2

trainer:
    type: trainers.face_trainer::FaceTrainer
    pretrain_warp_iteration: 200000
    loss_weight:
      weight_perceptual_warp: 2.5
      weight_perceptual_final: 4
    vgg_param_warp:
      network: vgg19
      layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1']
      use_style_loss: False
      num_scales: 4
    vgg_param_final:
      network: vgg19
      layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1']
      use_style_loss: True
      num_scales: 4      
      style_to_perceptual: 250
    init:
      type: 'normal'
      gain: 0.02
gen:
    type: generators.face_model::FaceGenerator
    param:
      mapping_net:
        coeff_nc: 73
        descriptor_nc: 256
        layer: 3
      warpping_net:
        encoder_layer: 5
        decoder_layer: 3
        base_nc: 32
      editing_net:
        layer: 3
        num_res_blocks: 2
        base_nc: 64
      common:
        image_nc: 3
        descriptor_nc: 256
        max_nc: 256
        use_spect: False
                

# Data options.
data:
    type: data.vox_dataset::VoxDataset
    path: ./dataset/vox_lmdb
    resolution: 256
    semantic_radius: 13
    train:
      batch_size: 5
      distributed: True
    val:
      batch_size: 8
      distributed: True




================================================
FILE: config/face_demo.yaml
================================================
# How often do you want to log the training stats.
# network_list: 
#     gen: gen_optimizer
#     dis: dis_optimizer

distributed: True
image_to_tensorboard: True
snapshot_save_iter: 40000
snapshot_save_epoch: 20
snapshot_save_start_iter: 20000
snapshot_save_start_epoch: 10
image_save_iter: 1000
max_epoch: 200
logging_iter: 100
results_dir: ./eval_results

gen_optimizer:
    type: adam
    lr: 0.0001
    adam_beta1: 0.5
    adam_beta2: 0.999
    lr_policy:
        iteration_mode: True
        type: step
        step_size: 300000
        gamma: 0.2

trainer:
    type: trainers.face_trainer::FaceTrainer
    pretrain_warp_iteration: 200000
    loss_weight:
      weight_perceptual_warp: 2.5
      weight_perceptual_final: 4
    vgg_param_warp:
      network: vgg19
      layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1']
      use_style_loss: False
      num_scales: 4
    vgg_param_final:
      network: vgg19
      layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1']
      use_style_loss: True
      num_scales: 4      
      style_to_perceptual: 250
    init:
      type: 'normal'
      gain: 0.02
gen:
    type: generators.face_model::FaceGenerator
    param:
      mapping_net:
        coeff_nc: 73
        descriptor_nc: 256
        layer: 3
      warpping_net:
        encoder_layer: 5
        decoder_layer: 3
        base_nc: 32
      editing_net:
        layer: 3
        num_res_blocks: 2
        base_nc: 64
      common:
        image_nc: 3
        descriptor_nc: 256
        max_nc: 256
        use_spect: False
                

# Data options.
data:
    type: data.vox_dataset::VoxDataset
    path: ./dataset/vox_lmdb_demo
    resolution: 256
    semantic_radius: 13
    train:
      batch_size: 5
      distributed: True
    val:
      batch_size: 8
      distributed: True




================================================
FILE: config.py
================================================
import collections
import functools
import os
import re

import yaml
from util.distributed import master_only_print as print


class AttrDict(dict):
    """Dict as attribute trick."""

    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self
        for key, value in self.__dict__.items():
            if isinstance(value, dict):
                self.__dict__[key] = AttrDict(value)
            elif isinstance(value, (list, tuple)):
                if isinstance(value[0], dict):
                    self.__dict__[key] = [AttrDict(item) for item in value]
                else:
                    self.__dict__[key] = value

    def yaml(self):
        """Convert object to yaml dict and return."""
        yaml_dict = {}
        for key, value in self.__dict__.items():
            if isinstance(value, AttrDict):
                yaml_dict[key] = value.yaml()
            elif isinstance(value, list):
                if isinstance(value[0], AttrDict):
                    new_l = []
                    for item in value:
                        new_l.append(item.yaml())
                    yaml_dict[key] = new_l
                else:
                    yaml_dict[key] = value
            else:
                yaml_dict[key] = value
        return yaml_dict

    def __repr__(self):
        """Print all variables."""
        ret_str = []
        for key, value in self.__dict__.items():
            if isinstance(value, AttrDict):
                ret_str.append('{}:'.format(key))
                child_ret_str = value.__repr__().split('\n')
                for item in child_ret_str:
                    ret_str.append('    ' + item)
            elif isinstance(value, list):
                if isinstance(value[0], AttrDict):
                    ret_str.append('{}:'.format(key))
                    for item in value:
                        # Treat as AttrDict above.
                        child_ret_str = item.__repr__().split('\n')
                        for item in child_ret_str:
                            ret_str.append('    ' + item)
                else:
                    ret_str.append('{}: {}'.format(key, value))
            else:
                ret_str.append('{}: {}'.format(key, value))
        return '\n'.join(ret_str)


class Config(AttrDict):
    r"""Configuration class. This should include every human specifiable
    hyperparameter values for your training."""

    def __init__(self, filename=None, args=None, verbose=False, is_train=True):
        super(Config, self).__init__()
        # Set default parameters.
        # Logging.

        large_number = 1000000000
        self.snapshot_save_iter = large_number
        self.snapshot_save_epoch = large_number
        self.snapshot_save_start_iter = 0
        self.snapshot_save_start_epoch = 0
        self.image_save_iter = large_number
        self.eval_epoch = large_number
        self.start_eval_epoch = large_number
        self.eval_epoch = large_number
        self.max_epoch = large_number
        self.max_iter = large_number
        self.logging_iter = 100
        self.image_to_tensorboard=False
        self.which_iter = args.which_iter
        self.resume = not args.no_resume


        self.checkpoints_dir = args.checkpoints_dir
        self.name = args.name
        self.phase = 'train' if is_train else 'test'

        # Networks.
        self.gen = AttrDict(type='generators.dummy')
        self.dis = AttrDict(type='discriminators.dummy')

        # Optimizers.
        self.gen_optimizer = AttrDict(type='adam',
                                    lr=0.0001,
                                    adam_beta1=0.0,
                                    adam_beta2=0.999,
                                    eps=1e-8,
                                    lr_policy=AttrDict(iteration_mode=False,
                                                    type='step',
                                                    step_size=large_number,
                                                    gamma=1))
        self.dis_optimizer = AttrDict(type='adam',
                                lr=0.0001,
                                adam_beta1=0.0,
                                adam_beta2=0.999,
                                eps=1e-8,
                                lr_policy=AttrDict(iteration_mode=False,
                                                   type='step',
                                                   step_size=large_number,
                                                   gamma=1))
        # Data.
        self.data = AttrDict(name='dummy',
                             type='datasets.images',
                             num_workers=0)
        self.test_data = AttrDict(name='dummy',
                                  type='datasets.images',
                                  num_workers=0,
                                  test=AttrDict(is_lmdb=False,
                                                roots='',
                                                batch_size=1))
        self.trainer = AttrDict(
            model_average=False,
            model_average_beta=0.9999,
            model_average_start_iteration=1000,
            model_average_batch_norm_estimation_iteration=30,
            model_average_remove_sn=True,
            image_to_tensorboard=False,
            hparam_to_tensorboard=False,
            distributed_data_parallel='pytorch',
            delay_allreduce=True,
            gan_relativistic=False,
            gen_step=1,
            dis_step=1)

        # # Cudnn.
        self.cudnn = AttrDict(deterministic=False,
                              benchmark=True)

        # Others.
        self.pretrained_weight = ''
        self.inference_args = AttrDict()


        # Update with given configurations.
        assert os.path.exists(filename), 'File {} not exist.'.format(filename)
        loader = yaml.SafeLoader
        loader.add_implicit_resolver(
            u'tag:yaml.org,2002:float',
            re.compile(u'''^(?:
             [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
            |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
            |\\.[0-9_]+(?:[eE][-+][0-9]+)?
            |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
            |[-+]?\\.(?:inf|Inf|INF)
            |\\.(?:nan|NaN|NAN))$''', re.X),
            list(u'-+0123456789.'))
        try:
            with open(filename, 'r') as f:
                cfg_dict = yaml.load(f, Loader=loader)
        except EnvironmentError:
            print('Please check the file with name of "%s"', filename)
        recursive_update(self, cfg_dict)

        # Put common opts in both gen and dis.
        if 'common' in cfg_dict:
            self.common = AttrDict(**cfg_dict['common'])
            self.gen.common = self.common
            self.dis.common = self.common


        if verbose:
            print(' config '.center(80, '-'))
            print(self.__repr__())
            print(''.center(80, '-'))


def rsetattr(obj, attr, val):
    """Recursively find object and set value"""
    pre, _, post = attr.rpartition('.')
    return setattr(rgetattr(obj, pre) if pre else obj, post, val)


def rgetattr(obj, attr, *args):
    """Recursively find object and return value"""

    def _getattr(obj, attr):
        r"""Get attribute."""
        return getattr(obj, attr, *args)

    return functools.reduce(_getattr, [obj] + attr.split('.'))


def recursive_update(d, u):
    """Recursively update AttrDict d with AttrDict u"""
    for key, value in u.items():
        if isinstance(value, collections.abc.Mapping):
            d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value)
        elif isinstance(value, (list, tuple)):
            if isinstance(value[0], dict):
                d.__dict__[key] = [AttrDict(item) for item in value]
            else:
                d.__dict__[key] = value
        else:
            d.__dict__[key] = value
    return d


================================================
FILE: data/__init__.py
================================================
import importlib

import torch.utils.data
from util.distributed import master_only_print as print

def find_dataset_using_name(dataset_name):
    dataset_filename = dataset_name
    module, target = dataset_name.split('::')
    datasetlib = importlib.import_module(module)
    dataset = None
    for name, cls in datasetlib.__dict__.items():
        if name == target:
            dataset = cls
            
    if dataset is None:
        raise ValueError("In %s.py, there should be a class "
                         "with class name that matches %s in lowercase." %
                         (dataset_filename, target))

    return dataset


def get_option_setter(dataset_name):    
    dataset_class = find_dataset_using_name(dataset_name)
    return dataset_class.modify_commandline_options


def create_dataloader(opt, is_inference):
    dataset = find_dataset_using_name(opt.type)
    instance = dataset(opt, is_inference)
    phase = 'val' if is_inference else 'training'
    batch_size = opt.val.batch_size if is_inference else opt.train.batch_size
    print("%s dataset [%s] of size %d was created" %
          (phase, opt.type, len(instance)))
    dataloader = torch.utils.data.DataLoader(
        instance,
        batch_size=batch_size,
        sampler=data_sampler(instance, shuffle=not is_inference, distributed=opt.train.distributed),
        drop_last=not is_inference,
        num_workers=getattr(opt, 'num_workers', 0),
    )          

    return dataloader


def data_sampler(dataset, shuffle, distributed):
    if distributed:
        return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
    if shuffle:
        return torch.utils.data.RandomSampler(dataset)
    else:
        return torch.utils.data.SequentialSampler(dataset)


def get_dataloader(opt, is_inference=False):
    dataset = create_dataloader(opt, is_inference=is_inference)
    return dataset


def get_train_val_dataloader(opt):
    val_dataset = create_dataloader(opt, is_inference=True)
    train_dataset = create_dataloader(opt, is_inference=False)
    return val_dataset, train_dataset


================================================
FILE: data/image_dataset.py
================================================
import os
import glob
import time
import numpy as np
from PIL import Image

import torch
import torchvision.transforms.functional as F



class ImageDataset():
    def __init__(self, opt, input_name):
        self.opt = opt
        self.IMAGEEXT = ['png', 'jpg']
        self.input_image_list, self.coeff_list = self.obtain_inputs(input_name)
        self.index = -1
        # load image dataset opt
        self.resolution = opt.resolution
        self.semantic_radius = opt.semantic_radius

    def next_image(self):
        self.index += 1
        image_name = self.input_image_list[self.index]
        coeff_name = self.coeff_list[self.index]
        img = Image.open(image_name)
        input_image = self.trans_image(img)

        coeff_3dmm = np.loadtxt(coeff_name).astype(np.float32)
        coeff_3dmm = self.transform_semantic(coeff_3dmm)
        
        return {
            'source_image': input_image[None],
            'target_semantics': coeff_3dmm[None],
            'name': os.path.splitext(os.path.basename(image_name))[0]
        }

    def obtain_inputs(self, root):
        filenames = list()

        IMAGE_EXTENSIONS_LOWERCASE = {'jpg', 'png', 'jpeg', 'webp'}
        IMAGE_EXTENSIONS = IMAGE_EXTENSIONS_LOWERCASE.union({f.upper() for f in IMAGE_EXTENSIONS_LOWERCASE})
        extensions = IMAGE_EXTENSIONS

        for ext in extensions:
            filenames += glob.glob(f'{root}/*.{ext}', recursive=True)
        filenames = sorted(filenames)
        coeffnames = sorted(glob.glob(f'{root}/*_3dmm_coeff.txt'))     

        return filenames, coeffnames

    def transform_semantic(self, semantic):
        semantic = semantic[None].repeat(self.semantic_radius*2+1, 0)
        ex_coeff = semantic[:,80:144] #expression
        angles = semantic[:,224:227] #euler angles for pose
        translation = semantic[:,254:257] #translation
        crop = semantic[:,259:262] #crop param

        coeff_3dmm = np.concatenate([ex_coeff, angles, translation, crop], 1)
        return torch.Tensor(coeff_3dmm).permute(1,0)   

    def trans_image(self, image):
        image = F.resize(
            image, size=self.resolution, interpolation=Image.BICUBIC)
        image = F.to_tensor(image)
        image = F.normalize(image, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        return image
        
    def __len__(self):
        return len(self.input_image_list)

        


================================================
FILE: data/vox_dataset.py
================================================
import os
import lmdb
import random
import collections
import numpy as np
from PIL import Image
from io import BytesIO

import torch
from torch.utils.data import Dataset
from torchvision import transforms

def format_for_lmdb(*args):
    key_parts = []
    for arg in args:
        if isinstance(arg, int):
            arg = str(arg).zfill(7)
        key_parts.append(arg)
    return '-'.join(key_parts).encode('utf-8')

class VoxDataset(Dataset):
    def __init__(self, opt, is_inference):
        path = opt.path
        self.env = lmdb.open(
            os.path.join(path, str(opt.resolution)),
            max_readers=32,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
        )

        if not self.env:
            raise IOError('Cannot open lmdb dataset', path)
        list_file = "test_list.txt" if is_inference else "train_list.txt"
        list_file = os.path.join(path, list_file)
        with open(list_file, 'r') as f:
            lines = f.readlines()
            videos = [line.replace('\n', '') for line in lines]

        self.resolution = opt.resolution
        self.semantic_radius = opt.semantic_radius
        self.video_items, self.person_ids = self.get_video_index(videos)
        self.idx_by_person_id = self.group_by_key(self.video_items, key='person_id')
        self.person_ids = self.person_ids * 100

        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
            ])

    def get_video_index(self, videos):
        video_items = []
        for video in videos:
            video_items.append(self.Video_Item(video))

        person_ids = sorted(list({video.split('#')[0] for video in videos}))

        return video_items, person_ids            

    def group_by_key(self, video_list, key):
        return_dict = collections.defaultdict(list)
        for index, video_item in enumerate(video_list):
            return_dict[video_item[key]].append(index)
        return return_dict  
    
    def Video_Item(self, video_name):
        video_item = {}
        video_item['video_name'] = video_name
        video_item['person_id'] = video_name.split('#')[0]
        with self.env.begin(write=False) as txn:
            key = format_for_lmdb(video_item['video_name'], 'length')
            length = int(txn.get(key).decode('utf-8'))
        video_item['num_frame'] = length
        
        return video_item

    def __len__(self):
        return len(self.person_ids)

    def __getitem__(self, index):
        data={}
        person_id = self.person_ids[index]
        video_item = self.video_items[random.choices(self.idx_by_person_id[person_id], k=1)[0]]
        frame_source, frame_target = self.random_select_frames(video_item)

        with self.env.begin(write=False) as txn:
            key = format_for_lmdb(video_item['video_name'], frame_source)
            img_bytes_1 = txn.get(key) 
            key = format_for_lmdb(video_item['video_name'], frame_target)
            img_bytes_2 = txn.get(key) 
            semantics_key = format_for_lmdb(video_item['video_name'], 'coeff_3dmm')
            semantics_numpy = np.frombuffer(txn.get(semantics_key), dtype=np.float32)
            semantics_numpy = semantics_numpy.reshape((video_item['num_frame'],-1))

        img1 = Image.open(BytesIO(img_bytes_1))
        data['source_image'] = self.transform(img1)

        img2 = Image.open(BytesIO(img_bytes_2))
        data['target_image'] = self.transform(img2) 

        data['target_semantics'] = self.transform_semantic(semantics_numpy, frame_target)
        data['source_semantics'] = self.transform_semantic(semantics_numpy, frame_source)
    
        return data
    
    def random_select_frames(self, video_item):
        num_frame = video_item['num_frame']
        frame_idx = random.choices(list(range(num_frame)), k=2)
        return frame_idx[0], frame_idx[1]

    def transform_semantic(self, semantic, frame_index):
        index = self.obtain_seq_index(frame_index, semantic.shape[0])
        coeff_3dmm = semantic[index,...]
        # id_coeff = coeff_3dmm[:,:80] #identity
        ex_coeff = coeff_3dmm[:,80:144] #expression
        # tex_coeff = coeff_3dmm[:,144:224] #texture
        angles = coeff_3dmm[:,224:227] #euler angles for pose
        # gamma = coeff_3dmm[:,227:254] #lighting
        translation = coeff_3dmm[:,254:257] #translation
        crop = coeff_3dmm[:,257:260] #crop param

        coeff_3dmm = np.concatenate([ex_coeff, angles, translation, crop], 1)
        return torch.Tensor(coeff_3dmm).permute(1,0)

    def obtain_seq_index(self, index, num_frames):
        seq = list(range(index-self.semantic_radius, index+self.semantic_radius+1))
        seq = [ min(max(item, 0), num_frames-1) for item in seq ]
        return seq





================================================
FILE: data/vox_video_dataset.py
================================================
import os
import lmdb
import random
import collections
import numpy as np
from PIL import Image
from io import BytesIO

import torch

from data.vox_dataset import VoxDataset
from data.vox_dataset import format_for_lmdb

class VoxVideoDataset(VoxDataset):
    def __init__(self, opt, is_inference):
        super(VoxVideoDataset, self).__init__(opt, is_inference)
        self.video_index = -1
        self.cross_id = opt.cross_id
        # whether normalize the crop parameters when performing cross_id reenactments
        # set it as "True" always brings better performance
        self.norm_crop_param = True

    def __len__(self):
        return len(self.video_items)

    def load_next_video(self):
        data={}
        self.video_index += 1
        video_item = self.video_items[self.video_index]
        source_video_item = self.random_video(video_item) if self.cross_id else video_item 

        with self.env.begin(write=False) as txn:
            key = format_for_lmdb(source_video_item['video_name'], 0)
            img_bytes_1 = txn.get(key) 
            img1 = Image.open(BytesIO(img_bytes_1))
            data['source_image'] = self.transform(img1)

            semantics_key = format_for_lmdb(video_item['video_name'], 'coeff_3dmm')
            semantics_numpy = np.frombuffer(txn.get(semantics_key), dtype=np.float32)
            semantics_numpy = semantics_numpy.reshape((video_item['num_frame'],-1))
            if self.cross_id and self.norm_crop_param:
                semantics_source_key = format_for_lmdb(source_video_item['video_name'], 'coeff_3dmm')
                semantics_source_numpy = np.frombuffer(txn.get(semantics_source_key), dtype=np.float32)
                semantic_source_numpy = semantics_source_numpy.reshape((source_video_item['num_frame'],-1))[0:1]
                crop_norm_ratio = self.find_crop_norm_ratio(semantic_source_numpy, semantics_numpy)
            else:
                crop_norm_ratio = None            

            data['target_image'], data['target_semantics'] = [], []
            for frame_index in range(video_item['num_frame']):
                key = format_for_lmdb(video_item['video_name'], frame_index)
                img_bytes_1 = txn.get(key) 
                img1 = Image.open(BytesIO(img_bytes_1))
                data['target_image'].append(self.transform(img1))
                data['target_semantics'].append(
                    self.transform_semantic(semantics_numpy, frame_index, crop_norm_ratio)
                )
            data['video_name'] = self.obtain_name(video_item['video_name'], source_video_item['video_name'])
        return data  
    
    def random_video(self, target_video_item):
        target_person_id = target_video_item['person_id']
        assert len(self.person_ids) > 1 
        source_person_id = np.random.choice(self.person_ids)
        if source_person_id == target_person_id:
            source_person_id = np.random.choice(self.person_ids)
        source_video_index = np.random.choice(self.idx_by_person_id[source_person_id])
        source_video_item = self.video_items[source_video_index]
        return source_video_item

    def find_crop_norm_ratio(self, source_coeff, target_coeffs):
        alpha = 0.3
        exp_diff = np.mean(np.abs(target_coeffs[:,80:144] - source_coeff[:,80:144]), 1)
        angle_diff = np.mean(np.abs(target_coeffs[:,224:227] - source_coeff[:,224:227]), 1)
        index = np.argmin(alpha*exp_diff + (1-alpha)*angle_diff)
        crop_norm_ratio = source_coeff[:,-3] / target_coeffs[index:index+1, -3]
        return crop_norm_ratio
   
    def transform_semantic(self, semantic, frame_index, crop_norm_ratio):
        index = self.obtain_seq_index(frame_index, semantic.shape[0])
        coeff_3dmm = semantic[index,...]
        # id_coeff = coeff_3dmm[:,:80] #identity
        ex_coeff = coeff_3dmm[:,80:144] #expression
        # tex_coeff = coeff_3dmm[:,144:224] #texture
        angles = coeff_3dmm[:,224:227] #euler angles for pose
        # gamma = coeff_3dmm[:,227:254] #lighting
        translation = coeff_3dmm[:,254:257] #translation
        crop = coeff_3dmm[:,257:300] #crop param

        if self.cross_id and self.norm_crop_param:
            crop[:, -3] = crop[:, -3] * crop_norm_ratio

        coeff_3dmm = np.concatenate([ex_coeff, angles, translation, crop], 1)
        return torch.Tensor(coeff_3dmm).permute(1,0)   

    def obtain_name(self, target_name, source_name):
        if not self.cross_id:
            return target_name
        else:
            source_name = os.path.splitext(os.path.basename(source_name))[0]
            return source_name+'_to_'+target_name

================================================
FILE: generators/base_function.py
================================================
import sys
import math

import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Function
from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm


class LayerNorm2d(nn.Module):
    def __init__(self, n_out, affine=True):
        super(LayerNorm2d, self).__init__()
        self.n_out = n_out
        self.affine = affine

        if self.affine:
          self.weight = nn.Parameter(torch.ones(n_out, 1, 1))
          self.bias = nn.Parameter(torch.zeros(n_out, 1, 1))

    def forward(self, x):
        normalized_shape = x.size()[1:]
        if self.affine:
          return F.layer_norm(x, normalized_shape, \
              self.weight.expand(normalized_shape), 
              self.bias.expand(normalized_shape))
              
        else:
          return F.layer_norm(x, normalized_shape)  

class ADAINHourglass(nn.Module):
    def __init__(self, image_nc, pose_nc, ngf, img_f, encoder_layers, decoder_layers, nonlinearity, use_spect):
        super(ADAINHourglass, self).__init__()
        self.encoder = ADAINEncoder(image_nc, pose_nc, ngf, img_f, encoder_layers, nonlinearity, use_spect)
        self.decoder = ADAINDecoder(pose_nc, ngf, img_f, encoder_layers, decoder_layers, True, nonlinearity, use_spect)
        self.output_nc = self.decoder.output_nc

    def forward(self, x, z):
        return self.decoder(self.encoder(x, z), z)                 



class ADAINEncoder(nn.Module):
    def __init__(self, image_nc, pose_nc, ngf, img_f, layers, nonlinearity=nn.LeakyReLU(), use_spect=False):
        super(ADAINEncoder, self).__init__()
        self.layers = layers
        self.input_layer = nn.Conv2d(image_nc, ngf, kernel_size=7, stride=1, padding=3)
        for i in range(layers):
            in_channels = min(ngf * (2**i), img_f)
            out_channels = min(ngf *(2**(i+1)), img_f)
            model = ADAINEncoderBlock(in_channels, out_channels, pose_nc, nonlinearity, use_spect)
            setattr(self, 'encoder' + str(i), model)
        self.output_nc = out_channels
        
    def forward(self, x, z):
        out = self.input_layer(x)
        out_list = [out]
        for i in range(self.layers):
            model = getattr(self, 'encoder' + str(i))
            out = model(out, z)
            out_list.append(out)
        return out_list
        
class ADAINDecoder(nn.Module):
    """docstring for ADAINDecoder"""
    def __init__(self, pose_nc, ngf, img_f, encoder_layers, decoder_layers, skip_connect=True, 
                 nonlinearity=nn.LeakyReLU(), use_spect=False):

        super(ADAINDecoder, self).__init__()
        self.encoder_layers = encoder_layers
        self.decoder_layers = decoder_layers
        self.skip_connect = skip_connect
        use_transpose = True

        for i in range(encoder_layers-decoder_layers, encoder_layers)[::-1]:
            in_channels = min(ngf * (2**(i+1)), img_f)
            in_channels = in_channels*2 if i != (encoder_layers-1) and self.skip_connect else in_channels
            out_channels = min(ngf * (2**i), img_f)
            model = ADAINDecoderBlock(in_channels, out_channels, out_channels, pose_nc, use_transpose, nonlinearity, use_spect)
            setattr(self, 'decoder' + str(i), model)

        self.output_nc = out_channels*2 if self.skip_connect else out_channels

    def forward(self, x, z):
        out = x.pop() if self.skip_connect else x
        for i in range(self.encoder_layers-self.decoder_layers, self.encoder_layers)[::-1]:
            model = getattr(self, 'decoder' + str(i))
            out = model(out, z)
            out = torch.cat([out, x.pop()], 1) if self.skip_connect else out
        return out

class ADAINEncoderBlock(nn.Module):       
    def __init__(self, input_nc, output_nc, feature_nc, nonlinearity=nn.LeakyReLU(), use_spect=False):
        super(ADAINEncoderBlock, self).__init__()
        kwargs_down = {'kernel_size': 4, 'stride': 2, 'padding': 1}
        kwargs_fine = {'kernel_size': 3, 'stride': 1, 'padding': 1}

        self.conv_0 = spectral_norm(nn.Conv2d(input_nc,  output_nc, **kwargs_down), use_spect)
        self.conv_1 = spectral_norm(nn.Conv2d(output_nc, output_nc, **kwargs_fine), use_spect)


        self.norm_0 = ADAIN(input_nc, feature_nc)
        self.norm_1 = ADAIN(output_nc, feature_nc)
        self.actvn = nonlinearity

    def forward(self, x, z):
        x = self.conv_0(self.actvn(self.norm_0(x, z)))
        x = self.conv_1(self.actvn(self.norm_1(x, z)))
        return x

class ADAINDecoderBlock(nn.Module):
    def __init__(self, input_nc, output_nc, hidden_nc, feature_nc, use_transpose=True, nonlinearity=nn.LeakyReLU(), use_spect=False):
        super(ADAINDecoderBlock, self).__init__()        
        # Attributes
        self.actvn = nonlinearity
        hidden_nc = min(input_nc, output_nc) if hidden_nc is None else hidden_nc

        kwargs_fine = {'kernel_size':3, 'stride':1, 'padding':1}
        if use_transpose:
            kwargs_up = {'kernel_size':3, 'stride':2, 'padding':1, 'output_padding':1}
        else:
            kwargs_up = {'kernel_size':3, 'stride':1, 'padding':1}

        # create conv layers
        self.conv_0 = spectral_norm(nn.Conv2d(input_nc, hidden_nc, **kwargs_fine), use_spect)
        if use_transpose:
            self.conv_1 = spectral_norm(nn.ConvTranspose2d(hidden_nc, output_nc, **kwargs_up), use_spect)
            self.conv_s = spectral_norm(nn.ConvTranspose2d(input_nc, output_nc, **kwargs_up), use_spect)
        else:
            self.conv_1 = nn.Sequential(spectral_norm(nn.Conv2d(hidden_nc, output_nc, **kwargs_up), use_spect),
                                        nn.Upsample(scale_factor=2))
            self.conv_s = nn.Sequential(spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_up), use_spect),
                                        nn.Upsample(scale_factor=2))
        # define normalization layers
        self.norm_0 = ADAIN(input_nc, feature_nc)
        self.norm_1 = ADAIN(hidden_nc, feature_nc)
        self.norm_s = ADAIN(input_nc, feature_nc)
        
    def forward(self, x, z):
        x_s = self.shortcut(x, z)
        dx = self.conv_0(self.actvn(self.norm_0(x, z)))
        dx = self.conv_1(self.actvn(self.norm_1(dx, z)))
        out = x_s + dx
        return out

    def shortcut(self, x, z):
        x_s = self.conv_s(self.actvn(self.norm_s(x, z)))
        return x_s              


def spectral_norm(module, use_spect=True):
    """use spectral normal layer to stable the training process"""
    if use_spect:
        return SpectralNorm(module)
    else:
        return module


class ADAIN(nn.Module):
    def __init__(self, norm_nc, feature_nc):
        super().__init__()

        self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)

        nhidden = 128
        use_bias=True

        self.mlp_shared = nn.Sequential(
            nn.Linear(feature_nc, nhidden, bias=use_bias),            
            nn.ReLU()
        )
        self.mlp_gamma = nn.Linear(nhidden, norm_nc, bias=use_bias)    
        self.mlp_beta = nn.Linear(nhidden, norm_nc, bias=use_bias)    

    def forward(self, x, feature):

        # Part 1. generate parameter-free normalized activations
        normalized = self.param_free_norm(x)

        # Part 2. produce scaling and bias conditioned on feature
        feature = feature.view(feature.size(0), -1)
        actv = self.mlp_shared(feature)
        gamma = self.mlp_gamma(actv)
        beta = self.mlp_beta(actv)

        # apply scale and bias
        gamma = gamma.view(*gamma.size()[:2], 1,1)
        beta = beta.view(*beta.size()[:2], 1,1)
        out = normalized * (1 + gamma) + beta
        return out


class FineEncoder(nn.Module):
    """docstring for Encoder"""
    def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
        super(FineEncoder, self).__init__()
        self.layers = layers
        self.first = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
        for i in range(layers):
            in_channels = min(ngf*(2**i), img_f)
            out_channels = min(ngf*(2**(i+1)), img_f)
            model = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
            setattr(self, 'down' + str(i), model)
        self.output_nc = out_channels

    def forward(self, x):
        x = self.first(x)
        out=[x]
        for i in range(self.layers):
            model = getattr(self, 'down'+str(i))
            x = model(x)
            out.append(x)
        return out

class FineDecoder(nn.Module):
    """docstring for FineDecoder"""
    def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
        super(FineDecoder, self).__init__()
        self.layers = layers
        for i in range(layers)[::-1]:
            in_channels = min(ngf*(2**(i+1)), img_f)
            out_channels = min(ngf*(2**i), img_f)
            up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
            res = FineADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect)
            jump = Jump(out_channels, norm_layer, nonlinearity, use_spect)

            setattr(self, 'up' + str(i), up)
            setattr(self, 'res' + str(i), res)            
            setattr(self, 'jump' + str(i), jump)

        self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'tanh')

        self.output_nc = out_channels

    def forward(self, x, z):
        out = x.pop()
        for i in range(self.layers)[::-1]:
            res_model = getattr(self, 'res' + str(i))
            up_model = getattr(self, 'up' + str(i))
            jump_model = getattr(self, 'jump' + str(i))
            out = res_model(out, z)
            out = up_model(out)
            out = jump_model(x.pop()) + out
        out_image = self.final(out)
        return out_image

class FirstBlock2d(nn.Module):
    """
    Downsampling block for use in encoder.
    """
    def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
        super(FirstBlock2d, self).__init__()
        kwargs = {'kernel_size': 7, 'stride': 1, 'padding': 3}
        conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)

        if type(norm_layer) == type(None):
            self.model = nn.Sequential(conv, nonlinearity)
        else:
            self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)


    def forward(self, x):
        out = self.model(x)
        return out  

class DownBlock2d(nn.Module):
    def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
        super(DownBlock2d, self).__init__()


        kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
        conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
        pool = nn.AvgPool2d(kernel_size=(2, 2))

        if type(norm_layer) == type(None):
            self.model = nn.Sequential(conv, nonlinearity, pool)
        else:
            self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity, pool)

    def forward(self, x):
        out = self.model(x)
        return out 

class UpBlock2d(nn.Module):
    def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
        super(UpBlock2d, self).__init__()
        kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
        conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
        if type(norm_layer) == type(None):
            self.model = nn.Sequential(conv, nonlinearity)
        else:
            self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)

    def forward(self, x):
        out = self.model(F.interpolate(x, scale_factor=2))
        return out

class FineADAINResBlocks(nn.Module):
    def __init__(self, num_block, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
        super(FineADAINResBlocks, self).__init__()                                
        self.num_block = num_block
        for i in range(num_block):
            model = FineADAINResBlock2d(input_nc, feature_nc, norm_layer, nonlinearity, use_spect)
            setattr(self, 'res'+str(i), model)

    def forward(self, x, z):
        for i in range(self.num_block):
            model = getattr(self, 'res'+str(i))
            x = model(x, z)
        return x     

class Jump(nn.Module):
    def __init__(self, input_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
        super(Jump, self).__init__()
        kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
        conv = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)

        if type(norm_layer) == type(None):
            self.model = nn.Sequential(conv, nonlinearity)
        else:
            self.model = nn.Sequential(conv, norm_layer(input_nc), nonlinearity)

    def forward(self, x):
        out = self.model(x)
        return out          

class FineADAINResBlock2d(nn.Module):
    """
    Define an Residual block for different types
    """
    def __init__(self, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
        super(FineADAINResBlock2d, self).__init__()

        kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}

        self.conv1 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
        self.conv2 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
        self.norm1 = ADAIN(input_nc, feature_nc)
        self.norm2 = ADAIN(input_nc, feature_nc)

        self.actvn = nonlinearity


    def forward(self, x, z):
        dx = self.actvn(self.norm1(self.conv1(x), z))
        dx = self.norm2(self.conv2(x), z)
        out = dx + x
        return out        

class FinalBlock2d(nn.Module):
    """
    Define the output layer
    """
    def __init__(self, input_nc, output_nc, use_spect=False, tanh_or_sigmoid='tanh'):
        super(FinalBlock2d, self).__init__()

        kwargs = {'kernel_size': 7, 'stride': 1, 'padding':3}
        conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)

        if tanh_or_sigmoid == 'sigmoid':
            out_nonlinearity = nn.Sigmoid()
        else:
            out_nonlinearity = nn.Tanh()            

        self.model = nn.Sequential(conv, out_nonlinearity)
    def forward(self, x):
        out = self.model(x)
        return out          

================================================
FILE: generators/face_model.py
================================================
import functools
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from util import flow_util
from generators.base_function import LayerNorm2d, ADAINHourglass, FineEncoder, FineDecoder

class FaceGenerator(nn.Module):
    def __init__(
        self, 
        mapping_net, 
        warpping_net, 
        editing_net, 
        common
        ):  
        super(FaceGenerator, self).__init__()
        self.mapping_net = MappingNet(**mapping_net)
        self.warpping_net = WarpingNet(**warpping_net, **common)
        self.editing_net = EditingNet(**editing_net, **common)
 
    def forward(
        self, 
        input_image, 
        driving_source, 
        stage=None
        ):
        if stage == 'warp':
            descriptor = self.mapping_net(driving_source)
            output = self.warpping_net(input_image, descriptor)
        else:
            descriptor = self.mapping_net(driving_source)
            output = self.warpping_net(input_image, descriptor)
            output['fake_image'] = self.editing_net(input_image, output['warp_image'], descriptor)
        return output

class MappingNet(nn.Module):
    def __init__(self, coeff_nc, descriptor_nc, layer):
        super( MappingNet, self).__init__()

        self.layer = layer
        nonlinearity = nn.LeakyReLU(0.1)

        self.first = nn.Sequential(
            torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True))

        for i in range(layer):
            net = nn.Sequential(nonlinearity,
                torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3))
            setattr(self, 'encoder' + str(i), net)   

        self.pooling = nn.AdaptiveAvgPool1d(1)
        self.output_nc = descriptor_nc

    def forward(self, input_3dmm):
        out = self.first(input_3dmm)
        for i in range(self.layer):
            model = getattr(self, 'encoder' + str(i))
            out = model(out) + out[:,:,3:-3]
        out = self.pooling(out)
        return out   

class WarpingNet(nn.Module):
    def __init__(
        self, 
        image_nc, 
        descriptor_nc, 
        base_nc, 
        max_nc, 
        encoder_layer, 
        decoder_layer, 
        use_spect
        ):
        super( WarpingNet, self).__init__()

        nonlinearity = nn.LeakyReLU(0.1)
        norm_layer = functools.partial(LayerNorm2d, affine=True) 
        kwargs = {'nonlinearity':nonlinearity, 'use_spect':use_spect}

        self.descriptor_nc = descriptor_nc 
        self.hourglass = ADAINHourglass(image_nc, self.descriptor_nc, base_nc,
                                       max_nc, encoder_layer, decoder_layer, **kwargs)

        self.flow_out = nn.Sequential(norm_layer(self.hourglass.output_nc), 
                                      nonlinearity,
                                      nn.Conv2d(self.hourglass.output_nc, 2, kernel_size=7, stride=1, padding=3))

        self.pool = nn.AdaptiveAvgPool2d(1)

    def forward(self, input_image, descriptor):
        final_output={}
        output = self.hourglass(input_image, descriptor)
        final_output['flow_field'] = self.flow_out(output)

        deformation = flow_util.convert_flow_to_deformation(final_output['flow_field'])
        final_output['warp_image'] = flow_util.warp_image(input_image, deformation)
        return final_output


class EditingNet(nn.Module):
    def __init__(
        self, 
        image_nc, 
        descriptor_nc, 
        layer, 
        base_nc, 
        max_nc, 
        num_res_blocks, 
        use_spect):  
        super(EditingNet, self).__init__()

        nonlinearity = nn.LeakyReLU(0.1)
        norm_layer = functools.partial(LayerNorm2d, affine=True) 
        kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect}
        self.descriptor_nc = descriptor_nc

        # encoder part
        self.encoder = FineEncoder(image_nc*2, base_nc, max_nc, layer, **kwargs)
        self.decoder = FineDecoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs)

    def forward(self, input_image, warp_image, descriptor):
        x = torch.cat([input_image, warp_image], 1)
        x = self.encoder(x)
        gen_image = self.decoder(x, descriptor)
        return gen_image


================================================
FILE: inference.py
================================================
import os
import cv2 
import lmdb
import math
import argparse
import numpy as np
from io import BytesIO
from PIL import Image

import torch
import torchvision.transforms.functional as F
import torchvision.transforms as transforms

from util.logging import init_logging, make_logging_dir
from util.distributed import init_dist
from util.trainer import get_model_optimizer_and_scheduler, set_random_seed, get_trainer
from util.distributed import master_only_print as print
from data.vox_video_dataset import VoxVideoDataset
from config import Config


def parse_args():
    parser = argparse.ArgumentParser(description='Training')
    parser.add_argument('--config', default='./config/face.yaml')
    parser.add_argument('--name', default=None)
    parser.add_argument('--checkpoints_dir', default='result',
                        help='Dir for saving logs and models.')
    parser.add_argument('--seed', type=int, default=0, help='Random seed.')
    parser.add_argument('--cross_id', action='store_true')
    parser.add_argument('--which_iter', type=int, default=None)
    parser.add_argument('--no_resume', action='store_true')
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument('--single_gpu', action='store_true')
    parser.add_argument('--output_dir', type=str)


    args = parser.parse_args()
    return args

def write2video(results_dir, *video_list):
    cat_video=None

    for video in video_list:
        video_numpy = video[:,:3,:,:].cpu().float().detach().numpy()
        video_numpy = (np.transpose(video_numpy, (0, 2, 3, 1)) + 1) / 2.0 * 255.0
        video_numpy = video_numpy.astype(np.uint8)
        cat_video = np.concatenate([cat_video, video_numpy], 2) if cat_video is not None else video_numpy

    image_array=[]
    for i in range(cat_video.shape[0]):
        image_array.append(cat_video[i]) 

    out_name = results_dir+'.mp4' 
    _, height, width, layers = cat_video.shape
    size = (width,height)
    out = cv2.VideoWriter(out_name, cv2.VideoWriter_fourcc(*'mp4v'), 15, size)

    for i in range(len(image_array)):
        out.write(image_array[i][:,:,::-1])
    out.release() 

if __name__ == '__main__':
    args = parse_args()
    set_random_seed(args.seed)
    opt = Config(args.config, args, is_train=False)

    if not args.single_gpu:
        opt.local_rank = args.local_rank
        init_dist(opt.local_rank)    
        opt.device = torch.cuda.current_device()
    # create a visualizer
    date_uid, logdir = init_logging(opt)
    opt.logdir = logdir
    make_logging_dir(logdir, date_uid)

    # create a model
    net_G, net_G_ema, opt_G, sch_G \
        = get_model_optimizer_and_scheduler(opt)

    trainer = get_trainer(opt, net_G, net_G_ema, \
                          opt_G, sch_G, None)

    current_epoch, current_iteration = trainer.load_checkpoint(
        opt, args.which_iter)                          
    net_G = trainer.net_G_ema.eval()

    output_dir = os.path.join(
        args.output_dir, 
        'epoch_{:05}_iteration_{:09}'.format(current_epoch, current_iteration)
        )
    os.makedirs(output_dir, exist_ok=True)
    opt.data.cross_id = args.cross_id
    dataset = VoxVideoDataset(opt.data, is_inference=True)
    with torch.no_grad():
        for video_index in range(dataset.__len__()):
            data = dataset.load_next_video()
            input_source = data['source_image'][None].cuda()
            name = data['video_name']

            output_images, gt_images, warp_images = [],[],[]
            for frame_index in range(len(data['target_semantics'])):
                target_semantic = data['target_semantics'][frame_index][None].cuda()
                output_dict = net_G(input_source, target_semantic)
                output_images.append(
                    output_dict['fake_image'].cpu().clamp_(-1, 1)
                    )
                warp_images.append(
                    output_dict['warp_image'].cpu().clamp_(-1, 1)
                    )                    
                gt_images.append(
                    data['target_image'][frame_index][None]
                    )
            
            gen_images = torch.cat(output_images, 0)
            gt_images = torch.cat(gt_images, 0)
            warp_images = torch.cat(warp_images, 0)

            write2video("{}/{}".format(output_dir, name), gt_images, warp_images, gen_images)
            print("write results to video {}/{}".format(output_dir, name))



================================================
FILE: intuitive_control.py
================================================
import os
import math
import argparse
import numpy as np
from scipy.io import savemat,loadmat

import torch
import torchvision.transforms.functional as F
import torchvision.transforms as transforms

from config import Config
from util.logging import init_logging, make_logging_dir
from util.distributed import init_dist
from util.trainer import get_model_optimizer_and_scheduler, set_random_seed, get_trainer
from util.distributed import master_only_print as print
from data.image_dataset import ImageDataset
from inference import write2video


def parse_args():
    parser = argparse.ArgumentParser(description='Training')
    parser.add_argument('--config', default='./config/face.yaml')
    parser.add_argument('--name', default=None)
    parser.add_argument('--checkpoints_dir', default='result',
                        help='Dir for saving logs and models.')
    parser.add_argument('--seed', type=int, default=0, help='Random seed.')
    parser.add_argument('--which_iter', type=int, default=None)
    parser.add_argument('--no_resume', action='store_true')
    parser.add_argument('--input_name', type=str)
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument('--single_gpu', action='store_true')
    parser.add_argument('--output_dir', type=str)

    args = parser.parse_args()
    return args

def get_control(input_name):
    control_dict = {}
    control_dict['rotation_center'] = torch.tensor([0,0,0,0,0,0.45])
    control_dict['rotation_left_x'] = torch.tensor([0,0,math.pi/10,0,0,0.45])
    control_dict['rotation_right_x'] = torch.tensor([0,0,-math.pi/10,0,0,0.45])

    control_dict['rotation_left_y'] = torch.tensor([math.pi/10,0,0,0,0,0.45])
    control_dict['rotation_right_y'] = torch.tensor([-math.pi/10,0,0,0,0,0.45])        

    control_dict['rotation_left_z'] = torch.tensor([0,math.pi/8,0,0,0,0.45])
    control_dict['rotation_right_z'] = torch.tensor([0,-math.pi/8,0,0,0,0.45])   

    expession = loadmat('{}/expression.mat'.format(input_name))

    for item in ['expression_center', 'expression_mouth', 'expression_eyebrow', 'expression_eyes']:
        control_dict[item] = torch.tensor(expession[item])[0]

    sort_rot_control = [
                        'rotation_left_x',  'rotation_center', 
                        'rotation_right_x', 'rotation_center',
                        'rotation_left_y',  'rotation_center',
                        'rotation_right_y', 'rotation_center',
                        'rotation_left_z',  'rotation_center',
                        'rotation_right_z', 'rotation_center'
                        ]
    
    sort_exp_control = [
                        'expression_center', 'expression_mouth',
                        'expression_center', 'expression_eyebrow',
                        'expression_center', 'expression_eyes',
                        ]
    return control_dict, sort_rot_control, sort_exp_control

if __name__ == '__main__':
    args = parse_args()
    set_random_seed(args.seed)
    opt = Config(args.config, args, is_train=False)

    if not args.single_gpu:
        opt.local_rank = args.local_rank
        init_dist(opt.local_rank)    
        opt.device = torch.cuda.current_device()

    # create a visualizer
    date_uid, logdir = init_logging(opt)
    opt.logdir = logdir
    make_logging_dir(logdir, date_uid)

    # create a model
    net_G, net_G_ema, opt_G, sch_G \
        = get_model_optimizer_and_scheduler(opt)

    trainer = get_trainer(opt, net_G, net_G_ema, \
                          opt_G, sch_G, None)

    current_epoch, current_iteration = trainer.load_checkpoint(
        opt, args.which_iter)                          
    net_G = trainer.net_G_ema.eval()

    output_dir = os.path.join(
        args.output_dir, 
        'epoch_{:05}_iteration_{:09}'.format(current_epoch, current_iteration)
        )

    os.makedirs(output_dir, exist_ok=True)
    image_dataset = ImageDataset(opt.data, args.input_name)

    control_dict, sort_rot_control, sort_exp_control = get_control(args.input_name)
    for _ in range(image_dataset.__len__()):
        with torch.no_grad():
            data = image_dataset.next_image()
            num = 10
            output_images = []     
            # rotation control
            current = control_dict['rotation_center']
            for control in sort_rot_control: 
                for i in range(num):
                    rotation = (control_dict[control]-current)*i/(num-1)+current
                    data['target_semantics'][:, 64:70, :] = rotation[None, :, None]
                    output_dict = net_G(data['source_image'].cuda(), data['target_semantics'].cuda())
                    output_images.append(
                        output_dict['fake_image'].cpu().clamp_(-1, 1)
                        )    
                current = rotation

            # expression control
            current = data['target_semantics'][0, :64, 0]
            for control in sort_exp_control: 
                for i in range(num):
                    expression = (control_dict[control]-current)*i/(num-1)+current
                    data['target_semantics'][:, :64, :] = expression[None, :, None]
                    output_dict = net_G(data['source_image'].cuda(), data['target_semantics'].cuda())
                    output_images.append(
                        output_dict['fake_image'].cpu().clamp_(-1, 1)
                        )    
                current = expression
            output_images = torch.cat(output_images, 0)   
            print('write results to file {}/{}'.format(output_dir, data['name']))
            write2video('{}/{}'.format(output_dir, data['name']), output_images)



================================================
FILE: loss/perceptual.py
================================================
import torch
import torch.nn.functional as F
import torchvision
from torch import nn

from util.distributed import master_only_print as print

def apply_imagenet_normalization(input):
    r"""Normalize using ImageNet mean and std.

    Args:
        input (4D tensor NxCxHxW): The input images, assuming to be [-1, 1].

    Returns:
        Normalized inputs using the ImageNet normalization.
    """
    # normalize the input back to [0, 1]
    normalized_input = (input + 1) / 2
    # normalize the input using the ImageNet mean and std
    mean = normalized_input.new_tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
    std = normalized_input.new_tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
    output = (normalized_input - mean) / std
    return output

class PerceptualLoss(nn.Module):
    r"""Perceptual loss initialization.

    Args:
        network (str) : The name of the loss network: 'vgg16' | 'vgg19'.
        layers (str or list of str) : The layers used to compute the loss.
        weights (float or list of float : The loss weights of each layer.
        criterion (str): The type of distance function: 'l1' | 'l2'.
        resize (bool) : If ``True``, resize the input images to 224x224.
        resize_mode (str): Algorithm used for resizing.
        instance_normalized (bool): If ``True``, applies instance normalization
            to the feature maps before computing the distance.
        num_scales (int): The loss will be evaluated at original size and
            this many times downsampled sizes.
    """

    def __init__(self, network='vgg19', layers='relu_4_1', weights=None,
                 criterion='l1', resize=False, resize_mode='bilinear',
                 instance_normalized=False, num_scales=1, 
                 use_style_loss=False, weight_style_to_perceptual=0):
        super().__init__()
        if isinstance(layers, str):
            layers = [layers]
        if weights is None:
            weights = [1.] * len(layers)
        elif isinstance(layers, float) or isinstance(layers, int):
            weights = [weights]

        assert len(layers) == len(weights), \
            'The number of layers (%s) must be equal to ' \
            'the number of weights (%s).' % (len(layers), len(weights))
        if network == 'vgg19':
            self.model = _vgg19(layers)
        elif network == 'vgg16':
            self.model = _vgg16(layers)
        elif network == 'alexnet':
            self.model = _alexnet(layers)
        elif network == 'inception_v3':
            self.model = _inception_v3(layers)
        elif network == 'resnet50':
            self.model = _resnet50(layers)
        elif network == 'robust_resnet50':
            self.model = _robust_resnet50(layers)
        elif network == 'vgg_face_dag':
            self.model = _vgg_face_dag(layers)
        else:
            raise ValueError('Network %s is not recognized' % network)

        self.num_scales = num_scales
        self.layers = layers
        self.weights = weights
        if criterion == 'l1':
            self.criterion = nn.L1Loss()
        elif criterion == 'l2' or criterion == 'mse':
            self.criterion = nn.MSELoss()
        else:
            raise ValueError('Criterion %s is not recognized' % criterion)
        self.resize = resize
        self.resize_mode = resize_mode
        self.instance_normalized = instance_normalized
        self.use_style_loss = use_style_loss
        self.weight_style = weight_style_to_perceptual

        print('Perceptual loss:')
        print('\tMode: {}'.format(network))

    def forward(self, inp, target, mask=None):
        r"""Perceptual loss forward.

        Args:
           inp (4D tensor) : Input tensor.
           target (4D tensor) : Ground truth tensor, same shape as the input.

        Returns:
           (scalar tensor) : The perceptual loss.
        """
        # Perceptual loss should operate in eval mode by default.
        self.model.eval()
        inp, target = \
            apply_imagenet_normalization(inp), \
            apply_imagenet_normalization(target)
        if self.resize:
            inp = F.interpolate(
                inp, mode=self.resize_mode, size=(224, 224),
                align_corners=False)
            target = F.interpolate(
                target, mode=self.resize_mode, size=(224, 224),
                align_corners=False)

        # Evaluate perceptual loss at each scale.
        loss = 0
        style_loss=0
        for scale in range(self.num_scales):
            input_features, target_features = \
                self.model(inp), self.model(target)
            for layer, weight in zip(self.layers, self.weights):
                # Example per-layer VGG19 loss values after applying
                # [0.03125, 0.0625, 0.125, 0.25, 1.0] weighting.
                # relu_1_1, 0.014698
                # relu_2_1, 0.085817
                # relu_3_1, 0.349977
                # relu_4_1, 0.544188
                # relu_5_1, 0.906261
                input_feature = input_features[layer]
                target_feature = target_features[layer].detach()
                if self.instance_normalized:
                    input_feature = F.instance_norm(input_feature)
                    target_feature = F.instance_norm(target_feature)

                if mask is not None:
                    mask_ = F.interpolate(mask, input_feature.shape[2:],
                                          mode='bilinear',
                                          align_corners=False)
                    input_feature = input_feature * mask_  
                    target_feature = target_feature * mask_ 
                    # print('mask',mask_.shape) 


                loss += weight * self.criterion(input_feature,
                                                target_feature)
                if self.use_style_loss and scale==0:
                    style_loss += self.criterion(self.compute_gram(input_feature),
                                                 self.compute_gram(target_feature))

            # Downsample the input and target.
            if scale != self.num_scales - 1:
                inp = F.interpolate(
                    inp, mode=self.resize_mode, scale_factor=0.5,
                    align_corners=False, recompute_scale_factor=True)
                target = F.interpolate(
                    target, mode=self.resize_mode, scale_factor=0.5,
                    align_corners=False, recompute_scale_factor=True)

        if self.use_style_loss:
            return loss + style_loss*self.weight_style
        else:
            return loss


    def compute_gram(self, x):
        b, ch, h, w = x.size()
        f = x.view(b, ch, w * h)
        f_T = f.transpose(1, 2)
        G = f.bmm(f_T) / (h * w * ch)
        return G

class _PerceptualNetwork(nn.Module):
    r"""The network that extracts features to compute the perceptual loss.

    Args:
        network (nn.Sequential) : The network that extracts features.
        layer_name_mapping (dict) : The dictionary that
            maps a layer's index to its name.
        layers (list of str): The list of layer names that we are using.
    """

    def __init__(self, network, layer_name_mapping, layers):
        super().__init__()
        assert isinstance(network, nn.Sequential), \
            'The network needs to be of type "nn.Sequential".'
        self.network = network
        self.layer_name_mapping = layer_name_mapping
        self.layers = layers
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        r"""Extract perceptual features."""
        output = {}
        for i, layer in enumerate(self.network):
            x = layer(x)
            layer_name = self.layer_name_mapping.get(i, None)
            if layer_name in self.layers:
                # If the current layer is used by the perceptual loss.
                output[layer_name] = x
        return output


def _vgg19(layers):
    r"""Get vgg19 layers"""
    network = torchvision.models.vgg19(pretrained=True).features
    layer_name_mapping = {1: 'relu_1_1',
                          3: 'relu_1_2',
                          6: 'relu_2_1',
                          8: 'relu_2_2',
                          11: 'relu_3_1',
                          13: 'relu_3_2',
                          15: 'relu_3_3',
                          17: 'relu_3_4',
                          20: 'relu_4_1',
                          22: 'relu_4_2',
                          24: 'relu_4_3',
                          26: 'relu_4_4',
                          29: 'relu_5_1'}
    return _PerceptualNetwork(network, layer_name_mapping, layers)


def _vgg16(layers):
    r"""Get vgg16 layers"""
    network = torchvision.models.vgg16(pretrained=True).features
    layer_name_mapping = {1: 'relu_1_1',
                          3: 'relu_1_2',
                          6: 'relu_2_1',
                          8: 'relu_2_2',
                          11: 'relu_3_1',
                          13: 'relu_3_2',
                          15: 'relu_3_3',
                          18: 'relu_4_1',
                          20: 'relu_4_2',
                          22: 'relu_4_3',
                          25: 'relu_5_1'}
    return _PerceptualNetwork(network, layer_name_mapping, layers)


def _alexnet(layers):
    r"""Get alexnet layers"""
    network = torchvision.models.alexnet(pretrained=True).features
    layer_name_mapping = {0: 'conv_1',
                          1: 'relu_1',
                          3: 'conv_2',
                          4: 'relu_2',
                          6: 'conv_3',
                          7: 'relu_3',
                          8: 'conv_4',
                          9: 'relu_4',
                          10: 'conv_5',
                          11: 'relu_5'}
    return _PerceptualNetwork(network, layer_name_mapping, layers)


def _inception_v3(layers):
    r"""Get inception v3 layers"""
    inception = torchvision.models.inception_v3(pretrained=True)
    network = nn.Sequential(inception.Conv2d_1a_3x3,
                            inception.Conv2d_2a_3x3,
                            inception.Conv2d_2b_3x3,
                            nn.MaxPool2d(kernel_size=3, stride=2),
                            inception.Conv2d_3b_1x1,
                            inception.Conv2d_4a_3x3,
                            nn.MaxPool2d(kernel_size=3, stride=2),
                            inception.Mixed_5b,
                            inception.Mixed_5c,
                            inception.Mixed_5d,
                            inception.Mixed_6a,
                            inception.Mixed_6b,
                            inception.Mixed_6c,
                            inception.Mixed_6d,
                            inception.Mixed_6e,
                            inception.Mixed_7a,
                            inception.Mixed_7b,
                            inception.Mixed_7c,
                            nn.AdaptiveAvgPool2d(output_size=(1, 1)))
    layer_name_mapping = {3: 'pool_1',
                          6: 'pool_2',
                          14: 'mixed_6e',
                          18: 'pool_3'}
    return _PerceptualNetwork(network, layer_name_mapping, layers)


def _resnet50(layers):
    r"""Get resnet50 layers"""
    resnet50 = torchvision.models.resnet50(pretrained=True)
    network = nn.Sequential(resnet50.conv1,
                            resnet50.bn1,
                            resnet50.relu,
                            resnet50.maxpool,
                            resnet50.layer1,
                            resnet50.layer2,
                            resnet50.layer3,
                            resnet50.layer4,
                            resnet50.avgpool)
    layer_name_mapping = {4: 'layer_1',
                          5: 'layer_2',
                          6: 'layer_3',
                          7: 'layer_4'}
    return _PerceptualNetwork(network, layer_name_mapping, layers)


def _robust_resnet50(layers):
    r"""Get robust resnet50 layers"""
    resnet50 = torchvision.models.resnet50(pretrained=False)
    state_dict = torch.utils.model_zoo.load_url(
        'http://andrewilyas.com/ImageNet.pt')
    new_state_dict = {}
    for k, v in state_dict['model'].items():
        if k.startswith('module.model.'):
            new_state_dict[k[13:]] = v
    resnet50.load_state_dict(new_state_dict)
    network = nn.Sequential(resnet50.conv1,
                            resnet50.bn1,
                            resnet50.relu,
                            resnet50.maxpool,
                            resnet50.layer1,
                            resnet50.layer2,
                            resnet50.layer3,
                            resnet50.layer4,
                            resnet50.avgpool)
    layer_name_mapping = {4: 'layer_1',
                          5: 'layer_2',
                          6: 'layer_3',
                          7: 'layer_4'}
    return _PerceptualNetwork(network, layer_name_mapping, layers)


def _vgg_face_dag(layers):
    r"""Get vgg face layers"""
    network = torchvision.models.vgg16(num_classes=2622)
    state_dict = torch.utils.model_zoo.load_url(
        'http://www.robots.ox.ac.uk/~albanie/models/pytorch-mcn/'
        'vgg_face_dag.pth')
    feature_layer_name_mapping = {
        0: 'conv1_1',
        2: 'conv1_2',
        5: 'conv2_1',
        7: 'conv2_2',
        10: 'conv3_1',
        12: 'conv3_2',
        14: 'conv3_3',
        17: 'conv4_1',
        19: 'conv4_2',
        21: 'conv4_3',
        24: 'conv5_1',
        26: 'conv5_2',
        28: 'conv5_3'}
    new_state_dict = {}
    for k, v in feature_layer_name_mapping.items():
        new_state_dict['features.' + str(k) + '.weight'] =\
            state_dict[v + '.weight']
        new_state_dict['features.' + str(k) + '.bias'] = \
            state_dict[v + '.bias']

    classifier_layer_name_mapping = {
        0: 'fc6',
        3: 'fc7',
        6: 'fc8'}
    for k, v in classifier_layer_name_mapping.items():
        new_state_dict['classifier.' + str(k) + '.weight'] = \
            state_dict[v + '.weight']
        new_state_dict['classifier.' + str(k) + '.bias'] = \
            state_dict[v + '.bias']

    network.load_state_dict(new_state_dict)

    class Flatten(nn.Module):
        r"""Flatten the tensor"""

        def forward(self, x):
            r"""Flatten it"""
            return x.view(x.shape[0], -1)

    layer_name_mapping = {
        1: 'avgpool',
        3: 'fc6',
        4: 'relu_6',
        6: 'fc7',
        7: 'relu_7',
        9: 'fc8'}
    seq_layers = [network.features, network.avgpool, Flatten()]
    for i in range(7):
        seq_layers += [network.classifier[i]]
    network = nn.Sequential(*seq_layers)
    return _PerceptualNetwork(network, layer_name_mapping, layers)


================================================
FILE: requirements.txt
================================================
absl-py==0.13.0
backcall==0.2.0
cachetools==4.2.2
certifi==2021.5.30
charset-normalizer==2.0.6
cycler==0.10.0
dataclasses==0.8
decorator==4.4.2
filelock==3.0.12
gdown==3.13.1
google-auth==1.35.0
google-auth-oauthlib==0.4.6
grpcio==1.40.0
idna==3.2
imageio==2.9.0
importlib-metadata==4.8.1
ipython==7.16.1
ipython-genutils==0.2.0
jedi==0.18.0
kiwisolver==1.3.1
lmdb==1.2.1
Markdown==3.3.4
matplotlib==3.3.4
mkl-fft==1.3.0
mkl-random==1.1.1
mkl-service==2.3.0
networkx==2.5.1
numpy==1.19.2
oauthlib==3.1.1
olefile==0.46
opencv-python==4.5.3.56
parso==0.8.2
pexpect==4.8.0
pickleshare==0.7.5
Pillow==8.3.1
pip==21.2.2
prompt-toolkit==3.0.20
protobuf==3.18.0
ptyprocess==0.7.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
Pygments==2.10.0
pyparsing==2.4.7
PySocks==1.7.1
python-dateutil==2.8.2
PyWavelets==1.1.1
PyYAML==5.4.1
requests==2.26.0
requests-oauthlib==1.3.0
rsa==4.7.2
scikit-image==0.17.2
scipy==1.5.4
setuptools==58.0.4
six==1.16.0
tensorboard==2.6.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tifffile==2020.9.3
torch==1.7.1
torchvision==0.8.2
tqdm==4.62.2
traitlets==4.3.3
typing-extensions==3.10.0.2
urllib3==1.26.6
wcwidth==0.2.5
Werkzeug==2.0.1
wheel==0.37.0
zipp==3.5.0


================================================
FILE: scripts/coeff_detector.py
================================================
import os
import glob
import numpy as np
from os import makedirs, name
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn

from options.inference_options import InferenceOptions
from models import create_model
from util.preprocess import align_img
from util.load_mats import load_lm3d
from extract_kp_videos import KeypointExtractor


class CoeffDetector(nn.Module):
    def __init__(self, opt):
        super().__init__()

        self.model = create_model(opt)
        self.model.setup(opt)
        self.model.device = 'cuda'
        self.model.parallelize()
        self.model.eval()

        self.lm3d_std = load_lm3d(opt.bfm_folder) 

    def forward(self, img, lm):
        
        img, trans_params = self.image_transform(img, lm)

        data_input = {                
                'imgs': img[None],
                }        
        self.model.set_input(data_input)  
        self.model.test()
        pred_coeff = {key:self.model.pred_coeffs_dict[key].cpu().numpy() for key in self.model.pred_coeffs_dict}
        pred_coeff = np.concatenate([
            pred_coeff['id'], 
            pred_coeff['exp'], 
            pred_coeff['tex'], 
            pred_coeff['angle'],
            pred_coeff['gamma'],
            pred_coeff['trans'],
            trans_params[None],
            ], 1)
        
        return {'coeff_3dmm':pred_coeff, 
                'crop_img': Image.fromarray((img.cpu().permute(1, 2, 0).numpy()*255).astype(np.uint8))}

    def image_transform(self, images, lm):
        """
        param:
            images:          -- PIL image 
            lm:              -- numpy array
        """
        W,H = images.size
        if np.mean(lm) == -1:
            lm = (self.lm3d_std[:, :2]+1)/2.
            lm = np.concatenate(
                [lm[:, :1]*W, lm[:, 1:2]*H], 1
            )
        else:
            lm[:, -1] = H - 1 - lm[:, -1]

        trans_params, img, lm, _ = align_img(images, lm, self.lm3d_std)        
        img = torch.tensor(np.array(img)/255., dtype=torch.float32).permute(2, 0, 1)
        trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)])
        trans_params = torch.tensor(trans_params.astype(np.float32))
        return img, trans_params        

def get_data_path(root, keypoint_root):
    filenames = list()
    keypoint_filenames = list()

    IMAGE_EXTENSIONS_LOWERCASE = {'jpg', 'png', 'jpeg', 'webp'}
    IMAGE_EXTENSIONS = IMAGE_EXTENSIONS_LOWERCASE.union({f.upper() for f in IMAGE_EXTENSIONS_LOWERCASE})
    extensions = IMAGE_EXTENSIONS

    for ext in extensions:
        filenames += glob.glob(f'{root}/*.{ext}', recursive=True)
    filenames = sorted(filenames)
    for filename in filenames:
        name = os.path.splitext(os.path.basename(filename))[0]
        keypoint_filenames.append(
            os.path.join(keypoint_root, name + '.txt')
        )
    return filenames, keypoint_filenames


if __name__ == "__main__":
    opt = InferenceOptions().parse() 
    coeff_detector = CoeffDetector(opt)
    kp_extractor = KeypointExtractor()
    image_names, keypoint_names = get_data_path(opt.input_dir, opt.keypoint_dir)
    makedirs(opt.keypoint_dir, exist_ok=True)
    makedirs(opt.output_dir, exist_ok=True)

    for image_name, keypoint_name in tqdm(zip(image_names, keypoint_names)):
        image = Image.open(image_name)
        if not os.path.isfile(keypoint_name):
            lm = kp_extractor.extract_keypoint(image, keypoint_name)
        else:
            lm = np.loadtxt(keypoint_name).astype(np.float32)
            lm = lm.reshape([-1, 2]) 
        predicted = coeff_detector(image, lm)
        name = os.path.splitext(os.path.basename(image_name))[0]
        np.savetxt(
            "{}/{}_3dmm_coeff.txt".format(opt.output_dir, name), 
            predicted['coeff_3dmm'].reshape(-1))

        



    

================================================
FILE: scripts/download_demo_dataset.sh
================================================
gdown https://drive.google.com/uc?id=1ruuLw5-0fpm6EREexPn3I_UQPmkrBoq9
unzip -x ./vox_lmdb_demo.zip
mkdir ./dataset
mv vox_lmdb_demo ./dataset


================================================
FILE: scripts/download_weights.sh
================================================
gdown https://drive.google.com/uc?id=1-0xOf6g58OmtKtEWJlU3VlnfRqPN9Uq7
unzip -x ./face.zip
mkdir ./result
mv face ./result
rm face.zip


================================================
FILE: scripts/extract_kp_videos.py
================================================
import os
import cv2
import time
import glob
import argparse
import face_alignment
import numpy as np
from PIL import Image
from tqdm import tqdm
from itertools import cycle

from torch.multiprocessing import Pool, Process, set_start_method

class KeypointExtractor():
    def __init__(self):
        self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D)   

    def extract_keypoint(self, images, name=None):
        if isinstance(images, list):
            keypoints = []
            for image in images:
                current_kp = self.extract_keypoint(image)
                if np.mean(current_kp) == -1 and keypoints:
                    keypoints.append(keypoints[-1])
                else:
                    keypoints.append(current_kp[None])

            keypoints = np.concatenate(keypoints, 0)
            np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
            return keypoints
        else:
            while True:
                try:
                    keypoints = self.detector.get_landmarks_from_image(np.array(images))[0]
                    break
                except RuntimeError as e:
                    if str(e).startswith('CUDA'):
                        print("Warning: out of memory, sleep for 1s")
                        time.sleep(1)
                    else:
                        print(e)
                        break    
                except TypeError:
                    print('No face detected in this image')
                    shape = [68, 2]
                    keypoints = -1. * np.ones(shape)                    
                    break
            if name is not None:
                np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
            return keypoints

def read_video(filename):
    frames = []
    cap = cv2.VideoCapture(filename)
    while cap.isOpened():
        ret, frame = cap.read()
        if ret:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = Image.fromarray(frame)
            frames.append(frame)
        else:
            break
    cap.release()
    return frames

def run(data):
    filename, opt, device = data
    os.environ['CUDA_VISIBLE_DEVICES'] = device
    kp_extractor = KeypointExtractor()
    images = read_video(filename)
    name = filename.split('/')[-2:]
    os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
    kp_extractor.extract_keypoint(
        images, 
        name=os.path.join(opt.output_dir, name[-2], name[-1])
    )

if __name__ == '__main__':
    set_start_method('spawn')
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--input_dir', type=str, help='the folder of the input files')
    parser.add_argument('--output_dir', type=str, help='the folder of the output files')
    parser.add_argument('--device_ids', type=str, default='0,1')
    parser.add_argument('--workers', type=int, default=4)

    opt = parser.parse_args()
    filenames = list()
    VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
    VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
    extensions = VIDEO_EXTENSIONS
    for ext in extensions:
        filenames = sorted(glob.glob(f'{opt.input_dir}/**/*.{ext}'))
    print('Total number of videos:', len(filenames))
    pool = Pool(opt.workers)
    args_list = cycle([opt])
    device_ids = opt.device_ids.split(",")
    device_ids = cycle(device_ids)
    for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
        None


================================================
FILE: scripts/face_recon_images.py
================================================
import os
import glob
import numpy as np
from PIL import Image
from tqdm import tqdm
from scipy.io import savemat

import torch 

from models import create_model
from options.inference_options import InferenceOptions
from util.preprocess import align_img
from util.load_mats import load_lm3d
from util.util import tensor2im, save_image


def get_data_path(root, keypoint_root):
    filenames = list()
    keypoint_filenames = list()

    IMAGE_EXTENSIONS_LOWERCASE = {'jpg', 'png', 'jpeg', 'webp'}
    IMAGE_EXTENSIONS = IMAGE_EXTENSIONS_LOWERCASE.union({f.upper() for f in IMAGE_EXTENSIONS_LOWERCASE})
    extensions = IMAGE_EXTENSIONS

    for ext in extensions:
        filenames += glob.glob(f'{root}/*.{ext}', recursive=True)
    filenames = sorted(filenames)
    for filename in filenames:
        name = os.path.splitext(os.path.basename(filename))[0]
        keypoint_filenames.append(
            os.path.join(keypoint_root, name + '.txt')
        )
    return filenames, keypoint_filenames


class ImagePathDataset(torch.utils.data.Dataset):
    def __init__(self, filenames, txt_filenames, bfm_folder):
        self.filenames = filenames
        self.txt_filenames = txt_filenames
        self.lm3d_std = load_lm3d(bfm_folder) 

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, i):
        filename = self.filenames[i]
        txt_filename = self.txt_filenames[i]
        imgs, _, trans_params = self.read_data(filename, txt_filename)
        return {
            'imgs':imgs,
            'trans_param':trans_params,
            'filename': filename
        }

    def image_transform(self, images, lm):
        W,H = images.size
        if np.mean(lm) == -1:
            lm = (self.lm3d_std[:, :2]+1)/2.
            lm = np.concatenate(
                [lm[:, :1]*W, lm[:, 1:2]*H], 1
            )
        else:
            lm[:, -1] = H - 1 - lm[:, -1]

        trans_params, img, lm, _ = align_img(images, lm, self.lm3d_std)        
        img = torch.tensor(np.array(img)/255., dtype=torch.float32).permute(2, 0, 1)
        lm = torch.tensor(lm)
        trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)])
        trans_params = torch.tensor(trans_params.astype(np.float32))
        return img, lm, trans_params        

    def read_data(self, filename, txt_filename):
        images = Image.open(filename).convert('RGB')
        lm = np.loadtxt(txt_filename).astype(np.float32)
        lm = lm.reshape([-1, 2]) 
        imgs, lms, trans_params = self.image_transform(images, lm)
        return imgs, lms, trans_params


def main(opt, model):
    import torch.multiprocessing
    torch.multiprocessing.set_sharing_strategy('file_system')
    filenames, keypoint_filenames = get_data_path(opt.input_dir, opt.keypoint_dir)
        
    dataset = ImagePathDataset(filenames, keypoint_filenames, opt.bfm_folder)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=opt.inference_batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=8,
    ) 
    pred_coeffs, pred_trans_params = [], []
    print('nums of images:', dataset.__len__())
    for iteration, data in tqdm(enumerate(dataloader)):
        data_input = {                
                'imgs': data['imgs'],
                }
        
        model.set_input(data_input)  
        model.test()
        pred_coeff = {key:model.pred_coeffs_dict[key].cpu().numpy() for key in model.pred_coeffs_dict}
        pred_coeff = np.concatenate([
            pred_coeff['id'], 
            pred_coeff['exp'], 
            pred_coeff['tex'], 
            pred_coeff['angle'],
            pred_coeff['gamma'],
            pred_coeff['trans']], 1)
        pred_coeffs.append(pred_coeff) 
        trans_param = data['trans_param'].cpu().numpy()
        pred_trans_params.append(trans_param)
        if opt.save_split_files:
            for index, filename in enumerate(data['filename']):
                basename = os.path.splitext(os.path.basename(filename))[0]
                output_path = os.path.join(opt.output_dir, basename+'.mat')
                savemat(
                    output_path, 
                    {'coeff':pred_coeff[index], 
                    'transform_params':trans_param[index]}
                )
        # visuals = model.get_current_visuals()  # get image results
        # for name in visuals:
        #     images = visuals[name]
        #     for i in range(images.shape[0]):
        #         image_numpy = tensor2im(images[i])
        #         save_image(image_numpy, os.path.basename(data['filename'][i])+'.png')                

    pred_coeffs = np.concatenate(pred_coeffs, 0)
    pred_trans_params = np.concatenate(pred_trans_params, 0)
    savemat(os.path.join(opt.output_dir, 'ffhq.mat'), {'coeff':pred_coeffs, 'transform_params':pred_trans_params})


if __name__ == '__main__':
    opt = InferenceOptions().parse()  # get test options
    model = create_model(opt)
    model.setup(opt)
    model.device = 'cuda:0'
    model.parallelize()
    model.eval()
    lm3d_std = load_lm3d(opt.bfm_folder) 
    main(opt, model)




================================================
FILE: scripts/face_recon_videos.py
================================================
import os
import cv2
import glob
import numpy as np
from PIL import Image
from tqdm import tqdm
from scipy.io import savemat

import torch 

from models import create_model
from options.inference_options import InferenceOptions
from util.preprocess import align_img
from util.load_mats import load_lm3d
from util.util import mkdirs, tensor2im, save_image


def get_data_path(root, keypoint_root):
    filenames = list()
    keypoint_filenames = list()

    VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
    VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
    extensions = VIDEO_EXTENSIONS

    for ext in extensions:
        filenames += glob.glob(f'{root}/**/*.{ext}', recursive=True)
    filenames = sorted(filenames)
    keypoint_filenames = sorted(glob.glob(f'{keypoint_root}/**/*.txt', recursive=True))
    assert len(filenames) == len(keypoint_filenames)

    return filenames, keypoint_filenames

class VideoPathDataset(torch.utils.data.Dataset):
    def __init__(self, filenames, txt_filenames, bfm_folder):
        self.filenames = filenames
        self.txt_filenames = txt_filenames
        self.lm3d_std = load_lm3d(bfm_folder) 

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, index):
        filename = self.filenames[index]
        txt_filename = self.txt_filenames[index]
        frames = self.read_video(filename)
        lm = np.loadtxt(txt_filename).astype(np.float32)
        lm = lm.reshape([len(frames), -1, 2]) 
        out_images, out_trans_params = list(), list()
        for i in range(len(frames)):
            out_img, _, out_trans_param \
                = self.image_transform(frames[i], lm[i])
            out_images.append(out_img[None])
            out_trans_params.append(out_trans_param[None])
        return {
            'imgs': torch.cat(out_images, 0),
            'trans_param':torch.cat(out_trans_params, 0),
            'filename': filename
        }
        
    def read_video(self, filename):
        frames = list()
        cap = cv2.VideoCapture(filename)
        while cap.isOpened():
            ret, frame = cap.read()
            if ret:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = Image.fromarray(frame)
                frames.append(frame)
            else:
                break
        cap.release()
        return frames

    def image_transform(self, images, lm):
        W,H = images.size
        if np.mean(lm) == -1:
            lm = (self.lm3d_std[:, :2]+1)/2.
            lm = np.concatenate(
                [lm[:, :1]*W, lm[:, 1:2]*H], 1
            )
        else:
            lm[:, -1] = H - 1 - lm[:, -1]

        trans_params, img, lm, _ = align_img(images, lm, self.lm3d_std)        
        img = torch.tensor(np.array(img)/255., dtype=torch.float32).permute(2, 0, 1)
        lm = torch.tensor(lm)
        trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)])
        trans_params = torch.tensor(trans_params.astype(np.float32))
        return img, lm, trans_params        

def main(opt, model):
    import torch.multiprocessing
    torch.multiprocessing.set_sharing_strategy('file_system')
    filenames, keypoint_filenames = get_data_path(opt.input_dir, opt.keypoint_dir)
    dataset = VideoPathDataset(filenames, keypoint_filenames, opt.bfm_folder)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=1, # can noly set to one here!
        shuffle=False,
        drop_last=False,
        num_workers=8,
    )     
    batch_size = opt.inference_batch_size
    for data in tqdm(dataloader):
        num_batch = data['imgs'][0].shape[0] // batch_size + 1
        pred_coeffs = list()
        for index in range(num_batch):
            data_input = {                
                'imgs': data['imgs'][0,index*batch_size:(index+1)*batch_size],
            }
            model.set_input(data_input)  
            model.test()
            pred_coeff = {key:model.pred_coeffs_dict[key].cpu().numpy() for key in model.pred_coeffs_dict}
            pred_coeff = np.concatenate([
                pred_coeff['id'], 
                pred_coeff['exp'], 
                pred_coeff['tex'], 
                pred_coeff['angle'],
                pred_coeff['gamma'],
                pred_coeff['trans']], 1)
            pred_coeffs.append(pred_coeff) 
            visuals = model.get_current_visuals()  # get image results
            if False: # debug
                for name in visuals:
                    images = visuals[name]
                    for i in range(images.shape[0]):
                        image_numpy = tensor2im(images[i])
                        save_image(
                            image_numpy, 
                            os.path.join(
                                opt.output_dir,
                                os.path.basename(data['filename'][0])+str(i).zfill(5)+'.jpg')
                            )
                exit()

        pred_coeffs = np.concatenate(pred_coeffs, 0)
        pred_trans_params = data['trans_param'][0].cpu().numpy()
        name = data['filename'][0].split('/')[-2:]
        name[-1] = os.path.splitext(name[-1])[0] + '.mat'
        os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
        savemat(
            os.path.join(opt.output_dir, name[-2], name[-1]), 
            {'coeff':pred_coeffs, 'transform_params':pred_trans_params}
        )

if __name__ == '__main__':
    opt = InferenceOptions().parse()  # get test options
    model = create_model(opt)
    model.setup(opt)
    model.device = 'cuda:0'
    model.parallelize()
    model.eval()

    main(opt, model)




================================================
FILE: scripts/inference_options.py
================================================
from .base_options import BaseOptions


class InferenceOptions(BaseOptions):
    """This class includes test options.

    It also includes shared options defined in BaseOptions.
    """

    def initialize(self, parser):
        parser = BaseOptions.initialize(self, parser)  # define shared options
        parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
        parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]')

        parser.add_argument('--input_dir', type=str, help='the folder of the input files')
        parser.add_argument('--keypoint_dir', type=str, help='the folder of the keypoint files')
        parser.add_argument('--output_dir', type=str, default='mp4', help='the output dir to save the extracted coefficients')
        parser.add_argument('--save_split_files', action='store_true', help='save split files or not')
        parser.add_argument('--inference_batch_size', type=int, default=8)
        
        # Dropout and Batchnorm has different behavior during training and test.
        self.isTrain = False
        return parser


================================================
FILE: scripts/prepare_vox_lmdb.py
================================================
import os
import cv2
import lmdb
import argparse
import multiprocessing
import numpy as np

from glob import glob
from io import BytesIO
from tqdm import tqdm
from PIL import Image
from scipy.io import loadmat
from torchvision.transforms import functional as trans_fn

def format_for_lmdb(*args):
    key_parts = []
    for arg in args:
        if isinstance(arg, int):
            arg = str(arg).zfill(7)
        key_parts.append(arg)
    return '-'.join(key_parts).encode('utf-8')

class Resizer:
    def __init__(self, size, kp_root, coeff_3dmm_root, img_format):
        self.size = size
        self.kp_root = kp_root
        self.coeff_3dmm_root = coeff_3dmm_root
        self.img_format = img_format

    def get_resized_bytes(self, img, img_format='jpeg'):
        img = trans_fn.resize(img, (self.size, self.size), interpolation=Image.BICUBIC)
        buf = BytesIO()
        img.save(buf, format=img_format)
        img_bytes = buf.getvalue()
        return img_bytes

    def prepare(self, filename):
        frames = {'img':[], 'kp':None, 'coeff_3dmm':None}
        cap = cv2.VideoCapture(filename)
        while cap.isOpened():
            ret, frame = cap.read()
            if ret:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                img_pil = Image.fromarray(frame)
                img_bytes = self.get_resized_bytes(img_pil, self.img_format)
                frames['img'].append(img_bytes)
            else:
                break
        cap.release()
        video_name = os.path.splitext(os.path.basename(filename))[0]
        keypoint_byte = get_others(self.kp_root, video_name, 'keypoint')
        coeff_3dmm_byte = get_others(self.coeff_3dmm_root, video_name, 'coeff_3dmm')
        frames['kp'] = keypoint_byte
        frames['coeff_3dmm'] = coeff_3dmm_byte
        return frames

    def __call__(self, index_filename):
        index, filename = index_filename
        result = self.prepare(filename)
        return index, result, filename

def get_others(root, video_name, data_type):
    if root is None:
        return
    else:
        assert data_type in ('keypoint', 'coeff_3dmm')
    if os.path.isfile(os.path.join(root, 'train', video_name+'.mat')):
        file_path = os.path.join(root, 'train', video_name+'.mat')
    else:
        file_path = os.path.join(root, 'test', video_name+'.mat')
    
    if data_type == 'keypoint':
        return_byte = convert_kp(file_path)
    else:
        return_byte = convert_3dmm(file_path)
    return return_byte

def convert_kp(file_path):
    file_mat = loadmat(file_path)
    kp_byte = file_mat['landmark'].tobytes()
    return kp_byte

def convert_3dmm(file_path):
    file_mat = loadmat(file_path)
    coeff_3dmm = file_mat['coeff']
    crop_param = file_mat['transform_params']
    _, _, ratio, t0, t1 = np.hsplit(crop_param.astype(np.float32), 5)
    crop_param = np.concatenate([ratio, t0, t1], 1)
    coeff_3dmm_cat = np.concatenate([coeff_3dmm, crop_param], 1) 
    coeff_3dmm_byte = coeff_3dmm_cat.tobytes()
    return coeff_3dmm_byte


def prepare_data(path, keypoint_path, coeff_3dmm_path, out, n_worker, sizes, chunksize, img_format):
    filenames = list()
    VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
    VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
    extensions = VIDEO_EXTENSIONS
    for ext in extensions:
        filenames += glob(f'{path}/**/*.{ext}', recursive=True)
    train_video, test_video = [], []
    for item in filenames:
        if "/train/" in item:
            train_video.append(item)
        else:
            test_video.append(item)
    print(len(train_video), len(test_video))
    with open(os.path.join(out, 'train_list.txt'),'w') as f:
        for item in train_video:
            item = os.path.splitext(os.path.basename(item))[0]
            f.write(item + '\n')

    with open(os.path.join(out, 'test_list.txt'),'w') as f:
        for item in test_video:
            item = os.path.splitext(os.path.basename(item))[0]
            f.write(item + '\n')      


    filenames = sorted(filenames)
    total = len(filenames)
    os.makedirs(out, exist_ok=True)
    for size in sizes:
        lmdb_path = os.path.join(out, str(size))
        with lmdb.open(lmdb_path, map_size=1024 ** 4, readahead=False) as env:
            with env.begin(write=True) as txn:
                txn.put(format_for_lmdb('length'), format_for_lmdb(total))
                resizer = Resizer(size, keypoint_path, coeff_3dmm_path, img_format)
                with multiprocessing.Pool(n_worker) as pool:
                    for idx, result, filename in tqdm(
                            pool.imap_unordered(resizer, enumerate(filenames), chunksize=chunksize),
                            total=total):
                        filename = os.path.basename(filename)
                        video_name = os.path.splitext(filename)[0]
                        txn.put(format_for_lmdb(video_name, 'length'), format_for_lmdb(len(result['img'])))

                        for frame_idx, frame in enumerate(result['img']):
                            txn.put(format_for_lmdb(video_name, frame_idx), frame)

                        if result['kp']:
                            txn.put(format_for_lmdb(video_name, 'keypoint'), result['kp'])
                        if result['coeff_3dmm']:
                            txn.put(format_for_lmdb(video_name, 'coeff_3dmm'), result['coeff_3dmm'])


if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--path', type=str, help='a path to input directiory')
    parser.add_argument('--keypoint_path', type=str, help='a path to output directory', default=None)
    parser.add_argument('--coeff_3dmm_path', type=str, help='a path to output directory', default=None)
    parser.add_argument('--out', type=str, help='a path to output directory')
    parser.add_argument('--sizes', type=int, nargs='+', default=(256,))
    parser.add_argument('--n_worker', type=int, help='number of worker processes', default=8)
    parser.add_argument('--chunksize', type=int, help='approximate chunksize for each worker', default=10)
    parser.add_argument('--img_format', type=str, default='jpeg')
    args = parser.parse_args()
    prepare_data(**vars(args))

================================================
FILE: third_part/PerceptualSimilarity/models/__init__.py
================================================


================================================
FILE: third_part/PerceptualSimilarity/models/base_model.py
================================================
import os
import torch
from torch.autograd import Variable
from pdb import set_trace as st
from IPython import embed

class BaseModel():
    def __init__(self):
        pass;
        
    def name(self):
        return 'BaseModel'

    def initialize(self, use_gpu=True):
        self.use_gpu = use_gpu
        self.Tensor = torch.cuda.FloatTensor if self.use_gpu else torch.Tensor
        # self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)

    def forward(self):
        pass

    def get_image_paths(self):
        pass

    def optimize_parameters(self):
        pass

    def get_current_visuals(self):
        return self.input

    def get_current_errors(self):
        return {}

    def save(self, label):
        pass

    # helper saving function that can be used by subclasses
    def save_network(self, network, path, network_label, epoch_label):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(path, save_filename)
        torch.save(network.state_dict(), save_path)

    # helper loading function that can be used by subclasses
    def load_network(self, network, network_label, epoch_label):
        # embed()
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.save_dir, save_filename)
        print('Loading network from %s'%save_path)
        network.load_state_dict(torch.load(save_path))

    def update_learning_rate():
        pass

    def get_image_paths(self):
        return self.image_paths

    def save_done(self, flag=False):
        np.save(os.path.join(self.save_dir, 'done_flag'),flag)
        np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')



================================================
FILE: third_part/PerceptualSimilarity/models/dist_model.py
================================================

from __future__ import absolute_import

import sys
sys.path.append('..')
sys.path.append('.')
import numpy as np
import torch
from torch import nn
import os
from collections import OrderedDict
from torch.autograd import Variable
import itertools
from .base_model import BaseModel
from scipy.ndimage import zoom
import fractions
import functools
import skimage.transform
from IPython import embed

from . import networks_basic as networks
from third_part.PerceptualSimilarity.util import util
# from util import util

class DistModel(BaseModel):
    def name(self):
        return self.model_name

    def initialize(self, model='net-lin', net='alex', pnet_rand=False, pnet_tune=False, model_path=None, colorspace='Lab', use_gpu=True, printNet=False, spatial=False, spatial_shape=None, spatial_order=1, spatial_factor=None, is_train=False, lr=.0001, beta1=0.5, version='0.1'):
        '''
        INPUTS
            model - ['net-lin'] for linearly calibrated network
                    ['net'] for off-the-shelf network
                    ['L2'] for L2 distance in Lab colorspace
                    ['SSIM'] for ssim in RGB colorspace
            net - ['squeeze','alex','vgg']
            model_path - if None, will look in weights/[NET_NAME].pth
            colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
            use_gpu - bool - whether or not to use a GPU
            printNet - bool - whether or not to print network architecture out
            spatial - bool - whether to output an array containing varying distances across spatial dimensions
            spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
            spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
            spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
            is_train - bool - [True] for training mode
            lr - float - initial learning rate
            beta1 - float - initial momentum term for adam
            version - 0.1 for latest, 0.0 was original
        '''
        BaseModel.initialize(self, use_gpu=use_gpu)

        self.model = model
        self.net = net
        self.use_gpu = use_gpu
        self.is_train = is_train
        self.spatial = spatial
        self.spatial_shape = spatial_shape
        self.spatial_order = spatial_order
        self.spatial_factor = spatial_factor

        self.model_name = '%s [%s]'%(model,net)
        if(self.model == 'net-lin'): # pretrained net + linear layer
            self.net = networks.PNetLin(use_gpu=use_gpu,pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,use_dropout=True,spatial=spatial,version=version)
            kw = {}
            if not use_gpu:
                kw['map_location'] = 'cpu'
            if(model_path is None):
                import inspect
                # model_path = './PerceptualSimilarity/weights/v%s/%s.pth'%(version,net)
                model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', '..', 'weights/v%s/%s.pth'%(version,net)))

            if(not is_train):
                print('Loading model from: %s'%model_path)
                self.net.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage))

        elif(self.model=='net'): # pretrained network
            assert not self.spatial, 'spatial argument not supported yet for uncalibrated networks'
            self.net = networks.PNet(use_gpu=use_gpu,pnet_type=net)
            self.is_fake_net = True
        elif(self.model in ['L2','l2']):
            self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing
            self.model_name = 'L2'
        elif(self.model in ['DSSIM','dssim','SSIM','ssim']):
            self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace)
            self.model_name = 'SSIM'
        else:
            raise ValueError("Model [%s] not recognized." % self.model)

        self.parameters = list(self.net.parameters())

        if self.is_train: # training mode
            # extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
            self.rankLoss = networks.BCERankingLoss(use_gpu=use_gpu)
            self.parameters+=self.rankLoss.parameters
            self.lr = lr
            self.old_lr = lr
            self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
        else: # test mode
            self.net.eval()

        if(printNet):
            print('---------- Networks initialized -------------')
            networks.print_network(self.net)
            print('-----------------------------------------------')

    def forward_pair(self,in1,in2,retPerLayer=False):
        if(retPerLayer):
            return self.net.forward(in1,in2, retPerLayer=True)
        else:
            return self.net.forward(in1,in2)

    def forward(self, in0, in1, retNumpy=True):
        ''' Function computes the distance between image patches in0 and in1
        INPUTS
            in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
            retNumpy - [False] to return as torch.Tensor, [True] to return as numpy array
        OUTPUT
            computed distances between in0 and in1
        '''

        self.input_ref = in0
        self.input_p0 = in1

        if(self.use_gpu):
            self.input_ref = self.input_ref.cuda()
            self.input_p0 = self.input_p0.cuda()

        self.var_ref = Variable(self.input_ref,requires_grad=True)
        self.var_p0 = Variable(self.input_p0,requires_grad=True)

        self.d0 = self.forward_pair(self.var_ref, self.var_p0)
        self.loss_total = self.d0

        def convert_output(d0):
            if(retNumpy):
                ans = d0.cpu().data.numpy()
                if not self.spatial:
                    ans = ans.flatten()
                else:
                    assert(ans.shape[0] == 1 and len(ans.shape) == 4)
                    return ans[0,...].transpose([1, 2, 0])                  # Reshape to usual numpy image format: (height, width, channels)
                return ans
            else:
                return d0

        if self.spatial:
            L = [convert_output(x) for x in self.d0]
            spatial_shape = self.spatial_shape
            if spatial_shape is None:
                if(self.spatial_factor is None):
                    spatial_shape = (in0.size()[2],in0.size()[3])
                else:
                    spatial_shape = (max([x.shape[0] for x in L])*self.spatial_factor, max([x.shape[1] for x in L])*self.spatial_factor)
            
            L = [skimage.transform.resize(x, spatial_shape, order=self.spatial_order, mode='edge') for x in L]
            
            L = np.mean(np.concatenate(L, 2) * len(L), 2)
            return L
        else:
            return convert_output(self.d0)

    # ***** TRAINING FUNCTIONS *****
    def optimize_parameters(self):
        self.forward_train()
        self.optimizer_net.zero_grad()
        self.backward_train()
        self.optimizer_net.step()
        self.clamp_weights()

    def clamp_weights(self):
        for module in self.net.modules():
            if(hasattr(module, 'weight') and module.kernel_size==(1,1)):
                module.weight.data = torch.clamp(module.weight.data,min=0)

    def set_input(self, data):
        self.input_ref = data['ref']
        self.input_p0 = data['p0']
        self.input_p1 = data['p1']
        self.input_judge = data['judge']

        if(self.use_gpu):
            self.input_ref = self.input_ref.cuda()
            self.input_p0 = self.input_p0.cuda()
            self.input_p1 = self.input_p1.cuda()
            self.input_judge = self.input_judge.cuda()

        self.var_ref = Variable(self.input_ref,requires_grad=True)
        self.var_p0 = Variable(self.input_p0,requires_grad=True)
        self.var_p1 = Variable(self.input_p1,requires_grad=True)

    def forward_train(self): # run forward pass
        self.d0 = self.forward_pair(self.var_ref, self.var_p0)
        self.d1 = self.forward_pair(self.var_ref, self.var_p1)
        self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge)

        # var_judge
        self.var_judge = Variable(1.*self.input_judge).view(self.d0.size())

        self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.)
        return self.loss_total

    def backward_train(self):
        torch.mean(self.loss_total).backward()

    def compute_accuracy(self,d0,d1,judge):
        ''' d0, d1 are Variables, judge is a Tensor '''
        d1_lt_d0 = (d1<d0).cpu().data.numpy().flatten()
        judge_per = judge.cpu().numpy().flatten()
        return d1_lt_d0*judge_per + (1-d1_lt_d0)*(1-judge_per)

    def get_current_errors(self):
        retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
                            ('acc_r', self.acc_r)])

        for key in retDict.keys():
            retDict[key] = np.mean(retDict[key])

        return retDict

    def get_current_visuals(self):
        zoom_factor = 256/self.var_ref.data.size()[2]

        ref_img = util.tensor2im(self.var_ref.data)
        p0_img = util.tensor2im(self.var_p0.data)
        p1_img = util.tensor2im(self.var_p1.data)

        ref_img_vis = zoom(ref_img,[zoom_factor, zoom_factor, 1],order=0)
        p0_img_vis = zoom(p0_img,[zoom_factor, zoom_factor, 1],order=0)
        p1_img_vis = zoom(p1_img,[zoom_factor, zoom_factor, 1],order=0)

        return OrderedDict([('ref', ref_img_vis),
                            ('p0', p0_img_vis),
                            ('p1', p1_img_vis)])

    def save(self, path, label):
        self.save_network(self.net, path, '', label)
        self.save_network(self.rankLoss.net, path, 'rank', label)

    def update_learning_rate(self,nepoch_decay):
        lrd = self.lr / nepoch_decay
        lr = self.old_lr - lrd

        for param_group in self.optimizer_net.param_groups:
            param_group['lr'] = lr

        print('update lr [%s] decay: %f -> %f' % (type,self.old_lr, lr))
        self.old_lr = lr



def score_2afc_dataset(data_loader,func):
    ''' Function computes Two Alternative Forced Choice (2AFC) score using
        distance function 'func' in dataset 'data_loader'
    INPUTS
        data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
        func - callable distance function - calling d=func(in0,in1) should take 2
            pytorch tensors with shape Nx3xXxY, and return numpy array of length N
    OUTPUTS
        [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
        [1] - dictionary with following elements
            d0s,d1s - N arrays containing distances between reference patch to perturbed patches 
            gts - N array in [0,1], preferred patch selected by human evaluators
                (closer to "0" for left patch p0, "1" for right patch p1,
                "0.6" means 60pct people preferred right patch, 40pct preferred left)
            scores - N array in [0,1], corresponding to what percentage function agreed with humans
    CONSTS
        N - number of test triplets in data_loader
    '''

    d0s = []
    d1s = []
    gts = []

    # bar = pb.ProgressBar(max_value=data_loader.load_data().__len__())
    for (i,data) in enumerate(data_loader.load_data()):
        d0s+=func(data['ref'],data['p0']).tolist()
        d1s+=func(data['ref'],data['p1']).tolist()
        gts+=data['judge'].cpu().numpy().flatten().tolist()
        # bar.update(i)

    d0s = np.array(d0s)
    d1s = np.array(d1s)
    gts = np.array(gts)
    scores = (d0s<d1s)*(1.-gts) + (d1s<d0s)*gts + (d1s==d0s)*.5

    return(np.mean(scores), dict(d0s=d0s,d1s=d1s,gts=gts,scores=scores))

def score_jnd_dataset(data_loader,func):
    ''' Function computes JND score using distance function 'func' in dataset 'data_loader'
    INPUTS
        data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
        func - callable distance function - calling d=func(in0,in1) should take 2
            pytorch tensors with shape Nx3xXxY, and return numpy array of length N
    OUTPUTS
        [0] - JND score in [0,1], mAP score (area under precision-recall curve)
        [1] - dictionary with following elements
            ds - N array containing distances between two patches shown to human evaluator
            sames - N array containing fraction of people who thought the two patches were identical
    CONSTS
        N - number of test triplets in data_loader
    '''

    ds = []
    gts = []

    # bar = pb.ProgressBar(max_value=data_loader.load_data().__len__())
    for (i,data) in enumerate(data_loader.load_data()):
        ds+=func(data['p0'],data['p1']).tolist()
        gts+=data['same'].cpu().numpy().flatten().tolist()
        # bar.update(i)

    sames = np.array(gts)
    ds = np.array(ds)

    sorted_inds = np.argsort(ds)
    ds_sorted = ds[sorted_inds]
    sames_sorted = sames[sorted_inds]

    TPs = np.cumsum(sames_sorted)
    FPs = np.cumsum(1-sames_sorted)
    FNs = np.sum(sames_sorted)-TPs

    precs = TPs/(TPs+FPs)
    recs = TPs/(TPs+FNs)
    score = util.voc_ap(recs,precs)

    return(score, dict(ds=ds,sames=sames))


================================================
FILE: third_part/PerceptualSimilarity/models/models.py
================================================
from __future__ import absolute_import

def create_model(opt):
    model = None
    print(opt.model)
    from .siam_model import *
    model = DistModel()
    model.initialize(opt, opt.batchSize, )
    print("model [%s] was created" % (model.name()))
    return model



================================================
FILE: third_part/PerceptualSimilarity/models/networks_basic.py
================================================

from __future__ import absolute_import

import sys
sys.path.append('..')
sys.path.append('.')
import torch
import torch.nn as nn
import torch.nn.init as init
from torch.autograd import Variable
import numpy as np
from pdb import set_trace as st
from skimage import color
from IPython import embed
from . import pretrained_networks as pn

# from .PerceptualSimilarity.util import util
from ..util import util

# Off-the-shelf deep network
class PNet(nn.Module):
    '''Pre-trained network with all channels equally weighted by default'''
    def __init__(self, pnet_type='vgg', pnet_rand=False, use_gpu=True):
        super(PNet, self).__init__()

        self.use_gpu = use_gpu

        self.pnet_type = pnet_type
        self.pnet_rand = pnet_rand

        self.shift = torch.autograd.Variable(torch.Tensor([-.030, -.088, -.188]).view(1,3,1,1))
        self.scale = torch.autograd.Variable(torch.Tensor([.458, .448, .450]).view(1,3,1,1))
        
        if(self.pnet_type in ['vgg','vgg16']):
            self.net = pn.vgg16(pretrained=not self.pnet_rand,requires_grad=False)
        elif(self.pnet_type=='alex'):
            self.net = pn.alexnet(pretrained=not self.pnet_rand,requires_grad=False)
        elif(self.pnet_type[:-2]=='resnet'):
            self.net = pn.resnet(pretrained=not self.pnet_rand,requires_grad=False, num=int(self.pnet_type[-2:]))
        elif(self.pnet_type=='squeeze'):
            self.net = pn.squeezenet(pretrained=not self.pnet_rand,requires_grad=False)

        self.L = self.net.N_slices

        if(use_gpu):
            self.net.cuda()
            self.shift = self.shift.cuda()
            self.scale = self.scale.cuda()

    def forward(self, in0, in1, retPerLayer=False):
        in0_sc = (in0 - self.shift.expand_as(in0))/self.scale.expand_as(in0)
        in1_sc = (in1 - self.shift.expand_as(in0))/self.scale.expand_as(in0)

        outs0 = self.net.forward(in0_sc)
        outs1 = self.net.forward(in1_sc)

        if(retPerLayer):
            all_scores = []
        for (kk,out0) in enumerate(outs0):
            cur_score = (1.-util.cos_sim(outs0[kk],outs1[kk]))
            if(kk==0):
                val = 1.*cur_score
            else:
                # val = val + self.lambda_feat_layers[kk]*cur_score
                val = val + cur_score
            if(retPerLayer):
                all_scores+=[cur_score]

        if(retPerLayer):
            return (val, all_scores)
        else:
            return val

# Learned perceptual metric
class PNetLin(nn.Module):
    def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, use_gpu=True, spatial=False, version='0.1'):
        super(PNetLin, self).__init__()

        self.use_gpu = use_gpu
        self.pnet_type = pnet_type
        self.pnet_tune = pnet_tune
        self.pnet_rand = pnet_rand
        self.spatial = spatial
        self.version = version

        if(self.pnet_type in ['vgg','vgg16']):
            net_type = pn.vgg16
            self.chns = [64,128,256,512,512]
        elif(self.pnet_type=='alex'):
            net_type = pn.alexnet
            self.chns = [64,192,384,256,256]
        elif(self.pnet_type=='squeeze'):
            net_type = pn.squeezenet
            self.chns = [64,128,256,384,384,512,512]

        if(self.pnet_tune):
            self.net = net_type(pretrained=not self.pnet_rand,requires_grad=True)
        else:
            self.net = [net_type(pretrained=not self.pnet_rand,requires_grad=False),]

        self.lin0 = NetLinLayer(self.chns[0],use_dropout=use_dropout)
        self.lin1 = NetLinLayer(self.chns[1],use_dropout=use_dropout)
        self.lin2 = NetLinLayer(self.chns[2],use_dropout=use_dropout)
        self.lin3 = NetLinLayer(self.chns[3],use_dropout=use_dropout)
        self.lin4 = NetLinLayer(self.chns[4],use_dropout=use_dropout)
        self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
        if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
            self.lin5 = NetLinLayer(self.chns[5],use_dropout=use_dropout)
            self.lin6 = NetLinLayer(self.chns[6],use_dropout=use_dropout)
            self.lins+=[self.lin5,self.lin6]

        self.shift = torch.autograd.Variable(torch.Tensor([-.030, -.088, -.188]).view(1,3,1,1))
        self.scale = torch.autograd.Variable(torch.Tensor([.458, .448, .450]).view(1,3,1,1))

        if(use_gpu):
            if(self.pnet_tune):
                self.net.cuda()
            else:
                self.net[0].cuda()
            self.shift = self.shift.cuda()
            self.scale = self.scale.cuda()
            self.lin0.cuda()
            self.lin1.cuda()
            self.lin2.cuda()
            self.lin3.cuda()
            self.lin4.cuda()
            if(self.pnet_type=='squeeze'):
                self.lin5.cuda()
                self.lin6.cuda()

    def forward(self, in0, in1):
        in0_sc = (in0 - self.shift.expand_as(in0))/self.scale.expand_as(in0)
        in1_sc = (in1 - self.shift.expand_as(in0))/self.scale.expand_as(in0)

        if(self.version=='0.0'):
            # v0.0 - original release had a bug, where input was not scaled
            in0_input = in0
            in1_input = in1
        else:
            # v0.1
            in0_input = in0_sc
            in1_input = in1_sc

        if(self.pnet_tune):
            outs0 = self.net.forward(in0_input)
            outs1 = self.net.forward(in1_input)
        else:
            outs0 = self.net[0].forward(in0_input)
            outs1 = self.net[0].forward(in1_input)

        feats0 = {}
        feats1 = {}
        diffs = [0]*len(outs0)

        for (kk,out0) in enumerate(outs0):
            feats0[kk] = util.normalize_tensor(outs0[kk])
            feats1[kk] = util.normalize_tensor(outs1[kk])
            diffs[kk] = (feats0[kk]-feats1[kk])**2

        if self.spatial:
            lin_models = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
            if(self.pnet_type=='squeeze'):
                lin_models.extend([self.lin5, self.lin6])
            res = [lin_models[kk].model(diffs[kk]) for kk in range(len(diffs))]
            return res
			
        val = torch.mean(torch.mean(self.lin0.model(diffs[0]),dim=3),dim=2)
        val = val + torch.mean(torch.mean(self.lin1.model(diffs[1]),dim=3),dim=2)
        val = val + torch.mean(torch.mean(self.lin2.model(diffs[2]),dim=3),dim=2)
        val = val + torch.mean(torch.mean(self.lin3.model(diffs[3]),dim=3),dim=2)
        val = val + torch.mean(torch.mean(self.lin4.model(diffs[4]),dim=3),dim=2)
        if(self.pnet_type=='squeeze'):
            val = val + torch.mean(torch.mean(self.lin5.model(diffs[5]),dim=3),dim=2)
            val = val + torch.mean(torch.mean(self.lin6.model(diffs[6]),dim=3),dim=2)

        val = val.view(val.size()[0],val.size()[1],1,1)

        return val

class Dist2LogitLayer(nn.Module):
    ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
    def __init__(self, chn_mid=32,use_sigmoid=True):
        super(Dist2LogitLayer, self).__init__()
        layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
        layers += [nn.LeakyReLU(0.2,True),]
        layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
        layers += [nn.LeakyReLU(0.2,True),]
        layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
        if(use_sigmoid):
            layers += [nn.Sigmoid(),]
        self.model = nn.Sequential(*layers)

    def forward(self,d0,d1,eps=0.1):
        return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))

class BCERankingLoss(nn.Module):
    def __init__(self, use_gpu=True, chn_mid=32):
        super(BCERankingLoss, self).__init__()
        self.use_gpu = use_gpu
        self.net = Dist2LogitLayer(chn_mid=chn_mid)
        self.parameters = list(self.net.parameters())
        self.loss = torch.nn.BCELoss()
        self.model = nn.Sequential(*[self.net])

        if(self.use_gpu):
            self.net.cuda()

    def forward(self, d0, d1, judge):
        per = (judge+1.)/2.
        if(self.use_gpu):
            per = per.cuda()
        self.logit = self.net.forward(d0,d1)
        return self.loss(self.logit, per)

class NetLinLayer(nn.Module):
    ''' A single linear layer which does a 1x1 conv '''
    def __init__(self, chn_in, chn_out=1, use_dropout=False):
        super(NetLinLayer, self).__init__()

        layers = [nn.Dropout(),] if(use_dropout) else []
        layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
        self.model = nn.Sequential(*layers)


# L2, DSSIM metrics
class FakeNet(nn.Module):
    def __init__(self, use_gpu=True, colorspace='Lab'):
        super(FakeNet, self).__init__()
        self.use_gpu = use_gpu
        self.colorspace=colorspace

class L2(FakeNet):

    def forward(self, in0, in1):
        assert(in0.size()[0]==1) # currently only supports batchSize 1

        if(self.colorspace=='RGB'):
            (N,C,X,Y) = in0.size()
            value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
            return value
        elif(self.colorspace=='Lab'):
            value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 
                util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
            ret_var = Variable( torch.Tensor((value,) ) )
            if(self.use_gpu):
                ret_var = ret_var.cuda()
            return ret_var

class DSSIM(FakeNet):

    def forward(self, in0, in1):
        assert(in0.size()[0]==1) # currently only supports batchSize 1

        if(self.colorspace=='RGB'):
            value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
        elif(self.colorspace=='Lab'):
            value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 
                util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
        ret_var = Variable( torch.Tensor((value,) ) )
        if(self.use_gpu):
            ret_var = ret_var.cuda()
        return ret_var

def print_network(net):
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print('Network',net)
    print('Total number of parameters: %d' % num_params)


================================================
FILE: third_part/PerceptualSimilarity/models/pretrained_networks.py
================================================
from collections import namedtuple
import torch
from torchvision import models
from IPython import embed

class squeezenet(torch.nn.Module):
    def __init__(self, requires_grad=False, pretrained=True):
        super(squeezenet, self).__init__()
        pretrained_features = models.squeezenet1_1(pretrained=pretrained).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        self.slice6 = torch.nn.Sequential()
        self.slice7 = torch.nn.Sequential()
        self.N_slices = 7
        for x in range(2):
            self.slice1.add_module(str(x), pretrained_features[x])
        for x in range(2,5):
            self.slice2.add_module(str(x), pretrained_features[x])
        for x in range(5, 8):
            self.slice3.add_module(str(x), pretrained_features[x])
        for x in range(8, 10):
            self.slice4.add_module(str(x), pretrained_features[x])
        for x in range(10, 11):
            self.slice5.add_module(str(x), pretrained_features[x])
        for x in range(11, 12):
            self.slice6.add_module(str(x), pretrained_features[x])
        for x in range(12, 13):
            self.slice7.add_module(str(x), pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu1 = h
        h = self.slice2(h)
        h_relu2 = h
        h = self.slice3(h)
        h_relu3 = h
        h = self.slice4(h)
        h_relu4 = h
        h = self.slice5(h)
        h_relu5 = h
        h = self.slice6(h)
        h_relu6 = h
        h = self.slice7(h)
        h_relu7 = h
        vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
        out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)

        return out


class alexnet(torch.nn.Module):
    def __init__(self, requires_grad=False, pretrained=True):
        super(alexnet, self).__init__()
        alexnet_pretrained_features = models.alexnet(pretrained=pretrained).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        self.N_slices = 5
        for x in range(2):
            self.slice1.add_module(str(x), alexnet_pretrained_features[x])
        for x in range(2, 5):
            self.slice2.add_module(str(x), alexnet_pretrained_features[x])
        for x in range(5, 8):
            self.slice3.add_module(str(x), alexnet_pretrained_features[x])
        for x in range(8, 10):
            self.slice4.add_module(str(x), alexnet_pretrained_features[x])
        for x in range(10, 12):
            self.slice5.add_module(str(x), alexnet_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu1 = h
        h = self.slice2(h)
        h_relu2 = h
        h = self.slice3(h)
        h_relu3 = h
        h = self.slice4(h)
        h_relu4 = h
        h = self.slice5(h)
        h_relu5 = h
        alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
        out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)

        return out

class vgg16(torch.nn.Module):
    def __init__(self, requires_grad=False, pretrained=True):
        super(vgg16, self).__init__()
        vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        self.N_slices = 5
        for x in range(4):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(4, 9):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(9, 16):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(16, 23):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(23, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu1_2 = h
        h = self.slice2(h)
        h_relu2_2 = h
        h = self.slice3(h)
        h_relu3_3 = h
        h = self.slice4(h)
        h_relu4_3 = h
        h = self.slice5(h)
        h_relu5_3 = h
        vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)

        return out



class resnet(torch.nn.Module):
    def __init__(self, requires_grad=False, pretrained=True, num=18):
        super(resnet, self).__init__()
        if(num==18):
            self.net = models.resnet18(pretrained=pretrained)
        elif(num==34):
            self.net = models.resnet34(pretrained=pretrained)
        elif(num==50):
            self.net = models.resnet50(pretrained=pretrained)
        elif(num==101):
            self.net = models.resnet101(pretrained=pretrained)
        elif(num==152):
            self.net = models.resnet152(pretrained=pretrained)
        self.N_slices = 5

        self.conv1 = self.net.conv1
        self.bn1 = self.net.bn1
        self.relu = self.net.relu
        self.maxpool = self.net.maxpool
        self.layer1 = self.net.layer1
        self.layer2 = self.net.layer2
        self.layer3 = self.net.layer3
        self.layer4 = self.net.layer4

    def forward(self, X):
        h = self.conv1(X)
        h = self.bn1(h)
        h = self.relu(h)
        h_relu1 = h
        h = self.maxpool(h)
        h = self.layer1(h)
        h_conv2 = h
        h = self.layer2(h)
        h_conv3 = h
        h = self.layer3(h)
        h_conv4 = h
        h = self.layer4(h)
        h_conv5 = h

        outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
        out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)

        return out


================================================
FILE: third_part/PerceptualSimilarity/util/__init__.py
================================================


================================================
FILE: third_part/PerceptualSimilarity/util/html.py
================================================
import dominate
from dominate.tags import *
import os


class HTML:
    def __init__(self, web_dir, title, image_subdir='', reflesh=0):
        self.title = title
        self.web_dir = web_dir
        # self.img_dir = os.path.join(self.web_dir, )
        self.img_subdir = image_subdir
        self.img_dir = os.path.join(self.web_dir, image_subdir)
        if not os.path.exists(self.web_dir):
            os.makedirs(self.web_dir)
        if not os.path.exists(self.img_dir):
            os.makedirs(self.img_dir)
        # print(self.img_dir)

        self.doc = dominate.document(title=title)
        if reflesh > 0:
            with self.doc.head:
                meta(http_equiv="reflesh", content=str(reflesh))

    def get_image_dir(self):
        return self.img_dir

    def add_header(self, str):
        with self.doc:
            h3(str)

    def add_table(self, border=1):
        self.t = table(border=border, style="table-layout: fixed;")
        self.doc.add(self.t)

    def add_images(self, ims, txts, links, width=400):
        self.add_table()
        with self.t:
            with tr():
                for im, txt, link in zip(ims, txts, links):
                    with td(style="word-wrap: break-word;", halign="center", valign="top"):
                        with p():
                            with a(href=os.path.join(link)):
                                img(style="width:%dpx" % width, src=os.path.join(im))
                            br()
                            p(txt)

    def save(self,file='index'):
        html_file = '%s/%s.html' % (self.web_dir,file)
        f = open(html_file, 'wt')
        f.write(self.doc.render())
        f.close()


if __name__ == '__main__':
    html = HTML('web/', 'test_html')
    html.add_header('hello world')

    ims = []
    txts = []
    links = []
    for n in range(4):
        ims.append('image_%d.png' % n)
        txts.append('text_%d' % n)
        links.append('image_%d.png' % n)
    html.add_images(ims, txts, links)
    html.save()


================================================
FILE: third_part/PerceptualSimilarity/util/util.py
================================================
from __future__ import print_function

import numpy as np
from PIL import Image
import inspect
import re
import numpy as np
import os
import collections
import matplotlib.pyplot as plt
from scipy.ndimage.interpolation import zoom
from skimage.measure import compare_ssim
import torch
from IPython import embed
import cv2
from datetime import datetime

def datetime_str():
    now = datetime.now()
    return '%04d-%02d-%02d-%02d-%02d-%02d'%(now.year,now.month,now.day,now.hour,now.minute,now.second)

def read_text_file(in_path):
    fid = open(in_path,'r')

    vals = []
    cur_line = fid.readline()
    while(cur_line!=''):
        vals.append(float(cur_line))
        cur_line = fid.readline()

    fid.close()
    return np.array(vals)

def bootstrap(in_vec,num_samples=100,bootfunc=np.mean):
    from astropy import stats
    return stats.bootstrap(np.array(in_vec),bootnum=num_samples,bootfunc=bootfunc)

def rand_flip(input1,input2):
    if(np.random.binomial(1,.5)==1):
        return (input1,input2)
    else:
        return (input2,input1)

def l2(p0, p1, range=255.):
    return .5*np.mean((p0 / range - p1 / range)**2)

def psnr(p0, p1, peak=255.):
    return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))

def dssim(p0, p1, range=255.):
    # embed()
    return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.

def rgb2lab(in_img,mean_cent=False):
    from skimage import color
    img_lab = color.rgb2lab(in_img)
    if(mean_cent):
        img_lab[:,:,0] = img_lab[:,:,0]-50
    return img_lab

def normalize_blob(in_feat,eps=1e-10):
    norm_factor = np.sqrt(np.sum(in_feat**2,axis=1,keepdims=True))
    return in_feat/(norm_factor+eps)

def cos_sim_blob(in0,in1):
    in0_norm = normalize_blob(in0)
    in1_norm = normalize_blob(in1)
    (N,C,X,Y) = in0_norm.shape

    return np.mean(np.mean(np.sum(in0_norm*in1_norm,axis=1),axis=1),axis=1)

def normalize_tensor(in_feat,eps=1e-10):
    # norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1)).view(in_feat.size()[0],1,in_feat.size()[2],in_feat.size()[3]).repeat(1,in_feat.size()[1],1,1)
    norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1)).view(in_feat.size()[0],1,in_feat.size()[2],in_feat.size()[3])
    return in_feat/(norm_factor.expand_as(in_feat)+eps)

def cos_sim(in0,in1):
    in0_norm = normalize_tensor(in0)
    in1_norm = normalize_tensor(in1)
    N = in0.size()[0]
    X = in0.size()[2]
    Y = in0.size()[3]

    return torch.mean(torch.mean(torch.sum(in0_norm*in1_norm,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)

# Converts a Tensor into a Numpy array
# |imtype|: the desired type of the conve

def tensor2np(tensor_obj):
    # change dimension of a tensor object into a numpy array
    return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))

def np2tensor(np_obj):
     # change dimenion of np array into tensor array
    return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))

def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
    # image tensor to lab tensor
    from skimage import color

    img = tensor2im(image_tensor)
    # print('img_rgb',img.flatten())
    img_lab = color.rgb2lab(img)
    # print('img_lab',img_lab.flatten())
    if(mc_only):
        img_lab[:,:,0] = img_lab[:,:,0]-50
    if(to_norm and not mc_only):
        img_lab[:,:,0] = img_lab[:,:,0]-50
        img_lab = img_lab/100.

    return np2tensor(img_lab)

def tensorlab2tensor(lab_tensor,return_inbnd=False):
    from skimage import color
    import warnings
    warnings.filterwarnings("ignore")

    lab = tensor2np(lab_tensor)*100.
    lab[:,:,0] = lab[:,:,0]+50
    # print('lab',lab)

    rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
    # print('rgb',rgb_back)
    if(return_inbnd):
        # convert back to lab, see if we match
        lab_back = color.rgb2lab(rgb_back.astype('uint8'))
        # print('lab_back',lab_back)
        # print('lab==lab_back',np.isclose(lab_back,lab,atol=1.))
        # print('lab-lab_back',np.abs(lab-lab_back))
        mask = 1.*np.isclose(lab_back,lab,atol=2.)
        mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
        return (im2tensor(rgb_back),mask)
    else:
        return im2tensor(rgb_back)

def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
    image_numpy = image_tensor[0].cpu().float().numpy()
    image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
    return image_numpy.astype(imtype)

def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
    return torch.Tensor((image / factor - cent)
                        [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))

def tensor2vec(vector_tensor):
    return vector_tensor.data.cpu().numpy()[:, :, 0, 0]

def diagnose_network(net, name='network'):
    mean = 0.0
    count = 0
    for param in net.parameters():
        if param.grad is not None:
            mean += torch.mean(torch.abs(param.grad.data))
            count += 1
    if count > 0:
        mean = mean / count
    print(name)
    print(mean)

def grab_patch(img_in, P, yy, xx):
    return img_in[yy:yy+P,xx:xx+P,:]

def load_image(path):
    if(path[-3:] == 'dng'):
        import rawpy
        with rawpy.imread(path) as raw:
            img = raw.postprocess()
        # img = plt.imread(path)
    elif(path[-3:]=='bmp' or path[-3:]=='jpg' or path[-3:]=='png'):
        import cv2
        return cv2.imread(path)[:,:,::-1]
    else:
        img = (255*plt.imread(path)[:,:,:3]).astype('uint8')

    return img


def resize_image(img, max_size=256):
    [Y, X] = img.shape[:2]

    # resize
    max_dim = max([Y, X])
    zoom_factor = 1. * max_size / max_dim
    img = zoom(img, [zoom_factor, zoom_factor, 1])

    return img

def resize_image_zoom(img, zoom_factor=1., order=3):
    if(zoom_factor==1):
        return img
    else:
        return zoom(img, [zoom_factor, zoom_factor, 1], order=order)

def save_image(image_numpy, image_path, ):
    image_pil = Image.fromarray(image_numpy)
    image_pil.save(image_path)


def prep_display_image(img, dtype='uint8'):
    if(dtype == 'uint8'):
        return np.clip(img, 0, 255).astype('uint8')
    else:
        return np.clip(img, 0, 1.)


def info(object, spacing=10, collapse=1):
    """Print methods and doc strings.
    Takes module, class, list, dictionary, or string."""
    methodList = [
        e for e in dir(object) if isinstance(
            getattr(
                object,
                e),
            collections.Callable)]
    processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s)
    print("\n".join(["%s %s" %
                     (method.ljust(spacing),
                      processFunc(str(getattr(object, method).__doc__)))
                     for method in methodList]))


def varname(p):
    for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]:
        m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line)
        if m:
            return m.group(1)


def print_numpy(x, val=True, shp=False):
    x = x.astype(np.float64)
    if shp:
        print('shape,', x.shape)
    if val:
        x = x.flatten()
        print(
            'mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' %
            (np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))


def mkdirs(paths):
    if isinstance(paths, list) and not isinstance(paths, str):
        for path in paths:
            mkdir(path)
    else:
        mkdir(paths)


def mkdir(path):
    if not os.path.exists(path):
        os.makedirs(path)


def rgb2lab(input):
    from skimage import color
    return color.rgb2lab(input / 255.)


def montage(
    imgs,
    PAD=5,
    RATIO=16 / 9.,
    EXTRA_PAD=(
        False,
        False),
        MM=-1,
        NN=-1,
        primeDir=0,
        verbose=False,
        returnGridPos=False,
        backClr=np.array(
            (0,
             0,
             0))):
    # INPUTS
    #   imgs        YxXxMxN or YxXxN
    #   PAD         scalar              number of pixels in between
    #   RATIO       scalar              target ratio of cols/rows
    #   MM          scalar              # rows, if specified, overrides RATIO
    #   NN          scalar              # columns, if specified, overrides RATIO
    #   primeDir    scalar              0 for top-to-bottom, 1 for left-to-right
    # OUTPUTS
    #   mont_imgs   MM*Y x NN*X x M     big image with everything montaged
    # def montage(imgs, PAD=5, RATIO=16/9., MM=-1, NN=-1, primeDir=0,
    # verbose=False, forceFloat=False):
    if(imgs.ndim == 3):
        toExp = True
        imgs = imgs[:, :, np.newaxis, :]
    else:
        toExp = False

    Y = imgs.shape[0]
    X = imgs.shape[1]
    M = imgs.shape[2]
    N = imgs.shape[3]

    PADS = np.array((PAD))
    if(PADS.flatten().size == 1):
        PADY = PADS
        PADX = PADS
    else:
        PADY = PADS[0]
        PADX = PADS[1]

    if(MM == -1 and NN == -1):
        NN = np.ceil(np.sqrt(1.0 * N * RATIO))
        MM = np.ceil(1.0 * N / NN)
        NN = np.ceil(1.0 * N / MM)
    elif(MM == -1):
        MM = np.ceil(1.0 * N / NN)
    elif(NN == -1):
        NN = np.ceil(1.0 * N / MM)

    if(primeDir == 0):  # write top-to-bottom
        [grid_mm, grid_nn] = np.meshgrid(
            np.arange(MM, dtype='uint'), np.arange(NN, dtype='uint'))
    elif(primeDir == 1):  # write left-to-right
        [grid_nn, grid_mm] = np.meshgrid(
            np.arange(NN, dtype='uint'), np.arange(MM, dtype='uint'))

    grid_mm = np.uint(grid_mm.flatten()[0:N])
    grid_nn = np.uint(grid_nn.flatten()[0:N])

    EXTRA_PADY = EXTRA_PAD[0] * PADY
    EXTRA_PADX = EXTRA_PAD[0] * PADX

    # mont_imgs = np.zeros(((Y+PAD)*MM-PAD, (X+PAD)*NN-PAD, M), dtype=use_dtype)
    mont_imgs = np.zeros(
        (np.uint(
            (Y + PADY) * MM - PADY + EXTRA_PADY),
            np.uint(
            (X + PADX) * NN - PADX + EXTRA_PADX),
            M),
        dtype=imgs.dtype)
    mont_imgs = mont_imgs + \
        backClr.flatten()[np.newaxis, np.newaxis, :].astype(mont_imgs.dtype)

    for ii in np.random.permutation(N):
        # print imgs[:,:,:,ii].shape
        # mont_imgs[grid_mm[ii]*(Y+PAD):(grid_mm[ii]*(Y+PAD)+Y), grid_nn[ii]*(X+PAD):(grid_nn[ii]*(X+PAD)+X),:]
        mont_imgs[np.uint(grid_mm[ii] *
                          (Y +
                           PADY)):np.uint((grid_mm[ii] *
                                           (Y +
                                            PADY) +
                                           Y)), np.uint(grid_nn[ii] *
                                                        (X +
                                                         PADX)):np.uint((grid_nn[ii] *
                                                                         (X +
                                                                          PADX) +
                                                                         X)), :] = imgs[:, :, :, ii]

    if(M == 1):
        imgs = imgs.reshape(imgs.shape[0], imgs.shape[1], imgs.shape[3])

    if(toExp):
        mont_imgs = mont_imgs[:, :, 0]

    if(returnGridPos):
        # return (mont_imgs,np.concatenate((grid_mm[:,:,np.newaxis]*(Y+PAD),
        # grid_nn[:,:,np.newaxis]*(X+PAD)),axis=2))
        return (mont_imgs, np.concatenate(
            (grid_mm[:, np.newaxis] * (Y + PADY), grid_nn[:, np.newaxis] * (X + PADX)), axis=1))
        # return (mont_imgs, (grid_mm,grid_nn))
    else:
        return mont_imgs

class zeroClipper(object):
    def __init__(self, frequency=1):
        self.frequency = frequency

    def __call__(self, module):
        embed()
        if hasattr(module, 'weight'):
            # module.weight.data = torch.max(module.weight.data, 0)
            module.weight.data = torch.max(module.weight.data, 0) + 100

def flatten_nested_list(nested_list):
    # only works for list of list
    accum = []
    for sublist in nested_list:
        for item in sublist:
            accum.append(item)
    return accum

def read_file(in_path,list_lines=False):
    agg_str = ''
    f = open(in_path,'r')
    cur_line = f.readline()
    while(cur_line!=''):
        agg_str+=cur_line
        cur_line = f.readline()
    f.close()
    if(list_lines==False):
        return agg_str.replace('\n','')
    else:
        line_list = agg_str.split('\n')
        ret_list = []
        for item in line_list:
            if(item!=''):
                ret_list.append(item)
        return ret_list

def read_csv_file_as_text(in_path):
    agg_str = []
    f = open(in_path,'r')
    cur_line = f.readline()
    while(cur_line!=''):
        agg_str.append(cur_line)
        cur_line = f.readline()
    f.close()
    return agg_str

def random_swap(obj0,obj1):
    if(np.random.rand() < .5):
        return (obj0,obj1,0)
    else:
        return (obj1,obj0,1)

def voc_ap(rec, prec, use_07_metric=False):
    """ ap = voc_ap(rec, prec, [use_07_metric])
    Compute VOC AP given precision and recall.
    If use_07_metric is true, uses the
    VOC 07 11 point method (default:False).
    """
    if use_07_metric:
        # 11 point metric
        ap = 0.
        for t in np.arange(0., 1.1, 0.1):
            if np.sum(rec >= t) == 0:
                p = 0
            else:
                p = np.max(prec[rec >= t])
            ap = ap + p / 11.
    else:
        # correct AP calculation
        # first append sentinel values at the end
        mrec = np.concatenate(([0.], rec, [1.]))
        mpre = np.concatenate(([0.], prec, [0.]))

        # compute the precision envelope
        for i in range(mpre.size - 1, 0, -1):
            mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

        # to calculate area under PR curve, look for points
        # where X axis (recall) changes value
        i = np.where(mrec[1:] != mrec[:-1])[0]

        # and sum (\Delta recall) * prec
        ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    return ap


================================================
FILE: third_part/PerceptualSimilarity/util/visualizer.py
================================================
import numpy as np
import os
import time
from . import util
from . import html
# from pdb import set_trace as st
import matplotlib.pyplot as plt
import math
# from IPython import embed

def zoom_to_res(img,res=256,order=0,axis=0):
    # img   3xXxX
    from scipy.ndimage import zoom
    zoom_factor = res/img.shape[1]
    if(axis==0):
        return zoom(img,[1,zoom_factor,zoom_factor],order=order)
    elif(axis==2):
        return zoom(img,[zoom_factor,zoom_factor,1],order=order)

class Visualizer():
    def __init__(self, opt):
        # self.opt = opt
        self.display_id = opt.display_id
        # self.use_html = opt.is_train and not opt.no_html
        self.win_size = opt.display_winsize
        self.name = opt.name
        self.display_cnt = 0 # display_current_results counter
        self.display_cnt_high = 0
        self.use_html = opt.use_html

        if self.display_id > 0:
            import visdom
            self.vis = visdom.Visdom(port = opt.display_port)

        self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
        util.mkdirs([self.web_dir,])
        if self.use_html:
            self.img_dir = os.path.join(self.web_dir, 'images')
            print('create web directory %s...' % self.web_dir)
            util.mkdirs([self.img_dir,])

    # |visuals|: dictionary of images to display or save
    def display_current_results(self, visuals, epoch, nrows=None, res=256):
        if self.display_id > 0: # show images in the browser
            title = self.name
            if(nrows is None):
                nrows = int(math.ceil(len(visuals.items()) / 2.0))
            images = []
            idx = 0
            for label, image_numpy in visuals.items():
                title += " | " if idx % nrows == 0 else ", "
                title += label
                img = image_numpy.transpose([2, 0, 1])
                img = zoom_to_res(img,res=res,order=0)
                images.append(img)
                idx += 1
            if len(visuals.items()) % 2 != 0:
                white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255
                white_image = zoom_to_res(white_image,res=res,order=0)
                images.append(white_image)
            self.vis.images(images, nrow=nrows, win=self.display_id + 1,
                            opts=dict(title=title))

        if self.use_html: # save images to a html file
            for label, image_numpy in visuals.items():
                img_path = os.path.join(self.img_dir, 'epoch%.3d_cnt%.6d_%s.png' % (epoch, self.display_cnt, label))
                util.save_image(zoom_to_res(image_numpy, res=res, axis=2), img_path)

            self.display_cnt += 1
            self.display_cnt_high = np.maximum(self.display_cnt_high, self.display_cnt)

            # update website
            webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)
            for n in range(epoch, 0, -1):
                webpage.add_header('epoch [%d]' % n)
                if(n==epoch):
                    high = self.display_cnt
                else:
                    high = self.display_cnt_high
                for c in range(high-1,-1,-1):
                    ims = []
                    txts = []
                    links = []

                    for label, image_numpy in visuals.items():
                        img_path = 'epoch%.3d_cnt%.6d_%s.png' % (n, c, label)
                        ims.append(os.path.join('images',img_path))
                        txts.append(label)
                        links.append(os.path.join('images',img_path))
                    webpage.add_images(ims, txts, links, width=self.win_size)
            webpage.save()

    # save errors into a directory
    def plot_current_errors_save(self, epoch, counter_ratio, opt, errors,keys='+ALL',name='loss', to_plot=False):
        if not hasattr(self, 'plot_data'):
            self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())}
        self.plot_data['X'].append(epoch + counter_ratio)
        self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])

        # embed()
        if(keys=='+ALL'):
            plot_keys = self.plot_data['legend']
        else:
            plot_keys = keys

        if(to_plot):
            (f,ax) = plt.subplots(1,1)
        for (k,kname) in enumerate(plot_keys):
            kk = np.where(np.array(self.plot_data['legend'])==kname)[0][0]
            x = self.plot_data['X']
            y = np.array(self.plot_data['Y'])[:,kk]
            if(to_plot):
                ax.plot(x, y, 'o-', label=kname)
            np.save(os.path.join(self.web_dir,'%s_x')%kname,x)
            np.save(os.path.join(self.web_dir,'%s_y')%kname,y)

        if(to_plot):
            plt.legend(loc=0,fontsize='small')
            plt.xlabel('epoch')
            plt.ylabel('Value')
            f.savefig(os.path.join(self.web_dir,'%s.png'%name))
            f.clf()
            plt.close()

    # errors: dictionary of error labels and values
    def plot_current_errors(self, epoch, counter_ratio, opt, errors):
        if not hasattr(self, 'plot_data'):
            self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())}
        self.plot_data['X'].append(epoch + counter_ratio)
        self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])
        self.vis.line(
            X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1),
            Y=np.array(self.plot_data['Y']),
            opts={
                'title': self.name + ' loss over time',
                'legend': self.plot_data['legend'],
                'xlabel': 'epoch',
                'ylabel': 'loss'},
            win=self.display_id)

    # errors: same format as |errors| of plotCurrentErrors
    def print_current_errors(self, epoch, i, errors, t, t2=-1, t2o=-1, fid=None):
        message = '(ep: %d, it: %d, t: %.3f[s], ept: %.2f/%.2f[h]) ' % (epoch, i, t, t2o, t2)
        message += (', ').join(['%s: %.3f' % (k, v) for k, v in errors.items()])

        print(message)
        if(fid is not None):
            fid.write('%s\n'%message)


    # save image to the disk
    def save_images_simple(self, webpage, images, names, in_txts, prefix='', res=256):
        image_dir = webpage.get_image_dir()
        ims = []
        txts = []
        links = []

        for name, image_numpy, txt in zip(names, images, in_txts):
            image_name = '%s_%s.png' % (prefix, name)
            save_path = os.path.join(image_dir, image_name)
            if(res is not None):
                util.save_image(zoom_to_res(image_numpy,res=res,axis=2), save_path)
            else:
                util.save_image(image_numpy, save_path)

            ims.append(os.path.join(webpage.img_subdir,image_name))
            # txts.append(name)
            txts.append(txt)
            links.append(os.path.join(webpage.img_subdir,image_name))
        # embed()
        webpage.add_images(ims, txts, links, width=self.win_size)

    # save image to the disk
    def save_images(self, webpage, images, names, image_path, title=''):
        image_dir = webpage.get_image_dir()
        # short_path = ntpath.basename(image_path)
        # name = os.path.splitext(short_path)[0]
        # name = short_path
        # webpage.add_header('%s, %s' % (name, title))
        ims = []
        txts = []
        links = []

        for label, image_numpy in zip(names, images):
            image_name = '%s.jpg' % (label,)
            save_path = os.path.join(image_dir, image_name)
            util.save_image(image_numpy, save_path)

            ims.append(image_name)
            txts.append(label)
            links.append(image_name)
        webpage.add_images(ims, txts, links, width=self.win_size)

    # save image to the disk
    # def save_images(self, webpage, visuals, image_path, short=False):
    #     image_dir = webpage.get_image_dir()
    #     if short:
    #         short_path = ntpath.basename(image_path)
    #         name = os.path.splitext(short_path)[0]
    #     else:
    #         name = image_path

    #     webpage.add_header(name)
    #     ims = []
    #     txts = []
    #     links = []

    #     for label, image_numpy in visuals.items():
    #         image_name = '%s_%s.png' % (name, label)
    #         save_path = os.path.join(image_dir, image_name)
    #         util.save_image(image_numpy, save_path)

    #         ims.append(image_name)
    #         txts.append(label)
    #         links.append(image_name)
    #     webpage.add_images(ims, txts, links, width=self.win_size)


================================================
FILE: train.py
================================================
import argparse

import data as Dataset
from config import Config
from util.logging import init_logging, make_logging_dir
from util.trainer import get_model_optimizer_and_scheduler, set_random_seed, get_trainer
from util.distributed import init_dist
from util.distributed import master_only_print as print


def parse_args():
    parser = argparse.ArgumentParser(description='Training')
    parser.add_argument('--config', default='./config/face.yaml')
    parser.add_argument('--name', default=None)
    parser.add_argument('--checkpoints_dir', default='result',
                        help='Dir for saving logs and models.')
    parser.add_argument('--seed', type=int, default=0, help='Random seed.')
    parser.add_argument('--which_iter', type=int, default=None)
    parser.add_argument('--no_resume', action='store_true')
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument('--single_gpu', action='store_true')
    parser.add_argument('--debug', action='store_true')

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    # get training options
    args = parse_args()
    set_random_seed(args.seed)
    opt = Config(args.config, args, is_train=True)

    if not args.single_gpu:
        opt.local_rank = args.local_rank
        init_dist(opt.local_rank)    
        opt.device = opt.local_rank
    
    # create a visualizer
    date_uid, logdir = init_logging(opt)
    opt.logdir = logdir
    make_logging_dir(logdir, date_uid)
    # create a dataset
    val_dataset, train_dataset = Dataset.get_train_val_dataloader(opt.data)

    # create a model
    net_G, net_G_ema, opt_G, sch_G \
        = get_model_optimizer_and_scheduler(opt)

    trainer = get_trainer(opt, net_G, net_G_ema, opt_G, sch_G, train_dataset)

    current_epoch, current_iteration = trainer.load_checkpoint(opt, args.which_iter)   
    # training flag
    max_epoch = opt.max_epoch

    if args.debug:
        trainer.test_everything(train_dataset, val_dataset, current_epoch, current_iteration)
        exit()
    # Start training.
    for epoch in range(current_epoch, opt.max_epoch):
        print('Epoch {} ...'.format(epoch))
        if not args.single_gpu:
            train_dataset.sampler.set_epoch(current_epoch)
        trainer.start_of_epoch(current_epoch)
        for it, data in enumerate(train_dataset):
            data = trainer.start_of_iteration(data, current_iteration)
            trainer.optimize_parameters(data)
            current_iteration += 1
            trainer.end_of_iteration(data, current_epoch, current_iteration)
 
            if current_iteration >= opt.max_iter:
                print('Done with training!!!')
                break
        current_epoch += 1
        trainer.end_of_epoch(data, val_dataset, current_epoch, current_iteration)


================================================
FILE: trainers/__init__.py
================================================


================================================
FILE: trainers/base.py
================================================
import os
import time
import glob
from tqdm import tqdm

import torch
import torchvision
from torch import nn

from util.distributed import is_master, master_only
from util.distributed import master_only_print as print
from util.meters import Meter, add_hparams
from util.misc import to_cuda, to_device, requires_grad
from util.lpips import LPIPS



class BaseTrainer(object):
    r"""Base trainer. We expect that all trainers inherit this class.

    Args:
        opt (obj): Global configuration.
        net_G (obj): Generator network.
        net_D (obj): Discriminator network.
        opt_G (obj): Optimizer for the generator network.
        opt_D (obj): Optimizer for the discriminator network.
        sch_G (obj): Scheduler for the generator optimizer.
        sch_D (obj): Scheduler for the discriminator optimizer.
        train_data_loader (obj): Train data loader.
        val_data_loader (obj): Validation data loader.
    """

    def __init__(self,
                 opt,
                 net_G,
                 net_G_ema,
                 opt_G,
                 sch_G,
                 train_data_loader,
                 val_data_loader=None):
        super(BaseTrainer, self).__init__()
        print('Setup trainer.')

        # Initialize models and data loaders.
        self.opt = opt
        self.net_G = net_G
        if opt.distributed:
            self.net_G_module = self.net_G.module
        else:
            self.net_G_module = self.net_G

        self.is_inference = train_data_loader is None
        self.net_G_ema = net_G_ema
        self.opt_G = opt_G
        self.sch_G = sch_G
        self.train_data_loader = train_data_loader

        self.criteria = nn.ModuleDict()
        self.weights = dict()
        self.losses = dict(gen_update=dict(), dis_update=dict())
        self.gen_losses = self.losses['gen_update']
        self._init_loss(opt)
        for loss_name, loss_weight in self.weights.items():
            print("Loss {:<20} Weight {}".format(loss_name, loss_weight))
            if loss_name in self.criteria.keys() and \
                    self.criteria[loss_name] is not None:
                self.criteria[loss_name].to('cuda')

        if self.is_inference:
            # The initialization steps below can be skipped during inference.
            return

        # Initialize logging attributes.
        self.current_iteration = 0
        self.current_epoch = 0
        self.start_iteration_time = None
        self.start_epoch_time = None
        self.elapsed_iteration_time = 0
        self.time_iteration = -1
        self.time_epoch = -1
        if getattr(self.opt, 'speed_benchmark', False):
            self.accu_gen_forw_iter_time = 0
            self.accu_gen_loss_iter_time = 0
            self.accu_gen_back_iter_time = 0
            self.accu_gen_step_iter_time = 0
            self.accu_gen_avg_iter_time = 0

        # Initialize tensorboard and hparams.
        self._init_tensorboard()
        self._init_hparams()
        self.lpips = LPIPS()
        self.best_lpips = None

    def _init_tensorboard(self):
        r"""Initialize the tensorboard. Different algorithms might require
        different performance metrics. Hence, custom tensorboard
        initialization might be necessary.
        """
        # Logging frequency: self.opt.logging_iter
        self.meters = {}
        names = ['optim/gen_lr', 'time/iteration', 'time/epoch', 
                 'metric/best_lpips', 'metric/lpips']
        for name in names:
            self.meters[name] = Meter(name)

        # Logging frequency: self.opt.image_display_iter
        self.image_meter = Meter('images')

        # Logging frequency: self.opt.snapshot_save_iter
        # self.meters['metric/lpips'] = Meter('metric/lpips')


    def _init_hparams(self):
        r"""Initialize a dictionary of hyperparameters that we want to monitor
        in the HParams dashboard in tensorBoard.
        """
        self.hparam_dict = {}

    def _write_tensorboard(self):
        r"""Write values to tensorboard. By default, we will log the time used
        per iteration, time used per epoch, generator learning rate, and
        discriminator learning rate. We will log all the losses as well as
        custom meters.
        """
        # Logs that are shared by all models.
        self._write_to_meters({'time/iteration': self.time_iteration,
                               'time/epoch': self.time_epoch,
                               'optim/gen_lr': self.sch_G.get_last_lr()[0]},
                                self.meters)
        # Logs for loss values. Different models have different losses.
        self._write_loss_meters()
        # Other custom logs.
        self._write_custom_meters()
        # Write all logs to tensorboard.
        self._flush_meters(self.meters)

    def _write_loss_meters(self):
        r"""Write all loss values to tensorboard."""
        for loss_name, loss in self.gen_losses.items():
            full_loss_name = 'gen_update' + '/' + loss_name
            if full_loss_name not in self.meters.keys():
                # Create a new meter if it doesn't exist.
                self.meters[full_loss_name] = Meter(full_loss_name)
            self.meters[full_loss_name].write(loss.item())

    def test_everything(self, train_dataset, val_dataset, current_epoch, current_iteration):
        r"""Test the functions defined in the models. by default, we will test the 
        training function, the inference function, the visualization function.
        """        
        self._set_custom_debug_parameter()
        self.start_of_epoch(current_epoch)
        print('Start testing your functions')
        for it in tqdm(range(30)):
            data = iter(train_dataset).next()
            data = self.start_of_iteration(data, current_iteration)
            self.optimize_parameters(data)
            current_iteration += 1
            self.end_of_iteration(data, current_epoch, current_iteration)
            
        self.save_image(self._get_save_path('image', 'jpg'), data)
        self._write_tensorboard()
        self._print_current_errors()
        self.write_metrics(data)
        self.end_of_epoch(data, val_dataset, current_epoch, current_iteration)
        print('End debugging')
        

    def _set_custom_debug_parameter(self):
        r"""Set custom debug parame.
        """
        self.opt.logging_iter = 10
        self.opt.image_save_iter = 10
        

    def _write_custom_meters(self):
        r"""Dummy member function to be overloaded by the child class.
        In the child class, you can write down whatever you want to track.
        """
        pass

    @staticmethod
    def _write_to_meters(data, meters):
        r"""Write values to meters."""
        for key, value in data.items():
            meters[key].write(value)

    def _flush_meters(self, meters):
        r"""Flush all meters using the current iteration."""
        for meter in meters.values():
            meter.flush(self.current_iteration)

    def _pre_save_checkpoint(self):
        r"""Implement the things you want to do before saving a checkpoint.
        For example, you can compute the K-mean features (pix2pixHD) before
        saving the model weights to a checkpoint.
        """
        pass

    def save_checkpoint(self, current_epoch, current_iteration):
        r"""Save network weights, optimizer parameters, scheduler parameters
        to a checkpoint.
        """
        self._pre_save_checkpoint()
        _save_checkpoint(self.opt,
                         self.net_G, self.net_G_ema, 
                         self.opt_G, self.sch_G,
                         current_epoch, current_iteration)

    def load_checkpoint(self, opt, which_iter=None):
        if which_iter is not None:
            model_path = os.path.join(
                opt.logdir, '*_iteration_{:09}_checkpoint.pt'.format(which_iter))
            latest_checkpoint_path = glob.glob(model_path)
            assert len(latest_checkpoint_path) <= 1, "please check the saved model {}".format(
                model_path)
            if len(latest_checkpoint_path) == 0:
                current_epoch = 0
                current_iteration = 0
                print('No checkpoint found at iteration {}.'.format(which_iter))
                return current_epoch, current_iteration
            checkpoint_path = latest_checkpoint_path[0]

        elif os.path.exists(os.path.join(opt.logdir, 'latest_checkpoint.txt')):
            with open(os.path.join(opt.logdir, 'latest_checkpoint.txt'), 'r') as f:
                line = f.readlines()[0].replace('\n', '')
                checkpoint_path = os.path.join(opt.logdir, line.split(' ')[-1])
        else:
            current_epoch = 0
            current_iteration = 0
            print('No checkpoint found.')
            return current_epoch, current_iteration
        resume = opt.phase == 'train' and opt.resume
        current_epoch, current_iteration = self._load_checkpoint(
            checkpoint_path, resume)
        return current_epoch, current_iteration

    def _load_checkpoint(self, checkpoint_path, resume=True):
        checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
        self.net_G.load_state_dict(checkpoint['net_G'], strict=False)
        self.net_G_ema.load_state_dict(checkpoint['net_G_ema'], strict=False)
        print('load [net_G] and [net_G_ema] from {}'.format(checkpoint_path))
        if self.opt.phase == 'train' and resume:
            # the checkpoint we provided does not contains 
            # the parameters of the optimizer and schdule 
            # because we train the model use another code 
            # which does not save these parameters
            self.opt_G.load_state_dict(checkpoint['opt_G'])
            self.sch_G.load_state_dict(checkpoint['sch_G'])
            print('load optimizers and schdules from {}'.format(checkpoint_path))

        if resume or self.opt.phase == 'test':
            current_epoch = checkpoint['current_epoch']
            current_iteration = checkpoint['current_iteration']
        else:
            current_epoch = 0
            current_iteration = 0
        print('Done with loading the checkpoint.')
        return current_epoch, current_iteration        

    def start_of_epoch(self, current_epoch):
        r"""Things to do before an epoch.

        Args:
            current_epoch (int): Current number of epoch.
        """
        self._start_of_epoch(current_epoch)
        self.current_epoch = current_epoch
        self.start_epoch_time = time.time()

    def start_of_iteration(self, data, current_iteration):
        r"""Things to do before an iteration.

        Args:
            data (dict): Data used for the current iteration.
            current_iteration (int): Current number of iteration.
        """
        data = self._start_of_iteration(data, current_iteration)
        data = to_cuda(data)
        self.current_iteration = current_iteration
        if not self.is_inference:
            self.net_G.train()
        self.start_iteration_time = time.time()
        return data

    def end_of_iteration(self, data, current_epoch, current_iteration):
        r"""Things to do after an iteration.

        Args:
            data (dict): Data used for the current iteration.
            current_epoch (int): Current number of epoch.
            current_iteration (int): Current number of iteration.
        """
        self.current_iteration = current_iteration
        self.current_epoch = current_epoch
        # Update the learning rate policy for the generator if operating in the
        # iteration mode.
        if self.opt.gen_optimizer.lr_policy.iteration_mode:
            self.sch_G.step()

        # Accumulate time
        # torch.cuda.synchronize()
        self.elapsed_iteration_time += time.time() - self.start_iteration_time
        # Logging.
        if current_iteration % self.opt.logging_iter == 0:
            ave_t = self.elapsed_iteration_time / self.opt.logging_iter
            self.time_i
Download .txt
gitextract_uz8rj1lb/

├── .gitmodules
├── DatasetHelper.md
├── LICENSE.md
├── README.md
├── config/
│   ├── face.yaml
│   └── face_demo.yaml
├── config.py
├── data/
│   ├── __init__.py
│   ├── image_dataset.py
│   ├── vox_dataset.py
│   └── vox_video_dataset.py
├── demo_images/
│   └── expression.mat
├── generators/
│   ├── base_function.py
│   └── face_model.py
├── inference.py
├── intuitive_control.py
├── loss/
│   └── perceptual.py
├── requirements.txt
├── scripts/
│   ├── coeff_detector.py
│   ├── download_demo_dataset.sh
│   ├── download_weights.sh
│   ├── extract_kp_videos.py
│   ├── face_recon_images.py
│   ├── face_recon_videos.py
│   ├── inference_options.py
│   └── prepare_vox_lmdb.py
├── third_part/
│   └── PerceptualSimilarity/
│       ├── models/
│       │   ├── __init__.py
│       │   ├── base_model.py
│       │   ├── dist_model.py
│       │   ├── models.py
│       │   ├── networks_basic.py
│       │   └── pretrained_networks.py
│       ├── util/
│       │   ├── __init__.py
│       │   ├── html.py
│       │   ├── util.py
│       │   └── visualizer.py
│       └── weights/
│           ├── v0.0/
│           │   ├── alex.pth
│           │   ├── squeeze.pth
│           │   └── vgg.pth
│           └── v0.1/
│               ├── alex.pth
│               ├── squeeze.pth
│               └── vgg.pth
├── train.py
├── trainers/
│   ├── __init__.py
│   ├── base.py
│   └── face_trainer.py
└── util/
    ├── cudnn.py
    ├── distributed.py
    ├── flow_util.py
    ├── init_weight.py
    ├── io.py
    ├── logging.py
    ├── lpips.py
    ├── meters.py
    ├── misc.py
    └── trainer.py
Download .txt
SYMBOL INDEX (389 symbols across 37 files)

FILE: config.py
  class AttrDict (line 10) | class AttrDict(dict):
    method __init__ (line 13) | def __init__(self, *args, **kwargs):
    method yaml (line 25) | def yaml(self):
    method __repr__ (line 43) | def __repr__(self):
  class Config (line 67) | class Config(AttrDict):
    method __init__ (line 71) | def __init__(self, filename=None, args=None, verbose=False, is_train=T...
  function rsetattr (line 186) | def rsetattr(obj, attr, val):
  function rgetattr (line 192) | def rgetattr(obj, attr, *args):
  function recursive_update (line 202) | def recursive_update(d, u):

FILE: data/__init__.py
  function find_dataset_using_name (line 6) | def find_dataset_using_name(dataset_name):
  function get_option_setter (line 23) | def get_option_setter(dataset_name):
  function create_dataloader (line 28) | def create_dataloader(opt, is_inference):
  function data_sampler (line 46) | def data_sampler(dataset, shuffle, distributed):
  function get_dataloader (line 55) | def get_dataloader(opt, is_inference=False):
  function get_train_val_dataloader (line 60) | def get_train_val_dataloader(opt):

FILE: data/image_dataset.py
  class ImageDataset (line 12) | class ImageDataset():
    method __init__ (line 13) | def __init__(self, opt, input_name):
    method next_image (line 22) | def next_image(self):
    method obtain_inputs (line 38) | def obtain_inputs(self, root):
    method transform_semantic (line 52) | def transform_semantic(self, semantic):
    method trans_image (line 62) | def trans_image(self, image):
    method __len__ (line 69) | def __len__(self):

FILE: data/vox_dataset.py
  function format_for_lmdb (line 13) | def format_for_lmdb(*args):
  class VoxDataset (line 21) | class VoxDataset(Dataset):
    method __init__ (line 22) | def __init__(self, opt, is_inference):
    method get_video_index (line 53) | def get_video_index(self, videos):
    method group_by_key (line 62) | def group_by_key(self, video_list, key):
    method Video_Item (line 68) | def Video_Item(self, video_name):
    method __len__ (line 79) | def __len__(self):
    method __getitem__ (line 82) | def __getitem__(self, index):
    method random_select_frames (line 108) | def random_select_frames(self, video_item):
    method transform_semantic (line 113) | def transform_semantic(self, semantic, frame_index):
    method obtain_seq_index (line 127) | def obtain_seq_index(self, index, num_frames):

FILE: data/vox_video_dataset.py
  class VoxVideoDataset (line 14) | class VoxVideoDataset(VoxDataset):
    method __init__ (line 15) | def __init__(self, opt, is_inference):
    method __len__ (line 23) | def __len__(self):
    method load_next_video (line 26) | def load_next_video(self):
    method random_video (line 61) | def random_video(self, target_video_item):
    method find_crop_norm_ratio (line 71) | def find_crop_norm_ratio(self, source_coeff, target_coeffs):
    method transform_semantic (line 79) | def transform_semantic(self, semantic, frame_index, crop_norm_ratio):
    method obtain_name (line 96) | def obtain_name(self, target_name, source_name):

FILE: generators/base_function.py
  class LayerNorm2d (line 11) | class LayerNorm2d(nn.Module):
    method __init__ (line 12) | def __init__(self, n_out, affine=True):
    method forward (line 21) | def forward(self, x):
  class ADAINHourglass (line 31) | class ADAINHourglass(nn.Module):
    method __init__ (line 32) | def __init__(self, image_nc, pose_nc, ngf, img_f, encoder_layers, deco...
    method forward (line 38) | def forward(self, x, z):
  class ADAINEncoder (line 43) | class ADAINEncoder(nn.Module):
    method __init__ (line 44) | def __init__(self, image_nc, pose_nc, ngf, img_f, layers, nonlinearity...
    method forward (line 55) | def forward(self, x, z):
  class ADAINDecoder (line 64) | class ADAINDecoder(nn.Module):
    method __init__ (line 66) | def __init__(self, pose_nc, ngf, img_f, encoder_layers, decoder_layers...
    method forward (line 84) | def forward(self, x, z):
  class ADAINEncoderBlock (line 92) | class ADAINEncoderBlock(nn.Module):
    method __init__ (line 93) | def __init__(self, input_nc, output_nc, feature_nc, nonlinearity=nn.Le...
    method forward (line 106) | def forward(self, x, z):
  class ADAINDecoderBlock (line 111) | class ADAINDecoderBlock(nn.Module):
    method __init__ (line 112) | def __init__(self, input_nc, output_nc, hidden_nc, feature_nc, use_tra...
    method forward (line 139) | def forward(self, x, z):
    method shortcut (line 146) | def shortcut(self, x, z):
  function spectral_norm (line 151) | def spectral_norm(module, use_spect=True):
  class ADAIN (line 159) | class ADAIN(nn.Module):
    method __init__ (line 160) | def __init__(self, norm_nc, feature_nc):
    method forward (line 175) | def forward(self, x, feature):
  class FineEncoder (line 193) | class FineEncoder(nn.Module):
    method __init__ (line 195) | def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNo...
    method forward (line 206) | def forward(self, x):
  class FineDecoder (line 215) | class FineDecoder(nn.Module):
    method __init__ (line 217) | def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block...
    method forward (line 235) | def forward(self, x, z):
  class FirstBlock2d (line 247) | class FirstBlock2d(nn.Module):
    method __init__ (line 251) | def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, non...
    method forward (line 262) | def forward(self, x):
  class DownBlock2d (line 266) | class DownBlock2d(nn.Module):
    method __init__ (line 267) | def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, non...
    method forward (line 280) | def forward(self, x):
  class UpBlock2d (line 284) | class UpBlock2d(nn.Module):
    method __init__ (line 285) | def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, non...
    method forward (line 294) | def forward(self, x):
  class FineADAINResBlocks (line 298) | class FineADAINResBlocks(nn.Module):
    method __init__ (line 299) | def __init__(self, num_block, input_nc, feature_nc, norm_layer=nn.Batc...
    method forward (line 306) | def forward(self, x, z):
  class Jump (line 312) | class Jump(nn.Module):
    method __init__ (line 313) | def __init__(self, input_nc, norm_layer=nn.BatchNorm2d, nonlinearity=n...
    method forward (line 323) | def forward(self, x):
  class FineADAINResBlock2d (line 327) | class FineADAINResBlock2d(nn.Module):
    method __init__ (line 331) | def __init__(self, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, no...
    method forward (line 344) | def forward(self, x, z):
  class FinalBlock2d (line 350) | class FinalBlock2d(nn.Module):
    method __init__ (line 354) | def __init__(self, input_nc, output_nc, use_spect=False, tanh_or_sigmo...
    method forward (line 366) | def forward(self, x):

FILE: generators/face_model.py
  class FaceGenerator (line 11) | class FaceGenerator(nn.Module):
    method __init__ (line 12) | def __init__(
    method forward (line 24) | def forward(
  class MappingNet (line 39) | class MappingNet(nn.Module):
    method __init__ (line 40) | def __init__(self, coeff_nc, descriptor_nc, layer):
    method forward (line 57) | def forward(self, input_3dmm):
  class WarpingNet (line 65) | class WarpingNet(nn.Module):
    method __init__ (line 66) | def __init__(
    method forward (line 92) | def forward(self, input_image, descriptor):
  class EditingNet (line 102) | class EditingNet(nn.Module):
    method __init__ (line 103) | def __init__(
    method forward (line 123) | def forward(self, input_image, warp_image, descriptor):

FILE: inference.py
  function parse_args (line 22) | def parse_args():
  function write2video (line 40) | def write2video(results_dir, *video_list):

FILE: intuitive_control.py
  function parse_args (line 20) | def parse_args():
  function get_control (line 37) | def get_control(input_name):

FILE: loss/perceptual.py
  function apply_imagenet_normalization (line 8) | def apply_imagenet_normalization(input):
  class PerceptualLoss (line 25) | class PerceptualLoss(nn.Module):
    method __init__ (line 41) | def __init__(self, network='vgg19', layers='relu_4_1', weights=None,
    method forward (line 91) | def forward(self, inp, target, mask=None):
    method compute_gram (line 164) | def compute_gram(self, x):
  class _PerceptualNetwork (line 171) | class _PerceptualNetwork(nn.Module):
    method __init__ (line 181) | def __init__(self, network, layer_name_mapping, layers):
    method forward (line 191) | def forward(self, x):
  function _vgg19 (line 203) | def _vgg19(layers):
  function _vgg16 (line 222) | def _vgg16(layers):
  function _alexnet (line 239) | def _alexnet(layers):
  function _inception_v3 (line 255) | def _inception_v3(layers):
  function _resnet50 (line 284) | def _resnet50(layers):
  function _robust_resnet50 (line 303) | def _robust_resnet50(layers):
  function _vgg_face_dag (line 329) | def _vgg_face_dag(layers):

FILE: scripts/coeff_detector.py
  class CoeffDetector (line 18) | class CoeffDetector(nn.Module):
    method __init__ (line 19) | def __init__(self, opt):
    method forward (line 30) | def forward(self, img, lm):
    method image_transform (line 53) | def image_transform(self, images, lm):
  function get_data_path (line 74) | def get_data_path(root, keypoint_root):

FILE: scripts/extract_kp_videos.py
  class KeypointExtractor (line 14) | class KeypointExtractor():
    method __init__ (line 15) | def __init__(self):
    method extract_keypoint (line 18) | def extract_keypoint(self, images, name=None):
  function read_video (line 52) | def read_video(filename):
  function run (line 66) | def run(data):

FILE: scripts/face_recon_images.py
  function get_data_path (line 17) | def get_data_path(root, keypoint_root):
  class ImagePathDataset (line 36) | class ImagePathDataset(torch.utils.data.Dataset):
    method __init__ (line 37) | def __init__(self, filenames, txt_filenames, bfm_folder):
    method __len__ (line 42) | def __len__(self):
    method __getitem__ (line 45) | def __getitem__(self, i):
    method image_transform (line 55) | def image_transform(self, images, lm):
    method read_data (line 72) | def read_data(self, filename, txt_filename):
  function main (line 80) | def main(opt, model):

FILE: scripts/face_recon_videos.py
  function get_data_path (line 18) | def get_data_path(root, keypoint_root):
  class VideoPathDataset (line 34) | class VideoPathDataset(torch.utils.data.Dataset):
    method __init__ (line 35) | def __init__(self, filenames, txt_filenames, bfm_folder):
    method __len__ (line 40) | def __len__(self):
    method __getitem__ (line 43) | def __getitem__(self, index):
    method read_video (line 61) | def read_video(self, filename):
    method image_transform (line 75) | def image_transform(self, images, lm):
  function main (line 92) | def main(opt, model):

FILE: scripts/inference_options.py
  class InferenceOptions (line 4) | class InferenceOptions(BaseOptions):
    method initialize (line 10) | def initialize(self, parser):

FILE: scripts/prepare_vox_lmdb.py
  function format_for_lmdb (line 15) | def format_for_lmdb(*args):
  class Resizer (line 23) | class Resizer:
    method __init__ (line 24) | def __init__(self, size, kp_root, coeff_3dmm_root, img_format):
    method get_resized_bytes (line 30) | def get_resized_bytes(self, img, img_format='jpeg'):
    method prepare (line 37) | def prepare(self, filename):
    method __call__ (line 57) | def __call__(self, index_filename):
  function get_others (line 62) | def get_others(root, video_name, data_type):
  function convert_kp (line 78) | def convert_kp(file_path):
  function convert_3dmm (line 83) | def convert_3dmm(file_path):
  function prepare_data (line 94) | def prepare_data(path, keypoint_path, coeff_3dmm_path, out, n_worker, si...

FILE: third_part/PerceptualSimilarity/models/base_model.py
  class BaseModel (line 7) | class BaseModel():
    method __init__ (line 8) | def __init__(self):
    method name (line 11) | def name(self):
    method initialize (line 14) | def initialize(self, use_gpu=True):
    method forward (line 19) | def forward(self):
    method get_image_paths (line 22) | def get_image_paths(self):
    method optimize_parameters (line 25) | def optimize_parameters(self):
    method get_current_visuals (line 28) | def get_current_visuals(self):
    method get_current_errors (line 31) | def get_current_errors(self):
    method save (line 34) | def save(self, label):
    method save_network (line 38) | def save_network(self, network, path, network_label, epoch_label):
    method load_network (line 44) | def load_network(self, network, network_label, epoch_label):
    method update_learning_rate (line 51) | def update_learning_rate():
    method get_image_paths (line 54) | def get_image_paths(self):
    method save_done (line 57) | def save_done(self, flag=False):

FILE: third_part/PerceptualSimilarity/models/dist_model.py
  class DistModel (line 25) | class DistModel(BaseModel):
    method name (line 26) | def name(self):
    method initialize (line 29) | def initialize(self, model='net-lin', net='alex', pnet_rand=False, pne...
    method forward_pair (line 106) | def forward_pair(self,in1,in2,retPerLayer=False):
    method forward (line 112) | def forward(self, in0, in1, retNumpy=True):
    method optimize_parameters (line 163) | def optimize_parameters(self):
    method clamp_weights (line 170) | def clamp_weights(self):
    method set_input (line 175) | def set_input(self, data):
    method forward_train (line 191) | def forward_train(self): # run forward pass
    method backward_train (line 202) | def backward_train(self):
    method compute_accuracy (line 205) | def compute_accuracy(self,d0,d1,judge):
    method get_current_errors (line 211) | def get_current_errors(self):
    method get_current_visuals (line 220) | def get_current_visuals(self):
    method save (line 235) | def save(self, path, label):
    method update_learning_rate (line 239) | def update_learning_rate(self,nepoch_decay):
  function score_2afc_dataset (line 251) | def score_2afc_dataset(data_loader,func):
  function score_jnd_dataset (line 288) | def score_jnd_dataset(data_loader,func):

FILE: third_part/PerceptualSimilarity/models/models.py
  function create_model (line 3) | def create_model(opt):

FILE: third_part/PerceptualSimilarity/models/networks_basic.py
  class PNet (line 21) | class PNet(nn.Module):
    method __init__ (line 23) | def __init__(self, pnet_type='vgg', pnet_rand=False, use_gpu=True):
    method forward (line 50) | def forward(self, in0, in1, retPerLayer=False):
  class PNetLin (line 75) | class PNetLin(nn.Module):
    method __init__ (line 76) | def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, ...
    method forward (line 131) | def forward(self, in0, in1):
  class Dist2LogitLayer (line 180) | class Dist2LogitLayer(nn.Module):
    method __init__ (line 182) | def __init__(self, chn_mid=32,use_sigmoid=True):
    method forward (line 193) | def forward(self,d0,d1,eps=0.1):
  class BCERankingLoss (line 196) | class BCERankingLoss(nn.Module):
    method __init__ (line 197) | def __init__(self, use_gpu=True, chn_mid=32):
    method forward (line 208) | def forward(self, d0, d1, judge):
  class NetLinLayer (line 215) | class NetLinLayer(nn.Module):
    method __init__ (line 217) | def __init__(self, chn_in, chn_out=1, use_dropout=False):
  class FakeNet (line 226) | class FakeNet(nn.Module):
    method __init__ (line 227) | def __init__(self, use_gpu=True, colorspace='Lab'):
  class L2 (line 232) | class L2(FakeNet):
    method forward (line 234) | def forward(self, in0, in1):
  class DSSIM (line 249) | class DSSIM(FakeNet):
    method forward (line 251) | def forward(self, in0, in1):
  function print_network (line 264) | def print_network(net):

FILE: third_part/PerceptualSimilarity/models/pretrained_networks.py
  class squeezenet (line 6) | class squeezenet(torch.nn.Module):
    method __init__ (line 7) | def __init__(self, requires_grad=False, pretrained=True):
    method forward (line 36) | def forward(self, X):
  class alexnet (line 57) | class alexnet(torch.nn.Module):
    method __init__ (line 58) | def __init__(self, requires_grad=False, pretrained=True):
    method forward (line 81) | def forward(self, X):
  class vgg16 (line 97) | class vgg16(torch.nn.Module):
    method __init__ (line 98) | def __init__(self, requires_grad=False, pretrained=True):
    method forward (line 121) | def forward(self, X):
  class resnet (line 139) | class resnet(torch.nn.Module):
    method __init__ (line 140) | def __init__(self, requires_grad=False, pretrained=True, num=18):
    method forward (line 163) | def forward(self, X):

FILE: third_part/PerceptualSimilarity/util/html.py
  class HTML (line 6) | class HTML:
    method __init__ (line 7) | def __init__(self, web_dir, title, image_subdir='', reflesh=0):
    method get_image_dir (line 24) | def get_image_dir(self):
    method add_header (line 27) | def add_header(self, str):
    method add_table (line 31) | def add_table(self, border=1):
    method add_images (line 35) | def add_images(self, ims, txts, links, width=400):
    method save (line 47) | def save(self,file='index'):

FILE: third_part/PerceptualSimilarity/util/util.py
  function datetime_str (line 18) | def datetime_str():
  function read_text_file (line 22) | def read_text_file(in_path):
  function bootstrap (line 34) | def bootstrap(in_vec,num_samples=100,bootfunc=np.mean):
  function rand_flip (line 38) | def rand_flip(input1,input2):
  function l2 (line 44) | def l2(p0, p1, range=255.):
  function psnr (line 47) | def psnr(p0, p1, peak=255.):
  function dssim (line 50) | def dssim(p0, p1, range=255.):
  function rgb2lab (line 54) | def rgb2lab(in_img,mean_cent=False):
  function normalize_blob (line 61) | def normalize_blob(in_feat,eps=1e-10):
  function cos_sim_blob (line 65) | def cos_sim_blob(in0,in1):
  function normalize_tensor (line 72) | def normalize_tensor(in_feat,eps=1e-10):
  function cos_sim (line 77) | def cos_sim(in0,in1):
  function tensor2np (line 89) | def tensor2np(tensor_obj):
  function np2tensor (line 93) | def np2tensor(np_obj):
  function tensor2tensorlab (line 97) | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
  function tensorlab2tensor (line 113) | def tensorlab2tensor(lab_tensor,return_inbnd=False):
  function tensor2im (line 136) | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
  function im2tensor (line 142) | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
  function tensor2vec (line 147) | def tensor2vec(vector_tensor):
  function diagnose_network (line 150) | def diagnose_network(net, name='network'):
  function grab_patch (line 162) | def grab_patch(img_in, P, yy, xx):
  function load_image (line 165) | def load_image(path):
  function resize_image (line 180) | def resize_image(img, max_size=256):
  function resize_image_zoom (line 190) | def resize_image_zoom(img, zoom_factor=1., order=3):
  function save_image (line 196) | def save_image(image_numpy, image_path, ):
  function prep_display_image (line 201) | def prep_display_image(img, dtype='uint8'):
  function info (line 208) | def info(object, spacing=10, collapse=1):
  function varname (line 224) | def varname(p):
  function print_numpy (line 231) | def print_numpy(x, val=True, shp=False):
  function mkdirs (line 242) | def mkdirs(paths):
  function mkdir (line 250) | def mkdir(path):
  function rgb2lab (line 255) | def rgb2lab(input):
  function montage (line 260) | def montage(
  class zeroClipper (line 369) | class zeroClipper(object):
    method __init__ (line 370) | def __init__(self, frequency=1):
    method __call__ (line 373) | def __call__(self, module):
  function flatten_nested_list (line 379) | def flatten_nested_list(nested_list):
  function read_file (line 387) | def read_file(in_path,list_lines=False):
  function read_csv_file_as_text (line 405) | def read_csv_file_as_text(in_path):
  function random_swap (line 415) | def random_swap(obj0,obj1):
  function voc_ap (line 421) | def voc_ap(rec, prec, use_07_metric=False):

FILE: third_part/PerceptualSimilarity/util/visualizer.py
  function zoom_to_res (line 11) | def zoom_to_res(img,res=256,order=0,axis=0):
  class Visualizer (line 20) | class Visualizer():
    method __init__ (line 21) | def __init__(self, opt):
    method display_current_results (line 43) | def display_current_results(self, visuals, epoch, nrows=None, res=256):
    method plot_current_errors_save (line 94) | def plot_current_errors_save(self, epoch, counter_ratio, opt, errors,k...
    method plot_current_errors (line 126) | def plot_current_errors(self, epoch, counter_ratio, opt, errors):
    method print_current_errors (line 142) | def print_current_errors(self, epoch, i, errors, t, t2=-1, t2o=-1, fid...
    method save_images_simple (line 152) | def save_images_simple(self, webpage, images, names, in_txts, prefix='...
    method save_images (line 174) | def save_images(self, webpage, images, names, image_path, title=''):

FILE: train.py
  function parse_args (line 11) | def parse_args():

FILE: trainers/base.py
  class BaseTrainer (line 18) | class BaseTrainer(object):
    method __init__ (line 33) | def __init__(self,
    method _init_tensorboard (line 94) | def _init_tensorboard(self):
    method _init_hparams (line 113) | def _init_hparams(self):
    method _write_tensorboard (line 119) | def _write_tensorboard(self):
    method _write_loss_meters (line 137) | def _write_loss_meters(self):
    method test_everything (line 146) | def test_everything(self, train_dataset, val_dataset, current_epoch, c...
    method _set_custom_debug_parameter (line 168) | def _set_custom_debug_parameter(self):
    method _write_custom_meters (line 175) | def _write_custom_meters(self):
    method _write_to_meters (line 182) | def _write_to_meters(data, meters):
    method _flush_meters (line 187) | def _flush_meters(self, meters):
    method _pre_save_checkpoint (line 192) | def _pre_save_checkpoint(self):
    method save_checkpoint (line 199) | def save_checkpoint(self, current_epoch, current_iteration):
    method load_checkpoint (line 209) | def load_checkpoint(self, opt, which_iter=None):
    method _load_checkpoint (line 237) | def _load_checkpoint(self, checkpoint_path, resume=True):
    method start_of_epoch (line 260) | def start_of_epoch(self, current_epoch):
    method start_of_iteration (line 270) | def start_of_iteration(self, data, current_iteration):
    method end_of_iteration (line 285) | def end_of_iteration(self, data, current_epoch, current_iteration):
    method _print_current_errors (line 348) | def _print_current_errors(self):
    method end_of_epoch (line 360) | def end_of_epoch(self, data, val_dataset, current_epoch, current_itera...
    method write_data_tensorboard (line 403) | def write_data_tensorboard(self, data, epoch, iteration):
    method save_image (line 422) | def save_image(self, path, data):
    method write_metrics (line 442) | def write_metrics(self, data):
    method _get_save_path (line 458) | def _get_save_path(self, subdir, ext):
    method _compute_metrics (line 476) | def _compute_metrics(self, data, current_iteration):
    method _start_of_epoch (line 482) | def _start_of_epoch(self, current_epoch):
    method _start_of_iteration (line 490) | def _start_of_iteration(self, data, current_iteration):
    method _end_of_iteration (line 502) | def _end_of_iteration(self, data, current_epoch, current_iteration):
    method _end_of_epoch (line 513) | def _end_of_epoch(self, data, current_epoch, current_iteration):
    method _get_visualizations (line 523) | def _get_visualizations(self, data):
    method _init_loss (line 531) | def _init_loss(self, opt):
    method gen_forward (line 535) | def gen_forward(self, data):
    method test (line 540) | def test(self, data_loader, output_dir, current_iteration):
  function _save_checkpoint (line 643) | def _save_checkpoint(opt,

FILE: trainers/face_trainer.py
  class FaceTrainer (line 9) | class FaceTrainer(BaseTrainer):
    method __init__ (line 21) | def __init__(self, opt, net_G, opt_G, sch_G,
    method _init_loss (line 27) | def _init_loss(self, opt):
    method _assign_criteria (line 50) | def _assign_criteria(self, name, criterion, weight):
    method optimize_parameters (line 54) | def optimize_parameters(self, data):
    method _start_of_iteration (line 87) | def _start_of_iteration(self, data, current_iteration):
    method reset_trainer (line 93) | def reset_trainer(self):
    method _get_visualizations (line 96) | def _get_visualizations(self, data):
    method test (line 119) | def test(self, data_loader, output_dir, current_iteration=-1):
    method _compute_metrics (line 122) | def _compute_metrics(self, data, current_iteration):

FILE: util/cudnn.py
  function init_cudnn (line 6) | def init_cudnn(deterministic, benchmark):

FILE: util/distributed.py
  function init_dist (line 6) | def init_dist(local_rank, backend='nccl', **kwargs):
  function get_rank (line 15) | def get_rank():
  function get_world_size (line 24) | def get_world_size():
  function master_only (line 33) | def master_only(func):
  function is_master (line 45) | def is_master():
  function master_only_print (line 51) | def master_only_print(*args):
  function dist_reduce_tensor (line 56) | def dist_reduce_tensor(tensor):
  function dist_all_reduce_tensor (line 68) | def dist_all_reduce_tensor(tensor):
  function dist_all_gather_tensor (line 79) | def dist_all_gather_tensor(tensor):

FILE: util/flow_util.py
  function convert_flow_to_deformation (line 3) | def convert_flow_to_deformation(flow):
  function make_coordinate_grid (line 17) | def make_coordinate_grid(flow):
  function warp_image (line 41) | def warp_image(source_image, deformation):

FILE: util/init_weight.py
  function weights_init (line 4) | def weights_init(init_type='normal', gain=0.02, bias=None):

FILE: util/io.py
  function save_pilimage_in_jpeg (line 10) | def save_pilimage_in_jpeg(fullname, output_img):
  function save_intermediate_training_results (line 22) | def save_intermediate_training_results(
  function download_file_from_google_drive (line 44) | def download_file_from_google_drive(file_id, destination):
  function get_confirm_token (line 64) | def get_confirm_token(response):
  function save_response_content (line 79) | def save_response_content(response, destination):
  function get_checkpoint (line 96) | def get_checkpoint(checkpoint_path, url=''):

FILE: util/logging.py
  function get_date_uid (line 8) | def get_date_uid():
  function init_logging (line 16) | def init_logging(opt):
  function make_logging_dir (line 26) | def make_logging_dir(logdir, date_uid):

FILE: util/lpips.py
  function get_image_list (line 10) | def get_image_list(flist):
  function preprocess_path_for_deform_task (line 29) | def preprocess_path_for_deform_task(gt_path, distorted_path):
  class LPIPS (line 45) | class LPIPS():
    method __init__ (line 46) | def __init__(self, use_gpu=True):
    method __call__ (line 51) | def __call__(self, image_1, image_2):
    method calculate_from_disk (line 59) | def calculate_from_disk(self, gt_path, distorted_path, batch_size=64, ...

FILE: util/meters.py
  function sn_reshape_weight_to_matrix (line 16) | def sn_reshape_weight_to_matrix(weight):
  function get_weight_stats (line 28) | def get_weight_stats(mod, cfg, loss_id):
  function set_summary_writer (line 51) | def set_summary_writer(log_dir):
  function write_summary (line 63) | def write_summary(name, summary, step, hist=False):
  function add_hparams (line 77) | def add_hparams(hparam_dict=None, metric_dict=None):
  class Meter (line 103) | class Meter(object):
    method __init__ (line 114) | def __init__(self, name):
    method reset (line 119) | def reset(self):
    method write (line 124) | def write(self, value):
    method flush (line 129) | def flush(self, step):
    method write_image (line 144) | def write_image(self, img_grid, step):

FILE: util/misc.py
  function split_labels (line 11) | def split_labels(labels, label_lengths):
  function requires_grad (line 36) | def requires_grad(model, require=True):
  function to_device (line 50) | def to_device(data, device):
  function to_cuda (line 70) | def to_cuda(data):
  function to_cpu (line 79) | def to_cpu(data):
  function to_half (line 88) | def to_half(data):
  function to_float (line 106) | def to_float(data):
  function get_and_setattr (line 124) | def get_and_setattr(cfg, name, default):
  function get_nested_attr (line 141) | def get_nested_attr(cfg, attr_name, default):
  function gradient_norm (line 162) | def gradient_norm(model):
  function random_shift (line 177) | def random_shift(x, offset=0.05, mode='bilinear', padding_mode='reflecti...
  function truncated_gaussian (line 200) | def truncated_gaussian(threshold, size, seed=None, device=None):
  function apply_imagenet_normalization (line 215) | def apply_imagenet_normalization(input):

FILE: util/trainer.py
  function accumulate (line 12) | def accumulate(model1, model2, decay=0.999):
  function set_random_seed (line 19) | def set_random_seed(seed):
  function get_trainer (line 34) | def get_trainer(opt, net_G, net_G_ema, opt_G, sch_G, train_dataset):
  function get_model_optimizer_and_scheduler (line 42) | def get_model_optimizer_and_scheduler(opt):
  function _calculate_model_size (line 77) | def _calculate_model_size(model):
  function get_scheduler (line 89) | def get_scheduler(opt_opt, opt):
  function get_optimizer (line 112) | def get_optimizer(opt_opt, net):
  function get_optimizer_for_params (line 116) | def get_optimizer_for_params(opt_opt, params):
Condensed preview — 56 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (258K chars).
[
  {
    "path": ".gitmodules",
    "chars": 126,
    "preview": "[submodule \"Deep3DFaceRecon_pytorch\"]\n\tpath = Deep3DFaceRecon_pytorch\n\turl = https://github.com/sicxu/Deep3DFaceRecon_py"
  },
  {
    "path": "DatasetHelper.md",
    "chars": 1178,
    "preview": "### Extract 3DMM Coefficients for Videos\n\nWe provide scripts for extracting 3dmm coefficients for videos by using [DeepF"
  },
  {
    "path": "LICENSE.md",
    "chars": 17654,
    "preview": "## creative commons\n\n# Attribution-NonCommercial 4.0 International\n\nCreative Commons Corporation (“Creative Commons”) is"
  },
  {
    "path": "README.md",
    "chars": 7066,
    "preview": "<p align='center'>\n  <b>\n    <a href=\"https://renyurui.github.io/PIRender_web/\"> Website</a>\n    | \n    <a href=\"https:/"
  },
  {
    "path": "config/face.yaml",
    "chars": 1830,
    "preview": "# How often do you want to log the training stats.\n# network_list: \n#     gen: gen_optimizer\n#     dis: dis_optimizer\n\nd"
  },
  {
    "path": "config/face_demo.yaml",
    "chars": 1835,
    "preview": "# How often do you want to log the training stats.\n# network_list: \n#     gen: gen_optimizer\n#     dis: dis_optimizer\n\nd"
  },
  {
    "path": "config.py",
    "chars": 7994,
    "preview": "import collections\nimport functools\nimport os\nimport re\n\nimport yaml\nfrom util.distributed import master_only_print as p"
  },
  {
    "path": "data/__init__.py",
    "chars": 2105,
    "preview": "import importlib\n\nimport torch.utils.data\nfrom util.distributed import master_only_print as print\n\ndef find_dataset_usin"
  },
  {
    "path": "data/image_dataset.py",
    "chars": 2392,
    "preview": "import os\nimport glob\nimport time\nimport numpy as np\nfrom PIL import Image\n\nimport torch\nimport torchvision.transforms.f"
  },
  {
    "path": "data/vox_dataset.py",
    "chars": 4895,
    "preview": "import os\nimport lmdb\nimport random\nimport collections\nimport numpy as np\nfrom PIL import Image\nfrom io import BytesIO\n\n"
  },
  {
    "path": "data/vox_video_dataset.py",
    "chars": 4641,
    "preview": "import os\nimport lmdb\nimport random\nimport collections\nimport numpy as np\nfrom PIL import Image\nfrom io import BytesIO\n\n"
  },
  {
    "path": "generators/base_function.py",
    "chars": 14659,
    "preview": "import sys\nimport math\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch.autograd impor"
  },
  {
    "path": "generators/face_model.py",
    "chars": 4296,
    "preview": "import functools\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom util impor"
  },
  {
    "path": "inference.py",
    "chars": 4445,
    "preview": "import os\nimport cv2 \nimport lmdb\nimport math\nimport argparse\nimport numpy as np\nfrom io import BytesIO\nfrom PIL import "
  },
  {
    "path": "intuitive_control.py",
    "chars": 5670,
    "preview": "import os\nimport math\nimport argparse\nimport numpy as np\nfrom scipy.io import savemat,loadmat\n\nimport torch\nimport torch"
  },
  {
    "path": "loss/perceptual.py",
    "chars": 14861,
    "preview": "import torch\nimport torch.nn.functional as F\nimport torchvision\nfrom torch import nn\n\nfrom util.distributed import maste"
  },
  {
    "path": "requirements.txt",
    "chars": 1193,
    "preview": "absl-py==0.13.0\nbackcall==0.2.0\ncachetools==4.2.2\ncertifi==2021.5.30\ncharset-normalizer==2.0.6\ncycler==0.10.0\ndataclasse"
  },
  {
    "path": "scripts/coeff_detector.py",
    "chars": 3860,
    "preview": "import os\nimport glob\nimport numpy as np\nfrom os import makedirs, name\nfrom PIL import Image\nfrom tqdm import tqdm\n\nimpo"
  },
  {
    "path": "scripts/download_demo_dataset.sh",
    "chars": 143,
    "preview": "gdown https://drive.google.com/uc?id=1ruuLw5-0fpm6EREexPn3I_UQPmkrBoq9\nunzip -x ./vox_lmdb_demo.zip\nmkdir ./dataset\nmv v"
  },
  {
    "path": "scripts/download_weights.sh",
    "chars": 135,
    "preview": "gdown https://drive.google.com/uc?id=1-0xOf6g58OmtKtEWJlU3VlnfRqPN9Uq7\nunzip -x ./face.zip\nmkdir ./result\nmv face ./resu"
  },
  {
    "path": "scripts/extract_kp_videos.py",
    "chars": 3610,
    "preview": "import os\nimport cv2\nimport time\nimport glob\nimport argparse\nimport face_alignment\nimport numpy as np\nfrom PIL import Im"
  },
  {
    "path": "scripts/face_recon_images.py",
    "chars": 5136,
    "preview": "import os\nimport glob\nimport numpy as np\nfrom PIL import Image\nfrom tqdm import tqdm\nfrom scipy.io import savemat\n\nimpor"
  },
  {
    "path": "scripts/face_recon_videos.py",
    "chars": 5693,
    "preview": "import os\nimport cv2\nimport glob\nimport numpy as np\nfrom PIL import Image\nfrom tqdm import tqdm\nfrom scipy.io import sav"
  },
  {
    "path": "scripts/inference_options.py",
    "chars": 1157,
    "preview": "from .base_options import BaseOptions\n\n\nclass InferenceOptions(BaseOptions):\n    \"\"\"This class includes test options.\n\n "
  },
  {
    "path": "scripts/prepare_vox_lmdb.py",
    "chars": 6330,
    "preview": "import os\nimport cv2\nimport lmdb\nimport argparse\nimport multiprocessing\nimport numpy as np\n\nfrom glob import glob\nfrom i"
  },
  {
    "path": "third_part/PerceptualSimilarity/models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "third_part/PerceptualSimilarity/models/base_model.py",
    "chars": 1723,
    "preview": "import os\nimport torch\nfrom torch.autograd import Variable\nfrom pdb import set_trace as st\nfrom IPython import embed\n\ncl"
  },
  {
    "path": "third_part/PerceptualSimilarity/models/dist_model.py",
    "chars": 13451,
    "preview": "\nfrom __future__ import absolute_import\n\nimport sys\nsys.path.append('..')\nsys.path.append('.')\nimport numpy as np\nimport"
  },
  {
    "path": "third_part/PerceptualSimilarity/models/models.py",
    "chars": 269,
    "preview": "from __future__ import absolute_import\n\ndef create_model(opt):\n    model = None\n    print(opt.model)\n    from .siam_mode"
  },
  {
    "path": "third_part/PerceptualSimilarity/models/networks_basic.py",
    "chars": 10421,
    "preview": "\nfrom __future__ import absolute_import\n\nimport sys\nsys.path.append('..')\nsys.path.append('.')\nimport torch\nimport torch"
  },
  {
    "path": "third_part/PerceptualSimilarity/models/pretrained_networks.py",
    "chars": 6559,
    "preview": "from collections import namedtuple\nimport torch\nfrom torchvision import models\nfrom IPython import embed\n\nclass squeezen"
  },
  {
    "path": "third_part/PerceptualSimilarity/util/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "third_part/PerceptualSimilarity/util/html.py",
    "chars": 2023,
    "preview": "import dominate\nfrom dominate.tags import *\nimport os\n\n\nclass HTML:\n    def __init__(self, web_dir, title, image_subdir="
  },
  {
    "path": "third_part/PerceptualSimilarity/util/util.py",
    "chars": 14036,
    "preview": "from __future__ import print_function\n\nimport numpy as np\nfrom PIL import Image\nimport inspect\nimport re\nimport numpy as"
  },
  {
    "path": "third_part/PerceptualSimilarity/util/visualizer.py",
    "chars": 8602,
    "preview": "import numpy as np\nimport os\nimport time\nfrom . import util\nfrom . import html\n# from pdb import set_trace as st\nimport "
  },
  {
    "path": "train.py",
    "chars": 2815,
    "preview": "import argparse\n\nimport data as Dataset\nfrom config import Config\nfrom util.logging import init_logging, make_logging_di"
  },
  {
    "path": "trainers/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "trainers/base.py",
    "chars": 27096,
    "preview": "import os\nimport time\nimport glob\nfrom tqdm import tqdm\n\nimport torch\nimport torchvision\nfrom torch import nn\n\nfrom util"
  },
  {
    "path": "trainers/face_trainer.py",
    "chars": 6224,
    "preview": "import math\n\nimport torch\n\nfrom trainers.base import BaseTrainer\nfrom util.trainer import accumulate, get_optimizer\nfrom"
  },
  {
    "path": "util/cudnn.py",
    "chars": 689,
    "preview": "import torch.backends.cudnn as cudnn\n\nfrom util.distributed import master_only_print as print\n\n\ndef init_cudnn(determini"
  },
  {
    "path": "util/distributed.py",
    "chars": 2217,
    "preview": "import functools\n\nimport torch\nimport torch.distributed as dist\n\ndef init_dist(local_rank, backend='nccl', **kwargs):\n  "
  },
  {
    "path": "util/flow_util.py",
    "chars": 1809,
    "preview": "import torch\n\ndef convert_flow_to_deformation(flow):\n    r\"\"\"convert flow fields to deformations.\n\n    Args:\n        flo"
  },
  {
    "path": "util/init_weight.py",
    "chars": 2214,
    "preview": "from torch.nn import init\n\n\ndef weights_init(init_type='normal', gain=0.02, bias=None):\n    r\"\"\"Initialize weights in th"
  },
  {
    "path": "util/io.py",
    "chars": 3597,
    "preview": "import os\n\nimport requests\nimport torch.distributed as dist\nimport torchvision.utils\n\nfrom util.distributed import is_ma"
  },
  {
    "path": "util/logging.py",
    "chars": 1411,
    "preview": "import os\nimport datetime\n\nfrom util.meters import set_summary_writer\nfrom util.distributed import master_only_print as "
  },
  {
    "path": "util/lpips.py",
    "chars": 3994,
    "preview": "import os \nimport glob\nimport numpy as np\nfrom imageio import imread\n\nimport torch\n\nfrom third_part.PerceptualSimilarity"
  },
  {
    "path": "util/meters.py",
    "chars": 4622,
    "preview": "import math\n\nimport torch\nfrom torch.utils.tensorboard import SummaryWriter\nfrom torch.utils.tensorboard.summary import "
  },
  {
    "path": "util/misc.py",
    "chars": 6676,
    "preview": "\"\"\"Miscellaneous utils.\"\"\"\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torch\nimport torch.nn.function"
  },
  {
    "path": "util/trainer.py",
    "chars": 4023,
    "preview": "import random\nimport importlib\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nfrom torch.optim import Adam, lr_s"
  }
]

// ... and 7 more files (download for full content)

About this extraction

This page contains the full source code of the RenYurui/PIRender GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 56 files (241.6 KB), approximately 60.7k tokens, and a symbol index with 389 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!