[
  {
    "path": ".gitignore",
    "content": "experiments/\nresults/\ntb_logger/\n*.pyc\n.vscode/\ndownload\ndownload/*\n*.sh\n"
  },
  {
    "path": "README.md",
    "content": "# Talk-to-Edit (ICCV2021)\n\n![Python 3.7](https://img.shields.io/badge/python-3.7-green.svg?style=plastic)\n![pytorch 1.6.0](https://img.shields.io/badge/pytorch-1.6.0-green.svg?style=plastic)\n\nThis repository contains the implementation of the following paper:\n> **Talk-to-Edit: Fine-Grained Facial Editing via Dialog**<br>\n> Yuming Jiang<sup>∗</sup>, Ziqi Huang<sup>∗</sup>, Xingang Pan, Chen Change Loy, Ziwei Liu<br>\n> IEEE International Conference on Computer Vision (**ICCV**), 2021<br>\n\n[[Paper](https://arxiv.org/abs/2109.04425)]\n[[Project Page](https://www.mmlab-ntu.com/project/talkedit/)]\n[[CelebA-Dialog Dataset](https://github.com/ziqihuangg/CelebA-Dialog)]\n[[Poster](https://drive.google.com/file/d/1KaojezBNqDrkwcT0yOkvAgqW1grwUDed/view?usp=sharing)]\n[[Video](https://www.youtube.com/watch?v=ZKMkQhkMXPI)]\n\nYou can try our colab demo here. Enjoy!\n1. Editing with dialog: <a href=\"https://colab.research.google.com/drive/14inhJjrNIj_SdhIA7NEtGS2kKOWXXSjb?usp=sharing\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"google colab logo\"></a>\n1. Editing without dialog: <a href=\"https://colab.research.google.com/drive/1mO5NmlPi4YV359cPkLZnOpG_kShQi_hN?usp=sharing\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"google colab logo\"></a>\n\n## Overview\n![overall_structure](./assets/teaser.png)\n\n\n\n## Dependencies and Installation\n\n1. Clone Repo\n\n   ```bash\n   git clone git@github.com:yumingj/Talk-to-Edit.git\n   ```\n\n1. Create Conda Environment and Install Dependencies\n\n   ```bash\n   conda env create -f environment.yml\n   conda activate talk_edit\n   ```\n   - Python >= 3.7\n   - PyTorch >= 1.6\n   - CUDA 10.1\n   - GCC 5.4.0\n\n\n## Get Started\n\n## Editing\n\nWe provide scripts for editing using our pretrained models.\n\n1. First, download the pretrained models from this [link](https://drive.google.com/drive/folders/1W9dvjz8bUolEIG524o8ZvM62uEWKJ5do?usp=sharing) and put them under `./download/pretrained_models` as follows:\n   ```\n   ./download/pretrained_models\n   ├── 1024_field\n   │   ├── Bangs.pth\n   │   ├── Eyeglasses.pth\n   │   ├── No_Beard.pth\n   │   ├── Smiling.pth\n   │   └── Young.pth\n   ├── 128_field\n   │   ├── Bangs.pth\n   │   ├── Eyeglasses.pth\n   │   ├── No_Beard.pth\n   │   ├── Smiling.pth\n   │   └── Young.pth\n   ├── arcface_resnet18_110.pth\n   ├── language_encoder.pth.tar\n   ├── predictor_1024.pth.tar\n   ├── predictor_128.pth.tar\n   ├── stylegan2_1024.pth\n   ├── stylegan2_128.pt\n   ├── StyleGAN2_FFHQ1024_discriminator.pth\n   └── eval_predictor.pth.tar\n   ```\n\n1. You can try pure image editing without dialog instructions:\n   ```bash\n   python editing_wo_dialog.py \\\n      --opt ./configs/editing/editing_wo_dialog.yml \\\n      --attr 'Bangs' \\\n      --target_val 5\n   ```\n   The editing results will be saved in `./results`.\n\n   You can change `attr` to one of the following attributes: `Bangs`, `Eyeglasses`, `Beard`, `Smiling`, and `Young(i.e. Age)`. And the `target_val` can be `[0, 1, 2, 3, 4, 5]`.\n\n1. You can also try dialog-based editing, where you talk to the system through the command prompt:\n   ```bash\n   python editing_with_dialog.py --opt ./configs/editing/editing_with_dialog.yml\n   ```\n   The editing results will be saved in `./results`.\n\n   **How to talk to the system:**\n   * Our system is able to edit five facial attributes: `Bangs`, `Eyeglasses`, `Beard`, `Smiling`, and `Young(i.e. Age)`.\n   * When prompted with `\"Enter your request (Press enter when you finish):\"`, you can enter an editing request about one of the five attributes. For example, you can say `\"Make the bangs longer.\"`\n   * To respond to the system's feedback, just talk as if you were talking to a real person. For example, if the system asks `\"Is the length of the bangs just right?\"` after one round of editing,  You can say things like `\"Yes.\"` / `\"No.\"` / `\"Yes, and I also want her to smile more happily.\"`.\n   * To end the conversation, just tell the system things like `\"That's all\"` / `\"Nothing else, thank you.\"`\n\n1. By default, the above editing would be performed on the teaser image. You may change the image to be edited in two ways: 1) change `line 11: latent_code_index` to other values ranging from `0` to `99`; 2) set `line 10: latent_code_path` to `~`, so that an image would be randomly generated.\n\n1. If you want to try editing on real images, you may download the real images from this [link](https://drive.google.com/drive/folders/1BunrwvlwCBZJnb9QqeUp_uIXMxeXXJrY?usp=sharing) and put them under `./download/real_images`. You could also provide other real images at your choice. You need to change `line 12: img_path` in  `editing_with_dialog.yml` or `editing_wo_dialog.yml` according to the path to the real image and set `line 11: is_real_image` as `True`.\n\n1. You can switch the default image size to `128 x 128` by setting `line 3: img_res` to `128` in config files.\n\n\n## Train the Semantic Field\n\n\n1. To train the Semantic Field, a number of sampled latent codes should be prepared and then we use the attribute predictor to predict the facial attributes for their corresponding images. The attribute predictor is trained using fine-grained annotations in [CelebA-Dialog](https://github.com/ziqihuangg/CelebA-Dialog) dataset. Here, we provide the latent codes we used. You can download the train data from this [link](https://drive.google.com/drive/folders/1CYBpLIwts3ZVFiFAPb4TTnqYH3NBR63p?usp=sharing) and put them under `./download/train_data` as follows:\n   ```\n   ./download/train_data\n   ├── 1024\n   │   ├── Bangs\n   │   ├── Eyeglasses\n   │   ├── No_Beard\n   │   ├── Smiling\n   │   └── Young\n   └── 128\n       ├── Bangs\n       ├── Eyeglasses\n       ├── No_Beard\n       ├── Smiling\n       └── Young\n   ```\n\n1. We will also use some editing latent codes to monitor the training phase. You can download the editing latent code from this [link](https://drive.google.com/drive/folders/1G-0srCePEXcPq9HY38Il_4FTVHX_rOa-?usp=sharing) and put them under `./download/editing_data` as follows:\n\n   ```\n   ./download/editing_data\n   ├── 1024\n   │   ├── Bangs.npz.npy\n   │   ├── Eyeglasses.npz.npy\n   │   ├── No_Beard.npz.npy\n   │   ├── Smiling.npz.npy\n   │   └── Young.npz.npy\n   └── 128\n       ├── Bangs.npz.npy\n       ├── Eyeglasses.npz.npy\n       ├── No_Beard.npz.npy\n       ├── Smiling.npz.npy\n       └── Young.npz.npy\n   ```\n\n1. All logging files in the training process, *e.g.*, log message, checkpoints, and snapshots, will be saved to `./experiments` and `./tb_logger` directory.\n\n1. There are 10 configuration files under `./configs/train`, named in the format of `field_<IMAGE_RESOLUTION>_<ATTRIBUTE_NAME>`.\nChoose the corresponding configuration file for the attribute and resolution you want.\n\n1. For example, to train the semantic field which edits the attribute `Bangs` in `128x128` image resolution, simply run:\n   ```bash\n   python train.py --opt ./configs/train/field_128_Bangs.yml\n   ```\n\n\n## Quantitative Results\n\nWe provide codes for quantitative results shown in Table 1. Here we use `Bangs` in `128x128` resolution as an example.\n\n1. Use the trained semantic field to edit images.\n   ```bash\n   python editing_quantitative.py \\\n   --opt ./configs/train/field_128_bangs.yml \\\n   --pretrained_path ./download/pretrained_models/128_field/Bangs.pth\n   ```\n\n2. Evaluate the edited images using quantitative metircs. Change `image_num` for different attribute accordingly: `Bangs: 148`, `Eyeglasses: 82`, `Beard: 129`, `Smiling: 140`, `Young: 61`.\n   ```bash\n   python quantitative_results.py \\\n   --attribute Bangs \\\n   --work_dir ./results/field_128_bangs \\\n   --image_dir ./results/field_128_bangs/visualization \\\n   --image_num 148\n   ```\n\n## Qualitative Results\n\n![result](./assets/1024_results_updated.png)\n\n\n## CelebA-Dialog Dataset\n\n![result](./assets/celeba_dialog.png)\n\nOur [**CelebA-Dialog Dataset**](https://github.com/ziqihuangg/CelebA-Dialog) is available for [Download](https://drive.google.com/drive/folders/18nejI_hrwNzWyoF6SW8bL27EYnM4STAs?usp=sharing).\n\n**CelebA-Dialog** is a large-scale visual-language face dataset with the following features:\n- Facial images are annotated with rich **fine-grained labels**, which classify one attribute into multiple degrees according to its semantic meaning.\n- Accompanied with each image, there are **captions** describing the attributes and a **user request** sample.\n\n\n![result](./assets/dataset.png)\n\nThe dataset can be employed as the training and test sets for the following computer vision tasks: fine-grained facial attribute recognition, fine-grained facial manipulation, text-based facial generation and manipulation, face image captioning, and broader natural language based facial recognition and manipulation tasks.\n\n\n## Citation\n\n   If you find our repo useful for your research, please consider citing our paper:\n\n   ```bibtex\n   @inproceedings{jiang2021talk,\n     title={Talk-to-Edit: Fine-Grained Facial Editing via Dialog},\n     author={Jiang, Yuming and Huang, Ziqi and Pan, Xingang and Loy, Chen Change and Liu, Ziwei},\n     booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},\n     pages={13799--13808},\n     year={2021}\n   }\n\n   @article{jiang2023talk,\n     title={Talk-to-edit: Fine-grained 2d and 3d facial editing via dialog},\n     author={Jiang, Yuming and Huang, Ziqi and Wu, Tianxing and Pan, Xingang and Loy, Chen Change and Liu, Ziwei},\n     journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},\n     year={2023},\n     publisher={IEEE}\n   }\n   ```\n\n## Contact\n\nIf you have any question, please feel free to contact us via `yuming002@ntu.edu.sg` or `hu0007qi@ntu.edu.sg`.\n\n## Acknowledgement\n\nThe codebase is maintained by [Yuming Jiang](https://yumingj.github.io/) and [Ziqi Huang](https://ziqihuangg.github.io/).\n\nPart of the code is borrowed from [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch), [IEP](https://github.com/facebookresearch/clevr-iep) and [face-attribute-prediction](https://github.com/d-li14/face-attribute-prediction).\n\n"
  },
  {
    "path": "configs/attributes_5.json",
    "content": "\n{\n    \"attr_info\":{\n        \"6\": {\n            \"name\": \"Bangs\",\n            \"value\":[0, 1, 2, 3, 4, 5],\n            \"idx_scale\": 1,\n            \"idx_bias\": 0\n        },\n        \"16\": {\n            \"name\": \"Eyeglasses\",\n            \"value\":[0, 1, 2, 3, 4, 5],\n            \"idx_scale\": 1,\n            \"idx_bias\": 0\n        },    \n        \"25\": {\n            \"name\": \"No_Beard\",\n            \"value\":[0, 1, 2, 3, 4, 5],\n            \"idx_scale\": -1,\n            \"idx_bias\": 5\n        },\n        \"32\": {\n            \"name\": \"Smiling\",\n            \"value\":[0, 1, 2, 3, 4, 5],\n            \"idx_scale\": 1,\n            \"idx_bias\": 0\n        },\n        \"40\": {\n            \"name\": \"Young\",\n            \"value\":[0, 1, 2, 3, 4, 5],\n            \"idx_scale\": -1,\n            \"idx_bias\": 5\n        }\n    },\n    \"newIdx_to_attrIdx\":{\n        \"0\": \"6\",\n        \"1\": \"16\",\n        \"2\": \"25\",\n        \"3\": \"32\",\n        \"4\": \"40\"\n    },\n    \"newIdx_to_attrName\":{\n        \"0\": \"Bangs\",\n        \"1\": \"Eyeglasses\",\n        \"2\": \"No_Beard\",\n        \"3\": \"Smiling\",\n        \"4\": \"Young\"\n    },   \n    \"attrName_to_newIdx\":{\n        \"Bangs\": \"0\",\n        \"Eyeglasses\": \"1\",\n        \"No_Beard\": \"2\",\n        \"Smiling\": \"3\",\n        \"Young\": \"4\"\n    },   \n    \"attrIdx_to_newIdx\":{\n        \"6\": 0,\n        \"16\": 1,\n        \"25\": 2,\n        \"32\": 3,\n        \"40\": 4\n    }    \n}"
  },
  {
    "path": "configs/editing/editing_with_dialog.yml",
    "content": "name: dialog_editing\n\nimg_res: 1024 # 128\n\n# latent code\nlatent_code_path: ./download/editing_data/teaser_latent_code.npz.npy\nlatent_code_index: 38\n\n# inversion\ninversion:\n  is_real_image: False # False\n  img_path: ./download/real_images/annehathaway.png\n  crop_img: True\n  device: cuda\n  img_mse_weight: 1.0\n  step: 600\n  noise: 0.05\n  noise_ramp: 0.75\n  lr: 0.1\n  lr_gen: !!float 1e-4\n\nuse_tb_logger: true\nset_CUDA_VISIBLE_DEVICES: ~\ngpu_ids: [3]\n\nattribute: Eyeglasses\n\nmodel_type: FieldFunctionModel\nfix_layers: true\nreplaced_layers_128: 8\nreplaced_layers_1024: 10\n\nmanual_seed: 2021\n\n# editing configs\nconfidence_thresh: 0\nmax_cls_num: 5\nmin_cls_num: 0\nmax_trials_num: 100\nprint_every: False\ntransform_z_to_w: False\n\n# field_function configs\nnum_layer: 8\nhidden_dim: 512\nleaky_relu_neg_slope: 0.2\n\n# predictor\nattr_file: ./configs/attributes_5.json\nbaseline: classification\nuse_sigmoid: True\ngt_remapping_file: ~\npredictor_ckpt_128: ./download/pretrained_models/predictor_128.pth.tar\npredictor_ckpt_1024: ./download/pretrained_models/predictor_1024.pth.tar\n\n# stylegan configs\nlatent_dim: 512\nn_mlp: 8\nchannel_multiplier_128: 1\nchannel_multiplier_1024: 2\ngenerator_ckpt_128: ./download/pretrained_models/stylegan2_128.pt\ngenerator_ckpt_1024: ./download/pretrained_models/stylegan2_1024.pth\nlatent_space: w\n\n# ---------- Dialog Editing -----------\n\nhas_dialog: True\ndevice_name: gpu\n\n# pretrained field\npretrained_field_128:\n  Bangs: ./download/pretrained_models/128_field/Bangs.pth\n  Eyeglasses: ./download/pretrained_models/128_field/Eyeglasses.pth\n  No_Beard: ./download/pretrained_models/128_field/No_Beard.pth\n  Smiling: ./download/pretrained_models/128_field/Smiling.pth\n  Young: ./download/pretrained_models/128_field/Young.pth\npretrained_field_1024:\n  Bangs: ./download/pretrained_models/1024_field/Bangs.pth\n  Eyeglasses: ./download/pretrained_models/1024_field/Eyeglasses.pth\n  No_Beard: ./download/pretrained_models/1024_field/No_Beard.pth\n  Smiling: ./download/pretrained_models/1024_field/Smiling.pth\n  Young: ./download/pretrained_models/1024_field/Young.pth\n\nattr_to_idx:\n  Bangs: 0\n  Eyeglasses: 1\n  No_Beard: 2\n  Smiling: 3\n  Young: 4\n\n# language template files set up\nfeedback_templates_file: ./language/templates/feedback.json\nmetadata_file: ./language/templates/metadata_fsm.json\npool_file: ./language/templates/pool.json\nsystem_mode_file: ./language/templates/system_mode.json\ninput_vocab_file: ./language/templates/vocab.json\n# dialog setting\npostfix_prob: 0.3\nwhether_enough_general_prob: 0.2\nallow_unknown: 1\nverbose: 0\n\n# pretrained language encoder\npretrained_language_encoder: ./download/pretrained_models/language_encoder.pth.tar\nlanguage_encoder:\n  word_embedding_dim: 300\n  text_embed_size: 1024\n  linear_hidden_size: 256\n  linear_dropout_rate: 0\n"
  },
  {
    "path": "configs/editing/editing_wo_dialog.yml",
    "content": "name: editing_wo_dialog\n\nimg_res: 1024 # 128\n\n# latent code\nlatent_code_path: ./download/editing_data/teaser_latent_code.npz.npy\nlatent_code_index: 38\n\n# inversion\ninversion:\n  is_real_image: False # False\n  img_path: ./download/real_images/annehathaway.png\n  crop_img: True\n  device: cuda\n  img_mse_weight: 1.0\n  step: 600\n  noise: 0.05\n  noise_ramp: 0.75\n  lr: 0.1\n  lr_gen: !!float 1e-4\n\nuse_tb_logger: true\nset_CUDA_VISIBLE_DEVICES: ~\ngpu_ids: [3]\n\nattribute: Eyeglasses\n\nmodel_type: FieldFunctionModel\nfix_layers: true\nreplaced_layers_128: 8\nreplaced_layers_1024: 10\n\nmanual_seed: 2021\n\n# editing configs\nconfidence_thresh: 0\nmax_cls_num: 5\nmin_cls_num: 0\nmax_trials_num: 100\nprint_every: False\ntransform_z_to_w: False\n\n# field_function configs\nnum_layer: 8\nhidden_dim: 512\nleaky_relu_neg_slope: 0.2\n\n# predictor\nattr_file: ./configs/attributes_5.json\nbaseline: classification\nuse_sigmoid: True\ngt_remapping_file: ~\npredictor_ckpt_128: ./download/pretrained_models/predictor_128.pth.tar\npredictor_ckpt_1024: ./download/pretrained_models/predictor_1024.pth.tar\n\n# stylegan configs\nlatent_dim: 512\nn_mlp: 8\nchannel_multiplier_128: 1\nchannel_multiplier_1024: 2\ngenerator_ckpt_128: ./download/pretrained_models/stylegan2_128.pt\ngenerator_ckpt_1024: ./download/pretrained_models/stylegan2_1024.pth\nlatent_space: w\n\n# ---------- Dialog Editing -----------\nhas_dialog: False\ndevice_name: gpu\n\n# pretrained field\npretrained_field_128:\n  Bangs: ./download/pretrained_models/128_field/Bangs.pth\n  Eyeglasses: ./download/pretrained_models/128_field/Eyeglasses.pth\n  No_Beard: ./download/pretrained_models/128_field/No_Beard.pth\n  Smiling: ./download/pretrained_models/128_field/Smiling.pth\n  Young: ./download/pretrained_models/128_field/Young.pth\npretrained_field_1024:\n  Bangs: ./download/pretrained_models/1024_field/Bangs.pth\n  Eyeglasses: ./download/pretrained_models/1024_field/Eyeglasses.pth\n  No_Beard: ./download/pretrained_models/1024_field/No_Beard.pth\n  Smiling: ./download/pretrained_models/1024_field/Smiling.pth\n  Young: ./download/pretrained_models/1024_field/Young.pth\n\nattr_to_idx:\n  Bangs: 0\n  Eyeglasses: 1\n  No_Beard: 2\n  Smiling: 3\n  Young: 4\n"
  },
  {
    "path": "configs/train/field_1024_bangs.yml",
    "content": "name: field_1024_bangs\nuse_tb_logger: true\nset_CUDA_VISIBLE_DEVICES: ~\ngpu_ids: [3]\n\nattribute: Bangs\n\nmodel_type: FieldFunctionModel\nfix_layers: true\nreplaced_layers: 10\n\n# dataset configs\nbatch_size: 8\nnum_workers: 8\ninput_latent_dir: ./download/train_data/1024/Bangs\nediting_latent_code_path: ./download/editing_data/1024/Bangs.npz.npy\nnum_attr: 5\nval_on_train_subset: true\nval_on_valset: true\n\n# training configs\nval_freq: 1\nprint_freq: 100\nweight_decay: 0\nmanual_seed: 2021\nnum_epochs: 500\nlr: !!float 1e-4\nlr_decay: step\ngamma: 0.1\nstep: 100\n\n# editing configs\nconfidence_thresh: 0.8\nmax_cls_num: 5\nmax_trials_num: 100\nprint_every: False\n\n# field_function configs\nnum_layer: 8\nhidden_dim: 512\nleaky_relu_neg_slope: 0.2\n\n# loss configs\n# predictor loss\nedited_attribute_weight: 1.0\nattr_file: ./configs/attributes_5.json\npredictor_ckpt: ./download/pretrained_models/predictor_1024.pth.tar\n\n# arcface loss\npretrained_arcface: ./download/pretrained_models/arcface_resnet18_110.pth\narcface_weight: 5.0\narcface_loss_type: l1\n# disciminator loss\ndisc_weight: 1.0\ndiscriminator_ckpt: ./download/pretrained_models/StyleGAN2_FFHQ1024_discriminator.pth\n\n# stylegan configs\nimg_res: 1024\nlatent_dim: 512\nn_mlp: 8\nchannel_multiplier: 2\ngenerator_ckpt: ./download/pretrained_models/stylegan2_1024.pth\nlatent_space: w"
  },
  {
    "path": "configs/train/field_1024_beard.yml",
    "content": "name: field_1024_beard\nuse_tb_logger: true\nset_CUDA_VISIBLE_DEVICES: ~\ngpu_ids: [3]\n\nattribute: No_Beard\n\nmodel_type: FieldFunctionModel\nfix_layers: true\nreplaced_layers: 10\n\n# dataset configs\nbatch_size: 8\nnum_workers: 8\ninput_latent_dir: ./download/train_data/1024/No_Beard\nediting_latent_code_path: ./download/editing_data/1024/No_Beard.npz.npy\nnum_attr: 5\nval_on_train_subset: true\nval_on_valset: true\n\n# training configs\nval_freq: 1\nprint_freq: 100\nweight_decay: 0\nmanual_seed: 2021\nnum_epochs: 30\nlr: !!float 1e-4\nlr_decay: step\ngamma: 0.1\nstep: 100\n\n# editing configs\nconfidence_thresh: 0.8\nmax_cls_num: 5\nmax_trials_num: 100\nprint_every: False\n\n# field_function configs\nnum_layer: 8\nhidden_dim: 512\nleaky_relu_neg_slope: 0.2\n\n# loss configs\n# predictor loss\nedited_attribute_weight: 1.0\nattr_file: ./configs/attributes_5.json\npredictor_ckpt: ./download/pretrained_models/predictor_1024.pth.tar\n\n# arcface loss\npretrained_arcface: ./download/pretrained_models/arcface_resnet18_110.pth\narcface_weight: 10.0\narcface_loss_type: l1\n# disciminator loss\ndisc_weight: 1.0\ndiscriminator_ckpt: ./download/pretrained_models/StyleGAN2_FFHQ1024_discriminator.pth\n\n# stylegan configs\nimg_res: 1024\nlatent_dim: 512\nn_mlp: 8\nchannel_multiplier: 2\ngenerator_ckpt: ./download/pretrained_models/stylegan2_1024.pth\nlatent_space: w"
  },
  {
    "path": "configs/train/field_1024_eyeglasses.yml",
    "content": "name: field_1024_eyeglasses\nuse_tb_logger: true\nset_CUDA_VISIBLE_DEVICES: ~\ngpu_ids: [3]\n\nattribute: Eyeglasses\n\nmodel_type: FieldFunctionModel\nfix_layers: true\nreplaced_layers: 10\n\n# dataset configs\nbatch_size: 8\nnum_workers: 8\ninput_latent_dir: ./download/train_data/1024/Eyeglasses\nediting_latent_code_path: ./download/editing_data/1024/Eyeglasses.npz.npy\nnum_attr: 5\nval_on_train_subset: true\nval_on_valset: true\n\n# training configs\nval_freq: 1\nprint_freq: 100\nweight_decay: 0\nmanual_seed: 2021\nnum_epochs: 30\nlr: !!float 1e-4\nlr_decay: step\ngamma: 0.1\nstep: 100\n\n# editing configs\nconfidence_thresh: 0.8\nmax_cls_num: 5\nmax_trials_num: 100\nprint_every: False\n\n# field_function configs\nnum_layer: 8\nhidden_dim: 512\nleaky_relu_neg_slope: 0.2\n\n# loss configs\n# predictor loss\nedited_attribute_weight: 1.0\nattr_file: ./configs/attributes_5.json\npredictor_ckpt: ./download/pretrained_models/predictor_1024.pth.tar\n\n# arcface loss\npretrained_arcface: ./download/pretrained_models/arcface_resnet18_110.pth\narcface_weight: 10.0\narcface_loss_type: l1\n# disciminator loss\ndisc_weight: 1.0\ndiscriminator_ckpt: ./download/pretrained_models/StyleGAN2_FFHQ1024_discriminator.pth\n\n# stylegan configs\nimg_res: 1024\nlatent_dim: 512\nn_mlp: 8\nchannel_multiplier: 2\ngenerator_ckpt: ./download/pretrained_models/stylegan2_1024.pth\nlatent_space: w\n"
  },
  {
    "path": "configs/train/field_1024_smiling.yml",
    "content": "name: field_1024_smiling\nuse_tb_logger: true\nset_CUDA_VISIBLE_DEVICES: ~\ngpu_ids: [3]\n\nattribute: Smiling\n\nmodel_type: FieldFunctionModel\nfix_layers: true\nreplaced_layers: 10\n\n# dataset configs\nbatch_size: 8\nnum_workers: 8\ninput_latent_dir: ./download/train_data/1024/Smiling\nediting_latent_code_path: ./download/editing_data/1024/Smiling.npz.npy\nnum_attr: 5\nval_on_train_subset: true\nval_on_valset: true\n\n# training configs\nval_freq: 1\nprint_freq: 100\nweight_decay: 0\nmanual_seed: 2021\nnum_epochs: 30\nlr: !!float 1e-4\nlr_decay: step\ngamma: 0.1\nstep: 100\n\n# editing configs\nconfidence_thresh: 0.8\nmax_cls_num: 5\nmax_trials_num: 100\nprint_every: False\n\n# field_function configs\nnum_layer: 8\nhidden_dim: 512\nleaky_relu_neg_slope: 0.2\n\n# loss configs\n# predictor loss\nedited_attribute_weight: 1.0\nattr_file: ./configs/attributes_5.json\npredictor_ckpt: ./download/pretrained_models/predictor_1024.pth.tar\n\n# arcface loss\npretrained_arcface: ./download/pretrained_models/arcface_resnet18_110.pth\narcface_weight: 5.0\narcface_loss_type: l1\n# disciminator loss\ndisc_weight: 1.0\ndiscriminator_ckpt: ./download/pretrained_models/StyleGAN2_FFHQ1024_discriminator.pth\n\n# stylegan configs\nimg_res: 1024\nlatent_dim: 512\nn_mlp: 8\nchannel_multiplier: 2\ngenerator_ckpt: ./download/pretrained_models/stylegan2_1024.pth\nlatent_space: w"
  },
  {
    "path": "configs/train/field_1024_young.yml",
    "content": "name: field_1024_young\nuse_tb_logger: true\nset_CUDA_VISIBLE_DEVICES: ~\ngpu_ids: [3]\n\nattribute: Young\n\nmodel_type: FieldFunctionModel\nfix_layers: true\nreplaced_layers: 10\n\n# dataset configs\nbatch_size: 8\nnum_workers: 8\ninput_latent_dir: ./download/train_data/1024/Young\nediting_latent_code_path: ./download/editing_data/1024/Young.npz.npy\nnum_attr: 5\nval_on_train_subset: true\nval_on_valset: true\n\n# training configs\nval_freq: 1\nprint_freq: 100\nweight_decay: 0\nmanual_seed: 2021\nnum_epochs: 30\nlr: !!float 1e-4\nlr_decay: step\ngamma: 0.1\nstep: 100\n\n# editing configs\nconfidence_thresh: 0.8\nmax_cls_num: 5\nmax_trials_num: 100\nprint_every: False\n\n# field_function configs\nnum_layer: 8\nhidden_dim: 512\nleaky_relu_neg_slope: 0.2\n\n# loss configs\n# predictor loss\nedited_attribute_weight: 1.0\nattr_file: ./configs/attributes_5.json\npredictor_ckpt: ./download/pretrained_models/predictor_1024.pth.tar\n\n# arcface loss\npretrained_arcface: ./download/pretrained_models/arcface_resnet18_110.pth\narcface_weight: 10.0\narcface_loss_type: l1\n# disciminator loss\ndisc_weight: 1.0\ndiscriminator_ckpt: ./download/pretrained_models/StyleGAN2_FFHQ1024_discriminator.pth\n\n# stylegan configs\nimg_res: 1024\nlatent_dim: 512\nn_mlp: 8\nchannel_multiplier: 2\ngenerator_ckpt: ./download/pretrained_models/stylegan2_1024.pth\nlatent_space: w"
  },
  {
    "path": "configs/train/field_128_bangs.yml",
    "content": "name: field_128_bangs\nuse_tb_logger: true\nset_CUDA_VISIBLE_DEVICES: ~\ngpu_ids: [3]\n\nattribute: Bangs\n\nmodel_type: FieldFunctionModel\nfix_layers: true\nreplaced_layers: 8\n\n# dataset configs\nbatch_size: 32\nnum_workers: 8\ninput_latent_dir: ./download/train_data/128/Bangs\nediting_latent_code_path: ./download/editing_data/128/Bangs.npz.npy\nnum_attr: 5\nval_on_train_subset: true\nval_on_valset: true\n\n# training configs\nval_freq: 1\nprint_freq: 100\nweight_decay: 0\nmanual_seed: 2021\nnum_epochs: 30\nlr: !!float 1e-4\nlr_decay: step\ngamma: 0.1\nstep: 100\n\n# editing configs\nconfidence_thresh: 0.8\nmax_cls_num: 5\nmax_trials_num: 100\nprint_every: False\n\n# field_function configs\nnum_layer: 8\nhidden_dim: 512\nleaky_relu_neg_slope: 0.2\n\n# loss configs\n# predictor loss\nedited_attribute_weight: 1.0\nattr_file: ./configs/attributes_5.json\npredictor_ckpt: ./download/pretrained_models/predictor_128.pth.tar\n\n# arcface loss\npretrained_arcface: ./download/pretrained_models/arcface_resnet18_110.pth\narcface_weight: 5.0\narcface_loss_type: l1\n# disciminator loss\ndisc_weight: 1.0\ndiscriminator_ckpt: ./download/pretrained_models/stylegan2_128.pt\n\n# stylegan configs\nimg_res: 128\nlatent_dim: 512\nn_mlp: 8\nchannel_multiplier: 1\ngenerator_ckpt: ./download/pretrained_models/stylegan2_128.pt\nlatent_space: w"
  },
  {
    "path": "configs/train/field_128_beard.yml",
    "content": "name: field_128_beard\nuse_tb_logger: true\nset_CUDA_VISIBLE_DEVICES: ~\ngpu_ids: [3]\n\nattribute: No_Beard\n\nmodel_type: FieldFunctionModel\nfix_layers: true\nreplaced_layers: 8\n\n# dataset configs\nbatch_size: 32\nnum_workers: 8\ninput_latent_dir: ./download/train_data/128/No_Beard\nediting_latent_code_path: ./download/editing_data/128/No_Beard.npz.npy\nnum_attr: 5\nval_on_train_subset: true\nval_on_valset: true\n\n# training configs\nval_freq: 1\nprint_freq: 100\nweight_decay: 0\nmanual_seed: 2021\nnum_epochs: 30\nlr: !!float 1e-4\nlr_decay: step\ngamma: 0.1\nstep: 100\n\n# editing configs\nconfidence_thresh: 0.8\nmax_cls_num: 5\nmax_trials_num: 100\nprint_every: False\n\n# field_function configs\nnum_layer: 8\nhidden_dim: 512\nleaky_relu_neg_slope: 0.2\n\n# loss configs\n# predictor loss\nedited_attribute_weight: 1.0\nattr_file: ./configs/attributes_5.json\npredictor_ckpt: ./download/pretrained_models/predictor_128.pth.tar\n\n# arcface loss\npretrained_arcface: ./download/pretrained_models/arcface_resnet18_110.pth\narcface_weight: 5.0\narcface_loss_type: l1\n# disciminator loss\ndisc_weight: 1.0\ndiscriminator_ckpt: ./download/pretrained_models/stylegan2_128.pt\n\n# stylegan configs\nimg_res: 128\nlatent_dim: 512\nn_mlp: 8\nchannel_multiplier: 1\ngenerator_ckpt: ./download/pretrained_models/stylegan2_128.pt\nlatent_space: w"
  },
  {
    "path": "configs/train/field_128_eyeglasses.yml",
    "content": "name: field_128_eyeglasses\nuse_tb_logger: true\nset_CUDA_VISIBLE_DEVICES: ~\ngpu_ids: [3]\n\nattribute: Eyeglasses\n\nmodel_type: FieldFunctionModel\nfix_layers: true\nreplaced_layers: 8\n\n# dataset configs\nbatch_size: 32\nnum_workers: 8\ninput_latent_dir: ./download/train_data/128/Eyeglasses\nediting_latent_code_path: ./download/editing_data/128/Eyeglasses.npz.npy\nnum_attr: 5\nval_on_train_subset: true\nval_on_valset: true\n\n# training configs\nval_freq: 1\nprint_freq: 100\nweight_decay: 0\nmanual_seed: 2021\nnum_epochs: 30\nlr: !!float 1e-4\nlr_decay: step\ngamma: 0.1\nstep: 100\n\n# editing configs\nconfidence_thresh: 0.8\nmax_cls_num: 5\nmax_trials_num: 100\nprint_every: False\n\n# field_function configs\nnum_layer: 8\nhidden_dim: 512\nleaky_relu_neg_slope: 0.2\n\n# loss configs\n# predictor loss\nedited_attribute_weight: 1.0\nattr_file: ./configs/attributes_5.json\npredictor_ckpt: ./download/pretrained_models/predictor_128.pth.tar\n\n# arcface loss\npretrained_arcface: ./download/pretrained_models/arcface_resnet18_110.pth\narcface_weight: 5.0\narcface_loss_type: l1\n# disciminator loss\ndisc_weight: 1.0\ndiscriminator_ckpt: ./download/pretrained_models/stylegan2_128.pt\n\n# stylegan configs\nimg_res: 128\nlatent_dim: 512\nn_mlp: 8\nchannel_multiplier: 1\ngenerator_ckpt: ./download/pretrained_models/stylegan2_128.pt\nlatent_space: w"
  },
  {
    "path": "configs/train/field_128_smiling.yml",
    "content": "name: field_128_smiling\nuse_tb_logger: true\nset_CUDA_VISIBLE_DEVICES: ~\ngpu_ids: [3]\n\nattribute: Smiling\n\nmodel_type: FieldFunctionModel\nfix_layers: true\nreplaced_layers: 8\n\n# dataset configs\nbatch_size: 32\nnum_workers: 8\ninput_latent_dir: ./download/train_data/128/Smiling\nediting_latent_code_path: ./download/editing_data/128/Smiling.npz.npy\nnum_attr: 5\nval_on_train_subset: true\nval_on_valset: true\n\n# training configs\nval_freq: 1\nprint_freq: 100\nweight_decay: 0\nmanual_seed: 2021\nnum_epochs: 30\nlr: !!float 1e-4\nlr_decay: step\ngamma: 0.1\nstep: 100\n\n# editing configs\nconfidence_thresh: 0.8\nmax_cls_num: 5\nmax_trials_num: 100\nprint_every: False\n\n# field_function configs\nnum_layer: 8\nhidden_dim: 512\nleaky_relu_neg_slope: 0.2\n\n# loss configs\n# predictor loss\nedited_attribute_weight: 1.0\nattr_file: ./configs/attributes_5.json\npredictor_ckpt: ./download/pretrained_models/predictor_128.pth.tar\n\n# arcface loss\npretrained_arcface: ./download/pretrained_models/arcface_resnet18_110.pth\narcface_weight: 5.0\narcface_loss_type: l1\n# disciminator loss\ndisc_weight: 1.0\ndiscriminator_ckpt: ./download/pretrained_models/stylegan2_128.pt\n\n# stylegan configs\nimg_res: 128\nlatent_dim: 512\nn_mlp: 8\nchannel_multiplier: 1\ngenerator_ckpt: ./download/pretrained_models/stylegan2_128.pt"
  },
  {
    "path": "configs/train/field_128_young.yml",
    "content": "name: field_128_young\nuse_tb_logger: true\nset_CUDA_VISIBLE_DEVICES: ~\ngpu_ids: [3]\n\nattribute: Young\n\nmodel_type: FieldFunctionModel\nfix_layers: true\nreplaced_layers: 8\n\n# dataset configs\nbatch_size: 32\nnum_workers: 8\ninput_latent_dir: ./download/train_data/128/Young\nediting_latent_code_path: ./download/editing_data/128/Young.npz.npy\nnum_attr: 5\nval_on_train_subset: true\nval_on_valset: true\n\n# training configs\nval_freq: 1\nprint_freq: 100\nweight_decay: 0\nmanual_seed: 2021\nnum_epochs: 30\nlr: !!float 1e-4\nlr_decay: step\ngamma: 0.1\nstep: 100\n\n# editing configs\nconfidence_thresh: 0.5\nmax_cls_num: 5\nmax_trials_num: 100\nprint_every: False\n\n# field_function configs\nnum_layer: 8\nhidden_dim: 512\nleaky_relu_neg_slope: 0.2\n\n# loss configs\n# predictor loss\nedited_attribute_weight: 1.0\nattr_file: ./configs/attributes_5.json\npredictor_ckpt: ./download/pretrained_models/predictor_128.pth.tar\n\n# arcface loss\npretrained_arcface: ./download/pretrained_models/arcface_resnet18_110.pth\narcface_weight: 5.0\narcface_loss_type: l1\n# disciminator loss\ndisc_weight: 1.0\ndiscriminator_ckpt: ./download/pretrained_models/stylegan2_128.pt\n\n# stylegan configs\nimg_res: 128\nlatent_dim: 512\nn_mlp: 8\nchannel_multiplier: 1\ngenerator_ckpt: ./download/pretrained_models/stylegan2_128.pt\nlatent_space: w"
  },
  {
    "path": "data/__init__.py",
    "content": ""
  },
  {
    "path": "data/latent_code_dataset.py",
    "content": "\"\"\"\nDataset for field function\n\"\"\"\n\nimport os\nimport os.path\nimport random\n\nimport numpy as np\nimport torch\nimport torch.utils.data as data\n\n\nclass LatentCodeDataset(data.Dataset):\n\n    def __init__(self, input_dir, subset_samples=None):\n\n        assert os.path.exists(input_dir)\n        self.latent_codes = np.load(\n            os.path.join(input_dir, 'selected_latent_code.npy')).astype(float)\n        self.labels = np.load(\n            os.path.join(input_dir, 'selected_pred_class.npy')).astype(int)\n        self.scores = np.load(\n            os.path.join(input_dir, 'selected_pred_scores.npy')).astype(float)\n\n        self.latent_codes = torch.FloatTensor(self.latent_codes)\n        self.labels = torch.LongTensor(self.labels)\n        self.scores = torch.FloatTensor(self.scores)\n\n        # select a subset from train set\n        if subset_samples is not None and len(\n                self.latent_codes) > subset_samples:\n            idx = list(range(len(self.latent_codes)))\n            selected_idx = random.sample(idx, subset_samples)\n            self.latent_codes = [self.latent_codes[i] for i in selected_idx]\n            self.labels = [self.labels[i] for i in selected_idx]\n            self.scores = [self.scores[i] for i in selected_idx]\n\n        assert len(self.latent_codes) == len(self.labels)\n        assert len(self.labels) == len(self.scores)\n\n    def __getitem__(self, index):\n        return (self.latent_codes[index], self.labels[index],\n                self.scores[index])\n\n    def __len__(self):\n        return len(self.latent_codes)\n"
  },
  {
    "path": "editing_quantitative.py",
    "content": "import argparse\nimport logging\nimport os\n\nimport numpy as np\n\nfrom models import create_model\nfrom utils.logger import get_root_logger\nfrom utils.numerical_metrics import compute_num_metrics\nfrom utils.options import dict2str, dict_to_nonedict, parse\nfrom utils.util import make_exp_dirs\n\n\ndef main():\n    # options\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--opt', type=str, help='Path to option YAML file.')\n    parser.add_argument(\n        '--pretrained_path', type=str, help='Path to pretrained field model')\n    args = parser.parse_args()\n    opt = parse(args.opt, is_train=False)\n\n    # mkdir and loggers\n    make_exp_dirs(opt)\n\n    # convert to NoneDict, which returns None for missing keys\n    opt = dict_to_nonedict(opt)\n\n    # load editing latent code\n    editing_latent_codes = np.load(opt['editing_latent_code_path'])\n    num_latent_codes = editing_latent_codes.shape[0]\n\n    save_path = f'{opt[\"path\"][\"visualization\"]}'\n    os.makedirs(save_path)\n    editing_logger = get_root_logger(\n        logger_name='editing',\n        log_level=logging.INFO,\n        log_file=f'{save_path}/editing.log')\n\n    editing_logger.info(dict2str(opt))\n\n    field_model = create_model(opt)\n\n    field_model.load_network(args.pretrained_path)\n\n    field_model.continuous_editing(editing_latent_codes, save_path,\n                                   editing_logger)\n\n    _, _ = compute_num_metrics(save_path, num_latent_codes,\n                               opt['pretrained_arcface'], opt['attr_file'],\n                               opt['predictor_ckpt'],\n                               opt['attr_dict'][opt['attribute']],\n                               editing_logger)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "editing_with_dialog.py",
    "content": "import argparse\nimport json\nimport logging\nimport os.path\n\nimport numpy as np\nimport torch\n\nfrom models import create_model\nfrom utils.dialog_edit_utils import dialog_with_real_user\nfrom utils.inversion_utils import inversion\nfrom utils.logger import get_root_logger\nfrom utils.options import (dict2str, dict_to_nonedict, parse,\n                           parse_args_from_opt, parse_opt_wrt_resolution)\nfrom utils.util import make_exp_dirs\n\n\ndef parse_args():\n    \"\"\"Parses arguments.\"\"\"\n    parser = argparse.ArgumentParser(description='')\n    parser.add_argument(\n        '--opt', default=None, type=str, help='Path to option YAML file.')\n    return parser.parse_args()\n\n\ndef main():\n\n    # ---------- Set up -----------\n    args = parse_args()\n    opt = parse(args.opt, is_train=False)\n    opt = parse_opt_wrt_resolution(opt)\n    args = parse_args_from_opt(args, opt)\n    make_exp_dirs(opt)\n\n    # convert to NoneDict, which returns None for missing keys\n    opt = dict_to_nonedict(opt)\n\n    # set up logger\n    save_log_path = f'{opt[\"path\"][\"log\"]}'\n    dialog_logger = get_root_logger(\n        logger_name='dialog',\n        log_level=logging.INFO,\n        log_file=f'{save_log_path}/dialog.log')\n    dialog_logger.info(dict2str(opt))\n\n    save_image_path = f'{opt[\"path\"][\"visualization\"]}'\n    os.makedirs(save_image_path)\n\n    # ---------- Load files -----------\n    dialog_logger.info('loading template files')\n    with open(opt['feedback_templates_file'], 'r') as f:\n        args.feedback_templates = json.load(f)\n        args.feedback_replacement = args.feedback_templates['replacement']\n    with open(opt['pool_file'], 'r') as f:\n        pool = json.load(f)\n        args.synonyms_dict = pool[\"synonyms\"]\n\n    # ---------- create model ----------\n    field_model = create_model(opt)\n\n    # ---------- load latent code ----------\n    if opt['inversion']['is_real_image']:\n        latent_code = inversion(opt, field_model)\n    else:\n        if opt['latent_code_path'] is None:\n            latent_code = torch.randn(1, 512, device=torch.device('cuda'))\n            with torch.no_grad():\n                latent_code = field_model.stylegan_gen.get_latent(latent_code)\n            latent_code = latent_code.cpu().numpy()\n            np.save(f'{opt[\"path\"][\"visualization\"]}/latent_code.npz.npy',\n                    latent_code)\n        else:\n            i = opt['latent_code_index']\n            latent_code = np.load(\n                opt['latent_code_path'],\n                allow_pickle=True).item()[f\"{str(i).zfill(7)}.png\"]\n            latent_code = torch.from_numpy(latent_code).to(\n                torch.device('cuda'))\n            with torch.no_grad():\n                latent_code = field_model.stylegan_gen.get_latent(latent_code)\n            latent_code = latent_code.cpu().numpy()\n\n    np.save(f'{opt[\"path\"][\"visualization\"]}/latent_code.npz.npy', latent_code)\n\n    # ---------- Perform dialog-based editing with user -----------\n    dialog_overall_log = dialog_with_real_user(field_model, latent_code, opt,\n                                               args, dialog_logger)\n\n    # ---------- Log the dialog history -----------\n    for (key, value) in dialog_overall_log.items():\n        dialog_logger.info(f'{key}: {value}')\n    dialog_logger.info('successfully end.')\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "editing_wo_dialog.py",
    "content": "import argparse\nimport logging\nimport os\n\nimport numpy as np\nimport torch\n\nfrom models import create_model\nfrom models.utils import save_image\nfrom utils.editing_utils import edit_target_attribute\nfrom utils.inversion_utils import inversion\nfrom utils.logger import get_root_logger\nfrom utils.options import (dict2str, dict_to_nonedict, parse,\n                           parse_opt_wrt_resolution)\nfrom utils.util import make_exp_dirs\n\n\ndef parse_args():\n    \"\"\"Parses arguments.\"\"\"\n    parser = argparse.ArgumentParser(description='')\n    parser.add_argument('--opt', type=str, help='Path to option YAML file.')\n    parser.add_argument('--attr', type=str, help='Attribute to be edited.')\n    parser.add_argument(\n        '--target_val', type=int, help='Target Attribute Value.')\n\n    return parser.parse_args()\n\n\ndef main():\n    # ---------- Set up -----------\n    args = parse_args()\n\n    opt = parse(args.opt, is_train=False)\n    opt = parse_opt_wrt_resolution(opt)\n    # args = parse_args_from_opt(args, opt)\n    make_exp_dirs(opt)\n\n    # convert to NoneDict, which returns None for missing keys\n    opt = dict_to_nonedict(opt)\n\n    # set up logger\n    save_log_path = f'{opt[\"path\"][\"log\"]}'\n    editing_logger = get_root_logger(\n        logger_name='editing',\n        log_level=logging.INFO,\n        log_file=f'{save_log_path}/editing.log')\n    editing_logger.info(dict2str(opt))\n\n    save_image_path = f'{opt[\"path\"][\"visualization\"]}'\n    os.makedirs(save_image_path)\n\n    # ---------- create model ----------\n    field_model = create_model(opt)\n\n    # ---------- load latent code ----------\n    if opt['inversion']['is_real_image']:\n        latent_code = inversion(opt, field_model)\n    else:\n        if opt['latent_code_path'] is None:\n            latent_code = torch.randn(1, 512, device=torch.device('cuda'))\n            with torch.no_grad():\n                latent_code = field_model.stylegan_gen.get_latent(latent_code)\n            latent_code = latent_code.cpu().numpy()\n            np.save(f'{opt[\"path\"][\"visualization\"]}/latent_code.npz.npy',\n                    latent_code)\n        else:\n            i = opt['latent_code_index']\n            latent_code = np.load(\n                opt['latent_code_path'],\n                allow_pickle=True).item()[f\"{str(i).zfill(7)}.png\"]\n            latent_code = torch.from_numpy(latent_code).to(\n                torch.device('cuda'))\n            with torch.no_grad():\n                latent_code = field_model.stylegan_gen.get_latent(latent_code)\n            latent_code = latent_code.cpu().numpy()\n\n    # ---------- synthesize images ----------\n    with torch.no_grad():\n        start_image, start_label, start_score = \\\n            field_model.synthesize_and_predict(torch.from_numpy(latent_code).to(torch.device('cuda'))) # noqa\n\n    save_image(start_image, f'{opt[\"path\"][\"visualization\"]}/start_image.png')\n\n    # initialize attribtue_dict\n    attribute_dict = {\n        \"Bangs\": start_label[0],\n        \"Eyeglasses\": start_label[1],\n        \"No_Beard\": start_label[2],\n        \"Smiling\": start_label[3],\n        \"Young\": start_label[4],\n    }\n\n    edit_label = {'attribute': args.attr, 'target_score': args.target_val}\n\n    edited_latent_code = None\n    print_intermediate_result = True\n    round_idx = 0\n\n    attribute_dict, exception_mode, latent_code, edited_latent_code = edit_target_attribute(\n        opt, attribute_dict, edit_label, round_idx, latent_code,\n        edited_latent_code, field_model, editing_logger,\n        print_intermediate_result)\n\n    if exception_mode != 'normal':\n        if exception_mode == 'already_at_target_class':\n            editing_logger.info(\"This attribute is already at the degree that you want. Let's try a different attribute degree or another attribute.\")\n        elif exception_mode == 'max_edit_num_reached':\n            editing_logger.info(\"Sorry, we are unable to edit this attribute. Perhaps we can try something else.\")\n\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "environment.yml",
    "content": "name: talk_edit\nchannels:\n  - pytorch\n  - conda-forge\n  - anaconda\n  - defaults\ndependencies:\n  - _libgcc_mutex=0.1=main\n  - absl-py=0.11.0=pyhd3eb1b0_1\n  - aiohttp=3.7.3=py37h27cfd23_1\n  - async-timeout=3.0.1=py37h06a4308_0\n  - attrs=20.3.0=pyhd3eb1b0_0\n  - backcall=0.2.0=py_0\n  - blas=1.0=mkl\n  - blinker=1.4=py37h06a4308_0\n  - blosc=1.21.0=h8c45485_0\n  - brotli=1.0.9=he6710b0_2\n  - brotlipy=0.7.0=py37h27cfd23_1003\n  - brunsli=0.1=h2531618_0\n  - bzip2=1.0.8=h7b6447c_0\n  - c-ares=1.17.1=h27cfd23_0\n  - ca-certificates=2021.7.5=h06a4308_1\n  - cachetools=4.2.1=pyhd3eb1b0_0\n  - certifi=2021.5.30=py37h06a4308_0\n  - cffi=1.14.4=py37h261ae71_0\n  - chardet=3.0.4=py37h06a4308_1003\n  - charls=2.1.0=he6710b0_2\n  - click=7.1.2=pyhd3eb1b0_0\n  - cloudpickle=1.6.0=py_0\n  - cryptography=2.9.2=py37h1ba5d50_0\n  - cudatoolkit=10.1.243=h6bb024c_0\n  - cycler=0.10.0=py_2\n  - cytoolz=0.11.0=py37h7b6447c_0\n  - dask-core=2021.3.0=pyhd3eb1b0_0\n  - decorator=4.4.2=pyhd3eb1b0_0\n  - freetype=2.10.4=h5ab3b9f_0\n  - giflib=5.1.4=h14c3975_1\n  - google-auth=1.24.0=pyhd3eb1b0_0\n  - google-auth-oauthlib=0.4.2=pyhd3eb1b0_2\n  - grpcio=1.31.0=py37hf8bcb03_0\n  - icu=67.1=he1b5a44_0\n  - idna=2.10=pyhd3eb1b0_0\n  - imagecodecs=2021.1.11=py37h581e88b_1\n  - imageio=2.9.0=py_0\n  - intel-openmp=2020.2=254\n  - ipython=7.18.1=py37h5ca1d4c_0\n  - ipython_genutils=0.2.0=py37_0\n  - jedi=0.18.0=py37h06a4308_1\n  - joblib=1.0.0=pyhd3eb1b0_0\n  - jpeg=9b=h024ee3a_2\n  - jxrlib=1.1=h7b6447c_2\n  - kiwisolver=1.3.1=py37hc928c03_0\n  - lcms2=2.11=h396b838_0\n  - ld_impl_linux-64=2.33.1=h53a641e_7\n  - lerc=2.2.1=h2531618_0\n  - libaec=1.0.4=he6710b0_1\n  - libdeflate=1.7=h27cfd23_5\n  - libedit=3.1.20191231=h14c3975_1\n  - libffi=3.3=he6710b0_2\n  - libgcc-ng=9.1.0=hdf63c60_0\n  - libgfortran-ng=7.3.0=hdf63c60_0\n  - libpng=1.6.37=hbc83047_0\n  - libprotobuf=3.13.0.1=h8b12597_0\n  - libstdcxx-ng=9.1.0=hdf63c60_0\n  - libtiff=4.1.0=h2733197_1\n  - libwebp=1.0.1=h8e7db2f_0\n  - libzopfli=1.0.3=he6710b0_0\n  - lz4-c=1.9.3=h2531618_0\n  - markdown=3.3.3=py37h06a4308_0\n  - matplotlib=3.2.2=1\n  - matplotlib-base=3.2.2=py37h1d35a4c_1\n  - mkl=2020.2=256\n  - mkl-service=2.3.0=py37he8ac12f_0\n  - mkl_fft=1.2.0=py37h23d657b_0\n  - mkl_random=1.1.1=py37h0573a6f_0\n  - multidict=4.7.6=py37h7b6447c_1\n  - ncurses=6.2=he6710b0_1\n  - networkx=2.5=py_0\n  - ninja=1.10.2=py37hff7bd54_0\n  - numpy=1.19.2=py37h54aff64_0\n  - numpy-base=1.19.2=py37hfa32c7d_0\n  - oauthlib=3.1.0=py_0\n  - olefile=0.46=py37_0\n  - openjpeg=2.3.0=h05c96fa_1\n  - openssl=1.1.1k=h27cfd23_0\n  - parso=0.8.0=py_0\n  - pexpect=4.8.0=py37_1\n  - pickleshare=0.7.5=py37_1001\n  - pillow=8.2.0=py37he98fc37_0\n  - pip=20.3.3=py37h06a4308_0\n  - prompt-toolkit=3.0.8=py_0\n  - protobuf=3.13.0.1=py37he6710b0_1\n  - ptyprocess=0.6.0=py37_0\n  - pyasn1=0.4.8=py_0\n  - pyasn1-modules=0.2.8=py_0\n  - pycparser=2.20=py_2\n  - pygments=2.7.1=py_0\n  - pyjwt=2.0.1=py37h06a4308_0\n  - pyopenssl=20.0.1=pyhd3eb1b0_1\n  - pyparsing=2.4.7=pyh9f0ad1d_0\n  - pysocks=1.7.1=py37_1\n  - python=3.7.9=h7579374_0\n  - python-dateutil=2.8.1=py_0\n  - python_abi=3.7=1_cp37m\n  - pytorch=1.6.0=py3.7_cuda10.1.243_cudnn7.6.3_0\n  - pywavelets=1.1.1=py37h7b6447c_2\n  - pyyaml=5.4.1=py37h27cfd23_1\n  - readline=8.0=h7b6447c_0\n  - requests=2.25.1=pyhd3eb1b0_0\n  - requests-oauthlib=1.3.0=py_0\n  - rsa=4.7=pyhd3eb1b0_1\n  - scikit-image=0.17.2=py37hdf5156a_0\n  - scikit-learn=0.23.2=py37h0573a6f_0\n  - scipy=1.6.2=py37h91f5cce_0\n  - setuptools=52.0.0=py37h06a4308_0\n  - six=1.15.0=py37h06a4308_0\n  - snappy=1.1.8=he6710b0_0\n  - sqlite=3.33.0=h62c20be_0\n  - tensorboard=2.3.0=pyh4dce500_0\n  - tensorboard-plugin-wit=1.6.0=py_0\n  - tensorboardx=2.1=py_0\n  - threadpoolctl=2.1.0=pyh5ca1d4c_0\n  - tifffile=2021.3.5=pyhd3eb1b0_1\n  - tk=8.6.10=hbc83047_0\n  - toolz=0.11.1=pyhd3eb1b0_0\n  - torchvision=0.7.0=py37_cu101\n  - tornado=6.1=py37h4abf009_0\n  - tqdm=4.55.1=pyhd3eb1b0_0\n  - traitlets=5.0.5=py_0\n  - typing-extensions=3.7.4.3=hd3eb1b0_0\n  - typing_extensions=3.7.4.3=pyh06a4308_0\n  - urllib3=1.26.3=pyhd3eb1b0_0\n  - wcwidth=0.2.5=py_0\n  - werkzeug=1.0.1=pyhd3eb1b0_0\n  - wheel=0.36.2=pyhd3eb1b0_0\n  - xz=5.2.5=h7b6447c_0\n  - yaml=0.2.5=h7b6447c_0\n  - yarl=1.5.1=py37h7b6447c_0\n  - zfp=0.5.5=h2531618_4\n  - zipp=3.4.0=pyhd3eb1b0_0\n  - zlib=1.2.11=h7b6447c_3\n  - zstd=1.4.5=h9ceee32_0\n  - pip:\n    - cmake==3.21.2\n    - dlib==19.22.1\n    - facenet-pytorch==2.5.2\n    - flake8==3.8.4\n    - future==0.18.2\n    - importlib-metadata==3.4.0\n    - isort==5.7.0\n    - lpips==0.1.4\n    - mccabe==0.6.1\n    - opencv-python==4.5.1.48\n    - pycodestyle==2.6.0\n    - pyflakes==2.2.0\n    - yapf==0.30.0\n"
  },
  {
    "path": "language/accuracy.py",
    "content": "import torch\n\n\ndef head_accuracy(output, target, unlabeled_value=999):\n    \"\"\"\n    Computes the precision@k for the specified values of k\n    output: batch_size * num_cls (for a specific attribute)\n    target: batch_size * 1 (for a specific attribute)\n    return res: res = 100 * num_correct / batch_size, for a specific attribute\n    for a batch\n    \"\"\"\n\n    with torch.no_grad():\n        batch_size = target.size(0)\n\n        # _ = the largest score, pred = cls_idx with the largest score\n        _, pred = output.topk(1, 1, True, True)\n        pred = pred.reshape(-1)\n\n        # acc = float(torch.sum(pred == target)) / float(batch_size) * 100\n        return_dict = {}\n\n        if unlabeled_value is not None:\n\n            correct_count = torch.sum(\n                (target != unlabeled_value) * (pred == target))\n            labeled_count = torch.sum(target != unlabeled_value)\n\n            if labeled_count:\n\n                labeled_acc = float(correct_count) / float(labeled_count) * 100\n            else:\n                labeled_acc = 0\n\n            return_dict['acc'] = labeled_acc\n            return_dict['labeled_count'] = labeled_count\n        else:\n\n            return_dict['acc'] = acc  # noqa\n            return_dict['labeled_count'] = batch_size\n\n        return return_dict\n"
  },
  {
    "path": "language/build_vocab.py",
    "content": "import argparse\nimport json\nimport os\nimport sys\n\nsys.path.append('.')\nfrom language_utils import *  # noqa\n\"\"\"\nBuild vocabulary from all instantiated templates\n\"\"\"\n\n\ndef parse_args():\n    \"\"\"Parses arguments.\"\"\"\n\n    parser = argparse.ArgumentParser(description='Build vocabulary')\n    parser.add_argument(\n        '--input_data_path',\n        required=True,\n        type=str,\n        help='path to the input data file')\n    parser.add_argument(\n        '--output_dir',\n        required=True,\n        type=str,\n        help='folder to save the output vocabulary file')\n\n    return parser.parse_args()\n\n\ndef main():\n    args = parse_args()\n\n    # prepare output directory\n    if not os.path.isdir(args.output_dir):\n        os.makedirs(args.output_dir, exist_ok=True)\n\n    # load text data\n    print(\"Loading text data from\", args.input_data_path)\n    with open(args.input_data_path, 'r') as f:\n        input_data = json.load(f)\n\n    # gather a list of text\n    print(\"Building vocabulary from\", len(input_data), \"text data samples\")\n    text_list = []\n    for idx, data_sample in enumerate(input_data):\n        if idx % 10000 == 0:\n            print('loaded', idx, '/', len(input_data))\n        text = data_sample['text']\n        text_list.append(text)\n\n    # build vocabulary\n    text_token_to_idx = build_vocab(text_list=text_list)  # noqa\n    vocab = {\n        'text_token_to_idx': text_token_to_idx,\n    }\n\n    # save vocabulary\n    print(\"Saving vocabulary file to\",\n          os.path.join(args.output_dir, 'vocab.json'))\n    with open(os.path.join(args.output_dir, 'vocab.json'), 'w') as f:\n        json.dump(vocab, f, indent=4)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "language/dataset.py",
    "content": "import os.path\n\nimport numpy as np\nfrom torch.utils.data import Dataset\n\n\nclass EncoderDataset(Dataset):\n\n    def __init__(self, preprocessed_dir):\n\n        # load text\n        text_path = os.path.join(preprocessed_dir, 'text.npy')\n        self.text = np.load(text_path)\n        # load system_mode\n        system_mode_path = os.path.join(preprocessed_dir, 'system_mode.npy')\n        self.system_mode = np.load(system_mode_path)\n        # load labels\n        labels_path = os.path.join(preprocessed_dir, 'labels.npy')\n        self.labels = np.load(labels_path)\n\n    def __getitem__(self, index):\n        # retrieve text\n        text = self.text[index]\n        # retrieve system_mode\n        system_mode = self.system_mode[index]\n        # retrieve labels\n        labels = self.labels[index]\n\n        return text, system_mode, labels\n\n    def __len__(self):\n        return len(self.text)\n\n\ndef main():\n    \"\"\" Testing the Dataset\"\"\"\n\n    encoderdataset = EncoderDataset(\n        preprocessed_dir=  # noqa\n        ''  # noqa\n    )\n    print('len(encoderdataset):', len(encoderdataset))\n    print('encoderdataset[0]:', encoderdataset[0])\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "language/generate_feedback.py",
    "content": "import argparse\nimport json\nimport os.path\nimport random\n\nimport numpy as np\n\nfrom .language_utils import proper_capitalize\n\n\ndef parse_args():\n    \"\"\"Parses arguments.\"\"\"\n\n    parser = argparse.ArgumentParser(description='')\n\n    parser.add_argument(\n        '--feedback_templates_file',\n        default='./templates/feedback.json',\n        type=str,\n        help='directory to the request templates file')\n    parser.add_argument(\n        '--pool_file',\n        default='./templates/pool.json',\n        type=str,\n        help='directory to the word pool file')\n    parser.add_argument(\n        '--num_feedback',\n        default=100,\n        type=int,\n        help='number of feedback data to generate')\n    parser.add_argument(\n        '--output_file_dir',\n        required=True,\n        type=str,\n        help='folder to save the output request file')\n    parser.add_argument(\n        '--output_file_name',\n        required=True,\n        type=str,\n        help='name of the output request file')\n    parser.add_argument(\n        '--whether_enough_general_prob',\n        default=0.2,\n        type=float,\n        help='probability of using general templates in whether_enough mode')\n    return parser.parse_args()\n\n\ndef main():\n    args = parse_args()\n\n    if not os.path.isdir(args.output_file_dir):\n        os.makedirs(args.output_file_dir, exist_ok=True)\n\n    # load template files\n    print('loading template files')\n    with open(args.feedback_templates_file, 'r') as f:\n        args.feedback_templates = json.load(f)\n        args.feedback_replacement = args.feedback_templates['replacement']\n    with open(args.pool_file, 'r') as f:\n        pool = json.load(f)\n        args.synonyms_dict = pool[\"synonyms\"]\n\n    system_mode_list = ['whats_next', 'whether_enough', 'suggestion']\n    attribute_list = ['Bangs', \"Eyeglasses\", \"No_Beard\", \"Smiling\", \"Young\"]\n\n    feedback_list = []\n    output_txt = []\n\n    # instantiate feedback\n    for index in range(args.num_feedback):\n\n        if index % 1000 == 0:\n            print('generated', index, '/', args.num_feedback, 'feedback')\n\n        # initialize feedback parameters\n        attribute = None\n\n        # randomly choose the feedback parameters\n        system_mode = random.choice(system_mode_list)\n        if system_mode == 'whether_enough' or system_mode == 'suggestion':\n            attribute = random.choice(attribute_list)\n\n        feedback = instantiate_feedback(\n            args, system_mode=system_mode, attribute=attribute)\n\n        feedback['index'] = index\n        feedback_list.append(feedback)\n        output_txt.append(feedback['text'])\n\n    # save feedback dataset\n\n    with open(os.path.join(args.output_file_dir, args.output_file_name),\n              'w') as f:\n        json.dump(feedback_list, f, indent=4)\n    np.savetxt(\n        os.path.join(args.output_file_dir, \"feedback.txt\"),\n        output_txt,\n        fmt='%s',\n        delimiter='\\t')\n\n    print('successfully saved.')\n\n\ndef instantiate_feedback(args,\n                         system_mode=None,\n                         attribute=None,\n                         exception_mode='normal'):\n    \"\"\"\n    Given the feedback mode (i.e. system_mode) and the attribute (if any),\n    return a feedback.\n    \"\"\"\n\n    if exception_mode != 'normal':\n        candidate_templates = args.feedback_templates[exception_mode]\n        template = random.choice(candidate_templates)\n        attribute = attribute\n    else:\n        # ---------- STEP 1: 1st part of feedback: 'ok' template ----------\n\n        # instantiate the feedback prefix like \"ok\"\n        ok_distribution_prob = random.uniform(0, 1)\n        ok_template = ''\n\n        if ok_distribution_prob < 0.7:\n            ok_templates = args.feedback_templates['ok']\n            for idx, templates in enumerate(ok_templates):\n                if 0.3 < ok_distribution_prob < 0.7 and (idx == 0 or idx == 1):\n                    continue\n                ok_template += random.choice(templates)\n            ok_template += ' '\n            ok_template = ok_template[0].capitalize() + ok_template[1:]\n\n        # ---------- STEP 2: 2nd part of feedback: content template ----------\n\n        # feedback is trivial like \"what's next?\"\n        if system_mode == 'whats_next':\n            candidate_templates = args.feedback_templates['whats_next']\n            template = random.choice(candidate_templates)\n        # feedback asks whether the editing extent is enough\n        elif system_mode == 'whether_enough':\n            whether_enough_general_prob = random.uniform(0, 1)\n            if whether_enough_general_prob < args.whether_enough_general_prob \\\n                or args.feedback_templates[\n                    'whether_enough'][attribute] == []:\n                candidate_templates = args.feedback_templates[\n                    'whether_enough']['general']\n            else:\n                candidate_templates = args.feedback_templates[\n                    'whether_enough'][attribute]\n            template = random.choice(candidate_templates)\n        # feedback provides suggestion on the next edit\n        elif system_mode == 'suggestion':\n            candidate_templates = args.feedback_templates['suggestion']\n            template = random.choice(candidate_templates)\n        else:\n            raise KeyError('System mode \"%s\" not recognized' % system_mode)\n\n        # ---------- STEP 3: Postprocess the instantiated template sentence ---------- # noqa\n\n        # replace the <xxx> in the template with\n        # proper attribute-specific words.\n        # this is not applicable to 'whats_next' type of feedback\n        if system_mode != 'whats_next':\n            for word in args.feedback_replacement:\n                new_word_dict = args.feedback_replacement[word]\n                new_word = new_word_dict[attribute]\n                template = template.replace(word, new_word)\n\n    # to lower case\n    template = template.lower()\n\n    # randomly replace words with synonyms\n    for word in args.synonyms_dict:\n        replacing_word = random.choice(args.synonyms_dict[word])\n        template = template.replace(word, replacing_word)\n\n    # capitalize\n    template = proper_capitalize(template)\n\n    if exception_mode != 'normal':\n        # after given feedback of cannot_edit\n        # encode user request by pretending that\n        # the system_mode is 'whats_next'\n        system_mode = 'whats_next'\n    else:\n        template = ok_template + template\n\n    # ---------- STEP 4: Return the feedback and its annotations ----------\n\n    feedback = {\n        \"text\": template,\n        \"system_mode\": system_mode,\n        \"attribute\": attribute\n    }\n    return feedback\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "language/generate_training_request.py",
    "content": "import argparse\nimport json\nimport os.path\nimport random\nimport sys\n\nsys.path.append('.')\nfrom language_utils import proper_capitalize  # noqa\n\n\ndef parse_args():\n    \"\"\"Parses arguments.\"\"\"\n\n    parser = argparse.ArgumentParser(description='')\n\n    parser.add_argument(\n        '--num_request',\n        default=100,\n        type=int,\n        help='number of request data to generate')\n\n    # template files\n    parser.add_argument(\n        '--user_templates_file',\n        type=str,\n        default='./templates/user_fsm.json',\n        help='directory to the request templates file')\n    parser.add_argument(\n        '--pool_file',\n        type=str,\n        default='./templates/pool.json',\n        help='directory to the word pool file')\n    parser.add_argument(\n        '--metadata_file',\n        type=str,\n        default='./templates/metadata_fsm.json',\n        help='directory to the metadata file')\n    parser.add_argument(\n        '--system_mode_file',\n        type=str,\n        default='./templates/system_mode.json',\n        help='directory to the system_mode file')\n    # output\n    parser.add_argument(\n        '--output_file_dir',\n        required=True,\n        type=str,\n        help='folder to save the output request file')\n\n    return parser.parse_args()\n\n\ndef main():\n    args = parse_args()\n\n    if not os.path.isdir(args.output_file_dir):\n        os.makedirs(args.output_file_dir, exist_ok=False)\n\n    # load template files\n    print('loading template files')\n    with open(args.user_templates_file, 'r') as f:\n        args.user_templates = json.load(f)\n    with open(args.pool_file, 'r') as f:\n        pool = json.load(f)\n        args.synonyms_dict = pool[\"synonyms\"]\n        args.postfix_list = pool[\"postfix\"]\n    with open(args.metadata_file, 'r') as f:\n        args.metadata = json.load(f)\n    with open(args.system_mode_file, 'r') as f:\n        args.system_mode_dict = json.load(f)\n    args.system_mode_list = []\n    for key, value in args.system_mode_dict.items():\n        args.system_mode_list.append(key)\n\n    attribute_list = ['Bangs', \"Eyeglasses\", \"No_Beard\", \"Smiling\", \"Young\"]\n    target_score_list = [0, 1, 2, 3, 4, 5]\n    score_change_direction_list = ['positive', 'negative']\n    score_change_value_list = [1, 2, 3, 4, 5]\n\n    request_list = []\n\n    # instantiate requests\n    for index in range(args.num_request):\n\n        if index % 1000 == 0:\n            print('generated', index, '/', args.num_request, 'requests')\n\n        # randomly choose the semantic editing parameters\n        system_mode = random.choice(args.system_mode_list)\n        user_mode_list = list(args.metadata[system_mode].keys())\n        user_mode = random.choice(user_mode_list)\n        attribute = random.choice(attribute_list)\n        score_change_value = random.choice(score_change_value_list)\n        score_change_direction = random.choice(score_change_direction_list)\n        target_score = random.choice(target_score_list)\n\n        # instantiate a request according to the\n        # chosen semantic editing parameters\n        request = instantiate_training_request(\n            args,\n            attribute=attribute,\n            user_mode=user_mode,\n            score_change_direction=score_change_direction,\n            score_change_value=score_change_value,\n            target_score=target_score)\n\n        request['system_mode'] = system_mode\n\n        # assign each system_mode's user_mode\n        for mode in args.system_mode_list:\n            if system_mode == mode:\n                request[mode] = request['user_mode']\n            else:\n                request[mode] = None\n\n        request['index'] = index\n        request_list.append(request)\n\n    # save request dataset\n    if not os.path.isdir(args.output_file_dir):\n        os.makedirs(args.output_file_dir, exist_ok=True)\n    with open(\n            os.path.join(args.output_file_dir, 'training_request.json'),\n            'w') as f:\n        json.dump(request_list, f, indent=4)\n\n    print('successfully saved.')\n\n\ndef instantiate_training_request(\n    args,\n    attribute=None,\n    user_mode=None,\n    score_change_direction=None,\n    score_change_value=None,\n    target_score=None,\n):\n    \"\"\"\n    Given semantic editing parameters, instantiate the request\n    using the request templates.\n    \"\"\"\n\n    request_mode = None\n\n    instantiated_sentence = ''\n    user_sub_mode_list = user_mode.split('_')\n\n    for user_sub_mode_idx, user_sub_mode in enumerate(user_sub_mode_list):\n\n        sub_mode_template = ''\n        if user_sub_mode != 'pureRequest':\n            sub_mode_templates = args.user_templates[user_sub_mode]\n            for templates in sub_mode_templates:\n                sub_mode_template += random.choice(templates)\n        else:\n            request_mode = random.choice(\n                ['target', 'change_definite', 'change_indefinite'])\n\n            request_templates = args.user_templates['pureRequest']\n            attribute_templates = request_templates[attribute]\n\n            # request is the score change direction and value\n            if request_mode == 'change_definite':\n                assert score_change_direction is not None\n                assert score_change_value is not None\n                target_score = None\n                candidate_templates = attribute_templates['change'][\n                    score_change_direction]['definite'][str(\n                        score_change_value)]\n            # request is the score change direction without value\n            elif request_mode == 'change_indefinite':\n                assert score_change_direction is not None\n                score_change_value = None\n                target_score = None\n                candidate_templates = attribute_templates['change'][\n                    score_change_direction]['indefinite']\n            # request is the edit target\n            elif request_mode == 'target':\n                score_change_direction = None\n                score_change_value = None\n                assert target_score is not None\n                candidate_templates = attribute_templates['target'][str(\n                    target_score)]\n            else:\n                raise KeyError('Request mode \"%s\" not recognized' %\n                               request_mode)\n\n            # randomly choose one request template\n            sub_mode_template = random.choice(candidate_templates)\n\n        if user_sub_mode_idx >= 1:\n            instantiated_sentence += ' '\n        instantiated_sentence += sub_mode_template\n\n    if 'pureRequest' not in user_sub_mode_list:\n        score_change_direction = None\n        score_change_value = None\n        target_score = None\n        attribute = None\n\n    # to lower case\n    instantiated_sentence = instantiated_sentence.lower()\n\n    # randomly replace words with synonyms\n    for word in args.synonyms_dict:\n        new_word = random.choice(args.synonyms_dict[word])\n        instantiated_sentence = instantiated_sentence.replace(word, new_word)\n\n    # capitalize\n    instantiated_sentence = proper_capitalize(instantiated_sentence)\n\n    request = {\n        \"text\": instantiated_sentence,\n        \"user_mode\": user_mode,\n        'request_mode': request_mode,\n        \"attribute\": attribute,\n        \"score_change_direction\": score_change_direction,\n        \"score_change_value\": score_change_value,\n        \"target_score\": target_score,\n    }\n\n    return request\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "language/language_utils.py",
    "content": "import numpy as np\nimport torch\n\n# global variables\nPUNCTUATION_TO_KEEP = ['?', ';']\nPUNCTUATION_TO_REMOVE = ['.', '!', ',']\nSPECIAL_TOKENS = {\n    '<NULL>': 0,\n    '<START>': 1,\n    '<END>': 2,\n    '<UNK>': 3,\n}\n\n\ndef build_vocab(text_list,\n                min_token_count=1,\n                delimiter=' ',\n                punct_to_keep=None,\n                punct_to_remove=None,\n                print_every=10000):\n    \"\"\"\n    Build token to index mapping from a list of text strings\n    -- Input: a list of text string\n    -- Output: a dict which is a mapping from token to index,\n    \"\"\"\n\n    token_to_count = {}\n    # tokenize text and add tokens to token_to_count dict\n    for text_idx, text in enumerate(text_list):\n        if text_idx % print_every == 0:\n            print('tokenized', text_idx, '/', len(text_list))\n        text_tokens = tokenize(text=text, delimiter=delimiter)\n        for token in text_tokens:\n            if token in token_to_count:\n                token_to_count[token] += 1\n            else:\n                token_to_count[token] = 1\n\n    token_to_idx = {}\n    print('Mapping tokens to indices')\n\n    # reserve indices for special tokens (must-have tokens)\n    for token, idx in SPECIAL_TOKENS.items():\n        token_to_idx[token] = idx\n\n    # assign indices to tokens\n    for token, count in sorted(token_to_count.items()):\n        if count >= min_token_count:\n            token_to_idx[token] = len(token_to_idx)\n\n    return token_to_idx\n\n\ndef tokenize(text,\n             delimiter=' ',\n             add_start_token=False,\n             add_end_token=False,\n             punctuation_to_keep=PUNCTUATION_TO_KEEP,\n             punctuation_to_remove=PUNCTUATION_TO_REMOVE):\n    \"\"\"\n    Tokenize a text string\n    -- Input: a text string\n    -- Output: a list of tokens,\n       each token is still a string (usually an english word)\n    \"\"\"\n\n    # (1) Optionally keep or remove certain punctuation\n    if punctuation_to_keep is not None:\n        for punctuation in punctuation_to_keep:\n            text = text.replace(punctuation, '%s%s' % (delimiter, punctuation))\n    if punctuation_to_remove is not None:\n        for punctuation in punctuation_to_remove:\n            text = text.replace(punctuation, '')\n\n    # (2) Split the text string into a list of tokens\n    text = text.lower()\n    tokens = text.split(delimiter)\n\n    # (3) Optionally add start and end tokens\n    if add_start_token:\n        tokens.insert(0, '<START>')\n    if add_end_token:\n        tokens.append('<END>')\n\n    return tokens\n\n\ndef encode(text_tokens, token_to_idx, allow_unk=False):\n    text_encoded = []\n    for token in text_tokens:\n        if token not in token_to_idx:\n            if allow_unk:\n                token = '<UNK>'\n            else:\n                raise KeyError('Token \"%s\" not in vocab' % token)\n        text_encoded.append(token_to_idx[token])\n    return text_encoded\n\n\ndef decode(seq_idx, idx_to_token, delim=None, stop_at_end=True):\n    tokens = []\n    for idx in seq_idx:\n        tokens.append(idx_to_token[idx])\n        if stop_at_end and tokens[-1] == '<END>':\n            break\n    if delim is None:\n        return tokens\n    else:\n        return delim.join(tokens)\n\n\ndef reverse_dict(input_dict):\n\n    reversed_dict = {}\n    for key in input_dict.keys():\n        val = input_dict[key]\n        reversed_dict[val] = key\n    return reversed_dict\n\n\ndef to_long_tensor(dset):\n    arr = np.asarray(dset, dtype=np.int64)\n    tensor = torch.LongTensor(arr)\n    return tensor\n\n\ndef proper_capitalize(text):\n    if len(text) > 0:\n        text = text.lower()\n        text = text[0].capitalize() + text[1:]\n        for idx, char in enumerate(text):\n            if char in ['.', '!', '?'] and (idx + 2) < len(text):\n                text = text[:idx + 2] + text[idx + 2].capitalize() + text[idx +\n                                                                          3:]\n        text = text.replace(' i ', ' I ')\n        text = text.replace(',i ', ',I ')\n        text = text.replace('.i ', '.I ')\n        text = text.replace('!i ', '!I ')\n    return text\n"
  },
  {
    "path": "language/lstm.py",
    "content": "\"\"\"\nLSTM\n\nInput: batch_size x max_text_length (tokenized questions)\nOutput: batch_size x lstm_hidden_size (question embedding)\n\nDetails:\nTokenized text are first word-embedded (300-D), then passed to\n2-layer LSTM, where each cell has is 1024-D. For each text,\noutput the hidden state of the last non-null token.\n\"\"\"\n\nfrom __future__ import print_function\n\nimport json\n\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Variable\n\n\nclass Encoder(nn.Module):\n\n    def __init__(self,\n                 token_to_idx,\n                 word_embedding_dim=300,\n                 text_embed_size=1024,\n                 metadata_file='./templates/metadata_fsm.json',\n                 linear_hidden_size=256,\n                 linear_dropout_rate=0):\n\n        super(Encoder, self).__init__()\n\n        # LSTM (shared)\n        self.lstm = LSTM(\n            token_to_idx=token_to_idx,\n            word_embedding_dim=word_embedding_dim,\n            lstm_hidden_size=text_embed_size)\n\n        # classifiers (not shared)\n        with open(metadata_file, 'r') as f:\n            self.metadata = json.load(f)\n        self.classifier_names = []\n        for idx, (key, val) in enumerate(self.metadata.items()):\n            num_val = len(val.items())\n            classifier_name = key\n            self.classifier_names.append(classifier_name)\n            setattr(\n                self, classifier_name,\n                nn.Sequential(\n                    fc_block(text_embed_size, linear_hidden_size,\n                             linear_dropout_rate),\n                    nn.Linear(linear_hidden_size, num_val)))\n\n    def forward(self, text):\n\n        # LSTM (shared)\n        # Input: batch_size x max_text_length\n        # Output: batch_size x text_embed_size\n        text_embedding = self.lstm(text)\n\n        # classifiers (not shared)\n        output = []\n        for classifier_name in self.classifier_names:\n            classifier = getattr(self, classifier_name)\n            output.append(classifier(text_embedding))\n\n        return output\n\n\nclass LSTM(nn.Module):\n\n    def __init__(self,\n                 token_to_idx,\n                 word_embedding_dim=300,\n                 lstm_hidden_size=1024,\n                 lstm_num_layers=2,\n                 lstm_dropout=0):\n\n        super(LSTM, self).__init__()\n\n        # token\n        self.token_to_idx = token_to_idx\n        self.NULL = token_to_idx['<NULL>']\n        self.START = token_to_idx['<START>']\n        self.END = token_to_idx['<END>']\n\n        # word embedding\n        self.word2vec = nn.Embedding(\n            num_embeddings=len(token_to_idx), embedding_dim=word_embedding_dim)\n\n        # LSTM\n        self.rnn = nn.LSTM(\n            input_size=word_embedding_dim,\n            hidden_size=lstm_hidden_size,\n            num_layers=lstm_num_layers,\n            bias=True,\n            batch_first=True,\n            dropout=lstm_dropout,\n            bidirectional=False)\n\n    def forward(self, x):\n\n        batch_size, max_text_length = x.size()\n\n        # Find the last non-null element in each sequence, store in idx\n        idx = torch.LongTensor(batch_size).fill_(max_text_length - 1)\n        x_cpu = x.data.cpu()\n        for text_idx in range(batch_size):\n            for token_idx in range(max_text_length - 1):\n                if (x_cpu[text_idx, token_idx] != self.NULL\n                    ) and x_cpu[text_idx, token_idx + 1] == self.NULL:  # noqa\n                    idx[text_idx] = token_idx\n                    break\n        idx = idx.type_as(x.data).long()\n        idx = Variable(idx, requires_grad=False)\n\n        # reduce memory access time\n        self.rnn.flatten_parameters()\n\n        # hs: all hidden states\n        #      [batch_size x max_text_length x hidden_size]\n        # h_n: [2 x batch_size x hidden_size]\n        # c_n: [2 x batch_size x hidden_size]\n        hidden_states, (_, _) = self.rnn(self.word2vec(x))\n\n        idx = idx.view(batch_size, 1, 1).expand(batch_size, 1,\n                                                hidden_states.size(2))\n        hidden_size = hidden_states.size(2)\n\n        # only retrieve the hidden state of the last non-null element\n        # [batch_size x 1 x hidden_size]\n        hidden_state_at_last_token = hidden_states.gather(1, idx)\n\n        # [batch_size x hidden_size]\n        hidden_state_at_last_token = hidden_state_at_last_token.view(\n            batch_size, hidden_size)\n\n        return hidden_state_at_last_token\n\n\nclass fc_block(nn.Module):\n\n    def __init__(self, inplanes, planes, drop_rate=0.15):\n        super(fc_block, self).__init__()\n        self.fc = nn.Linear(inplanes, planes)\n        self.bn = nn.BatchNorm1d(planes)\n        if drop_rate > 0:\n            self.dropout = nn.Dropout(drop_rate)\n        self.relu = nn.ReLU(inplace=True)\n        self.drop_rate = drop_rate\n\n    def forward(self, x):\n        x = self.fc(x)\n        x = self.bn(x)\n        if self.drop_rate > 0:\n            x = self.dropout(x)\n        x = self.relu(x)\n        return x\n\n\ndef main():\n    \"\"\" Test Code \"\"\"\n    # ################### LSTM #########################\n    question_token_to_idx = {\n        \".\": 4,\n        \"missing\": 34,\n        \"large\": 28,\n        \"is\": 26,\n        \"cubes\": 19,\n        \"cylinder\": 21,\n        \"what\": 54,\n        \"<START>\": 1,\n        \"green\": 24,\n        \"<END>\": 2,\n        \"object\": 35,\n        \"things\": 51,\n        \"<UNK>\": 3,\n        \"matte\": 31,\n        \"rubber\": 41,\n        \"tiny\": 52,\n        \"yellow\": 55,\n        \"red\": 40,\n        \"visible\": 53,\n        \"color\": 17,\n        \"size\": 44,\n        \"balls\": 11,\n        \"the\": 48,\n        \"any\": 8,\n        \"blocks\": 14,\n        \"ball\": 10,\n        \"a\": 6,\n        \"it\": 27,\n        \"an\": 7,\n        \"one\": 38,\n        \"purple\": 39,\n        \"how\": 25,\n        \"thing\": 50,\n        \"?\": 5,\n        \"objects\": 36,\n        \"blue\": 15,\n        \"block\": 13,\n        \"small\": 45,\n        \"shiny\": 43,\n        \"material\": 30,\n        \"cylinders\": 22,\n        \"<NULL>\": 0,\n        \"many\": 29,\n        \"of\": 37,\n        \"cube\": 18,\n        \"metallic\": 33,\n        \"gray\": 23,\n        \"brown\": 16,\n        \"spheres\": 47,\n        \"there\": 49,\n        \"sphere\": 46,\n        \"shape\": 42,\n        \"are\": 9,\n        \"metal\": 32,\n        \"cyan\": 20,\n        \"big\": 12\n    },\n\n    batch_size = 64\n    print('batch size:', batch_size)\n\n    # questions=torch.ones(batch_size, 15, dtype=torch.long)\n    questions = torch.randint(0, 10, (batch_size, 15), dtype=torch.long)\n    print('intput size:', questions.size())\n\n    lstm = LSTM(token_to_idx=question_token_to_idx[0])\n    output = lstm(questions)\n    print('output size:', output.size())\n\n    # ################### Language Encoder #########################\n\n    encoder = Encoder(\n        token_to_idx=question_token_to_idx[0],\n        metadata_file='./templates/metadata_fsm.json')\n    output = encoder(questions)\n    print('output length:', len(output))\n    for classifier in output:\n        print('classifier.size():', classifier.size())\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "language/preprocess_request.py",
    "content": "import argparse\nimport json\nimport os\nimport sys\n\nimport numpy as np\n\nsys.path.append('.')\nfrom language_utils import *  # noqa\n\"\"\"\nPreprocess the text\n\"\"\"\n\n\ndef parse_args():\n    \"\"\"Parses arguments.\"\"\"\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        '--input_vocab_path',\n        required=True,\n        type=str,\n        help='path to the input vocabulary file')\n    parser.add_argument(\n        '--input_data_path',\n        required=True,\n        type=str,\n        help='path to the input data file')\n    parser.add_argument(\n        '--metadata_file',\n        type=str,\n        default='./templates/metadata_fsm.json',\n        help='directory to the metadata file')\n    parser.add_argument(\n        '--system_mode_file',\n        type=str,\n        default='./templates/system_mode.json',\n        help='directory to the system_mode file')\n    parser.add_argument(\n        '--allow_unknown',\n        default=0,\n        type=int,\n        help='whether allow unknown tokens (i.e. words)')\n    parser.add_argument(\n        '--expand_vocab',\n        default=0,\n        type=int,\n        help='whether expand vocabularies')\n    parser.add_argument(\n        '--output_dir',\n        required=True,\n        type=str,\n        help='folder to save the output vocabulary file')\n    parser.add_argument(\n        '--unlabeled_value',\n        default=999,\n        type=int,\n        help='value to represent unlabeled value')\n\n    return parser.parse_args()\n\n\ndef main():\n    args = parse_args()\n\n    if not os.path.isdir(args.output_dir):\n        os.makedirs(args.output_dir, exist_ok=False)\n\n    # load vocabulary\n    print(\"Loading vocab\")\n    with open(args.input_vocab_path, 'r') as f:\n        vocab = json.load(f)\n    text_token_to_idx = vocab['text_token_to_idx']\n\n    # load metadata file\n    with open(args.metadata_file, 'r') as f:\n        metadata = json.load(f)\n\n    # load system_mode file\n    with open(args.system_mode_file, 'r') as f:\n        system_mode_file = json.load(f)\n\n    # load input data\n    with open(args.input_data_path, 'r') as f:\n        input_data = json.load(f)\n\n    # initialize lists to store encoded data\n    text_encoded_list = []\n    system_mode_encoded_list = []\n    labels_encoded_list = []\n\n    print('Encoding')\n\n    for idx, data_sample in enumerate(input_data):\n\n        # encode text\n        text = data_sample['text']\n        text_tokens = tokenize(text=text)  # noqa\n        text_encoded = encode(  # noqa\n            text_tokens=text_tokens,\n            token_to_idx=text_token_to_idx,\n            allow_unk=args.allow_unknown)\n        text_encoded_list.append(text_encoded)\n\n        # encode system_mode\n        system_mode = data_sample['system_mode']\n        system_mode_encoded = system_mode_file[system_mode]\n        system_mode_encoded_list.append(system_mode_encoded)\n\n        # encode labels\n        labels_encoded = []\n        for idx, (key, val) in enumerate(metadata.items()):\n            label = data_sample[key]\n            if label is None:\n                # use args.unlabeled_value to represent missing labels\n                label_encoded = args.unlabeled_value\n            else:\n                label_encoded = val[str(label)]\n            labels_encoded.append(label_encoded)\n        labels_encoded_list.append(labels_encoded)\n\n    # Pad encoded text to equal length\n    print('Padding tokens')\n    text_encoded_padded_list = []\n    max_text_length = max(len(text) for text in text_encoded_list)\n    for text_encoded in text_encoded_list:\n        while len(text_encoded) < max_text_length:\n            text_encoded.append(text_token_to_idx['<NULL>'])\n        text_encoded_padded_list.append(text_encoded)\n\n    # save processed text\n    np.save(\n        os.path.join(args.output_dir, 'text.npy'), text_encoded_padded_list)\n    np.savetxt(\n        os.path.join(args.output_dir, 'text.txt'),\n        text_encoded_padded_list,\n        fmt='%.0f')\n\n    # save processed system_mode\n    np.save(\n        os.path.join(args.output_dir, 'system_mode.npy'),\n        system_mode_encoded_list)\n    np.savetxt(\n        os.path.join(args.output_dir, 'system_mode.txt'),\n        system_mode_encoded_list,\n        fmt='%.0f')\n\n    # save processed labels\n    np.save(os.path.join(args.output_dir, 'labels.npy'), labels_encoded_list)\n    np.savetxt(\n        os.path.join(args.output_dir, 'labels.txt'),\n        labels_encoded_list,\n        fmt='%.0f')\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "language/run_encoder.py",
    "content": "import argparse\nimport json\nimport random\n\nimport torch\n\nfrom .language_utils import *  # noqa\nfrom .lstm import Encoder\n\n\ndef parse_args():\n    \"\"\"Parses arguments.\"\"\"\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        '--input_vocab_file',\n        required=True,\n        type=str,\n        help='path to the input vocabulary file')\n    parser.add_argument(\n        '--allow_unknown',\n        default=1,\n        type=int,\n        help='whether allow unknown tokens (i.e. words)')\n    parser.add_argument(\n        '--pretrained_checkpoint',\n        default='',\n        type=str,\n        help='The pretrained network weights for testing')\n    parser.add_argument(\n        '--metadata_file',\n        default='./templates/metadata_fsm.json',\n        type=str,\n        help='path to metadata file.')\n    parser.add_argument(\n        '--system_mode_file',\n        default='./templates/system_mode.json',\n        type=str,\n        help='path to system_mode file.')\n    parser.add_argument(\n        '--device_name',\n        default='gpu',\n        type=str,\n    )\n    parser.add_argument(\n        '--verbose',\n        default=0,\n        type=int,\n    )\n\n    # LSTM hyperparameter\n    parser.add_argument('--word_embedding_dim', default=300, type=int)\n    parser.add_argument('--text_embed_size', default=1024, type=int)\n    parser.add_argument('--linear_hidden_size', default=256, type=int)\n    parser.add_argument('--linear_dropout_rate', default=0, type=float)\n\n    return parser.parse_args()\n\n\ndef main():\n\n    args = parse_args()\n    encode_request(args)\n\n\ndef encode_request(args, system_mode=None, dialog_logger=None):\n\n    # set up\n    if args.device_name == 'cpu':\n        args.device = torch.device('cpu')\n    elif args.device_name == 'gpu':\n        args.device = torch.device('cuda')\n\n    if dialog_logger is None:\n        output_function = print\n    else:\n        # output_function = dialog_logger.info\n        def output_function(input):\n            # suppress output when called by other scripts\n            pass\n            return\n\n        compulsory_output_function = dialog_logger.info\n\n    # ---------------- STEP 1: Input the Request ----------------\n\n    # choose system_mode\n    with open(args.system_mode_file, 'r') as f:\n        system_mode_dict = json.load(f)\n    system_mode_list = []\n    for (mode, mode_idx) in system_mode_dict.items():\n        system_mode_list.append(mode)\n\n    if __name__ == '__main__':\n        assert system_mode is None\n        system_mode = random.choice(system_mode_list)\n        output_function('      PREDEFINED system_mode:', system_mode)\n    else:\n        assert system_mode is not None\n\n    # input request\n    if True:\n        compulsory_output_function(\n            'Enter your request (Press enter when you finish):')\n        input_text = input()\n    else:\n        input_text = 'make the bangs slightly longer.'\n    compulsory_output_function('USER INPUT >>> ' + input_text)\n\n    # ---------------- STEP 2: Preprocess Request ----------------\n\n    # output_function(\"      The system is trying to understand your request:\")\n    # output_function(\"      ########################################\")\n\n    # load vocabulary\n    with open(args.input_vocab_file, 'r') as f:\n        vocab = json.load(f)\n    text_token_to_idx = vocab['text_token_to_idx']\n\n    text_tokens = tokenize(text=input_text)  # noqa\n    text_encoded = encode(  # noqa\n        text_tokens=text_tokens,\n        token_to_idx=text_token_to_idx,\n        allow_unk=args.allow_unknown)\n    text_encoded = to_long_tensor([text_encoded]).to(args.device)  # noqa\n\n    # ---------------- STEP 3: Encode Request ----------------\n\n    # prepare encoder\n    encoder = Encoder(\n        token_to_idx=text_token_to_idx,\n        word_embedding_dim=args.word_embedding_dim,\n        text_embed_size=args.text_embed_size,\n        metadata_file=args.metadata_file,\n        linear_hidden_size=args.linear_hidden_size,\n        linear_dropout_rate=args.linear_dropout_rate)\n    encoder = encoder.to(args.device)\n    checkpoint = torch.load(args.pretrained_checkpoint)\n    encoder.load_state_dict(checkpoint['state_dict'], True)\n    encoder.eval()\n\n    # forward pass\n    output = encoder(text_encoded)\n\n    # ---------------- STEP 4: Process Encoder Output ----------------\n\n    output_labels = []\n    for head_idx in range(len(output)):\n        _, pred = torch.max(output[head_idx], 1)\n        head_label = pred.cpu().numpy()[0]\n        output_labels.append(head_label)\n\n    # load metadata file\n    with open(args.metadata_file, 'r') as f:\n        metadata = json.load(f)\n\n    # find mapping from value to label\n    reversed_metadata = {}\n    for idx, (key, val) in enumerate(metadata.items()):\n        reversed_val = reverse_dict(val)  # noqa\n        reversed_metadata[key] = reversed_val\n    if args.verbose:\n        output_function('reversed_metadata:', reversed_metadata)\n\n    # convert predicted values to a dict of predicted labels\n    output_semantic_labels = {}  # from LSTM output\n    valid_semantic_labels = {}  # useful information among LSTM output\n    for idx, (key, val) in enumerate(reversed_metadata.items()):\n        output_semantic_labels[key] = val[output_labels[idx]]\n        valid_semantic_labels[key] = None\n    if args.verbose:\n        output_function('output_semantic_labels:', output_semantic_labels)\n\n    # extract predicted labels\n    user_mode = output_semantic_labels[system_mode]\n    valid_semantic_labels[system_mode] = user_mode\n\n    request_mode = output_semantic_labels['request_mode']\n    attribute = output_semantic_labels['attribute']\n    score_change_direction = output_semantic_labels['score_change_direction']\n    if output_semantic_labels['score_change_value'] is None:\n        score_change_value = None\n    else:\n        score_change_value = int(output_semantic_labels['score_change_value'])\n    if output_semantic_labels['target_score'] is None:\n        target_score = None\n    else:\n        target_score = int(output_semantic_labels['target_score'])\n\n    # print to screen\n    output_function('      ENCODED user_mode:' + ' ' + user_mode)\n    valid_semantic_labels['user_mode'] = user_mode\n    if 'pureRequest' in user_mode:\n        output_function('      ENCODED request_mode: ' + ' ' + request_mode)\n        valid_semantic_labels['request_mode'] = request_mode\n        output_function('      ENCODED attribute:' + ' ' + attribute)\n        valid_semantic_labels['attribute'] = attribute\n        # only output_function labels valid for this request_mode\n        if request_mode == 'change_definite':\n            output_function('      ENCODED score_change_direction:' + ' ' +\n                            (score_change_direction))\n            valid_semantic_labels[\n                'score_change_direction'] = score_change_direction\n            output_function('      ENCODED score_change_value:' + ' ' +\n                            str(score_change_value))\n            valid_semantic_labels['score_change_value'] = score_change_value\n        elif request_mode == 'change_indefinite':\n            output_function('      ENCODED score_change_direction:' + ' ' +\n                            score_change_direction)\n            valid_semantic_labels[\n                'score_change_direction'] = score_change_direction\n        elif request_mode == 'target':\n            output_function('      ENCODED target_score:' + ' ' +\n                            str(target_score))\n            valid_semantic_labels['target_score'] = target_score\n\n    valid_semantic_labels['text'] = input_text\n\n    if args.verbose:\n        output_function('valid_semantic_labels:' + ' ' +\n                        str(valid_semantic_labels))\n    # output_function(\"      ########################################\")\n\n    return valid_semantic_labels\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "language/templates/attr_wise_caption_templates.json",
    "content": "{\n    \"Bangs\": {\n        \"0\": [\n            \"<He> has no bangs at all.\",\n            \"<He> has no bangs at all and <his> forehead is visible.\",\n            \"<He> doesn't have any bangs.\",\n            \"<He> doesn't have any bangs and <his> forehead is visible.\",\n            \"<his> entire forehead is visible.\",\n            \"<his> entire forehead is visible without any bangs.\",\n            \"<He> shows <his> entire forehead without any bangs.\"\n        ],\n        \"1\": [\n            \"<He> has very short bangs which only covers a tiny portion of <his> forehead.\"\n        ],\n        \"2\": [\n            \"<He> has short bangs that covers a small portion of <his> forehead.\",\n            \"<He> has short bangs that only covers a small portion of <his> forehead.\"\n        ],\n        \"3\": [\n            \"<He> has medium bangs that covers half of <his> forehead.\",\n            \"<He> has bangs of medium length that covers half of <his> forehead.\",\n            \"<He> has bangs of medium length that leaves half of <his> forehead visible.\"\n        ],\n        \"4\": [\n            \"<He> has long bangs that almost covers all of <his> forehead.\",\n            \"<He> has long bangs that almost covers This entire forehead.\"\n        ],\n        \"5\": [\n            \"<He> has extremely long bangs that almost covers all of <his> forehead.\",\n            \"<He> has extremely long bangs that almost covers This entire forehead.\"\n        ]\n    },\n    \"Eyeglasses\": {\n        \"0\": [\n            \"<He> is not wearing any eyeglasses.\",\n            \"There is not any eyeglasses on <his> face.\"\n        ],\n        \"1\": [\n            \"<He> is wearing rimless eyeglasses.\"\n        ],\n        \"2\": [\n            \"<He> is wearing eyeglasses with thin frame.\",\n            \"<He> is wearing thin frame eyeglasses.\"\n        ],\n        \"3\": [\n            \"<He> is wearing eyeglasses with thick frame.\",\n            \"<He> is wearing thick frame eyeglasses.\"\n        ],\n        \"4\": [\n            \"<He> is wearing sunglasses with thin frame.\",\n            \"<He> is wearing thin frame sunglasses.\"\n        ],\n        \"5\": [\n            \"<He> is wearing sunglasses with thick frame.\",\n            \"<He> is wearing thick frame sunglasses.\"\n        ]\n    },\n    \"No_Beard\": {\n        \"0\": [\n            \"<He> doesn't have any beard.\",\n            \"<He> doesn't have any beard at all.\"\n        ],\n        \"1\": [\n            \"<his> face is covered with short pointed beard.\",\n            \"<his> face is covered with his stubble.\",\n            \"<his> face has a rough growth of stubble.\",\n            \"<He> has a rough growth of stubble.\",\n            \"There should be stubble covering <his> cheeks and chin.\"\n        ],\n        \"2\": [\n            \"<his> face is covered with short beard.\"\n        ],\n        \"3\": [\n            \"<his> face is covered with beard of medium length.\",\n            \"<He> has beard of medium length.\"\n        ],\n        \"4\": [\n            \"<He> has a big mustache on his face.\",\n            \"<his> has a bushy beard.\"\n        ],\n        \"5\": [\n            \"<his> has very long beard.\",\n            \"<his> has full beard.\",\n            \"<his> has very thick beard.\",\n            \"<his> has a very bushy beard.\"\n        ]\n    },\n    \"Smiling\": {\n        \"0\": [\n            \"<He> looks serious with no smile in <his> face.\"\n        ],\n        \"1\": [\n            \"<He> smiles with corners of the mouth turned up.\",\n            \"<He> smiles with corners of <his> mouth turned up.\",\n            \"<He> turns up the corners of <his> mouth.\"\n        ],\n        \"2\": [\n            \"This corners of <his> mouth curve up and we can see some teeth.\",\n            \"<He> smiles broadly and shows some teeth.\"\n        ],\n        \"3\": [\n            \"The entire face of this <man> is beamed with happiness.\",\n            \"<He> has a beaming face.\",\n            \"<He> is smiling with <his> teeth visible.\",\n            \"<his> entire face is beamed with happiness.\"\n        ],\n        \"4\": [\n            \"<He> has a big smile.\",\n            \"<He> has a big smile on <his> face.\",\n            \"<He> is smiling with <his> mouth slightly open.\",\n            \"<He> is smiling with <his> mouth slightly open and teeth visible.\"\n        ],\n        \"5\": [\n            \"This <man> in the image is laughing happily.\",\n            \"<He> has a deep rumbling laugh.\",\n            \"<He> has a very big smile.\",\n            \"<He> has a very big smile on <his> face.\",\n            \"<He> is smiling with <his> mouth wide open.\",\n            \"<He> is smiling with <his> mouth wide open and teeth visible.\"\n        ]\n    },\n    \"Young\": {\n        \"0\": [\n            \"This is a young kid.\",\n            \"This is a young child.\"\n        ],\n        \"1\": [\n            \"<He> is a teenager.\",\n            \"<He> looks very young.\"\n        ],\n        \"2\": [\n            \"<He> is a young adult.\",\n            \"<He> is in <his> thirties.\"\n        ],\n        \"3\": [\n            \"<He> is in <his> forties.\",\n            \"<He> is in <his> middle age.\"\n        ],\n        \"4\": [\n            \"<He> is in <his> sixties.\",\n            \"<He> is in <his> fifties.\",\n            \"<He> looks like an elderly.\"\n        ],\n        \"5\": [\n            \"<He> is in <his> eighties.\",\n            \"This old <man> is in <his> eighties.\",\n            \"<He> is in <his> seventies.\",\n            \"This old <man> is in <his> seventies.\",\n            \"<He> looks very old.\"\n        ]\n    }\n}"
  },
  {
    "path": "language/templates/feedback.json",
    "content": "{\n    \"replacement\": {\n        \"<ATTR_NAME>\": {\n            \"Bangs\": \"bangs\",\n            \"Eyeglasses\": \"glasses\",\n            \"No_Beard\": \"beard\",\n            \"Smiling\": \"smile\",\n            \"Young\": \"age\"\n        },\n        \"<ATTR_PRONOUN>\": {\n            \"Bangs\": \"them\",\n            \"Eyeglasses\": \"them\",\n            \"No_Beard\": \"it\",\n            \"Smiling\": \"it\",\n            \"Young\": \"it\"\n        },\n        \"<ATTR_BE>\": {\n            \"Bangs\": \"are\",\n            \"Eyeglasses\": \"are\",\n            \"No_Beard\": \"is\",\n            \"Smiling\": \"is\",\n            \"Young\": \"is\"\n        },\n        \"<ATTR_DEGREE>\": {\n            \"Bangs\": \"length\",\n            \"Eyeglasses\": \"style\",\n            \"No_Beard\": \"shape\",\n            \"Smiling\": \"degree\",\n            \"Young\": \"level\"\n        }\n    },\n    \"suggestion\": [\n        \"Do you want to try manipulating the <ATTR_NAME>?\",\n        \"Do you want to try manipulating the <ATTR_NAME> instead?\",\n        \"Do you want to try manipulating the <ATTR_NAME> as well?\",\n        \"Do you want to try editing the <ATTR_NAME>?\",\n        \"Do you want to try editing the <ATTR_NAME> instead?\",\n        \"Do you want to try editing the <ATTR_NAME> as well?\",\n        \"What about the <ATTR_NAME>? Do you want to play with <ATTR_PRONOUN>?\",\n        \"Do you want to play with the <ATTR_NAME>?\",\n        \"What about the <ATTR_NAME>? Do you want to edit <ATTR_PRONOUN>?\",\n        \"Do you want to edit the <ATTR_NAME>?\",\n        \"What about the <ATTR_NAME>? Do you want to manipulate <ATTR_PRONOUN>?\",\n        \"Do you want to manipulate the <ATTR_NAME>?\"\n    ],\n    \"whether_enough\": {\n        \"general\": [\n            \"Is this enough?\",\n            \"Is this good enough?\",\n            \"<ATTR_BE> the <ATTR_NAME> just right now?\",\n            \"<ATTR_BE> the <ATTR_NAME> what you want now?\",\n            \"<ATTR_BE> the <ATTR_NAME> of the person just right now?\",\n            \"<ATTR_BE> the <ATTR_NAME> of the person what you want now?\",\n            \"<ATTR_BE> the <ATTR_NAME> of proper degree now?\",\n            \"<ATTR_BE> the <ATTR_DEGREE> of the <ATTR_NAME> ok now?\",\n            \"<ATTR_BE> the <ATTR_DEGREE> of the <ATTR_NAME> okay now?\"\n        ],\n        \"Bangs\": [\n            \"Are the bangs in proper shape now?\",\n            \"Is the length of the bangs ok now?\"\n        ],\n        \"Eyeglasses\": [],\n        \"No_Beard\": [],\n        \"Smiling\": [],\n        \"Young\": [\n            \"Is the age of the person ok now?\"\n        ]\n    },\n    \"whats_next\": [\n        \"What's next?\",\n        \"What else do you want to play with?\",\n        \"What else do you want to manipulate?\",\n        \"What else do you want to edit?\",\n        \"What else do you want to change?\",\n        \"What else do you want to try?\"\n    ],\n    \"ok\": [\n        [\n            \"Okay\",\n            \"Ok\",\n            \"Well\",\n            \"Okie\"\n        ],\n        [\n            \" \",\n            \", \"\n        ],\n        [\n            \"done.\",\n            \"it's done.\",\n            \"bingo.\",\n            \"finished.\",\n            \"that's it.\",\n            \"this is it.\"\n        ]\n    ],\n    \"max_edit_num_reached\": [\n        \"It is infeasible to edit this attribute. Let's try another attribute.\",\n        \"We cannot edit this attribute. Let's try something else.\",\n        \"Oops, it is hard to edit this attribute. Let's try something else.\",\n        \"Sorry, we are unable to edit this attribute. Perhaps we can try something else.\"\n    ],\n    \"already_at_target_class\": [\n        \"This attribute is already at the degree that you want. Let's try a different attribute degree or another attribute.\"\n    ]\n}"
  },
  {
    "path": "language/templates/gender.json",
    "content": "{\n    \"male\": {\n        \"<man>\": [\n            \"person\",\n            \"guy\",\n            \"gentleman\"\n        ],\n        \"<he>\": [\n            \"he\",\n            \"he\",\n            \"this person\",\n            \"this guy\",\n            \"this gentleman\",\n            \"this man\"\n        ],\n        \"<his>\": [\n            \"his\",\n            \"the\"\n        ],\n        \"<him>\": [\n            \"him\"\n        ],\n        \"<boy>\": [\n            \"boy\"\n        ]\n    },\n    \"female\": {\n        \"<man>\": [\n            \"person\",\n            \"lady\",\n            \"female\"\n        ],\n        \"<he>\": [\n            \"she\",\n            \"she\",\n            \"this lady\",\n            \"this person\",\n            \"this female\",\n            \"this woman\"\n        ],\n        \"<his>\": [\n            \"her\",\n            \"the\"\n        ],\n        \"<him>\": [\n            \"her\"\n        ],\n        \"<boy>\": [\n            \"girl\"\n        ]\n    }\n}"
  },
  {
    "path": "language/templates/metadata_fsm.json",
    "content": "{\n    \"start\": {\n        \"start_pureRequest\": 0\n    },\n    \"suggestion\": {\n        \"yes\": 0,\n        \"yes_pureRequest\": 1,\n        \"no\": 2,\n        \"no_pureRequest\": 3,\n        \"no_end\": 4\n    },\n    \"whether_enough\": {\n        \"yes\": 0,\n        \"yes_pureRequest\": 1,\n        \"yes_end\": 2,\n        \"no\": 3,\n        \"no_pureRequest\": 4\n    },\n    \"whats_next\": {\n        \"pureRequest\": 0,\n        \"end\": 1\n    },\n    \"attribute\": {\n        \"Bangs\": 0,\n        \"Eyeglasses\": 1,\n        \"No_Beard\": 2,\n        \"Smiling\": 3,\n        \"Young\": 4\n    },\n    \"score_change_direction\": {\n        \"negative\": 0,\n        \"positive\": 1\n    },\n    \"score_change_value\": {\n        \"1\": 0,\n        \"2\": 1,\n        \"3\": 2,\n        \"4\": 3,\n        \"5\": 4\n    },\n    \"target_score\": {\n        \"0\": 0,\n        \"1\": 1,\n        \"2\": 2,\n        \"3\": 3,\n        \"4\": 4,\n        \"5\": 5\n    },\n    \"request_mode\": {\n        \"change_definite\": 0,\n        \"change_indefinite\": 1,\n        \"target\": 2,\n        \"end\": 3\n    }\n}"
  },
  {
    "path": "language/templates/overall_caption_templates.json",
    "content": "{\n    \"attr_order_mapping\": {\n        \"Bangs\": {\n            \"0\": [\n                \"has\",\n                \"sentence\"\n            ],\n            \"1\": [\n                \"has\"\n            ],\n            \"2\": [\n                \"has\"\n            ],\n            \"3\": [\n                \"has\"\n            ],\n            \"4\": [\n                \"has\",\n                \"sentence\"\n            ]\n        },\n        \"No_Beard\": {\n            \"0\": [\n                \"has\",\n                \"sentence\"\n            ],\n            \"1\": [\n                \"has\"\n            ],\n            \"2\": [\n                \"has\"\n            ],\n            \"3\": [\n                \"has\"\n            ],\n            \"4\": [\n                \"has\",\n                \"sentence\"\n            ]\n        },\n        \"Eyeglasses\": {\n            \"0\": [\n                \"has\",\n                \"sentence\"\n            ],\n            \"1\": [\n                \"has\"\n            ],\n            \"2\": [\n                \"has\"\n            ],\n            \"3\": [\n                \"has\"\n            ],\n            \"4\": [\n                \"has\",\n                \"sentence\"\n            ]\n        },\n        \"Smiling\": {\n            \"0\": [\n                \"has\",\n                \"sentence\"\n            ],\n            \"1\": [\n                \"has\"\n            ],\n            \"2\": [\n                \"has\"\n            ],\n            \"3\": [\n                \"has\"\n            ],\n            \"4\": [\n                \"has\",\n                \"sentence\"\n            ]\n        },\n        \"Young\": {\n            \"0\": [\n                \"start\"\n            ],\n            \"1\": [\n                \"sentence\"\n            ],\n            \"2\": [\n                \"sentence\"\n            ],\n            \"3\": [\n                \"sentence\"\n            ],\n            \"4\": [\n                \"sentence\"\n            ]\n        }\n    },\n    \"has\": {\n        \"Bangs\": {\n            \"0\": [\n                \"no bangs\"\n            ],\n            \"1\": [\n                \"very short bangs\",\n                \"very short bangs which only covers a tiny portion of <his> forehead\"\n            ],\n            \"2\": [\n                \"short bangs\",\n                \"short bangs that covers a small portion of <his> forehead\",\n                \"short bangs that only covers a small portion of <his> forehead\"\n            ],\n            \"3\": [\n                \"medium bangs\",\n                \"medium bangs that covers half of <his> forehead\",\n                \"bangs of medium length that covers half of <his> forehead\",\n                \"bangs of medium length that leaves half of <his> forehead visible\"\n            ],\n            \"4\": [\n                \"long bangs\",\n                \"long bangs that almost covers all of <his> forehead\",\n                \"long bangs that almost covers This entire forehead\"\n            ],\n            \"5\": [\n                \"extremely long bangs\",\n                \"extremely long bangs that almost covers all of <his> forehead\",\n                \"extremely long bangs that almost covers This entire forehead\"\n            ]\n        },\n        \"Eyeglasses\": {\n            \"0\": [\n                \"no eyeglasses\"\n            ],\n            \"1\": [\n                \"rimless eyeglasses\"\n            ],\n            \"2\": [\n                \"eyeglasses with thin frame\",\n                \"thin frame eyeglasses\"\n            ],\n            \"3\": [\n                \"eyeglasses with thick frame\",\n                \"thick frame eyeglasses\"\n            ],\n            \"4\": [\n                \"sunglasses with thin frame\",\n                \"thin frame sunglasses\"\n            ],\n            \"5\": [\n                \"sunglasses with thick frame\",\n                \"thick frame sunglasses\"\n            ]\n        },\n        \"No_Beard\": {\n            \"0\": [\n                \"no beard\",\n                \"no beard at all\"\n            ],\n            \"1\": [\n                \"short pointed beard\",\n                \"stubble\",\n                \"a rough growth of stubble\",\n                \"stubble covering <his> cheeks and chin\"\n            ],\n            \"2\": [\n                \"short beard\"\n            ],\n            \"3\": [\n                \"beard of medium length\"\n            ],\n            \"4\": [\n                \"a big mustache on his face\",\n                \"a bushy beard\"\n            ],\n            \"5\": [\n                \"very long beard\",\n                \"full beard\",\n                \"very thick beard\",\n                \"a very bushy beard\"\n            ]\n        },\n        \"Smiling\": {\n            \"0\": [\n                \"no smile\"\n            ],\n            \"1\": [\n                \"a very mild smile\"\n            ],\n            \"2\": [\n                \"a mild smile\"\n            ],\n            \"3\": [\n                \"a beaming face\",\n                \"a smile with <his> teeth visible\",\n                \"a face that is beamed with happiness\",\n                \"a smile\"\n            ],\n            \"4\": [\n                \"a big smile\",\n                \"a big smile on <his> face\",\n                \"a big smile with <his> mouth slightly open\",\n                \"a big smile with <his> mouth slightly open and teeth visible\"\n            ],\n            \"5\": [\n                \"a deep rumbling laugh\",\n                \"a very big smile\",\n                \"a very big smile on <his> face\",\n                \"a very big smile with <his> mouth wide open\",\n                \"a very big smile with <his> mouth wide open and teeth visible\"\n            ]\n        }\n    },\n    \"start\": {\n        \"Young\": {\n            \"0\": [\n                \"This young kid\",\n                \"This young child\",\n                \"This little <boy>\"\n            ],\n            \"1\": [\n                \"This teenager\",\n                \"This young <man>\",\n                \"This young <boy>\"\n            ],\n            \"2\": [\n                \"This young adult\",\n                \"This <man> in <his> thirties\"\n            ],\n            \"3\": [\n                \"This <man> in <his> forties\",\n                \"This <man> in <his> middle age\",\n                \"This middle-aged <man>\"\n            ],\n            \"4\": [\n                \"This <man> in <his> sixties\",\n                \"This <man> in <his> fifties\",\n                \"This elderly <man>\"\n            ],\n            \"5\": [\n                \"This old <man>\",\n                \"This <man> in <his> eighties\",\n                \"This old <man> in <his> eighties\",\n                \"This <man> in <his> seventies\",\n                \"This old <man> in <his> seventies\",\n                \"This very old <man>\"\n            ]\n        }\n    },\n    \"has_prefix\": [\n        \"This <man> has \",\n        \"<He> has \"\n    ]\n}"
  },
  {
    "path": "language/templates/pool.json",
    "content": "{\n    \"synonyms\": {\n        \" can \": [\n            \" can \",\n            \" could \",\n            \" should \"\n        ],\n        \"i'm\": [\n            \"i'm\",\n            \"i am\"\n        ],\n        \"it's\": [\n            \"it's\",\n            \"it is\"\n        ],\n        \"bangs\": [\n            \"bangs\",\n            \"fringe\"\n        ],\n        \"slightly\": [\n            \"slightly\",\n            \"a little bit\",\n            \"a tiny little bit\",\n            \"a little\",\n            \"a bit\",\n            \"only a little\",\n            \"just a little bit\"\n        ],\n        \"somewhat\": [\n            \"somewhat\",\n            \"relatively\",\n            \"to some extent\",\n            \"to some degree\",\n            \"moderately\",\n            \"partially\",\n            \"sort of\",\n            \"kind of\",\n            \"considerably\"\n        ],\n        \"very\": [\n            \"very\",\n            \"extremely\"\n        ],\n        \"entire\": [\n            \"entire\",\n            \"whole\",\n            \"full\"\n        ],\n        \"child\": [\n            \"child\",\n            \"schoolchild\"\n        ],\n        \"teenager\": [\n            \"teenager\",\n            \"teen\"\n        ],\n        \"beard\": [\n            \"beard\",\n            \"mustache\"\n        ],\n        \"i think\": [\n            \"i think\",\n            \"i think that\",\n            \"i feel\",\n            \"i feel that\",\n            \"i kind of think\",\n            \"i kind of think that\",\n            \"i kind of feel\",\n            \"i kind of feel that\",\n            \"i guess\",\n            \"i guess that\"\n        ],\n        \"i want\": [\n            \"i want\",\n            \"i kind of want\",\n            \"i would like\"\n        ],\n        \"let's try\": [\n            \"let's try\",\n            \"how about trying\",\n            \"what about trying\"\n        ],\n        \"but not too much\": [\n            \"but not too much\",\n            \"just not too much\",\n            \"just that not too much\",\n            \"just don't go too much\"\n        ],\n        \"only\": [\n            \"only\",\n            \"simply\",\n            \"just\"\n        ],\n        \"eyeglasses\": [\n            \"eyeglasses\",\n            \"glasses\"\n        ],\n        \"pokerface\": [\n            \"pokerface\",\n            \"poker face\"\n        ],\n        \"what's\": [\n            \"what's\",\n            \"what is\"\n        ],\n        \"how's\": [\n            \"how's\",\n            \"how is\"\n        ],\n        \"do you want to\": [\n            \"do you want to\",\n            \"would you like to\",\n            \"perhaps you would like to\",\n            \"perhaps you might want to\",\n            \"maybe you would like to\",\n            \"maybe you might want to\"\n        ],\n        \"want to\": [\n            \"want to\",\n            \"would like to\"\n        ],\n        \"manipulate\": [\n            \"manipulate\",\n            \"edit\"\n        ],\n        \"manipulating\": [\n            \"manipulating\",\n            \"editing\",\n            \"playing with\"\n        ]\n    },\n    \"prefix\": [\n        \"Actually,\",\n        \"To be honest,\",\n        \"Well,\",\n        \"Well\",\n        \"Emm\",\n        \"Emmm\",\n        \"Emmmm\",\n        \"Emm,\",\n        \"Emmm,\",\n        \"Emmmm,\",\n        \"Hi,\",\n        \"Hello,\",\n        \"Let me think about it.\",\n        \"I'm not too sure but\",\n        \"What about this?\",\n        \"Can we try this?\",\n        \"It looks okay now but\",\n        \"It looks better now, but still,\",\n        \"It looks nice, but still,\",\n        \"Let me have a look. Well,\",\n        \"Let me have a look. Well\",\n        \"Let me have a look. Emm,\",\n        \"Let me have a look. Emmm,\",\n        \"Let me have a look. Emmmm,\",\n        \"Let me have a look. Emm\",\n        \"Let me have a look. Emmm\",\n        \"Let me have a look. Emmmm\",\n        \"Let me take a look. Well,\",\n        \"Let me take a look. Well\",\n        \"Let me take a look. Emm,\",\n        \"Let me take a look. Emmm,\",\n        \"Let me take a look. Emmmm,\",\n        \"Let me take a look. Emm\",\n        \"Let me take a look. Emmm\",\n        \"Let me take a look. Emmmm\"\n    ],\n    \"postfix\": [\n        \"Thanks!\",\n        \"Thank you!\",\n        \"Is that possible?\",\n        \"and emmm... well let's try this first.\",\n        \"I guess it will probably get better this way.\",\n        \"I'm not too sure, let's see how it goes first.\",\n        \"It would be nicer in that way.\",\n        \"It would be nicer in that way, I think.\",\n        \"It would be nicer in that way, I guess.\",\n        \"I think it would be nicer in that way.\",\n        \"I guess it would be nicer in that way.\",\n        \"It would be nicer this way.\",\n        \"It would be nicer this way, I think.\",\n        \"It would be nicer this way, I guess.\",\n        \"I think it would be nicer this way.\",\n        \"I guess it would be nicer this way.\",\n        \"It might be nicer in that way.\",\n        \"It might be nicer in that way, I think.\",\n        \"It might be nicer in that way, I guess.\",\n        \"I think it might be nicer in that way.\",\n        \"I guess it might be nicer in that way.\",\n        \"It might be nicer this way.\",\n        \"It might be nicer this way, I think.\",\n        \"It might be nicer this way, I guess.\",\n        \"I think it might be nicer this way.\",\n        \"I guess it might be nicer this way.\",\n        \"It would look better in that way.\",\n        \"It would look better in that way, I think.\",\n        \"It would look better in that way, I guess.\",\n        \"I think it would look better in that way.\",\n        \"I guess it would look better in that way.\",\n        \"It would look better this way.\",\n        \"It would look better this way, I think.\",\n        \"It would look better this way, I guess.\",\n        \"I think it would look better this way.\",\n        \"I guess it would look better this way.\",\n        \"It might look better in that way.\",\n        \"It might look better in that way, I think.\",\n        \"It might look better in that way, I guess.\",\n        \"I think it might look better in that way.\",\n        \"I guess it might look better in that way.\",\n        \"It might look better this way.\",\n        \"It might look better this way, I think.\",\n        \"It might look better this way, I guess.\",\n        \"I think it might look better this way.\",\n        \"I guess it might look better this way.\"\n    ]\n}"
  },
  {
    "path": "language/templates/system_mode.json",
    "content": "\n{\n    \"start\": 0,\n    \"suggestion\": 1,\n    \"whether_enough\": 2,\n    \"whats_next\": 3\n}    \n"
  },
  {
    "path": "language/templates/user_fsm.json",
    "content": "{\n    \"start\": [\n        [\n            \"Hi.\",\n            \"Hello.\"\n        ]\n    ],\n    \"pureRequest\": {\n        \"Bangs\": {\n            \"target\": {\n                \"0\": [\n                    \"No bangs.\",\n                    \"Remove all the bangs.\",\n                    \"Cut off all the bangs.\",\n                    \"I don't want the bangs at all.\",\n                    \"I don't want any bangs.\",\n                    \"I don't want any bangs visible.\",\n                    \"The bangs doesn't look good, let's remove it.\",\n                    \"The bangs covers the forehead, but I want the entire forehead visible.\"\n                ],\n                \"1\": [\n                    \"Add very short bangs.\",\n                    \"I want very short bangs.\",\n                    \"Add very short bangs that leaves most of the forehead uncovered.\",\n                    \"I want very short bangs that leaves most of the forehead uncovered.\"\n                ],\n                \"2\": [\n                    \"Add short bangs.\",\n                    \"Let's try short bangs.\",\n                    \"Add short bangs that covers only a small portion of the forehead.\",\n                    \"Let's try short bangs that covers only a small portion of the forehead.\"\n                ],\n                \"3\": [\n                    \"Add medium bangs.\",\n                    \"Add bangs of medium length.\",\n                    \"Let's try bangs of medium length.\",\n                    \"Let's try bangs that leaves half of the forehead visible.\"\n                ],\n                \"4\": [\n                    \"Add long bangs.\",\n                    \"Let's try long bangs.\",\n                    \"Add long bangs but don't cover the entire forehead.\",\n                    \"Let's try long bangs but don't cover the entire forehead.\"\n                ],\n                \"5\": [\n                    \"Add extremely long bangs.\",\n                    \"Let's try extremely long bangs.\",\n                    \"Add extremely long bangs that covers the entire forehead.\",\n                    \"Let's try extremely long bangs that covers the entire forehead.\",\n                    \"Indeed, the bangs can be much longer. Let's cover the eyebrows.\"\n                ]\n            },\n            \"change\": {\n                \"positive\": {\n                    \"definite\": {\n                        \"1\": [\n                            \"The bangs can be slightly longer.\",\n                            \"Make the bangs slightly longer.\"\n                        ],\n                        \"2\": [\n                            \"The bangs can be somewhat longer, but not too much.\",\n                            \"Make the bangs somewhat longer, but not too much.\"\n                        ],\n                        \"3\": [\n                            \"Make the bangs longer, but not too much.\"\n                        ],\n                        \"4\": [\n                            \"The bangs can be longer.\",\n                            \"Make the bangs longer.\"\n                        ],\n                        \"5\": [\n                            \"The bangs can be much longer.\",\n                            \"Make the bangs much longer.\"\n                        ]\n                    },\n                    \"indefinite\": [\n                        \"Longer bangs.\",\n                        \"Add bangs.\",\n                        \"The bangs can be longer.\",\n                        \"Let's add some bangs.\",\n                        \"Maybe the bangs can be longer.\",\n                        \"Let's try adding longer bangs.\",\n                        \"What about adding longer bangs?\",\n                        \"Emm, I think the bangs can be longer.\",\n                        \"Let's make the bangs longer.\",\n                        \"Hi, I want to see how my friend looks like with some bangs.\"\n                    ]\n                },\n                \"negative\": {\n                    \"definite\": {\n                        \"1\": [\n                            \"The bangs can be slightly shorter.\",\n                            \"Make the bangs slightly shorter.\"\n                        ],\n                        \"2\": [\n                            \"The bangs can be somewhat shorter, but not too much.\",\n                            \"Make the bangs somewhat shorter, but not too much.\"\n                        ],\n                        \"3\": [\n                            \"The bangs can be shorter.\",\n                            \"Make the bangs shorter.\"\n                        ],\n                        \"4\": [\n                            \"The bangs can be much shorter.\",\n                            \"Make the bangs much shorter.\"\n                        ],\n                        \"5\": [\n                            \"Remove all the bangs.\",\n                            \"I don't want the bangs at all.\",\n                            \"I don't want any bangs at all.\"\n                        ]\n                    },\n                    \"indefinite\": [\n                        \"Less bangs\",\n                        \"Remove bangs.\",\n                        \"Remove the bangs.\",\n                        \"Let's cut off the bangs.\",\n                        \"Let's cut the bangs short.\",\n                        \"Let's cut the bangs off.\",\n                        \"I don't like the bangs, let's remove it.\",\n                        \"I don't like the bangs, let's cut it off.\",\n                        \"The bangs is too long, let's remove it.\",\n                        \"The bangs is too long, let's cut it off.\"\n                    ]\n                }\n            }\n        },\n        \"Eyeglasses\": {\n            \"target\": {\n                \"0\": [\n                    \"No eyeglass\",\n                    \"No eyeglasses please.\",\n                    \"No eyeglasses.\",\n                    \"Remove eyeglasses.\",\n                    \"Remove the eyeglasses.\",\n                    \"I don't want to see the eyeglasses.\",\n                    \"I think there shouldn't be any eyeglasses.\"\n                ],\n                \"1\": [\n                    \"The eyeglasses should be rimless.\",\n                    \"Let's try rimless eyeglasses.\"\n                ],\n                \"2\": [\n                    \"The eyeglasses should have thin frame.\",\n                    \"Let's try thin frame eyeglasses.\"\n                ],\n                \"3\": [\n                    \"The eyeglasses should have thick frame.\",\n                    \"Let's try thick frame eyeglasses.\"\n                ],\n                \"4\": [\n                    \"Let's try thin frame sunglasses.\",\n                    \"It should be sunglasses with thin frame.\"\n                ],\n                \"5\": [\n                    \"Let's try thick frame sunglasses.\",\n                    \"It should be sunglasses with thick frame.\"\n                ]\n            },\n            \"change\": {\n                \"positive\": {\n                    \"definite\": {\n                        \"1\": [\n                            \"Make the eyeglasses slightly more obvious.\",\n                            \"The eyeglasses can be slightly more obvious.\"\n                        ],\n                        \"2\": [\n                            \"Make the eyeglasses somewhat more obvious.\",\n                            \"The eyeglasses can be somewhat more obvious.\"\n                        ],\n                        \"3\": [\n                            \"Make the eyeglasses more obvious.\",\n                            \"The eyeglasses can be more obvious.\"\n                        ],\n                        \"4\": [\n                            \"Let's try eyeglasses with thicker frame and darker color.\"\n                        ],\n                        \"5\": [\n                            \"Let's try thick frame sunglasses.\",\n                            \"It should be sunglasses with thick frame.\"\n                        ]\n                    },\n                    \"indefinite\": [\n                        \"Add glasses\",\n                        \"Use eyeglasses\",\n                        \"Try eyeglasses.\",\n                        \"Add eyeglasses.\",\n                        \"Add eyeglasses to the face.\",\n                        \"Add eyeglasses please.\",\n                        \"Let's add eyeglasses.\",\n                        \"The eyeglasses can be more obvious.\",\n                        \"The eyeglasses are not obvious enough.\",\n                        \"I can't see the eyeglasses clearly, let's make them more obvious.\",\n                        \"The eyeglasses frame can be thicker.\",\n                        \"The glass color can be darker.\"\n                    ]\n                },\n                \"negative\": {\n                    \"definite\": {\n                        \"1\": [\n                            \"Make the eyeglasses slightly less obvious.\",\n                            \"The eyeglasses can be slightly less obvious.\"\n                        ],\n                        \"2\": [\n                            \"Make the eyeglasses somewhat less obvious.\",\n                            \"The eyeglasses can be somewhat less obvious.\"\n                        ],\n                        \"3\": [\n                            \"Make the eyeglasses less obvious.\",\n                            \"The eyeglasses can be less obvious.\"\n                        ],\n                        \"4\": [\n                            \"The eyeglasses are too obvious, let's make it much less obvious.\",\n                            \"The eyeglasses are too obvious, let's try make it much less obvious.\"\n                        ],\n                        \"5\": [\n                            \"Remove eyeglasses.\",\n                            \"Remove the eyeglasses.\",\n                            \"I don't like the eyeglasses.\",\n                            \"I don't want to see the eyeglasses.\",\n                            \"There shouldn't be any eyeglasses.\"\n                        ]\n                    },\n                    \"indefinite\": [\n                        \"Remove eyeglasses.\",\n                        \"No eyeglasses.\",\n                        \"The eyeglasses can be less obvious.\",\n                        \"The eyeglasses are too obvious.\",\n                        \"Let's make the eyeglasses more obvious.\",\n                        \"The eyeglasses frame can be thinner.\",\n                        \"The glass color can be lighter.\"\n                    ]\n                }\n            }\n        },\n        \"No_Beard\": {\n            \"target\": {\n                \"0\": [\n                    \"Let's see what he looks like without his beard.\",\n                    \"Let's shave the beard off.\",\n                    \"No beard\"\n                ],\n                \"1\": [\n                    \"His face should be covered with short pointed beard.\",\n                    \"His face should be covered with the stubble.\",\n                    \"His face has a rough growth of stubble.\",\n                    \"There should be stubble covering his cheeks and chin.\"\n                ],\n                \"2\": [\n                    \"His face should be covered with short beard.\",\n                    \"Let's add short beard to his face.\",\n                    \"Let's try short beard on his face.\"\n                ],\n                \"3\": [\n                    \"His face should be covered with beard of medium length.\",\n                    \"Let's add medium-length beard to his face.\",\n                    \"Let's try medium-length beard on his face.\"\n                ],\n                \"4\": [\n                    \"Let's try a big mustache on his face.\",\n                    \"He should have a bushy beard.\"\n                ],\n                \"5\": [\n                    \"Let's add very long beard.\",\n                    \"Let's add a full beard.\",\n                    \"He should have very thick beard.\",\n                    \"He should have a very bushy beard.\"\n                ]\n            },\n            \"change\": {\n                \"positive\": {\n                    \"definite\": {\n                        \"1\": [\n                            \"The beard can be slightly longer.\",\n                            \"Make the beard slightly longer.\",\n                            \"Slightly add more beard.\"\n                        ],\n                        \"2\": [\n                            \"The beard can be somewhat longer, but not too much.\",\n                            \"Make the beard somewhat longer, but not too much.\"\n                        ],\n                        \"3\": [\n                            \"The beard can be longer.\",\n                            \"Make the beard longer.\"\n                        ],\n                        \"4\": [\n                            \"The beard can be much longer.\",\n                            \"Make the beard much longer.\"\n                        ],\n                        \"5\": [\n                            \"Let's add very long beard.\",\n                            \"Let's add a full beard.\",\n                            \"He should have very thick beard\",\n                            \"He has a very bushy beard.\"\n                        ]\n                    },\n                    \"indefinite\": [\n                        \"Add beard.\",\n                        \"Add some beard.\",\n                        \"Longer beard.\",\n                        \"Let's add more beard.\",\n                        \"I want some more beard on the face.\"\n                    ]\n                },\n                \"negative\": {\n                    \"definite\": {\n                        \"1\": [\n                            \"The beard can be slightly shorter.\",\n                            \"Make the beard slightly shorter.\",\n                            \"Slightly remove some beard.\"\n                        ],\n                        \"2\": [\n                            \"The beard can be somewhat shorter, but not too much.\",\n                            \"Make the beard somewhat shorter, but not too much.\"\n                        ],\n                        \"3\": [\n                            \"The beard can be shorter.\",\n                            \"Make the beard shorter.\"\n                        ],\n                        \"4\": [\n                            \"The beard can be much shorter.\",\n                            \"Make the beard much shorter.\"\n                        ],\n                        \"5\": [\n                            \"Let's see what he looks like without his beard.\",\n                            \"Let's shave the beard off.\"\n                        ]\n                    },\n                    \"indefinite\": [\n                        \"Less beard.\",\n                        \"Remove beard.\",\n                        \"Remove the beard.\",\n                        \"The beard should be gone.\",\n                        \"Let's try to remove the beard.\",\n                        \"I don't like the beard.\",\n                        \"Let's try shorter beard.\"\n                    ]\n                }\n            }\n        },\n        \"Smiling\": {\n            \"target\": {\n                \"0\": [\n                    \"I think the person shouldn't be smiling.\",\n                    \"I don't like the smile.\",\n                    \"I don't want the smile.\",\n                    \"No smile.\",\n                    \"Remove the smile.\"\n                ],\n                \"1\": [\n                    \"Turn up the corners of the mouth.\",\n                    \"The corners of the mouth should curve up.\"\n                ],\n                \"2\": [\n                    \"The corners of the mouth should curve up and show some teeth.\",\n                    \"Smile broadly and show some teeth.\"\n                ],\n                \"3\": [\n                    \"I want a beaming face.\",\n                    \"I want the face to be smiling with teeth visible.\",\n                    \"The entire face should be beamed with happiness.\"\n                ],\n                \"4\": [\n                    \"It can be a big smile.\",\n                    \"I want a big smile on the face.\",\n                    \"I want the face to be smiling with the mouth slightly open.\",\n                    \"I want the face to be smiling with the mouth slightly open. We should be able to see the teeth.\",\n                    \"I want the face to be smiling with the mouth slightly open so that we can see the teeth.\"\n                ],\n                \"5\": [\n                    \"I want a deep rumbling laugh.\",\n                    \"It can be laughing happily.\",\n                    \"It can be a very big smile.\",\n                    \"I want a very big smile on the face.\",\n                    \"I want the face to be smiling with the mouth wide open.\",\n                    \"I want the face to be smiling with the mouth wide open. We should be able to see the teeth.\"\n                ]\n            },\n            \"change\": {\n                \"positive\": {\n                    \"definite\": {\n                        \"1\": [\n                            \"Smile slightly more.\",\n                            \"The smile can be slightly bigger.\",\n                            \"Make the smile slightly bigger.\",\n                            \"The person can look slightly happier.\",\n                            \"The person can smile slightly more happily.\"\n                        ],\n                        \"2\": [\n                            \"The smile can be somewhat bigger, but not too much.\",\n                            \"Make the smile somewhat bigger, but not too much.\",\n                            \"The person can look somewhat happier.\",\n                            \"The person can smile somewhat more happily.\"\n                        ],\n                        \"3\": [\n                            \"Smile more.\",\n                            \"The smile can be bigger.\",\n                            \"Make the smile bigger.\",\n                            \"The person can be happier.\",\n                            \"The person can smile more happily.\"\n                        ],\n                        \"4\": [\n                            \"The smile can be much bigger.\",\n                            \"Make the smile much bigger.\",\n                            \"The person can be a lot happier.\",\n                            \"The person can smile a lot more happily.\"\n                        ],\n                        \"5\": [\n                            \"I want a deep rumbling laugh.\",\n                            \"It can be laughing happily.\",\n                            \"It can be a very big smile.\",\n                            \"I want a very big smile on the face.\",\n                            \"I want the face to be smiling with the mouth wide open.\",\n                            \"I want the face to be smiling with the mouth wide open. We should be able to see the teeth.\",\n                            \"The person can smile very happily.\"\n                        ]\n                    },\n                    \"indefinite\": [\n                        \"Look not so serious.\",\n                        \"Look less serious.\",\n                        \"Too serious, be happier.\",\n                        \"Add smile.\",\n                        \"Add some smiling please.\",\n                        \"The smile is not big enough.\",\n                        \"I want a bigger smile.\",\n                        \"I want the face to smile more.\",\n                        \"I want to change the pokerface face to a smiling face.\",\n                        \"The person can smile more happily.\",\n                        \"Can look happier.\"\n                    ]\n                },\n                \"negative\": {\n                    \"definite\": {\n                        \"1\": [\n                            \"I want the smile to be slightly less obvious.\",\n                            \"The smile can be slightly less obvious.\",\n                            \"The person can smile slightly less happily.\"\n                        ],\n                        \"2\": [\n                            \"I want the smile to be less obvious.\",\n                            \"The smile can be less obvious.\",\n                            \"The person can smile somewhat less happily.\"\n                        ],\n                        \"3\": [\n                            \"I want the smile to be much less obvious.\",\n                            \"The smile can be much less obvious.\",\n                            \"The person can smile less happily.\"\n                        ],\n                        \"4\": [\n                            \"I want to make the smile almost vanish.\",\n                            \"The person can smile a lot less happily.\"\n                        ],\n                        \"5\": [\n                            \"I want the smile to vanish.\",\n                            \"I don't like the smile, let's remove it.\"\n                        ]\n                    },\n                    \"indefinite\": [\n                        \"Not serious enough.\",\n                        \"More serious.\",\n                        \"No smiling.\",\n                        \"No smile.\",\n                        \"Remove smiling.\",\n                        \"Remove the smiling.\",\n                        \"Remove smile.\",\n                        \"Remove the smile.\",\n                        \"Smile less happily.\",\n                        \"Don't be so happy.\",\n                        \"The smile is too much.\",\n                        \"Can we have a gentler smile? This smile is too big.\",\n                        \"I want to change the smiling face to a pokerface.\"\n                    ]\n                }\n            }\n        },\n        \"Young\": {\n            \"target\": {\n                \"0\": [\n                    \"Let's make the face a child one.\",\n                    \"Let's make the face very young.\"\n                ],\n                \"1\": [\n                    \"Let's make the face a teenager one.\",\n                    \"Let's make the face relatively young.\",\n                    \"The person should be in the twenties.\"\n                ],\n                \"2\": [\n                    \"Let's make the face a young one.\",\n                    \"It should be a young adult.\",\n                    \"The person should be in the thirties.\"\n                ],\n                \"3\": [\n                    \"Let's make the face a middle age one.\",\n                    \"The person should be in the forties.\"\n                ],\n                \"4\": [\n                    \"Let's make the face slightly older than middle age.\",\n                    \"Let's make the face the one of a senior.\",\n                    \"Let's make the face the one of an elderly.\",\n                    \"The person should be in the sixties.\",\n                    \"The person should be in the fifties.\"\n                ],\n                \"5\": [\n                    \"Let's make the face a very old one.\",\n                    \"The person should be in the seventies.\",\n                    \"The person should be in the eighties.\"\n                ]\n            },\n            \"change\": {\n                \"positive\": {\n                    \"definite\": {\n                        \"1\": [\n                            \"The face can be slightly older.\",\n                            \"Make the face slightly older.\"\n                        ],\n                        \"2\": [\n                            \"Somewhat older\",\n                            \"The face can be somewhat older, just not too much.\",\n                            \"Make the face somewhat older, but not too much.\"\n                        ],\n                        \"3\": [\n                            \"Make the face older, but not too much.\",\n                            \"Make the face older, but not too much.\"\n                        ],\n                        \"4\": [\n                            \"The face can be older.\",\n                            \"Make the face older.\"\n                        ],\n                        \"5\": [\n                            \"The face can be much older.\",\n                            \"Make the face much older.\",\n                            \"Let's make the face a very old one.\"\n                        ]\n                    },\n                    \"indefinite\": [\n                        \"Older.\",\n                        \"Make it older.\",\n                        \"The face can be older.\",\n                        \"This face is too young, let's make it older.\",\n                        \"Let's make the face older.\",\n                        \"What about making the face look older?\"\n                    ]\n                },\n                \"negative\": {\n                    \"definite\": {\n                        \"1\": [\n                            \"The face can be slightly younger.\",\n                            \"Make the face slightly younger.\"\n                        ],\n                        \"2\": [\n                            \"Somewhat younger.\",\n                            \"The face can be somewhat younger, but not too much.\",\n                            \"Make the face somewhat younger, but not too much.\"\n                        ],\n                        \"3\": [\n                            \"The face can be younger.\",\n                            \"Make the face younger.\",\n                            \"Younger face.\"\n                        ],\n                        \"4\": [\n                            \"Much younger.\",\n                            \"The face can be much younger.\",\n                            \"Make the face much younger.\"\n                        ],\n                        \"5\": [\n                            \"Let's make the face a child one.\"\n                        ]\n                    },\n                    \"indefinite\": [\n                        \"Younger face.\",\n                        \"Younger.\",\n                        \"Look younger\",\n                        \"Make it younger.\",\n                        \"Be younger.\",\n                        \"Less old.\",\n                        \"The face can be younger.\",\n                        \"This face is too old, let's make it younger.\",\n                        \"Let's make the face younger.\",\n                        \"What about making it younger?\",\n                        \"Can you make the person look younger?\"\n                    ]\n                }\n            }\n        }\n    },\n    \"yes\": [\n        [\n            \"Yes\",\n            \"Yep\",\n            \"Yeep\",\n            \"Yep sure\",\n            \"Yes sure\",\n            \"Sure\",\n            \"Ok\"\n        ],\n        [\n            \".\"\n        ]\n    ],\n    \"no\": [\n        [\n            \"No\",\n            \"Nope\"\n        ],\n        [\n            \".\"\n        ]\n    ],\n    \"end\": [\n        [\n            \"End.\",\n            \"Nothing.\",\n            \"Nothing else.\",\n            \"Nothing else for now.\",\n            \"It's all good now.\",\n            \"I don't want any further edits.\",\n            \"Actually it's all good now.\",\n            \"No need for further edits.\",\n            \"I don't need any further edits.\",\n            \"That's all.\",\n            \"This is it.\",\n            \"That is it.\",\n            \"That is all.\",\n            \"No.\"\n        ],\n        [\n            \" Thanks!\",\n            \" Thank you!\",\n            \" Thanks a lot!\",\n            \"\"\n        ]\n    ]\n}"
  },
  {
    "path": "language/templates/user_old_templates.json",
    "content": "{\n    \"start\": [\n        [\n            \"Hi.\",\n            \"Hello.\"\n        ],\n        [\n            \" \"\n        ]\n    ],\n    \"requests\": {\n        \"Bangs\": {\n            \"target\": {\n                \"0\": [\n                    \"No bangs.\",\n                    \"Remove all the bangs.\",\n                    \"Cut off all the bangs.\",\n                    \"I don't want the bangs at all.\",\n                    \"I don't want any bangs.\",\n                    \"I don't want any bangs visible.\",\n                    \"The bangs doesn't look good, let's remove it.\",\n                    \"The bangs covers the forehead, but I want the entire forehead visible.\"\n                ],\n                \"1\": [\n                    \"Add very short bangs.\",\n                    \"I want very short bangs.\",\n                    \"Add very short bangs that leaves most of the forehead uncovered.\",\n                    \"I want very short bangs that leaves most of the forehead uncovered.\"\n                ],\n                \"2\": [\n                    \"Add short bangs.\",\n                    \"Let's try short bangs.\",\n                    \"Add short bangs that covers only a small portion of the forehead.\",\n                    \"Let's try short bangs that covers only a small portion of the forehead.\"\n                ],\n                \"3\": [\n                    \"Add medium bangs.\",\n                    \"Add bangs of medium length.\",\n                    \"Let's try bangs of medium length.\",\n                    \"Let's try bangs that leaves half of the forehead visible.\"\n                ],\n                \"4\": [\n                    \"Add long bangs.\",\n                    \"Let's try long bangs.\",\n                    \"Add long bangs but don't cover the entire forehead.\",\n                    \"Let's try long bangs but don't cover the entire forehead.\"\n                ],\n                \"5\": [\n                    \"Add extremely long bangs.\",\n                    \"Let's try extremely long bangs.\",\n                    \"Add extremely long bangs that covers the entire forehead.\",\n                    \"Let's try extremely long bangs that covers the entire forehead.\",\n                    \"Indeed, the bangs can be much longer. Let's cover the eyebrows.\"\n                ]\n            },\n            \"change\": {\n                \"positive\": {\n                    \"definite\": {\n                        \"1\": [\n                            \"The bangs can be slightly longer.\",\n                            \"Make the bangs slightly longer.\"\n                        ],\n                        \"2\": [\n                            \"The bangs can be somewhat longer, but not too much.\",\n                            \"Make the bangs somewhat longer, but not too much.\"\n                        ],\n                        \"3\": [\n                            \"Make the bangs longer, but not too much.\"\n                        ],\n                        \"4\": [\n                            \"The bangs can be longer.\",\n                            \"Make the bangs longer.\"\n                        ],\n                        \"5\": [\n                            \"The bangs can be much longer.\",\n                            \"Make the bangs much longer.\"\n                        ]\n                    },\n                    \"indefinite\": [\n                        \"The bangs can be longer.\",\n                        \"Let's add some bangs.\",\n                        \"Maybe the bangs can be longer.\",\n                        \"Let's try adding longer bangs.\",\n                        \"What about adding longer bangs?\",\n                        \"Emm, I think the bangs can be longer.\",\n                        \"Let's make the bangs longer.\",\n                        \"Hi, I want to see how my friend looks like with some bangs.\"\n                    ]\n                },\n                \"negative\": {\n                    \"definite\": {\n                        \"1\": [\n                            \"The bangs can be slightly shorter.\",\n                            \"Make the bangs slightly shorter.\"\n                        ],\n                        \"2\": [\n                            \"The bangs can be somewhat shorter, but not too much.\",\n                            \"Make the bangs somewhat shorter, but not too much.\"\n                        ],\n                        \"3\": [\n                            \"The bangs can be shorter.\",\n                            \"Make the bangs shorter.\"\n                        ],\n                        \"4\": [\n                            \"The bangs can be much shorter.\",\n                            \"Make the bangs much shorter.\"\n                        ],\n                        \"5\": [\n                            \"Remove all the bangs.\",\n                            \"I don't want the bangs at all.\",\n                            \"I don't want any bangs at all.\"\n                        ]\n                    },\n                    \"indefinite\": [\n                        \"Remove bangs.\",\n                        \"Remove the bangs.\",\n                        \"Let's cut off the bangs.\",\n                        \"Let's cut the bangs short.\",\n                        \"Let's cut the bangs off.\",\n                        \"I don't like the bangs, let's remove it.\",\n                        \"I don't like the bangs, let's cut it off.\",\n                        \"The bangs is too long, let's remove it.\",\n                        \"The bangs is too long, let's cut it off.\"\n                    ]\n                }\n            }\n        },\n        \"Eyeglasses\": {\n            \"target\": {\n                \"0\": [\n                    \"No eyeglasses please.\",\n                    \"No eyeglasses.\",\n                    \"Remove eyeglasses.\",\n                    \"Remove the eyeglasses.\",\n                    \"I don't want to see the eyeglasses.\",\n                    \"I think there shouldn't be any eyeglasses.\"\n                ],\n                \"1\": [\n                    \"The eyeglasses should be rimless.\",\n                    \"Let's try rimless eyeglasses.\"\n                ],\n                \"2\": [\n                    \"The eyeglasses should have thin frame.\",\n                    \"Let's try thin frame eyeglasses.\"\n                ],\n                \"3\": [\n                    \"The eyeglasses should have thick frame.\",\n                    \"Let's try thick frame eyeglasses.\"\n                ],\n                \"4\": [\n                    \"Let's try thin frame sunglasses.\",\n                    \"It should be sunglasses with thin frame.\"\n                ],\n                \"5\": [\n                    \"Let's try thick frame sunglasses.\",\n                    \"It should be sunglasses with thick frame.\"\n                ]\n            },\n            \"change\": {\n                \"positive\": {\n                    \"definite\": {\n                        \"1\": [\n                            \"Make the eyeglasses slightly more obvious.\",\n                            \"The eyeglasses can be slightly more obvious.\"\n                        ],\n                        \"2\": [\n                            \"Make the eyeglasses somewhat more obvious.\",\n                            \"The eyeglasses can be somewhat more obvious.\"\n                        ],\n                        \"3\": [\n                            \"Make the eyeglasses more obvious.\",\n                            \"The eyeglasses can be more obvious.\"\n                        ],\n                        \"4\": [\n                            \"Let's try eyeglasses with thicker frame and darker color.\"\n                        ],\n                        \"5\": [\n                            \"Let's try thick frame sunglasses.\",\n                            \"It should be sunglasses with thick frame.\"\n                        ]\n                    },\n                    \"indefinite\": [\n                        \"Try eyeglasses.\",\n                        \"Add eyeglasses.\",\n                        \"Add eyeglasses to the face.\",\n                        \"Add eyeglasses please.\",\n                        \"Let's add eyeglasses.\",\n                        \"The eyeglasses can be more obvious.\",\n                        \"The eyeglasses are not obvious enough.\",\n                        \"I can't see the eyeglasses clearly, let's make them more obvious.\",\n                        \"The eyeglasses frame can be thicker.\",\n                        \"The glass color can be darker.\"\n                    ]\n                },\n                \"negative\": {\n                    \"definite\": {\n                        \"1\": [\n                            \"Make the eyeglasses slightly less obvious.\",\n                            \"The eyeglasses can be slightly less obvious.\"\n                        ],\n                        \"2\": [\n                            \"Make the eyeglasses somewhat less obvious.\",\n                            \"The eyeglasses can be somewhat less obvious.\"\n                        ],\n                        \"3\": [\n                            \"Make the eyeglasses less obvious.\",\n                            \"The eyeglasses can be less obvious.\"\n                        ],\n                        \"4\": [\n                            \"The eyeglasses are too obvious, let's make it much less obvious.\",\n                            \"The eyeglasses are too obvious, let's try make it much less obvious.\"\n                        ],\n                        \"5\": [\n                            \"Remove eyeglasses.\",\n                            \"Remove the eyeglasses.\",\n                            \"I don't like the eyeglasses.\",\n                            \"I don't want to see the eyeglasses.\",\n                            \"There shouldn't be any eyeglasses.\"\n                        ]\n                    },\n                    \"indefinite\": [\n                        \"The eyeglasses can be less obvious.\",\n                        \"The eyeglasses are too obvious.\",\n                        \"Let's make the eyeglasses more obvious.\",\n                        \"The eyeglasses frame can be thinner.\",\n                        \"The glass color can be lighter.\"\n                    ]\n                }\n            }\n        },\n        \"No_Beard\": {\n            \"target\": {\n                \"0\": [\n                    \"Let's see what he looks like without his beard.\",\n                    \"Let's shave the beard off.\"\n                ],\n                \"1\": [\n                    \"His face should be covered with short pointed beard.\",\n                    \"His face should be covered with the stubble.\",\n                    \"His face has a rough growth of stubble.\",\n                    \"There should be stubble covering his cheeks and chin.\"\n                ],\n                \"2\": [\n                    \"His face should be covered with short beard.\",\n                    \"Let's add short beard to his face.\",\n                    \"Let's try short beard on his face.\"\n                ],\n                \"3\": [\n                    \"His face should be covered with beard of medium length.\",\n                    \"Let's add medium-length beard to his face.\",\n                    \"Let's try medium-length beard on his face.\"\n                ],\n                \"4\": [\n                    \"Let's try a big mustache on his face.\",\n                    \"He should have a bushy beard.\"\n                ],\n                \"5\": [\n                    \"Let's add very long beard.\",\n                    \"Let's add a full beard.\",\n                    \"He should have very thick beard.\",\n                    \"He should have a very bushy beard.\"\n                ]\n            },\n            \"change\": {\n                \"positive\": {\n                    \"definite\": {\n                        \"1\": [\n                            \"The beard can be slightly longer.\",\n                            \"Make the beard slightly longer.\",\n                            \"Slightly add more beard.\"\n                        ],\n                        \"2\": [\n                            \"The beard can be somewhat longer, but not too much.\",\n                            \"Make the beard somewhat longer, but not too much.\"\n                        ],\n                        \"3\": [\n                            \"The beard can be longer.\",\n                            \"Make the beard longer.\"\n                        ],\n                        \"4\": [\n                            \"The beard can be much longer.\",\n                            \"Make the beard much longer.\"\n                        ],\n                        \"5\": [\n                            \"Let's add very long beard.\",\n                            \"Let's add a full beard.\",\n                            \"He should have very thick beard\",\n                            \"He has a very bushy beard.\"\n                        ]\n                    },\n                    \"indefinite\": [\n                        \"Add beard.\",\n                        \"Add some beard.\",\n                        \"Longer beard.\",\n                        \"Let's add more beard.\",\n                        \"I want some more beard on the face.\"\n                    ]\n                },\n                \"negative\": {\n                    \"definite\": {\n                        \"1\": [\n                            \"The beard can be slightly shorter.\",\n                            \"Make the beard slightly shorter.\",\n                            \"Slightly remove some beard.\"\n                        ],\n                        \"2\": [\n                            \"The beard can be somewhat shorter, but not too much.\",\n                            \"Make the beard somewhat shorter, but not too much.\"\n                        ],\n                        \"3\": [\n                            \"The beard can be shorter.\",\n                            \"Make the beard shorter.\"\n                        ],\n                        \"4\": [\n                            \"The beard can be much shorter.\",\n                            \"Make the beard much shorter.\"\n                        ],\n                        \"5\": [\n                            \"Let's see what he looks like without his beard.\",\n                            \"Let's shave the beard off.\"\n                        ]\n                    },\n                    \"indefinite\": [\n                        \"Remove beard.\",\n                        \"Remove the beard.\",\n                        \"The beard should be gone.\",\n                        \"Let's try to remove the beard.\",\n                        \"I don't like the beard.\",\n                        \"Let's try shorter beard\"\n                    ]\n                }\n            }\n        },\n        \"Smiling\": {\n            \"target\": {\n                \"0\": [\n                    \"I think the person shouldn't be smiling.\",\n                    \"I don't like the smile.\",\n                    \"I don't want the smile\"\n                ],\n                \"1\": [\n                    \"Turn up the corners of the mouth\",\n                    \"The corners of the mouth curve up.\"\n                ],\n                \"2\": [\n                    \"The corners of the mouth curve up and show some teeth.\",\n                    \"Smile broadly and show some teeth.\"\n                ],\n                \"3\": [\n                    \"I want a beaming face.\",\n                    \"I want the face to be smiling with teeth visible.\",\n                    \"The entire face should be beamed with happiness.\"\n                ],\n                \"4\": [\n                    \"It can be a big smile.\",\n                    \"I want a big smile on the face.\",\n                    \"I want the face to be smiling with the mouth slightly open.\",\n                    \"I want the face to be smiling with the mouth slightly open. We should be able to see the teeth.\",\n                    \"I want the face to be smiling with the mouth slightly open so that we can see the teeth.\"\n                ],\n                \"5\": [\n                    \"I want a deep rumbling laugh.\",\n                    \"It can be laughing happily.\",\n                    \"It can be a very big smile.\",\n                    \"I want a very big smile on the face.\",\n                    \"I want the face to be smiling with the mouth wide open.\",\n                    \"I want the face to be smiling with the mouth wide open. We should be able to see the teeth.\"\n                ]\n            },\n            \"change\": {\n                \"positive\": {\n                    \"definite\": {\n                        \"1\": [\n                            \"Smile slightly more.\",\n                            \"The smile can be slightly bigger.\",\n                            \"Make the smile slightly bigger.\",\n                            \"The person can look slightly happier.\",\n                            \"The person can smile slightly more happily.\"\n                        ],\n                        \"2\": [\n                            \"The smile can be somewhat bigger, but not too much.\",\n                            \"Make the smile somewhat bigger, but not too much.\",\n                            \"The person can look somewhat happier.\",\n                            \"The person can smile somewhat more happily.\"\n                        ],\n                        \"3\": [\n                            \"Smile more.\",\n                            \"The smile can be bigger.\",\n                            \"Make the smile bigger.\",\n                            \"The person can be happier.\",\n                            \"The person can smile more happily.\"\n                        ],\n                        \"4\": [\n                            \"The smile can be much bigger.\",\n                            \"Make the smile much bigger.\",\n                            \"The person can be a lot happier.\",\n                            \"The person can smile a lot more happily.\"\n                        ],\n                        \"5\": [\n                            \"I want a deep rumbling laugh.\",\n                            \"It can be laughing happily.\",\n                            \"It can be a very big smile.\",\n                            \"I want a very big smile on the face.\",\n                            \"I want the face to be smiling with the mouth wide open.\",\n                            \"I want the face to be smiling with the mouth wide open. We should be able to see the teeth.\",\n                            \"The person can smile very happily.\"\n                        ]\n                    },\n                    \"indefinite\": [\n                        \"Add some smiling please.\",\n                        \"The smile is not big enough.\",\n                        \"I want a bigger smile.\",\n                        \"I want the face to smile more.\",\n                        \"I want to change the pokerface face to a smiling face.\",\n                        \"The person can smile more happily.\",\n                        \"Can look happier.\"\n                    ]\n                },\n                \"negative\": {\n                    \"definite\": {\n                        \"1\": [\n                            \"I want the smile to be slightly less obvious.\",\n                            \"The smile can be slightly less obvious.\",\n                            \"The person can smile slightly less happily.\"\n                        ],\n                        \"2\": [\n                            \"I want the smile to be less obvious.\",\n                            \"The smile can be less obvious.\",\n                            \"The person can smile somewhat less happily.\"\n                        ],\n                        \"3\": [\n                            \"I want the smile to be much less obvious.\",\n                            \"The smile can be much less obvious.\",\n                            \"The person can smile less happily.\"\n                        ],\n                        \"4\": [\n                            \"I want to make the smile almost vanish.\",\n                            \"The person can smile a lot less happily.\"\n                        ],\n                        \"5\": [\n                            \"I want the smile to vanish.\",\n                            \"I don't like the smile, let's remove it.\"\n                        ]\n                    },\n                    \"indefinite\": [\n                        \"No smiling.\",\n                        \"No smile.\",\n                        \"Remove smiling.\",\n                        \"Remove the smiling.\",\n                        \"Remove smile.\",\n                        \"Remove the smile.\",\n                        \"Smile less happily.\",\n                        \"Don't be so happy.\",\n                        \"The smile is too much.\",\n                        \"Can we have a gentler smile? This smile is too big.\",\n                        \"I want to change the smiling face to a pokerface.\"\n                    ]\n                }\n            }\n        },\n        \"Young\": {\n            \"target\": {\n                \"0\": [\n                    \"Let's make the face a child one.\",\n                    \"Let's make the face very young.\"\n                ],\n                \"1\": [\n                    \"Let's make the face a teenager one.\",\n                    \"Let's make the face relatively young.\",\n                    \"The person should be in the twenties.\"\n                ],\n                \"2\": [\n                    \"Let's make the face a young one.\",\n                    \"It should be a young adult.\",\n                    \"The person should be in the thirties.\"\n                ],\n                \"3\": [\n                    \"Let's make the face a middle age one.\",\n                    \"The person should be in the forties.\"\n                ],\n                \"4\": [\n                    \"Let's make the face slightly older than middle age.\",\n                    \"Let's make the face the one of a senior.\",\n                    \"Let's make the face the one of an elderly.\",\n                    \"The person should be in the sixties.\",\n                    \"The person should be in the fifties.\"\n                ],\n                \"5\": [\n                    \"Let's make the face a very old one.\",\n                    \"The person should be in the seventies.\",\n                    \"The person should be in the eighties.\"\n                ]\n            },\n            \"change\": {\n                \"positive\": {\n                    \"definite\": {\n                        \"1\": [\n                            \"The face can be slightly older.\",\n                            \"Make the face slightly older.\"\n                        ],\n                        \"2\": [\n                            \"Somewhat older\",\n                            \"The face can be somewhat older, just not too much.\",\n                            \"Make the face somewhat older, but not too much.\"\n                        ],\n                        \"3\": [\n                            \"Make the face older, but not too much.\",\n                            \"Make the face older, but not too much.\"\n                        ],\n                        \"4\": [\n                            \"The face can be older.\",\n                            \"Make the face older.\"\n                        ],\n                        \"5\": [\n                            \"The face can be much older.\",\n                            \"Make the face much older.\",\n                            \"Let's make the face a very old one.\"\n                        ]\n                    },\n                    \"indefinite\": [\n                        \"Older.\",\n                        \"Make it older.\",\n                        \"The face can be older.\",\n                        \"This face is too young, let's make it older.\",\n                        \"Let's make the face older.\",\n                        \"What about making the face look older?\"\n                    ]\n                },\n                \"negative\": {\n                    \"definite\": {\n                        \"1\": [\n                            \"The face can be slightly younger.\",\n                            \"Make the face slightly younger.\"\n                        ],\n                        \"2\": [\n                            \"Somewhat younger.\",\n                            \"The face can be somewhat younger, but not too much.\",\n                            \"Make the face somewhat younger, but not too much.\"\n                        ],\n                        \"3\": [\n                            \"The face can be younger.\",\n                            \"Make the face younger.\",\n                            \"Younger face.\"\n                        ],\n                        \"4\": [\n                            \"Much younger.\",\n                            \"The face can be much younger.\",\n                            \"Make the face much younger.\"\n                        ],\n                        \"5\": [\n                            \"Let's make the face a child one.\"\n                        ]\n                    },\n                    \"indefinite\": [\n                        \"Younger face.\",\n                        \"Younger.\",\n                        \"Make it younger.\",\n                        \"Be younger.\",\n                        \"Less old.\",\n                        \"The face can be younger.\",\n                        \"This face is too old, let's make it younger.\",\n                        \"Let's make the face younger.\",\n                        \"What about making it younger?\"\n                    ]\n                }\n            }\n        }\n    },\n    \"yes_enough\": [\n        [\n            \"Emmm, yep\",\n            \"Emmm, yes\",\n            \"Emmm, yeep\",\n            \"Yes\",\n            \"Yep\",\n            \"Yeep\",\n            \"Yep sure\"\n        ],\n        [\n            \", \",\n            \". \",\n            \"! \"\n        ],\n        [\n            \"That's good enough now.\",\n            \"That's nice.\",\n            \"That's perfect.\",\n            \"This is great.\"\n        ],\n        [\n            \" \"\n        ]\n    ],\n    \"no_enough\": [\n        [\n            \"Actually,\",\n            \"To be honest,\",\n            \"Well,\",\n            \"Well\",\n            \"Emm\",\n            \"Emmm\",\n            \"Emmmm\",\n            \"Emm,\",\n            \"Emmm,\",\n            \"Emmmm,\",\n            \"I'm not too sure but\",\n            \"It looks okay now but\",\n            \"It looks better now, but still,\",\n            \"It looks nice, but still,\",\n            \"Let me have a look. Well,\",\n            \"Let me have a look. Well\",\n            \"Let me have a look. Emm,\",\n            \"Let me have a look. Emmm,\",\n            \"Let me have a look. Emmmm,\",\n            \"Let me have a look. Emm\",\n            \"Let me have a look. Emmm\",\n            \"Let me have a look. Emmmm\",\n            \"Let me take a look. Well,\",\n            \"Let me take a look. Well\",\n            \"Let me take a look. Emm,\",\n            \"Let me take a look. Emmm,\",\n            \"Let me take a look. Emmmm,\",\n            \"Let me take a look. Emm\",\n            \"Let me take a look. Emmm\",\n            \"Let me take a look. Emmmm\"\n        ],\n        [\n            \" \"\n        ]\n    ],\n    \"yes_suggestion\": [\n        [\n            \"Emmm, yep\",\n            \"Emmm, yes\",\n            \"Emmm, yeep\",\n            \"Yes\",\n            \"Yep\",\n            \"Yeep\",\n            \"Yep sure\",\n            \"Yes sure\"\n        ],\n        [\n            \",\",\n            \".\",\n            \"!\"\n        ],\n        [\n            \" \"\n        ]\n    ],\n    \"no_suggestion\": [\n        [\n            \"Well,\",\n            \"Well\",\n            \"Emm,\",\n            \"Emmm\",\n            \"Emmmm\",\n            \"Emm,\",\n            \"Emmm,\",\n            \"Emmmm,\",\n            \"I'm not too sure so\",\n            \"It looks okay now so\",\n            \"It looks nice, so,\",\n            \"Let me have a look. Well,\",\n            \"Let me have a look. Well\",\n            \"Let me have a look. Emm,\",\n            \"Let me have a look. Emmm,\",\n            \"Let me have a look. Emmmm,\",\n            \"Let me have a look. Emm\",\n            \"Let me have a look. Emmm\",\n            \"Let me have a look. Emmmm\",\n            \"Let me take a look. Well,\",\n            \"Let me take a look. Well\",\n            \"Let me take a look. Emm,\",\n            \"Let me take a look. Emmm,\",\n            \"Let me take a look. Emmmm,\",\n            \"Let me take a look. Emm\",\n            \"Let me take a look. Emmm\",\n            \"Let me take a look. Emmmm\"\n        ],\n        [\n            \" \"\n        ],\n        [\n            \"Not really.\",\n            \"Not really actually.\",\n            \"No actually.\"\n        ],\n        [\n            \" \"\n        ]\n    ],\n    \"end\": [\n        [\n            \"Nothing else.\",\n            \"Nothing else for now.\",\n            \"It's all good now.\",\n            \"I don't want any further edits.\",\n            \"Actually it's all good now.\",\n            \"No need for further edits.\",\n            \"I don't need any further edits.\",\n            \"That's all.\",\n            \"This is it.\",\n            \"That is it.\",\n            \"That is all.\",\n            \"No.\"\n        ],\n        [\n            \" \"\n        ],\n        [\n            \"Thanks!\",\n            \"Thank you!\",\n            \"Thanks a lot!\"\n        ]\n    ]\n}"
  },
  {
    "path": "language/templates/vocab.json",
    "content": "{\n    \"text_token_to_idx\": {\n        \"<NULL>\": 0,\n        \"<START>\": 1,\n        \"<END>\": 2,\n        \"<UNK>\": 3,\n        \"?\": 4,\n        \"a\": 5,\n        \"able\": 6,\n        \"about\": 7,\n        \"actually\": 8,\n        \"add\": 9,\n        \"adding\": 10,\n        \"adult\": 11,\n        \"age\": 12,\n        \"all\": 13,\n        \"almost\": 14,\n        \"an\": 15,\n        \"and\": 16,\n        \"any\": 17,\n        \"are\": 18,\n        \"at\": 19,\n        \"bangs\": 20,\n        \"be\": 21,\n        \"beamed\": 22,\n        \"beaming\": 23,\n        \"beard\": 24,\n        \"big\": 25,\n        \"bigger\": 26,\n        \"bit\": 27,\n        \"broadly\": 28,\n        \"bushy\": 29,\n        \"but\": 30,\n        \"can\": 31,\n        \"can't\": 32,\n        \"change\": 33,\n        \"cheeks\": 34,\n        \"child\": 35,\n        \"chin\": 36,\n        \"clearly\": 37,\n        \"color\": 38,\n        \"considerably\": 39,\n        \"corners\": 40,\n        \"could\": 41,\n        \"cover\": 42,\n        \"covered\": 43,\n        \"covering\": 44,\n        \"covers\": 45,\n        \"curve\": 46,\n        \"cut\": 47,\n        \"darker\": 48,\n        \"deep\": 49,\n        \"degree\": 50,\n        \"doesn't\": 51,\n        \"don't\": 52,\n        \"edits\": 53,\n        \"eighties\": 54,\n        \"elderly\": 55,\n        \"else\": 56,\n        \"emm\": 57,\n        \"end\": 58,\n        \"enough\": 59,\n        \"entire\": 60,\n        \"extent\": 61,\n        \"extremely\": 62,\n        \"eyebrows\": 63,\n        \"eyeglass\": 64,\n        \"eyeglasses\": 65,\n        \"face\": 66,\n        \"feel\": 67,\n        \"fifties\": 68,\n        \"for\": 69,\n        \"forehead\": 70,\n        \"forties\": 71,\n        \"frame\": 72,\n        \"friend\": 73,\n        \"fringe\": 74,\n        \"full\": 75,\n        \"further\": 76,\n        \"gentler\": 77,\n        \"glass\": 78,\n        \"glasses\": 79,\n        \"go\": 80,\n        \"gone\": 81,\n        \"good\": 82,\n        \"growth\": 83,\n        \"guess\": 84,\n        \"half\": 85,\n        \"happier\": 86,\n        \"happily\": 87,\n        \"happiness\": 88,\n        \"happy\": 89,\n        \"has\": 90,\n        \"have\": 91,\n        \"he\": 92,\n        \"hello\": 93,\n        \"hi\": 94,\n        \"his\": 95,\n        \"how\": 96,\n        \"i\": 97,\n        \"in\": 98,\n        \"indeed\": 99,\n        \"is\": 100,\n        \"it\": 101,\n        \"it's\": 102,\n        \"just\": 103,\n        \"kind\": 104,\n        \"laugh\": 105,\n        \"laughing\": 106,\n        \"leaves\": 107,\n        \"length\": 108,\n        \"less\": 109,\n        \"let's\": 110,\n        \"lighter\": 111,\n        \"like\": 112,\n        \"little\": 113,\n        \"long\": 114,\n        \"longer\": 115,\n        \"look\": 116,\n        \"looks\": 117,\n        \"lot\": 118,\n        \"make\": 119,\n        \"making\": 120,\n        \"maybe\": 121,\n        \"medium\": 122,\n        \"medium-length\": 123,\n        \"middle\": 124,\n        \"moderately\": 125,\n        \"more\": 126,\n        \"most\": 127,\n        \"mouth\": 128,\n        \"much\": 129,\n        \"mustache\": 130,\n        \"my\": 131,\n        \"need\": 132,\n        \"no\": 133,\n        \"nope\": 134,\n        \"not\": 135,\n        \"nothing\": 136,\n        \"now\": 137,\n        \"obvious\": 138,\n        \"of\": 139,\n        \"off\": 140,\n        \"ok\": 141,\n        \"old\": 142,\n        \"older\": 143,\n        \"on\": 144,\n        \"one\": 145,\n        \"only\": 146,\n        \"open\": 147,\n        \"partially\": 148,\n        \"person\": 149,\n        \"please\": 150,\n        \"pointed\": 151,\n        \"poker\": 152,\n        \"pokerface\": 153,\n        \"portion\": 154,\n        \"relatively\": 155,\n        \"remove\": 156,\n        \"rimless\": 157,\n        \"rough\": 158,\n        \"rumbling\": 159,\n        \"schoolchild\": 160,\n        \"see\": 161,\n        \"senior\": 162,\n        \"serious\": 163,\n        \"seventies\": 164,\n        \"shave\": 165,\n        \"short\": 166,\n        \"shorter\": 167,\n        \"should\": 168,\n        \"shouldn't\": 169,\n        \"show\": 170,\n        \"simply\": 171,\n        \"sixties\": 172,\n        \"slightly\": 173,\n        \"small\": 174,\n        \"smile\": 175,\n        \"smiling\": 176,\n        \"so\": 177,\n        \"some\": 178,\n        \"somewhat\": 179,\n        \"sort\": 180,\n        \"stubble\": 181,\n        \"sunglasses\": 182,\n        \"sure\": 183,\n        \"teen\": 184,\n        \"teenager\": 185,\n        \"teeth\": 186,\n        \"than\": 187,\n        \"thank\": 188,\n        \"thanks\": 189,\n        \"that\": 190,\n        \"that's\": 191,\n        \"the\": 192,\n        \"them\": 193,\n        \"there\": 194,\n        \"thick\": 195,\n        \"thicker\": 196,\n        \"thin\": 197,\n        \"think\": 198,\n        \"thinner\": 199,\n        \"thirties\": 200,\n        \"this\": 201,\n        \"tiny\": 202,\n        \"to\": 203,\n        \"too\": 204,\n        \"try\": 205,\n        \"trying\": 206,\n        \"turn\": 207,\n        \"twenties\": 208,\n        \"uncovered\": 209,\n        \"up\": 210,\n        \"use\": 211,\n        \"vanish\": 212,\n        \"very\": 213,\n        \"visible\": 214,\n        \"want\": 215,\n        \"we\": 216,\n        \"what\": 217,\n        \"whole\": 218,\n        \"wide\": 219,\n        \"with\": 220,\n        \"without\": 221,\n        \"would\": 222,\n        \"yeep\": 223,\n        \"yep\": 224,\n        \"yes\": 225,\n        \"you\": 226,\n        \"young\": 227,\n        \"younger\": 228\n    }\n}"
  },
  {
    "path": "language/train_encoder.py",
    "content": "import argparse\nimport json\nimport sys\nimport time\n\nimport torch\nimport torch.nn as nn\nimport torch.utils.data\n\nsys.path.append('.')\nfrom accuracy import head_accuracy  # noqa\nfrom dataset import EncoderDataset  # noqa\nfrom lstm import Encoder  # noqa\nfrom utils import AverageMeter, dict2str, save_checkpoint  # noqa\nfrom utils.setup_logger import setup_logger  # noqa\n\n\ndef parse_args():\n    \"\"\"Parses arguments.\"\"\"\n\n    parser = argparse.ArgumentParser(description='Train the language encoder')\n\n    # mode\n    parser.add_argument('--debug', type=int, default=0)\n\n    # training\n    parser.add_argument('--batch_size', type=int, default=2048)\n    parser.add_argument('--val_batch', type=int, default=1024)\n\n    # learning rate scheme\n    parser.add_argument('--num_epochs', default=20, type=int)\n    parser.add_argument('--lr', default=1e-3, type=float)\n    parser.add_argument('--weight_decay', default=0, type=float)\n\n    # LSTM hyperparameter\n    parser.add_argument('--word_embedding_dim', default=300, type=int)\n    parser.add_argument('--text_embed_size', default=1024, type=int)\n    parser.add_argument('--linear_hidden_size', default=256, type=int)\n    parser.add_argument('--linear_dropout_rate', default=0, type=float)\n\n    # input directories\n    parser.add_argument(\n        '--vocab_file', required=True, type=str, help='path to vocab file.')\n    parser.add_argument(\n        '--metadata_file',\n        default='./templates/metadata_fsm.json',\n        type=str,\n        help='path to metadata file.')\n    parser.add_argument(\n        '--train_set_dir', required=True, type=str, help='path to train data.')\n    parser.add_argument(\n        '--val_set_dir', required=True, type=str, help='path to val data.')\n    # output directories\n    parser.add_argument(\n        '--work_dir',\n        required=True,\n        type=str,\n        help='path to save checkpoint and log files.')\n\n    # misc\n    parser.add_argument(\n        '--unlabeled_value',\n        default=999,\n        type=int,\n        help='value to represent unlabeled value')\n    parser.add_argument('--num_workers', default=8, type=int)\n\n    return parser.parse_args()\n\n\nbest_val_acc, best_epoch, current_iters = 0, 0, 0\n\n\ndef main():\n    \"\"\"Main function.\"\"\"\n\n    # ################### Set Up #######################\n    global args, best_val_acc, best_epoch\n\n    args = parse_args()\n    logger = setup_logger(\n        args.work_dir, logger_name='train.txt', debug=args.debug)\n\n    args.device = torch.device('cuda')\n\n    logger.info('Saving arguments.')\n    logger.info(dict2str(args.__dict__))\n\n    # ################### Metadata #######################\n    with open(args.metadata_file, 'r') as f:\n        args.metadata = json.load(f)\n        args.num_head = len(args.metadata.items())\n    logger.info(f'args.num_head: {args.num_head}, ')\n    logger.info(f'args.metadata: {args.metadata}.')\n\n    # ################### Language Encoder #######################\n\n    # load vocab file\n    with open(args.vocab_file, 'r') as f:\n        vocab = json.load(f)\n    text_token_to_idx = vocab['text_token_to_idx']\n\n    encoder = Encoder(\n        token_to_idx=text_token_to_idx,\n        word_embedding_dim=args.word_embedding_dim,\n        text_embed_size=args.text_embed_size,\n        metadata_file=args.metadata_file,\n        linear_hidden_size=args.linear_hidden_size,\n        linear_dropout_rate=args.linear_dropout_rate)\n    encoder = encoder.to(args.device)\n\n    # ################### DataLoader #######################\n\n    logger.info('Preparing train_dataset')\n\n    train_dataset = EncoderDataset(preprocessed_dir=args.train_set_dir)\n    logger.info('Preparing train_loader')\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.batch_size,\n        shuffle=True,\n        num_workers=args.num_workers,\n        pin_memory=False,\n        sampler=None)\n    logger.info('Preparing val_dataset')\n    val_dataset = EncoderDataset(preprocessed_dir=args.val_set_dir)\n    logger.info('Preparing val_loader')\n    val_loader = torch.utils.data.DataLoader(\n        val_dataset,\n        batch_size=args.val_batch,\n        shuffle=False,\n        num_workers=args.num_workers,\n        pin_memory=False)\n    logger.info(f'Number of train text: {len(train_dataset)}, '\n                f'Number of val text: {len(val_dataset)}.')\n    data_loader = {\n        'train': train_loader,\n        'val': val_loader,\n    }\n\n    # ################### Optimizer #######################\n\n    optimizer = torch.optim.Adam(\n        encoder.parameters(), args.lr, weight_decay=args.weight_decay)\n\n    # ################### Loss Function #######################\n    criterion = nn.CrossEntropyLoss(\n        reduction='mean', ignore_index=args.unlabeled_value)\n\n    # ################### Epochs #######################\n\n    for epoch in range(args.num_epochs):\n        logger.info(\n            '----------- Training: Epoch '\n            f'({epoch + 1} / {args.num_epochs}),  LR: {args.lr:.4f}. ---------'\n        )\n        train_per_head_acc_avg, train_overall_acc = train(\n            args,\n            'train',\n            encoder,\n            data_loader['train'],\n            criterion,\n            optimizer,\n            logger,\n        )\n        logger.info(\n            'Train accuracy '\n            f'({epoch + 1} / {args.num_epochs}), '\n            f'{[str(round(i, 2))+\"%\" for i in train_per_head_acc_avg]}')\n        val_per_head_acc_avg, val_overall_acc = train(\n            args,\n            'val',\n            encoder,\n            data_loader['val'],\n            criterion,\n            optimizer,\n            logger,\n        )\n        logger.info('Validation accuracy '\n                    f'({epoch + 1} / {args.num_epochs}), '\n                    f'{[str(round(i, 2))+\"%\" for i in val_per_head_acc_avg]}')\n\n        # whether this epoch has the highest val acc so far\n        is_best = val_overall_acc > best_val_acc\n        if is_best:\n            best_epoch = epoch + 1\n            best_val_acc = val_overall_acc\n        logger.info(\n            f'Best Epoch: {best_epoch}, best acc: {best_val_acc: .4f}.')\n        save_checkpoint(\n            args, {\n                'epoch': epoch + 1,\n                'best_epoch_so_far': best_epoch,\n                'state_dict': encoder.state_dict(),\n                'best_val_acc': best_val_acc,\n                'optimizer': optimizer.state_dict(),\n            },\n            is_best,\n            checkpoint=args.work_dir)\n    logger.info('successful')\n\n\ndef train(args, phase, encoder, data_loader, criterion, optimizer, logger):\n\n    if phase == 'train':\n        encoder.train()\n    else:\n        encoder.eval()\n\n    # record time\n    batch_time = AverageMeter()\n    data_time = AverageMeter()\n    end = time.time()\n\n    # record accuracy\n    per_head_acc_list = [AverageMeter() for _ in range(args.num_head)]\n\n    for batch_idx, batch_data in enumerate(data_loader):\n        data_time.update(time.time() - end)\n\n        text, system_mode, labels = batch_data\n        text = text.to(args.device)\n        system_mode = system_mode.to(args.device)\n        labels = labels.to(args.device)\n\n        if phase == 'train':\n            output = encoder(text)\n        else:\n            with torch.no_grad():\n                output = encoder(text)\n        loss_list = []\n\n        # Labels: loss and acc\n        for head_idx, (key, val) in enumerate(args.metadata.items()):\n            loss = criterion(output[head_idx], labels[:, head_idx])\n            loss_list.append(loss)\n            acc_dict = head_accuracy(\n                output=output[head_idx],\n                target=labels[:, head_idx],\n                unlabeled_value=args.unlabeled_value)\n            acc = acc_dict['acc']\n            labeled_count = int(acc_dict['labeled_count'])\n            if labeled_count > 0:\n                per_head_acc_list[head_idx].update(acc, labeled_count)\n\n        loss_avg = sum(loss_list) / len(loss_list)\n\n        if phase == 'train':\n            optimizer.zero_grad()\n            loss_avg.backward()\n            optimizer.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        logger.info(\n            f'Batch: {batch_idx+1}, '\n            f'Data time: {data_time.avg:.3f}s, Batch time: {batch_time.avg:.3f}s, '  # noqa\n            f'loss: {loss_avg:.4f}.')\n\n    overall_acc = 0\n    per_head_acc_avg = []\n    for head_idx in range(args.num_head):\n        per_head_acc_avg.append(per_head_acc_list[head_idx].avg)\n        overall_acc += per_head_acc_list[head_idx].avg\n    overall_acc = overall_acc / args.num_head\n    return per_head_acc_avg, overall_acc\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "language/utils/__init__.py",
    "content": "\"\"\"Useful utils\n\"\"\"\n# progress bar\nimport os\nimport sys\n\nfrom .eval import *  # noqa\nfrom .logger import *  # noqa\nfrom .lr_schedule import *  # noqa\nfrom .misc import *  # noqa\nfrom .numerical import *  # noqa\nfrom .visualize import *  # noqa\n\nsys.path.append(os.path.join(os.path.dirname(__file__), \"progress\"))\nfrom progress.bar import Bar as Bar  # noqa\n"
  },
  {
    "path": "language/utils/eval.py",
    "content": "from __future__ import absolute_import, print_function\n\nimport torch\n\n__all__ = ['classification_accuracy', 'regression_accuracy']\n\n\ndef classification_accuracy(output,\n                            target,\n                            class_wise=False,\n                            num_cls=6,\n                            excluded_cls_idx=None):\n    \"\"\"\n    Computes the precision@k for the specified values of k\n    output: batch_size * num_cls (for a specific attribute)\n    target: batch_size * 1 (for a specific attribute)\n    return res: res = 100 * num_correct / batch_size, for a specific attribute\n    for a batch\n    \"\"\"\n\n    with torch.no_grad():\n        batch_size = target.size(0)\n\n        # _ = the largest score, pred = cls_idx with the largest score\n        _, pred = output.topk(1, 1, True, True)\n        pred = pred.reshape(-1)\n\n        acc = float(torch.sum(pred == target)) / float(batch_size) * 100\n        return_dict = {'acc': acc}\n\n        if excluded_cls_idx is not None:\n            correct_count = torch.sum(\n                (pred == target) * (target != excluded_cls_idx))\n            labeled_count = torch.sum(target != excluded_cls_idx)\n            if labeled_count:\n                labeled_acc = float(correct_count) / float(labeled_count) * 100\n            else:\n                labeled_acc = 0\n\n            return_dict['labeled_acc'] = labeled_acc\n            return_dict['labeled_count'] = labeled_count\n        else:\n            return_dict['labeled_acc'] = acc\n            return_dict['labeled_count'] = batch_size\n\n        if class_wise:\n            acc_class_wise = []\n            per_class_count = []\n            # actual number of classes <= num_cls=6\n            for i in range(num_cls):\n                total_sample_cls_i = torch.sum(target == i)\n                if total_sample_cls_i:\n                    correct_samples_cls_i = torch.sum(\n                        (pred == i) * (target == i))\n                    acc_class_wise.append(\n                        float(correct_samples_cls_i) /\n                        float(total_sample_cls_i) * 100)\n                else:\n                    acc_class_wise.append(0)\n                per_class_count.append(total_sample_cls_i)\n\n        return_dict['acc_class_wise'] = acc_class_wise\n        return_dict['per_class_count'] = per_class_count\n\n        return return_dict\n\n\ndef regression_accuracy(output,\n                        target,\n                        margin=0.2,\n                        uni_neg=True,\n                        class_wise=False,\n                        num_cls=6,\n                        excluded_cls_idx=None,\n                        max_cls_value=5):\n    \"\"\"\n    Computes the regression accuracy\n\n    if predicted score is less than one margin from the ground-truth score, we\n    consider it as correct otherwise it is incorrect， the acc is the\n    percentage of correct regression\n\n    class_wise: if True, then report overall accuracy and class-wise accuracy\n                else, then only report overall accuracy\n    \"\"\"\n\n    output = output.clone().reshape(-1)\n\n    if uni_neg:\n        output[(output <= 0 + margin) * (target == 0)] = 0\n        output[(output >= max_cls_value - margin) *\n               (target == max_cls_value)] = max_cls_value\n\n    distance = torch.absolute(target - output)\n    distance = distance - margin\n\n    predicted_class = torch.zeros_like(target)\n    # if distance <= 0, assign ground truth class\n    predicted_class[distance <= 0] = target[distance <= 0]\n    # if distance > 0, assign an invalid value\n    predicted_class[distance > 0] = -1\n\n    acc = float(torch.sum(predicted_class == target)) / float(\n        target.size(0)) * 100\n\n    return_dict = {'acc': acc}\n\n    if excluded_cls_idx is not None:\n        correct_count = torch.sum(\n            (predicted_class == target) * (target != excluded_cls_idx))\n        labeled_count = torch.sum(target != excluded_cls_idx)\n        if labeled_count:\n            labeled_acc = float(correct_count) / float(labeled_count) * 100\n        else:\n            labeled_acc = 0\n        return_dict['labeled_acc'] = labeled_acc\n        return_dict['labeled_count'] = labeled_count\n    else:\n        labeled_acc = acc\n        return_dict['labeled_acc'] = acc\n        return_dict['labeled_count'] = target.size(0)\n\n    if class_wise:\n        acc_class_wise = []\n        per_class_count = []\n        for i in range(num_cls):\n            total_sample_cls_i = torch.sum(target == i)\n            if total_sample_cls_i:\n                correct_samples_cls_i = torch.sum(\n                    (predicted_class == i) * (target == i))\n                acc_class_wise.append(\n                    float(correct_samples_cls_i) / float(total_sample_cls_i) *\n                    100)\n            else:\n                acc_class_wise.append(0)\n            per_class_count.append(total_sample_cls_i)\n\n        return_dict['acc_class_wise'] = acc_class_wise\n        return_dict['per_class_count'] = per_class_count\n\n    return return_dict\n\n\ndef main():\n\n    l1 = [\n        0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 2, 2, 2, 1.7, 0, 3, 3, 2.79, 3.3, 0, 4,\n        2, 5, 3, 0, 6, 6, 4.78, 6, 0\n    ]\n    l2 = [\n        0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4,\n        4, 5, 5, 5, 5, 5\n    ]\n\n    output = torch.FloatTensor(l1)\n    target = torch.LongTensor(l2)\n    acc = regression_accuracy(output, target, margin=0.2)\n    print('acc:', acc)\n    print()\n    acc, acc_class_wise_list, per_class_count = regression_accuracy(\n        output, target, margin=0.2, class_wise=True)\n    print('acc:', acc)\n    print('acc_class_wise_list:', acc_class_wise_list)\n    print('per_class_count: ', per_class_count)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "language/utils/logger.py",
    "content": "from __future__ import absolute_import\n\nimport datetime\nimport logging\nimport time\n\nimport matplotlib.pyplot as plt\nimport matplotlib.ticker as plticker\nimport numpy as np\n\n# from mmcv.runner import get_dist_info, master_only\n\n__all__ = [\n    'Logger', 'LoggerMonitor', 'savefig', 'MessageLogger', 'init_tb_logger',\n    'get_root_logger', 'dict2str'\n]\n\n\ndef savefig(fname, dpi=None):\n    dpi = 150 if dpi is None else dpi\n    plt.savefig(fname, dpi=dpi)\n\n\ndef plot_overlap(logger, names=None):\n    names = logger.names if names is None else names\n    numbers = logger.numbers\n    for _, name in enumerate(names):\n        x = np.arange(len(numbers[name]))\n        plt.plot(x, np.asarray(numbers[name]))\n    return [logger.title + '(' + name + ')' for name in names]\n\n\nclass Logger(object):\n    '''Save training process to log file with simple plot function.'''\n\n    def __init__(self, fpath, title=None, resume=False):\n        self.file = None\n        self.resume = resume\n        self.title = '' if title is None else title\n        if fpath is not None:\n            if resume:\n                self.file = open(fpath, 'r')\n                name = self.file.readline()\n                self.names = name.rstrip().split('\\t')\n                self.numbers = {}\n                for _, name in enumerate(self.names):\n                    self.numbers[name] = []\n\n                for numbers in self.file:\n                    numbers = numbers.rstrip().split('\\t')\n                    for i in range(0, len(numbers)):\n                        self.numbers[self.names[i]].append(numbers[i])\n                self.file.close()\n                self.file = open(fpath, 'a')\n            else:\n                self.file = open(fpath, 'w')\n\n    def set_names(self, names):\n        if self.resume:\n            pass\n        # initialize numbers as empty list\n        self.numbers = {}\n        self.names = names\n        for _, name in enumerate(self.names):\n            self.file.write(name)\n            self.file.write('\\t')\n            self.numbers[name] = []\n        self.file.write('\\n')\n        self.file.flush()\n\n    def append(self, numbers):\n        assert len(self.names) == len(numbers), 'Numbers do not match names'\n        for index, num in enumerate(numbers):\n            if type(num) == int:\n                self.file.write(str(num))\n            elif type(num) == float:\n                self.file.write(\"{0:.6f}\".format(num))\n            else:  # str\n                self.file.write(str(num))\n            self.file.write('\\t')\n            self.numbers[self.names[index]].append(num)\n        self.file.write('\\n')\n        self.file.flush()\n\n    def plot(self, out_file, names=None):\n        names = self.names if names is None else names\n        numbers = self.numbers\n        fig, ax = plt.subplots(1, 1)\n        for _, name in enumerate(names):\n            x = np.arange(len(numbers[name]))\n            ax.plot(x, numbers[name])\n\n            # whether add data labels to each point in the plot\n            if False:\n                for i in range(len(x)):\n                    y = numbers[name][i]\n                    # text = round(y, 2) # below 4 line are added by ziqi\n                    if type(y) == int or type(y) == float:\n                        text = round(y, 2)\n                    else:\n                        text = y\n                    ax.text(x[i], y, text)\n\n        ax.legend([self.title + '(' + name + ')' for name in names])\n        loc = plticker.MultipleLocator(\n            base=1.0\n        )  # this locator puts ticks at regular intervals # ziqi added\n        ax.xaxis.set_major_locator(loc)\n        ax.grid(True)\n        plt.savefig(out_file)\n        plt.close()\n\n    def close(self):\n        if self.file is not None:\n            self.file.close()\n\n    def get_numbers(self):\n        stats = {}\n        for name in self.names:\n            stats[name] = self.numbers[name]\n        return stats\n\n\nclass LoggerMonitor(object):\n    '''Load and visualize multiple logs.'''\n\n    def __init__(self, paths):\n        '''paths is a distionary with {name:filepath} pair'''\n        self.loggers = []\n        for title, path in paths.items():\n            logger = Logger(path, title=title, resume=True)\n            self.loggers.append(logger)\n\n    def plot(self, names=None):\n        plt.figure()\n        plt.subplot(121)\n        legend_text = []\n        for logger in self.loggers:\n            legend_text += plot_overlap(logger, names)\n        plt.legend(\n            legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)\n        plt.grid(True)\n\n\nclass MessageLogger():\n    \"\"\"Message logger for printing.\n\n    Args:\n        opt (dict): Config. It contains the following keys:\n            name (str): Exp name.\n            logger (dict): Contains 'print_freq' (str) for logger interval.\n            train (dict): Contains 'niter' (int) for total iters.\n            use_tb_logger (bool): Use tensorboard logger.\n        start_iter (int): Start iter. Default: 1.\n        tb_logger (obj:`tb_logger`): Tensorboard logger. Default： None.\n    \"\"\"\n\n    def __init__(self, opt, start_iter=1, tb_logger=None):\n        self.exp_name = opt['name']\n        self.interval = opt['logger']['print_freq']\n        self.start_iter = start_iter\n        self.max_iters = opt['train']['niter']\n        self.use_tb_logger = opt['use_tb_logger']\n        self.tb_logger = tb_logger\n        self.start_time = time.time()\n        self.logger = get_root_logger()\n\n    # @master_only\n    def __call__(self, log_vars):\n        \"\"\"Format logging message.\n\n        Args:\n            log_vars (dict): It contains the following keys:\n                epoch (int): Epoch number.\n                iter (int): Current iter.\n                lrs (list): List for learning rates.\n\n                time (float): Iter time.\n                data_time (float): Data time for each iter.\n        \"\"\"\n        # epoch, iter, learning rates\n        epoch = log_vars.pop('epoch')\n        current_iter = log_vars.pop('iter')\n        lrs = log_vars.pop('lrs')\n\n        message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, '\n                   f'iter:{current_iter:8,d}, lr:(')\n        for v in lrs:\n            message += f'{v:.3e},'\n        message += ')] '\n\n        # time and estimated time\n        if 'time' in log_vars.keys():\n            iter_time = log_vars.pop('time')\n            data_time = log_vars.pop('data_time')\n\n            total_time = time.time() - self.start_time\n            time_sec_avg = total_time / (current_iter - self.start_iter + 1)\n            eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)\n            eta_str = str(datetime.timedelta(seconds=int(eta_sec)))\n            message += f'[eta: {eta_str}, '\n            message += f'time: {iter_time:.3f}, data_time: {data_time:.3f}] '\n\n        # other items, especially losses\n        for k, v in log_vars.items():\n            message += f'{k}: {v:.4e} '\n            # tensorboard logger\n            if self.use_tb_logger and 'debug' not in self.exp_name:\n                self.tb_logger.add_scalar(k, v, current_iter)\n\n        self.logger.info(message)\n\n\n# @master_only\ndef init_tb_logger(log_dir):\n    from torch.utils.tensorboard import SummaryWriter\n    tb_logger = SummaryWriter(log_dir=log_dir)\n    return tb_logger\n\n\ndef get_root_logger(logger_name='base', log_level=logging.INFO, log_file=None):\n    \"\"\"Get the root logger.\n\n    The logger will be initialized if it has not been initialized. By default a\n    StreamHandler will be added. If `log_file` is specified, a FileHandler will\n    also be added.\n\n    Args:\n        logger_name (str): root logger name. Default: base.\n        log_file (str | None): The log filename. If specified, a FileHandler\n            will be added to the root logger.\n        log_level (int): The root logger level. Note that only the process of\n            rank 0 is affected, while other processes will set the level to\n            \"Error\" and be silent most of the time.\n\n    Returns:\n        logging.Logger: The root logger.\n    \"\"\"\n    logger = logging.getLogger(logger_name)\n    # if the logger has been initialized, just return it\n    if logger.hasHandlers():\n        return logger\n\n    format_str = '%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s'\n    logging.basicConfig(format=format_str, level=log_level)\n    if log_file is not None:\n        file_handler = logging.FileHandler(log_file, 'w')\n        file_handler.setFormatter(logging.Formatter(format_str))\n        file_handler.setLevel(log_level)\n        logger.addHandler(file_handler)\n\n    return logger\n\n\ndef dict2str(opt, indent_level=1):\n    \"\"\"dict to string for printing options.\n\n    Args:\n        opt (dict): Option dict.\n        indent_level (int): Indent level. Default: 1.\n\n    Return:\n        (str): Option string for printing.\n    \"\"\"\n    msg = ''\n    for k, v in opt.items():\n        if isinstance(v, dict):\n            msg += ' ' * (indent_level * 2) + k + ':[\\n'\n            msg += dict2str(v, indent_level + 1)\n            msg += ' ' * (indent_level * 2) + ']\\n'\n        else:\n            msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\\n'\n    return msg\n"
  },
  {
    "path": "language/utils/lr_schedule.py",
    "content": "import math\n\n__all__ = ['adjust_learning_rate']\n\n\ndef adjust_learning_rate(args, optimizer, epoch):\n    lr = optimizer.param_groups[0]['lr']\n    \"\"\"\n    Sets the learning rate to the initial LR decayed by 10 following schedule\n    \"\"\"\n    if args.lr_decay == 'step':\n        lr = args.lr * (args.gamma**(epoch // args.step))\n    elif args.lr_decay == 'cos':\n        lr = args.lr * (1 + math.cos(math.pi * epoch / args.epochs)) / 2\n    elif args.lr_decay == 'linear':\n        lr = args.lr * (1 - epoch / args.epochs)\n    elif args.lr_decay == 'linear2exp':\n        if epoch < args.turning_point + 1:\n            # learning rate decay as 95%\n            # at the turning point (1 / 95% = 1.0526)\n            lr = args.lr * (1 - epoch / int(args.turning_point * 1.0526))\n        else:\n            lr *= args.gamma\n    elif args.lr_decay == 'schedule':\n        if epoch in args.schedule:\n            lr *= args.gamma\n    else:\n        raise ValueError('Unknown lr mode {}'.format(args.lr_decay))\n\n    for param_group in optimizer.param_groups:\n        param_group['lr'] = lr\n    return lr\n"
  },
  {
    "path": "language/utils/misc.py",
    "content": "'''Some helper functions for PyTorch, including:\n    - get_mean_and_std: calculate the mean and std value of dataset.\n    - msr_init: net parameter initialization.\n    - progress_bar: progress bar mimic xlua.progress.\n'''\nimport errno\nimport os\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.init as init\n\n__all__ = [\n    'get_mean_and_std', 'init_params', 'mkdir_p', 'save_checkpoint',\n    'AverageMeter'\n]\n\n\ndef get_mean_and_std(dataset):\n    '''Compute the mean and std value of dataset.'''\n    dataloader = trainloader = torch.utils.data.DataLoader(  # noqa\n        dataset, batch_size=1, shuffle=True, num_workers=2)\n\n    mean = torch.zeros(3)\n    std = torch.zeros(3)\n    print('==> Computing mean and std..')\n    for inputs, targets in dataloader:\n        for i in range(3):\n            mean[i] += inputs[:, i, :, :].mean()\n            std[i] += inputs[:, i, :, :].std()\n    mean.div_(len(dataset))\n    std.div_(len(dataset))\n    return mean, std\n\n\ndef init_params(net):\n    '''Init layer parameters.'''\n    for m in net.modules():\n        if isinstance(m, nn.Conv2d):\n            init.kaiming_normal(m.weight, mode='fan_out')\n            if m.bias:\n                init.constant(m.bias, 0)\n        elif isinstance(m, nn.BatchNorm2d):\n            init.constant(m.weight, 1)\n            init.constant(m.bias, 0)\n        elif isinstance(m, nn.Linear):\n            init.normal(m.weight, std=1e-3)\n            if m.bias:\n                init.constant(m.bias, 0)\n\n\ndef mkdir_p(path):\n    '''make dir if not exist'''\n    try:\n        os.makedirs(path)\n    except OSError as exc:  # Python >2.5\n        if exc.errno == errno.EEXIST and os.path.isdir(path):\n            pass\n        else:\n            raise\n\n\ndef save_checkpoint(args,\n                    state,\n                    is_best,\n                    checkpoint='checkpoint',\n                    filename='checkpoint.pth.tar'):\n    epoch = str(state['epoch']).zfill(2)\n    save_every_epoch = True\n    if not os.path.exists(os.path.join(args.work_dir, 'checkpoints')):\n        os.makedirs(os.path.join(args.work_dir, 'checkpoints'))\n    if save_every_epoch:\n        filename = 'checkpoint_' + epoch + '.pth.tar'\n        filepath = os.path.join(checkpoint, 'checkpoints', filename)\n        torch.save(state, filepath)\n    if is_best:\n        filename = 'model_best.pth.tar'\n        filepath = os.path.join(checkpoint, 'checkpoints', filename)\n        torch.save(state, filepath)\n        # shutil.copyfile(filepath, os.path.join(checkpoint, \\\n        # 'model_best_'+epoch+'.pth.tar'))\n\n\nclass AverageMeter(object):\n    \"\"\"\n    Computes and stores the average and current value\n    Imported from\n    https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262\n    \"\"\"\n\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0  # running average = running sum / running count\n        self.sum = 0  # running sum\n        self.count = 0  # running count\n\n    def update(self, val, n=1):\n        # n = batch_size\n\n        # val = batch accuracy for an attribute\n        # self.val = val\n\n        # sum = 100 * accumulative correct predictions for this attribute\n        self.sum += val * n\n\n        # count = total samples so far\n        self.count += n\n\n        # avg = 100 * avg accuracy for this attribute\n        # for all the batches so far\n        self.avg = self.sum / self.count\n"
  },
  {
    "path": "language/utils/numerical.py",
    "content": "import json\n\nimport numpy as np\n\n__all__ = ['get_weight', 'transpose_and_format']\n\n\ndef get_weight(args):\n    \"\"\"\n    read the attribute class distribution file stats.txt and return the counts\n    \"\"\"\n\n    # read counts from stats file\n    stats_f = open(args.stats_file, \"r\")\n\n    # each list [] in the count_list is for one attribute\n    # each value in [] is the number of training samples\n    # for that attribute value\n    count_list = []\n    for i in range(args.num_attr):\n        count_list.append([])\n    for row_idx, row in enumerate(stats_f):\n        # row 0 is attr names, row 1 is unlabeled statistics\n        if row_idx == 0 or row_idx == 1:\n            continue\n        # [:-1] because the last value is the new line character\n        row = row.split(' ')[:-1]\n        for new_idx_in_row, attr_val in enumerate(row):\n            # print('num_idx:', num_idx, 'num:', num)\n            if new_idx_in_row == 0:\n                continue\n            new_idx = new_idx_in_row - 1\n            count_list[new_idx].append((int(attr_val)))  # **0.5)\n\n    # weight for gt_remapping case\n    count_list = np.array(count_list)\n    num_attr = count_list.shape[0]\n    num_cls = count_list.shape[1]\n\n    if args.gt_remapping:\n        remap_count_list = np.zeros((num_attr, num_cls))\n        for attr_idx in range(num_attr):\n            for cls_idx in range(num_cls):\n                new_cls_idx = int(args.gt_remapping[attr_idx][cls_idx])\n                remap_count_list[attr_idx][new_cls_idx] += count_list[\n                    attr_idx][cls_idx]\n        count_list = remap_count_list\n\n    # For each attribute, among classes, weight Inversion and Normalization\n    value_weights = []\n    for attr_idx in range(num_attr):\n        weight_l = np.zeros(num_cls)\n        for cls_idx in range(num_cls):\n            weight_l[cls_idx] = (1 / count_list[attr_idx][cls_idx]\n                                 ) if count_list[attr_idx][cls_idx] else 0\n\n        # normalize weight_l so that their average value is 1\n        normalized_weight_l = np.zeros(num_cls)\n        for cls_idx in range(num_cls):\n            normalized_weight_l[cls_idx] = weight_l[cls_idx] / sum(weight_l)\n        value_weights.append(normalized_weight_l)\n\n    # Among attributes, weight Inversion and Normalization\n    # count_sum_list = []\n    # for a_list in count_list:\n    #     count_sum_list.append(sum(a_list))\n    # count_sum = sum(count_sum_list)\n    # attribute_weights = []\n    # for i in range(len(count_sum_list)):\n    #     attribute_weight = count_sum / count_sum_list[i]\n    #     attribute_weights.append(attribute_weight)\n    # # normalize attribute_weights so that their average value is 1\n    # normalized_attribute_weights = []\n    # for i in range(len(attribute_weights)):\n    #     normalized_attribute_weights.append(attribute_weights[i] /\n    #                                         sum(attribute_weights) *\n    #                                         len(attribute_weights))\n\n    weights = {'value_weights': value_weights}\n\n    return weights\n\n\ndef transpose_and_format(args, input):\n    \"\"\"\n    input = [\n        [#, #, #, #, #, #],\n        [#, #, #, #, #, #],\n        [#, #, #, #, #, #]\n    ]\n    where outer loop is attribute\n    inner loop is class labels\n\n    new_f:\n\n    attr_val Bangs Smiling Young\n    0 # # #\n    1 # # #\n    2 # # #\n    3 # # #\n    4 # # #\n    5 # # #\n    \"\"\"\n\n    with open(args.attr_file, 'r') as f:\n        attr_f = json.load(f)\n    attr_info = attr_f['attr_info']\n    attr_list = ['attr_val']\n    for key, val in attr_info.items():\n        attr_list.append(val[\"name\"])\n\n    # new_f stores the output\n    new_f = []\n\n    # first line is the header\n    new_f.append(attr_list)\n    for i in range(len(input[0])):\n        row = []\n        row.append(i)\n        for j in range(args.num_attr):\n            row.append(round(input[j][i].item(), 2))\n            # row.append(round(input[j][i], 2))\n        new_f.append(row)\n    return new_f\n"
  },
  {
    "path": "language/utils/progress/.gitignore",
    "content": ""
  },
  {
    "path": "language/utils/progress/LICENSE",
    "content": "# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>\n#\n# Permission to use, copy, modify, and distribute this software for any\n# purpose with or without fee is hereby granted, provided that the above\n# copyright notice and this permission notice appear in all copies.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\" AND THE AUTHOR DISCLAIMS ALL WARRANTIES\n# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF\n# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR\n# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES\n# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN\n# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF\n# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.\n"
  },
  {
    "path": "language/utils/progress/MANIFEST.in",
    "content": "include README.rst LICENSE\n"
  },
  {
    "path": "language/utils/progress/README.rst",
    "content": "Easy progress reporting for Python\n==================================\n\n|pypi|\n\n|demo|\n\n.. |pypi| image:: https://img.shields.io/pypi/v/progress.svg\n.. |demo| image:: https://raw.github.com/verigak/progress/master/demo.gif\n   :alt: Demo\n\nBars\n----\n\nThere are 7 progress bars to choose from:\n\n- ``Bar``\n- ``ChargingBar``\n- ``FillingSquaresBar``\n- ``FillingCirclesBar``\n- ``IncrementalBar``\n- ``PixelBar``\n- ``ShadyBar``\n\nTo use them, just call ``next`` to advance and ``finish`` to finish:\n\n.. code-block:: python\n\n    from progress.bar import Bar\n\n    bar = Bar('Processing', max=20)\n    for i in range(20):\n        # Do some work\n        bar.next()\n    bar.finish()\n\nThe result will be a bar like the following: ::\n\n    Processing |#############                   | 42/100\n\nTo simplify the common case where the work is done in an iterator, you can\nuse the ``iter`` method:\n\n.. code-block:: python\n\n    for i in Bar('Processing').iter(it):\n        # Do some work\n\nProgress bars are very customizable, you can change their width, their fill\ncharacter, their suffix and more:\n\n.. code-block:: python\n\n    bar = Bar('Loading', fill='@', suffix='%(percent)d%%')\n\nThis will produce a bar like the following: ::\n\n    Loading |@@@@@@@@@@@@@                   | 42%\n\nYou can use a number of template arguments in ``message`` and ``suffix``:\n\n==========  ================================\nName        Value\n==========  ================================\nindex       current value\nmax         maximum value\nremaining   max - index\nprogress    index / max\npercent     progress * 100\navg         simple moving average time per item (in seconds)\nelapsed     elapsed time in seconds\nelapsed_td  elapsed as a timedelta (useful for printing as a string)\neta         avg * remaining\neta_td      eta as a timedelta (useful for printing as a string)\n==========  ================================\n\nInstead of passing all configuration options on instatiation, you can create\nyour custom subclass:\n\n.. code-block:: python\n\n    class FancyBar(Bar):\n        message = 'Loading'\n        fill = '*'\n        suffix = '%(percent).1f%% - %(eta)ds'\n\nYou can also override any of the arguments or create your own:\n\n.. code-block:: python\n\n    class SlowBar(Bar):\n        suffix = '%(remaining_hours)d hours remaining'\n        @property\n        def remaining_hours(self):\n            return self.eta // 3600\n\n\nSpinners\n========\n\nFor actions with an unknown number of steps you can use a spinner:\n\n.. code-block:: python\n\n    from progress.spinner import Spinner\n\n    spinner = Spinner('Loading ')\n    while state != 'FINISHED':\n        # Do some work\n        spinner.next()\n\nThere are 5 predefined spinners:\n\n- ``Spinner``\n- ``PieSpinner``\n- ``MoonSpinner``\n- ``LineSpinner``\n- ``PixelSpinner``\n\n\nOther\n=====\n\nThere are a number of other classes available too, please check the source or\nsubclass one of them to create your own.\n\n\nLicense\n=======\n\nprogress is licensed under ISC\n"
  },
  {
    "path": "language/utils/progress/progress/__init__.py",
    "content": "# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>\n#\n# Permission to use, copy, modify, and distribute this software for any\n# purpose with or without fee is hereby granted, provided that the above\n# copyright notice and this permission notice appear in all copies.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\" AND THE AUTHOR DISCLAIMS ALL WARRANTIES\n# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF\n# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR\n# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES\n# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN\n# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF\n# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.\n\nfrom __future__ import division\n\nfrom collections import deque\nfrom datetime import timedelta\nfrom math import ceil\nfrom sys import stderr\nfrom time import time\n\n\n__version__ = '1.3'\n\n\nclass Infinite(object):\n    file = stderr\n    sma_window = 10         # Simple Moving Average window\n\n    def __init__(self, *args, **kwargs):\n        self.index = 0\n        self.start_ts = time()\n        self.avg = 0\n        self._ts = self.start_ts\n        self._xput = deque(maxlen=self.sma_window)\n        for key, val in kwargs.items():\n            setattr(self, key, val)\n\n    def __getitem__(self, key):\n        if key.startswith('_'):\n            return None\n        return getattr(self, key, None)\n\n    @property\n    def elapsed(self):\n        return int(time() - self.start_ts)\n\n    @property\n    def elapsed_td(self):\n        return timedelta(seconds=self.elapsed)\n\n    def update_avg(self, n, dt):\n        if n > 0:\n            self._xput.append(dt / n)\n            self.avg = sum(self._xput) / len(self._xput)\n\n    def update(self):\n        pass\n\n    def start(self):\n        pass\n\n    def finish(self):\n        pass\n\n    def next(self, n=1):\n        now = time()\n        dt = now - self._ts\n        self.update_avg(n, dt)\n        self._ts = now\n        self.index = self.index + n\n        self.update()\n\n    def iter(self, it):\n        try:\n            for x in it:\n                yield x\n                self.next()\n        finally:\n            self.finish()\n\n\nclass Progress(Infinite):\n    def __init__(self, *args, **kwargs):\n        super(Progress, self).__init__(*args, **kwargs)\n        self.max = kwargs.get('max', 100)\n\n    @property\n    def eta(self):\n        return int(ceil(self.avg * self.remaining))\n\n    @property\n    def eta_td(self):\n        return timedelta(seconds=self.eta)\n\n    @property\n    def percent(self):\n        return self.progress * 100\n\n    @property\n    def progress(self):\n        return min(1, self.index / self.max)\n\n    @property\n    def remaining(self):\n        return max(self.max - self.index, 0)\n\n    def start(self):\n        self.update()\n\n    def goto(self, index):\n        incr = index - self.index\n        self.next(incr)\n\n    def iter(self, it):\n        try:\n            self.max = len(it)\n        except TypeError:\n            pass\n\n        try:\n            for x in it:\n                yield x\n                self.next()\n        finally:\n            self.finish()\n"
  },
  {
    "path": "language/utils/progress/progress/bar.py",
    "content": "# -*- coding: utf-8 -*-\n\n# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>\n#\n# Permission to use, copy, modify, and distribute this software for any\n# purpose with or without fee is hereby granted, provided that the above\n# copyright notice and this permission notice appear in all copies.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\" AND THE AUTHOR DISCLAIMS ALL WARRANTIES\n# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF\n# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR\n# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES\n# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN\n# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF\n# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.\n\nfrom __future__ import unicode_literals\nfrom . import Progress\nfrom .helpers import WritelnMixin\n\n\nclass Bar(WritelnMixin, Progress):\n    width = 32\n    message = ''\n    suffix = '%(index)d/%(max)d'\n    bar_prefix = ' |'\n    bar_suffix = '| '\n    empty_fill = ' '\n    fill = '#'\n    hide_cursor = True\n\n    def update(self):\n        filled_length = int(self.width * self.progress)\n        empty_length = self.width - filled_length\n\n        message = self.message % self\n        bar = self.fill * filled_length\n        empty = self.empty_fill * empty_length\n        suffix = self.suffix % self\n        line = ''.join([message, self.bar_prefix, bar, empty, self.bar_suffix,\n                        suffix])\n        self.writeln(line)\n\n\nclass ChargingBar(Bar):\n    suffix = '%(percent)d%%'\n    bar_prefix = ' '\n    bar_suffix = ' '\n    empty_fill = '∙'\n    fill = '█'\n\n\nclass FillingSquaresBar(ChargingBar):\n    empty_fill = '▢'\n    fill = '▣'\n\n\nclass FillingCirclesBar(ChargingBar):\n    empty_fill = '◯'\n    fill = '◉'\n\n\nclass IncrementalBar(Bar):\n    phases = (' ', '▏', '▎', '▍', '▌', '▋', '▊', '▉', '█')\n\n    def update(self):\n        nphases = len(self.phases)\n        filled_len = self.width * self.progress\n        nfull = int(filled_len)                      # Number of full chars\n        phase = int((filled_len - nfull) * nphases)  # Phase of last char\n        nempty = self.width - nfull                  # Number of empty chars\n\n        message = self.message % self\n        bar = self.phases[-1] * nfull\n        current = self.phases[phase] if phase > 0 else ''\n        empty = self.empty_fill * max(0, nempty - len(current))\n        suffix = self.suffix % self\n        line = ''.join([message, self.bar_prefix, bar, current, empty,\n                        self.bar_suffix, suffix])\n        self.writeln(line)\n\n\nclass PixelBar(IncrementalBar):\n    phases = ('⡀', '⡄', '⡆', '⡇', '⣇', '⣧', '⣷', '⣿')\n\n\nclass ShadyBar(IncrementalBar):\n    phases = (' ', '░', '▒', '▓', '█')\n"
  },
  {
    "path": "language/utils/progress/progress/counter.py",
    "content": "# -*- coding: utf-8 -*-\n\n# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>\n#\n# Permission to use, copy, modify, and distribute this software for any\n# purpose with or without fee is hereby granted, provided that the above\n# copyright notice and this permission notice appear in all copies.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\" AND THE AUTHOR DISCLAIMS ALL WARRANTIES\n# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF\n# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR\n# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES\n# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN\n# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF\n# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.\n\nfrom __future__ import unicode_literals\nfrom . import Infinite, Progress\nfrom .helpers import WriteMixin\n\n\nclass Counter(WriteMixin, Infinite):\n    message = ''\n    hide_cursor = True\n\n    def update(self):\n        self.write(str(self.index))\n\n\nclass Countdown(WriteMixin, Progress):\n    hide_cursor = True\n\n    def update(self):\n        self.write(str(self.remaining))\n\n\nclass Stack(WriteMixin, Progress):\n    phases = (' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█')\n    hide_cursor = True\n\n    def update(self):\n        nphases = len(self.phases)\n        i = min(nphases - 1, int(self.progress * nphases))\n        self.write(self.phases[i])\n\n\nclass Pie(Stack):\n    phases = ('○', '◔', '◑', '◕', '●')\n"
  },
  {
    "path": "language/utils/progress/progress/helpers.py",
    "content": "# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>\n#\n# Permission to use, copy, modify, and distribute this software for any\n# purpose with or without fee is hereby granted, provided that the above\n# copyright notice and this permission notice appear in all copies.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\" AND THE AUTHOR DISCLAIMS ALL WARRANTIES\n# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF\n# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR\n# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES\n# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN\n# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF\n# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.\n\nfrom __future__ import print_function\n\n\nHIDE_CURSOR = '\\x1b[?25l'\nSHOW_CURSOR = '\\x1b[?25h'\n\n\nclass WriteMixin(object):\n    hide_cursor = False\n\n    def __init__(self, message=None, **kwargs):\n        super(WriteMixin, self).__init__(**kwargs)\n        self._width = 0\n        if message:\n            self.message = message\n\n        if self.file.isatty():\n            if self.hide_cursor:\n                print(HIDE_CURSOR, end='', file=self.file)\n            print(self.message, end='', file=self.file)\n            self.file.flush()\n\n    def write(self, s):\n        if self.file.isatty():\n            b = '\\b' * self._width\n            c = s.ljust(self._width)\n            print(b + c, end='', file=self.file)\n            self._width = max(self._width, len(s))\n            self.file.flush()\n\n    def finish(self):\n        if self.file.isatty() and self.hide_cursor:\n            print(SHOW_CURSOR, end='', file=self.file)\n\n\nclass WritelnMixin(object):\n    hide_cursor = False\n\n    def __init__(self, message=None, **kwargs):\n        super(WritelnMixin, self).__init__(**kwargs)\n        if message:\n            self.message = message\n\n        if self.file.isatty() and self.hide_cursor:\n            print(HIDE_CURSOR, end='', file=self.file)\n\n    def clearln(self):\n        if self.file.isatty():\n            print('\\r\\x1b[K', end='', file=self.file)\n\n    def writeln(self, line):\n        if self.file.isatty():\n            self.clearln()\n            print(line, end='', file=self.file)\n            self.file.flush()\n\n    def finish(self):\n        if self.file.isatty():\n            print(file=self.file)\n            if self.hide_cursor:\n                print(SHOW_CURSOR, end='', file=self.file)\n\n\nfrom signal import signal, SIGINT\nfrom sys import exit\n\n\nclass SigIntMixin(object):\n    \"\"\"Registers a signal handler that calls finish on SIGINT\"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super(SigIntMixin, self).__init__(*args, **kwargs)\n        signal(SIGINT, self._sigint_handler)\n\n    def _sigint_handler(self, signum, frame):\n        self.finish()\n        exit(0)\n"
  },
  {
    "path": "language/utils/progress/progress/spinner.py",
    "content": "# -*- coding: utf-8 -*-\n\n# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>\n#\n# Permission to use, copy, modify, and distribute this software for any\n# purpose with or without fee is hereby granted, provided that the above\n# copyright notice and this permission notice appear in all copies.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\" AND THE AUTHOR DISCLAIMS ALL WARRANTIES\n# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF\n# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR\n# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES\n# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN\n# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF\n# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.\n\nfrom __future__ import unicode_literals\nfrom . import Infinite\nfrom .helpers import WriteMixin\n\n\nclass Spinner(WriteMixin, Infinite):\n    message = ''\n    phases = ('-', '\\\\', '|', '/')\n    hide_cursor = True\n\n    def update(self):\n        i = self.index % len(self.phases)\n        self.write(self.phases[i])\n\n\nclass PieSpinner(Spinner):\n    phases = ['◷', '◶', '◵', '◴']\n\n\nclass MoonSpinner(Spinner):\n    phases = ['◑', '◒', '◐', '◓']\n\n\nclass LineSpinner(Spinner):\n    phases = ['⎺', '⎻', '⎼', '⎽', '⎼', '⎻']\n\nclass PixelSpinner(Spinner):\n    phases = ['⣾','⣷', '⣯', '⣟', '⡿', '⢿', '⣻', '⣽']\n"
  },
  {
    "path": "language/utils/progress/setup.py",
    "content": "#!/usr/bin/env python\n\nfrom setuptools import setup\n\nimport progress\n\n\nsetup(\n    name='progress',\n    version=progress.__version__,\n    description='Easy to use progress bars',\n    long_description=open('README.rst').read(),\n    author='Giorgos Verigakis',\n    author_email='verigak@gmail.com',\n    url='http://github.com/verigak/progress/',\n    license='ISC',\n    packages=['progress'],\n    classifiers=[\n        'Environment :: Console',\n        'Intended Audience :: Developers',\n        'License :: OSI Approved :: ISC License (ISCL)',\n        'Programming Language :: Python :: 2.6',\n        'Programming Language :: Python :: 2.7',\n        'Programming Language :: Python :: 3.3',\n        'Programming Language :: Python :: 3.4',\n        'Programming Language :: Python :: 3.5',\n        'Programming Language :: Python :: 3.6',\n    ]\n)\n"
  },
  {
    "path": "language/utils/progress/test_progress.py",
    "content": "#!/usr/bin/env python\n\nfrom __future__ import print_function\n\nimport random\nimport time\n\nfrom progress.bar import (Bar, ChargingBar, FillingSquaresBar,\n                          FillingCirclesBar, IncrementalBar, PixelBar,\n                          ShadyBar)\nfrom progress.spinner import (Spinner, PieSpinner, MoonSpinner, LineSpinner,\n                              PixelSpinner)\nfrom progress.counter import Counter, Countdown, Stack, Pie\n\n\ndef sleep():\n    t = 0.01\n    t += t * random.uniform(-0.1, 0.1)  # Add some variance\n    time.sleep(t)\n\n\nfor bar_cls in (Bar, ChargingBar, FillingSquaresBar, FillingCirclesBar):\n    suffix = '%(index)d/%(max)d [%(elapsed)d / %(eta)d / %(eta_td)s]'\n    bar = bar_cls(bar_cls.__name__, suffix=suffix)\n    for i in bar.iter(range(200)):\n        sleep()\n\nfor bar_cls in (IncrementalBar, PixelBar, ShadyBar):\n    suffix = '%(percent)d%% [%(elapsed_td)s / %(eta)d / %(eta_td)s]'\n    bar = bar_cls(bar_cls.__name__, suffix=suffix)\n    for i in bar.iter(range(200)):\n        sleep()\n\nfor spin in (Spinner, PieSpinner, MoonSpinner, LineSpinner, PixelSpinner):\n    for i in spin(spin.__name__ + ' ').iter(range(100)):\n        sleep()\n    print()\n\nfor singleton in (Counter, Countdown, Stack, Pie):\n    for i in singleton(singleton.__name__ + ' ').iter(range(100)):\n        sleep()\n    print()\n\nbar = IncrementalBar('Random', suffix='%(index)d')\nfor i in range(100):\n    bar.goto(random.randint(0, 100))\n    sleep()\nbar.finish()\n"
  },
  {
    "path": "language/utils/setup_logger.py",
    "content": "# python3.7\n\"\"\"Utility functions for logging.\"\"\"\n\nimport logging\nimport os\nimport sys\n\n__all__ = ['setup_logger']\n\n\ndef setup_logger(work_dir=None,\n                 logfile_name='log.txt',\n                 logger_name='logger',\n                 debug=0):\n    \"\"\"Sets up logger from target work directory.\n\n  The function will sets up a logger with `DEBUG` log level. Two handlers will\n  be added to the logger automatically. One is the `sys.stdout` stream, with\n  `INFO` log level, which will print improtant messages on the screen. The other\n  is used to save all messages to file `$WORK_DIR/$LOGFILE_NAME`. Messages will\n  be added time stamp and log level before logged.\n\n  NOTE: If `work_dir` or `logfile_name` is empty, the file stream will be\n  skipped.\n\n  Args:\n    work_dir: The work directory. All intermediate files will be saved here.\n      (default: None)\n    logfile_name: Name of the file to save log message. (default: `log.txt`)\n    logger_name: Unique name for the logger. (default: `logger`)\n\n  Returns:\n    A `logging.Logger` object.\n\n  Raises:\n    SystemExit: If the work directory has already existed, of the logger with\n      specified name `logger_name` has already existed.\n  \"\"\"\n\n    logger = logging.getLogger(logger_name)\n    if logger.hasHandlers():  # Already existed\n        raise SystemExit(\n            f'Logger name `{logger_name}` has already been set up!\\n'\n            f'Please use another name, or otherwise the messages '\n            f'may be mixed between these two loggers.')\n\n    logger.setLevel(logging.DEBUG)\n    formatter = logging.Formatter(\"[%(asctime)s][%(levelname)s] %(message)s\")\n\n    # Print log message with `INFO` level or above onto the screen.\n    sh = logging.StreamHandler(stream=sys.stdout)\n    sh.setLevel(logging.INFO)\n    sh.setFormatter(formatter)\n    logger.addHandler(sh)\n\n    if not work_dir or not logfile_name:\n        return logger\n\n    if os.path.exists(work_dir) and debug == 0:\n        raise SystemExit(f'Work directory `{work_dir}` has already existed!\\n'\n                         f'Please specify another one.')\n    os.makedirs(work_dir, exist_ok=debug)\n\n    # Save log message with all levels in log file.\n    fh = logging.FileHandler(os.path.join(work_dir, logfile_name))\n    fh.setLevel(logging.DEBUG)\n    fh.setFormatter(formatter)\n    logger.addHandler(fh)\n\n    return logger\n"
  },
  {
    "path": "language/utils/visualize.py",
    "content": "import matplotlib.pyplot as plt\nimport torch\nimport torch.nn as nn\nimport torchvision\nimport torchvision.transforms as transforms\nimport numpy as np\nfrom .misc import *   \n\n__all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single']\n\n# functions to show an image\ndef make_image(img, mean=(0,0,0), std=(1,1,1)):\n    for i in range(0, 3):\n        img[i] = img[i] * std[i] + mean[i]    # unnormalize\n    npimg = img.numpy()\n    return np.transpose(npimg, (1, 2, 0))\n\ndef gauss(x,a,b,c):\n    return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a)\n\ndef colorize(x):\n    ''' Converts a one-channel grayscale image to a color heatmap image '''\n    if x.dim() == 2:\n        torch.unsqueeze(x, 0, out=x)\n    if x.dim() == 3:\n        cl = torch.zeros([3, x.size(1), x.size(2)])\n        cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)\n        cl[1] = gauss(x,1,.5,.3)\n        cl[2] = gauss(x,1,.2,.3)\n        cl[cl.gt(1)] = 1\n    elif x.dim() == 4:\n        cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)])\n        cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)\n        cl[:,1,:,:] = gauss(x,1,.5,.3)\n        cl[:,2,:,:] = gauss(x,1,.2,.3)\n    return cl\n\ndef show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):\n    images = make_image(torchvision.utils.make_grid(images), Mean, Std)\n    plt.imshow(images)\n    plt.show()\n\n\ndef show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):\n    im_size = images.size(2)\n\n    # save for adding mask\n    im_data = images.clone()\n    for i in range(0, 3):\n        im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i]    # unnormalize\n\n    images = make_image(torchvision.utils.make_grid(images), Mean, Std)\n    plt.subplot(2, 1, 1)\n    plt.imshow(images)\n    plt.axis('off')\n\n    # for b in range(mask.size(0)):\n    #     mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())\n    mask_size = mask.size(2)\n    # print('Max %f Min %f' % (mask.max(), mask.min()))\n    mask = (upsampling(mask, scale_factor=im_size/mask_size))\n    # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))\n    # for c in range(3):\n    #     mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]\n\n    # print(mask.size())\n    mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data)))\n    # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)\n    plt.subplot(2, 1, 2)\n    plt.imshow(mask)\n    plt.axis('off')\n\ndef show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):\n    im_size = images.size(2)\n\n    # save for adding mask\n    im_data = images.clone()\n    for i in range(0, 3):\n        im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i]    # unnormalize\n\n    images = make_image(torchvision.utils.make_grid(images), Mean, Std)\n    plt.subplot(1+len(masklist), 1, 1)\n    plt.imshow(images)\n    plt.axis('off')\n\n    for i in range(len(masklist)):\n        mask = masklist[i].data.cpu()\n        # for b in range(mask.size(0)):\n        #     mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())\n        mask_size = mask.size(2)\n        # print('Max %f Min %f' % (mask.max(), mask.min()))\n        mask = (upsampling(mask, scale_factor=im_size/mask_size))\n        # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))\n        # for c in range(3):\n        #     mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]\n\n        # print(mask.size())\n        mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data)))\n        # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)\n        plt.subplot(1+len(masklist), 1, i+2)\n        plt.imshow(mask)\n        plt.axis('off')\n\n\n\n# x = torch.zeros(1, 3, 3)\n# out = colorize(x)\n# out_im = make_image(out)\n# plt.imshow(out_im)\n# plt.show()"
  },
  {
    "path": "models/__init__.py",
    "content": "import glob\nimport importlib\nimport logging\nimport os.path as osp\n\n# automatically scan and import model modules\n# scan all the files under the 'models' folder and collect files ending with\n# '_model.py'\nmodel_folder = osp.dirname(osp.abspath(__file__))\nmodel_filenames = [\n    osp.splitext(osp.basename(v))[0]\n    for v in glob.glob(f'{model_folder}/*_model.py')\n]\n# import all the model modules\n_model_modules = [\n    importlib.import_module(f'models.{file_name}')\n    for file_name in model_filenames\n]\n\n\ndef create_model(opt):\n    \"\"\"Create model.\n\n    Args:\n        opt (dict): Configuration. It constains:\n            model_type (str): Model type.\n    \"\"\"\n    model_type = opt['model_type']\n\n    # dynamically instantiation\n    for module in _model_modules:\n        model_cls = getattr(module, model_type, None)\n        if model_cls is not None:\n            break\n    if model_cls is None:\n        raise ValueError(f'Model {model_type} is not found.')\n\n    model = model_cls(opt)\n\n    logger = logging.getLogger('base')\n    logger.info(f'Model [{model.__class__.__name__}] is created.')\n    return model\n"
  },
  {
    "path": "models/archs/__init__.py",
    "content": ""
  },
  {
    "path": "models/archs/attribute_predictor_arch.py",
    "content": "import json\n\nimport torch.nn as nn\nimport torch.utils.model_zoo as model_zoo\n\n__all__ = ['ResNet', 'resnet50']\n\nmodel_urls = {\n    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',\n    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',\n    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',\n    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',\n    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',\n}\n\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(\n        in_planes,\n        out_planes,\n        kernel_size=3,\n        stride=stride,\n        padding=1,\n        bias=False)\n\n\ndef conv1x1(in_planes, out_planes, stride=1):\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(\n        in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(BasicBlock, self).__init__()\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(Bottleneck, self).__init__()\n        self.conv1 = conv1x1(inplanes, planes)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.conv2 = conv3x3(planes, planes, stride)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.conv3 = conv1x1(planes, planes * self.expansion)\n        self.bn3 = nn.BatchNorm2d(planes * self.expansion)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\nclass fc_block(nn.Module):\n\n    def __init__(self, inplanes, planes, drop_rate=0.15):\n        super(fc_block, self).__init__()\n        self.fc = nn.Linear(inplanes, planes)\n        self.bn = nn.BatchNorm1d(planes)\n        if drop_rate > 0:\n            self.dropout = nn.Dropout(drop_rate)\n        self.relu = nn.ReLU(inplace=True)\n        self.drop_rate = drop_rate\n\n    def forward(self, x):\n        x = self.fc(x)\n        x = self.bn(x)\n        if self.drop_rate > 0:\n            x = self.dropout(x)\n        x = self.relu(x)\n        return x\n\n\nclass ResNet(nn.Module):\n\n    def __init__(self,\n                 block,\n                 layers,\n                 attr_file,\n                 zero_init_residual=False,\n                 dropout_rate=0):\n        super(ResNet, self).__init__()\n        self.inplanes = 64\n        self.conv1 = nn.Conv2d(\n            3, 64, kernel_size=7, stride=2, padding=3, bias=False)\n        self.bn1 = nn.BatchNorm2d(64)\n        self.relu = nn.ReLU(inplace=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n        self.stem = fc_block(512 * block.expansion, 512, dropout_rate)\n\n        # construct classifier heads according to the number of values of\n        # each attribute\n        self.attr_file = attr_file\n        with open(self.attr_file, 'r') as f:\n            attr_f = json.load(f)\n        self.attr_info = attr_f['attr_info']\n        for idx, (key, val) in enumerate(self.attr_info.items()):\n            num_val = int(len(val[\"value\"]))\n            setattr(\n                self, 'classifier' + str(key).zfill(2) + val[\"name\"],\n                nn.Sequential(\n                    fc_block(512, 256, dropout_rate), nn.Linear(256, num_val)))\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(\n                    m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n        # Zero-initialize the last BN in each residual branch,\n        # so that the residual branch starts with zeros, and each residual\n        # block behaves like an identity.\n        # This improves the model by 0.2~0.3% according\n        # to https://arxiv.org/abs/1706.02677\n        if zero_init_residual:\n            for m in self.modules():\n                if isinstance(m, Bottleneck):\n                    nn.init.constant_(m.bn3.weight, 0)\n                elif isinstance(m, BasicBlock):\n                    nn.init.constant_(m.bn2.weight, 0)\n\n    def _make_layer(self, block, planes, blocks, stride=1):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                conv1x1(self.inplanes, planes * block.expansion, stride),\n                nn.BatchNorm2d(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample))\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(block(self.inplanes, planes))\n\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n        x = x.view(x.size(0), -1)\n        x = self.stem(x)\n\n        y = []\n\n        for idx, (key, val) in enumerate(self.attr_info.items()):\n            classifier = getattr(\n                self, 'classifier' + str(key).zfill(2) + val[\"name\"])\n            y.append(classifier(x))\n\n        return y\n\n\ndef resnet50(pretrained=True, **kwargs):\n    \"\"\"Constructs a ResNet-50 model.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)\n    if pretrained:\n        init_pretrained_weights(model, model_urls['resnet50'])\n    return model\n\n\ndef init_pretrained_weights(model, model_url):\n    \"\"\"\n    Initialize model with pretrained weights.\n    Layers that don't match with pretrained layers in name or size are kept\n    unchanged.\n    \"\"\"\n    pretrain_dict = model_zoo.load_url(model_url)\n    model_dict = model.state_dict()\n    pretrain_dict = {\n        k: v\n        for k, v in pretrain_dict.items()\n        if k in model_dict and model_dict[k].size() == v.size()\n    }\n    model_dict.update(pretrain_dict)\n    model.load_state_dict(model_dict)\n    print(\n        \"Initialized model with pretrained weights from {}\".format(model_url))\n"
  },
  {
    "path": "models/archs/field_function_arch.py",
    "content": "import torch\nimport torch.nn as nn\n\n\nclass FieldFunction(nn.Module):\n\n    def __init__(\n        self,\n        num_layer=4,\n        latent_dim=512,\n        hidden_dim=512,\n        leaky_relu_neg_slope=0.2,\n    ):\n\n        super(FieldFunction, self).__init__()\n\n        layers = []\n\n        # first layer\n        linear_layer = LinearLayer(\n            in_dim=latent_dim,\n            out_dim=hidden_dim,\n            activation=True,\n            negative_slope=leaky_relu_neg_slope)\n        layers.append(linear_layer)\n\n        # hidden layers\n        for i in range(num_layer - 2):\n            linear_layer = LinearLayer(\n                in_dim=hidden_dim,\n                out_dim=hidden_dim,\n                activation=True,\n                negative_slope=leaky_relu_neg_slope)\n            layers.append(linear_layer)\n\n        # final layers\n        linear_layer = LinearLayer(\n            in_dim=hidden_dim, out_dim=latent_dim, activation=False)\n        layers.append(linear_layer)\n\n        self.field = nn.Sequential(*layers)\n\n    def forward(self, x):\n        x = self.field(x)\n        return x\n\n\nclass LinearLayer(nn.Module):\n\n    def __init__(\n        self,\n        in_dim=512,\n        out_dim=512,\n        activation=True,\n        negative_slope=0.2,\n    ):\n\n        super(LinearLayer, self).__init__()\n\n        self.Linear = nn.Linear(\n            in_features=in_dim, out_features=out_dim, bias=True)\n\n        self.activation = activation\n        if activation:\n            self.leaky_relu = nn.LeakyReLU(\n                negative_slope=negative_slope, inplace=False)\n\n    def forward(self, x):\n        x = self.Linear(x)\n        if self.activation:\n            x = self.leaky_relu(x)\n        return x\n\n\nclass Normalization(nn.Module):\n\n    def __init__(self, ):\n\n        super(Normalization, self).__init__()\n\n        self.mean = torch.tensor([0.485, 0.456, 0.406\n                                  ]).unsqueeze(-1).unsqueeze(-1).to('cuda')\n        print(self.mean.shape)\n        self.std = torch.tensor([0.229, 0.224,\n                                 0.225]).unsqueeze(-1).unsqueeze(-1).to('cuda')\n\n    def forward(self, x):\n        x = torch.sub(x, self.mean)\n        x = torch.div(x, self.std)\n        return x\n"
  },
  {
    "path": "models/archs/stylegan2/.gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\nwandb/\n*.lmdb/\n*.pkl\n"
  },
  {
    "path": "models/archs/stylegan2/LICENSE",
    "content": "MIT License\n\nCopyright (c) 2019 Kim Seonghyeon\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "models/archs/stylegan2/LICENSE-FID",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "models/archs/stylegan2/LICENSE-LPIPS",
    "content": "Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang\r\nAll rights reserved.\r\n\r\nRedistribution and use in source and binary forms, with or without\r\nmodification, are permitted provided that the following conditions are met:\r\n\r\n* Redistributions of source code must retain the above copyright notice, this\r\n  list of conditions and the following disclaimer.\r\n\r\n* Redistributions in binary form must reproduce the above copyright notice,\r\n  this list of conditions and the following disclaimer in the documentation\r\n  and/or other materials provided with the distribution.\r\n\r\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\r\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\r\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\r\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\r\nFOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\r\nDAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\r\nSERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\r\nCAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\r\nOR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\r\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\r\n\r\n"
  },
  {
    "path": "models/archs/stylegan2/LICENSE-NVIDIA",
    "content": "Copyright (c) 2019, NVIDIA Corporation. All rights reserved.\r\n\r\n\r\nNvidia Source Code License-NC\r\n\r\n=======================================================================\r\n\r\n1. Definitions\r\n\r\n\"Licensor\" means any person or entity that distributes its Work.\r\n\r\n\"Software\" means the original work of authorship made available under\r\nthis License.\r\n\r\n\"Work\" means the Software and any additions to or derivative works of\r\nthe Software that are made available under this License.\r\n\r\n\"Nvidia Processors\" means any central processing unit (CPU), graphics\r\nprocessing unit (GPU), field-programmable gate array (FPGA),\r\napplication-specific integrated circuit (ASIC) or any combination\r\nthereof designed, made, sold, or provided by Nvidia or its affiliates.\r\n\r\nThe terms \"reproduce,\" \"reproduction,\" \"derivative works,\" and\r\n\"distribution\" have the meaning as provided under U.S. copyright law;\r\nprovided, however, that for the purposes of this License, derivative\r\nworks shall not include works that remain separable from, or merely\r\nlink (or bind by name) to the interfaces of, the Work.\r\n\r\nWorks, including the Software, are \"made available\" under this License\r\nby including in or with the Work either (a) a copyright notice\r\nreferencing the applicability of this License to the Work, or (b) a\r\ncopy of this License.\r\n\r\n2. License Grants\r\n\r\n    2.1 Copyright Grant. Subject to the terms and conditions of this\r\n    License, each Licensor grants to you a perpetual, worldwide,\r\n    non-exclusive, royalty-free, copyright license to reproduce,\r\n    prepare derivative works of, publicly display, publicly perform,\r\n    sublicense and distribute its Work and any resulting derivative\r\n    works in any form.\r\n\r\n3. Limitations\r\n\r\n    3.1 Redistribution. You may reproduce or distribute the Work only\r\n    if (a) you do so under this License, (b) you include a complete\r\n    copy of this License with your distribution, and (c) you retain\r\n    without modification any copyright, patent, trademark, or\r\n    attribution notices that are present in the Work.\r\n\r\n    3.2 Derivative Works. You may specify that additional or different\r\n    terms apply to the use, reproduction, and distribution of your\r\n    derivative works of the Work (\"Your Terms\") only if (a) Your Terms\r\n    provide that the use limitation in Section 3.3 applies to your\r\n    derivative works, and (b) you identify the specific derivative\r\n    works that are subject to Your Terms. Notwithstanding Your Terms,\r\n    this License (including the redistribution requirements in Section\r\n    3.1) will continue to apply to the Work itself.\r\n\r\n    3.3 Use Limitation. The Work and any derivative works thereof only\r\n    may be used or intended for use non-commercially. The Work or\r\n    derivative works thereof may be used or intended for use by Nvidia\r\n    or its affiliates commercially or non-commercially. As used herein,\r\n    \"non-commercially\" means for research or evaluation purposes only.\r\n\r\n    3.4 Patent Claims. If you bring or threaten to bring a patent claim\r\n    against any Licensor (including any claim, cross-claim or\r\n    counterclaim in a lawsuit) to enforce any patents that you allege\r\n    are infringed by any Work, then your rights under this License from\r\n    such Licensor (including the grants in Sections 2.1 and 2.2) will\r\n    terminate immediately.\r\n\r\n    3.5 Trademarks. This License does not grant any rights to use any\r\n    Licensor's or its affiliates' names, logos, or trademarks, except\r\n    as necessary to reproduce the notices described in this License.\r\n\r\n    3.6 Termination. If you violate any term of this License, then your\r\n    rights under this License (including the grants in Sections 2.1 and\r\n    2.2) will terminate immediately.\r\n\r\n4. Disclaimer of Warranty.\r\n\r\nTHE WORK IS PROVIDED \"AS IS\" WITHOUT WARRANTIES OR CONDITIONS OF ANY\r\nKIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF\r\nMERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR\r\nNON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER\r\nTHIS LICENSE. \r\n\r\n5. Limitation of Liability.\r\n\r\nEXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL\r\nTHEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE\r\nSHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,\r\nINDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF\r\nOR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK\r\n(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,\r\nLOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER\r\nCOMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF\r\nTHE POSSIBILITY OF SUCH DAMAGES.\r\n\r\n=======================================================================\r\n"
  },
  {
    "path": "models/archs/stylegan2/__init__.py",
    "content": ""
  },
  {
    "path": "models/archs/stylegan2/apply_factor.py",
    "content": "import argparse\n\nimport torch\nfrom torchvision import utils\n\nfrom model import Generator\n\n\nif __name__ == \"__main__\":\n    torch.set_grad_enabled(False)\n\n    parser = argparse.ArgumentParser(description=\"Apply closed form factorization\")\n\n    parser.add_argument(\n        \"-i\", \"--index\", type=int, default=0, help=\"index of eigenvector\"\n    )\n    parser.add_argument(\n        \"-d\",\n        \"--degree\",\n        type=float,\n        default=5,\n        help=\"scalar factors for moving latent vectors along eigenvector\",\n    )\n    parser.add_argument(\n        \"--channel_multiplier\",\n        type=int,\n        default=2,\n        help='channel multiplier factor. config-f = 2, else = 1',\n    )\n    parser.add_argument(\"--ckpt\", type=str, required=True, help=\"stylegan2 checkpoints\")\n    parser.add_argument(\n        \"--size\", type=int, default=256, help=\"output image size of the generator\"\n    )\n    parser.add_argument(\n        \"-n\", \"--n_sample\", type=int, default=7, help=\"number of samples created\"\n    )\n    parser.add_argument(\n        \"--truncation\", type=float, default=0.7, help=\"truncation factor\"\n    )\n    parser.add_argument(\n        \"--device\", type=str, default=\"cuda\", help=\"device to run the model\"\n    )\n    parser.add_argument(\n        \"--out_prefix\",\n        type=str,\n        default=\"factor\",\n        help=\"filename prefix to result samples\",\n    )\n    parser.add_argument(\n        \"factor\",\n        type=str,\n        help=\"name of the closed form factorization result factor file\",\n    )\n\n    args = parser.parse_args()\n\n    eigvec = torch.load(args.factor)[\"eigvec\"].to(args.device)\n    ckpt = torch.load(args.ckpt)\n    g = Generator(args.size, 512, 8, channel_multiplier=args.channel_multiplier).to(args.device)\n    g.load_state_dict(ckpt[\"g_ema\"], strict=False)\n\n    trunc = g.mean_latent(4096)\n\n    latent = torch.randn(args.n_sample, 512, device=args.device)\n    latent = g.get_latent(latent)\n\n    direction = args.degree * eigvec[:, args.index].unsqueeze(0)\n\n    img, _ = g(\n        [latent],\n        truncation=args.truncation,\n        truncation_latent=trunc,\n        input_is_latent=True,\n    )\n    img1, _ = g(\n        [latent + direction],\n        truncation=args.truncation,\n        truncation_latent=trunc,\n        input_is_latent=True,\n    )\n    img2, _ = g(\n        [latent - direction],\n        truncation=args.truncation,\n        truncation_latent=trunc,\n        input_is_latent=True,\n    )\n\n    grid = utils.save_image(\n        torch.cat([img1, img, img2], 0),\n        f\"{args.out_prefix}_index-{args.index}_degree-{args.degree}.png\",\n        normalize=True,\n        range=(-1, 1),\n        nrow=args.n_sample,\n    )\n"
  },
  {
    "path": "models/archs/stylegan2/calc_inception.py",
    "content": "import argparse\nimport pickle\nimport os\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch.utils.data import DataLoader\nfrom torchvision import transforms\nfrom torchvision.models import inception_v3, Inception3\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom inception import InceptionV3\nfrom dataset import MultiResolutionDataset\n\n\nclass Inception3Feature(Inception3):\n    def forward(self, x):\n        if x.shape[2] != 299 or x.shape[3] != 299:\n            x = F.interpolate(x, size=(299, 299), mode=\"bilinear\", align_corners=True)\n\n        x = self.Conv2d_1a_3x3(x)  # 299 x 299 x 3\n        x = self.Conv2d_2a_3x3(x)  # 149 x 149 x 32\n        x = self.Conv2d_2b_3x3(x)  # 147 x 147 x 32\n        x = F.max_pool2d(x, kernel_size=3, stride=2)  # 147 x 147 x 64\n\n        x = self.Conv2d_3b_1x1(x)  # 73 x 73 x 64\n        x = self.Conv2d_4a_3x3(x)  # 73 x 73 x 80\n        x = F.max_pool2d(x, kernel_size=3, stride=2)  # 71 x 71 x 192\n\n        x = self.Mixed_5b(x)  # 35 x 35 x 192\n        x = self.Mixed_5c(x)  # 35 x 35 x 256\n        x = self.Mixed_5d(x)  # 35 x 35 x 288\n\n        x = self.Mixed_6a(x)  # 35 x 35 x 288\n        x = self.Mixed_6b(x)  # 17 x 17 x 768\n        x = self.Mixed_6c(x)  # 17 x 17 x 768\n        x = self.Mixed_6d(x)  # 17 x 17 x 768\n        x = self.Mixed_6e(x)  # 17 x 17 x 768\n\n        x = self.Mixed_7a(x)  # 17 x 17 x 768\n        x = self.Mixed_7b(x)  # 8 x 8 x 1280\n        x = self.Mixed_7c(x)  # 8 x 8 x 2048\n\n        x = F.avg_pool2d(x, kernel_size=8)  # 8 x 8 x 2048\n\n        return x.view(x.shape[0], x.shape[1])  # 1 x 1 x 2048\n\n\ndef load_patched_inception_v3():\n    # inception = inception_v3(pretrained=True)\n    # inception_feat = Inception3Feature()\n    # inception_feat.load_state_dict(inception.state_dict())\n    inception_feat = InceptionV3([3], normalize_input=False)\n\n    return inception_feat\n\n\n@torch.no_grad()\ndef extract_features(loader, inception, device):\n    pbar = tqdm(loader)\n\n    feature_list = []\n\n    for img in pbar:\n        img = img.to(device)\n        feature = inception(img)[0].view(img.shape[0], -1)\n        feature_list.append(feature.to(\"cpu\"))\n\n    features = torch.cat(feature_list, 0)\n\n    return features\n\n\nif __name__ == \"__main__\":\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n    parser = argparse.ArgumentParser(\n        description=\"Calculate Inception v3 features for datasets\"\n    )\n    parser.add_argument(\n        \"--size\",\n        type=int,\n        default=256,\n        help=\"image sizes used for embedding calculation\",\n    )\n    parser.add_argument(\n        \"--batch\", default=64, type=int, help=\"batch size for inception networks\"\n    )\n    parser.add_argument(\n        \"--n_sample\",\n        type=int,\n        default=50000,\n        help=\"number of samples used for embedding calculation\",\n    )\n    parser.add_argument(\n        \"--flip\", action=\"store_true\", help=\"apply random flipping to real images\"\n    )\n    parser.add_argument(\"path\", metavar=\"PATH\", help=\"path to datset lmdb file\")\n\n    args = parser.parse_args()\n\n    inception = load_patched_inception_v3()\n    inception = nn.DataParallel(inception).eval().to(device)\n\n    transform = transforms.Compose(\n        [\n            transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0),\n            transforms.ToTensor(),\n            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),\n        ]\n    )\n\n    dset = MultiResolutionDataset(args.path, transform=transform, resolution=args.size)\n    loader = DataLoader(dset, batch_size=args.batch, num_workers=4)\n\n    features = extract_features(loader, inception, device).numpy()\n\n    features = features[: args.n_sample]\n\n    print(f\"extracted {features.shape[0]} features\")\n\n    mean = np.mean(features, 0)\n    cov = np.cov(features, rowvar=False)\n\n    name = os.path.splitext(os.path.basename(args.path))[0]\n\n    with open(f\"inception_{name}.pkl\", \"wb\") as f:\n        pickle.dump({\"mean\": mean, \"cov\": cov, \"size\": args.size, \"path\": args.path}, f)\n"
  },
  {
    "path": "models/archs/stylegan2/checkpoint/.gitignore",
    "content": "*.pt\n"
  },
  {
    "path": "models/archs/stylegan2/convert_weight.py",
    "content": "import argparse\nimport math\nimport os\nimport pickle\nimport sys\n\nimport numpy as np\nimport torch\nfrom torchvision import utils\n\nfrom model import Discriminator, Generator\n\n\ndef convert_modconv(vars, source_name, target_name, flip=False):\n    weight = vars[source_name + \"/weight\"].value().eval()\n    mod_weight = vars[source_name + \"/mod_weight\"].value().eval()\n    mod_bias = vars[source_name + \"/mod_bias\"].value().eval()\n    noise = vars[source_name + \"/noise_strength\"].value().eval()\n    bias = vars[source_name + \"/bias\"].value().eval()\n\n    dic = {\n        \"conv.weight\": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0),\n        \"conv.modulation.weight\": mod_weight.transpose((1, 0)),\n        \"conv.modulation.bias\": mod_bias + 1,\n        \"noise.weight\": np.array([noise]),\n        \"activate.bias\": bias,\n    }\n\n    dic_torch = {}\n\n    for k, v in dic.items():\n        dic_torch[target_name + \".\" + k] = torch.from_numpy(v)\n\n    if flip:\n        dic_torch[target_name + \".conv.weight\"] = torch.flip(\n            dic_torch[target_name + \".conv.weight\"], [3, 4])\n\n    return dic_torch\n\n\ndef convert_conv(vars, source_name, target_name, bias=True, start=0):\n    weight = vars[source_name + \"/weight\"].value().eval()\n\n    dic = {\"weight\": weight.transpose((3, 2, 0, 1))}\n\n    if bias:\n        dic[\"bias\"] = vars[source_name + \"/bias\"].value().eval()\n\n    dic_torch = {}\n\n    dic_torch[target_name + f\".{start}.weight\"] = torch.from_numpy(\n        dic[\"weight\"])\n\n    if bias:\n        dic_torch[target_name + f\".{start + 1}.bias\"] = torch.from_numpy(\n            dic[\"bias\"])\n\n    return dic_torch\n\n\ndef convert_torgb(vars, source_name, target_name):\n    weight = vars[source_name + \"/weight\"].value().eval()\n    mod_weight = vars[source_name + \"/mod_weight\"].value().eval()\n    mod_bias = vars[source_name + \"/mod_bias\"].value().eval()\n    bias = vars[source_name + \"/bias\"].value().eval()\n\n    dic = {\n        \"conv.weight\": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0),\n        \"conv.modulation.weight\": mod_weight.transpose((1, 0)),\n        \"conv.modulation.bias\": mod_bias + 1,\n        \"bias\": bias.reshape((1, 3, 1, 1)),\n    }\n\n    dic_torch = {}\n\n    for k, v in dic.items():\n        dic_torch[target_name + \".\" + k] = torch.from_numpy(v)\n\n    return dic_torch\n\n\ndef convert_dense(vars, source_name, target_name):\n    weight = vars[source_name + \"/weight\"].value().eval()\n    bias = vars[source_name + \"/bias\"].value().eval()\n\n    dic = {\"weight\": weight.transpose((1, 0)), \"bias\": bias}\n\n    dic_torch = {}\n\n    for k, v in dic.items():\n        dic_torch[target_name + \".\" + k] = torch.from_numpy(v)\n\n    return dic_torch\n\n\ndef update(state_dict, new):\n    for k, v in new.items():\n        if k not in state_dict:\n            raise KeyError(k + \" is not found\")\n\n        if v.shape != state_dict[k].shape:\n            raise ValueError(\n                f\"Shape mismatch: {v.shape} vs {state_dict[k].shape}\")\n\n        state_dict[k] = v\n\n\ndef discriminator_fill_statedict(statedict, vars, size):\n    log_size = int(math.log(size, 2))\n\n    update(statedict, convert_conv(vars, f\"{size}x{size}/FromRGB\", \"convs.0\"))\n\n    conv_i = 1\n\n    for i in range(log_size - 2, 0, -1):\n        reso = 4 * 2**i\n        update(\n            statedict,\n            convert_conv(vars, f\"{reso}x{reso}/Conv0\",\n                         f\"convs.{conv_i}.conv1\"),\n        )\n        update(\n            statedict,\n            convert_conv(\n                vars,\n                f\"{reso}x{reso}/Conv1_down\",\n                f\"convs.{conv_i}.conv2\",\n                start=1),\n        )\n        update(\n            statedict,\n            convert_conv(\n                vars,\n                f\"{reso}x{reso}/Skip\",\n                f\"convs.{conv_i}.skip\",\n                start=1,\n                bias=False),\n        )\n        conv_i += 1\n\n    update(statedict, convert_conv(vars, f\"4x4/Conv\", \"final_conv\"))\n    update(statedict, convert_dense(vars, f\"4x4/Dense0\", \"final_linear.0\"))\n    update(statedict, convert_dense(vars, f\"Output\", \"final_linear.1\"))\n\n    return statedict\n\n\ndef fill_statedict(state_dict, vars, size, n_mlp):\n    log_size = int(math.log(size, 2))\n\n    for i in range(n_mlp):\n        update(state_dict,\n               convert_dense(vars, f\"G_mapping/Dense{i}\", f\"style.{i + 1}\"))\n\n    update(\n        state_dict,\n        {\n            \"input.input\":\n            torch.from_numpy(\n                vars[\"G_synthesis/4x4/Const/const\"].value().eval())\n        },\n    )\n\n    update(state_dict, convert_torgb(vars, \"G_synthesis/4x4/ToRGB\", \"to_rgb1\"))\n\n    for i in range(log_size - 2):\n        reso = 4 * 2**(i + 1)\n        update(\n            state_dict,\n            convert_torgb(vars, f\"G_synthesis/{reso}x{reso}/ToRGB\",\n                          f\"to_rgbs.{i}\"),\n        )\n\n    update(state_dict, convert_modconv(vars, \"G_synthesis/4x4/Conv\", \"conv1\"))\n\n    conv_i = 0\n\n    for i in range(log_size - 2):\n        reso = 4 * 2**(i + 1)\n        update(\n            state_dict,\n            convert_modconv(\n                vars,\n                f\"G_synthesis/{reso}x{reso}/Conv0_up\",\n                f\"convs.{conv_i}\",\n                flip=True,\n            ),\n        )\n        update(\n            state_dict,\n            convert_modconv(vars, f\"G_synthesis/{reso}x{reso}/Conv1\",\n                            f\"convs.{conv_i + 1}\"),\n        )\n        conv_i += 2\n\n    for i in range(0, (log_size - 2) * 2 + 1):\n        update(\n            state_dict,\n            {\n                f\"noises.noise_{i}\":\n                torch.from_numpy(vars[f\"G_synthesis/noise{i}\"].value().eval())\n            },\n        )\n\n    return state_dict\n\n\nif __name__ == \"__main__\":\n    device = \"cuda\"\n\n    parser = argparse.ArgumentParser(\n        description=\"Tensorflow to pytorch model checkpoint converter\")\n    parser.add_argument(\n        \"--repo\",\n        type=str,\n        required=True,\n        help=\"path to the offical StyleGAN2 repository with dnnlib/ folder\",\n    )\n    parser.add_argument(\n        \"--gen\", action=\"store_true\", help=\"convert the generator weights\")\n    parser.add_argument(\n        \"--disc\",\n        action=\"store_true\",\n        help=\"convert the discriminator weights\")\n    parser.add_argument(\n        \"--channel_multiplier\",\n        type=int,\n        default=2,\n        help=\"channel multiplier factor. config-f = 2, else = 1\",\n    )\n    parser.add_argument(\n        \"path\", metavar=\"PATH\", help=\"path to the tensorflow weights\")\n\n    args = parser.parse_args()\n\n    sys.path.append(args.repo)\n\n    import dnnlib\n    from dnnlib import tflib\n\n    tflib.init_tf()\n\n    with open(args.path, \"rb\") as f:\n        generator, discriminator, g_ema = pickle.load(f)\n\n    size = g_ema.output_shape[2]\n    print(size)\n\n    raie NotImplementedError\n\n    n_mlp = 0\n    mapping_layers_names = g_ema.__getstate__(\n    )['components']['mapping'].list_layers()\n    for layer in mapping_layers_names:\n        if layer[0].startswith('Dense'):\n            n_mlp += 1\n\n    g = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier)\n    state_dict = g.state_dict()\n    state_dict = fill_statedict(state_dict, g_ema.vars, size, n_mlp)\n\n    g.load_state_dict(state_dict)\n\n    latent_avg = torch.from_numpy(g_ema.vars[\"dlatent_avg\"].value().eval())\n\n    ckpt = {\"g_ema\": state_dict, \"latent_avg\": latent_avg}\n\n    if args.gen:\n        g_train = Generator(\n            size, 512, n_mlp, channel_multiplier=args.channel_multiplier)\n        g_train_state = g_train.state_dict()\n        g_train_state = fill_statedict(g_train_state, generator.vars, size)\n        ckpt[\"g\"] = g_train_state\n\n    if args.disc:\n        disc = Discriminator(size, channel_multiplier=args.channel_multiplier)\n        d_state = disc.state_dict()\n        d_state = discriminator_fill_statedict(d_state, discriminator.vars,\n                                               size)\n        ckpt[\"d\"] = d_state\n\n    name = os.path.splitext(os.path.basename(args.path))[0]\n    torch.save(ckpt, name + \".pt\")\n\n    batch_size = {256: 16, 512: 9, 1024: 4}\n    n_sample = batch_size.get(size, 25)\n\n    g = g.to(device)\n\n    z = np.random.RandomState(0).randn(n_sample, 512).astype(\"float32\")\n\n    with torch.no_grad():\n        img_pt, _ = g(\n            [torch.from_numpy(z).to(device)],\n            truncation=0.5,\n            truncation_latent=latent_avg.to(device),\n            randomize_noise=False,\n        )\n\n    Gs_kwargs = dnnlib.EasyDict()\n    Gs_kwargs.randomize_noise = False\n    img_tf = g_ema.run(z, None, **Gs_kwargs)\n    img_tf = torch.from_numpy(img_tf).to(device)\n\n    img_diff = ((img_pt + 1) / 2).clamp(0.0, 1.0) - (\n        (img_tf.to(device) + 1) / 2).clamp(0.0, 1.0)\n\n    img_concat = torch.cat((img_tf, img_pt, img_diff), dim=0)\n\n    print(img_diff.abs().max())\n\n    utils.save_image(\n        img_concat,\n        name + \".png\",\n        nrow=n_sample,\n        normalize=True,\n        range=(-1, 1))\n"
  },
  {
    "path": "models/archs/stylegan2/dataset.py",
    "content": "from io import BytesIO\n\nimport lmdb\nfrom PIL import Image\nfrom torch.utils.data import Dataset\n\n\nclass MultiResolutionDataset(Dataset):\n    def __init__(self, path, transform, resolution=256):\n        self.env = lmdb.open(\n            path,\n            max_readers=32,\n            readonly=True,\n            lock=False,\n            readahead=False,\n            meminit=False,\n        )\n\n        if not self.env:\n            raise IOError('Cannot open lmdb dataset', path)\n\n        with self.env.begin(write=False) as txn:\n            self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))\n\n        self.resolution = resolution\n        self.transform = transform\n\n    def __len__(self):\n        return self.length\n\n    def __getitem__(self, index):\n        with self.env.begin(write=False) as txn:\n            key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')\n            img_bytes = txn.get(key)\n\n        buffer = BytesIO(img_bytes)\n        img = Image.open(buffer)\n        img = self.transform(img)\n\n        return img\n"
  },
  {
    "path": "models/archs/stylegan2/distributed.py",
    "content": "import math\nimport pickle\n\nimport torch\nfrom torch import distributed as dist\nfrom torch.utils.data.sampler import Sampler\n\n\ndef get_rank():\n    if not dist.is_available():\n        return 0\n\n    if not dist.is_initialized():\n        return 0\n\n    return dist.get_rank()\n\n\ndef synchronize():\n    if not dist.is_available():\n        return\n\n    if not dist.is_initialized():\n        return\n\n    world_size = dist.get_world_size()\n\n    if world_size == 1:\n        return\n\n    dist.barrier()\n\n\ndef get_world_size():\n    if not dist.is_available():\n        return 1\n\n    if not dist.is_initialized():\n        return 1\n\n    return dist.get_world_size()\n\n\ndef reduce_sum(tensor):\n    if not dist.is_available():\n        return tensor\n\n    if not dist.is_initialized():\n        return tensor\n\n    tensor = tensor.clone()\n    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)\n\n    return tensor\n\n\ndef gather_grad(params):\n    world_size = get_world_size()\n    \n    if world_size == 1:\n        return\n\n    for param in params:\n        if param.grad is not None:\n            dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)\n            param.grad.data.div_(world_size)\n\n\ndef all_gather(data):\n    world_size = get_world_size()\n\n    if world_size == 1:\n        return [data]\n\n    buffer = pickle.dumps(data)\n    storage = torch.ByteStorage.from_buffer(buffer)\n    tensor = torch.ByteTensor(storage).to('cuda')\n\n    local_size = torch.IntTensor([tensor.numel()]).to('cuda')\n    size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]\n    dist.all_gather(size_list, local_size)\n    size_list = [int(size.item()) for size in size_list]\n    max_size = max(size_list)\n\n    tensor_list = []\n    for _ in size_list:\n        tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))\n\n    if local_size != max_size:\n        padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')\n        tensor = torch.cat((tensor, padding), 0)\n\n    dist.all_gather(tensor_list, tensor)\n\n    data_list = []\n\n    for size, tensor in zip(size_list, tensor_list):\n        buffer = tensor.cpu().numpy().tobytes()[:size]\n        data_list.append(pickle.loads(buffer))\n\n    return data_list\n\n\ndef reduce_loss_dict(loss_dict):\n    world_size = get_world_size()\n\n    if world_size < 2:\n        return loss_dict\n\n    with torch.no_grad():\n        keys = []\n        losses = []\n\n        for k in sorted(loss_dict.keys()):\n            keys.append(k)\n            losses.append(loss_dict[k])\n\n        losses = torch.stack(losses, 0)\n        dist.reduce(losses, dst=0)\n\n        if dist.get_rank() == 0:\n            losses /= world_size\n\n        reduced_losses = {k: v for k, v in zip(keys, losses)}\n\n    return reduced_losses\n"
  },
  {
    "path": "models/archs/stylegan2/fid.py",
    "content": "import argparse\nimport pickle\n\nimport torch\nfrom torch import nn\nimport numpy as np\nfrom scipy import linalg\nfrom tqdm import tqdm\n\nfrom model import Generator\nfrom calc_inception import load_patched_inception_v3\n\n\n@torch.no_grad()\ndef extract_feature_from_samples(\n    generator, inception, truncation, truncation_latent, batch_size, n_sample, device\n):\n    n_batch = n_sample // batch_size\n    resid = n_sample - (n_batch * batch_size)\n    batch_sizes = [batch_size] * n_batch + [resid]\n    features = []\n\n    for batch in tqdm(batch_sizes):\n        latent = torch.randn(batch, 512, device=device)\n        img, _ = g([latent], truncation=truncation, truncation_latent=truncation_latent)\n        feat = inception(img)[0].view(img.shape[0], -1)\n        features.append(feat.to(\"cpu\"))\n\n    features = torch.cat(features, 0)\n\n    return features\n\n\ndef calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6):\n    cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False)\n\n    if not np.isfinite(cov_sqrt).all():\n        print(\"product of cov matrices is singular\")\n        offset = np.eye(sample_cov.shape[0]) * eps\n        cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset))\n\n    if np.iscomplexobj(cov_sqrt):\n        if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):\n            m = np.max(np.abs(cov_sqrt.imag))\n\n            raise ValueError(f\"Imaginary component {m}\")\n\n        cov_sqrt = cov_sqrt.real\n\n    mean_diff = sample_mean - real_mean\n    mean_norm = mean_diff @ mean_diff\n\n    trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt)\n\n    fid = mean_norm + trace\n\n    return fid\n\n\nif __name__ == \"__main__\":\n    device = \"cuda\"\n\n    parser = argparse.ArgumentParser(description=\"Calculate FID scores\")\n\n    parser.add_argument(\"--truncation\", type=float, default=1, help=\"truncation factor\")\n    parser.add_argument(\n        \"--truncation_mean\",\n        type=int,\n        default=4096,\n        help=\"number of samples to calculate mean for truncation\",\n    )\n    parser.add_argument(\n        \"--batch\", type=int, default=64, help=\"batch size for the generator\"\n    )\n    parser.add_argument(\n        \"--n_sample\",\n        type=int,\n        default=50000,\n        help=\"number of the samples for calculating FID\",\n    )\n    parser.add_argument(\n        \"--size\", type=int, default=256, help=\"image sizes for generator\"\n    )\n    parser.add_argument(\n        \"--inception\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"path to precomputed inception embedding\",\n    )\n    parser.add_argument(\n        \"ckpt\", metavar=\"CHECKPOINT\", help=\"path to generator checkpoint\"\n    )\n\n    args = parser.parse_args()\n\n    ckpt = torch.load(args.ckpt)\n\n    g = Generator(args.size, 512, 8).to(device)\n    g.load_state_dict(ckpt[\"g_ema\"])\n    g = nn.DataParallel(g)\n    g.eval()\n\n    if args.truncation < 1:\n        with torch.no_grad():\n            mean_latent = g.mean_latent(args.truncation_mean)\n\n    else:\n        mean_latent = None\n\n    inception = nn.DataParallel(load_patched_inception_v3()).to(device)\n    inception.eval()\n\n    features = extract_feature_from_samples(\n        g, inception, args.truncation, mean_latent, args.batch, args.n_sample, device\n    ).numpy()\n    print(f\"extracted {features.shape[0]} features\")\n\n    sample_mean = np.mean(features, 0)\n    sample_cov = np.cov(features, rowvar=False)\n\n    with open(args.inception, \"rb\") as f:\n        embeds = pickle.load(f)\n        real_mean = embeds[\"mean\"]\n        real_cov = embeds[\"cov\"]\n\n    fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov)\n\n    print(\"fid:\", fid)\n"
  },
  {
    "path": "models/archs/stylegan2/generate.py",
    "content": "import argparse\nimport os\nimport sys\n\nimport numpy as np\nimport torch\nfrom torchvision import utils\nfrom tqdm import tqdm\n\nsys.path.append('..')\nfrom stylegan2_pytorch.model import Generator\n\n\ndef generate(args, g_ema, device, mean_latent):\n\n    if not os.path.exists(args.synthetic_image_dir):\n        os.makedirs(args.synthetic_image_dir)\n\n    latent_code = {}\n    w_space_code = {}\n    with torch.no_grad():\n        g_ema.eval()\n        for i in tqdm(range(args.pics)):\n            sample_z = torch.randn(args.sample, args.latent, device=device)\n\n            sample, w_space = g_ema([sample_z],\n                                    truncation=args.truncation,\n                                    truncation_latent=mean_latent,\n                                    return_latents=True,\n                                    randomize_noise=False)\n\n            utils.save_image(\n                sample,\n                os.path.join(args.synthetic_image_dir,\n                             f\"{str(i).zfill(7)}.png\"),\n                nrow=1,\n                normalize=True,\n                range=(-1, 1),\n            )\n            latent_code[f\"{str(i).zfill(7)}.png\"] = sample_z.cpu().numpy()\n            w_space_code[f\"{str(i).zfill(7)}.png\"] = w_space.cpu().numpy()\n\n    # save latent code\n    np.save(f'{args.synthetic_image_dir}/latent_code.npz', latent_code)\n    np.save(f'{args.synthetic_image_dir}/w_space_code.npz', w_space_code)\n\n\nif __name__ == \"__main__\":\n    device = \"cuda\"\n\n    parser = argparse.ArgumentParser(\n        description=\"Generate samples from the generator\")\n\n    parser.add_argument(\n        \"--size\",\n        type=int,\n        default=1024,\n        help=\"output image size of the generator\")\n    parser.add_argument(\n        \"--sample\",\n        type=int,\n        default=1,\n        help=\"number of samples to be generated for each image\",\n    )\n    parser.add_argument(\n        \"--pics\",\n        type=int,\n        default=20,\n        help=\"number of images to be generated\")\n    parser.add_argument(\n        \"--truncation\", type=float, default=1, help=\"truncation ratio\")\n    parser.add_argument(\n        \"--truncation_mean\",\n        type=int,\n        default=4096,\n        help=\"number of vectors to calculate mean for the truncation\",\n    )\n    parser.add_argument(\n        \"--ckpt\",\n        type=str,\n        default=\"stylegan2-ffhq-config-f.pt\",\n        help=\"path to the model checkpoint\",\n    )\n    parser.add_argument(\n        \"--channel_multiplier\",\n        type=int,\n        default=2,\n        help=\"channel multiplier of the generator. config-f = 2, else = 1\",\n    )\n    parser.add_argument(\n        \"--synthetic_image_dir\",\n        default='',\n        help=\"channel multiplier of the generator. config-f = 2, else = 1\",\n    )\n    args = parser.parse_args()\n\n    args.latent = 512\n    args.n_mlp = 8\n\n    g_ema = Generator(\n        args.size,\n        args.latent,\n        args.n_mlp,\n        channel_multiplier=args.channel_multiplier).to(device)\n    checkpoint = torch.load(args.ckpt)\n\n    g_ema.load_state_dict(checkpoint[\"g_ema\"])\n\n    if args.truncation < 1:\n        with torch.no_grad():\n            mean_latent = g_ema.mean_latent(args.truncation_mean)\n    else:\n        mean_latent = None\n\n    generate(args, g_ema, device, mean_latent)\n"
  },
  {
    "path": "models/archs/stylegan2/inception.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torchvision import models\n\ntry:\n    from torchvision.models.utils import load_state_dict_from_url\nexcept ImportError:\n    from torch.utils.model_zoo import load_url as load_state_dict_from_url\n\n# Inception weights ported to Pytorch from\n# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz\nFID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'\n\n\nclass InceptionV3(nn.Module):\n    \"\"\"Pretrained InceptionV3 network returning feature maps\"\"\"\n\n    # Index of default block of inception to return,\n    # corresponds to output of final average pooling\n    DEFAULT_BLOCK_INDEX = 3\n\n    # Maps feature dimensionality to their output blocks indices\n    BLOCK_INDEX_BY_DIM = {\n        64: 0,   # First max pooling features\n        192: 1,  # Second max pooling featurs\n        768: 2,  # Pre-aux classifier features\n        2048: 3  # Final average pooling features\n    }\n\n    def __init__(self,\n                 output_blocks=[DEFAULT_BLOCK_INDEX],\n                 resize_input=True,\n                 normalize_input=True,\n                 requires_grad=False,\n                 use_fid_inception=True):\n        \"\"\"Build pretrained InceptionV3\n\n        Parameters\n        ----------\n        output_blocks : list of int\n            Indices of blocks to return features of. Possible values are:\n                - 0: corresponds to output of first max pooling\n                - 1: corresponds to output of second max pooling\n                - 2: corresponds to output which is fed to aux classifier\n                - 3: corresponds to output of final average pooling\n        resize_input : bool\n            If true, bilinearly resizes input to width and height 299 before\n            feeding input to model. As the network without fully connected\n            layers is fully convolutional, it should be able to handle inputs\n            of arbitrary size, so resizing might not be strictly needed\n        normalize_input : bool\n            If true, scales the input from range (0, 1) to the range the\n            pretrained Inception network expects, namely (-1, 1)\n        requires_grad : bool\n            If true, parameters of the model require gradients. Possibly useful\n            for finetuning the network\n        use_fid_inception : bool\n            If true, uses the pretrained Inception model used in Tensorflow's\n            FID implementation. If false, uses the pretrained Inception model\n            available in torchvision. The FID Inception model has different\n            weights and a slightly different structure from torchvision's\n            Inception model. If you want to compute FID scores, you are\n            strongly advised to set this parameter to true to get comparable\n            results.\n        \"\"\"\n        super(InceptionV3, self).__init__()\n\n        self.resize_input = resize_input\n        self.normalize_input = normalize_input\n        self.output_blocks = sorted(output_blocks)\n        self.last_needed_block = max(output_blocks)\n\n        assert self.last_needed_block <= 3, \\\n            'Last possible output block index is 3'\n\n        self.blocks = nn.ModuleList()\n\n        if use_fid_inception:\n            inception = fid_inception_v3()\n        else:\n            inception = models.inception_v3(pretrained=True)\n\n        # Block 0: input to maxpool1\n        block0 = [\n            inception.Conv2d_1a_3x3,\n            inception.Conv2d_2a_3x3,\n            inception.Conv2d_2b_3x3,\n            nn.MaxPool2d(kernel_size=3, stride=2)\n        ]\n        self.blocks.append(nn.Sequential(*block0))\n\n        # Block 1: maxpool1 to maxpool2\n        if self.last_needed_block >= 1:\n            block1 = [\n                inception.Conv2d_3b_1x1,\n                inception.Conv2d_4a_3x3,\n                nn.MaxPool2d(kernel_size=3, stride=2)\n            ]\n            self.blocks.append(nn.Sequential(*block1))\n\n        # Block 2: maxpool2 to aux classifier\n        if self.last_needed_block >= 2:\n            block2 = [\n                inception.Mixed_5b,\n                inception.Mixed_5c,\n                inception.Mixed_5d,\n                inception.Mixed_6a,\n                inception.Mixed_6b,\n                inception.Mixed_6c,\n                inception.Mixed_6d,\n                inception.Mixed_6e,\n            ]\n            self.blocks.append(nn.Sequential(*block2))\n\n        # Block 3: aux classifier to final avgpool\n        if self.last_needed_block >= 3:\n            block3 = [\n                inception.Mixed_7a,\n                inception.Mixed_7b,\n                inception.Mixed_7c,\n                nn.AdaptiveAvgPool2d(output_size=(1, 1))\n            ]\n            self.blocks.append(nn.Sequential(*block3))\n\n        for param in self.parameters():\n            param.requires_grad = requires_grad\n\n    def forward(self, inp):\n        \"\"\"Get Inception feature maps\n\n        Parameters\n        ----------\n        inp : torch.autograd.Variable\n            Input tensor of shape Bx3xHxW. Values are expected to be in\n            range (0, 1)\n\n        Returns\n        -------\n        List of torch.autograd.Variable, corresponding to the selected output\n        block, sorted ascending by index\n        \"\"\"\n        outp = []\n        x = inp\n\n        if self.resize_input:\n            x = F.interpolate(x,\n                              size=(299, 299),\n                              mode='bilinear',\n                              align_corners=False)\n\n        if self.normalize_input:\n            x = 2 * x - 1  # Scale from range (0, 1) to range (-1, 1)\n\n        for idx, block in enumerate(self.blocks):\n            x = block(x)\n            if idx in self.output_blocks:\n                outp.append(x)\n\n            if idx == self.last_needed_block:\n                break\n\n        return outp\n\n\ndef fid_inception_v3():\n    \"\"\"Build pretrained Inception model for FID computation\n\n    The Inception model for FID computation uses a different set of weights\n    and has a slightly different structure than torchvision's Inception.\n\n    This method first constructs torchvision's Inception and then patches the\n    necessary parts that are different in the FID Inception model.\n    \"\"\"\n    inception = models.inception_v3(num_classes=1008,\n                                    aux_logits=False,\n                                    pretrained=False)\n    inception.Mixed_5b = FIDInceptionA(192, pool_features=32)\n    inception.Mixed_5c = FIDInceptionA(256, pool_features=64)\n    inception.Mixed_5d = FIDInceptionA(288, pool_features=64)\n    inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)\n    inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)\n    inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)\n    inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)\n    inception.Mixed_7b = FIDInceptionE_1(1280)\n    inception.Mixed_7c = FIDInceptionE_2(2048)\n\n    state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)\n    inception.load_state_dict(state_dict)\n    return inception\n\n\nclass FIDInceptionA(models.inception.InceptionA):\n    \"\"\"InceptionA block patched for FID computation\"\"\"\n    def __init__(self, in_channels, pool_features):\n        super(FIDInceptionA, self).__init__(in_channels, pool_features)\n\n    def forward(self, x):\n        branch1x1 = self.branch1x1(x)\n\n        branch5x5 = self.branch5x5_1(x)\n        branch5x5 = self.branch5x5_2(branch5x5)\n\n        branch3x3dbl = self.branch3x3dbl_1(x)\n        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)\n        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)\n\n        # Patch: Tensorflow's average pool does not use the padded zero's in\n        # its average calculation\n        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,\n                                   count_include_pad=False)\n        branch_pool = self.branch_pool(branch_pool)\n\n        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]\n        return torch.cat(outputs, 1)\n\n\nclass FIDInceptionC(models.inception.InceptionC):\n    \"\"\"InceptionC block patched for FID computation\"\"\"\n    def __init__(self, in_channels, channels_7x7):\n        super(FIDInceptionC, self).__init__(in_channels, channels_7x7)\n\n    def forward(self, x):\n        branch1x1 = self.branch1x1(x)\n\n        branch7x7 = self.branch7x7_1(x)\n        branch7x7 = self.branch7x7_2(branch7x7)\n        branch7x7 = self.branch7x7_3(branch7x7)\n\n        branch7x7dbl = self.branch7x7dbl_1(x)\n        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)\n        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)\n        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)\n        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)\n\n        # Patch: Tensorflow's average pool does not use the padded zero's in\n        # its average calculation\n        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,\n                                   count_include_pad=False)\n        branch_pool = self.branch_pool(branch_pool)\n\n        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]\n        return torch.cat(outputs, 1)\n\n\nclass FIDInceptionE_1(models.inception.InceptionE):\n    \"\"\"First InceptionE block patched for FID computation\"\"\"\n    def __init__(self, in_channels):\n        super(FIDInceptionE_1, self).__init__(in_channels)\n\n    def forward(self, x):\n        branch1x1 = self.branch1x1(x)\n\n        branch3x3 = self.branch3x3_1(x)\n        branch3x3 = [\n            self.branch3x3_2a(branch3x3),\n            self.branch3x3_2b(branch3x3),\n        ]\n        branch3x3 = torch.cat(branch3x3, 1)\n\n        branch3x3dbl = self.branch3x3dbl_1(x)\n        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)\n        branch3x3dbl = [\n            self.branch3x3dbl_3a(branch3x3dbl),\n            self.branch3x3dbl_3b(branch3x3dbl),\n        ]\n        branch3x3dbl = torch.cat(branch3x3dbl, 1)\n\n        # Patch: Tensorflow's average pool does not use the padded zero's in\n        # its average calculation\n        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,\n                                   count_include_pad=False)\n        branch_pool = self.branch_pool(branch_pool)\n\n        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]\n        return torch.cat(outputs, 1)\n\n\nclass FIDInceptionE_2(models.inception.InceptionE):\n    \"\"\"Second InceptionE block patched for FID computation\"\"\"\n    def __init__(self, in_channels):\n        super(FIDInceptionE_2, self).__init__(in_channels)\n\n    def forward(self, x):\n        branch1x1 = self.branch1x1(x)\n\n        branch3x3 = self.branch3x3_1(x)\n        branch3x3 = [\n            self.branch3x3_2a(branch3x3),\n            self.branch3x3_2b(branch3x3),\n        ]\n        branch3x3 = torch.cat(branch3x3, 1)\n\n        branch3x3dbl = self.branch3x3dbl_1(x)\n        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)\n        branch3x3dbl = [\n            self.branch3x3dbl_3a(branch3x3dbl),\n            self.branch3x3dbl_3b(branch3x3dbl),\n        ]\n        branch3x3dbl = torch.cat(branch3x3dbl, 1)\n\n        # Patch: The FID Inception model uses max pooling instead of average\n        # pooling. This is likely an error in this specific Inception\n        # implementation, as other Inception models use average pooling here\n        # (which matches the description in the paper).\n        branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)\n        branch_pool = self.branch_pool(branch_pool)\n\n        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]\n        return torch.cat(outputs, 1)\n"
  },
  {
    "path": "models/archs/stylegan2/inversion.py",
    "content": "import argparse\nimport math\nimport os\n\nimport numpy as np\nimport torch\nfrom PIL import Image\nfrom torch import optim\nfrom torch.nn import functional as F\nfrom torchvision import transforms\nfrom tqdm import tqdm\n\nimport lpips\nfrom model import Generator\n\n\ndef noise_regularize(noises):\n    loss = 0\n\n    for noise in noises:\n        size = noise.shape[2]\n\n        while True:\n            loss = (\n                loss +\n                (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2) +\n                (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2))\n\n            if size <= 8:\n                break\n\n            noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2])\n            noise = noise.mean([3, 5])\n            size //= 2\n\n    return loss\n\n\ndef noise_normalize_(noises):\n    for noise in noises:\n        mean = noise.mean()\n        std = noise.std()\n\n        noise.data.add_(-mean).div_(std)\n\n\ndef get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):\n    lr_ramp = min(1, (1 - t) / rampdown)\n    lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)\n    lr_ramp = lr_ramp * min(1, t / rampup)\n\n    return initial_lr * lr_ramp\n\n\ndef latent_noise(latent, strength):\n    noise = torch.randn_like(latent) * strength\n\n    return latent + noise\n\n\ndef make_image(tensor):\n    return (tensor.detach().clamp_(min=-1, max=1).add(1).div_(2).mul(255).type(\n        torch.uint8).permute(0, 2, 3, 1).to(\"cpu\").numpy())\n\n\nif __name__ == \"__main__\":\n    device = \"cuda\"\n\n    parser = argparse.ArgumentParser(\n        description=\"Image projector to the generator latent spaces\")\n    parser.add_argument(\n        \"--ckpt\", type=str, required=True, help=\"path to the model checkpoint\")\n    parser.add_argument(\n        \"--size\",\n        type=int,\n        default=256,\n        help=\"output image sizes of the generator\")\n    parser.add_argument(\n        \"--lr_rampup\",\n        type=float,\n        default=0.05,\n        help=\"duration of the learning rate warmup\",\n    )\n    parser.add_argument(\n        \"--lr_rampdown\",\n        type=float,\n        default=0.25,\n        help=\"duration of the learning rate decay\",\n    )\n    parser.add_argument(\"--lr\", type=float, default=0.1, help=\"learning rate\")\n    parser.add_argument(\n        \"--noise\",\n        type=float,\n        default=0.05,\n        help=\"strength of the noise level\")\n    parser.add_argument(\n        \"--noise_ramp\",\n        type=float,\n        default=0.75,\n        help=\"duration of the noise level decay\",\n    )\n    parser.add_argument(\n        \"--step\", type=int, default=1000, help=\"optimize iterations\")\n    parser.add_argument(\n        \"--noise_regularize\",\n        type=float,\n        default=1e5,\n        help=\"weight of the noise regularization\",\n    )\n    parser.add_argument(\"--randomise_noise\", type=int, default=1)\n    parser.add_argument(\n        \"--img_mse_weight\",\n        type=float,\n        default=0,\n        help=\"weight of the mse loss\")\n    parser.add_argument(\n        \"files\",\n        metavar=\"FILES\",\n        nargs=\"+\",\n        help=\"path to image files to be projected\")\n    parser.add_argument(\"--output_dir\", type=str, required=True)\n\n    parser.add_argument(\n        \"--w_plus\",\n        action=\"store_true\",\n        help=\"allow to use distinct latent codes to each layers\",\n    )\n    parser.add_argument(\n        \"--postfix\", default='', type=str, help='postfix for filenames')\n    parser.add_argument(\n        \"--latent_type\",\n        required=True,\n        type=str,\n        help='z or w, not case sensitive')\n    parser.add_argument(\n        \"--w_path\", default='', type=str, help='path to w latent code')\n    parser.add_argument('--w_mse_weight', default=0, type=float)\n    parser.add_argument('--w_loss_type', default='mse', type=str)\n\n    args = parser.parse_args()\n\n    # latent space type\n    args.latent_type = args.latent_type.lower()\n    if args.latent_type == 'z':\n        args.input_is_latent = False\n    elif args.latent_type == 'w':\n        args.input_is_latent = True\n    else:\n        assert False, \"Unrecognized args.latent_type\"\n\n    n_mean_latent = 10000\n\n    resize = min(args.size, 256)\n\n    transform = transforms.Compose([\n        transforms.Resize(resize),\n        transforms.CenterCrop(resize),\n        transforms.ToTensor(),\n        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),\n    ])\n\n    imgs = []\n\n    for imgfile in args.files:\n        img = transform(Image.open(imgfile).convert(\"RGB\"))\n        imgs.append(img)\n\n    imgs = torch.stack(imgs, 0).to(device)\n\n    if args.w_mse_weight:\n        assert args.latent_type == 'z'\n        w_latent_code = np.load(args.w_path)\n        w_latent_code = torch.tensor(w_latent_code).to(device)\n\n    # g_ema = Generator(args.size, 512, 8) # ziqi modified\n    g_ema = Generator(args.size, 512, 8, 1)\n\n    g_ema.load_state_dict(torch.load(args.ckpt)[\"g_ema\"], strict=False)\n    g_ema.eval()\n    g_ema = g_ema.to(device)\n\n    with torch.no_grad():\n        noise_sample = torch.randn(n_mean_latent, 512, device=device)\n        latent_out = g_ema.style(noise_sample)\n\n        latent_mean = latent_out.mean(0)\n        latent_std = ((latent_out - latent_mean).pow(2).sum() /\n                      n_mean_latent)**0.5\n\n    percept = lpips.PerceptualLoss(\n        model=\"net-lin\", net=\"vgg\", use_gpu=device.startswith(\"cuda\"))\n\n    if args.latent_type == 'w':\n        latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(\n            imgs.shape[0], 1)\n    elif args.latent_type == 'z':\n        latent_in = noise_sample.mean(0).detach().clone().unsqueeze(0).repeat(\n            imgs.shape[0], 1)\n\n    if args.w_plus:\n        latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)\n\n    latent_in.requires_grad = True\n\n    if args.randomise_noise:\n        print('Noise term will be optimized together.')\n        noises_single = g_ema.make_noise()\n        noises = []\n        for noise in noises_single:\n            noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_())\n        for noise in noises:\n            noise.requires_grad = True\n        optimizer = optim.Adam(\n            [latent_in] + noises + [g_ema.parameters()], lr=args.lr)\n    else:\n        optim_params = []\n        for v in g_ema.parameters():\n            if v.requires_grad:\n                optim_params.append(v)\n        optimizer = optim.Adam([{\n            'params': [latent_in]\n        }, {\n            'params': optim_params,\n            'lr': 1e-4\n        }],\n                               lr=args.lr)\n\n    pbar = tqdm(range(args.step))\n    latent_path = []\n\n    for i in pbar:\n        t = i / args.step\n        lr = get_lr(t, args.lr)\n        optimizer.param_groups[0][\"lr\"] = lr\n        noise_strength = latent_std * args.noise * max(\n            0, 1 - t / args.noise_ramp)**2\n        if args.latent_type == 'z':\n            latent_w = g_ema.style(latent_in)\n            latent_n = latent_noise(latent_w, noise_strength.item())\n        else:\n            latent_n = latent_noise(latent_in, noise_strength.item())\n\n        if args.randomise_noise:\n            img_gen, _ = g_ema([latent_n], input_is_latent=True, noise=noises)\n        else:\n            img_gen, _ = g_ema([latent_n],\n                               input_is_latent=True,\n                               randomize_noise=False)\n\n        batch, channel, height, width = img_gen.shape\n\n        if height > 256:\n            factor = height // 256\n\n            img_gen = img_gen.reshape(batch, channel, height // factor, factor,\n                                      width // factor, factor)\n            img_gen = img_gen.mean([3, 5])\n\n        p_loss = percept(img_gen, imgs).sum()\n        mse_loss = F.mse_loss(img_gen, imgs)\n        if args.randomise_noise:\n            n_loss = noise_regularize(noises)\n        else:\n            n_loss = 0\n\n        loss = p_loss + args.noise_regularize * n_loss + args.img_mse_weight * mse_loss\n\n        if args.w_mse_weight > 0:\n            # this loss is only applicable to z space\n            assert args.latent_type == 'z'\n            if args.w_loss_type == 'mse':\n                w_mse_loss = F.mse_loss(latent_w, w_latent_code)\n            elif args.w_loss_type == 'l1':\n                w_mse_loss = F.l1_loss(latent_w, w_latent_code)\n            loss += args.w_mse_weight * w_mse_loss\n        else:\n            w_mse_loss = 0\n\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        if args.randomise_noise:\n            noise_normalize_(noises)\n\n        if (i + 1) % 100 == 0:\n            latent_path.append(latent_in.detach().clone())\n\n        pbar.set_description((\n            f\"total: {loss:.4f}; perceptual: {p_loss:.4f}; noise regularize: {n_loss:.4f};\"\n            f\" mse: {mse_loss:.4f}; w_mse_loss: {w_mse_loss:.4f}; lr: {lr:.4f}\"\n        ))\n\n    if args.randomise_noise:\n        img_gen, _ = g_ema([latent_path[-1]],\n                           input_is_latent=args.input_is_latent,\n                           noise=noises)\n    else:\n        img_gen, _ = g_ema([latent_path[-1]],\n                           input_is_latent=args.input_is_latent,\n                           randomize_noise=False)\n\n    filename = os.path.splitext(os.path.basename(args.files[0]))[0] + \".pt\"\n\n    img_ar = make_image(img_gen)\n\n    result_file = {}\n    for i, input_name in enumerate(args.files):\n        result_file[input_name] = {\"img\": img_gen[i], \"latent\": latent_in[i]}\n        if args.randomise_noise:\n            noise_single = []\n            for noise in noises:\n                noise_single.append(noise[i:i + 1])\n            result_file[input_name][\"noise\"] = noise_single\n\n        img_name = os.path.splitext(\n            os.path.basename(input_name)\n        )[0] + '_' + args.postfix + '-' + args.latent_type + \"-project.png\"\n        pil_img = Image.fromarray(img_ar[i])\n\n        # save image\n        if not os.path.isdir(os.path.join(args.output_dir, 'recovered_image')):\n            os.makedirs(\n                os.path.join(args.output_dir, 'recovered_image'),\n                exist_ok=False)\n        pil_img.save(\n            os.path.join(args.output_dir, 'recovered_image', img_name))\n\n        latent_code = latent_in[i].cpu()\n        latent_code = latent_code.detach().numpy()\n        latent_code = np.expand_dims(latent_code, axis=0)\n        print('latent_code:', len(latent_code), len(latent_code[0]))\n        # save latent code\n        if not os.path.isdir(os.path.join(args.output_dir, 'latent_codes')):\n            os.makedirs(\n                os.path.join(args.output_dir, 'latent_codes'), exist_ok=False)\n        np.save(\n            f'{args.output_dir}/latent_codes/{img_name}_{args.latent_type}.npz.npy',\n            latent_code)\n\n        if not os.path.isdir(os.path.join(args.output_dir, 'checkpoint')):\n            os.makedirs(\n                os.path.join(args.output_dir, 'checkpoint'), exist_ok=False)\n        torch.save(\n            {\n                \"g_ema\": g_ema.state_dict(),\n            },\n            f\"{os.path.join(args.output_dir, 'checkpoint')}/{img_name}_{args.latent_type}.pt\",\n        )\n\n    # save info\n    if not os.path.isdir(os.path.join(args.output_dir, 'pt')):\n        os.makedirs(os.path.join(args.output_dir, 'pt'), exist_ok=False)\n    torch.save(\n        result_file,\n        os.path.join(\n            args.output_dir,\n            os.path.join(args.output_dir, 'pt',\n                         filename + '_' + args.latent_type)))\n"
  },
  {
    "path": "models/archs/stylegan2/lpips/__init__.py",
    "content": "from __future__ import absolute_import, division, print_function\n\nimport numpy as np\nimport torch\nfrom models.archs.stylegan2.lpips import dist_model\nfrom skimage.measure import compare_ssim\n\n\nclass PerceptualLoss(torch.nn.Module):\n\n    def __init__(\n        self,\n        model='net-lin',\n        net='alex',\n        colorspace='rgb',\n        spatial=False,\n        use_gpu=True,\n        gpu_ids=[\n            0\n        ]):  # VGG using our perceptually-learned weights (LPIPS metric)\n        # def __init__(self, model='net', net='vgg', use_gpu=True): # \"default\" way of using VGG as a perceptual loss\n        super(PerceptualLoss, self).__init__()\n        self.use_gpu = use_gpu\n        self.spatial = spatial\n        self.gpu_ids = gpu_ids\n        self.model = dist_model.DistModel()\n        self.model.initialize(\n            model=model,\n            net=net,\n            use_gpu=use_gpu,\n            colorspace=colorspace,\n            spatial=self.spatial,\n            gpu_ids=gpu_ids)\n\n    def forward(self, pred, target, normalize=False):\n        \"\"\"\n        Pred and target are Variables.\n        If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]\n        If normalize is False, assumes the images are already between [-1,+1]\n\n        Inputs pred and target are Nx3xHxW\n        Output pytorch Variable N long\n        \"\"\"\n\n        if normalize:\n            target = 2 * target - 1\n            pred = 2 * pred - 1\n\n        return self.model.forward(target, pred)\n\n\ndef normalize_tensor(in_feat, eps=1e-10):\n    norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True))\n    return in_feat / (norm_factor + eps)\n\n\ndef l2(p0, p1, range=255.):\n    return .5 * np.mean((p0 / range - p1 / range)**2)\n\n\ndef psnr(p0, p1, peak=255.):\n    return 10 * np.log10(peak**2 / np.mean((1. * p0 - 1. * p1)**2))\n\n\ndef dssim(p0, p1, range=255.):\n    return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.\n\n\ndef rgb2lab(in_img, mean_cent=False):\n    from skimage import color\n    img_lab = color.rgb2lab(in_img)\n    if (mean_cent):\n        img_lab[:, :, 0] = img_lab[:, :, 0] - 50\n    return img_lab\n\n\ndef tensor2np(tensor_obj):\n    # change dimension of a tensor object into a numpy array\n    return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0))\n\n\ndef np2tensor(np_obj):\n    # change dimenion of np array into tensor array\n    return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))\n\n\ndef tensor2tensorlab(image_tensor, to_norm=True, mc_only=False):\n    # image tensor to lab tensor\n    from skimage import color\n\n    img = tensor2im(image_tensor)\n    img_lab = color.rgb2lab(img)\n    if (mc_only):\n        img_lab[:, :, 0] = img_lab[:, :, 0] - 50\n    if (to_norm and not mc_only):\n        img_lab[:, :, 0] = img_lab[:, :, 0] - 50\n        img_lab = img_lab / 100.\n\n    return np2tensor(img_lab)\n\n\ndef tensorlab2tensor(lab_tensor, return_inbnd=False):\n    import warnings\n\n    from skimage import color\n    warnings.filterwarnings(\"ignore\")\n\n    lab = tensor2np(lab_tensor) * 100.\n    lab[:, :, 0] = lab[:, :, 0] + 50\n\n    rgb_back = 255. * np.clip(color.lab2rgb(lab.astype('float')), 0, 1)\n    if (return_inbnd):\n        # convert back to lab, see if we match\n        lab_back = color.rgb2lab(rgb_back.astype('uint8'))\n        mask = 1. * np.isclose(lab_back, lab, atol=2.)\n        mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis])\n        return (im2tensor(rgb_back), mask)\n    else:\n        return im2tensor(rgb_back)\n\n\ndef rgb2lab(input):\n    from skimage import color\n    return color.rgb2lab(input / 255.)\n\n\ndef tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.):\n    image_numpy = image_tensor[0].cpu().float().numpy()\n    image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor\n    return image_numpy.astype(imtype)\n\n\ndef im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.):\n    return torch.Tensor((image / factor - cent)[:, :, :, np.newaxis].transpose(\n        (3, 2, 0, 1)))\n\n\ndef tensor2vec(vector_tensor):\n    return vector_tensor.data.cpu().numpy()[:, :, 0, 0]\n\n\ndef voc_ap(rec, prec, use_07_metric=False):\n    \"\"\" ap = voc_ap(rec, prec, [use_07_metric])\n    Compute VOC AP given precision and recall.\n    If use_07_metric is true, uses the\n    VOC 07 11 point method (default:False).\n    \"\"\"\n    if use_07_metric:\n        # 11 point metric\n        ap = 0.\n        for t in np.arange(0., 1.1, 0.1):\n            if np.sum(rec >= t) == 0:\n                p = 0\n            else:\n                p = np.max(prec[rec >= t])\n            ap = ap + p / 11.\n    else:\n        # correct AP calculation\n        # first append sentinel values at the end\n        mrec = np.concatenate(([0.], rec, [1.]))\n        mpre = np.concatenate(([0.], prec, [0.]))\n\n        # compute the precision envelope\n        for i in range(mpre.size - 1, 0, -1):\n            mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])\n\n        # to calculate area under PR curve, look for points\n        # where X axis (recall) changes value\n        i = np.where(mrec[1:] != mrec[:-1])[0]\n\n        # and sum (\\Delta recall) * prec\n        ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])\n    return ap\n\n\ndef tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.):\n    # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):\n    image_numpy = image_tensor[0].cpu().float().numpy()\n    image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor\n    return image_numpy.astype(imtype)\n\n\ndef im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.):\n    # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):\n    return torch.Tensor((image / factor - cent)[:, :, :, np.newaxis].transpose(\n        (3, 2, 0, 1)))\n"
  },
  {
    "path": "models/archs/stylegan2/lpips/base_model.py",
    "content": "import os\n\nimport numpy as np\nimport torch\n\n\nclass BaseModel():\n\n    def __init__(self):\n        pass\n\n    def name(self):\n        return 'BaseModel'\n\n    def initialize(self, use_gpu=True, gpu_ids=[0]):\n        self.use_gpu = use_gpu\n        self.gpu_ids = gpu_ids\n\n    def forward(self):\n        pass\n\n    def get_image_paths(self):\n        pass\n\n    def optimize_parameters(self):\n        pass\n\n    def get_current_visuals(self):\n        return self.input\n\n    def get_current_errors(self):\n        return {}\n\n    def save(self, label):\n        pass\n\n    # helper saving function that can be used by subclasses\n    def save_network(self, network, path, network_label, epoch_label):\n        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)\n        save_path = os.path.join(path, save_filename)\n        torch.save(network.state_dict(), save_path)\n\n    # helper loading function that can be used by subclasses\n    def load_network(self, network, network_label, epoch_label):\n        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)\n        save_path = os.path.join(self.save_dir, save_filename)\n        print('Loading network from %s' % save_path)\n        network.load_state_dict(torch.load(save_path))\n\n    def update_learning_rate():\n        pass\n\n    def get_image_paths(self):\n        return self.image_paths\n\n    def save_done(self, flag=False):\n        np.save(os.path.join(self.save_dir, 'done_flag'), flag)\n        np.savetxt(\n            os.path.join(self.save_dir, 'done_flag'), [\n                flag,\n            ], fmt='%i')\n"
  },
  {
    "path": "models/archs/stylegan2/lpips/dist_model.py",
    "content": "from __future__ import absolute_import\n\nimport os\nfrom collections import OrderedDict\n\nimport models.archs.stylegan2.lpips as util\nimport numpy as np\nimport torch\nfrom scipy.ndimage import zoom\nfrom torch.autograd import Variable\nfrom tqdm import tqdm\n\nfrom . import networks_basic as networks\nfrom .base_model import BaseModel\n\n\nclass DistModel(BaseModel):\n\n    def name(self):\n        return self.model_name\n\n    def initialize(self,\n                   model='net-lin',\n                   net='alex',\n                   colorspace='Lab',\n                   pnet_rand=False,\n                   pnet_tune=False,\n                   model_path=None,\n                   use_gpu=True,\n                   printNet=False,\n                   spatial=False,\n                   is_train=False,\n                   lr=.0001,\n                   beta1=0.5,\n                   version='0.1',\n                   gpu_ids=[0]):\n        '''\n        INPUTS\n            model - ['net-lin'] for linearly calibrated network\n                    ['net'] for off-the-shelf network\n                    ['L2'] for L2 distance in Lab colorspace\n                    ['SSIM'] for ssim in RGB colorspace\n            net - ['squeeze','alex','vgg']\n            model_path - if None, will look in weights/[NET_NAME].pth\n            colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM\n            use_gpu - bool - whether or not to use a GPU\n            printNet - bool - whether or not to print network architecture out\n            spatial - bool - whether to output an array containing varying distances across spatial dimensions\n            spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).\n            spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.\n            spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).\n            is_train - bool - [True] for training mode\n            lr - float - initial learning rate\n            beta1 - float - initial momentum term for adam\n            version - 0.1 for latest, 0.0 was original (with a bug)\n            gpu_ids - int array - [0] by default, gpus to use\n        '''\n        BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids)\n\n        self.model = model\n        self.net = net\n        self.is_train = is_train\n        self.spatial = spatial\n        self.gpu_ids = gpu_ids\n        self.model_name = '%s [%s]' % (model, net)\n\n        if (self.model == 'net-lin'):  # pretrained net + linear layer\n            self.net = networks.PNetLin(\n                pnet_rand=pnet_rand,\n                pnet_tune=pnet_tune,\n                pnet_type=net,\n                use_dropout=True,\n                spatial=spatial,\n                version=version,\n                lpips=True)\n            kw = {}\n            if not use_gpu:\n                kw['map_location'] = 'cpu'\n            if (model_path is None):\n                import inspect\n                model_path = os.path.abspath(\n                    os.path.join(\n                        inspect.getfile(self.initialize), '..',\n                        'weights/v%s/%s.pth' % (version, net)))\n\n            if (not is_train):\n                print('Loading model from: %s' % model_path)\n                self.net.load_state_dict(\n                    torch.load(model_path, **kw), strict=False)\n\n        elif (self.model == 'net'):  # pretrained network\n            self.net = networks.PNetLin(\n                pnet_rand=pnet_rand, pnet_type=net, lpips=False)\n        elif (self.model in ['L2', 'l2']):\n            self.net = networks.L2(\n                use_gpu=use_gpu, colorspace=colorspace\n            )  # not really a network, only for testing\n            self.model_name = 'L2'\n        elif (self.model in ['DSSIM', 'dssim', 'SSIM', 'ssim']):\n            self.net = networks.DSSIM(use_gpu=use_gpu, colorspace=colorspace)\n            self.model_name = 'SSIM'\n        else:\n            raise ValueError(\"Model [%s] not recognized.\" % self.model)\n\n        self.parameters = list(self.net.parameters())\n\n        if self.is_train:  # training mode\n            # extra network on top to go from distances (d0,d1) => predicted human judgment (h*)\n            self.rankLoss = networks.BCERankingLoss()\n            self.parameters += list(self.rankLoss.net.parameters())\n            self.lr = lr\n            self.old_lr = lr\n            self.optimizer_net = torch.optim.Adam(\n                self.parameters, lr=lr, betas=(beta1, 0.999))\n        else:  # test mode\n            self.net.eval()\n\n        if (use_gpu):\n            self.net.to(gpu_ids[0])\n            self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)\n            if (self.is_train):\n                self.rankLoss = self.rankLoss.to(\n                    device=gpu_ids[0])  # just put this on GPU0\n\n        if (printNet):\n            print('---------- Networks initialized -------------')\n            networks.print_network(self.net)\n            print('-----------------------------------------------')\n\n    def forward(self, in0, in1, retPerLayer=False):\n        ''' Function computes the distance between image patches in0 and in1\n        INPUTS\n            in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]\n        OUTPUT\n            computed distances between in0 and in1\n        '''\n\n        return self.net.forward(in0, in1, retPerLayer=retPerLayer)\n\n    # ***** TRAINING FUNCTIONS *****\n    def optimize_parameters(self):\n        self.forward_train()\n        self.optimizer_net.zero_grad()\n        self.backward_train()\n        self.optimizer_net.step()\n        self.clamp_weights()\n\n    def clamp_weights(self):\n        for module in self.net.modules():\n            if (hasattr(module, 'weight') and module.kernel_size == (1, 1)):\n                module.weight.data = torch.clamp(module.weight.data, min=0)\n\n    def set_input(self, data):\n        self.input_ref = data['ref']\n        self.input_p0 = data['p0']\n        self.input_p1 = data['p1']\n        self.input_judge = data['judge']\n\n        if (self.use_gpu):\n            self.input_ref = self.input_ref.to(device=self.gpu_ids[0])\n            self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])\n            self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])\n            self.input_judge = self.input_judge.to(device=self.gpu_ids[0])\n\n        self.var_ref = Variable(self.input_ref, requires_grad=True)\n        self.var_p0 = Variable(self.input_p0, requires_grad=True)\n        self.var_p1 = Variable(self.input_p1, requires_grad=True)\n\n    def forward_train(self):  # run forward pass\n        # print(self.net.module.scaling_layer.shift)\n        # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())\n\n        self.d0 = self.forward(self.var_ref, self.var_p0)\n        self.d1 = self.forward(self.var_ref, self.var_p1)\n        self.acc_r = self.compute_accuracy(self.d0, self.d1, self.input_judge)\n\n        self.var_judge = Variable(1. * self.input_judge).view(self.d0.size())\n\n        self.loss_total = self.rankLoss.forward(self.d0, self.d1,\n                                                self.var_judge * 2. - 1.)\n\n        return self.loss_total\n\n    def backward_train(self):\n        torch.mean(self.loss_total).backward()\n\n    def compute_accuracy(self, d0, d1, judge):\n        ''' d0, d1 are Variables, judge is a Tensor '''\n        d1_lt_d0 = (d1 < d0).cpu().data.numpy().flatten()\n        judge_per = judge.cpu().numpy().flatten()\n        return d1_lt_d0 * judge_per + (1 - d1_lt_d0) * (1 - judge_per)\n\n    def get_current_errors(self):\n        retDict = OrderedDict([('loss_total',\n                                self.loss_total.data.cpu().numpy()),\n                               ('acc_r', self.acc_r)])\n\n        for key in retDict.keys():\n            retDict[key] = np.mean(retDict[key])\n\n        return retDict\n\n    def get_current_visuals(self):\n        zoom_factor = 256 / self.var_ref.data.size()[2]\n\n        ref_img = util.tensor2im(self.var_ref.data)\n        p0_img = util.tensor2im(self.var_p0.data)\n        p1_img = util.tensor2im(self.var_p1.data)\n\n        ref_img_vis = zoom(ref_img, [zoom_factor, zoom_factor, 1], order=0)\n        p0_img_vis = zoom(p0_img, [zoom_factor, zoom_factor, 1], order=0)\n        p1_img_vis = zoom(p1_img, [zoom_factor, zoom_factor, 1], order=0)\n\n        return OrderedDict([('ref', ref_img_vis), ('p0', p0_img_vis),\n                            ('p1', p1_img_vis)])\n\n    def save(self, path, label):\n        if (self.use_gpu):\n            self.save_network(self.net.module, path, '', label)\n        else:\n            self.save_network(self.net, path, '', label)\n        self.save_network(self.rankLoss.net, path, 'rank', label)\n\n    def update_learning_rate(self, nepoch_decay):\n        lrd = self.lr / nepoch_decay\n        lr = self.old_lr - lrd\n\n        for param_group in self.optimizer_net.param_groups:\n            param_group['lr'] = lr\n\n        print('update lr [%s] decay: %f -> %f' % (type, self.old_lr, lr))\n        self.old_lr = lr\n\n\ndef score_2afc_dataset(data_loader, func, name=''):\n    ''' Function computes Two Alternative Forced Choice (2AFC) score using\n        distance function 'func' in dataset 'data_loader'\n    INPUTS\n        data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside\n        func - callable distance function - calling d=func(in0,in1) should take 2\n            pytorch tensors with shape Nx3xXxY, and return numpy array of length N\n    OUTPUTS\n        [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators\n        [1] - dictionary with following elements\n            d0s,d1s - N arrays containing distances between reference patch to perturbed patches\n            gts - N array in [0,1], preferred patch selected by human evaluators\n                (closer to \"0\" for left patch p0, \"1\" for right patch p1,\n                \"0.6\" means 60pct people preferred right patch, 40pct preferred left)\n            scores - N array in [0,1], corresponding to what percentage function agreed with humans\n    CONSTS\n        N - number of test triplets in data_loader\n    '''\n\n    d0s = []\n    d1s = []\n    gts = []\n\n    for data in tqdm(data_loader.load_data(), desc=name):\n        d0s += func(data['ref'],\n                    data['p0']).data.cpu().numpy().flatten().tolist()\n        d1s += func(data['ref'],\n                    data['p1']).data.cpu().numpy().flatten().tolist()\n        gts += data['judge'].cpu().numpy().flatten().tolist()\n\n    d0s = np.array(d0s)\n    d1s = np.array(d1s)\n    gts = np.array(gts)\n    scores = (d0s < d1s) * (1. - gts) + (d1s < d0s) * gts + (d1s == d0s) * .5\n\n    return (np.mean(scores), dict(d0s=d0s, d1s=d1s, gts=gts, scores=scores))\n\n\ndef score_jnd_dataset(data_loader, func, name=''):\n    ''' Function computes JND score using distance function 'func' in dataset 'data_loader'\n    INPUTS\n        data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside\n        func - callable distance function - calling d=func(in0,in1) should take 2\n            pytorch tensors with shape Nx3xXxY, and return pytorch array of length N\n    OUTPUTS\n        [0] - JND score in [0,1], mAP score (area under precision-recall curve)\n        [1] - dictionary with following elements\n            ds - N array containing distances between two patches shown to human evaluator\n            sames - N array containing fraction of people who thought the two patches were identical\n    CONSTS\n        N - number of test triplets in data_loader\n    '''\n\n    ds = []\n    gts = []\n\n    for data in tqdm(data_loader.load_data(), desc=name):\n        ds += func(data['p0'], data['p1']).data.cpu().numpy().tolist()\n        gts += data['same'].cpu().numpy().flatten().tolist()\n\n    sames = np.array(gts)\n    ds = np.array(ds)\n\n    sorted_inds = np.argsort(ds)\n    ds_sorted = ds[sorted_inds]\n    sames_sorted = sames[sorted_inds]\n\n    TPs = np.cumsum(sames_sorted)\n    FPs = np.cumsum(1 - sames_sorted)\n    FNs = np.sum(sames_sorted) - TPs\n\n    precs = TPs / (TPs + FPs)\n    recs = TPs / (TPs + FNs)\n    score = util.voc_ap(recs, precs)\n\n    return (score, dict(ds=ds, sames=sames))\n"
  },
  {
    "path": "models/archs/stylegan2/lpips/networks_basic.py",
    "content": "from __future__ import absolute_import\n\nimport models.archs.stylegan2.lpips as util\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Variable\n\nfrom . import pretrained_networks as pn\n\n\ndef spatial_average(in_tens, keepdim=True):\n    return in_tens.mean([2, 3], keepdim=keepdim)\n\n\ndef upsample(in_tens, out_H=64):  # assumes scale factor is same for H and W\n    in_H = in_tens.shape[2]\n    scale_factor = 1. * out_H / in_H\n\n    return nn.Upsample(\n        scale_factor=scale_factor, mode='bilinear', align_corners=False)(\n            in_tens)\n\n\n# Learned perceptual metric\nclass PNetLin(nn.Module):\n\n    def __init__(self,\n                 pnet_type='vgg',\n                 pnet_rand=False,\n                 pnet_tune=False,\n                 use_dropout=True,\n                 spatial=False,\n                 version='0.1',\n                 lpips=True):\n        super(PNetLin, self).__init__()\n\n        self.pnet_type = pnet_type\n        self.pnet_tune = pnet_tune\n        self.pnet_rand = pnet_rand\n        self.spatial = spatial\n        self.lpips = lpips\n        self.version = version\n        self.scaling_layer = ScalingLayer()\n\n        if (self.pnet_type in ['vgg', 'vgg16']):\n            net_type = pn.vgg16\n            self.chns = [64, 128, 256, 512, 512]\n        elif (self.pnet_type == 'alex'):\n            net_type = pn.alexnet\n            self.chns = [64, 192, 384, 256, 256]\n        elif (self.pnet_type == 'squeeze'):\n            net_type = pn.squeezenet\n            self.chns = [64, 128, 256, 384, 384, 512, 512]\n        self.L = len(self.chns)\n\n        self.net = net_type(\n            pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)\n\n        if (lpips):\n            self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)\n            self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)\n            self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)\n            self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)\n            self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)\n            self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]\n            if (self.pnet_type == 'squeeze'):  # 7 layers for squeezenet\n                self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)\n                self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)\n                self.lins += [self.lin5, self.lin6]\n\n    def forward(self, in0, in1, retPerLayer=False):\n        # v0.0 - original release had a bug, where input was not scaled\n        in0_input, in1_input = (\n            self.scaling_layer(in0),\n            self.scaling_layer(in1)) if self.version == '0.1' else (in0, in1)\n        outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)\n        feats0, feats1, diffs = {}, {}, {}\n\n        for kk in range(self.L):\n            feats0[kk], feats1[kk] = util.normalize_tensor(\n                outs0[kk]), util.normalize_tensor(outs1[kk])\n            diffs[kk] = (feats0[kk] - feats1[kk])**2\n\n        if (self.lpips):\n            if (self.spatial):\n                res = [\n                    upsample(\n                        self.lins[kk].model(diffs[kk]), out_H=in0.shape[2])\n                    for kk in range(self.L)\n                ]\n            else:\n                res = [\n                    spatial_average(\n                        self.lins[kk].model(diffs[kk]), keepdim=True)\n                    for kk in range(self.L)\n                ]\n        else:\n            if (self.spatial):\n                res = [\n                    upsample(\n                        diffs[kk].sum(dim=1, keepdim=True), out_H=in0.shape[2])\n                    for kk in range(self.L)\n                ]\n            else:\n                res = [\n                    spatial_average(\n                        diffs[kk].sum(dim=1, keepdim=True), keepdim=True)\n                    for kk in range(self.L)\n                ]\n\n        val = res[0]\n        for l in range(1, self.L):\n            val += res[l]\n\n        if (retPerLayer):\n            return (val, res)\n        else:\n            return val\n\n\nclass ScalingLayer(nn.Module):\n\n    def __init__(self):\n        super(ScalingLayer, self).__init__()\n        self.register_buffer(\n            'shift',\n            torch.Tensor([-.030, -.088, -.188])[None, :, None, None])\n        self.register_buffer(\n            'scale',\n            torch.Tensor([.458, .448, .450])[None, :, None, None])\n\n    def forward(self, inp):\n        return (inp - self.shift) / self.scale\n\n\nclass NetLinLayer(nn.Module):\n    ''' A single linear layer which does a 1x1 conv '''\n\n    def __init__(self, chn_in, chn_out=1, use_dropout=False):\n        super(NetLinLayer, self).__init__()\n\n        layers = [\n            nn.Dropout(),\n        ] if (use_dropout) else []\n        layers += [\n            nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),\n        ]\n        self.model = nn.Sequential(*layers)\n\n\nclass Dist2LogitLayer(nn.Module):\n    ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''\n\n    def __init__(self, chn_mid=32, use_sigmoid=True):\n        super(Dist2LogitLayer, self).__init__()\n\n        layers = [\n            nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),\n        ]\n        layers += [\n            nn.LeakyReLU(0.2, True),\n        ]\n        layers += [\n            nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),\n        ]\n        layers += [\n            nn.LeakyReLU(0.2, True),\n        ]\n        layers += [\n            nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),\n        ]\n        if (use_sigmoid):\n            layers += [\n                nn.Sigmoid(),\n            ]\n        self.model = nn.Sequential(*layers)\n\n    def forward(self, d0, d1, eps=0.1):\n        return self.model.forward(\n            torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)),\n                      dim=1))\n\n\nclass BCERankingLoss(nn.Module):\n\n    def __init__(self, chn_mid=32):\n        super(BCERankingLoss, self).__init__()\n        self.net = Dist2LogitLayer(chn_mid=chn_mid)\n        # self.parameters = list(self.net.parameters())\n        self.loss = torch.nn.BCELoss()\n\n    def forward(self, d0, d1, judge):\n        per = (judge + 1.) / 2.\n        self.logit = self.net.forward(d0, d1)\n        return self.loss(self.logit, per)\n\n\n# L2, DSSIM metrics\nclass FakeNet(nn.Module):\n\n    def __init__(self, use_gpu=True, colorspace='Lab'):\n        super(FakeNet, self).__init__()\n        self.use_gpu = use_gpu\n        self.colorspace = colorspace\n\n\nclass L2(FakeNet):\n\n    def forward(self, in0, in1, retPerLayer=None):\n        assert (in0.size()[0] == 1)  # currently only supports batchSize 1\n\n        if (self.colorspace == 'RGB'):\n            (N, C, X, Y) = in0.size()\n            value = torch.mean(\n                torch.mean(\n                    torch.mean((in0 - in1)**2, dim=1).view(N, 1, X, Y),\n                    dim=2).view(N, 1, 1, Y),\n                dim=3).view(N)\n            return value\n        elif (self.colorspace == 'Lab'):\n            value = util.l2(\n                util.tensor2np(util.tensor2tensorlab(in0.data, to_norm=False)),\n                util.tensor2np(util.tensor2tensorlab(in1.data, to_norm=False)),\n                range=100.).astype('float')\n            ret_var = Variable(torch.Tensor((value, )))\n            if (self.use_gpu):\n                ret_var = ret_var.cuda()\n            return ret_var\n\n\nclass DSSIM(FakeNet):\n\n    def forward(self, in0, in1, retPerLayer=None):\n        assert (in0.size()[0] == 1)  # currently only supports batchSize 1\n\n        if (self.colorspace == 'RGB'):\n            value = util.dssim(\n                1. * util.tensor2im(in0.data),\n                1. * util.tensor2im(in1.data),\n                range=255.).astype('float')\n        elif (self.colorspace == 'Lab'):\n            value = util.dssim(\n                util.tensor2np(util.tensor2tensorlab(in0.data, to_norm=False)),\n                util.tensor2np(util.tensor2tensorlab(in1.data, to_norm=False)),\n                range=100.).astype('float')\n        ret_var = Variable(torch.Tensor((value, )))\n        if (self.use_gpu):\n            ret_var = ret_var.cuda()\n        return ret_var\n\n\ndef print_network(net):\n    num_params = 0\n    for param in net.parameters():\n        num_params += param.numel()\n    print('Network', net)\n    print('Total number of parameters: %d' % num_params)\n"
  },
  {
    "path": "models/archs/stylegan2/lpips/pretrained_networks.py",
    "content": "from collections import namedtuple\n\nimport torch\nfrom torchvision import models as tv\n\n\nclass squeezenet(torch.nn.Module):\n\n    def __init__(self, requires_grad=False, pretrained=True):\n        super(squeezenet, self).__init__()\n        pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        self.slice6 = torch.nn.Sequential()\n        self.slice7 = torch.nn.Sequential()\n        self.N_slices = 7\n        for x in range(2):\n            self.slice1.add_module(str(x), pretrained_features[x])\n        for x in range(2, 5):\n            self.slice2.add_module(str(x), pretrained_features[x])\n        for x in range(5, 8):\n            self.slice3.add_module(str(x), pretrained_features[x])\n        for x in range(8, 10):\n            self.slice4.add_module(str(x), pretrained_features[x])\n        for x in range(10, 11):\n            self.slice5.add_module(str(x), pretrained_features[x])\n        for x in range(11, 12):\n            self.slice6.add_module(str(x), pretrained_features[x])\n        for x in range(12, 13):\n            self.slice7.add_module(str(x), pretrained_features[x])\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, X):\n        h = self.slice1(X)\n        h_relu1 = h\n        h = self.slice2(h)\n        h_relu2 = h\n        h = self.slice3(h)\n        h_relu3 = h\n        h = self.slice4(h)\n        h_relu4 = h\n        h = self.slice5(h)\n        h_relu5 = h\n        h = self.slice6(h)\n        h_relu6 = h\n        h = self.slice7(h)\n        h_relu7 = h\n        vgg_outputs = namedtuple(\n            \"SqueezeOutputs\",\n            ['relu1', 'relu2', 'relu3', 'relu4', 'relu5', 'relu6', 'relu7'])\n        out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6,\n                          h_relu7)\n\n        return out\n\n\nclass alexnet(torch.nn.Module):\n\n    def __init__(self, requires_grad=False, pretrained=True):\n        super(alexnet, self).__init__()\n        alexnet_pretrained_features = tv.alexnet(\n            pretrained=pretrained).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        self.N_slices = 5\n        for x in range(2):\n            self.slice1.add_module(str(x), alexnet_pretrained_features[x])\n        for x in range(2, 5):\n            self.slice2.add_module(str(x), alexnet_pretrained_features[x])\n        for x in range(5, 8):\n            self.slice3.add_module(str(x), alexnet_pretrained_features[x])\n        for x in range(8, 10):\n            self.slice4.add_module(str(x), alexnet_pretrained_features[x])\n        for x in range(10, 12):\n            self.slice5.add_module(str(x), alexnet_pretrained_features[x])\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, X):\n        h = self.slice1(X)\n        h_relu1 = h\n        h = self.slice2(h)\n        h_relu2 = h\n        h = self.slice3(h)\n        h_relu3 = h\n        h = self.slice4(h)\n        h_relu4 = h\n        h = self.slice5(h)\n        h_relu5 = h\n        alexnet_outputs = namedtuple(\n            \"AlexnetOutputs\", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])\n        out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)\n\n        return out\n\n\nclass vgg16(torch.nn.Module):\n\n    def __init__(self, requires_grad=False, pretrained=True):\n        super(vgg16, self).__init__()\n        vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        self.N_slices = 5\n        for x in range(4):\n            self.slice1.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(4, 9):\n            self.slice2.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(9, 16):\n            self.slice3.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(16, 23):\n            self.slice4.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(23, 30):\n            self.slice5.add_module(str(x), vgg_pretrained_features[x])\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, X):\n        h = self.slice1(X)\n        h_relu1_2 = h\n        h = self.slice2(h)\n        h_relu2_2 = h\n        h = self.slice3(h)\n        h_relu3_3 = h\n        h = self.slice4(h)\n        h_relu4_3 = h\n        h = self.slice5(h)\n        h_relu5_3 = h\n        vgg_outputs = namedtuple(\n            \"VggOutputs\",\n            ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])\n        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3,\n                          h_relu5_3)\n\n        return out\n\n\nclass resnet(torch.nn.Module):\n\n    def __init__(self, requires_grad=False, pretrained=True, num=18):\n        super(resnet, self).__init__()\n        if (num == 18):\n            self.net = tv.resnet18(pretrained=pretrained)\n        elif (num == 34):\n            self.net = tv.resnet34(pretrained=pretrained)\n        elif (num == 50):\n            self.net = tv.resnet50(pretrained=pretrained)\n        elif (num == 101):\n            self.net = tv.resnet101(pretrained=pretrained)\n        elif (num == 152):\n            self.net = tv.resnet152(pretrained=pretrained)\n        self.N_slices = 5\n\n        self.conv1 = self.net.conv1\n        self.bn1 = self.net.bn1\n        self.relu = self.net.relu\n        self.maxpool = self.net.maxpool\n        self.layer1 = self.net.layer1\n        self.layer2 = self.net.layer2\n        self.layer3 = self.net.layer3\n        self.layer4 = self.net.layer4\n\n    def forward(self, X):\n        h = self.conv1(X)\n        h = self.bn1(h)\n        h = self.relu(h)\n        h_relu1 = h\n        h = self.maxpool(h)\n        h = self.layer1(h)\n        h_conv2 = h\n        h = self.layer2(h)\n        h_conv3 = h\n        h = self.layer3(h)\n        h_conv4 = h\n        h = self.layer4(h)\n        h_conv5 = h\n\n        outputs = namedtuple(\"Outputs\",\n                             ['relu1', 'conv2', 'conv3', 'conv4', 'conv5'])\n        out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)\n\n        return out\n"
  },
  {
    "path": "models/archs/stylegan2/model.py",
    "content": "import functools\nimport math\nimport operator\nimport random\nimport sys\n\nimport torch\nfrom models.archs.stylegan2.op import (FusedLeakyReLU, fused_leaky_relu,\n                                       upfirdn2d)\nfrom torch import nn\nfrom torch.autograd import Function\nfrom torch.nn import functional as F\n\n\nclass PixelNorm(nn.Module):\n\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, input):\n        return input * torch.rsqrt(\n            torch.mean(input**2, dim=1, keepdim=True) + 1e-8)\n\n\ndef make_kernel(k):\n    k = torch.tensor(k, dtype=torch.float32)\n\n    if k.ndim == 1:\n        k = k[None, :] * k[:, None]\n\n    k /= k.sum()\n\n    return k\n\n\nclass Upsample(nn.Module):\n\n    def __init__(self, kernel, factor=2):\n        super().__init__()\n\n        self.factor = factor\n        kernel = make_kernel(kernel) * (factor**2)\n        self.register_buffer(\"kernel\", kernel)\n\n        p = kernel.shape[0] - factor\n\n        pad0 = (p + 1) // 2 + factor - 1\n        pad1 = p // 2\n\n        self.pad = (pad0, pad1)\n\n    def forward(self, input):\n        out = upfirdn2d(\n            input, self.kernel, up=self.factor, down=1, pad=self.pad)\n\n        return out\n\n\nclass Downsample(nn.Module):\n\n    def __init__(self, kernel, factor=2):\n        super().__init__()\n\n        self.factor = factor\n        kernel = make_kernel(kernel)\n        self.register_buffer(\"kernel\", kernel)\n\n        p = kernel.shape[0] - factor\n\n        pad0 = (p + 1) // 2\n        pad1 = p // 2\n\n        self.pad = (pad0, pad1)\n\n    def forward(self, input):\n        out = upfirdn2d(\n            input, self.kernel, up=1, down=self.factor, pad=self.pad)\n\n        return out\n\n\nclass Blur(nn.Module):\n\n    def __init__(self, kernel, pad, upsample_factor=1):\n        super().__init__()\n\n        kernel = make_kernel(kernel)\n\n        if upsample_factor > 1:\n            kernel = kernel * (upsample_factor**2)\n\n        self.register_buffer(\"kernel\", kernel)\n\n        self.pad = pad\n\n    def forward(self, input):\n        out = upfirdn2d(input, self.kernel, pad=self.pad)\n\n        return out\n\n\nclass EqualConv2d(nn.Module):\n\n    def __init__(self,\n                 in_channel,\n                 out_channel,\n                 kernel_size,\n                 stride=1,\n                 padding=0,\n                 bias=True):\n        super().__init__()\n\n        self.weight = nn.Parameter(\n            torch.randn(out_channel, in_channel, kernel_size, kernel_size))\n        self.scale = 1 / math.sqrt(in_channel * kernel_size**2)\n\n        self.stride = stride\n        self.padding = padding\n\n        if bias:\n            self.bias = nn.Parameter(torch.zeros(out_channel))\n\n        else:\n            self.bias = None\n\n    def forward(self, input):\n        out = F.conv2d(\n            input,\n            self.weight * self.scale,\n            bias=self.bias,\n            stride=self.stride,\n            padding=self.padding,\n        )\n\n        return out\n\n    def __repr__(self):\n        return (\n            f\"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},\"\n            f\" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})\"\n        )\n\n\nclass EqualLinear(nn.Module):\n\n    def __init__(self,\n                 in_dim,\n                 out_dim,\n                 bias=True,\n                 bias_init=0,\n                 lr_mul=1,\n                 activation=None):\n        super().__init__()\n\n        self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))\n\n        if bias:\n            self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))\n\n        else:\n            self.bias = None\n\n        self.activation = activation\n\n        self.scale = (1 / math.sqrt(in_dim)) * lr_mul\n        self.lr_mul = lr_mul\n\n    def forward(self, input):\n        if self.activation:\n            out = F.linear(input, self.weight * self.scale)\n            out = fused_leaky_relu(out, self.bias * self.lr_mul)\n\n        else:\n            out = F.linear(\n                input, self.weight * self.scale, bias=self.bias * self.lr_mul)\n\n        return out\n\n    def __repr__(self):\n        return (\n            f\"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})\"\n        )\n\n\nclass ModulatedConv2d(nn.Module):\n\n    def __init__(\n        self,\n        in_channel,\n        out_channel,\n        kernel_size,\n        style_dim,\n        demodulate=True,\n        upsample=False,\n        downsample=False,\n        blur_kernel=[1, 3, 3, 1],\n    ):\n        super().__init__()\n\n        self.eps = 1e-8\n        self.kernel_size = kernel_size\n        self.in_channel = in_channel\n        self.out_channel = out_channel\n        self.upsample = upsample\n        self.downsample = downsample\n\n        if upsample:\n            factor = 2\n            p = (len(blur_kernel) - factor) - (kernel_size - 1)\n            pad0 = (p + 1) // 2 + factor - 1\n            pad1 = p // 2 + 1\n\n            self.blur = Blur(\n                blur_kernel, pad=(pad0, pad1), upsample_factor=factor)\n\n        if downsample:\n            factor = 2\n            p = (len(blur_kernel) - factor) + (kernel_size - 1)\n            pad0 = (p + 1) // 2\n            pad1 = p // 2\n\n            self.blur = Blur(blur_kernel, pad=(pad0, pad1))\n\n        fan_in = in_channel * kernel_size**2\n        self.scale = 1 / math.sqrt(fan_in)\n        self.padding = kernel_size // 2\n\n        self.weight = nn.Parameter(\n            torch.randn(1, out_channel, in_channel, kernel_size, kernel_size))\n\n        self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)\n\n        self.demodulate = demodulate\n\n    def __repr__(self):\n        return (\n            f\"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, \"\n            f\"upsample={self.upsample}, downsample={self.downsample})\")\n\n    def forward(self, input, style):\n        batch, in_channel, height, width = input.shape\n\n        style = self.modulation(style).view(batch, 1, in_channel, 1, 1)\n        weight = self.scale * self.weight * style\n\n        if self.demodulate:\n            demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)\n            weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)\n\n        weight = weight.view(batch * self.out_channel, in_channel,\n                             self.kernel_size, self.kernel_size)\n\n        if self.upsample:\n            input = input.view(1, batch * in_channel, height, width)\n            weight = weight.view(batch, self.out_channel, in_channel,\n                                 self.kernel_size, self.kernel_size)\n            weight = weight.transpose(1, 2).reshape(batch * in_channel,\n                                                    self.out_channel,\n                                                    self.kernel_size,\n                                                    self.kernel_size)\n            out = F.conv_transpose2d(\n                input, weight, padding=0, stride=2, groups=batch)\n            _, _, height, width = out.shape\n            out = out.view(batch, self.out_channel, height, width)\n            out = self.blur(out)\n\n        elif self.downsample:\n            input = self.blur(input)\n            _, _, height, width = input.shape\n            input = input.view(1, batch * in_channel, height, width)\n            out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)\n            _, _, height, width = out.shape\n            out = out.view(batch, self.out_channel, height, width)\n\n        else:\n            input = input.view(1, batch * in_channel, height, width)\n            out = F.conv2d(input, weight, padding=self.padding, groups=batch)\n            _, _, height, width = out.shape\n            out = out.view(batch, self.out_channel, height, width)\n\n        return out\n\n\nclass NoiseInjection(nn.Module):\n\n    def __init__(self):\n        super().__init__()\n\n        self.weight = nn.Parameter(torch.zeros(1))\n\n    def forward(self, image, noise=None):\n        if noise is None:\n            batch, _, height, width = image.shape\n            noise = image.new_empty(batch, 1, height, width).normal_()\n\n        return image + self.weight * noise\n\n\nclass ConstantInput(nn.Module):\n\n    def __init__(self, channel, size=4):\n        super().__init__()\n\n        self.input = nn.Parameter(torch.randn(1, channel, size, size))\n\n    def forward(self, input):\n        batch = input.shape[0]\n        out = self.input.repeat(batch, 1, 1, 1)\n\n        return out\n\n\nclass StyledConv(nn.Module):\n\n    def __init__(\n        self,\n        in_channel,\n        out_channel,\n        kernel_size,\n        style_dim,\n        upsample=False,\n        blur_kernel=[1, 3, 3, 1],\n        demodulate=True,\n    ):\n        super().__init__()\n\n        self.conv = ModulatedConv2d(\n            in_channel,\n            out_channel,\n            kernel_size,\n            style_dim,\n            upsample=upsample,\n            blur_kernel=blur_kernel,\n            demodulate=demodulate,\n        )\n\n        self.noise = NoiseInjection()\n        # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))\n        # self.activate = ScaledLeakyReLU(0.2)\n        self.activate = FusedLeakyReLU(out_channel)\n\n    def forward(self, input, style, noise=None):\n        out = self.conv(input, style)\n        out = self.noise(out, noise=noise)\n        # out = out + self.bias\n        out = self.activate(out)\n\n        return out\n\n\nclass ToRGB(nn.Module):\n\n    def __init__(self,\n                 in_channel,\n                 style_dim,\n                 upsample=True,\n                 blur_kernel=[1, 3, 3, 1]):\n        super().__init__()\n\n        if upsample:\n            self.upsample = Upsample(blur_kernel)\n\n        self.conv = ModulatedConv2d(\n            in_channel, 3, 1, style_dim, demodulate=False)\n        self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))\n\n    def forward(self, input, style, skip=None):\n        out = self.conv(input, style)\n        out = out + self.bias\n\n        if skip is not None:\n            skip = self.upsample(skip)\n\n            out = out + skip\n\n        return out\n\n\nclass Generator(nn.Module):\n\n    def __init__(\n        self,\n        size,\n        style_dim,\n        n_mlp,\n        channel_multiplier=2,\n        blur_kernel=[1, 3, 3, 1],\n        lr_mlp=0.01,\n    ):\n        super().__init__()\n\n        self.size = size\n\n        self.style_dim = style_dim\n\n        layers = [PixelNorm()]\n\n        for i in range(n_mlp):\n            layers.append(\n                EqualLinear(\n                    style_dim,\n                    style_dim,\n                    lr_mul=lr_mlp,\n                    activation=\"fused_lrelu\"))\n\n        # self.style = nn.Sequential(*layers)\n        self.style = nn.ModuleList(layers)\n\n        self.channels = {\n            4: 512,\n            8: 512,\n            16: 512,\n            32: 512,\n            64: 256 * channel_multiplier,\n            128: 128 * channel_multiplier,\n            256: 64 * channel_multiplier,\n            512: 32 * channel_multiplier,\n            1024: 16 * channel_multiplier,\n        }\n\n        self.input = ConstantInput(self.channels[4])\n        self.conv1 = StyledConv(\n            self.channels[4],\n            self.channels[4],\n            3,\n            style_dim,\n            blur_kernel=blur_kernel)\n        self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)\n\n        self.log_size = int(math.log(size, 2))\n        self.num_layers = (self.log_size - 2) * 2 + 1\n\n        self.convs = nn.ModuleList()\n        self.upsamples = nn.ModuleList()\n        self.to_rgbs = nn.ModuleList()\n        self.noises = nn.Module()\n\n        in_channel = self.channels[4]\n\n        for layer_idx in range(self.num_layers):\n            res = (layer_idx + 5) // 2\n            shape = [1, 1, 2**res, 2**res]\n            self.noises.register_buffer(f\"noise_{layer_idx}\",\n                                        torch.randn(*shape))\n\n        for i in range(3, self.log_size + 1):\n            out_channel = self.channels[2**i]\n\n            self.convs.append(\n                StyledConv(\n                    in_channel,\n                    out_channel,\n                    3,\n                    style_dim,\n                    upsample=True,\n                    blur_kernel=blur_kernel,\n                ))\n\n            self.convs.append(\n                StyledConv(\n                    out_channel,\n                    out_channel,\n                    3,\n                    style_dim,\n                    blur_kernel=blur_kernel))\n\n            self.to_rgbs.append(ToRGB(out_channel, style_dim))\n\n            in_channel = out_channel\n\n        self.n_latent = self.log_size * 2 - 2\n\n    def make_noise(self):\n        device = self.input.input.device\n\n        noises = [torch.randn(1, 1, 2**2, 2**2, device=device)]\n\n        for i in range(3, self.log_size + 1):\n            for _ in range(2):\n                noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))\n\n        return noises\n\n    def mean_latent(self, n_latent):\n        latent_in = torch.randn(\n            n_latent, self.style_dim, device=self.input.input.device)\n        latent = self.style_forward(latent_in).mean(0, keepdim=True)\n\n        return latent\n\n    def get_latent(self, input):\n        out = input\n        for i, layer in enumerate(self.style):\n            out = layer(out)\n        return out\n\n    def style_forward(self, input, skip_norm=False):\n        out = input\n        for i, layer in enumerate(self.style):\n            if i == 0 and skip_norm:\n                continue\n            out = layer(out)\n        return out\n\n    def forward(\n        self,\n        styles,\n        return_latents=False,\n        inject_index=None,\n        truncation=1,\n        truncation_latent=None,\n        input_is_latent=False,\n        noise=None,\n        randomize_noise=True,\n    ):\n        if not input_is_latent:\n            styles = [self.style_forward(s) for s in styles]\n\n        if noise is None:\n            if randomize_noise:\n                noise = [None] * self.num_layers\n            else:\n                noise = [\n                    getattr(self.noises, f\"noise_{i}\")\n                    for i in range(self.num_layers)\n                ]\n\n        if truncation < 1:\n            style_t = []\n\n            for style in styles:\n                style_t.append(truncation_latent + truncation *\n                               (style - truncation_latent))\n\n            styles = style_t\n\n        if len(styles) < 2:\n            inject_index = self.n_latent\n\n            if styles[0].ndim < 3:\n                latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)\n\n            else:\n                latent = styles[0]\n\n        else:\n            if inject_index is None:\n                inject_index = random.randint(1, self.n_latent - 1)\n\n            latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)\n            latent2 = styles[1].unsqueeze(1).repeat(\n                1, self.n_latent - inject_index, 1)\n\n            latent = torch.cat([latent, latent2], 1)\n\n        out = self.input(latent)\n        out = self.conv1(out, latent[:, 0], noise=noise[0])\n\n        skip = self.to_rgb1(out, latent[:, 1])\n\n        i = 1\n        for conv1, conv2, noise1, noise2, to_rgb in zip(\n                self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2],\n                self.to_rgbs):\n            out = conv1(out, latent[:, i], noise=noise1)\n            out = conv2(out, latent[:, i + 1], noise=noise2)\n            skip = to_rgb(out, latent[:, i + 2], skip)\n\n            i += 2\n\n        image = skip\n\n        if return_latents:\n            return image, latent\n\n        else:\n            return image, None\n\n\nclass ConvLayer(nn.Sequential):\n\n    def __init__(\n        self,\n        in_channel,\n        out_channel,\n        kernel_size,\n        downsample=False,\n        blur_kernel=[1, 3, 3, 1],\n        bias=True,\n        activate=True,\n    ):\n        layers = []\n\n        if downsample:\n            factor = 2\n            p = (len(blur_kernel) - factor) + (kernel_size - 1)\n            pad0 = (p + 1) // 2\n            pad1 = p // 2\n\n            layers.append(Blur(blur_kernel, pad=(pad0, pad1)))\n\n            stride = 2\n            self.padding = 0\n\n        else:\n            stride = 1\n            self.padding = kernel_size // 2\n\n        layers.append(\n            EqualConv2d(\n                in_channel,\n                out_channel,\n                kernel_size,\n                padding=self.padding,\n                stride=stride,\n                bias=bias and not activate,\n            ))\n\n        if activate:\n            layers.append(FusedLeakyReLU(out_channel, bias=bias))\n\n        super().__init__(*layers)\n\n\nclass ResBlock(nn.Module):\n\n    def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):\n        super().__init__()\n\n        self.conv1 = ConvLayer(in_channel, in_channel, 3)\n        self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)\n\n        self.skip = ConvLayer(\n            in_channel,\n            out_channel,\n            1,\n            downsample=True,\n            activate=False,\n            bias=False)\n\n    def forward(self, input):\n        out = self.conv1(input)\n        out = self.conv2(out)\n\n        skip = self.skip(input)\n        out = (out + skip) / math.sqrt(2)\n\n        return out\n\n\nclass Discriminator(nn.Module):\n\n    def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):\n        super().__init__()\n\n        channels = {\n            4: 512,\n            8: 512,\n            16: 512,\n            32: 512,\n            64: 256 * channel_multiplier,\n            128: 128 * channel_multiplier,\n            256: 64 * channel_multiplier,\n            512: 32 * channel_multiplier,\n            1024: 16 * channel_multiplier,\n        }\n\n        convs = [ConvLayer(3, channels[size], 1)]\n\n        log_size = int(math.log(size, 2))\n\n        in_channel = channels[size]\n\n        for i in range(log_size, 2, -1):\n            out_channel = channels[2**(i - 1)]\n\n            convs.append(ResBlock(in_channel, out_channel, blur_kernel))\n\n            in_channel = out_channel\n\n        self.convs = nn.Sequential(*convs)\n\n        self.stddev_group = 4\n        self.stddev_feat = 1\n\n        self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)\n        self.final_linear = nn.Sequential(\n            EqualLinear(\n                channels[4] * 4 * 4, channels[4], activation=\"fused_lrelu\"),\n            EqualLinear(channels[4], 1),\n        )\n\n    def forward(self, input):\n        out = self.convs(input)\n\n        batch, channel, height, width = out.shape\n        group = min(batch, self.stddev_group)\n        stddev = out.view(group, -1, self.stddev_feat,\n                          channel // self.stddev_feat, height, width)\n        stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)\n        stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)\n        stddev = stddev.repeat(group, 1, height, width)\n        out = torch.cat([out, stddev], 1)\n\n        out = self.final_conv(out)\n\n        out = out.view(batch, -1)\n        out = self.final_linear(out)\n\n        return out\n"
  },
  {
    "path": "models/archs/stylegan2/non_leaking.py",
    "content": "import math\r\n\r\nimport torch\r\nfrom torch.nn import functional as F\r\n\r\nfrom distributed import reduce_sum\r\nfrom op import upfirdn2d\r\n\r\n\r\nclass AdaptiveAugment:\r\n    def __init__(self, ada_aug_target, ada_aug_len, update_every, device):\r\n        self.ada_aug_target = ada_aug_target\r\n        self.ada_aug_len = ada_aug_len\r\n        self.update_every = update_every\r\n\r\n        self.ada_aug_buf = torch.tensor([0.0, 0.0], device=device)\r\n        self.r_t_stat = 0\r\n        self.ada_aug_p = 0\r\n\r\n    @torch.no_grad()\r\n    def tune(self, real_pred):\r\n        ada_aug_data = torch.tensor(\r\n            (torch.sign(real_pred).sum().item(), real_pred.shape[0]),\r\n            device=real_pred.device,\r\n        )\r\n        self.ada_aug_buf += reduce_sum(ada_aug_data)\r\n\r\n        if self.ada_aug_buf[1] > self.update_every - 1:\r\n            pred_signs, n_pred = self.ada_aug_buf.tolist()\r\n\r\n            self.r_t_stat = pred_signs / n_pred\r\n\r\n            if self.r_t_stat > self.ada_aug_target:\r\n                sign = 1\r\n\r\n            else:\r\n                sign = -1\r\n\r\n            self.ada_aug_p += sign * n_pred / self.ada_aug_len\r\n            self.ada_aug_p = min(1, max(0, self.ada_aug_p))\r\n            self.ada_aug_buf.mul_(0)\r\n\r\n        return self.ada_aug_p\r\n\r\n\r\nSYM6 = (\r\n    0.015404109327027373,\r\n    0.0034907120842174702,\r\n    -0.11799011114819057,\r\n    -0.048311742585633,\r\n    0.4910559419267466,\r\n    0.787641141030194,\r\n    0.3379294217276218,\r\n    -0.07263752278646252,\r\n    -0.021060292512300564,\r\n    0.04472490177066578,\r\n    0.0017677118642428036,\r\n    -0.007800708325034148,\r\n)\r\n\r\n\r\ndef translate_mat(t_x, t_y):\r\n    batch = t_x.shape[0]\r\n\r\n    mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1)\r\n    translate = torch.stack((t_x, t_y), 1)\r\n    mat[:, :2, 2] = translate\r\n\r\n    return mat\r\n\r\n\r\ndef rotate_mat(theta):\r\n    batch = theta.shape[0]\r\n\r\n    mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1)\r\n    sin_t = torch.sin(theta)\r\n    cos_t = torch.cos(theta)\r\n    rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2)\r\n    mat[:, :2, :2] = rot\r\n\r\n    return mat\r\n\r\n\r\ndef scale_mat(s_x, s_y):\r\n    batch = s_x.shape[0]\r\n\r\n    mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1)\r\n    mat[:, 0, 0] = s_x\r\n    mat[:, 1, 1] = s_y\r\n\r\n    return mat\r\n\r\n\r\ndef translate3d_mat(t_x, t_y, t_z):\r\n    batch = t_x.shape[0]\r\n\r\n    mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)\r\n    translate = torch.stack((t_x, t_y, t_z), 1)\r\n    mat[:, :3, 3] = translate\r\n\r\n    return mat\r\n\r\n\r\ndef rotate3d_mat(axis, theta):\r\n    batch = theta.shape[0]\r\n\r\n    u_x, u_y, u_z = axis\r\n\r\n    eye = torch.eye(3).unsqueeze(0)\r\n    cross = torch.tensor([(0, -u_z, u_y), (u_z, 0, -u_x), (-u_y, u_x, 0)]).unsqueeze(0)\r\n    outer = torch.tensor(axis)\r\n    outer = (outer.unsqueeze(1) * outer).unsqueeze(0)\r\n\r\n    sin_t = torch.sin(theta).view(-1, 1, 1)\r\n    cos_t = torch.cos(theta).view(-1, 1, 1)\r\n\r\n    rot = cos_t * eye + sin_t * cross + (1 - cos_t) * outer\r\n\r\n    eye_4 = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)\r\n    eye_4[:, :3, :3] = rot\r\n\r\n    return eye_4\r\n\r\n\r\ndef scale3d_mat(s_x, s_y, s_z):\r\n    batch = s_x.shape[0]\r\n\r\n    mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)\r\n    mat[:, 0, 0] = s_x\r\n    mat[:, 1, 1] = s_y\r\n    mat[:, 2, 2] = s_z\r\n\r\n    return mat\r\n\r\n\r\ndef luma_flip_mat(axis, i):\r\n    batch = i.shape[0]\r\n\r\n    eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)\r\n    axis = torch.tensor(axis + (0,))\r\n    flip = 2 * torch.ger(axis, axis) * i.view(-1, 1, 1)\r\n\r\n    return eye - flip\r\n\r\n\r\ndef saturation_mat(axis, i):\r\n    batch = i.shape[0]\r\n\r\n    eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)\r\n    axis = torch.tensor(axis + (0,))\r\n    axis = torch.ger(axis, axis)\r\n    saturate = axis + (eye - axis) * i.view(-1, 1, 1)\r\n\r\n    return saturate\r\n\r\n\r\ndef lognormal_sample(size, mean=0, std=1):\r\n    return torch.empty(size).log_normal_(mean=mean, std=std)\r\n\r\n\r\ndef category_sample(size, categories):\r\n    category = torch.tensor(categories)\r\n    sample = torch.randint(high=len(categories), size=(size,))\r\n\r\n    return category[sample]\r\n\r\n\r\ndef uniform_sample(size, low, high):\r\n    return torch.empty(size).uniform_(low, high)\r\n\r\n\r\ndef normal_sample(size, mean=0, std=1):\r\n    return torch.empty(size).normal_(mean, std)\r\n\r\n\r\ndef bernoulli_sample(size, p):\r\n    return torch.empty(size).bernoulli_(p)\r\n\r\n\r\ndef random_mat_apply(p, transform, prev, eye):\r\n    size = transform.shape[0]\r\n    select = bernoulli_sample(size, p).view(size, 1, 1)\r\n    select_transform = select * transform + (1 - select) * eye\r\n\r\n    return select_transform @ prev\r\n\r\n\r\ndef sample_affine(p, size, height, width):\r\n    G = torch.eye(3).unsqueeze(0).repeat(size, 1, 1)\r\n    eye = G\r\n\r\n    # flip\r\n    param = category_sample(size, (0, 1))\r\n    Gc = scale_mat(1 - 2.0 * param, torch.ones(size))\r\n    G = random_mat_apply(p, Gc, G, eye)\r\n    # print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\\n')\r\n\r\n    # 90 rotate\r\n    param = category_sample(size, (0, 3))\r\n    Gc = rotate_mat(-math.pi / 2 * param)\r\n    G = random_mat_apply(p, Gc, G, eye)\r\n    # print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\\n')\r\n\r\n    # integer translate\r\n    param = uniform_sample(size, -0.125, 0.125)\r\n    param_height = torch.round(param * height) / height\r\n    param_width = torch.round(param * width) / width\r\n    Gc = translate_mat(param_width, param_height)\r\n    G = random_mat_apply(p, Gc, G, eye)\r\n    # print('integer translate', G, translate_mat(param_width, param_height), sep='\\n')\r\n\r\n    # isotropic scale\r\n    param = lognormal_sample(size, std=0.2 * math.log(2))\r\n    Gc = scale_mat(param, param)\r\n    G = random_mat_apply(p, Gc, G, eye)\r\n    # print('isotropic scale', G, scale_mat(param, param), sep='\\n')\r\n\r\n    p_rot = 1 - math.sqrt(1 - p)\r\n\r\n    # pre-rotate\r\n    param = uniform_sample(size, -math.pi, math.pi)\r\n    Gc = rotate_mat(-param)\r\n    G = random_mat_apply(p_rot, Gc, G, eye)\r\n    # print('pre-rotate', G, rotate_mat(-param), sep='\\n')\r\n\r\n    # anisotropic scale\r\n    param = lognormal_sample(size, std=0.2 * math.log(2))\r\n    Gc = scale_mat(param, 1 / param)\r\n    G = random_mat_apply(p, Gc, G, eye)\r\n    # print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\\n')\r\n\r\n    # post-rotate\r\n    param = uniform_sample(size, -math.pi, math.pi)\r\n    Gc = rotate_mat(-param)\r\n    G = random_mat_apply(p_rot, Gc, G, eye)\r\n    # print('post-rotate', G, rotate_mat(-param), sep='\\n')\r\n\r\n    # fractional translate\r\n    param = normal_sample(size, std=0.125)\r\n    Gc = translate_mat(param, param)\r\n    G = random_mat_apply(p, Gc, G, eye)\r\n    # print('fractional translate', G, translate_mat(param, param), sep='\\n')\r\n\r\n    return G\r\n\r\n\r\ndef sample_color(p, size):\r\n    C = torch.eye(4).unsqueeze(0).repeat(size, 1, 1)\r\n    eye = C\r\n    axis_val = 1 / math.sqrt(3)\r\n    axis = (axis_val, axis_val, axis_val)\r\n\r\n    # brightness\r\n    param = normal_sample(size, std=0.2)\r\n    Cc = translate3d_mat(param, param, param)\r\n    C = random_mat_apply(p, Cc, C, eye)\r\n\r\n    # contrast\r\n    param = lognormal_sample(size, std=0.5 * math.log(2))\r\n    Cc = scale3d_mat(param, param, param)\r\n    C = random_mat_apply(p, Cc, C, eye)\r\n\r\n    # luma flip\r\n    param = category_sample(size, (0, 1))\r\n    Cc = luma_flip_mat(axis, param)\r\n    C = random_mat_apply(p, Cc, C, eye)\r\n\r\n    # hue rotation\r\n    param = uniform_sample(size, -math.pi, math.pi)\r\n    Cc = rotate3d_mat(axis, param)\r\n    C = random_mat_apply(p, Cc, C, eye)\r\n\r\n    # saturation\r\n    param = lognormal_sample(size, std=1 * math.log(2))\r\n    Cc = saturation_mat(axis, param)\r\n    C = random_mat_apply(p, Cc, C, eye)\r\n\r\n    return C\r\n\r\n\r\ndef make_grid(shape, x0, x1, y0, y1, device):\r\n    n, c, h, w = shape\r\n    grid = torch.empty(n, h, w, 3, device=device)\r\n    grid[:, :, :, 0] = torch.linspace(x0, x1, w, device=device)\r\n    grid[:, :, :, 1] = torch.linspace(y0, y1, h, device=device).unsqueeze(-1)\r\n    grid[:, :, :, 2] = 1\r\n\r\n    return grid\r\n\r\n\r\ndef affine_grid(grid, mat):\r\n    n, h, w, _ = grid.shape\r\n    return (grid.view(n, h * w, 3) @ mat.transpose(1, 2)).view(n, h, w, 2)\r\n\r\n\r\ndef get_padding(G, height, width):\r\n    extreme = (\r\n        G[:, :2, :]\r\n        @ torch.tensor([(-1.0, -1, 1), (-1, 1, 1), (1, -1, 1), (1, 1, 1)]).t()\r\n    )\r\n\r\n    size = torch.tensor((width, height))\r\n\r\n    pad_low = (\r\n        ((extreme.min(-1).values + 1) * size)\r\n        .clamp(max=0)\r\n        .abs()\r\n        .ceil()\r\n        .max(0)\r\n        .values.to(torch.int64)\r\n        .tolist()\r\n    )\r\n    pad_high = (\r\n        (extreme.max(-1).values * size - size)\r\n        .clamp(min=0)\r\n        .ceil()\r\n        .max(0)\r\n        .values.to(torch.int64)\r\n        .tolist()\r\n    )\r\n\r\n    return pad_low[0], pad_high[0], pad_low[1], pad_high[1]\r\n\r\n\r\ndef try_sample_affine_and_pad(img, p, pad_k, G=None):\r\n    batch, _, height, width = img.shape\r\n\r\n    G_try = G\r\n\r\n    while True:\r\n        if G is None:\r\n            G_try = sample_affine(p, batch, height, width)\r\n\r\n        pad_x1, pad_x2, pad_y1, pad_y2 = get_padding(\r\n            torch.inverse(G_try), height, width\r\n        )\r\n\r\n        try:\r\n            img_pad = F.pad(\r\n                img,\r\n                (pad_x1 + pad_k, pad_x2 + pad_k, pad_y1 + pad_k, pad_y2 + pad_k),\r\n                mode=\"reflect\",\r\n            )\r\n\r\n        except RuntimeError:\r\n            continue\r\n\r\n        break\r\n\r\n    return img_pad, G_try, (pad_x1, pad_x2, pad_y1, pad_y2)\r\n\r\n\r\ndef random_apply_affine(img, p, G=None, antialiasing_kernel=SYM6):\r\n    kernel = antialiasing_kernel\r\n    len_k = len(kernel)\r\n    pad_k = (len_k + 1) // 2\r\n\r\n    kernel = torch.as_tensor(kernel)\r\n    kernel = torch.ger(kernel, kernel).to(img)\r\n    kernel_flip = torch.flip(kernel, (0, 1))\r\n\r\n    img_pad, G, (pad_x1, pad_x2, pad_y1, pad_y2) = try_sample_affine_and_pad(\r\n        img, p, pad_k, G\r\n    )\r\n\r\n    p_ux1 = pad_x1\r\n    p_ux2 = pad_x2 + 1\r\n    p_uy1 = pad_y1\r\n    p_uy2 = pad_y2 + 1\r\n    w_p = img_pad.shape[3] - len_k + 1\r\n    h_p = img_pad.shape[2] - len_k + 1\r\n    h_o = img.shape[2]\r\n    w_o = img.shape[3]\r\n\r\n    img_2x = upfirdn2d(img_pad, kernel_flip, up=2)\r\n\r\n    grid = make_grid(\r\n        img_2x.shape,\r\n        -2 * p_ux1 / w_o - 1,\r\n        2 * (w_p - p_ux1) / w_o - 1,\r\n        -2 * p_uy1 / h_o - 1,\r\n        2 * (h_p - p_uy1) / h_o - 1,\r\n        device=img_2x.device,\r\n    ).to(img_2x)\r\n    grid = affine_grid(grid, torch.inverse(G)[:, :2, :].to(img_2x))\r\n    grid = grid * torch.tensor(\r\n        [w_o / w_p, h_o / h_p], device=grid.device\r\n    ) + torch.tensor(\r\n        [(w_o + 2 * p_ux1) / w_p - 1, (h_o + 2 * p_uy1) / h_p - 1], device=grid.device\r\n    )\r\n\r\n    img_affine = F.grid_sample(\r\n        img_2x, grid, mode=\"bilinear\", align_corners=False, padding_mode=\"zeros\"\r\n    )\r\n\r\n    img_down = upfirdn2d(img_affine, kernel, down=2)\r\n\r\n    end_y = -pad_y2 - 1\r\n    if end_y == 0:\r\n        end_y = img_down.shape[2]\r\n\r\n    end_x = -pad_x2 - 1\r\n    if end_x == 0:\r\n        end_x = img_down.shape[3]\r\n\r\n    img = img_down[:, :, pad_y1:end_y, pad_x1:end_x]\r\n\r\n    return img, G\r\n\r\n\r\ndef apply_color(img, mat):\r\n    batch = img.shape[0]\r\n    img = img.permute(0, 2, 3, 1)\r\n    mat_mul = mat[:, :3, :3].transpose(1, 2).view(batch, 1, 3, 3)\r\n    mat_add = mat[:, :3, 3].view(batch, 1, 1, 3)\r\n    img = img @ mat_mul + mat_add\r\n    img = img.permute(0, 3, 1, 2)\r\n\r\n    return img\r\n\r\n\r\ndef random_apply_color(img, p, C=None):\r\n    if C is None:\r\n        C = sample_color(p, img.shape[0])\r\n\r\n    img = apply_color(img, C.to(img))\r\n\r\n    return img, C\r\n\r\n\r\ndef augment(img, p, transform_matrix=(None, None)):\r\n    img, G = random_apply_affine(img, p, transform_matrix[0])\r\n    img, C = random_apply_color(img, p, transform_matrix[1])\r\n\r\n    return img, (G, C)\r\n"
  },
  {
    "path": "models/archs/stylegan2/op/__init__.py",
    "content": "from .fused_act import FusedLeakyReLU, fused_leaky_relu\nfrom .upfirdn2d import upfirdn2d\n"
  },
  {
    "path": "models/archs/stylegan2/op/fused_act.py",
    "content": "import os\r\n\r\nimport torch\r\nfrom torch import nn\r\nfrom torch.nn import functional as F\r\nfrom torch.autograd import Function\r\nfrom torch.utils.cpp_extension import load\r\n\r\n\r\nmodule_path = os.path.dirname(__file__)\r\nfused = load(\r\n    \"fused\",\r\n    sources=[\r\n        os.path.join(module_path, \"fused_bias_act.cpp\"),\r\n        os.path.join(module_path, \"fused_bias_act_kernel.cu\"),\r\n    ],\r\n)\r\n\r\n\r\nclass FusedLeakyReLUFunctionBackward(Function):\r\n    @staticmethod\r\n    def forward(ctx, grad_output, out, bias, negative_slope, scale):\r\n        ctx.save_for_backward(out)\r\n        ctx.negative_slope = negative_slope\r\n        ctx.scale = scale\r\n\r\n        empty = grad_output.new_empty(0)\r\n\r\n        grad_input = fused.fused_bias_act(\r\n            grad_output, empty, out, 3, 1, negative_slope, scale\r\n        )\r\n\r\n        dim = [0]\r\n\r\n        if grad_input.ndim > 2:\r\n            dim += list(range(2, grad_input.ndim))\r\n\r\n        if bias:\r\n            grad_bias = grad_input.sum(dim).detach()\r\n\r\n        else:\r\n            grad_bias = empty\r\n\r\n        return grad_input, grad_bias\r\n\r\n    @staticmethod\r\n    def backward(ctx, gradgrad_input, gradgrad_bias):\r\n        out, = ctx.saved_tensors\r\n        gradgrad_out = fused.fused_bias_act(\r\n            gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale\r\n        )\r\n\r\n        return gradgrad_out, None, None, None, None\r\n\r\n\r\nclass FusedLeakyReLUFunction(Function):\r\n    @staticmethod\r\n    def forward(ctx, input, bias, negative_slope, scale):\r\n        empty = input.new_empty(0)\r\n\r\n        ctx.bias = bias is not None\r\n\r\n        if bias is None:\r\n            bias = empty\r\n\r\n        out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)\r\n        ctx.save_for_backward(out)\r\n        ctx.negative_slope = negative_slope\r\n        ctx.scale = scale\r\n\r\n        return out\r\n\r\n    @staticmethod\r\n    def backward(ctx, grad_output):\r\n        out, = ctx.saved_tensors\r\n\r\n        grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(\r\n            grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale\r\n        )\r\n\r\n        if not ctx.bias:\r\n            grad_bias = None\r\n\r\n        return grad_input, grad_bias, None, None\r\n\r\n\r\nclass FusedLeakyReLU(nn.Module):\r\n    def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):\r\n        super().__init__()\r\n\r\n        if bias:\r\n            self.bias = nn.Parameter(torch.zeros(channel))\r\n\r\n        else:\r\n            self.bias = None\r\n\r\n        self.negative_slope = negative_slope\r\n        self.scale = scale\r\n\r\n    def forward(self, input):\r\n        return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)\r\n\r\n\r\ndef fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):\r\n    if input.device.type == \"cpu\":\r\n        if bias is not None:\r\n            rest_dim = [1] * (input.ndim - bias.ndim - 1)\r\n            return (\r\n                F.leaky_relu(\r\n                    input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2\r\n                )\r\n                * scale\r\n            )\r\n\r\n        else:\r\n            return F.leaky_relu(input, negative_slope=0.2) * scale\r\n\r\n    else:\r\n        return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)\r\n"
  },
  {
    "path": "models/archs/stylegan2/op/fused_bias_act.cpp",
    "content": "#include <torch/extension.h>\r\n\r\n\r\ntorch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,\r\n    int act, int grad, float alpha, float scale);\r\n\r\n#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x \" must be a CUDA tensor\")\r\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\r\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\r\n\r\ntorch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,\r\n    int act, int grad, float alpha, float scale) {\r\n    CHECK_CUDA(input);\r\n    CHECK_CUDA(bias);\r\n\r\n    return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);\r\n}\r\n\r\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\r\n    m.def(\"fused_bias_act\", &fused_bias_act, \"fused bias act (CUDA)\");\r\n}"
  },
  {
    "path": "models/archs/stylegan2/op/fused_bias_act_kernel.cu",
    "content": "// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.\r\n//\r\n// This work is made available under the Nvidia Source Code License-NC.\r\n// To view a copy of this license, visit\r\n// https://nvlabs.github.io/stylegan2/license.html\r\n\r\n#include <torch/types.h>\r\n\r\n#include <ATen/ATen.h>\r\n#include <ATen/AccumulateType.h>\r\n#include <ATen/cuda/CUDAContext.h>\r\n#include <ATen/cuda/CUDAApplyUtils.cuh>\r\n\r\n#include <cuda.h>\r\n#include <cuda_runtime.h>\r\n\r\n\r\ntemplate <typename scalar_t>\r\nstatic __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,\r\n    int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {\r\n    int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;\r\n\r\n    scalar_t zero = 0.0;\r\n\r\n    for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {\r\n        scalar_t x = p_x[xi];\r\n\r\n        if (use_bias) {\r\n            x += p_b[(xi / step_b) % size_b];\r\n        }\r\n\r\n        scalar_t ref = use_ref ? p_ref[xi] : zero;\r\n\r\n        scalar_t y;\r\n\r\n        switch (act * 10 + grad) {\r\n            default:\r\n            case 10: y = x; break;\r\n            case 11: y = x; break;\r\n            case 12: y = 0.0; break;\r\n\r\n            case 30: y = (x > 0.0) ? x : x * alpha; break;\r\n            case 31: y = (ref > 0.0) ? x : x * alpha; break;\r\n            case 32: y = 0.0; break;\r\n        }\r\n\r\n        out[xi] = y * scale;\r\n    }\r\n}\r\n\r\n\r\ntorch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,\r\n    int act, int grad, float alpha, float scale) {\r\n    int curDevice = -1;\r\n    cudaGetDevice(&curDevice);\r\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);\r\n\r\n    auto x = input.contiguous();\r\n    auto b = bias.contiguous();\r\n    auto ref = refer.contiguous();\r\n\r\n    int use_bias = b.numel() ? 1 : 0;\r\n    int use_ref = ref.numel() ? 1 : 0;\r\n\r\n    int size_x = x.numel();\r\n    int size_b = b.numel();\r\n    int step_b = 1;\r\n\r\n    for (int i = 1 + 1; i < x.dim(); i++) {\r\n        step_b *= x.size(i);\r\n    }\r\n\r\n    int loop_x = 4;\r\n    int block_size = 4 * 32;\r\n    int grid_size = (size_x - 1) / (loop_x * block_size) + 1;\r\n\r\n    auto y = torch::empty_like(x);\r\n\r\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), \"fused_bias_act_kernel\", [&] {\r\n        fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(\r\n            y.data_ptr<scalar_t>(),\r\n            x.data_ptr<scalar_t>(),\r\n            b.data_ptr<scalar_t>(),\r\n            ref.data_ptr<scalar_t>(),\r\n            act,\r\n            grad,\r\n            alpha,\r\n            scale,\r\n            loop_x,\r\n            size_x,\r\n            step_b,\r\n            size_b,\r\n            use_bias,\r\n            use_ref\r\n        );\r\n    });\r\n\r\n    return y;\r\n}"
  },
  {
    "path": "models/archs/stylegan2/op/upfirdn2d.cpp",
    "content": "#include <torch/extension.h>\r\n\r\n\r\ntorch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,\r\n                            int up_x, int up_y, int down_x, int down_y,\r\n                            int pad_x0, int pad_x1, int pad_y0, int pad_y1);\r\n\r\n#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x \" must be a CUDA tensor\")\r\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\r\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\r\n\r\ntorch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,\r\n                        int up_x, int up_y, int down_x, int down_y,\r\n                        int pad_x0, int pad_x1, int pad_y0, int pad_y1) {\r\n    CHECK_CUDA(input);\r\n    CHECK_CUDA(kernel);\r\n\r\n    return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);\r\n}\r\n\r\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\r\n    m.def(\"upfirdn2d\", &upfirdn2d, \"upfirdn2d (CUDA)\");\r\n}"
  },
  {
    "path": "models/archs/stylegan2/op/upfirdn2d.py",
    "content": "import os\r\n\r\nimport torch\r\nfrom torch.nn import functional as F\r\nfrom torch.autograd import Function\r\nfrom torch.utils.cpp_extension import load\r\n\r\n\r\nmodule_path = os.path.dirname(__file__)\r\nupfirdn2d_op = load(\r\n    \"upfirdn2d\",\r\n    sources=[\r\n        os.path.join(module_path, \"upfirdn2d.cpp\"),\r\n        os.path.join(module_path, \"upfirdn2d_kernel.cu\"),\r\n    ],\r\n)\r\n\r\n\r\nclass UpFirDn2dBackward(Function):\r\n    @staticmethod\r\n    def forward(\r\n        ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size\r\n    ):\r\n\r\n        up_x, up_y = up\r\n        down_x, down_y = down\r\n        g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad\r\n\r\n        grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)\r\n\r\n        grad_input = upfirdn2d_op.upfirdn2d(\r\n            grad_output,\r\n            grad_kernel,\r\n            down_x,\r\n            down_y,\r\n            up_x,\r\n            up_y,\r\n            g_pad_x0,\r\n            g_pad_x1,\r\n            g_pad_y0,\r\n            g_pad_y1,\r\n        )\r\n        grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])\r\n\r\n        ctx.save_for_backward(kernel)\r\n\r\n        pad_x0, pad_x1, pad_y0, pad_y1 = pad\r\n\r\n        ctx.up_x = up_x\r\n        ctx.up_y = up_y\r\n        ctx.down_x = down_x\r\n        ctx.down_y = down_y\r\n        ctx.pad_x0 = pad_x0\r\n        ctx.pad_x1 = pad_x1\r\n        ctx.pad_y0 = pad_y0\r\n        ctx.pad_y1 = pad_y1\r\n        ctx.in_size = in_size\r\n        ctx.out_size = out_size\r\n\r\n        return grad_input\r\n\r\n    @staticmethod\r\n    def backward(ctx, gradgrad_input):\r\n        kernel, = ctx.saved_tensors\r\n\r\n        gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)\r\n\r\n        gradgrad_out = upfirdn2d_op.upfirdn2d(\r\n            gradgrad_input,\r\n            kernel,\r\n            ctx.up_x,\r\n            ctx.up_y,\r\n            ctx.down_x,\r\n            ctx.down_y,\r\n            ctx.pad_x0,\r\n            ctx.pad_x1,\r\n            ctx.pad_y0,\r\n            ctx.pad_y1,\r\n        )\r\n        # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])\r\n        gradgrad_out = gradgrad_out.view(\r\n            ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]\r\n        )\r\n\r\n        return gradgrad_out, None, None, None, None, None, None, None, None\r\n\r\n\r\nclass UpFirDn2d(Function):\r\n    @staticmethod\r\n    def forward(ctx, input, kernel, up, down, pad):\r\n        up_x, up_y = up\r\n        down_x, down_y = down\r\n        pad_x0, pad_x1, pad_y0, pad_y1 = pad\r\n\r\n        kernel_h, kernel_w = kernel.shape\r\n        batch, channel, in_h, in_w = input.shape\r\n        ctx.in_size = input.shape\r\n\r\n        input = input.reshape(-1, in_h, in_w, 1)\r\n\r\n        ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))\r\n\r\n        out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1\r\n        out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1\r\n        ctx.out_size = (out_h, out_w)\r\n\r\n        ctx.up = (up_x, up_y)\r\n        ctx.down = (down_x, down_y)\r\n        ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)\r\n\r\n        g_pad_x0 = kernel_w - pad_x0 - 1\r\n        g_pad_y0 = kernel_h - pad_y0 - 1\r\n        g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1\r\n        g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1\r\n\r\n        ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)\r\n\r\n        out = upfirdn2d_op.upfirdn2d(\r\n            input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1\r\n        )\r\n        # out = out.view(major, out_h, out_w, minor)\r\n        out = out.view(-1, channel, out_h, out_w)\r\n\r\n        return out\r\n\r\n    @staticmethod\r\n    def backward(ctx, grad_output):\r\n        kernel, grad_kernel = ctx.saved_tensors\r\n\r\n        grad_input = UpFirDn2dBackward.apply(\r\n            grad_output,\r\n            kernel,\r\n            grad_kernel,\r\n            ctx.up,\r\n            ctx.down,\r\n            ctx.pad,\r\n            ctx.g_pad,\r\n            ctx.in_size,\r\n            ctx.out_size,\r\n        )\r\n\r\n        return grad_input, None, None, None, None\r\n\r\n\r\ndef upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):\r\n    if input.device.type == \"cpu\":\r\n        out = upfirdn2d_native(\r\n            input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]\r\n        )\r\n\r\n    else:\r\n        out = UpFirDn2d.apply(\r\n            input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])\r\n        )\r\n\r\n    return out\r\n\r\n\r\ndef upfirdn2d_native(\r\n    input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1\r\n):\r\n    _, channel, in_h, in_w = input.shape\r\n    input = input.reshape(-1, in_h, in_w, 1)\r\n\r\n    _, in_h, in_w, minor = input.shape\r\n    kernel_h, kernel_w = kernel.shape\r\n\r\n    out = input.view(-1, in_h, 1, in_w, 1, minor)\r\n    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])\r\n    out = out.view(-1, in_h * up_y, in_w * up_x, minor)\r\n\r\n    out = F.pad(\r\n        out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]\r\n    )\r\n    out = out[\r\n        :,\r\n        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),\r\n        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),\r\n        :,\r\n    ]\r\n\r\n    out = out.permute(0, 3, 1, 2)\r\n    out = out.reshape(\r\n        [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]\r\n    )\r\n    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)\r\n    out = F.conv2d(out, w)\r\n    out = out.reshape(\r\n        -1,\r\n        minor,\r\n        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,\r\n        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,\r\n    )\r\n    out = out.permute(0, 2, 3, 1)\r\n    out = out[:, ::down_y, ::down_x, :]\r\n\r\n    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1\r\n    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1\r\n\r\n    return out.view(-1, channel, out_h, out_w)\r\n"
  },
  {
    "path": "models/archs/stylegan2/op/upfirdn2d_kernel.cu",
    "content": "// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.\r\n//\r\n// This work is made available under the Nvidia Source Code License-NC.\r\n// To view a copy of this license, visit\r\n// https://nvlabs.github.io/stylegan2/license.html\r\n\r\n#include <torch/types.h>\r\n\r\n#include <ATen/ATen.h>\r\n#include <ATen/AccumulateType.h>\r\n#include <ATen/cuda/CUDAApplyUtils.cuh>\r\n#include <ATen/cuda/CUDAContext.h>\r\n\r\n#include <cuda.h>\r\n#include <cuda_runtime.h>\r\n\r\nstatic __host__ __device__ __forceinline__ int floor_div(int a, int b) {\r\n  int c = a / b;\r\n\r\n  if (c * b > a) {\r\n    c--;\r\n  }\r\n\r\n  return c;\r\n}\r\n\r\nstruct UpFirDn2DKernelParams {\r\n  int up_x;\r\n  int up_y;\r\n  int down_x;\r\n  int down_y;\r\n  int pad_x0;\r\n  int pad_x1;\r\n  int pad_y0;\r\n  int pad_y1;\r\n\r\n  int major_dim;\r\n  int in_h;\r\n  int in_w;\r\n  int minor_dim;\r\n  int kernel_h;\r\n  int kernel_w;\r\n  int out_h;\r\n  int out_w;\r\n  int loop_major;\r\n  int loop_x;\r\n};\r\n\r\ntemplate <typename scalar_t>\r\n__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,\r\n                                       const scalar_t *kernel,\r\n                                       const UpFirDn2DKernelParams p) {\r\n  int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;\r\n  int out_y = minor_idx / p.minor_dim;\r\n  minor_idx -= out_y * p.minor_dim;\r\n  int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;\r\n  int major_idx_base = blockIdx.z * p.loop_major;\r\n\r\n  if (out_x_base >= p.out_w || out_y >= p.out_h ||\r\n      major_idx_base >= p.major_dim) {\r\n    return;\r\n  }\r\n\r\n  int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;\r\n  int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);\r\n  int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;\r\n  int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;\r\n\r\n  for (int loop_major = 0, major_idx = major_idx_base;\r\n       loop_major < p.loop_major && major_idx < p.major_dim;\r\n       loop_major++, major_idx++) {\r\n    for (int loop_x = 0, out_x = out_x_base;\r\n         loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {\r\n      int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;\r\n      int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);\r\n      int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;\r\n      int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;\r\n\r\n      const scalar_t *x_p =\r\n          &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +\r\n                 minor_idx];\r\n      const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];\r\n      int x_px = p.minor_dim;\r\n      int k_px = -p.up_x;\r\n      int x_py = p.in_w * p.minor_dim;\r\n      int k_py = -p.up_y * p.kernel_w;\r\n\r\n      scalar_t v = 0.0f;\r\n\r\n      for (int y = 0; y < h; y++) {\r\n        for (int x = 0; x < w; x++) {\r\n          v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);\r\n          x_p += x_px;\r\n          k_p += k_px;\r\n        }\r\n\r\n        x_p += x_py - w * x_px;\r\n        k_p += k_py - w * k_px;\r\n      }\r\n\r\n      out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +\r\n          minor_idx] = v;\r\n    }\r\n  }\r\n}\r\n\r\ntemplate <typename scalar_t, int up_x, int up_y, int down_x, int down_y,\r\n          int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>\r\n__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,\r\n                                 const scalar_t *kernel,\r\n                                 const UpFirDn2DKernelParams p) {\r\n  const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;\r\n  const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;\r\n\r\n  __shared__ volatile float sk[kernel_h][kernel_w];\r\n  __shared__ volatile float sx[tile_in_h][tile_in_w];\r\n\r\n  int minor_idx = blockIdx.x;\r\n  int tile_out_y = minor_idx / p.minor_dim;\r\n  minor_idx -= tile_out_y * p.minor_dim;\r\n  tile_out_y *= tile_out_h;\r\n  int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;\r\n  int major_idx_base = blockIdx.z * p.loop_major;\r\n\r\n  if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |\r\n      major_idx_base >= p.major_dim) {\r\n    return;\r\n  }\r\n\r\n  for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;\r\n       tap_idx += blockDim.x) {\r\n    int ky = tap_idx / kernel_w;\r\n    int kx = tap_idx - ky * kernel_w;\r\n    scalar_t v = 0.0;\r\n\r\n    if (kx < p.kernel_w & ky < p.kernel_h) {\r\n      v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];\r\n    }\r\n\r\n    sk[ky][kx] = v;\r\n  }\r\n\r\n  for (int loop_major = 0, major_idx = major_idx_base;\r\n       loop_major < p.loop_major & major_idx < p.major_dim;\r\n       loop_major++, major_idx++) {\r\n    for (int loop_x = 0, tile_out_x = tile_out_x_base;\r\n         loop_x < p.loop_x & tile_out_x < p.out_w;\r\n         loop_x++, tile_out_x += tile_out_w) {\r\n      int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;\r\n      int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;\r\n      int tile_in_x = floor_div(tile_mid_x, up_x);\r\n      int tile_in_y = floor_div(tile_mid_y, up_y);\r\n\r\n      __syncthreads();\r\n\r\n      for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;\r\n           in_idx += blockDim.x) {\r\n        int rel_in_y = in_idx / tile_in_w;\r\n        int rel_in_x = in_idx - rel_in_y * tile_in_w;\r\n        int in_x = rel_in_x + tile_in_x;\r\n        int in_y = rel_in_y + tile_in_y;\r\n\r\n        scalar_t v = 0.0;\r\n\r\n        if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {\r\n          v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *\r\n                        p.minor_dim +\r\n                    minor_idx];\r\n        }\r\n\r\n        sx[rel_in_y][rel_in_x] = v;\r\n      }\r\n\r\n      __syncthreads();\r\n      for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;\r\n           out_idx += blockDim.x) {\r\n        int rel_out_y = out_idx / tile_out_w;\r\n        int rel_out_x = out_idx - rel_out_y * tile_out_w;\r\n        int out_x = rel_out_x + tile_out_x;\r\n        int out_y = rel_out_y + tile_out_y;\r\n\r\n        int mid_x = tile_mid_x + rel_out_x * down_x;\r\n        int mid_y = tile_mid_y + rel_out_y * down_y;\r\n        int in_x = floor_div(mid_x, up_x);\r\n        int in_y = floor_div(mid_y, up_y);\r\n        int rel_in_x = in_x - tile_in_x;\r\n        int rel_in_y = in_y - tile_in_y;\r\n        int kernel_x = (in_x + 1) * up_x - mid_x - 1;\r\n        int kernel_y = (in_y + 1) * up_y - mid_y - 1;\r\n\r\n        scalar_t v = 0.0;\r\n\r\n#pragma unroll\r\n        for (int y = 0; y < kernel_h / up_y; y++)\r\n#pragma unroll\r\n          for (int x = 0; x < kernel_w / up_x; x++)\r\n            v += sx[rel_in_y + y][rel_in_x + x] *\r\n                 sk[kernel_y + y * up_y][kernel_x + x * up_x];\r\n\r\n        if (out_x < p.out_w & out_y < p.out_h) {\r\n          out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +\r\n              minor_idx] = v;\r\n        }\r\n      }\r\n    }\r\n  }\r\n}\r\n\r\ntorch::Tensor upfirdn2d_op(const torch::Tensor &input,\r\n                           const torch::Tensor &kernel, int up_x, int up_y,\r\n                           int down_x, int down_y, int pad_x0, int pad_x1,\r\n                           int pad_y0, int pad_y1) {\r\n  int curDevice = -1;\r\n  cudaGetDevice(&curDevice);\r\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);\r\n\r\n  UpFirDn2DKernelParams p;\r\n\r\n  auto x = input.contiguous();\r\n  auto k = kernel.contiguous();\r\n\r\n  p.major_dim = x.size(0);\r\n  p.in_h = x.size(1);\r\n  p.in_w = x.size(2);\r\n  p.minor_dim = x.size(3);\r\n  p.kernel_h = k.size(0);\r\n  p.kernel_w = k.size(1);\r\n  p.up_x = up_x;\r\n  p.up_y = up_y;\r\n  p.down_x = down_x;\r\n  p.down_y = down_y;\r\n  p.pad_x0 = pad_x0;\r\n  p.pad_x1 = pad_x1;\r\n  p.pad_y0 = pad_y0;\r\n  p.pad_y1 = pad_y1;\r\n\r\n  p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /\r\n            p.down_y;\r\n  p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /\r\n            p.down_x;\r\n\r\n  auto out =\r\n      at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());\r\n\r\n  int mode = -1;\r\n\r\n  int tile_out_h = -1;\r\n  int tile_out_w = -1;\r\n\r\n  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&\r\n      p.kernel_h <= 4 && p.kernel_w <= 4) {\r\n    mode = 1;\r\n    tile_out_h = 16;\r\n    tile_out_w = 64;\r\n  }\r\n\r\n  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&\r\n      p.kernel_h <= 3 && p.kernel_w <= 3) {\r\n    mode = 2;\r\n    tile_out_h = 16;\r\n    tile_out_w = 64;\r\n  }\r\n\r\n  if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&\r\n      p.kernel_h <= 4 && p.kernel_w <= 4) {\r\n    mode = 3;\r\n    tile_out_h = 16;\r\n    tile_out_w = 64;\r\n  }\r\n\r\n  if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&\r\n      p.kernel_h <= 2 && p.kernel_w <= 2) {\r\n    mode = 4;\r\n    tile_out_h = 16;\r\n    tile_out_w = 64;\r\n  }\r\n\r\n  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&\r\n      p.kernel_h <= 4 && p.kernel_w <= 4) {\r\n    mode = 5;\r\n    tile_out_h = 8;\r\n    tile_out_w = 32;\r\n  }\r\n\r\n  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&\r\n      p.kernel_h <= 2 && p.kernel_w <= 2) {\r\n    mode = 6;\r\n    tile_out_h = 8;\r\n    tile_out_w = 32;\r\n  }\r\n\r\n  dim3 block_size;\r\n  dim3 grid_size;\r\n\r\n  if (tile_out_h > 0 && tile_out_w > 0) {\r\n    p.loop_major = (p.major_dim - 1) / 16384 + 1;\r\n    p.loop_x = 1;\r\n    block_size = dim3(32 * 8, 1, 1);\r\n    grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,\r\n                     (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,\r\n                     (p.major_dim - 1) / p.loop_major + 1);\r\n  } else {\r\n    p.loop_major = (p.major_dim - 1) / 16384 + 1;\r\n    p.loop_x = 4;\r\n    block_size = dim3(4, 32, 1);\r\n    grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,\r\n                     (p.out_w - 1) / (p.loop_x * block_size.y) + 1,\r\n                     (p.major_dim - 1) / p.loop_major + 1);\r\n  }\r\n\r\n  AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), \"upfirdn2d_cuda\", [&] {\r\n    switch (mode) {\r\n    case 1:\r\n      upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>\r\n          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),\r\n                                                 x.data_ptr<scalar_t>(),\r\n                                                 k.data_ptr<scalar_t>(), p);\r\n\r\n      break;\r\n\r\n    case 2:\r\n      upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>\r\n          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),\r\n                                                 x.data_ptr<scalar_t>(),\r\n                                                 k.data_ptr<scalar_t>(), p);\r\n\r\n      break;\r\n\r\n    case 3:\r\n      upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>\r\n          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),\r\n                                                 x.data_ptr<scalar_t>(),\r\n                                                 k.data_ptr<scalar_t>(), p);\r\n\r\n      break;\r\n\r\n    case 4:\r\n      upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>\r\n          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),\r\n                                                 x.data_ptr<scalar_t>(),\r\n                                                 k.data_ptr<scalar_t>(), p);\r\n\r\n      break;\r\n\r\n    case 5:\r\n      upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>\r\n          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),\r\n                                                 x.data_ptr<scalar_t>(),\r\n                                                 k.data_ptr<scalar_t>(), p);\r\n\r\n      break;\r\n\r\n    case 6:\r\n      upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>\r\n          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),\r\n                                                 x.data_ptr<scalar_t>(),\r\n                                                 k.data_ptr<scalar_t>(), p);\r\n\r\n      break;\r\n\r\n    default:\r\n      upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(\r\n          out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),\r\n          k.data_ptr<scalar_t>(), p);\r\n    }\r\n  });\r\n\r\n  return out;\r\n}"
  },
  {
    "path": "models/archs/stylegan2/ppl.py",
    "content": "import argparse\r\n\r\nimport torch\r\nfrom torch.nn import functional as F\r\nimport numpy as np\r\nfrom tqdm import tqdm\r\n\r\nimport lpips\r\nfrom model import Generator\r\n\r\n\r\ndef normalize(x):\r\n    return x / torch.sqrt(x.pow(2).sum(-1, keepdim=True))\r\n\r\n\r\ndef slerp(a, b, t):\r\n    a = normalize(a)\r\n    b = normalize(b)\r\n    d = (a * b).sum(-1, keepdim=True)\r\n    p = t * torch.acos(d)\r\n    c = normalize(b - d * a)\r\n    d = a * torch.cos(p) + c * torch.sin(p)\r\n\r\n    return normalize(d)\r\n\r\n\r\ndef lerp(a, b, t):\r\n    return a + (b - a) * t\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    device = \"cuda\"\r\n\r\n    parser = argparse.ArgumentParser(description=\"Perceptual Path Length calculator\")\r\n\r\n    parser.add_argument(\r\n        \"--space\", choices=[\"z\", \"w\"], help=\"space that PPL calculated with\"\r\n    )\r\n    parser.add_argument(\r\n        \"--batch\", type=int, default=64, help=\"batch size for the models\"\r\n    )\r\n    parser.add_argument(\r\n        \"--n_sample\",\r\n        type=int,\r\n        default=5000,\r\n        help=\"number of the samples for calculating PPL\",\r\n    )\r\n    parser.add_argument(\r\n        \"--size\", type=int, default=256, help=\"output image sizes of the generator\"\r\n    )\r\n    parser.add_argument(\r\n        \"--eps\", type=float, default=1e-4, help=\"epsilon for numerical stability\"\r\n    )\r\n    parser.add_argument(\r\n        \"--crop\", action=\"store_true\", help=\"apply center crop to the images\"\r\n    )\r\n    parser.add_argument(\r\n        \"--sampling\",\r\n        default=\"end\",\r\n        choices=[\"end\", \"full\"],\r\n        help=\"set endpoint sampling method\",\r\n    )\r\n    parser.add_argument(\r\n        \"ckpt\", metavar=\"CHECKPOINT\", help=\"path to the model checkpoints\"\r\n    )\r\n\r\n    args = parser.parse_args()\r\n\r\n    latent_dim = 512\r\n\r\n    ckpt = torch.load(args.ckpt)\r\n\r\n    g = Generator(args.size, latent_dim, 8).to(device)\r\n    g.load_state_dict(ckpt[\"g_ema\"])\r\n    g.eval()\r\n\r\n    percept = lpips.PerceptualLoss(\r\n        model=\"net-lin\", net=\"vgg\", use_gpu=device.startswith(\"cuda\")\r\n    )\r\n\r\n    distances = []\r\n\r\n    n_batch = args.n_sample // args.batch\r\n    resid = args.n_sample - (n_batch * args.batch)\r\n    batch_sizes = [args.batch] * n_batch + [resid]\r\n\r\n    with torch.no_grad():\r\n        for batch in tqdm(batch_sizes):\r\n            noise = g.make_noise()\r\n\r\n            inputs = torch.randn([batch * 2, latent_dim], device=device)\r\n            if args.sampling == \"full\":\r\n                lerp_t = torch.rand(batch, device=device)\r\n            else:\r\n                lerp_t = torch.zeros(batch, device=device)\r\n\r\n            if args.space == \"w\":\r\n                latent = g.get_latent(inputs)\r\n                latent_t0, latent_t1 = latent[::2], latent[1::2]\r\n                latent_e0 = lerp(latent_t0, latent_t1, lerp_t[:, None])\r\n                latent_e1 = lerp(latent_t0, latent_t1, lerp_t[:, None] + args.eps)\r\n                latent_e = torch.stack([latent_e0, latent_e1], 1).view(*latent.shape)\r\n\r\n            image, _ = g([latent_e], input_is_latent=True, noise=noise)\r\n\r\n            if args.crop:\r\n                c = image.shape[2] // 8\r\n                image = image[:, :, c * 3 : c * 7, c * 2 : c * 6]\r\n\r\n            factor = image.shape[2] // 256\r\n\r\n            if factor > 1:\r\n                image = F.interpolate(\r\n                    image, size=(256, 256), mode=\"bilinear\", align_corners=False\r\n                )\r\n\r\n            dist = percept(image[::2], image[1::2]).view(image.shape[0] // 2) / (\r\n                args.eps ** 2\r\n            )\r\n            distances.append(dist.to(\"cpu\").numpy())\r\n\r\n    distances = np.concatenate(distances, 0)\r\n\r\n    lo = np.percentile(distances, 1, interpolation=\"lower\")\r\n    hi = np.percentile(distances, 99, interpolation=\"higher\")\r\n    filtered_dist = np.extract(\r\n        np.logical_and(lo <= distances, distances <= hi), distances\r\n    )\r\n\r\n    print(\"ppl:\", filtered_dist.mean())\r\n"
  },
  {
    "path": "models/archs/stylegan2/sample/.gitignore",
    "content": "*.png\n"
  },
  {
    "path": "models/archs/stylegan2/train.py",
    "content": "import argparse\r\nimport math\r\nimport random\r\nimport os\r\n\r\nimport numpy as np\r\nimport torch\r\nfrom torch import nn, autograd, optim\r\nfrom torch.nn import functional as F\r\nfrom torch.utils import data\r\nimport torch.distributed as dist\r\nfrom torchvision import transforms, utils\r\nfrom tqdm import tqdm\r\n\r\ntry:\r\n    import wandb\r\n\r\nexcept ImportError:\r\n    wandb = None\r\n\r\nfrom model import Generator, Discriminator\r\nfrom dataset import MultiResolutionDataset\r\nfrom distributed import (\r\n    get_rank,\r\n    synchronize,\r\n    reduce_loss_dict,\r\n    reduce_sum,\r\n    get_world_size,\r\n)\r\nfrom non_leaking import augment, AdaptiveAugment\r\n\r\n\r\ndef data_sampler(dataset, shuffle, distributed):\r\n    if distributed:\r\n        return data.distributed.DistributedSampler(dataset, shuffle=shuffle)\r\n\r\n    if shuffle:\r\n        return data.RandomSampler(dataset)\r\n\r\n    else:\r\n        return data.SequentialSampler(dataset)\r\n\r\n\r\ndef requires_grad(model, flag=True):\r\n    for p in model.parameters():\r\n        p.requires_grad = flag\r\n\r\n\r\ndef accumulate(model1, model2, decay=0.999):\r\n    par1 = dict(model1.named_parameters())\r\n    par2 = dict(model2.named_parameters())\r\n\r\n    for k in par1.keys():\r\n        par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)\r\n\r\n\r\ndef sample_data(loader):\r\n    while True:\r\n        for batch in loader:\r\n            yield batch\r\n\r\n\r\ndef d_logistic_loss(real_pred, fake_pred):\r\n    real_loss = F.softplus(-real_pred)\r\n    fake_loss = F.softplus(fake_pred)\r\n\r\n    return real_loss.mean() + fake_loss.mean()\r\n\r\n\r\ndef d_r1_loss(real_pred, real_img):\r\n    grad_real, = autograd.grad(\r\n        outputs=real_pred.sum(), inputs=real_img, create_graph=True\r\n    )\r\n    grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()\r\n\r\n    return grad_penalty\r\n\r\n\r\ndef g_nonsaturating_loss(fake_pred):\r\n    loss = F.softplus(-fake_pred).mean()\r\n\r\n    return loss\r\n\r\n\r\ndef g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):\r\n    noise = torch.randn_like(fake_img) / math.sqrt(\r\n        fake_img.shape[2] * fake_img.shape[3]\r\n    )\r\n    grad, = autograd.grad(\r\n        outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True\r\n    )\r\n    path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))\r\n\r\n    path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)\r\n\r\n    path_penalty = (path_lengths - path_mean).pow(2).mean()\r\n\r\n    return path_penalty, path_mean.detach(), path_lengths\r\n\r\n\r\ndef make_noise(batch, latent_dim, n_noise, device):\r\n    if n_noise == 1:\r\n        return torch.randn(batch, latent_dim, device=device)\r\n\r\n    noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0)\r\n\r\n    return noises\r\n\r\n\r\ndef mixing_noise(batch, latent_dim, prob, device):\r\n    if prob > 0 and random.random() < prob:\r\n        return make_noise(batch, latent_dim, 2, device)\r\n\r\n    else:\r\n        return [make_noise(batch, latent_dim, 1, device)]\r\n\r\n\r\ndef set_grad_none(model, targets):\r\n    for n, p in model.named_parameters():\r\n        if n in targets:\r\n            p.grad = None\r\n\r\n\r\ndef train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device):\r\n    loader = sample_data(loader)\r\n\r\n    pbar = range(args.iter)\r\n\r\n    if get_rank() == 0:\r\n        pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)\r\n\r\n    mean_path_length = 0\r\n\r\n    d_loss_val = 0\r\n    r1_loss = torch.tensor(0.0, device=device)\r\n    g_loss_val = 0\r\n    path_loss = torch.tensor(0.0, device=device)\r\n    path_lengths = torch.tensor(0.0, device=device)\r\n    mean_path_length_avg = 0\r\n    loss_dict = {}\r\n\r\n    if args.distributed:\r\n        g_module = generator.module\r\n        d_module = discriminator.module\r\n\r\n    else:\r\n        g_module = generator\r\n        d_module = discriminator\r\n\r\n    accum = 0.5 ** (32 / (10 * 1000))\r\n    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0\r\n    r_t_stat = 0\r\n\r\n    if args.augment and args.augment_p == 0:\r\n        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256, device)\r\n\r\n    sample_z = torch.randn(args.n_sample, args.latent, device=device)\r\n\r\n    for idx in pbar:\r\n        i = idx + args.start_iter\r\n\r\n        if i > args.iter:\r\n            print(\"Done!\")\r\n\r\n            break\r\n\r\n        real_img = next(loader)\r\n        real_img = real_img.to(device)\r\n\r\n        requires_grad(generator, False)\r\n        requires_grad(discriminator, True)\r\n\r\n        noise = mixing_noise(args.batch, args.latent, args.mixing, device)\r\n        fake_img, _ = generator(noise)\r\n\r\n        if args.augment:\r\n            real_img_aug, _ = augment(real_img, ada_aug_p)\r\n            fake_img, _ = augment(fake_img, ada_aug_p)\r\n\r\n        else:\r\n            real_img_aug = real_img\r\n\r\n        fake_pred = discriminator(fake_img)\r\n        real_pred = discriminator(real_img_aug)\r\n        d_loss = d_logistic_loss(real_pred, fake_pred)\r\n\r\n        loss_dict[\"d\"] = d_loss\r\n        loss_dict[\"real_score\"] = real_pred.mean()\r\n        loss_dict[\"fake_score\"] = fake_pred.mean()\r\n\r\n        discriminator.zero_grad()\r\n        d_loss.backward()\r\n        d_optim.step()\r\n\r\n        if args.augment and args.augment_p == 0:\r\n            ada_aug_p = ada_augment.tune(real_pred)\r\n            r_t_stat = ada_augment.r_t_stat\r\n\r\n        d_regularize = i % args.d_reg_every == 0\r\n\r\n        if d_regularize:\r\n            real_img.requires_grad = True\r\n            real_pred = discriminator(real_img)\r\n            r1_loss = d_r1_loss(real_pred, real_img)\r\n\r\n            discriminator.zero_grad()\r\n            (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward()\r\n\r\n            d_optim.step()\r\n\r\n        loss_dict[\"r1\"] = r1_loss\r\n\r\n        requires_grad(generator, True)\r\n        requires_grad(discriminator, False)\r\n\r\n        noise = mixing_noise(args.batch, args.latent, args.mixing, device)\r\n        fake_img, _ = generator(noise)\r\n\r\n        if args.augment:\r\n            fake_img, _ = augment(fake_img, ada_aug_p)\r\n\r\n        fake_pred = discriminator(fake_img)\r\n        g_loss = g_nonsaturating_loss(fake_pred)\r\n\r\n        loss_dict[\"g\"] = g_loss\r\n\r\n        generator.zero_grad()\r\n        g_loss.backward()\r\n        g_optim.step()\r\n\r\n        g_regularize = i % args.g_reg_every == 0\r\n\r\n        if g_regularize:\r\n            path_batch_size = max(1, args.batch // args.path_batch_shrink)\r\n            noise = mixing_noise(path_batch_size, args.latent, args.mixing, device)\r\n            fake_img, latents = generator(noise, return_latents=True)\r\n\r\n            path_loss, mean_path_length, path_lengths = g_path_regularize(\r\n                fake_img, latents, mean_path_length\r\n            )\r\n\r\n            generator.zero_grad()\r\n            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss\r\n\r\n            if args.path_batch_shrink:\r\n                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]\r\n\r\n            weighted_path_loss.backward()\r\n\r\n            g_optim.step()\r\n\r\n            mean_path_length_avg = (\r\n                reduce_sum(mean_path_length).item() / get_world_size()\r\n            )\r\n\r\n        loss_dict[\"path\"] = path_loss\r\n        loss_dict[\"path_length\"] = path_lengths.mean()\r\n\r\n        accumulate(g_ema, g_module, accum)\r\n\r\n        loss_reduced = reduce_loss_dict(loss_dict)\r\n\r\n        d_loss_val = loss_reduced[\"d\"].mean().item()\r\n        g_loss_val = loss_reduced[\"g\"].mean().item()\r\n        r1_val = loss_reduced[\"r1\"].mean().item()\r\n        path_loss_val = loss_reduced[\"path\"].mean().item()\r\n        real_score_val = loss_reduced[\"real_score\"].mean().item()\r\n        fake_score_val = loss_reduced[\"fake_score\"].mean().item()\r\n        path_length_val = loss_reduced[\"path_length\"].mean().item()\r\n\r\n        if get_rank() == 0:\r\n            pbar.set_description(\r\n                (\r\n                    f\"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; \"\r\n                    f\"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; \"\r\n                    f\"augment: {ada_aug_p:.4f}\"\r\n                )\r\n            )\r\n\r\n            if wandb and args.wandb:\r\n                wandb.log(\r\n                    {\r\n                        \"Generator\": g_loss_val,\r\n                        \"Discriminator\": d_loss_val,\r\n                        \"Augment\": ada_aug_p,\r\n                        \"Rt\": r_t_stat,\r\n                        \"R1\": r1_val,\r\n                        \"Path Length Regularization\": path_loss_val,\r\n                        \"Mean Path Length\": mean_path_length,\r\n                        \"Real Score\": real_score_val,\r\n                        \"Fake Score\": fake_score_val,\r\n                        \"Path Length\": path_length_val,\r\n                    }\r\n                )\r\n\r\n            if i % 100 == 0:\r\n                with torch.no_grad():\r\n                    g_ema.eval()\r\n                    sample, _ = g_ema([sample_z])\r\n                    utils.save_image(\r\n                        sample,\r\n                        f\"sample/{str(i).zfill(6)}.png\",\r\n                        nrow=int(args.n_sample ** 0.5),\r\n                        normalize=True,\r\n                        range=(-1, 1),\r\n                    )\r\n\r\n            if i % 10000 == 0:\r\n                torch.save(\r\n                    {\r\n                        \"g\": g_module.state_dict(),\r\n                        \"d\": d_module.state_dict(),\r\n                        \"g_ema\": g_ema.state_dict(),\r\n                        \"g_optim\": g_optim.state_dict(),\r\n                        \"d_optim\": d_optim.state_dict(),\r\n                        \"args\": args,\r\n                        \"ada_aug_p\": ada_aug_p,\r\n                    },\r\n                    f\"checkpoint/{str(i).zfill(6)}.pt\",\r\n                )\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    device = \"cuda\"\r\n\r\n    parser = argparse.ArgumentParser(description=\"StyleGAN2 trainer\")\r\n\r\n    parser.add_argument(\"path\", type=str, help=\"path to the lmdb dataset\")\r\n    parser.add_argument(\r\n        \"--iter\", type=int, default=800000, help=\"total training iterations\"\r\n    )\r\n    parser.add_argument(\r\n        \"--batch\", type=int, default=16, help=\"batch sizes for each gpus\"\r\n    )\r\n    parser.add_argument(\r\n        \"--n_sample\",\r\n        type=int,\r\n        default=64,\r\n        help=\"number of the samples generated during training\",\r\n    )\r\n    parser.add_argument(\r\n        \"--size\", type=int, default=256, help=\"image sizes for the model\"\r\n    )\r\n    parser.add_argument(\r\n        \"--r1\", type=float, default=10, help=\"weight of the r1 regularization\"\r\n    )\r\n    parser.add_argument(\r\n        \"--path_regularize\",\r\n        type=float,\r\n        default=2,\r\n        help=\"weight of the path length regularization\",\r\n    )\r\n    parser.add_argument(\r\n        \"--path_batch_shrink\",\r\n        type=int,\r\n        default=2,\r\n        help=\"batch size reducing factor for the path length regularization (reduce memory consumption)\",\r\n    )\r\n    parser.add_argument(\r\n        \"--d_reg_every\",\r\n        type=int,\r\n        default=16,\r\n        help=\"interval of the applying r1 regularization\",\r\n    )\r\n    parser.add_argument(\r\n        \"--g_reg_every\",\r\n        type=int,\r\n        default=4,\r\n        help=\"interval of the applying path length regularization\",\r\n    )\r\n    parser.add_argument(\r\n        \"--mixing\", type=float, default=0.9, help=\"probability of latent code mixing\"\r\n    )\r\n    parser.add_argument(\r\n        \"--ckpt\",\r\n        type=str,\r\n        default=None,\r\n        help=\"path to the checkpoints to resume training\",\r\n    )\r\n    parser.add_argument(\"--lr\", type=float, default=0.002, help=\"learning rate\")\r\n    parser.add_argument(\r\n        \"--channel_multiplier\",\r\n        type=int,\r\n        default=2,\r\n        help=\"channel multiplier factor for the model. config-f = 2, else = 1\",\r\n    )\r\n    parser.add_argument(\r\n        \"--wandb\", action=\"store_true\", help=\"use weights and biases logging\"\r\n    )\r\n    parser.add_argument(\r\n        \"--local_rank\", type=int, default=0, help=\"local rank for distributed training\"\r\n    )\r\n    parser.add_argument(\r\n        \"--augment\", action=\"store_true\", help=\"apply non leaking augmentation\"\r\n    )\r\n    parser.add_argument(\r\n        \"--augment_p\",\r\n        type=float,\r\n        default=0,\r\n        help=\"probability of applying augmentation. 0 = use adaptive augmentation\",\r\n    )\r\n    parser.add_argument(\r\n        \"--ada_target\",\r\n        type=float,\r\n        default=0.6,\r\n        help=\"target augmentation probability for adaptive augmentation\",\r\n    )\r\n    parser.add_argument(\r\n        \"--ada_length\",\r\n        type=int,\r\n        default=500 * 1000,\r\n        help=\"target duraing to reach augmentation probability for adaptive augmentation\",\r\n    )\r\n    parser.add_argument(\r\n        \"--ada_every\",\r\n        type=int,\r\n        default=256,\r\n        help=\"probability update interval of the adaptive augmentation\",\r\n    )\r\n\r\n    args = parser.parse_args()\r\n\r\n    n_gpu = int(os.environ[\"WORLD_SIZE\"]) if \"WORLD_SIZE\" in os.environ else 1\r\n    args.distributed = n_gpu > 1\r\n\r\n    if args.distributed:\r\n        torch.cuda.set_device(args.local_rank)\r\n        torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\r\n        synchronize()\r\n\r\n    args.latent = 512\r\n    args.n_mlp = 8\r\n\r\n    args.start_iter = 0\r\n\r\n    generator = Generator(\r\n        args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier\r\n    ).to(device)\r\n    discriminator = Discriminator(\r\n        args.size, channel_multiplier=args.channel_multiplier\r\n    ).to(device)\r\n    g_ema = Generator(\r\n        args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier\r\n    ).to(device)\r\n    g_ema.eval()\r\n    accumulate(g_ema, generator, 0)\r\n\r\n    g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)\r\n    d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)\r\n\r\n    g_optim = optim.Adam(\r\n        generator.parameters(),\r\n        lr=args.lr * g_reg_ratio,\r\n        betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),\r\n    )\r\n    d_optim = optim.Adam(\r\n        discriminator.parameters(),\r\n        lr=args.lr * d_reg_ratio,\r\n        betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),\r\n    )\r\n\r\n    if args.ckpt is not None:\r\n        print(\"load model:\", args.ckpt)\r\n\r\n        ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage)\r\n\r\n        try:\r\n            ckpt_name = os.path.basename(args.ckpt)\r\n            args.start_iter = int(os.path.splitext(ckpt_name)[0])\r\n\r\n        except ValueError:\r\n            pass\r\n\r\n        generator.load_state_dict(ckpt[\"g\"])\r\n        discriminator.load_state_dict(ckpt[\"d\"])\r\n        g_ema.load_state_dict(ckpt[\"g_ema\"])\r\n\r\n        g_optim.load_state_dict(ckpt[\"g_optim\"])\r\n        d_optim.load_state_dict(ckpt[\"d_optim\"])\r\n\r\n    if args.distributed:\r\n        generator = nn.parallel.DistributedDataParallel(\r\n            generator,\r\n            device_ids=[args.local_rank],\r\n            output_device=args.local_rank,\r\n            broadcast_buffers=False,\r\n        )\r\n\r\n        discriminator = nn.parallel.DistributedDataParallel(\r\n            discriminator,\r\n            device_ids=[args.local_rank],\r\n            output_device=args.local_rank,\r\n            broadcast_buffers=False,\r\n        )\r\n\r\n    transform = transforms.Compose(\r\n        [\r\n            transforms.RandomHorizontalFlip(),\r\n            transforms.ToTensor(),\r\n            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),\r\n        ]\r\n    )\r\n\r\n    dataset = MultiResolutionDataset(args.path, transform, args.size)\r\n    loader = data.DataLoader(\r\n        dataset,\r\n        batch_size=args.batch,\r\n        sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed),\r\n        drop_last=True,\r\n    )\r\n\r\n    if get_rank() == 0 and wandb is not None and args.wandb:\r\n        wandb.init(project=\"stylegan 2\")\r\n\r\n    train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device)\r\n"
  },
  {
    "path": "models/base_model.py",
    "content": "import logging\nimport math\nfrom collections import OrderedDict\n\nimport cv2\nimport matplotlib.image as mpimg\nimport matplotlib.pyplot as plt\nimport torch\nimport torch.nn as nn\n\nfrom models.archs.attribute_predictor_arch import resnet50\nfrom models.archs.field_function_arch import FieldFunction\nfrom models.archs.stylegan2.model import Generator\nfrom models.losses.arcface_loss import ArcFaceLoss\nfrom models.losses.discriminator_loss import DiscriminatorLoss\nfrom models.utils import (postprocess, predictor_to_label, save_image,\n                          transform_image)\n\nlogger = logging.getLogger('base')\n\n\nclass BaseModel():\n    \"\"\"Base model.\n    \"\"\"\n\n    def __init__(self, opt):\n        self.opt = opt\n        self.device = torch.device('cuda')\n        self.is_train = opt['is_train']\n        self.target_attr_idx = opt['attr_dict'][opt['attribute']]\n\n        # define stylegan generator\n        self.stylegan_gen = Generator(\n            size=opt['img_res'],\n            style_dim=opt['latent_dim'],\n            n_mlp=opt['n_mlp'],\n            channel_multiplier=opt['channel_multiplier']).to(self.device)\n\n        self.truncation = 1.0\n        self.truncation_latent = None\n        self.randomize_noise = False\n        if opt['latent_space'] == 'z':\n            self.input_is_latent = False\n            self.latent_code_is_w_space = False\n        else:\n            self.input_is_latent = True\n            self.latent_code_is_w_space = True\n\n        self.transform_z_to_w = opt.get('transform_z_to_w', True)\n\n        if opt['img_res'] == 128:\n            self.w_space_channel_num = 12\n            logger.info(\n                f'Loading stylegan model from: {opt[\"generator_ckpt\"]}')\n            checkpoint = torch.load(opt['generator_ckpt'])\n            self.stylegan_gen.load_state_dict(checkpoint[\"g_ema\"], strict=True)\n            self.img_resize = False\n        elif opt['img_res'] == 1024:\n            self.w_space_channel_num = 18\n            logger.info(\n                f'Loading stylegan model from: {opt[\"generator_ckpt\"]}')\n            checkpoint = torch.load(opt['generator_ckpt'])\n            self.stylegan_gen.load_state_dict(checkpoint, strict=True)\n            self.img_resize = True\n\n        # define attribute predictor\n        self.predictor = resnet50(attr_file=opt['attr_file'])\n        self.predictor = self.predictor.to(self.device)\n\n        logger.info(f'Loading model from: {opt[\"predictor_ckpt\"]}')\n        checkpoint = torch.load(opt['predictor_ckpt'])\n        self.predictor.load_state_dict(checkpoint['state_dict'], strict=True)\n        self.predictor.eval()\n\n        # define field function\n        self.field_function = FieldFunction(\n            num_layer=opt['num_layer'],\n            latent_dim=512,\n            hidden_dim=opt['hidden_dim'],\n            leaky_relu_neg_slope=opt['leaky_relu_neg_slope'])\n\n        self.field_function = self.field_function.to(self.device)\n\n        self.fix_layers = False\n        if self.is_train:\n            self.init_training_settings()\n            self.log_dict = OrderedDict()\n\n    def init_training_settings(self):\n        # set up optimizers\n        self.optimizer = torch.optim.Adam(\n            self.field_function.parameters(),\n            self.opt['lr'],\n            weight_decay=self.opt['weight_decay'])\n\n        # define loss functions\n        # predictor loss\n        self.criterion_predictor = nn.CrossEntropyLoss(reduction='mean')\n\n        # arcface loss\n        if self.opt['arcface_weight'] > 0:\n            self.criterion_arcface = ArcFaceLoss(\n                self.opt['pretrained_arcface'], self.opt['arcface_loss_type'])\n        else:\n            self.criterion_arcface = None\n\n        # discriminator loss\n        if self.opt['arcface_weight'] > 0:\n            self.criterion_disc = DiscriminatorLoss(\n                self.opt['discriminator_ckpt'], self.opt['img_res'])\n        else:\n            self.criterion_disc = None\n\n    def feed_data(self, data):\n        self.original_latent_code = data[0].to(self.device)\n        self.original_label = data[1].to(self.device)\n        self.gt_label = self.original_label.clone()\n        self.gt_label[:, self.target_attr_idx] = \\\n            self.gt_label[:, self.target_attr_idx] + 1\n\n    def optimize_parameters(self):\n        self.field_function.train()\n\n        if self.latent_code_is_w_space and self.transform_z_to_w:\n            # translate original z space latent code to w space\n            with torch.no_grad():\n                original_latent_code = self.stylegan_gen.get_latent(\n                    self.original_latent_code)\n        else:\n            original_latent_code = self.original_latent_code\n\n        # modify latent code via field function\n        edited_dict = self.modify_latent_code(original_latent_code)\n        edited_image = self.synthesize_image(edited_dict['edited_latent_code'])\n        predictor_output = self.predictor(\n            transform_image(edited_image, self.img_resize))\n\n        # compute loss function\n        loss_total = 0\n\n        assert self.opt['num_attr'] == len(predictor_output)\n        loss_list = []\n\n        # iterate over each attribute\n        for attr_idx in range(self.opt['num_attr']):\n            loss_attr = self.criterion_predictor(predictor_output[attr_idx],\n                                                 self.gt_label[:, attr_idx])\n            if attr_idx == self.target_attr_idx:\n                loss_attr = loss_attr * self.opt['edited_attribute_weight']\n            loss_list.append(loss_attr)\n        predictor_loss = sum(loss_list) / len(loss_list)\n        self.log_dict['predictor_loss'] = predictor_loss\n\n        loss_total += predictor_loss\n\n        if self.criterion_arcface is not None:\n            original_image = self.synthesize_image(original_latent_code)\n            arcface_loss = self.criterion_arcface(original_image, edited_image,\n                                                  self.img_resize)\n            loss_total += self.opt['arcface_weight'] * arcface_loss\n            self.log_dict['arcface_loss'] = arcface_loss\n\n        if self.opt['disc_weight'] > 0:\n            disc_loss = self.criterion_disc(edited_image)\n            loss_total += disc_loss * self.opt['disc_weight']\n            self.log_dict['disc_loss'] = disc_loss\n\n        self.optimizer.zero_grad()\n        loss_total.backward()\n        self.optimizer.step()\n\n        self.log_dict['loss_total'] = loss_total\n\n    def get_current_log(self):\n        return self.log_dict\n\n    def update_learning_rate(self, epoch):\n        \"\"\"Update learning rate.\n\n        Args:\n            current_iter (int): Current iteration.\n            warmup_iter (int)： Warmup iter numbers. -1 for no warmup.\n                Default： -1.\n        \"\"\"\n        lr = self.optimizer.param_groups[0]['lr']\n\n        if self.opt['lr_decay'] == 'step':\n            lr = self.opt['lr'] * (\n                self.opt['gamma']**(epoch // self.opt['step']))\n        elif self.opt['lr_decay'] == 'cos':\n            lr = self.opt['lr'] * (\n                1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2\n        elif self.opt['lr_decay'] == 'linear':\n            lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])\n        elif self.opt['lr_decay'] == 'linear2exp':\n            if epoch < self.opt['turning_point'] + 1:\n                # learning rate decay as 95%\n                # at the turning point (1 / 95% = 1.0526)\n                lr = self.opt['lr'] * (\n                    1 - epoch / int(self.opt['turning_point'] * 1.0526))\n            else:\n                lr *= self.opt['gamma']\n        elif self.opt['lr_decay'] == 'schedule':\n            if epoch in self.opt['schedule']:\n                lr *= self.opt['gamma']\n        else:\n            raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))\n        # set learning rate\n        for param_group in self.optimizer.param_groups:\n            param_group['lr'] = lr\n\n        return lr\n\n    def save_network(self, net, save_path):\n        \"\"\"Save networks.\n\n        Args:\n            net (nn.Module): Network to be saved.\n            net_label (str): Network label.\n            current_iter (int): Current iter number.\n        \"\"\"\n        state_dict = net.state_dict()\n        torch.save(state_dict, save_path)\n\n    def load_network(self, pretrained_field):\n        checkpoint = torch.load(pretrained_field)\n\n        self.field_function.load_state_dict(checkpoint, strict=True)\n        self.field_function.eval()\n\n    def synthesize_image(self, sample_latent_code):\n        synthesized_img, _ = self.stylegan_gen(\n            [sample_latent_code],\n            truncation=self.truncation,\n            input_is_latent=self.input_is_latent,\n            truncation_latent=self.truncation_latent,\n            randomize_noise=self.randomize_noise)\n\n        return synthesized_img\n\n    def synthesize_and_predict(self, sample_latent_code):\n        synthesized_img = self.synthesize_image(sample_latent_code)\n\n        current_predictor_output = self.predictor(\n            transform_image(synthesized_img, self.img_resize))\n        predicted_label, predicted_score = predictor_to_label(\n            current_predictor_output)\n\n        return synthesized_img, predicted_label, predicted_score\n\n    def inference(self, batch_idx, epoch, save_dir):\n        self.field_function.eval()\n\n        assert self.original_latent_code.size()[0] == 1\n\n        if self.latent_code_is_w_space and self.transform_z_to_w:\n            # translate original z space latent code to w space\n            with torch.no_grad():\n                original_latent_code = self.stylegan_gen.get_latent(\n                    self.original_latent_code)\n        else:\n            original_latent_code = self.original_latent_code\n\n        with torch.no_grad():\n            original_image = self.synthesize_image(original_latent_code)\n        original_image = postprocess(original_image.cpu().detach().numpy())\n\n        # field function mapping\n        with torch.no_grad():\n            return_dict = self.modify_latent_code(original_latent_code)\n\n        with torch.no_grad():\n            edited_image, edited_label, _ = self.synthesize_and_predict(\n                return_dict['edited_latent_code'])\n\n        edited_image = postprocess(edited_image.cpu().detach().numpy())\n        concat_images = cv2.hconcat([original_image[0], edited_image[0]])\n        save_image(\n            concat_images,\n            f'{save_dir}/{batch_idx:03d}_epoch_{epoch:03d}_{self.opt[\"exp_name\"]}_original_{self.original_label[0][self.target_attr_idx]}_edited_{edited_label[self.target_attr_idx]}.png',  # noqa\n            need_post_process=False)\n\n        self.field_function.train()\n\n    def continuous_editing(self, latent_codes, save_dir, editing_logger):\n        total_num = latent_codes.shape[0]\n\n        for sample_id in range(total_num):\n            sample_latent_code = torch.from_numpy(\n                latent_codes[sample_id:sample_id + 1]).to(\n                    torch.device('cuda'))\n\n            if self.latent_code_is_w_space and self.transform_z_to_w:\n                # translate original z space latent code to w space\n                with torch.no_grad():\n                    sample_latent_code = self.stylegan_gen.get_latent(\n                        sample_latent_code)\n\n            # synthesize\n            with torch.no_grad():\n                original_image, start_label, start_score = \\\n                    self.synthesize_and_predict(sample_latent_code)\n\n            target_attr_label = int(start_label[self.target_attr_idx])\n            target_score = start_score[self.target_attr_idx]\n\n            save_name = f'{sample_id:03d}_num_edits_0_class_{target_attr_label}.png'  # noqa\n            save_image(original_image, f'{save_dir}/{save_name}')\n\n            editing_logger.info(f'{save_name}: {start_label}, {start_score}')\n            # skip images with low confidence\n            if target_score < self.opt['confidence_thresh']:\n                editing_logger.info(\n                    f'Sample {sample_id:03d} is not confident, skip.')\n                continue\n\n            # skip images that are already the max_cls_num\n            if target_attr_label == self.opt['max_cls_num']:\n                editing_logger.info(\n                    f'Sample {sample_id:03d} is already the max_cls_num, skip.'\n                )\n                continue\n\n            num_trials = 0\n            num_edits = 0\n\n            current_stage_scores_list = []\n            current_stage_labels_list = []\n            current_stage_images_list = []\n            current_stage_target_scores_list = []\n\n            previous_target_attr_label = target_attr_label\n\n            if self.fix_layers:\n                edited_latent_code = sample_latent_code.unsqueeze(1).repeat(\n                    1, self.w_space_channel_num, 1)\n\n            while target_attr_label < self.opt['max_cls_num']:\n                num_trials += 1\n                with torch.no_grad():\n                    # modify sampled latent code\n                    if self.fix_layers:\n                        # for fix layers, the input to the field_function is w\n                        # space, but the input to the stylegan is w plus space\n                        edited_dict = self.modify_latent_code(\n                            sample_latent_code, edited_latent_code)\n                        sample_latent_code = sample_latent_code + edited_dict[\n                            'field']\n                        edited_latent_code = edited_dict['edited_latent_code']\n                    else:\n                        # for other modes, the input to the field function and\n                        # stylegan are same (both w space or z space)\n                        edited_dict = self.modify_latent_code(\n                            sample_latent_code)\n                        sample_latent_code = edited_dict['edited_latent_code']\n\n                    edited_image, edited_label, edited_score = \\\n                        self.synthesize_and_predict(edited_dict['edited_latent_code']) # noqa\n\n                target_attr_label = edited_label[self.target_attr_idx]\n                target_attr_score = edited_score[self.target_attr_idx]\n                if target_attr_label != previous_target_attr_label:\n                    num_edits += 1\n\n                if num_edits > 0:\n                    if target_attr_label == previous_target_attr_label:\n                        current_stage_images_list.append(edited_image)\n                        current_stage_labels_list.append(edited_label)\n                        current_stage_scores_list.append(edited_score)\n                        current_stage_target_scores_list.append(\n                            target_attr_score)\n                    else:\n                        if num_edits > 1:\n                            # save images for previous stage\n                            max_value = max(current_stage_target_scores_list)\n                            max_index = current_stage_target_scores_list.index(\n                                max_value)\n                            saved_image = current_stage_images_list[max_index]\n                            saved_label = current_stage_labels_list[max_index]\n                            saved_score = current_stage_scores_list[max_index]\n                            save_name = f'{sample_id:03d}_num_edits_{num_edits-1}_class_{previous_target_attr_label}.png'  # noqa\n                            save_image(saved_image, f'{save_dir}/{save_name}')\n                            editing_logger.info(\n                                f'{save_name}: {saved_label}, {saved_score}')\n\n                        current_stage_images_list = []\n                        current_stage_labels_list = []\n                        current_stage_scores_list = []\n                        current_stage_target_scores_list = []\n                        num_trials = 0\n\n                        current_stage_images_list.append(edited_image)\n                        current_stage_labels_list.append(edited_label)\n                        current_stage_scores_list.append(edited_score)\n                        current_stage_target_scores_list.append(\n                            target_attr_score)\n\n                previous_target_attr_label = target_attr_label\n                if self.opt['print_every']:\n                    save_name = f'{sample_id:03d}_num_edits_{num_edits}_num_trials_{num_trials}_class_{target_attr_label}.png'  # noqa\\\n                    saved_image(edited_image, f'{save_dir}/{save_name}')\n                    editing_logger.info(\n                        f'{save_name}: {edited_label}, {edited_score}')\n\n                if num_trials > self.opt['max_trials_num']:\n                    editing_logger.info('Maximum edits num reached.')\n                    break\n\n            if num_edits > 0:\n                # save images for previous stage\n                max_value = max(current_stage_target_scores_list)\n                max_index = current_stage_target_scores_list.index(max_value)\n                saved_image = current_stage_images_list[max_index]\n                saved_label = current_stage_labels_list[max_index]\n                saved_score = current_stage_scores_list[max_index]\n                save_name = f'{sample_id:03d}_num_edits_{num_edits}_class_{previous_target_attr_label}.png'  # noqa\n                save_image(saved_image, f'{save_dir}/{save_name}')\n                editing_logger.info(\n                    f'{save_name}: {saved_label}, {saved_score}')\n\n            editing_logger.info(f'{sample_id:03d}: Finish editing.')\n\n    def continuous_editing_with_target(self,\n                                       latent_codes,\n                                       target_cls,\n                                       save_dir,\n                                       editing_logger,\n                                       edited_latent_code,\n                                       prefix,\n                                       print_intermediate_result=False,\n                                       display_img=False):\n        total_num = latent_codes.shape[0]\n\n        for sample_id in range(total_num):\n\n            sample_latent_code = torch.from_numpy(\n                latent_codes[sample_id:sample_id + 1]).to(\n                    torch.device('cuda'))\n            start_latent_codes = sample_latent_code\n            start_edited_latent_code = edited_latent_code\n\n            exception_mode = 'normal'\n\n            # synthesize\n            if edited_latent_code is None:\n                if self.latent_code_is_w_space and self.transform_z_to_w:\n                    # translate original z space latent code to w space\n                    with torch.no_grad():\n                        sample_latent_code = self.stylegan_gen.get_latent(\n                            sample_latent_code)\n\n                with torch.no_grad():\n                    original_image, start_label, start_score = \\\n                        self.synthesize_and_predict(sample_latent_code)\n            else:\n                with torch.no_grad():\n                    original_image, start_label, start_score = \\\n                        self.synthesize_and_predict(edited_latent_code)\n\n            target_attr_label = int(start_label[self.target_attr_idx])\n            target_score = start_score[self.target_attr_idx]\n\n            # save_name = f'{prefix}_{sample_id:03d}_num_edits_0_class_{target_attr_label}_attr_idx_{self.target_attr_idx}.png'  # noqa\n            ### save_image(original_image, f'{save_dir}/{save_name}')\n\n            # editing_logger.info(f'{save_name}: {start_label}, {start_score}')\n            # skip images with low confidence\n            if target_score < self.opt['confidence_thresh']:\n                if editing_logger:\n                    editing_logger.info(\n                        f'Sample {sample_id:03d} is not confident, skip.')\n                continue\n\n            # skip images that are already the target class num\n            if target_attr_label == target_cls:\n                if editing_logger:\n                    editing_logger.info(\n                        f'Sample {sample_id:03d} is already at the target class, skip.'\n                    )\n                # return the exactly the input image and input latent codes\n                saved_label = start_label\n                saved_latent_code = start_latent_codes\n                saved_editing_latent_code = start_edited_latent_code\n                saved_score = start_score\n                # save_name = f'{prefix}_{sample_id:03d}_num_edits_1_class_{target_attr_label}_attr_idx_{self.target_attr_idx}.png'  # noqa\n                ### save_image(original_image, f'{save_dir}/{save_name}')\n                # editing_logger.info(\n                #     f'{save_name}: {saved_label}, {saved_score}')\n                exception_mode = 'already_at_target_class'\n                continue\n            elif target_attr_label < target_cls:\n                direction = 'positive'\n                alpha = 1\n            elif target_attr_label > target_cls:\n                direction = 'negative'\n                alpha = -1\n\n            num_trials = 0\n            num_edits = 0\n\n            current_stage_scores_list = []\n            current_stage_labels_list = []\n            current_stage_images_list = []\n            current_stage_target_scores_list = []\n            current_stage_latent_code_list = []\n            current_stage_editing_latent_code_list = []\n\n            previous_target_attr_label = target_attr_label\n\n            if self.fix_layers:\n                if edited_latent_code is None:\n                    edited_latent_code = sample_latent_code.unsqueeze(\n                        1).repeat(1, self.w_space_channel_num, 1)\n\n            while ((direction == 'positive') and\n                   (target_attr_label <= target_cls) and\n                   (target_attr_label < self.opt['max_cls_num'])) or (\n                       (direction == 'negative') and\n                       (target_attr_label >= target_cls) and\n                       (target_attr_label > self.opt['min_cls_num'])):\n                num_trials += 1\n                with torch.no_grad():\n                    # modify sampled latent code\n                    if self.fix_layers:\n                        # for fix layers, the input to the field_function is w\n                        # space, but the input to the stylegan is w plus space\n                        edited_dict = self.modify_latent_code_bidirection(\n                            sample_latent_code, edited_latent_code, alpha)\n                        sample_latent_code = sample_latent_code + alpha * edited_dict[\n                            'field']\n                        edited_latent_code = edited_dict['edited_latent_code']\n                    else:\n                        # for other modes, the input to the field function and\n                        # stylegan are same (both w space or z space)\n                        edited_dict = self.modify_latent_code_bidirection(\n                            latent_code_w=sample_latent_code, alpha=1)\n                        sample_latent_code = edited_dict['edited_latent_code']\n\n                    edited_image, edited_label, edited_score = \\\n                        self.synthesize_and_predict(edited_dict['edited_latent_code']) # noqa\n\n                target_attr_label = edited_label[self.target_attr_idx]\n                target_attr_score = edited_score[self.target_attr_idx]\n\n                if ((direction == 'positive') and\n                    (target_attr_label > target_cls)) or (\n                        (direction == 'negative') and\n                        (target_attr_label < target_cls)):\n                    if num_edits == 0:\n                        saved_label = edited_label\n                        saved_latent_code = sample_latent_code\n                        saved_editing_latent_code = edited_latent_code\n                        save_name = f'{prefix}_{sample_id:03d}_num_edits_{num_edits+1}_class_{target_attr_label}_attr_idx_{self.target_attr_idx}.png'  # noqa\n                        saved_image = edited_image\n                        saved_score = edited_score\n                        save_image(saved_image, f'{save_dir}/{save_name}')\n                        if display_img:\n                            plt.figure()\n                            plt.imshow(mpimg.imread(f'{save_dir}/{save_name}'))\n                            plt.axis('off')\n                            plt.show()\n                        if editing_logger:\n                            editing_logger.info(\n                                f'{save_name}: {saved_label}, {saved_score}')\n\n                    break\n\n                if target_attr_label != previous_target_attr_label:\n                    num_edits += 1\n\n                if num_edits > 0:\n                    if target_attr_label == previous_target_attr_label:\n                        current_stage_images_list.append(edited_image)\n                        current_stage_labels_list.append(edited_label)\n                        current_stage_scores_list.append(edited_score)\n                        current_stage_target_scores_list.append(\n                            target_attr_score)\n                        current_stage_latent_code_list.append(\n                            sample_latent_code)\n                        current_stage_editing_latent_code_list.append(\n                            edited_latent_code)\n                    else:\n                        if num_edits > 1:\n                            # save images for previous stage\n                            max_value = max(current_stage_target_scores_list)\n                            max_index = current_stage_target_scores_list.index(\n                                max_value)\n                            saved_image = current_stage_images_list[max_index]\n                            saved_label = current_stage_labels_list[max_index]\n                            saved_score = current_stage_scores_list[max_index]\n                            saved_latent_code = current_stage_latent_code_list[\n                                max_index]\n                            saved_editing_latent_code = current_stage_editing_latent_code_list[\n                                max_index]\n                            save_name = f'{prefix}_{sample_id:03d}_num_edits_{num_edits-1}_class_{previous_target_attr_label}_attr_idx_{self.target_attr_idx}.png'  # noqa\n                            if print_intermediate_result:\n                                save_image(saved_image,\n                                           f'{save_dir}/{save_name}')\n                            if editing_logger:\n                                editing_logger.info(\n                                    f'{save_name}: {saved_label}, {saved_score}'\n                                )\n\n                        current_stage_images_list = []\n                        current_stage_labels_list = []\n                        current_stage_scores_list = []\n                        current_stage_target_scores_list = []\n                        current_stage_latent_code_list = []\n                        current_stage_editing_latent_code_list = []\n                        num_trials = 0\n\n                        current_stage_images_list.append(edited_image)\n                        current_stage_labels_list.append(edited_label)\n                        current_stage_scores_list.append(edited_score)\n                        current_stage_target_scores_list.append(\n                            target_attr_score)\n                        current_stage_latent_code_list.append(\n                            sample_latent_code)\n                        current_stage_editing_latent_code_list.append(\n                            edited_latent_code)\n\n                previous_target_attr_label = target_attr_label\n\n                if num_trials > self.opt['max_trials_num']:\n                    if num_edits == 0:\n                        saved_label = start_label\n                        saved_latent_code = start_latent_codes\n                        saved_editing_latent_code = start_edited_latent_code\n                        saved_score = start_score\n                        # save_name = f'{prefix}_{sample_id:03d}_num_edits_1_class_{target_attr_label}_attr_idx_{self.target_attr_idx}.png'  # noqa\n                        ### save_image(original_image, f'{save_dir}/{save_name}')\n                        # if editing_logger:\n                        #     editing_logger.info(\n                        #         f'{save_name}: {saved_label}, {saved_score}')\n                        exception_mode = 'max_edit_num_reached'\n                    break\n\n            if num_edits > 0:\n                # save images for previous stage\n                max_value = max(current_stage_target_scores_list)\n                max_index = current_stage_target_scores_list.index(max_value)\n                saved_image = current_stage_images_list[max_index]\n                saved_label = current_stage_labels_list[max_index]\n                saved_score = current_stage_scores_list[max_index]\n                saved_latent_code = current_stage_latent_code_list[max_index]\n                saved_editing_latent_code = current_stage_editing_latent_code_list[\n                    max_index]\n                save_name = f'{prefix}_{sample_id:03d}_num_edits_{num_edits}_class_{previous_target_attr_label}_attr_idx_{self.target_attr_idx}.png'  # noqa\n                save_image(saved_image, f'{save_dir}/{save_name}')\n                if display_img:\n                    plt.figure()\n                    plt.imshow(mpimg.imread(f'{save_dir}/{save_name}'))\n                    plt.axis('off')\n                    plt.show()\n                if editing_logger:\n                    editing_logger.info(\n                        f'{save_name}: {saved_label}, {saved_score}')\n\n        return saved_latent_code, saved_editing_latent_code, saved_label, exception_mode\n"
  },
  {
    "path": "models/field_function_model.py",
    "content": "import logging\n\nimport torch\n\nfrom models.base_model import BaseModel\n\nlogger = logging.getLogger('base')\n\n\nclass FieldFunctionModel(BaseModel):\n\n    def __init__(self, opt):\n        super(FieldFunctionModel, self).__init__(opt)\n        self.replaced_layers = opt['replaced_layers']\n        self.fix_layers = True\n\n    def modify_latent_code(self, latent_code_w, latent_code_w_plus=None):\n        assert self.input_is_latent\n\n        return_dict = {}\n        # field function mapping\n        field = self.field_function(latent_code_w)\n        with torch.no_grad():\n            offset_w = self.stylegan_gen.style_forward(\n                torch.zeros_like(field), skip_norm=True)\n        delta_w = self.stylegan_gen.style_forward(\n            field, skip_norm=True) - offset_w\n\n        if latent_code_w_plus is None:\n            edited_latent_code = latent_code_w.unsqueeze(1).repeat(\n                1, self.w_space_channel_num, 1)\n        else:\n            edited_latent_code = latent_code_w_plus.clone()\n            return_dict['field'] = delta_w\n\n        for layer_idx in range(self.replaced_layers):\n            edited_latent_code[:, layer_idx, :] += delta_w\n\n        return_dict['edited_latent_code'] = edited_latent_code\n        return return_dict\n\n    def modify_latent_code_bidirection(self,\n                                       latent_code_w,\n                                       latent_code_w_plus=None,\n                                       alpha=1):\n        assert self.input_is_latent\n\n        return_dict = {}\n        # field function mapping\n        field = self.field_function(latent_code_w)\n        with torch.no_grad():\n            offset_w = self.stylegan_gen.style_forward(\n                torch.zeros_like(field), skip_norm=True)\n        delta_w = self.stylegan_gen.style_forward(\n            field, skip_norm=True) - offset_w\n\n        if latent_code_w_plus is None:\n            edited_latent_code = latent_code_w.unsqueeze(1).repeat(\n                1, self.w_space_channel_num, 1)\n        else:\n            edited_latent_code = latent_code_w_plus.clone()\n            return_dict['field'] = delta_w\n\n        for layer_idx in range(self.replaced_layers):\n            edited_latent_code[:, layer_idx, :] += alpha * delta_w\n\n        return_dict['edited_latent_code'] = edited_latent_code\n        return return_dict\n"
  },
  {
    "path": "models/losses/__init__.py",
    "content": ""
  },
  {
    "path": "models/losses/arcface_loss.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(\n        in_planes,\n        out_planes,\n        kernel_size=3,\n        stride=stride,\n        padding=1,\n        bias=False)\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(BasicBlock, self).__init__()\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass IRBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self,\n                 inplanes,\n                 planes,\n                 stride=1,\n                 downsample=None,\n                 use_se=True):\n        super(IRBlock, self).__init__()\n        self.bn0 = nn.BatchNorm2d(inplanes)\n        self.conv1 = conv3x3(inplanes, inplanes)\n        self.bn1 = nn.BatchNorm2d(inplanes)\n        self.prelu = nn.PReLU()\n        self.conv2 = conv3x3(inplanes, planes, stride)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.downsample = downsample\n        self.stride = stride\n        self.use_se = use_se\n        if self.use_se:\n            self.se = SEBlock(planes)\n\n    def forward(self, x):\n        residual = x\n        out = self.bn0(x)\n        out = self.conv1(out)\n        out = self.bn1(out)\n        out = self.prelu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        if self.use_se:\n            out = self.se(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.prelu(out)\n\n        return out\n\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(Bottleneck, self).__init__()\n        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(\n            planes,\n            planes,\n            kernel_size=3,\n            stride=stride,\n            padding=1,\n            bias=False)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.conv3 = nn.Conv2d(\n            planes, planes * self.expansion, kernel_size=1, bias=False)\n        self.bn3 = nn.BatchNorm2d(planes * self.expansion)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass SEBlock(nn.Module):\n\n    def __init__(self, channel, reduction=16):\n        super(SEBlock, self).__init__()\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.fc = nn.Sequential(\n            nn.Linear(channel, channel // reduction), nn.PReLU(),\n            nn.Linear(channel // reduction, channel), nn.Sigmoid())\n\n    def forward(self, x):\n        b, c, _, _ = x.size()\n        y = self.avg_pool(x).view(b, c)\n        y = self.fc(y).view(b, c, 1, 1)\n        return x * y\n\n\nclass ResNetFace(nn.Module):\n\n    def __init__(self, block, layers, use_se=True):\n        self.inplanes = 64\n        self.use_se = use_se\n        super(ResNetFace, self).__init__()\n        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(64)\n        self.prelu = nn.PReLU()\n        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n        self.bn4 = nn.BatchNorm2d(512)\n        self.dropout = nn.Dropout()\n        self.fc5 = nn.Linear(512 * 8 * 8, 512)\n        self.bn5 = nn.BatchNorm1d(512)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.xavier_normal_(m.weight)\n            elif isinstance(m, nn.BatchNorm2d) or isinstance(\n                    m, nn.BatchNorm1d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.xavier_normal_(m.weight)\n                nn.init.constant_(m.bias, 0)\n\n    def _make_layer(self, block, planes, blocks, stride=1):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(\n                    self.inplanes,\n                    planes * block.expansion,\n                    kernel_size=1,\n                    stride=stride,\n                    bias=False),\n                nn.BatchNorm2d(planes * block.expansion),\n            )\n        layers = []\n        layers.append(\n            block(\n                self.inplanes, planes, stride, downsample, use_se=self.use_se))\n        self.inplanes = planes\n        for i in range(1, blocks):\n            layers.append(block(self.inplanes, planes, use_se=self.use_se))\n\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.prelu(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n        x = self.bn4(x)\n        x = self.dropout(x)\n        x = x.view(x.size(0), -1)\n        x = self.fc5(x)\n        x = self.bn5(x)\n\n        return x\n\n\ndef resnet_face18(use_se=True, **kwargs):\n    model = ResNetFace(IRBlock, [2, 2, 2, 2], use_se=use_se, **kwargs)\n    return model\n\n\nclass ArcFaceLoss(nn.Module):\n\n    def __init__(self, pretrained_model, loss_type, use_se=False):\n        super(ArcFaceLoss, self).__init__()\n        self.model = resnet_face18(use_se=use_se)\n        self.model = nn.DataParallel(self.model)\n        self.model.load_state_dict(torch.load(pretrained_model), strict=True)\n        self.model.to(torch.device('cuda'))\n        self.model.eval()\n\n        self.loss_type = loss_type\n\n        if self.loss_type == 'l1':\n            self.loss_func = nn.L1Loss(reduction='mean')\n        elif self.loss_type == 'l2':\n            self.loss_func = nn.MSELoss(reduction='mean')\n        elif self.loss_type == 'cosine':\n            self.loss_func = nn.CosineEmbeddingLoss(reduction='mean')\n        else:\n            raise NotImplementedError\n\n    def forward(self, original_imgs, edited_imgs, resize=False):\n        # the image range should be [-1, 1], and convert image to grayscale\n        if resize:\n            # need to resize image to [128, 128]\n            original_features = self.model(\n                F.interpolate(original_imgs, (128, 128),\n                              mode='area').mean(dim=1, keepdim=True))\n            edited_features = self.model(\n                F.interpolate(edited_imgs, (128, 128),\n                              mode='area').mean(dim=1, keepdim=True))\n        else:\n            # the image range should be [-1, 1], and convert image to grayscale\n            original_features = self.model(\n                original_imgs.mean(dim=1, keepdim=True))\n            edited_features = self.model(edited_imgs.mean(dim=1, keepdim=True))\n\n        if self.loss_type == 'l1' or self.loss_type == 'l2':\n            loss = self.loss_func(original_features, edited_features)\n        elif self.loss_type == 'cosine':\n            target = torch.ones(original_features.size(0)).to(\n                torch.device('cuda'))\n            loss = self.loss_func(original_features, edited_features, target)\n        else:\n            raise NotImplementedError\n\n        return loss\n"
  },
  {
    "path": "models/losses/discriminator_loss.py",
    "content": "import torch\nimport torch.nn as nn\nfrom models.archs.stylegan2.model import Discriminator\nfrom torch.nn import functional as F\n\n\nclass DiscriminatorLoss(nn.Module):\n\n    def __init__(self, pretrained_model, img_res):\n        super(DiscriminatorLoss, self).__init__()\n        if img_res == 128:\n            self.discriminator = Discriminator(\n                size=img_res, channel_multiplier=1)\n            self.discriminator.load_state_dict(\n                torch.load(pretrained_model)['d'], strict=True)\n        elif img_res == 1024:\n            self.discriminator = Discriminator(\n                size=img_res, channel_multiplier=2)\n            self.discriminator.load_state_dict(\n                torch.load(pretrained_model), strict=True)\n        self.discriminator.to(torch.device('cuda'))\n        self.discriminator.eval()\n\n    def forward(self, generated_images):\n        generated_pred = self.discriminator(generated_images)\n        loss = F.softplus(-generated_pred).mean()\n\n        return loss\n"
  },
  {
    "path": "models/utils.py",
    "content": "import random\n\nimport cv2\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\n\n\ndef postprocess(images, channel_order='BGR', min_val=-1.0, max_val=1.0):\n    \"\"\"Postprocesses the output images if needed.\n\n        This function assumes the input numpy array is with shape [batch_size,\n        channel, height, width]. Here, `channel = 3` for color image and\n        `channel = 1` for grayscale image. The return images are with shape\n        [batch_size, height, width, channel]. NOTE: The channel order of output\n        image will always be `RGB`.\n\n        Args:\n          images: The raw output from the generator.\n\n        Returns:\n          The postprocessed images with dtype `numpy.uint8` with range\n            [0, 255].\n\n        Raises:\n          ValueError: If the input `images` are not with type `numpy.ndarray`\n            or not with shape [batch_size, channel, height, width].\n        \"\"\"\n    if not isinstance(images, np.ndarray):\n        raise ValueError('Images should be with type `numpy.ndarray`!')\n\n    images_shape = images.shape\n    if len(images_shape) != 4 or images_shape[1] not in [1, 3]:\n        raise ValueError(f'Input should be with shape [batch_size, channel, '\n                         f'height, width], where channel equals to 1 or 3. '\n                         f'But {images_shape} is received!')\n    images = (images - min_val) * 255 / (max_val - min_val)\n    images = np.clip(images + 0.5, 0, 255).astype(np.uint8)\n    images = images.transpose(0, 2, 3, 1)\n    if channel_order == 'BGR':\n        images = images[:, :, :, ::-1]\n\n    return images\n\n\ndef transform_image(image, resize=False):\n    # transform image range to [0, 1]\n    image = (image + 1) * 255 / 2\n    # TODO: int()? quantization?\n    image = torch.clamp(image + 0.5, 0, 255)\n    image = image / 255.\n    if resize:\n        image = F.interpolate(image, (128, 128), mode='area')\n\n    # normalize image to imagenet range\n    img_mean = torch.Tensor([0.485, 0.456,\n                             0.406]).view(1, 3, 1, 1).to(torch.device('cuda'))\n    img_std = torch.Tensor([0.229, 0.224,\n                            0.225]).view(1, 3, 1, 1).to(torch.device('cuda'))\n    image = (image - img_mean) / img_std\n\n    return image\n\n\ndef set_random_seed(seed):\n    \"\"\"Set random seeds.\"\"\"\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n\n\ndef output_to_label(output):\n    \"\"\"\n    INPUT\n    - output: [num_attr, batch_size, num_classes]\n    OUTPUT\n    - scores: [num_attr, batch_size, num_classes] (softmaxed)\n    - label: [num_attr, batch_size]\n    \"\"\"\n    scores = []\n    labels = []\n    for attr_idx in range(len(output)):\n        _, label = torch.max(input=output[attr_idx], dim=1)\n        label = label.cpu().numpy()[0]\n        labels.append(label)\n\n        score_per_attr = output[attr_idx].cpu().numpy()[0]\n        # softmax\n        score_per_attr = (np.exp(score_per_attr) /\n                          np.sum(np.exp(score_per_attr)))[label]\n        scores.append(score_per_attr)\n\n    scores = torch.FloatTensor(scores)\n    labels = torch.LongTensor(labels)\n\n    return labels, scores\n\n\ndef predictor_to_label(predictor_output):\n\n    scores = []\n    labels = []\n    for attr_idx in range(len(predictor_output)):\n        _, label = torch.max(input=predictor_output[attr_idx], dim=1)\n        label = label.cpu().numpy()[0]\n        labels.append(label)\n\n        score_per_attr = predictor_output[attr_idx].cpu().numpy()[0]\n        # softmax\n        score_per_attr = (np.exp(score_per_attr) /\n                          np.sum(np.exp(score_per_attr)))[label]\n        scores.append(score_per_attr)\n\n    return labels, scores\n\n\ndef save_image(img, save_path, need_post_process=True):\n    if need_post_process:\n        cv2.imwrite(save_path, postprocess(img.cpu().detach().numpy())[0])\n    else:\n        cv2.imwrite(save_path, img)\n"
  },
  {
    "path": "quantitative_results.py",
    "content": "import argparse\nimport glob\nimport logging\n\nimport cv2\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torchvision.transforms as transforms\nfrom facenet_pytorch import InceptionResnetV1\nfrom PIL import Image\n\nfrom models.archs.attribute_predictor_arch import resnet50\nfrom models.utils import output_to_label\nfrom utils.logger import get_root_logger\nfrom utils.options import dict2str\n\nattr_predictor_eval_ckpt = './download/pretrained_models/eval_predictor.pth.tar'\n\n\ndef parse_args():\n    \"\"\"Parses arguments.\"\"\"\n    parser = argparse.ArgumentParser(\n        description='Continuous image editing via field function')\n    # inference\n    parser.add_argument(\n        '--attribute',\n        type=str,\n        required=True,\n        help='[Bangs, Eyeglasses, No_Beard, Smiling, Young]')\n\n    # input and output directories\n    parser.add_argument(\n        '--work_dir',\n        required=True,\n        type=str,\n        metavar='PATH',\n        help='path to save checkpoint and log files.')\n    parser.add_argument(\n        '--image_dir',\n        required=True,\n        type=str,\n        metavar='PATH',\n        help='path to save checkpoint and log files.')\n    parser.add_argument('--image_num', type=int, required=True)\n    parser.add_argument('--debug', default=0, type=int)\n\n    return parser.parse_args()\n\n\ndef get_edited_images_list(img_dir, img_idx):\n    return_img_list = []\n    img_path_list = glob.glob(f'{img_dir}/{img_idx:03d}_*.png')\n    start_img_path = glob.glob(f'{img_dir}/{img_idx:03d}_num_edits_0_*.png')\n    assert len(start_img_path) == 1\n    return_img_list.append(start_img_path[0])\n\n    num_edits = len(img_path_list) - 1\n    if num_edits > 0:\n        for edit_idx in range(1, num_edits + 1):\n            img_path_edit_list = glob.glob(\n                f'{img_dir}/{img_idx:03d}_num_edits_{edit_idx}_*.png')\n            assert len(img_path_edit_list) == 1\n            return_img_list.append(img_path_edit_list[0])\n\n    return return_img_list\n\n\ndef load_face_image(img_path):\n    image = cv2.imread(img_path)\n    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n    image = image.transpose((2, 0, 1))\n    image = image[np.newaxis, :, :, :]\n    image = image.astype(np.float32, copy=False)\n    image -= 127.5\n    image /= 128.0\n\n    image = torch.from_numpy(image).to(torch.device('cuda'))\n\n    return image\n\n\ndef load_image_predictor(img_path,\n                         transform=transforms.Compose([\n                             transforms.ToTensor(),\n                             transforms.Normalize(\n                                 mean=[0.485, 0.456, 0.406],\n                                 std=[0.229, 0.224, 0.225]),\n                         ])):\n    image = Image.open(img_path).convert('RGB')\n    image = transform(image)\n\n    image = image.to(torch.device('cuda')).unsqueeze(0)\n\n    return image\n\n\ndef predictor_score(predictor_output, gt_label, target_attr_idx,\n                    criterion_predictor):\n    num_attr = len(predictor_output)\n    loss_avg = 0\n    count = 0\n    for attr_idx in range(num_attr):\n        if attr_idx == target_attr_idx:\n            continue\n        loss_attr = criterion_predictor(\n            predictor_output[attr_idx],\n            gt_label[attr_idx].unsqueeze(0).to(torch.device('cuda')))\n\n        loss_avg += loss_attr\n        count += 1\n    loss_avg = loss_avg / count\n\n    return loss_avg\n\n\ndef compute_num_metrics(image_dir, image_num, target_attr_idx, logger):\n    # use different face model and predictor model from training phase\n    # define face recognition model\n    resnet = InceptionResnetV1(pretrained='vggface2').eval().to(\n        torch.device('cuda'))\n\n    # define attribute predictor model\n    predictor = resnet50(attr_file='./configs/attributes_5.json', )\n    predictor = predictor.to(torch.device('cuda'))\n\n    checkpoint = torch.load(attr_predictor_eval_ckpt)\n    predictor.load_state_dict(checkpoint['state_dict'], strict=True)\n    predictor.eval()\n\n    criterion_predictor = nn.CrossEntropyLoss(reduction='mean')\n\n    face_distance_dataset = 0\n    predictor_score_dataset = 0\n    count = 0\n    for img_idx in range(image_num):\n        edit_image_list = get_edited_images_list(image_dir, img_idx)\n        num_edits = len(edit_image_list) - 1\n        face_distance_img = 0\n        predictor_score_img = 0\n        if num_edits > 0:\n            # face recognition feature\n            source_img = load_face_image(edit_image_list[0])\n            source_img_feat = resnet(source_img)\n            # attribute label for predictor\n            source_img_predictor = load_image_predictor(edit_image_list[0])\n            with torch.no_grad():\n                source_predictor_output = predictor(source_img_predictor)\n            source_label, score = output_to_label(source_predictor_output)\n            for edit_idx in range(1, num_edits + 1):\n                edited_img = load_face_image(edit_image_list[edit_idx])\n                edited_img_feat = resnet(edited_img)\n                temp_face_dist = torch.norm(source_img_feat -\n                                            edited_img_feat).item()\n                face_distance_img += temp_face_dist\n                # attribute predictor score\n                edited_img_predictor = load_image_predictor(\n                    edit_image_list[edit_idx])\n                with torch.no_grad():\n                    edited_predictor_output = predictor(edited_img_predictor)\n                temp_predictor_score_img = predictor_score(\n                    edited_predictor_output, source_label, target_attr_idx,\n                    criterion_predictor)\n                predictor_score_img += temp_predictor_score_img\n\n            face_distance_img = face_distance_img / num_edits\n            face_distance_dataset += face_distance_img\n            predictor_score_img = predictor_score_img / num_edits\n            predictor_score_dataset += predictor_score_img\n            count += 1\n            logger.info(\n                f'{img_idx:03d}: Identity Preservation: {face_distance_img: .4f}, Attribute Preservation: {predictor_score_img: .4f}.'\n            )\n        else:\n            logger.info(f'{img_idx:03d}: no available edits.')\n\n    face_distance_dataset = face_distance_dataset / count\n    predictor_score_dataset = predictor_score_dataset / count\n    logger.info(\n        f'Avg: {face_distance_dataset: .4f}, {predictor_score_dataset: .4f}.')\n\n    return face_distance_dataset, predictor_score_dataset\n\n\ndef main():\n    \"\"\"Main function.\"\"\"\n    args = parse_args()\n    args.attr_dict = {\n        'Bangs': 0,\n        'Eyeglasses': 1,\n        'No_Beard': 2,\n        'Smiling': 3,\n        'Young': 4\n    }\n\n    logger = get_root_logger(\n        logger_name='base',\n        log_level=logging.INFO,\n        log_file=f'{args.work_dir}/quantitative_results.txt')\n    logger.info(dict2str(args.__dict__))\n\n    compute_num_metrics(args.image_dir, args.image_num,\n                        args.attr_dict[args.attribute], logger)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "train.py",
    "content": "import argparse\nimport logging\nimport os\nimport os.path as osp\nimport random\nimport time\n\nimport numpy as np\nimport torch\n\nfrom data.latent_code_dataset import LatentCodeDataset\nfrom models import create_model\nfrom utils.logger import MessageLogger, get_root_logger, init_tb_logger\nfrom utils.numerical_metrics import compute_num_metrics\nfrom utils.options import dict2str, dict_to_nonedict, parse\nfrom utils.util import make_exp_dirs\n\n\ndef main():\n    # options\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--opt', type=str, help='Path to option YAML file.')\n    args = parser.parse_args()\n    opt = parse(args.opt, is_train=True)\n\n    # mkdir and loggers\n    make_exp_dirs(opt)\n    log_file = osp.join(opt['path']['log'], f\"train_{opt['name']}.log\")\n    logger = get_root_logger(\n        logger_name='base', log_level=logging.INFO, log_file=log_file)\n    logger.info(dict2str(opt))\n    # initialize tensorboard logger\n    tb_logger = None\n    if opt['use_tb_logger'] and 'debug' not in opt['name']:\n        tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])\n\n    # convert to NoneDict, which returns None for missing keys\n    opt = dict_to_nonedict(opt)\n\n    # random seed\n    seed = opt['manual_seed']\n    if seed is None:\n        seed = random.randint(1, 10000)\n    logger.info(f'Random seed: {seed}')\n\n    # set up data loader\n    logger.info(f'Loading data from{opt[\"input_latent_dir\"]}.')\n    train_latent_dataset = LatentCodeDataset(\n        input_dir=opt['dataset']['train_latent_dir'])\n    train_latent_loader = torch.utils.data.DataLoader(\n        dataset=train_latent_dataset,\n        batch_size=opt['batch_size'],\n        shuffle=True,\n        num_workers=opt['num_workers'],\n        drop_last=True)\n    logger.info(f'Number of train set: {len(train_latent_dataset)}.')\n    opt['max_iters'] = opt['num_epochs'] * len(\n        train_latent_dataset) // opt['batch_size']\n    if opt['val_on_train_subset']:\n        train_subset_latent_dataset = LatentCodeDataset(\n            input_dir=opt['dataset']['train_subset_latent_dir'])\n        train_subset_latent_loader = torch.utils.data.DataLoader(\n            dataset=train_subset_latent_dataset,\n            batch_size=1,\n            shuffle=False,\n            num_workers=opt['num_workers'])\n        logger.info(\n            f'Number of train subset: {len(train_subset_latent_dataset)}.')\n    if opt['val_on_valset']:\n        val_latent_dataset = LatentCodeDataset(\n            input_dir=opt['dataset']['val_latent_dir'])\n        val_latent_loader = torch.utils.data.DataLoader(\n            dataset=val_latent_dataset,\n            batch_size=1,\n            shuffle=False,\n            num_workers=opt['num_workers'])\n        logger.info(f'Number of val set: {len(val_latent_dataset)}.')\n\n    # load editing latent code\n    editing_latent_codes = np.load(opt['editing_latent_code_path'])\n    num_latent_codes = editing_latent_codes.shape[0]\n\n    current_iter = 0\n    best_metric = 10000\n    best_epoch = None\n    best_arcface = None\n    best_predictor = None\n\n    field_model = create_model(opt)\n\n    data_time, iter_time = 0, 0\n    current_iter = 0\n\n    # create message logger (formatted outputs)\n    msg_logger = MessageLogger(opt, current_iter, tb_logger)\n\n    for epoch in range(opt['num_epochs']):\n        lr = field_model.update_learning_rate(epoch)\n\n        for _, batch_data in enumerate(train_latent_loader):\n            data_time = time.time() - data_time\n\n            current_iter += 1\n\n            field_model.feed_data(batch_data)\n            field_model.optimize_parameters()\n\n            iter_time = time.time() - iter_time\n            if current_iter % opt['print_freq'] == 0:\n                log_vars = {'epoch': epoch, 'iter': current_iter}\n                log_vars.update({'lrs': [lr]})\n                log_vars.update({'time': iter_time, 'data_time': data_time})\n                log_vars.update(field_model.get_current_log())\n                msg_logger(log_vars)\n\n            data_time = time.time()\n            iter_time = time.time()\n\n        if epoch % opt['val_freq'] == 0:\n            if opt['val_on_valset']:\n                save_dir = f'{opt[\"path\"][\"visualization\"]}/valset/epoch_{epoch:03d}'  # noqa\n                os.makedirs(save_dir, exist_ok=opt['debug'])\n                for batch_idx, batch_data in enumerate(val_latent_loader):\n                    field_model.feed_data(batch_data)\n                    field_model.inference(batch_idx, epoch, save_dir)\n            if opt['val_on_train_subset']:\n                save_dir = f'{opt[\"path\"][\"visualization\"]}/trainsubset/epoch_{epoch:03d}'  # noqa\n                os.makedirs(save_dir, exist_ok=opt['debug'])\n                for batch_idx, batch_data in enumerate(\n                        train_subset_latent_loader):\n                    field_model.feed_data(batch_data)\n                    field_model.inference(batch_idx, epoch, save_dir)\n\n            save_path = f'{opt[\"path\"][\"visualization\"]}/continuous_editing/epoch_{epoch:03d}'  # noqa\n            os.makedirs(save_path, exist_ok=opt['debug'])\n            editing_logger = get_root_logger(\n                logger_name=f'editing_{epoch:03d}',\n                log_level=logging.INFO,\n                log_file=f'{save_path}/editing.log')\n\n            field_model.continuous_editing(editing_latent_codes, save_path,\n                                           editing_logger)\n\n            arcface_sim, predictor_score = compute_num_metrics(\n                save_path, num_latent_codes, opt['pretrained_arcface'],\n                opt['attr_file'], opt['predictor_ckpt'],\n                opt['attr_dict'][opt['attribute']], editing_logger)\n\n            logger.info(f'Epoch: {epoch}, '\n                        f'ArcFace: {arcface_sim: .4f}, '\n                        f'Predictor: {predictor_score: .4f}.')\n\n            metrics = 1 - arcface_sim + predictor_score\n\n            if metrics < best_metric:\n                best_epoch = epoch\n                best_metric = metrics\n                best_arcface = arcface_sim\n                best_predictor = predictor_score\n\n            logger.info(f'Best epoch: {best_epoch}, '\n                        f'ArcFace: {best_arcface: .4f}, '\n                        f'Predictor: {best_predictor: .4f}.')\n\n            # save model\n            field_model.save_network(\n                field_model.field_function,\n                f'{opt[\"path\"][\"models\"]}/ckpt_epoch{epoch}.pth')\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "utils/__init__.py",
    "content": ""
  },
  {
    "path": "utils/crop_img.py",
    "content": "\"\"\"\nbrief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset)\nauthor: lzhbrian (https://lzhbrian.me)\ndate: 2020.1.5\nnote: code is heavily borrowed from\n    https://github.com/NVlabs/ffhq-dataset\n    http://dlib.net/face_landmark_detection.py.html\nrequirements:\n    apt install cmake\n    conda install Pillow numpy scipy\n    pip install dlib\n    # download face landmark model from:\n    # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2\n\"\"\"\n\nimport cv2\nimport dlib\nimport numpy as np\nimport PIL\nimport PIL.Image\nimport scipy\nimport scipy.ndimage\nfrom facenet_pytorch import MTCNN\nfrom PIL import Image\n\n\ndef crop_img(img_size, input_img_path, cropped_output_path, device='cuda'):\n    if img_size == 128:\n        return crop_img_128(input_img_path, cropped_output_path, device)\n    elif img_size == 1024:\n        return crop_img_1024(input_img_path, cropped_output_path)\n    else:\n        raise NotImplementedError\n\n\ndef crop_img_128(input_img_path, cropped_output_path, device='cuda'):\n    mtcnn = MTCNN(select_largest=True, device=device)\n\n    img = Image.open(input_img_path).convert('RGB')\n    img = np.uint8(img)\n    bboxes, _ = mtcnn.detect(img)\n    w0, h0, w1, h1 = bboxes[0]\n    hc, wc = (h0 + h1) / 2, (w0 + w1) / 2\n    crop = int(((h1 - h0) + (w1 - w0)) / 2 / 2 * 1.1)\n    h0 = int(hc - crop + crop + crop * 0.15)\n    w0 = int(wc - crop + crop)\n\n    x0, y0, w, h = w0 - crop, h0 - crop, crop * 2, crop * 2\n    im = cv2.imread(input_img_path)\n    im_pad = cv2.copyMakeBorder(\n        im, h, h, w, w,\n        cv2.BORDER_REPLICATE)  # allow cropping outside by replicating borders\n    im_crop = im_pad[y0 + h:y0 + h * 2, x0 + w:x0 + w * 2]\n\n    im_crop = cv2.resize(im_crop, (128, 128), interpolation=cv2.INTER_AREA)\n\n    cv2.imwrite(cropped_output_path, im_crop)\n\n    return True\n\n\n# download model from: http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 # noqa\npredictor = dlib.shape_predictor(\n    './download/pretrained_models/shape_predictor_68_face_landmarks.dat'  # noqa\n)\n\n\ndef get_landmark(filepath):\n    \"\"\"get landmark with dlib\n    :return: np.array shape=(68, 2)\n    \"\"\"\n    detector = dlib.get_frontal_face_detector()\n\n    img = dlib.load_rgb_image(filepath)\n    dets = detector(img, 1)\n\n    if len(dets) < 1:\n        return False, None\n\n    # print(\"Number of faces detected: {}\".format(len(dets)))\n    for k, d in enumerate(dets):\n        # print(\"Detection {}: Left: {} Top: {} Right: {} Bottom: {}\".format(\n        # k, d.left(), d.top(), d.right(), d.bottom()))\n        # Get the landmarks/parts for the face in box d.\n        shape = predictor(img, d)\n        # print(\"Part 0: {}, Part 1: {} ...\".format(\n        # shape.part(0), shape.part(1)))\n\n    t = list(shape.parts())\n    a = []\n    for tt in t:\n        a.append([tt.x, tt.y])\n    lm = np.array(a)\n    # lm is a shape=(68,2) np.array\n    return True, lm\n\n\ndef crop_img_1024(input_img_path, cropped_output_path):\n    \"\"\"\n    :param filepath: str\n    :return: PIL Image\n    \"\"\"\n\n    success, lm = get_landmark(input_img_path)\n    if success is False:\n        return False\n\n    lm_eye_left = lm[36:42]  # left-clockwise\n    lm_eye_right = lm[42:48]  # left-clockwise\n    lm_mouth_outer = lm[48:60]  # left-clockwise\n\n    # Calculate auxiliary vectors.\n    eye_left = np.mean(lm_eye_left, axis=0)\n    eye_right = np.mean(lm_eye_right, axis=0)\n    eye_avg = (eye_left + eye_right) * 0.5\n    eye_to_eye = eye_right - eye_left\n    mouth_left = lm_mouth_outer[0]\n    mouth_right = lm_mouth_outer[6]\n    mouth_avg = (mouth_left + mouth_right) * 0.5\n    eye_to_mouth = mouth_avg - eye_avg\n\n    # Choose oriented crop rectangle.\n    x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]\n    x /= np.hypot(*x)\n    x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)\n    y = np.flipud(x) * [-1, 1]\n    c = eye_avg + eye_to_mouth * 0.1\n    quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])\n    qsize = np.hypot(*x) * 2\n\n    # read image\n    img = PIL.Image.open(input_img_path)\n\n    output_size = 1024\n    transform_size = 4096\n    enable_padding = True\n\n    # Shrink.\n    shrink = int(np.floor(qsize / output_size * 0.5))\n    if shrink > 1:\n        rsize = (int(np.rint(float(img.size[0]) / shrink)),\n                 int(np.rint(float(img.size[1]) / shrink)))\n        img = img.resize(rsize, PIL.Image.ANTIALIAS)\n        quad /= shrink\n        qsize /= shrink\n\n    # Crop.\n    border = max(int(np.rint(qsize * 0.1)), 3)\n    crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),\n            int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))\n    crop = (max(crop[0] - border, 0), max(crop[1] - border, 0),\n            min(crop[2] + border,\n                img.size[0]), min(crop[3] + border, img.size[1]))\n    if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:\n        img = img.crop(crop)\n        quad -= crop[0:2]\n\n    # Pad.\n    pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),\n           int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))\n    pad = (max(-pad[0] + border,\n               0), max(-pad[1] + border,\n                       0), max(pad[2] - img.size[0] + border,\n                               0), max(pad[3] - img.size[1] + border, 0))\n    if enable_padding and max(pad) > border - 4:\n        pad = np.maximum(pad, int(np.rint(qsize * 0.3)))\n        img = np.pad(\n            np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)),\n            'reflect')\n        h, w, _ = img.shape\n        y, x, _ = np.ogrid[:h, :w, :1]\n        mask = np.maximum(\n            1.0 -\n            np.minimum(np.float32(x) / pad[0],\n                       np.float32(w - 1 - x) / pad[2]), 1.0 -\n            np.minimum(np.float32(y) / pad[1],\n                       np.float32(h - 1 - y) / pad[3]))\n        blur = qsize * 0.02\n        img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) -\n                img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)\n        img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)\n        img = PIL.Image.fromarray(\n            np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')\n        quad += pad[:2]\n\n    # Transform.\n    img = img.transform((transform_size, transform_size), PIL.Image.QUAD,\n                        (quad + 0.5).flatten(), PIL.Image.BILINEAR)\n    if output_size < transform_size:\n        img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)\n\n    img.save(cropped_output_path)\n\n    return True\n"
  },
  {
    "path": "utils/dialog_edit_utils.py",
    "content": "import random\n\nimport matplotlib.image as mpimg\nimport matplotlib.pyplot as plt\nimport torch\nfrom language.generate_feedback import instantiate_feedback\nfrom language.run_encoder import encode_request\nfrom models.utils import save_image\n\nfrom utils.editing_utils import edit_target_attribute\n\n\ndef dialog_with_real_user(field_model,\n                          latent_code,\n                          opt,\n                          args,\n                          dialog_logger,\n                          display_img=False):\n\n    # initialize dialog recorder\n    state_log = ['start']\n    edit_log = []\n    system_log = [{\"text\": None, \"system_mode\": 'start', \"attribute\": None}]\n    user_log = []\n    not_used_attribute = [\n        'Bangs', \"Eyeglasses\", \"No_Beard\", \"Smiling\", \"Young\"\n    ]\n    text_log = []\n    text_image_log = []\n\n    # initialize first round's variables\n    round_idx = 0\n\n    edited_latent_code = None\n\n    with torch.no_grad():\n        start_image, start_label, start_score = \\\n            field_model.synthesize_and_predict(torch.from_numpy(latent_code).to(torch.device('cuda'))) # noqa\n\n    save_image(start_image, f'{opt[\"path\"][\"visualization\"]}/start_image.png')\n\n    if display_img:\n        plt.figure()\n        plt.imshow(\n            mpimg.imread(f'{opt[\"path\"][\"visualization\"]}/start_image.png'))\n        plt.axis('off')\n        plt.show()\n\n    # initialize attribtue_dict\n    attribute_dict = {\n        \"Bangs\": start_label[0],\n        \"Eyeglasses\": start_label[1],\n        \"No_Beard\": start_label[2],\n        \"Smiling\": start_label[3],\n        \"Young\": start_label[4],\n    }\n    dialog_logger.info('START IMAGE  >>> ' + str(attribute_dict))\n\n    while True:\n\n        dialog_logger.info('\\n---------------------------------------- Edit ' +\n                           str(round_idx) +\n                           '----------------------------------------\\n')\n\n        # -------------------- TAKE USER INPUT --------------------\n        # understand user input\n        user_labels = encode_request(\n            args,\n            system_mode=system_log[-1]['system_mode'],\n            dialog_logger=dialog_logger)\n        text_image_log.append('USER:   ' + user_labels['text'])\n\n        # update not_used_attribute\n        if user_labels['attribute'] in not_used_attribute:\n            not_used_attribute.remove(user_labels['attribute'])\n\n        # #################### DECIDE STATE ####################\n        state = decide_next_state(\n            state=state_log[-1],\n            system_mode=system_log[-1]['system_mode'],\n            user_mode=user_labels['user_mode'])\n\n        if state == 'end':\n            user_log.append(user_labels)\n            state_log.append(state)\n            text_log.append('USER:   ' + user_labels['text'])\n            break\n\n        # #################### DECIDE EDIT ####################\n        edit_labels = decide_next_edit(\n            edit_log=edit_log,\n            system_labels=system_log[-1],\n            user_labels=user_labels,\n            state=state,\n            attribute_dict=attribute_dict,\n            dialog_logger=dialog_logger)\n\n        text_image_log.append(edit_labels)\n\n        attribute_dict, exception_mode, latent_code, edited_latent_code = edit_target_attribute(  # noqa\n            opt,\n            attribute_dict,\n            edit_labels,\n            round_idx,\n            latent_code,\n            edited_latent_code,\n            field_model,\n            display_img=display_img)\n        if state == 'no_edit':\n            dialog_logger.info('NO EDIT  >>> ' + str(attribute_dict))\n        else:\n            dialog_logger.info('UPDATED IMAGE >>> ' + str(attribute_dict))\n        text_image_log.append(attribute_dict.copy())\n\n        # #################### DECIDE SYSTEM ####################\n        # decide system feedback hard labels\n        temp_system_labels = decide_next_feedback(\n            system_labels=system_log[-1],\n            user_labels=user_labels,\n            state=state,\n            edit_labels=edit_labels,\n            not_used_attribute=not_used_attribute,\n            round_idx=round_idx,\n            exception_mode=exception_mode)\n\n        # instantiate feedback\n        system_labels = instantiate_feedback(\n            args,\n            system_mode=temp_system_labels['system_mode'],\n            attribute=temp_system_labels['attribute'],\n            exception_mode=exception_mode)\n\n        dialog_logger.info('SYSTEM FEEDBACK >>> ' + system_labels['text'])\n\n        # update not_used_attribute\n        if system_labels['attribute'] in not_used_attribute:\n            not_used_attribute.remove(system_labels['attribute'])\n\n        # -------------------- UPDATE LOG --------------------\n        state_log.append(state)\n        edit_log.append(edit_labels)\n        system_log.append(system_labels)\n        user_log.append(user_labels)\n        text_log.append('USER:   ' + user_labels['text'])\n        text_log.append('SYSTEM: ' + system_labels['text'])\n        text_log.append('')\n        text_image_log.append('SYSTEM: ' + system_labels['text'])\n        text_image_log.append('')\n\n        round_idx += 1\n\n    dialog_overall_log = {\n        'state_log': state_log,\n        'edit_log': edit_log,\n        'system_log': system_log,\n        'user_log': user_log,\n        'text_log': text_log,\n        'text_image_log': text_image_log\n    }\n    dialog_logger.info('Dialog successfully ended.')\n\n    return dialog_overall_log\n\n\ndef decide_next_state(state, system_mode, user_mode):\n    \"\"\"\n    Input: state, system, user\n    Output: next state\n    \"\"\"\n\n    if state == 'start':\n        assert system_mode == 'start'\n        assert user_mode == 'start_pureRequest'\n        next_state = 'edit'\n\n    elif state == 'edit':\n        if system_mode == 'suggestion':\n            if user_mode == 'yes':\n                next_state = 'edit'\n            elif user_mode == 'yes_pureRequest':\n                next_state = 'edit'\n            elif user_mode == 'no_pureRequest':\n                next_state = 'edit'\n            elif user_mode == 'no':\n                next_state = 'no_edit'\n            elif user_mode == 'no_end':\n                next_state = 'end'\n            else:\n                raise ValueError(\"invalid user_mode\")\n        elif system_mode == 'whether_enough':\n            if user_mode == 'yes':\n                next_state = 'no_edit'\n            elif user_mode == 'yes_pureRequest':\n                next_state = 'edit'\n            elif user_mode == 'yes_end':\n                next_state = 'end'\n            elif user_mode == 'no':\n                next_state = 'edit'\n            elif user_mode == 'no_pureRequest':\n                next_state = 'edit'\n            else:\n                raise ValueError(\"invalid user_mode\")\n        elif system_mode == 'whats_next':\n            if user_mode == 'pureRequest':\n                next_state = 'edit'\n            elif user_mode == 'end':\n                next_state = 'end'\n        else:\n            raise ValueError(\"invalid system_mode\")\n\n    elif state == 'no_edit':\n        if system_mode == 'suggestion':\n            if user_mode == 'yes':\n                next_state = 'edit'\n            elif user_mode == 'yes_pureRequest':\n                next_state = 'edit'\n            elif user_mode == 'no_pureRequest':\n                next_state = 'edit'\n            elif user_mode == 'no':\n                next_state = 'no_edit'\n            elif user_mode == 'no_end':\n                next_state = 'end'\n            else:\n                raise ValueError(\"invalid user_mode\")\n        elif system_mode == 'whether_enough':\n            raise ValueError(\"invalid system_mode\")\n        elif system_mode == 'whats_next':\n            if user_mode == 'pureRequest':\n                next_state = 'edit'\n            elif user_mode == 'end':\n                next_state = 'end'\n        else:\n            raise ValueError(\"invalid system_mode\")\n    elif state == 'end':\n        raise ValueError(\"invalid state\")\n\n    else:\n        raise ValueError(\"invalid state\")\n\n    return next_state\n\n\ndef decide_next_edit(edit_log, system_labels, user_labels, state,\n                     attribute_dict, dialog_logger):\n    \"\"\"\n    Input: previous edit, system, user, resulting state, attribute labels\n    Output: current edit\n    \"\"\"\n\n    attribute = None\n    score_change_direction = None\n    score_change_value = None\n    target_score = None\n\n    if len(edit_log) > 0:\n        edit_labels = edit_log[-1]\n\n    # ---------- decide edit_labels ----------\n    if len(edit_log) == 0:\n        # now is the first round, so edit according to user request\n        assert 'pureRequest' in user_labels['user_mode']\n        assert state == 'edit'\n        attribute = user_labels['attribute']\n        score_change_direction = user_labels['score_change_direction']\n        score_change_value = user_labels['score_change_value']\n        target_score = user_labels['target_score']\n        if user_labels['request_mode'] == 'change_indefinite':\n            assert score_change_value is None\n            score_change_value = 1\n\n    elif 'pureRequest' in user_labels['user_mode']:\n        # edit according to user request\n        assert state == 'edit'\n        attribute = user_labels['attribute']\n        score_change_direction = user_labels['score_change_direction']\n        score_change_value = user_labels['score_change_value']\n        target_score = user_labels['target_score']\n        if user_labels['request_mode'] == 'change_indefinite':\n            assert score_change_value is None\n            score_change_value = 1\n\n    elif system_labels['system_mode'] == 'whether_enough' and user_labels[\n            'user_mode'] == 'no':\n        # continue the previous edit\n        assert state == 'edit'\n        attribute = edit_labels['attribute']\n        score_change_direction = edit_labels['score_change_direction']\n        score_change_value = 1\n        target_score = None\n\n    elif system_labels['system_mode'] == 'suggestion' and user_labels[\n            'user_mode'] == 'yes':\n        # play with the suggested attribute, random direction\n        # (small degree --> positive direction)\n        assert state == 'edit'\n        attribute = system_labels['attribute']\n\n        if attribute_dict[attribute] <= 2:\n            score_change_direction = 'positive'\n        else:\n            score_change_direction = 'negative'\n\n        score_change_value = 1\n        target_score = None\n\n    else:\n        # no edit\n        assert (state == 'no_edit' or state == 'end')\n        attribute = None\n        score_change_direction = None\n        score_change_value = None\n        target_score = None\n\n    # --- The code below is moderation mechanism for language encoder ---\n    if system_labels['system_mode'] == 'suggestion' and user_labels[\n            'user_mode'] == 'yes':\n        attribute = system_labels['attribute']\n\n    # ---------- Fill in all the values in edit_labels ----------\n    if attribute is None:\n        assert score_change_direction is None\n        assert score_change_value is None\n        assert target_score is None\n    elif target_score is not None:\n        assert score_change_direction is None\n        assert score_change_value is None\n        if target_score > attribute_dict[attribute]:\n            score_change_direction = 'positive'\n        elif target_score < attribute_dict[attribute]:\n            score_change_direction = 'negative'\n        else:\n            pass\n        score_change_value = abs(target_score - attribute_dict[attribute])\n    elif score_change_direction is not None:\n        assert score_change_value is not None\n        if score_change_direction == 'positive':\n            target_score = attribute_dict[attribute] + score_change_value\n        elif score_change_direction == 'negative':\n            target_score = attribute_dict[attribute] - score_change_value\n        else:\n            raise ValueError('invalid direction')\n        # boundary value checking\n        if target_score > 5:\n            target_score = 5\n            score_change_value = abs(target_score - attribute_dict[attribute])\n        elif target_score < 0:\n            target_score = 0\n            score_change_value = abs(target_score - attribute_dict[attribute])\n\n    next_edit_labels = {\n        'attribute': attribute,\n        'score_change_direction': score_change_direction,\n        \"score_change_value\": score_change_value,\n        'target_score': target_score\n    }\n\n    return next_edit_labels\n\n\ndef decide_next_feedback(system_labels, user_labels, state, edit_labels,\n                         not_used_attribute, round_idx, exception_mode):\n    \"\"\"\n    Input: system, user, state, edit + others\n    Output: system\n    \"\"\"\n\n    assert (state == 'edit' or state == 'no_edit')\n\n    while True:\n\n        if exception_mode != 'normal':\n            system_mode = 'whats_next'\n            feedback_attribute = None\n            break\n\n        system_mode = None\n        feedback_attribute = None\n\n        if system_labels['system_mode'] == 'suggestion' and user_labels[\n                'user_mode'] == 'yes':\n            assert state == 'edit'\n            system_mode = 'whether_enough'\n            feedback_attribute = system_labels['attribute']\n            break\n\n        # ---------- whether_enough ----------\n        # first round has higher chance for whether_enough\n        whether_enough_random_num = random.uniform(0, 1)\n        if round_idx == 0:\n            whether_enough_prob = 0.8\n            if whether_enough_random_num < whether_enough_prob:\n                system_mode = 'whether_enough'\n                if state == 'no_edit':\n                    continue\n                else:\n                    feedback_attribute = user_labels['attribute']\n                    assert feedback_attribute is not None\n\n        # ---------- whats_next ----------\n        # higher chance at earlier rounds\n        if system_mode is None:\n            whats_next_random_num = random.uniform(0, 1)\n            whats_next_prob_list = [0.5, 0.4, 0.3, 0.3]\n            if round_idx <= 3:\n                whats_next_prob = whats_next_prob_list[round_idx]\n            else:\n                whats_next_prob = 0.2\n            if whats_next_random_num < whats_next_prob:\n                system_mode = 'whats_next'\n                feedback_attribute = None\n\n        # ---------- suggestion ----------\n        # if a lot of attribute has been edited, don't be suggestion\n        if system_mode is None:\n            suggestion_random_num = random.uniform(0, 1)\n            suggestion_prob = len(not_used_attribute) * 0.2\n            if suggestion_random_num < suggestion_prob:\n                system_mode = 'suggestion'\n                if len(not_used_attribute) > 0:\n                    feedback_attribute = random.choice(not_used_attribute)\n                else:\n                    system_mode = None\n\n        # ---------- whether_enough ----------\n        # if not chosen to be 'whats_next' or 'suggestion',\n        # then use 'whether_enough'\n        if system_mode is None:\n            system_mode = 'whether_enough'\n            if state == 'no_edit':\n                continue\n            else:\n                feedback_attribute = edit_labels['attribute']\n                assert feedback_attribute is not None\n\n        # if state is no_edit, system_mode cannot be whether_enough\n        if not (state == 'no_edit' and system_mode == 'whether_enough'):\n            break\n\n    next_system_labels = {\n        'exception_mode': exception_mode,\n        'system_mode': system_mode,\n        'attribute': feedback_attribute\n    }\n\n    return next_system_labels\n"
  },
  {
    "path": "utils/editing_utils.py",
    "content": "def edit_target_attribute(opt,\n                          attribute_dict,\n                          edit_labels,\n                          round_idx,\n                          latent_code,\n                          edited_latent_code,\n                          field_model,\n                          editing_logger=None,\n                          print_intermediate_result=False,\n                          display_img=False):\n    \"\"\"\n    Input: current attribute labels, how to edit\n    Output: updated attribute labels\n    \"\"\"\n\n    edit_attr_name = edit_labels['attribute']\n    if edit_attr_name is None:\n        # dialog_logger.info('No edit in the current round')\n        exception_mode = 'normal'\n        return attribute_dict, exception_mode, latent_code, edited_latent_code\n\n    # define network\n    field_model.target_attr_idx = int(opt['attr_to_idx'][edit_attr_name])\n    field_model.load_network(opt['pretrained_field'][edit_attr_name])\n\n    latent_code, edited_latent_code, saved_label, exception_mode = \\\n        field_model.continuous_editing_with_target(\n            latent_codes=latent_code,\n            target_cls=edit_labels['target_score'],\n            save_dir=opt['path']['visualization'],\n            editing_logger=editing_logger,\n            edited_latent_code=edited_latent_code,\n            prefix=f'edit_order_{str(round_idx)}',\n            print_intermediate_result=print_intermediate_result,\n            display_img=display_img)\n\n    latent_code = latent_code.cpu().numpy()\n\n    # update attribtue_dict\n    for idx, (attr, old_label) in enumerate(list(attribute_dict.items())):\n        new_label = int(saved_label[idx])\n        if field_model.target_attr_idx != idx and new_label != old_label:\n            pass\n        attribute_dict[attr] = new_label\n\n    return attribute_dict, exception_mode, latent_code, edited_latent_code\n"
  },
  {
    "path": "utils/inversion_utils.py",
    "content": "import math\n\nimport models.archs.stylegan2.lpips as lpips\nimport numpy as np\nimport torch\nfrom PIL import Image\nfrom torch import optim\nfrom torch.nn import functional as F\nfrom torchvision import transforms\nfrom tqdm import tqdm\n\nfrom utils.crop_img import crop_img\n\n\ndef noise_regularize(noises):\n    loss = 0\n\n    for noise in noises:\n        size = noise.shape[2]\n\n        while True:\n            loss = (\n                loss +\n                (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2) +\n                (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2))\n\n            if size <= 8:\n                break\n\n            noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2])\n            noise = noise.mean([3, 5])\n            size //= 2\n\n    return loss\n\n\ndef noise_normalize_(noises):\n    for noise in noises:\n        mean = noise.mean()\n        std = noise.std()\n\n        noise.data.add_(-mean).div_(std)\n\n\ndef get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):\n    lr_ramp = min(1, (1 - t) / rampdown)\n    lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)\n    lr_ramp = lr_ramp * min(1, t / rampup)\n\n    return initial_lr * lr_ramp\n\n\ndef latent_noise(latent, strength):\n    noise = torch.randn_like(latent) * strength\n\n    return latent + noise\n\n\ndef make_image(tensor):\n    return (tensor.detach().clamp_(min=-1, max=1).add(1).div_(2).mul(255).type(\n        torch.uint8).permute(0, 2, 3, 1).to(\"cpu\").numpy())\n\n\ndef inversion(opt, field_model):\n\n    inv_opt = opt['inversion']\n    device = inv_opt['device']\n\n    img_size = opt['img_res']\n\n    # inversion\n    transform = transforms.Compose([\n        transforms.Resize(img_size),\n        transforms.CenterCrop(img_size),\n        transforms.ToTensor(),\n        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),\n    ])\n\n    if inv_opt['crop_img']:\n        cropped_output_path = f'{opt[\"path\"][\"visualization\"]}/cropped.png'\n        crop_img(img_size, inv_opt['img_path'], cropped_output_path, device)\n        img = transform(Image.open(cropped_output_path).convert(\"RGB\"))\n    else:\n        img = transform(Image.open(inv_opt['img_path']).convert(\"RGB\"))\n\n    img = img.unsqueeze(0).to(torch.device('cuda'))\n\n    batch, channel, height, width = img.shape\n\n    if height > 256:\n        factor = height // 256\n\n        img = img.reshape(batch, channel, height // factor, factor,\n                          width // factor, factor)\n        img = img.mean([3, 5])\n\n    n_mean_latent = 10000\n    with torch.no_grad():\n        noise_sample = torch.randn(n_mean_latent, 512, device=device)\n        latent_out = field_model.stylegan_gen.style_forward(noise_sample)\n\n        latent_mean = latent_out.mean(0)\n        latent_std = ((latent_out - latent_mean).pow(2).sum() /\n                      n_mean_latent)**0.5\n\n    percept = lpips.PerceptualLoss(\n        model=\"net-lin\", net=\"vgg\", use_gpu=device.startswith(\"cuda\"))\n\n    latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(\n        img.shape[0], 1)\n\n    latent_in.requires_grad = True\n\n    optim_params = []\n\n    for v in field_model.stylegan_gen.parameters():\n        if v.requires_grad:\n            optim_params.append(v)\n    optimizer = optim.Adam([{\n        'params': [latent_in]\n    }, {\n        'params': optim_params,\n        'lr': inv_opt['lr_gen']\n    }],\n                           lr=inv_opt['lr'])\n\n    pbar = tqdm(range(inv_opt['step']))\n\n    latent_path = []\n    for i in pbar:\n        t = i / inv_opt['step']\n        lr = get_lr(t, inv_opt['lr'])\n        optimizer.param_groups[0][\"lr\"] = lr\n        noise_strength = latent_std * inv_opt['noise'] * max(\n            0, 1 - t / inv_opt['noise_ramp'])**2\n\n        latent_n = latent_noise(latent_in, noise_strength.item())\n\n        img_gen, _ = field_model.stylegan_gen([latent_n],\n                                              input_is_latent=True,\n                                              randomize_noise=False)\n\n        batch, channel, height, width = img_gen.shape\n\n        if height > 256:\n            factor = height // 256\n\n            img_gen = img_gen.reshape(batch, channel, height // factor, factor,\n                                      width // factor, factor)\n            img_gen = img_gen.mean([3, 5])\n\n        p_loss = percept(img_gen, img).sum()\n        mse_loss = F.mse_loss(img_gen, img)\n\n        loss = p_loss + inv_opt['img_mse_weight'] * mse_loss\n\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        if (i + 1) % 100 == 0:\n            latent_path.append(latent_in.detach().clone())\n\n        pbar.set_description((f\"total: {loss:.4f}; perceptual: {p_loss:.4f};\"\n                              f\" mse: {mse_loss:.4f}; lr: {lr:.4f}\"))\n\n    latent_code = latent_in[0].cpu().detach().numpy()\n    latent_code = np.expand_dims(latent_code, axis=0)\n\n    return latent_code\n"
  },
  {
    "path": "utils/logger.py",
    "content": "import datetime\nimport logging\nimport time\n\n\nclass MessageLogger():\n    \"\"\"Message logger for printing.\n\n    Args:\n        opt (dict): Config. It contains the following keys:\n            name (str): Exp name.\n            logger (dict): Contains 'print_freq' (str) for logger interval.\n            train (dict): Contains 'niter' (int) for total iters.\n            use_tb_logger (bool): Use tensorboard logger.\n        start_iter (int): Start iter. Default: 1.\n        tb_logger (obj:`tb_logger`): Tensorboard logger. Default： None.\n    \"\"\"\n\n    def __init__(self, opt, start_iter=1, tb_logger=None):\n        self.exp_name = opt['name']\n        self.interval = opt['print_freq']\n        self.start_iter = start_iter\n        self.max_iters = opt['max_iters']\n        self.use_tb_logger = opt['use_tb_logger']\n        self.tb_logger = tb_logger\n        self.start_time = time.time()\n        self.logger = get_root_logger()\n\n    def __call__(self, log_vars):\n        \"\"\"Format logging message.\n\n        Args:\n            log_vars (dict): It contains the following keys:\n                epoch (int): Epoch number.\n                iter (int): Current iter.\n                lrs (list): List for learning rates.\n\n                time (float): Iter time.\n                data_time (float): Data time for each iter.\n        \"\"\"\n        # epoch, iter, learning rates\n        epoch = log_vars.pop('epoch')\n        current_iter = log_vars.pop('iter')\n        lrs = log_vars.pop('lrs')\n\n        message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, '\n                   f'iter:{current_iter:8,d}, lr:(')\n        for v in lrs:\n            message += f'{v:.3e},'\n        message += ')] '\n\n        # time and estimated time\n        if 'time' in log_vars.keys():\n            iter_time = log_vars.pop('time')\n            data_time = log_vars.pop('data_time')\n\n            total_time = time.time() - self.start_time\n            time_sec_avg = total_time / (current_iter - self.start_iter + 1)\n            eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)\n            eta_str = str(datetime.timedelta(seconds=int(eta_sec)))\n            message += f'[eta: {eta_str}, '\n            message += f'time: {iter_time:.3f}, data_time: {data_time:.3f}] '\n\n        # other items, especially losses\n        for k, v in log_vars.items():\n            message += f'{k}: {v:.4e} '\n            # tensorboard logger\n            if self.use_tb_logger and 'debug' not in self.exp_name:\n                self.tb_logger.add_scalar(k, v, current_iter)\n\n        self.logger.info(message)\n\n\ndef init_tb_logger(log_dir):\n    from torch.utils.tensorboard import SummaryWriter\n    tb_logger = SummaryWriter(log_dir=log_dir)\n    return tb_logger\n\n\ndef get_root_logger(logger_name='base', log_level=logging.INFO, log_file=None):\n    \"\"\"Get the root logger.\n\n    The logger will be initialized if it has not been initialized. By default a\n    StreamHandler will be added. If `log_file` is specified, a FileHandler will\n    also be added.\n\n    Args:\n        logger_name (str): root logger name. Default: base.\n        log_file (str | None): The log filename. If specified, a FileHandler\n            will be added to the root logger.\n        log_level (int): The root logger level. Note that only the process of\n            rank 0 is affected, while other processes will set the level to\n            \"Error\" and be silent most of the time.\n\n    Returns:\n        logging.Logger: The root logger.\n    \"\"\"\n    logger = logging.getLogger(logger_name)\n    # if the logger has been initialized, just return it\n    if logger.hasHandlers():\n        return logger\n\n    format_str = '%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s'\n    logging.basicConfig(format=format_str, level=log_level)\n\n    if log_file is not None:\n        file_handler = logging.FileHandler(log_file, 'w')\n        file_handler.setFormatter(logging.Formatter(format_str))\n        file_handler.setLevel(log_level)\n        logger.addHandler(file_handler)\n\n    return logger\n"
  },
  {
    "path": "utils/numerical_metrics.py",
    "content": "import argparse\nimport glob\n\nimport cv2\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision.transforms as transforms\nfrom models.archs.attribute_predictor_arch import resnet50\nfrom models.losses.arcface_loss import resnet_face18\nfrom models.utils import output_to_label\nfrom PIL import Image\n\n\ndef parse_args():\n    \"\"\"Parses arguments.\"\"\"\n    parser = argparse.ArgumentParser(\n        description='Continuous image editing via field function')\n\n    # inference\n    parser.add_argument(\n        '--attribute',\n        type=str,\n        required=True,\n        help='[Bangs, Eyeglasses, No_Beard, Smiling, Young]')\n\n    parser.add_argument('--confidence_thresh', type=float, default=0.8)\n\n    # input and output directories\n    parser.add_argument(\n        '--work_dir',\n        required=True,\n        type=str,\n        metavar='PATH',\n        help='path to save checkpoint and log files.')\n    parser.add_argument(\n        '--image_dir',\n        required=True,\n        type=str,\n        metavar='PATH',\n        help='path to save checkpoint and log files.')\n    parser.add_argument('--image_num', type=int, required=True)\n    parser.add_argument('--debug', default=0, type=int)\n\n    # predictor args\n    parser.add_argument(\n        '--attr_file',\n        required=True,\n        type=str,\n        help='directory to attribute metadata')\n    parser.add_argument(\n        '--predictor_ckpt',\n        required=True,\n        type=str,\n        help='The pretrained network weights for testing')\n    parser.add_argument('--num_attr', type=int, default=5)\n\n    # arcface loss args\n    parser.add_argument(\n        '--pretrained_arcface',\n        default=  # noqa\n        '../share_work_dirs/pretrained_arcface/arcface_resnet18_110.pth',\n        type=str)\n\n    return parser.parse_args()\n\n\ndef get_edited_images_list(img_dir, img_idx):\n    return_img_list = []\n    img_path_list = glob.glob(f'{img_dir}/{img_idx:03d}_*.png')\n    start_img_path = glob.glob(f'{img_dir}/{img_idx:03d}_num_edits_0_*.png')\n    assert len(start_img_path) == 1\n    return_img_list.append(start_img_path[0])\n\n    num_edits = len(img_path_list) - 1\n    if num_edits > 0:\n        for edit_idx in range(1, num_edits + 1):\n            img_path_edit_list = glob.glob(\n                f'{img_dir}/{img_idx:03d}_num_edits_{edit_idx}_*.png')\n            assert len(img_path_edit_list) == 1\n            return_img_list.append(img_path_edit_list[0])\n\n    return return_img_list\n\n\ndef load_image_predictor(img_path,\n                         transform=transforms.Compose([transforms.ToTensor()\n                                                       ])):\n    image = Image.open(img_path).convert('RGB')\n    image = transform(image)\n    image = image.to(torch.device('cuda')).unsqueeze(0)\n\n    if image.size()[-1] > 128:\n        image = F.interpolate(image, (128, 128), mode='area')\n\n    img_mean = torch.Tensor([0.485, 0.456,\n                             0.406]).view(1, 3, 1, 1).to(torch.device('cuda'))\n    img_std = torch.Tensor([0.229, 0.224,\n                            0.225]).view(1, 3, 1, 1).to(torch.device('cuda'))\n    image = (image - img_mean) / img_std\n\n    return image\n\n\ndef load_image_arcface(img_path):\n    image = cv2.imread(img_path, 0)\n    if image is None:\n        return None\n    image = image[:, :, np.newaxis]\n    image = image.transpose((2, 0, 1))\n    image = image[:, np.newaxis, :, :]\n    image = image.astype(np.float32, copy=False)\n    image -= 127.5\n    image /= 127.5\n\n    image = torch.from_numpy(image).to(torch.device('cuda'))\n\n    if image.size()[-1] > 128:\n        image = F.interpolate(image, (128, 128), mode='area')\n\n    return image\n\n\ndef cosin_metric(x1, x2):\n    return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))\n\n\ndef predictor_score(predictor_output, gt_label, target_attr_idx,\n                    criterion_predictor):\n    num_attr = len(predictor_output)\n    loss_avg = 0\n    count = 0\n    for attr_idx in range(num_attr):\n        if attr_idx == target_attr_idx:\n            continue\n        loss_attr = criterion_predictor(\n            predictor_output[attr_idx],\n            gt_label[attr_idx].unsqueeze(0).to(torch.device('cuda')))\n\n        loss_avg += loss_attr\n        count += 1\n    loss_avg = loss_avg / count\n\n    return loss_avg\n\n\ndef compute_num_metrics(image_dir, image_num, pretrained_arcface, attr_file,\n                        pretrained_predictor, target_attr_idx, logger):\n\n    # define arcface model\n    arcface_model = resnet_face18(use_se=False)\n    arcface_model = nn.DataParallel(arcface_model)\n    arcface_model.load_state_dict(torch.load(pretrained_arcface), strict=True)\n    arcface_model.to(torch.device('cuda'))\n    arcface_model.eval()\n\n    # define predictor model\n    predictor = predictor = resnet50(attr_file=attr_file)\n    predictor = predictor.to(torch.device('cuda'))\n\n    checkpoint = torch.load(pretrained_predictor)\n    predictor.load_state_dict(checkpoint['state_dict'], strict=True)\n    predictor.eval()\n\n    criterion_predictor = nn.CrossEntropyLoss(reduction='mean')\n\n    arcface_sim_dataset = 0\n    predictor_score_dataset = 0\n    count = 0\n    for img_idx in range(image_num):\n        edit_image_list = get_edited_images_list(image_dir, img_idx)\n        num_edits = len(edit_image_list) - 1\n        arcface_sim_img = 0\n        predictor_score_img = 0\n        if num_edits > 0:\n            # read image for arcface\n            source_img_arcface = load_image_arcface(edit_image_list[0])\n            with torch.no_grad():\n                source_feature = arcface_model(\n                    source_img_arcface).cpu().numpy()\n            # read image for predictor\n            source_img_predictor = load_image_predictor(edit_image_list[0])\n            with torch.no_grad():\n                source_predictor_output = predictor(source_img_predictor)\n            source_label, score = output_to_label(source_predictor_output)\n            for edit_idx in range(1, num_edits + 1):\n                # arcface cosine similarity\n                edited_img_arcface = load_image_arcface(\n                    edit_image_list[edit_idx])\n                with torch.no_grad():\n                    edited_feature = arcface_model(\n                        edited_img_arcface).cpu().numpy()\n                temp_arcface_sim = cosin_metric(source_feature,\n                                                edited_feature.transpose(\n                                                    1, 0))[0][0]\n                arcface_sim_img += temp_arcface_sim\n                # predictor score\n                edited_img_predictor = load_image_predictor(\n                    edit_image_list[edit_idx])\n                with torch.no_grad():\n                    edited_predictor_output = predictor(edited_img_predictor)\n                temp_predictor_score_img = predictor_score(\n                    edited_predictor_output, source_label, target_attr_idx,\n                    criterion_predictor)\n                predictor_score_img += temp_predictor_score_img\n\n            arcface_sim_img = arcface_sim_img / num_edits\n            predictor_score_img = predictor_score_img / num_edits\n            arcface_sim_dataset += arcface_sim_img\n            predictor_score_dataset += predictor_score_img\n            count += 1\n            logger.info(\n                f'{img_idx:03d}: Arcface: {arcface_sim_img: .4f}, Predictor: {predictor_score_img: .4f}.'  # noqa\n            )\n        else:\n            logger.info(f'{img_idx:03d}: no available edits.')\n\n    arcface_sim_dataset = arcface_sim_dataset / count\n    predictor_score_dataset = predictor_score_dataset / count\n    logger.info(\n        f'Avg: {arcface_sim_dataset: .4f}, {predictor_score_dataset: .4f}.')\n\n    return arcface_sim_dataset, predictor_score_dataset\n"
  },
  {
    "path": "utils/options.py",
    "content": "import os\nimport os.path as osp\nfrom collections import OrderedDict\n\nimport yaml\n\n\ndef ordered_yaml():\n    \"\"\"Support OrderedDict for yaml.\n\n    Returns:\n        yaml Loader and Dumper.\n    \"\"\"\n    try:\n        from yaml import CDumper as Dumper\n        from yaml import CLoader as Loader\n    except ImportError:\n        from yaml import Dumper, Loader\n\n    _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG\n\n    def dict_representer(dumper, data):\n        return dumper.represent_dict(data.items())\n\n    def dict_constructor(loader, node):\n        return OrderedDict(loader.construct_pairs(node))\n\n    Dumper.add_representer(OrderedDict, dict_representer)\n    Loader.add_constructor(_mapping_tag, dict_constructor)\n    return Loader, Dumper\n\n\ndef parse(opt_path, is_train=True):\n    \"\"\"Parse option file.\n\n    Args:\n        opt_path (str): Option file path.\n        is_train (str): Indicate whether in training or not. Default: True.\n\n    Returns:\n        (dict): Options.\n    \"\"\"\n    with open(opt_path, mode='r') as f:\n        Loader, _ = ordered_yaml()\n        opt = yaml.load(f, Loader=Loader)\n\n    gpu_list = ','.join(str(x) for x in opt['gpu_ids'])\n    if opt.get('set_CUDA_VISIBLE_DEVICES', None):\n        os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list\n        # print('export CUDA_VISIBLE_DEVICES=' + gpu_list, flush=True)\n    else:\n        pass\n        # print('gpu_list: ', gpu_list, flush=True)\n\n    opt['is_train'] = is_train\n\n    # datasets\n    if opt['is_train']:\n        input_latent_dir = opt['input_latent_dir']\n        opt['dataset'] = {}\n        opt['dataset']['train_latent_dir'] = f'{input_latent_dir}/train'\n        if opt['val_on_train_subset']:\n            opt['dataset'][\n                'train_subset_latent_dir'] = f'{input_latent_dir}/train_subset'\n        if opt['val_on_valset']:\n            opt['dataset']['val_latent_dir'] = f'{input_latent_dir}/val'\n\n    # paths\n    opt['path'] = {}\n    opt['path']['root'] = osp.abspath(\n        osp.join(__file__, osp.pardir, osp.pardir))\n    if is_train:\n        experiments_root = osp.join(opt['path']['root'], 'experiments',\n                                    opt['name'])\n        opt['path']['experiments_root'] = experiments_root\n        opt['path']['models'] = osp.join(experiments_root, 'models')\n        opt['path']['log'] = experiments_root\n        opt['path']['visualization'] = osp.join(experiments_root,\n                                                'visualization')\n\n        # change some options for debug mode\n        if 'debug' in opt['name']:\n            opt['debug'] = True\n            opt['val_freq'] = 1\n            opt['print_freq'] = 1\n            opt['save_checkpoint_freq'] = 1\n            opt['dataset'][\n                'train_latent_dir'] = f'{input_latent_dir}/train_subset'\n            if opt['val_on_train_subset']:\n                opt['dataset'][\n                    'train_subset_latent_dir'] = f'{input_latent_dir}/train_subset'  # noqa\n            if opt['val_on_valset']:\n                opt['dataset'][\n                    'val_latent_dir'] = f'{input_latent_dir}/train_subset'\n    else:  # test\n        results_root = osp.join(opt['path']['root'], 'results', opt['name'])\n        opt['path']['results_root'] = results_root\n        opt['path']['log'] = results_root\n        opt['path']['visualization'] = osp.join(results_root, 'visualization')\n    # some basics for editing task\n    opt['attr_list'] = ['Bangs', 'Eyeglasses', 'No_Beard', 'Smiling', 'Young']\n    opt['attr_dict'] = {\n        'Bangs': 0,\n        'Eyeglasses': 1,\n        'No_Beard': 2,\n        'Smiling': 3,\n        'Young': 4\n    }\n\n    if 'has_dialog' in opt.keys():\n        opt['path']['dialog'] = osp.join(results_root, 'dialog')\n\n    return opt\n\n\ndef dict2str(opt, indent_level=1):\n    \"\"\"dict to string for printing options.\n\n    Args:\n        opt (dict): Option dict.\n        indent_level (int): Indent level. Default: 1.\n\n    Return:\n        (str): Option string for printing.\n    \"\"\"\n    msg = ''\n    for k, v in opt.items():\n        if isinstance(v, dict):\n            msg += ' ' * (indent_level * 2) + k + ':[\\n'\n            msg += dict2str(v, indent_level + 1)\n            msg += ' ' * (indent_level * 2) + ']\\n'\n        else:\n            msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\\n'\n    return msg\n\n\nclass NoneDict(dict):\n    \"\"\"None dict. It will return none if key is not in the dict.\"\"\"\n\n    def __missing__(self, key):\n        return None\n\n\ndef dict_to_nonedict(opt):\n    \"\"\"Convert to NoneDict, which returns None for missing keys.\n\n    Args:\n        opt (dict): Option dict.\n\n    Returns:\n        (dict): NoneDict for options.\n    \"\"\"\n    if isinstance(opt, dict):\n        new_opt = dict()\n        for key, sub_opt in opt.items():\n            new_opt[key] = dict_to_nonedict(sub_opt)\n        return NoneDict(**new_opt)\n    elif isinstance(opt, list):\n        return [dict_to_nonedict(sub_opt) for sub_opt in opt]\n    else:\n        return opt\n\n\ndef parse_args_from_opt(args, opt):\n    '''\n    Given the opt, parse it to args,\n    since previous code for dialog and language\n    uses args to pass arguments among different scripts\n    '''\n    for (key, value) in list(opt.items()):\n        setattr(args, key, value)\n    for (key, value) in list(opt['language_encoder'].items()):\n        setattr(args, key, value)\n    args.pretrained_checkpoint = opt['pretrained_language_encoder']\n    return args\n\n\ndef parse_opt_wrt_resolution(opt):\n    if opt['img_res'] == 1024:\n        opt['channel_multiplier'] = opt['channel_multiplier_1024']\n        opt['pretrained_field'] = opt['pretrained_field_1024']\n        opt['predictor_ckpt'] = opt['predictor_ckpt_1024']\n        opt['generator_ckpt'] = opt['generator_ckpt_1024']\n        opt['replaced_layers'] = opt['replaced_layers_1024']\n\n    elif opt['img_res'] == 128:\n        opt['channel_multiplier'] = opt['channel_multiplier_128']\n        opt['pretrained_field'] = opt['pretrained_field_128']\n        opt['predictor_ckpt'] = opt['predictor_ckpt_128']\n        opt['generator_ckpt'] = opt['generator_ckpt_128']\n        opt['replaced_layers'] = opt['replaced_layers_128']\n\n    return opt\n"
  },
  {
    "path": "utils/util.py",
    "content": "import logging\nimport os\nimport random\nimport sys\nimport time\nfrom shutil import get_terminal_size\n\nimport numpy as np\nimport torch\n\nlogger = logging.getLogger('base')\n\n\ndef make_exp_dirs(opt):\n    \"\"\"Make dirs for experiments.\"\"\"\n    path_opt = opt['path'].copy()\n    if opt['is_train']:\n        overwrite = True if 'debug' in opt['name'] else False\n        os.makedirs(path_opt.pop('experiments_root'), exist_ok=overwrite)\n        os.makedirs(path_opt.pop('models'), exist_ok=overwrite)\n    else:\n        os.makedirs(path_opt.pop('results_root'))\n\n\ndef set_random_seed(seed):\n    \"\"\"Set random seeds.\"\"\"\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n\n\nclass ProgressBar(object):\n    \"\"\"A progress bar which can print the progress.\n\n    Modified from:\n    https://github.com/hellock/cvbase/blob/master/cvbase/progress.py\n    \"\"\"\n\n    def __init__(self, task_num=0, bar_width=50, start=True):\n        self.task_num = task_num\n        max_bar_width = self._get_max_bar_width()\n        self.bar_width = (\n            bar_width if bar_width <= max_bar_width else max_bar_width)\n        self.completed = 0\n        if start:\n            self.start()\n\n    def _get_max_bar_width(self):\n        terminal_width, _ = get_terminal_size()\n        max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)\n        if max_bar_width < 10:\n            print(f'terminal width is too small ({terminal_width}), '\n                  'please consider widen the terminal for better '\n                  'progressbar visualization')\n            max_bar_width = 10\n        return max_bar_width\n\n    def start(self):\n        if self.task_num > 0:\n            sys.stdout.write(f\"[{' ' * self.bar_width}] 0/{self.task_num}, \"\n                             f'elapsed: 0s, ETA:\\nStart...\\n')\n        else:\n            sys.stdout.write('completed: 0, elapsed: 0s')\n        sys.stdout.flush()\n        self.start_time = time.time()\n\n    def update(self, msg='In progress...'):\n        self.completed += 1\n        elapsed = time.time() - self.start_time\n        fps = self.completed / elapsed\n        if self.task_num > 0:\n            percentage = self.completed / float(self.task_num)\n            eta = int(elapsed * (1 - percentage) / percentage + 0.5)\n            mark_width = int(self.bar_width * percentage)\n            bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)\n            sys.stdout.write('\\033[2F')  # cursor up 2 lines\n            sys.stdout.write(\n                '\\033[J'\n            )  # clean the output (remove extra chars since last display)\n            sys.stdout.write(\n                f'[{bar_chars}] {self.completed}/{self.task_num}, '\n                f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, '\n                f'ETA: {eta:5}s\\n{msg}\\n')\n        else:\n            sys.stdout.write(\n                f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s, '\n                f'{fps:.1f} tasks/s')\n        sys.stdout.flush()\n"
  }
]