[
  {
    "path": ".gitignore",
    "content": "# Compiled class file\n*.class\n\n# Log file\n*.log\n\n# BlueJ files\n*.ctxt\n\n# Mobile Tools for Java (J2ME)\n.mtj.tmp/\n\n# Package Files #\n*.jar\n*.war\n*.nar\n*.ear\n*.zip\n*.tar.gz\n*.rar\n\n# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml\nhs_err_pid*\n\n.gradle\n.vscode\n\nout/\n.idea/\n.gradle/\nbuild/\n/data/\n.idea/*\n*.iml\n\n*/.idea/*\n*/build\n*/.gradle/*\n*/out/*\n*.pyc\n*.pyd\n**/.cache/*\n*/bin\n\n__pycache__/\n./tools/\n\ntemp/\n\n*/*/bin/\n\nvenv/*\n\n# io dump to ABCI tasks\n*.o*"
  },
  {
    "path": ".python-version",
    "content": "3.10.11"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\r\n\r\nCopyright (c) 2024 pixiv Inc.\r\n\r\nPermission is hereby granted, free of charge, to any person obtaining a copy \r\nof this software and associated documentation files (the \"Software\"), to deal\r\nin the Software without restriction, including without limitation the rights\r\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\r\ncopies of the Software, and to permit persons to whom the Software is\r\nfurnished to do so, subject to the following conditions:\r\n\r\nThe above copyright notice and this permission notice shall be included in all\r\ncopies or substantial portions of the Software.\r\n\r\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\r\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\r\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\r\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\r\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\r\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\r\nSOFTWARE."
  },
  {
    "path": "README.md",
    "content": "# Demo Code for \"Talking Head(?) Anime from a Single Image 4: Improved Model and Its Distillation\"\n\nThis repository contains demo programs for the \"Talking Head(?) Anime from a Single Image 4: Improved Model and Its Distillation\" project. Roughly, the project is about a machine learning model that can animate an anime character given only one image. However, the model is too slow to run in real-time. So, it also proposes an algorithm to use the model to train a small machine learning model that is specialized to a character image that can anime the character in real time.\n\nThis demo code has two parts.\n\n* **Improved model.** This part gives a model similar to [Version 3](https://github.com/pkhungurn/talking-head-anime-3-demo) of the porject. It has one demo program:\n\n  * The `full_manual_poser` allows the user to manipulate a character's facial expression and body rotation through a graphical user interface.\n\n  There are no real-time demos because the new model is too slow for that.\n\n* **Distillation.** This part allows the user to train small models (which we will refer to as **student models**) to mimic that behavior of the full system with regards to a specific character image. It also allows the user to run these models under various interfaces. The demo programs are:\n\n  * `distill` trains a student model given a configuration file, a $512 \\times 512$ RGBA character image, and a mask of facial organs.\n  * `distiller_ui` provides a user-friendly interface to `distill`, allowing you to create training configurations and providing useful documentation.\n  * `character_model_manual_poser` allows the user to control trained student models with a graphical user interface.\n  * `character_model_ifacialmocap_puppeteer` allows the user to control trained student models with their facial movement, which is captured by the [iFacialMocap](https://www.ifacialmocap.com/) software. To run this software, you must have an iOS device and, of course, iFacialMocap.\n  *  `character_model_mediapipe_puppeteer` allows the user to control trained student models with their facial movement, which is captured a web camera and processed by the [Mediapipe FaceLandmarker](https://developers.google.com/mediapipe/solutions/vision/face_landmarker) model. To run this software, you need a web camera.\n\n## Preemptive FAQs\n\n### What is the program to control character images with my facial movement?\n\nThere is no such program in this release. If you want one, try the `ifacialmocap_puppeteer` of [Version 3](https://github.com/pkhungurn/talking-head-anime-3-demo).\n\n### OK. I'm confused. Isn't your work about easy VTubing? Are you saying this release cannot do it?\n\nNO. This release does it in a more complicated way. In order to control an image, you need to create a \"student model.\" It is a small (< 2MB) and fast machine learning model that knows how to animate that particular image. Then, the student model can be controlled with facial movement. You can find two student models in the `data/character_models` directory. The [two](https://pkhungurn.github.io/talking-head-anime-4/supplementary/webcam-demo/index.html) [demos](https://pkhungurn.github.io/talking-head-anime-4/supplementary/manual-poser-demo/index.html) on the project website feature 13 students models.\n\n### So, for this release, you can control only these few characters in real time?\n\nNo. You can create your own student models.\n\n### How do I create this student model then?\n\n1. You prepare your characater image according to the \"Constraint on Input Images\" section below.\n2. You prepare a black-and-white mask image that covers the eyes and the mouth of the character, like [this image](data/images/lambda_00_face_mask.png). You can see how I made it with [GIMP](https://www.gimp.org/) by inspecting this [GIMP file](data/images/lambda_00_face_mask.xcf).\n3. You use `distiller_ui` to create a configuration file that specifies how the student model should be trained.\n4. You use `distiller_ui` or `distill` to start the training process.\n5. You wait several ten hours for the student model to finish training. Last time I tried, it was about 30 hours on a computer with an Nvidia RTX A6000 GPU.\n6. After that, you can control the student model with `character_model_ifacialmocap_puppeteer` and `character_model_mediapipe_puppeteer`.\n\n### Why is this release so hard to use?\n\n[Version 3](https://github.com/pkhungurn/talking-head-anime-3-demo) is arguably easier to use because you can give it an animate and you can control it with your facial movment immediately. However, I was not satisfied with its image quality and speed. \n\nIn this release, I explore a new way of doing things. I added a new preprocessing stage (i.e., training the student models) that has to be done one time per character image. It allows the image to be animated much faster at a higher image quality level.\n\nIn other words, it makes the user's life difficult but the engineer/researcher happy. Patient users who are willing to go through the steps, though, would be rewarded with faster animation.\n\n\n### Can I use a student model from a web browser?\n\nNo. A student model created by `distill` is a [PyTorch](https://pytorch.org/) model, which cannot run directly in the browser. It needs to be converted to the appropriate format ([TensorFlow.js](https://www.tensorflow.org/js)) first, and the [web](https://pkhungurn.github.io/talking-head-anime-4/supplementary/webcam-demo/index.html) [demos](https://pkhungurn.github.io/talking-head-anime-4/supplementary/manual-poser-demo/index.html) use the converted models. However, The conversion code is not included in this repository. I will not release it unless I change my mind.\n\n## Hardware Requirements\n\nAll programs require a recent and powerful Nvidia GPU to run. I developed the programs on a machine with an Nvidia RTX A6000. However, anything after the GeForce RTX 2080 should be fine.\n\nThe `character_model_ifacialmocap_puppeteer` program requires an iOS device that is capable of computing [blend shape parameters](https://developer.apple.com/documentation/arkit/arfaceanchor/2928251-blendshapes) from a video feed. This means that the device must be able to run iOS 11.0 or higher and must have a TrueDepth front-facing camera. (See [this page](https://developer.apple.com/documentation/arkit/content_anchors/tracking_and_visualizing_faces) for more info.) In other words, if you have the iPhone X or something better, you should be all set. Personally, I have used an iPhone 12 mini.\n\nThe `character_model_mediapipe_puppeteer` program requires a web camera.\n\n## Software Requirements\n\n### GPU Driver and CUDA Toolkit\n\nPlease update your GPU's device driver and install the [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit) that is compatible with your GPU and is newer than the version you will be installing in the next subsection.\n\n### Python and Python Libraries\n\nAll programs are written in the [Python](https://www.python.org/) programming languages. The following libraries are required:\n\n* `python` 3.10.11\n* `torch` 1.13.1 with CUDA support\n* `torchvision` 0.14.1\n* `tensorboard` 2.15.1\n* `opencv-python` 4.8.1.78\n* `wxpython` 4.2.1\n* `numpy-quaternion` 2022.4.2\n* `pillow` 9.4.0\n* `matplotlib` 3.6.3\n* `einops` 0.6.0\n* `mediapipe` 0.10.3\n* `numpy` 1.26.3\n* `scipy` 1.12.0\n* `omegaconf` 2.3.0\n\nInstead of installing these libraries yourself, you should follow the recommended method to set up a Python environment in the next section.\n\n### iFacialMocap\n\nIf you want to use ``ifacialmocap_puppeteer``, you will also need to an iOS software called [iFacialMocap](https://www.ifacialmocap.com/) (a 980 yen purchase in the App Store). Your iOS and your computer must use the same network. For example, you may connect them to the same wireless router.\n\n## Creating Python Environment\n\n### Installing Python\n\nPlease install [Python 3.10.11](https://www.python.org/downloads/release/python-31011/). \n\nI recommend using [`pyenv`](https://github.com/pyenv/pyenv) (or [`pyenv-win`](https://github.com/pyenv-win/pyenv-win) for Windows users) to manage multiple Python versions on your system. If you use `pyenv`, this repository has a `.python-version` file that indicates it would use Python 3.10.11. So, you will be using Python 3.10.11 automatically once you `cd` into the repository's directory.\n\nMake sure that you can run Python from the command line.\n\n### Installing Poetry\n\nPlease install [Poetry](https://python-poetry.org/) 1.7 or later. We will use it to automatically install the required libraries. Again, make sure that you can run it from the command line.\n\n### Cloning the Repository\n\nPlease clone the repository to an arbitrary directory in your machine.\n\n### Instruction for Linux/OSX Users\n\n1. Open a shell.\n2. `cd` to the directory you just cloned the repository too\n   ```\n   cd SOMEWHERE/talking-head-anime-4-demo\n   ```\n3. Use Python to create a virtual environment under the `venv` directory.\n   ```\n   python -m venv venv --prompt talking-head-anime-4-demo\n   ```\n4. Activate the newly created virtual environment. You can either use the script I provide:\n   ```\n   source bin/activate-venv.sh\n   ```\n   or do it yourself:\n   ```\n   source venv/bin/activate   \n   ```\n5. Use Poetry to install libraries.\n   ```\n   cd poetry\n   poetry install\n   ```\n\n### Instruction for Windows Users\n\n1. Open a shell.\n2. `cd` to the directory you just cloned the repository too\n   ```\n   cd SOMEWHERE\\talking-head-anime-4-demo\n   ```\n3. Use Python to create a virtual environment under the `venv` directory.\n   ```\n   python -m venv venv --prompt talking-head-anime-4-demo\n   ```\n4. Activate the newly created virtual environment. You can either use the script I provide:\n   ```\n   bin\\activate-venv.bat\n   ```\n   or do it yourself:\n   ```\n   venv\\Scripts\\activate   \n   ```\n5. Use Poetry to install libraries.\n   ```\n   cd poetry\n   poetry install\n   ```\n\n## Download the Models/Dataset Files\n\n### THA4 Models\n\nPlease download [this ZIP file](https://www.dropbox.com/scl/fi/7wec0sur7449iqgtlpi3n/tha4-models.zip?rlkey=0f9d1djmbvjjjn09469s1adx8&dl=0) hosted on Dropbox, and unzip it to the `data/tha4` directory the under the repository's directory. In the end, the directory tree should look like the following diagram:\n\n```\n+ talking-head-anime-4-demo\n   + data\n      - character_models\n      - distill_examples\n      + tha4\n         - body_morpher.pt\n         - eyebrow_decomposer.pt\n         - eyebrow_morphing_combiner.pt\n         - face_morpher.pt\n         - upscaler.pt\n     - images\n     - third_party\n```\n\n### Pose Dataset\n\nIf you want to create your own student models, you also need to download a dataset of poses that are needed for the training process. Download [this `pose_dataset.pt` file](https://www.dropbox.com/scl/fi/du10e6buzr5bslbe025qu/pose_dataset.pt?rlkey=y052g4n3xb14nu2elctzouc5x&dl=0) and save it to the `data` folder. The directory tree should then look like the following diagram:\n\n```\n+ talking-head-anime-4-demo\n   + data\n      - character_models\n      - distill_examples\n      - tha4\n      - images\n      - third_party\n      - pose_dataset.pt\n```\n\n## Running the Programs\n\nThe programs are located in the `src/tha4/app` directory. You need to run them from a shell with the provided scripts.\n\n### Instruction for Linux/OSX Users\n\n1. Open a shell.\n2. `cd` to the repository's directory.\n   ```\n   cd SOMEWHERE/talking-head-anime-4-demo\n   ```\n3. Run a program.\n   ```\n   bin/run src/tha4/app/<program-file-name>\n   ```\n   where `<program-file-name>` can be replaced with:\n   \n   * `character_model_ifacialmocap_puppeteer.py`\n   * `character_model_manual_poser.py`\n   * `character_model_mediapipe_puppeteer.py`\n   * `distill.py`\n   * `disllerer_ui.py`\n   * `full_manual_poser.py`\n\n### Instruction for Windows Users\n\n1. Open a shell.\n2. `cd` to the repository's directory.\n   ```\n   cd SOMEWHERE\\talking-head-anime-4-demo\n   ```\n3. Run a program.\n   ```\n   bin\\run.bat src\\tha4\\app\\<program-file-name>\n   ```\n   where `<program-file-name>` can be replaced with:\n   \n   * `character_model_ifacialmocap_puppeteer.py`\n   * `character_model_manual_poser.py`\n   * `character_model_mediapipe_puppeteer.py`\n   * `distill.py`\n   * `disllerer_ui.py`\n   * `full_manual_poser.py`\n\n## Contraints on Input Images\n\nIn order for the system to work well, the input image must obey the following constraints:\n\n* It should be of resolution 512 x 512. (If the demo programs receives an input image of any other size, they will resize the image to this resolution and also output at this resolution.)\n* It must have an alpha channel.\n* It must contain only one humanoid character.\n* The character should be standing upright and facing forward.\n* The character's hands should be below and far from the head.\n* The head of the character should roughly be contained in the 128 x 128 box in the middle of the top half of the image.\n* The alpha channels of all pixels that do not belong to the character (i.e., background pixels) must be 0.\n\n![An example of an image that conforms to the above criteria](docs/images/input_spec.png \"An example of an image that conforms to the above criteria\")\n\n## Documentation for the Tools\n\n* [`character_model_ifacial_model_puppeteer`](docs/character_model_ifacialmocap_puppeteer.md)\n* [`character_model_manual_poser`](docs/character_model_manual_poser.md)\n* [`character_model_mediapipe_puppeteer`](docs/character_model_mediapipe_puppeteer.md)\n* [`distill`](docs/distill.md)\n* [`distiller_ui`](docs/distiller_ui.md)\n* [`full_manual_poser`](docs/full_manual_poser.md)\n\n## Disclaimer\n\nThe author is an employee of [pixiv Inc.](https://www.pixiv.co.jp/) This project is a part of his work as a researcher.\n\nHowever, this project is NOT a pixiv product. The company will NOT provide any support for this project. The author will try to support the project, but there are no Service Level Agreements (SLAs) that he will maintain.\n\nThe code is released under the [MIT license](https://github.com/pkhungurn/talking-head-anime-2-demo/blob/master/LICENSE).\nThe THA4 models and the images under the `data/images` directory are released under the [Creative Commons Attribution-NonCommercial 4.0 International](https://creativecommons.org/licenses/by-nc/4.0/deed.en).\n\nThis repository redistributes a version of the [Face landmark detection model](https://developers.google.com/mediapipe/solutions/vision/face_landmarker) from the [MediaPipe](https://developers.google.com/mediapipe) project. The model has been released under the [Apache License, Version 2.0](https://www.apache.org/licenses/LICENSE-2.0.html)."
  },
  {
    "path": "bin/activate-venv.bat",
    "content": "venv\\Scripts\\activate"
  },
  {
    "path": "bin/activate-venv.sh",
    "content": "#! /bin/bash\nsource venv/bin/activate"
  },
  {
    "path": "bin/run",
    "content": "#! /bin/bash\nexport PYTHONPATH=$(pwd)/src\nvenv/bin/python $@\n"
  },
  {
    "path": "bin/run.bat",
    "content": "set PYTHONPATH=%cd%\\src\r\nvenv\\Scripts\\python.exe %*\r\n"
  },
  {
    "path": "distiller-ui-doc/index.html",
    "content": "<html lang=\"en\">\r\n<head>\r\n    <title>Distiller UI Documentation</title>\r\n</head>\r\n<body>\r\n<h1>How to use Distiller UI</h1>\r\n\r\n<p>This program is called <code>distiller_ui</code>. It allows you to create and modify configurations for the process of distilling the full, but slow THA4 system to a student model that can be run in real time on computers with moderately powerful GPUs.</p>\r\n\r\n<h2>Basic Usage</h2>\r\n\r\n<p>This program manipulates YAML files that are used as configurations for the distillation process. The menus\r\n<ul>\r\n    <li><b>File &rarr; New</b></li>\r\n    <li><b>File &rarr; Open</b></li>\r\n    <li><b>File &rarr; Save</b></li>\r\n</ul>\r\ndo what they are supposed to do in typical application programs.\r\n</p>\r\n\r\n<p>You can use the UI in the middle panel to change various parameters of the configuration. If you do not understand what the meaning of a parameter, click the \"Help\" button for that parameter to learn more.</p>\r\n\r\n<p>Once you have modified the parameters to your liking, click the \"RUN\" button at the bottom of the middle panel to carry out the distillation. This will take several ten hours, so sit back and relax.</p>\r\n\r\n<p>The distillation process can be interrupted and resumed at any time. As a result, you do not have to worry that you may lose data if there's a blackout or if you need to free your GPU(s) to do something else. Resuming can be done through this program or through the <code>distill</code> script.</p>\r\n\r\n<h2>Explanation of Configuration Parameters</h2>\r\n\r\n<ul>\r\n    <li><a href=\"params/prefix.html\"><code>prefix</code></a></li>\r\n    <li><a href=\"params/character_image_file_name.html\"><code>character_image_file_name</code></a></li>\r\n    <li><a href=\"params/face_mask_image_file_name.html\"><code>face_mask_image_file_name</code></a></li>\r\n    <li><a href=\"params/num_cpu_workers.html\"><code>num_cpu_workers</code></a></li>\r\n    <li><a href=\"params/num_gpus.html\"><code>num_gpus</code></a></li>\r\n    <li><a href=\"params/face_morpher_random_seed_0.html\"><code>face_morpher_random_seed_0</code></a></li>\r\n    <li><a href=\"params/face_morpher_random_seed_1.html\"><code>face_morpher_random_seed_1</code></a></li>\r\n    <li><a href=\"params/face_morpher_batch_size.html\"><code>face_morpher_batch_size</code></a></li>\r\n    <li><a href=\"params/body_morpher_random_seed_0.html\"><code>body_morpher_random_seed_0</code></a></li>\r\n    <li><a href=\"params/body_morpher_random_seed_1.html\"><code>body_morpher_random_seed_1</code></a></li>\r\n    <li><a href=\"params/body_morpher_batch_size.html\"><code>body_morpher_batch_size</code></a></li>\r\n    <li><a href=\"params/num_training_examples_per_sample_output.html\"><code>num_training_examples_per_sample_output</code></a></li>\r\n</ul>\r\n\r\n</body>\r\n</html>"
  },
  {
    "path": "distiller-ui-doc/params/body_morpher_batch_size.html",
    "content": "<html lang=\"en\">\r\n<head>\r\n    <title>Distiller UI Documentation: body_morpher_batch_size</title>\r\n</head>\r\n<body>\r\n<h1><code>body_morpher_batch_size</code></h1>\r\n\r\n<p>The \"batch size\" is the number of training examples shown to a machine learning model in one round of parameter update. This parameter is the batch size for training the student body morpher. We recommend you set it to 8. However, if your computer does not have enough GPU RAM, you can reduce the number to any smaller positive integer.</p>\r\n\r\n<hr>\r\n<a href=\"../index.html\">Back to main documentation</a>\r\n</body>\r\n</html>"
  },
  {
    "path": "distiller-ui-doc/params/body_morpher_random_seed_0.html",
    "content": "<html lang=\"en\">\r\n<head>\r\n    <title>Distiller UI Documentation: body_morpher_random_seed_0</title>\r\n</head>\r\n<body>\r\n<h1><code>body_morpher_random_seed_0</code></h1>\r\n\r\n<p>This parameter will be used as a random seed in the process of training the student body morpher. It can be any non-negative integer from 0 to 2<sup>64</sup>-1. You can specify the number directly, or use the \"Randomize\" button to specify a random one.</p>\r\n\r\n<hr>\r\n<a href=\"../index.html\">Back to main documentation</a>\r\n</body>\r\n</html>"
  },
  {
    "path": "distiller-ui-doc/params/body_morpher_random_seed_1.html",
    "content": "<html lang=\"en\">\r\n<head>\r\n    <title>Distiller UI Documentation: body_morpher_random_seed_1</title>\r\n</head>\r\n<body>\r\n<h1><code>body_morpher_random_seed_1</code></h1>\r\n\r\n<p>This parameter will be used as a random seed in the process of training the student body morpher. It can be any non-negative integer from 0 to 2<sup>64</sup>-1. You can specify the number directly, or use the \"Randomize\" button to specify a random one.</p>\r\n\r\n<hr>\r\n<a href=\"../index.html\">Back to main documentation</a>\r\n</body>\r\n</html>"
  },
  {
    "path": "distiller-ui-doc/params/character_image_file_name.html",
    "content": "<html lang=\"en\">\r\n<head>\r\n    <title>Distiller UI Documentation: character_image_file_name</title>\r\n</head>\r\n<body>\r\n<h1><code>character_image_file_name</code></h1>\r\n<p>This is the name of the file of an image of a humanoid character. The image must conform to the following specifications.</p>\r\n\r\n<p>\r\n<ul>\r\n    <li>It MUST in the PNG format.</li>\r\n    <li>It MUST have an alpha channel.</li>\r\n    <li>It MUST be 512 x 512.</li>\r\n    <li>It MUST contain only one humanoid character.</li>\r\n    <li>The character should be standing upright and facing forward.</li>\r\n    <li>The character's hands should be below and far from the head.</li>\r\n    <li>The head of the character should roughly be contained in the 128 x 128 box in the middle of the top half of the image.</li>\r\n    <li>The alpha channels of all pixels that do not belong to the character (i.e., background pixels) must be 0.</li>\r\n</ul>\r\n</p>\r\n\r\n<p>\r\n<img src=\"../images/input_spec.png\" alt=\"\">\r\n</p>\r\n\r\n<p>Once you have chosen the image, a crop of the character face will be shown on the right side of the window. In order for the distillation process works correctly, <b>make sure that all the movable parts of the face&mdash; eyes, eyebrows, mouth, jaw line &mdash; can all be seen in this crop.</b></p>\r\n\r\n<p>\r\n<table border=\"1\" cellpadding=\"5\">\r\n    <tr>\r\n        <td align=\"center\"><img src=\"../images/face_crop_ok.png\" alt=\"\"><br><font size=\"18\" color=\"green\">&#9745;</font></td>\r\n        <td>This image is GOOD because we can see all of the eyes, eyebrows, mouth, and jaw line in the image.</td>\r\n    </tr>\r\n    <tr>\r\n        <td align=\"center\"><img src=\"../images/face_crop_not_ok_00.png\" alt=\"\"><br><font size=\"18\" color=\"red\">&#9746;</font></td>\r\n        <td>This image is NOT GOOD because we cannot see the whole of the jaw line in the image</td>\r\n    </tr>\r\n    <tr>\r\n        <td align=\"center\"><img src=\"../images/face_crop_not_ok_01.png\" alt=\"\"><br><font size=\"18\" color=\"red\">&#9746;</font></td>\r\n        <td>This image is NOT GOOD because we cannot see the whole of the right eye and eyebrow in the image.</td>\r\n    </tr>\r\n    <tr>\r\n        <td align=\"center\"><img src=\"../images/face_crop_not_ok_02.png\" alt=\"\"><br><font size=\"18\" color=\"red\">&#9746;</font></td>\r\n        <td>This image is NOT GOOD because we cannot see the whole of the eyebrows in the image.</td>\r\n    </tr>\r\n</table>\r\n</p>\r\n\r\n<p>The <code>data/images</code> directory contains two example images that conform to all the above specifications: <code>data/images/lambda_00.png</code> and <code>data/images/lambda_01.png</code>. Please use them as references.</p>\r\n\r\n<hr>\r\n\r\n<a href=\"../index.html\">Back to main documentation</a>\r\n</body>\r\n</html>"
  },
  {
    "path": "distiller-ui-doc/params/face_mask_image_file_name.html",
    "content": "<html lang=\"en\">\r\n<head>\r\n    <title>Distiller UI Documentation: face_mask_image_file_name</title>\r\n</head>\r\n<body>\r\n<h1><code>face_mask_image_file_name</code></h1>\r\n\r\n<p>This is the name of the file containing binary masks of movable facial organs of the character. It is probably the best to see an example.</p>\r\n\r\n<p>\r\n    <img src=\"../../data/images/lambda_00_face_mask.png\" alt=\"\">\r\n</p>\r\n\r\n<p>A \"face mask image\" conforms to the following specification.</p>\r\n\r\n<p>\r\n    <ul>\r\n        <li>It must be in the PNG format.</li>\r\n        <li>It must be 512 x 512.</li>\r\n        <li>It must be an RGB image (i.e., no alpha channel).</li>\r\n        <li>All pixels must be either block (0,0,0) or white (255,255,255).</li>\r\n        <li>The white pixels should cover movable parts of the face.</li>\r\n    </ul>\r\n</p>\r\n\r\n<p>We recommend creating three rectangles.</p>\r\n\r\n<p>\r\n    <ul>\r\n        <li>One covers the right eye and eyebrow.</li>\r\n        <li>One covers the left eye and eyebrow.</li>\r\n        <li>One covers the mouth and the jaw line.</li>\r\n    </ul>\r\n</p>\r\n\r\n<p>The rectangles for the eyes and the eyebrows should extend above the eyes to some extent because the eyebrows can move upward.</p>\r\n\r\n<p>Once you have specified the face mask image with the \"Change...\" button, a crop of the face area will show up on the left side of the window. If the character image has also been specified, an image of the face mask laid over the character's face will also show up. Use this image to check whether the masks are covering everything.</p>\r\n\r\n<p>\r\n    <img src=\"../images/left_panel.png\" alt=\"\">\r\n</p>\r\n\r\n<hr>\r\n<a href=\"../index.html\">Back to main documentation</a>\r\n</body>\r\n</html>"
  },
  {
    "path": "distiller-ui-doc/params/face_morpher_batch_size.html",
    "content": "<html lang=\"en\">\r\n<head>\r\n    <title>Distiller UI Documentation: face_morpher_batch_size</title>\r\n</head>\r\n<body>\r\n<h1><code>face_morpher_batch_size</code></h1>\r\n\r\n<p>The \"batch size\" is the number of training examples shown to a machine learning model in one round of parameter update. This parameter is the batch size for training the student face morpher. We recommend you set it to 8. However, if your computer does not have enough GPU RAM, you can reduce the number to any smaller positive integer.</p>\r\n\r\n<hr>\r\n<a href=\"../index.html\">Back to main documentation</a>\r\n</body>\r\n</html>"
  },
  {
    "path": "distiller-ui-doc/params/face_morpher_random_seed_0.html",
    "content": "<html lang=\"en\">\r\n<head>\r\n    <title>Distiller UI Documentation: face_morpher_random_seed_0</title>\r\n</head>\r\n<body>\r\n<h1><code>face_morpher_random_seed_0</code></h1>\r\n\r\n<p>This parameter will be used as a random seed in the process of training the student face morpher. It can be any non-negative integer from 0 to 2<sup>64</sup>-1. You can specify the number directly, or use the \"Randomize\" button to specify a random one.</p>\r\n\r\n<hr>\r\n<a href=\"../index.html\">Back to main documentation</a>\r\n</body>\r\n</html>"
  },
  {
    "path": "distiller-ui-doc/params/face_morpher_random_seed_1.html",
    "content": "<html lang=\"en\">\r\n<head>\r\n    <title>Distiller UI Documentation: face_morpher_random_seed_1</title>\r\n</head>\r\n<body>\r\n<h1><code>face_morpher_random_seed_1</code></h1>\r\n\r\n<p>This parameter will be used as a random seed in the process of training the student face morpher. It can be any non-negative integer from 0 to 2<sup>64</sup>-1. You can specify the number directly, or use the \"Randomize\" button to specify a random one.</p>\r\n\r\n<hr>\r\n<a href=\"../index.html\">Back to main documentation</a>\r\n</body>\r\n</html>"
  },
  {
    "path": "distiller-ui-doc/params/num_cpu_workers.html",
    "content": "<html lang=\"en\">\r\n<head>\r\n    <title>Distiller UI Documentation: face_mask_image_file_name</title>\r\n</head>\r\n<body>\r\n<h1><code>num_cpu_workers</code></h1>\r\n\r\n<p>This is the number of worker threads that are used to process pose data during training of the student models. Typically, 1 would be enough, but you can specify up to the number of CPUs your computer has.</p>\r\n\r\n<hr>\r\n<a href=\"../index.html\">Back to main documentation</a>\r\n</body>\r\n</html>"
  },
  {
    "path": "distiller-ui-doc/params/num_gpus.html",
    "content": "<html lang=\"en\">\r\n<head>\r\n    <title>Distiller UI Documentation: num_gpus</title>\r\n</head>\r\n<body>\r\n<h1><code>num_gpus</code></h1>\r\n\r\n<p>This is the number of GPUs that are used to to train the student models. Typically, 1 would be enough. However, you can specify up to the number of Nvidia GPUs that your PC has.</p>\r\n\r\n<hr>\r\n<a href=\"../index.html\">Back to main documentation</a>\r\n</body>\r\n</html>"
  },
  {
    "path": "distiller-ui-doc/params/num_training_examples_per_sample_output.html",
    "content": "<html lang=\"en\">\r\n<head>\r\n    <title>Distiller UI Documentation: num_training_example_per_sample_output</title>\r\n</head>\r\n<body>\r\n<h1><code>num_training_example_per_sample_output</code></h1>\r\n\r\n<p>During training of a student model, the training process would periodically create \"sample output\" produced by the model being trained in order to allow the user to see training progress and observe whether there is any anomalies.</p>\r\n\r\n<p>This parameter specifies how frequent the sample outputs are generated. You can indicate whether you want a sample output to be generated every time the trained model has beeen shown 10,000, 100,000 or 1,000,000 training examples. If you do not care about sample outputs, you can also make the process not generate any sample outputs at all.</p>\r\n\r\n<hr>\r\n<a href=\"../index.html\">Back to main documentation</a>\r\n</body>\r\n</html>"
  },
  {
    "path": "distiller-ui-doc/params/prefix.html",
    "content": "<html lang=\"en\">\r\n<head>\r\n    <title>Distiller UI Documentation: prefix</title>\r\n</head>\r\n<body>\r\n<h1><code>prefix</code></h1>\r\n\r\n<p><code>prefix</code> is the name of the directory under which the distillation process will store the trained models and other intermediate data. Please choose a directory that is a subdirectory of the directory that stores the <code>talking-head-anime-4-demo</code>'s repository.</p>\r\n\r\n<hr>\r\n<a href=\"../index.html\">Back to main documentation</a>\r\n</body>\r\n</html>"
  },
  {
    "path": "docs/character_model_ifacialmocap_puppeteer.md",
    "content": "# `character_model_ifacialmocap_puppeteer`\r\n\r\nThis program allows the user to control trained student models with their facial movement, which is captured by the [iFacialMocap](https://www.ifacialmocap.com/) software. You can purchase the software from the App Store for 980 Japanese Yen.\r\n\r\n## Invoking the Program\r\n\r\nMake sure you have (1) created a Python environment and (2) downloaded model files as instruction in the [main README file](../README.md).\r\n\r\n### Instruction for Linux/OSX Users\r\n\r\n1. Open a shell.\r\n2. `cd` to the repository's directory.\r\n   ```\r\n   cd SOMEWHERE/talking-head-anime-4-demo\r\n   ```\r\n3. Run the program.\r\n   ```\r\n   bin/run src/tha4/app/character_model_ifacialmocap_puppeteer.py\r\n   ```   \r\n\r\n### Instruction for Windows Users\r\n\r\n1. Open a shell.\r\n2. `cd` to the repository's directory.\r\n   ```\r\n   cd SOMEWHERE\\talking-head-anime-4-demo\r\n   ```\r\n3. Run the program.\r\n   ```\r\n   bin\\run.bat src\\tha4\\app\\character_model_ifacialmocap_puppeteer.py\r\n   ```\r\n\r\n## Usage\r\n\r\n1. Run iFacialMocap on your iOS device. It should show you the device's IP address. Jot it down. Keep the app open.\r\n\r\n   ![IP address in iFacialMocap screen](images/ifacialmocap_ip.jpg \"IP address in iFacialMocap screen\")\r\n\r\n2. Invoke the `character_model_ifacialmocap_puppeteer` application.\r\n\r\n3. You will see a text box with label \"Capture Device IP.\" Write the iOS device's IP address that you jotted down there.\r\n\r\n   ![Write IP address of your iOS device in the 'Capture Device IP' text box.](images/ifacialmocap-puppeteer-device-ip.png \"Write IP address of your iOS device in the 'Capture Device IP' text box.\")\r\n\r\n4. Click the \"START CAPTURE!\" button to the right.\r\n\r\n   ![Click the 'START CAPTURE!' button.](images/ifacialmocap-puppeteer-start-capture.png \"Click the 'START CAPTURE!' button.\")\r\n\r\n   If the programs are connected properly, you should see the numbers in the bottom part of the window change when you move your head.\r\n\r\n   ![The numbers in the bottom part of the window should change when you move your head.](images/ifacialmocap-puppeteer-moving-numbers.png \"The numbers in the bottom part of the window should change when you move your head.\")\r\n\r\n5. Now, you can load a student model, and the character should follow your facial movement."
  },
  {
    "path": "docs/character_model_manual_poser.md",
    "content": "# `character_model_manual_poser`\r\n\r\nThis program allows the user to control trained student models with a graphical user interface, mostly sliders.\r\n\r\n## Invoking the Program\r\n\r\nMake sure you have (1) created a Python environment and (2) downloaded model files as instruction in the [main README file](../README.md).\r\n\r\n### Instruction for Linux/OSX Users\r\n\r\n1. Open a shell.\r\n2. `cd` to the repository's directory.\r\n   ```\r\n   cd SOMEWHERE/talking-head-anime-4-demo\r\n   ```\r\n3. Run the program.\r\n   ```\r\n   bin/run src/tha4/app/character_model_manual_poser.py\r\n   ```   \r\n\r\n### Instruction for Windows Users\r\n\r\n1. Open a shell.\r\n2. `cd` to the repository's directory.\r\n   ```\r\n   cd SOMEWHERE\\talking-head-anime-4-demo\r\n   ```\r\n3. Run the program.\r\n   ```\r\n   bin\\run.bat src\\tha4\\app\\character_model_manual_poser.py\r\n   ```   \r\n"
  },
  {
    "path": "docs/character_model_mediapipe_puppeteer.md",
    "content": "# `character_model_mediapipe_puppeteer`\r\n\r\nallows the user to control trained student models with their facial movement, which is captured by a web camera and processed by the [Mediapipe FaceLandmarker](https://developers.google.com/mediapipe/solutions/vision/face_landmarker) model.\r\n\r\n## Web Camera\r\n\r\nPlease make sure that, before you invoke the program, your computer has a web camera plugged in. The program will use a web camera, but it does not allow you to specify which. In case your machine has more than one web camera, you can turn off all camera except the one that you want to use. \r\n\r\nYou can also inspect the [source code](../src/tha4/app/character_model_mediapipe_puppeteer.py) and change the \r\n\r\n```\r\n    video_capture = cv2.VideoCapture(0)\r\n```\r\n\r\nline to choose a particular camera that you want to use.\r\n\r\n## Invoking the Program\r\n\r\nMake sure you have (1) created a Python environment and (2) downloaded model files as instruction in the [main README file](../README.md).\r\n\r\n### Instruction for Linux/OSX Users\r\n\r\n1. Open a shell.\r\n2. `cd` to the repository's directory.\r\n   ```\r\n   cd SOMEWHERE/talking-head-anime-4-demo\r\n   ```\r\n3. Run the program.\r\n   ```\r\n   bin/run src/tha4/app/character_model_mediapipe_puppeteer.py\r\n   ```   \r\n\r\n### Instruction for Windows Users\r\n\r\n1. Open a shell.\r\n2. `cd` to the repository's directory.\r\n   ```\r\n   cd SOMEWHERE\\talking-head-anime-4-demo\r\n   ```\r\n3. Run the program.\r\n   ```\r\n   bin\\run.bat src\\tha4\\app\\character_model_mediapipe_puppeteer.py\r\n   ```   \r\n"
  },
  {
    "path": "docs/distill.md",
    "content": "# `distill`\r\n\r\nThis program trains a student model given a configuration file, a $512 \\times 512$ RGBA character image, and a mask of facial organs.\r\n\r\n## Invoking the Program\r\n\r\nMake sure you have (1) created a Python environment and (2) downloaded model files as instruction in the [main README file](../README.md).\r\n\r\n### Instruction for Linux/OSX Users\r\n\r\n1. Open a shell.\r\n2. `cd` to the repository's directory.\r\n   ```\r\n   cd SOMEWHERE/talking-head-anime-4-demo\r\n   ```\r\n3. Run the program.\r\n   ```\r\n   bin/run src/tha4/app/distill.py <config-file>\r\n   ```\r\n   where `<config-file>` is a configuration file for creating a student model. More on this later.\r\n\r\n### Instruction for Windows Users\r\n\r\n1. Open a shell.\r\n2. `cd` to the repository's directory.\r\n   ```\r\n   cd SOMEWHERE\\talking-head-anime-4-demo\r\n   ```\r\n3. Run the program.\r\n   ```\r\n   bin\\run.bat src\\tha4\\app\\full_manual_poser.py <config-file>\r\n   ```   \r\n   where `<config-file>` is a configuration file for creating a student model. More on this later.\r\n\r\n## Configuration File\r\n\r\nA configuration file is a [YAML](https://yaml.org/) file that specify how to create a student model. This repository comes with two valid configuration files that you can peruse:\r\n\r\n* [data/distill_examples/lambda_00/config.yaml](../data/distill_examples/lambda_00/config.yaml)\r\n* [data/distill_examples/lambda_01/config.yaml](../data/distill_examples/lambda_01/config.yaml)\r\n\r\nI recommend that you use the `distiller_ui` program to create configuration files rather than writing them yourself. Inside the program, you can see what the fields are and what they mean.\r\n\r\n## What `distill` Outputs\r\n\r\nInside the configuration file, you specify a directory where the student models should be saved to in the `prefix` field. After `distill` is done with its job, the output directory will look like this:\r\n\r\n```\r\n+ <prefix-specified-in-config-file>\r\n  + body_morpher\r\n  + face_morpher\r\n  + character_model\r\n  - config.yaml\r\n```\r\n\r\nHere:\r\n\r\n* `config.yaml` is a copy of the configuration file that you wrote. \r\n* The `character_model` directory contains a trained student model that can be used with `character_model_manual_poser.md`, `character_model_ifacialmocap_puppeteer.md`, and `character_model_mediapipe_puppeteer.md`. \r\n* `body_morpher` is a scratch directory that was used to save intermediate results during the training of a part of the student model.\r\n* `face_morpher` is a scratch directory that was used to save intermediate results during the training of another part of the student model.\r\n\r\nYou only need what is inside the `character_model` directory. As a resulit, you can delete other files after the `character_model` directory has been filled. You can move the directory out to somewhere and rename it as long as the contents inside are not modified.\r\n\r\n## The Training Process Is Interruptible\r\n\r\nInvoking `distill` on a configuration will start a rather long process of training a student model. On a machine with an A6000 GPU, it takes about 30 hours to complete. As a result, it might take several days on machines with less powerful GPUs.\r\n\r\nThe training process is robust and interruptible. You can stop it any time by closing the shell window or by typing `Ctrl+C`. Intermediate results are periodically saved in the scratch directories, ready to be picked up at a later time when you are ready to train the student model again. To resume the process, just invoke `distill` again with the same configuration file that you started with, and the process will take care of itself."
  },
  {
    "path": "docs/distiller_ui.md",
    "content": "# `distiller_ui`\r\n\r\nThis program provides a user-friendly interface to the [`distill`](distill.md) program, allowing you to create training configurations and providing useful documentation.\r\n\r\n## Invoking the Program\r\n\r\nMake sure you have (1) created a Python environment and (2) downloaded model files as instruction in the [main README file](../README.md).\r\n\r\n### Instruction for Linux/OSX Users\r\n\r\n1. Open a shell.\r\n2. `cd` to the repository's directory.\r\n   ```\r\n   cd SOMEWHERE/talking-head-anime-4-demo\r\n   ```\r\n3. Run the program.\r\n   ```\r\n   bin/run src/tha4/app/distill_ui.py\r\n   ```   \r\n\r\n### Instruction for Windows Users\r\n\r\n1. Open a shell.\r\n2. `cd` to the repository's directory.\r\n   ```\r\n   cd SOMEWHERE\\talking-head-anime-4-demo\r\n   ```\r\n3. Run the program.\r\n   ```\r\n   bin\\run.bat src\\tha4\\app\\distill_ui.py\r\n   ```   \r\n\r\n## Usage\r\n\r\nPlease consult the documentation inside the program itself. It is available on the rightmost panel."
  },
  {
    "path": "docs/full_manual_poser.md",
    "content": "# `full_manual_poser`\r\n\r\nThis program uses the full version of the Talking Head(?) Anime 4 system to animate character images.\r\n\r\n## Invoking the Program\r\n\r\nMake sure you have (1) created a Python environment and (2) downloaded model files as instruction in the [main README file](../README.md).\r\n\r\n### Instruction for Linux/OSX Users\r\n\r\n1. Open a shell.\r\n2. `cd` to the repository's directory.\r\n   ```\r\n   cd SOMEWHERE/talking-head-anime-4-demo\r\n   ```\r\n3. Run the program.\r\n   ```\r\n   bin/run src/tha4/app/full_manual_poser.py\r\n   ```   \r\n\r\n### Instruction for Windows Users\r\n\r\n1. Open a shell.\r\n2. `cd` to the repository's directory.\r\n   ```\r\n   cd SOMEWHERE\\talking-head-anime-4-demo\r\n   ```\r\n3. Run the program.\r\n   ```\r\n   bin\\run.bat src\\tha4\\app\\full_manual_poser.py\r\n   ```"
  },
  {
    "path": "poetry/README.md",
    "content": ""
  },
  {
    "path": "poetry/pyproject.toml",
    "content": "[tool.poetry]\r\nname = \"talking-head-anime-4-demo\"\r\nversion = \"0.1.0\"\r\ndescription = \"Demo code for Talking Head(?) Anime 4\"\r\nauthors = [\"Pramook Khungurn <pong@pixiv.co.jp>\"]\r\nreadme = \"README.md\"\r\npackages = [\r\n    {include = \"tha4\", from = \"../src\"},\r\n]\r\n\r\n[tool.poetry.dependencies]\r\npython = \">=3.10, <3.11\"\r\ntorch = {version = \"1.13.1\", source = \"torch_cu117\"}\r\ntorchvision = {version = \"0.14.1\", source = \"torch_cu117\"}\r\ntensorboard = \"^2.15.1\"\r\nopencv-python = \"^4.8.1.78\"\r\nwxpython = \"^4.2.1\"\r\nnumpy-quaternion = \"^2022.4.2\"\r\npillow = \"^9.4.0\"\r\nmatplotlib = \"^3.6.3\"\r\neinops = \"^0.6.0\"\r\nmediapipe = \"^0.10.3\"\r\nnumpy = \"^1.26.3\"\r\nscipy = \"^1.12.0\"\r\nomegaconf = \"^2.3.0\"\r\n\r\n[[tool.poetry.source]]\r\nname = \"torch_cu117\"\r\nurl = \"https://download.pytorch.org/whl/cu117\"\r\npriority = \"explicit\"\r\n\r\n[build-system]\r\nrequires = [\"poetry-core\"]\r\nbuild-backend = \"poetry.core.masonry.api\""
  },
  {
    "path": "src/tha4/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/app/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/app/character_model_ifacialmocap_puppeteer.py",
    "content": "import os\r\nimport socket\r\nimport sys\r\nimport threading\r\nimport time\r\nfrom typing import Optional\r\n\r\nimport PIL.Image\r\n\r\nfrom tha4.shion.base.image_util import torch_linear_to_srgb\r\nfrom tha4.image_util import convert_linear_to_srgb\r\nfrom tha4.mocap.ifacialmocap_pose_converter_25 import create_ifacialmocap_pose_converter\r\nfrom tha4.app.full_manual_poser import resize_PIL_image\r\nfrom tha4.charmodel.character_model import CharacterModel\r\n\r\nsys.path.append(os.getcwd())\r\n\r\nfrom tha4.mocap.ifacialmocap_pose import create_default_ifacialmocap_pose\r\nfrom tha4.mocap.ifacialmocap_v2 import IFACIALMOCAP_PORT, IFACIALMOCAP_START_STRING, parse_ifacialmocap_v2_pose\r\n\r\nimport torch\r\nimport wx\r\n\r\nfrom tha4.mocap.ifacialmocap_constants import *\r\nfrom tha4.mocap.ifacialmocap_pose_converter import IFacialMocapPoseConverter\r\n\r\n\r\nclass FpsStatistics:\r\n    def __init__(self):\r\n        self.count = 100\r\n        self.fps = []\r\n\r\n    def add_fps(self, fps):\r\n        self.fps.append(fps)\r\n        while len(self.fps) > self.count:\r\n            del self.fps[0]\r\n\r\n    def get_average_fps(self):\r\n        if len(self.fps) == 0:\r\n            return 0.0\r\n        else:\r\n            return sum(self.fps) / len(self.fps)\r\n\r\n\r\nclass MainFrame(wx.Frame):\r\n    IMAGE_SIZE = 512\r\n\r\n    def __init__(self, pose_converter: IFacialMocapPoseConverter, device: torch.device):\r\n        super().__init__(None, wx.ID_ANY, \"iFacialMocap Puppeteer (Fuji)\")\r\n        self.poser = None\r\n        self.pose_converter = pose_converter\r\n        self.device = device\r\n\r\n        self.ifacialmocap_pose = create_default_ifacialmocap_pose()\r\n        self.source_image_bitmap = wx.Bitmap(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE)\r\n        self.result_image_bitmap = wx.Bitmap(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE)\r\n        self.wx_source_image = None\r\n        self.torch_source_image = None\r\n        self.last_pose = None\r\n        self.fps_statistics = FpsStatistics()\r\n        self.last_update_time = None\r\n\r\n        self.create_receiving_socket()\r\n        self.create_ui()\r\n        self.create_timers()\r\n        self.Bind(wx.EVT_CLOSE, self.on_close)\r\n\r\n        self.update_source_image_bitmap()\r\n        self.update_result_image_bitmap()\r\n\r\n    def create_receiving_socket(self):\r\n        self.receiving_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)\r\n        self.receiving_socket.bind((\"\", IFACIALMOCAP_PORT))\r\n        self.receiving_socket.setblocking(False)\r\n\r\n    def create_timers(self):\r\n        self.capture_timer = wx.Timer(self, wx.ID_ANY)\r\n        self.Bind(wx.EVT_TIMER, self.update_capture_panel, id=self.capture_timer.GetId())\r\n        self.animation_timer = wx.Timer(self, wx.ID_ANY)\r\n        self.Bind(wx.EVT_TIMER, self.update_result_image_bitmap, id=self.animation_timer.GetId())\r\n\r\n    def on_close(self, event: wx.Event):\r\n        # Stop the timers\r\n        self.animation_timer.Stop()\r\n        self.capture_timer.Stop()\r\n\r\n        # Close receiving socket\r\n        self.receiving_socket.close()\r\n\r\n        # Destroy the windows\r\n        self.Destroy()\r\n        event.Skip()\r\n\r\n    def on_start_capture(self, event: wx.Event):\r\n        capture_device_ip_address = self.capture_device_ip_text_ctrl.GetValue()\r\n        out_socket = None\r\n        try:\r\n            address = (capture_device_ip_address, IFACIALMOCAP_PORT)\r\n            out_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)\r\n            out_socket.sendto(IFACIALMOCAP_START_STRING, address)\r\n        except Exception as e:\r\n            message_dialog = wx.MessageDialog(self, str(e), \"Error!\", wx.OK)\r\n            message_dialog.ShowModal()\r\n            message_dialog.Destroy()\r\n        finally:\r\n            if out_socket is not None:\r\n                out_socket.close()\r\n\r\n    def read_ifacialmocap_pose(self):\r\n        if not self.animation_timer.IsRunning():\r\n            return self.ifacialmocap_pose\r\n        socket_bytes = None\r\n        while True:\r\n            try:\r\n                socket_bytes = self.receiving_socket.recv(8192)\r\n            except socket.error as e:\r\n                break\r\n        if socket_bytes is not None:\r\n            socket_string = socket_bytes.decode(\"utf-8\")\r\n            self.ifacialmocap_pose = parse_ifacialmocap_v2_pose(socket_string)\r\n        return self.ifacialmocap_pose\r\n\r\n    def on_erase_background(self, event: wx.Event):\r\n        pass\r\n\r\n    def create_animation_panel(self, parent):\r\n        self.animation_panel = wx.Panel(parent, style=wx.RAISED_BORDER)\r\n        self.animation_panel_sizer = wx.BoxSizer(wx.HORIZONTAL)\r\n        self.animation_panel.SetSizer(self.animation_panel_sizer)\r\n        self.animation_panel.SetAutoLayout(1)\r\n\r\n        image_size = MainFrame.IMAGE_SIZE\r\n\r\n        if True:\r\n            self.input_panel = wx.Panel(self.animation_panel, size=(image_size, image_size + 128),\r\n                                        style=wx.SIMPLE_BORDER)\r\n            self.input_panel_sizer = wx.BoxSizer(wx.VERTICAL)\r\n            self.input_panel.SetSizer(self.input_panel_sizer)\r\n            self.input_panel.SetAutoLayout(1)\r\n            self.animation_panel_sizer.Add(self.input_panel, 0, wx.FIXED_MINSIZE)\r\n\r\n            self.source_image_panel = wx.Panel(self.input_panel, size=(image_size, image_size), style=wx.SIMPLE_BORDER)\r\n            self.source_image_panel.Bind(wx.EVT_PAINT, self.paint_source_image_panel)\r\n            self.source_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)\r\n            self.input_panel_sizer.Add(self.source_image_panel, 0, wx.FIXED_MINSIZE)\r\n\r\n            self.load_model_button = wx.Button(self.input_panel, wx.ID_ANY, \"Load Model\")\r\n            self.input_panel_sizer.Add(self.load_model_button, 1, wx.EXPAND)\r\n            self.load_model_button.Bind(wx.EVT_BUTTON, self.load_model)\r\n\r\n            self.input_panel_sizer.Fit(self.input_panel)\r\n\r\n        if True:\r\n            self.pose_converter.init_pose_converter_panel(self.animation_panel)\r\n\r\n        if True:\r\n            self.animation_left_panel = wx.Panel(self.animation_panel, style=wx.SIMPLE_BORDER)\r\n            self.animation_left_panel_sizer = wx.BoxSizer(wx.VERTICAL)\r\n            self.animation_left_panel.SetSizer(self.animation_left_panel_sizer)\r\n            self.animation_left_panel.SetAutoLayout(1)\r\n            self.animation_panel_sizer.Add(self.animation_left_panel, 0, wx.EXPAND)\r\n\r\n            self.result_image_panel = wx.Panel(self.animation_left_panel, size=(image_size, image_size),\r\n                                               style=wx.SIMPLE_BORDER)\r\n            self.result_image_panel.Bind(wx.EVT_PAINT, self.paint_result_image_panel)\r\n            self.result_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)\r\n            self.animation_left_panel_sizer.Add(self.result_image_panel, 0, wx.FIXED_MINSIZE)\r\n\r\n            separator = wx.StaticLine(self.animation_left_panel, -1, size=(256, 5))\r\n            self.animation_left_panel_sizer.Add(separator, 0, wx.EXPAND)\r\n\r\n            background_text = wx.StaticText(self.animation_left_panel, label=\"--- Background ---\",\r\n                                            style=wx.ALIGN_CENTER)\r\n            self.animation_left_panel_sizer.Add(background_text, 0, wx.EXPAND)\r\n\r\n            self.output_background_choice = wx.Choice(\r\n                self.animation_left_panel,\r\n                choices=[\r\n                    \"TRANSPARENT\",\r\n                    \"GREEN\",\r\n                    \"BLUE\",\r\n                    \"BLACK\",\r\n                    \"WHITE\"\r\n                ])\r\n            self.output_background_choice.SetSelection(0)\r\n            self.animation_left_panel_sizer.Add(self.output_background_choice, 0, wx.EXPAND)\r\n\r\n            separator = wx.StaticLine(self.animation_left_panel, -1, size=(256, 5))\r\n            self.animation_left_panel_sizer.Add(separator, 0, wx.EXPAND)\r\n\r\n            self.fps_text = wx.StaticText(self.animation_left_panel, label=\"\")\r\n            self.animation_left_panel_sizer.Add(self.fps_text, wx.SizerFlags().Border())\r\n\r\n            self.animation_left_panel_sizer.Fit(self.animation_left_panel)\r\n\r\n        self.animation_panel_sizer.Fit(self.animation_panel)\r\n\r\n    def create_ui(self):\r\n        self.main_sizer = wx.BoxSizer(wx.VERTICAL)\r\n        self.SetSizer(self.main_sizer)\r\n        self.SetAutoLayout(1)\r\n\r\n        self.capture_pose_lock = threading.Lock()\r\n\r\n        self.create_connection_panel(self)\r\n        self.main_sizer.Add(self.connection_panel, wx.SizerFlags(0).Expand().Border(wx.ALL, 5))\r\n\r\n        self.create_animation_panel(self)\r\n        self.main_sizer.Add(self.animation_panel, wx.SizerFlags(0).Expand().Border(wx.ALL, 5))\r\n\r\n        self.create_capture_panel(self)\r\n        self.main_sizer.Add(self.capture_panel, wx.SizerFlags(0).Expand().Border(wx.ALL, 5))\r\n\r\n        self.main_sizer.Fit(self)\r\n\r\n    def create_connection_panel(self, parent):\r\n        self.connection_panel = wx.Panel(parent, style=wx.RAISED_BORDER)\r\n        self.connection_panel_sizer = wx.BoxSizer(wx.HORIZONTAL)\r\n        self.connection_panel.SetSizer(self.connection_panel_sizer)\r\n        self.connection_panel.SetAutoLayout(1)\r\n\r\n        capture_device_ip_text = wx.StaticText(self.connection_panel, label=\"Capture Device IP:\", style=wx.ALIGN_RIGHT)\r\n        self.connection_panel_sizer.Add(capture_device_ip_text, wx.SizerFlags(0).FixedMinSize().Border(wx.ALL, 3))\r\n\r\n        self.capture_device_ip_text_ctrl = wx.TextCtrl(self.connection_panel, value=\"192.168.0.1\")\r\n        self.connection_panel_sizer.Add(self.capture_device_ip_text_ctrl, wx.SizerFlags(1).Expand().Border(wx.ALL, 3))\r\n\r\n        self.start_capture_button = wx.Button(self.connection_panel, label=\"START CAPTURE!\")\r\n        self.connection_panel_sizer.Add(self.start_capture_button, wx.SizerFlags(0).FixedMinSize().Border(wx.ALL, 3))\r\n        self.start_capture_button.Bind(wx.EVT_BUTTON, self.on_start_capture)\r\n\r\n    def create_capture_panel(self, parent):\r\n        self.capture_panel = wx.Panel(parent, style=wx.RAISED_BORDER)\r\n        self.capture_panel_sizer = wx.FlexGridSizer(cols=5)\r\n        for i in range(5):\r\n            self.capture_panel_sizer.AddGrowableCol(i)\r\n        self.capture_panel.SetSizer(self.capture_panel_sizer)\r\n        self.capture_panel.SetAutoLayout(1)\r\n\r\n        self.rotation_labels = {}\r\n        self.rotation_value_labels = {}\r\n        rotation_column_0 = self.create_rotation_column(self.capture_panel, RIGHT_EYE_BONE_ROTATIONS)\r\n        self.capture_panel_sizer.Add(rotation_column_0, wx.SizerFlags(0).Expand().Border(wx.ALL, 3))\r\n        rotation_column_1 = self.create_rotation_column(self.capture_panel, LEFT_EYE_BONE_ROTATIONS)\r\n        self.capture_panel_sizer.Add(rotation_column_1, wx.SizerFlags(0).Expand().Border(wx.ALL, 3))\r\n        rotation_column_2 = self.create_rotation_column(self.capture_panel, HEAD_BONE_ROTATIONS)\r\n        self.capture_panel_sizer.Add(rotation_column_2, wx.SizerFlags(0).Expand().Border(wx.ALL, 3))\r\n\r\n    def create_rotation_column(self, parent, rotation_names):\r\n        column_panel = wx.Panel(parent, style=wx.SIMPLE_BORDER)\r\n        column_panel_sizer = wx.FlexGridSizer(cols=2)\r\n        column_panel_sizer.AddGrowableCol(1)\r\n        column_panel.SetSizer(column_panel_sizer)\r\n        column_panel.SetAutoLayout(1)\r\n\r\n        for rotation_name in rotation_names:\r\n            self.rotation_labels[rotation_name] = wx.StaticText(\r\n                column_panel, label=rotation_name, style=wx.ALIGN_RIGHT)\r\n            column_panel_sizer.Add(self.rotation_labels[rotation_name],\r\n                                   wx.SizerFlags(1).Expand().Border(wx.ALL, 3))\r\n\r\n            self.rotation_value_labels[rotation_name] = wx.TextCtrl(\r\n                column_panel, style=wx.TE_RIGHT)\r\n            self.rotation_value_labels[rotation_name].SetValue(\"0.00\")\r\n            self.rotation_value_labels[rotation_name].Disable()\r\n            column_panel_sizer.Add(self.rotation_value_labels[rotation_name],\r\n                                   wx.SizerFlags(1).Expand().Border(wx.ALL, 3))\r\n\r\n        column_panel.GetSizer().Fit(column_panel)\r\n        return column_panel\r\n\r\n    def paint_capture_panel(self, event: wx.Event):\r\n        self.update_capture_panel(event)\r\n\r\n    def update_capture_panel(self, event: wx.Event):\r\n        data = self.ifacialmocap_pose\r\n        for rotation_name in ROTATION_NAMES:\r\n            value = data[rotation_name]\r\n            self.rotation_value_labels[rotation_name].SetValue(\"%0.2f\" % value)\r\n\r\n    @staticmethod\r\n    def convert_to_100(x):\r\n        return int(max(0.0, min(1.0, x)) * 100)\r\n\r\n    def paint_source_image_panel(self, event: wx.Event):\r\n        wx.BufferedPaintDC(self.source_image_panel, self.source_image_bitmap)\r\n\r\n    def update_source_image_bitmap(self):\r\n        dc = wx.MemoryDC()\r\n        dc.SelectObject(self.source_image_bitmap)\r\n        if self.wx_source_image is None:\r\n            self.draw_nothing_yet_string(dc)\r\n        else:\r\n            dc.Clear()\r\n            dc.DrawBitmap(self.wx_source_image, 0, 0, True)\r\n        del dc\r\n\r\n    def draw_nothing_yet_string(self, dc):\r\n        dc.Clear()\r\n        font = wx.Font(wx.FontInfo(14).Family(wx.FONTFAMILY_SWISS))\r\n        dc.SetFont(font)\r\n        w, h = dc.GetTextExtent(\"Nothing yet!\")\r\n        dc.DrawText(\"Nothing yet!\", (MainFrame.IMAGE_SIZE - w) // 2, (MainFrame.IMAGE_SIZE - h) // 2)\r\n\r\n    def paint_result_image_panel(self, event: wx.Event):\r\n        wx.BufferedPaintDC(self.result_image_panel, self.result_image_bitmap)\r\n\r\n    def update_result_image_bitmap(self, event: Optional[wx.Event] = None):\r\n        ifacialmocap_pose = self.read_ifacialmocap_pose()\r\n        current_pose = self.pose_converter.convert(ifacialmocap_pose)\r\n        if self.last_pose is not None and self.last_pose == current_pose:\r\n            return\r\n        self.last_pose = current_pose\r\n\r\n        if self.torch_source_image is None or self.poser is None:\r\n            dc = wx.MemoryDC()\r\n            dc.SelectObject(self.result_image_bitmap)\r\n            self.draw_nothing_yet_string(dc)\r\n            del dc\r\n            return\r\n\r\n        pose = torch.tensor(current_pose, device=self.device, dtype=self.poser.get_dtype())\r\n\r\n        with torch.no_grad():\r\n            output_image = self.poser.pose(self.torch_source_image, pose)[0].float()\r\n            output_image = torch.clip((output_image + 1.0) / 2.0, 0.0, 1.0)\r\n            output_image = convert_linear_to_srgb(output_image)\r\n\r\n            background_choice = self.output_background_choice.GetSelection()\r\n            if background_choice == 0:\r\n                pass\r\n            else:\r\n                background = torch.zeros(4, output_image.shape[1], output_image.shape[2], device=self.device)\r\n                background[3, :, :] = 1.0\r\n                if background_choice == 1:\r\n                    background[1, :, :] = 1.0\r\n                    output_image = self.blend_with_background(output_image, background)\r\n                elif background_choice == 2:\r\n                    background[2, :, :] = 1.0\r\n                    output_image = self.blend_with_background(output_image, background)\r\n                elif background_choice == 3:\r\n                    output_image = self.blend_with_background(output_image, background)\r\n                else:\r\n                    background[0:3, :, :] = 1.0\r\n                    output_image = self.blend_with_background(output_image, background)\r\n\r\n            c, h, w = output_image.shape\r\n            output_image = 255.0 * torch.transpose(output_image.reshape(c, h * w), 0, 1).reshape(h, w, c)\r\n            output_image = output_image.byte()\r\n\r\n        numpy_image = output_image.detach().cpu().numpy()\r\n        wx_image = wx.ImageFromBuffer(numpy_image.shape[0],\r\n                                      numpy_image.shape[1],\r\n                                      numpy_image[:, :, 0:3].tobytes(),\r\n                                      numpy_image[:, :, 3].tobytes())\r\n        wx_bitmap = wx_image.ConvertToBitmap()\r\n\r\n        dc = wx.MemoryDC()\r\n        dc.SelectObject(self.result_image_bitmap)\r\n        dc.Clear()\r\n        dc.DrawBitmap(wx_bitmap,\r\n                      (MainFrame.IMAGE_SIZE - numpy_image.shape[0]) // 2,\r\n                      (MainFrame.IMAGE_SIZE - numpy_image.shape[1]) // 2, True)\r\n        del dc\r\n\r\n        time_now = time.time_ns()\r\n        if self.last_update_time is not None:\r\n            elapsed_time = time_now - self.last_update_time\r\n            fps = 1.0 / (elapsed_time / 10 ** 9)\r\n            if self.torch_source_image is not None:\r\n                self.fps_statistics.add_fps(fps)\r\n            self.fps_text.SetLabelText(\"FPS = %0.2f\" % self.fps_statistics.get_average_fps())\r\n        self.last_update_time = time_now\r\n\r\n        self.Refresh()\r\n\r\n    def blend_with_background(self, numpy_image, background):\r\n        alpha = numpy_image[3:4, :, :]\r\n        color = numpy_image[0:3, :, :]\r\n        new_color = color * alpha + (1.0 - alpha) * background[0:3, :, :]\r\n        return torch.cat([new_color, background[3:4, :, :]], dim=0)\r\n\r\n    def load_model(self, event: wx.Event):\r\n        dir_name = \"data/character_models\"\r\n        file_dialog = wx.FileDialog(self, \"Choose a model\", dir_name, \"\", \"*.yaml\", wx.FD_OPEN)\r\n        if file_dialog.ShowModal() == wx.ID_OK:\r\n            character_model_json_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename())\r\n            try:\r\n                self.character_model = CharacterModel.load(character_model_json_file_name)\r\n                self.torch_source_image = self.character_model.get_character_image(self.device)\r\n                pil_image = resize_PIL_image(\r\n                    PIL.Image.open(self.character_model.character_image_file_name),\r\n                    (MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE))\r\n                w, h = pil_image.size\r\n                self.wx_source_image = wx.Bitmap.FromBufferRGBA(w, h, pil_image.convert(\"RGBA\").tobytes())\r\n                self.update_source_image_bitmap()\r\n                self.poser = self.character_model.get_poser(self.device)\r\n            except Exception:\r\n                message_dialog = wx.MessageDialog(\r\n                    self, \"Could not load character model \" + character_model_json_file_name, \"Poser\", wx.OK)\r\n                message_dialog.ShowModal()\r\n                message_dialog.Destroy()\r\n        file_dialog.Destroy()\r\n        self.Refresh()\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    device = torch.device('cuda:0')\r\n\r\n    pose_converter = create_ifacialmocap_pose_converter()\r\n\r\n    app = wx.App()\r\n    main_frame = MainFrame(pose_converter, device)\r\n    main_frame.Show(True)\r\n    main_frame.capture_timer.Start(10)\r\n    main_frame.animation_timer.Start(10)\r\n    app.MainLoop()\r\n"
  },
  {
    "path": "src/tha4/app/character_model_manual_poser.py",
    "content": "import logging\r\nimport os\r\nimport sys\r\nimport time\r\nfrom typing import List\r\n\r\nfrom tha4.charmodel.character_model import CharacterModel\r\nfrom tha4.image_util import resize_PIL_image, convert_output_image_from_torch_to_numpy\r\nfrom tha4.poser.modes.mode_14 import get_pose_parameters\r\n\r\nsys.path.append(os.getcwd())\r\n\r\nimport PIL.Image\r\nimport torch\r\nimport wx\r\n\r\nfrom tha4.poser.poser import PoseParameterCategory, PoseParameterGroup\r\n\r\n\r\nclass MorphCategoryControlPanel(wx.Panel):\r\n    def __init__(self,\r\n                 parent,\r\n                 title: str,\r\n                 pose_param_category: PoseParameterCategory,\r\n                 param_groups: List[PoseParameterGroup]):\r\n        super().__init__(parent, style=wx.SIMPLE_BORDER)\r\n        self.pose_param_category = pose_param_category\r\n        self.sizer = wx.BoxSizer(wx.VERTICAL)\r\n        self.SetSizer(self.sizer)\r\n        self.SetAutoLayout(1)\r\n\r\n        title_text = wx.StaticText(self, label=title, style=wx.ALIGN_CENTER)\r\n        self.sizer.Add(title_text, 0, wx.EXPAND)\r\n\r\n        self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category]\r\n        self.choice = wx.Choice(self, choices=[group.get_group_name() for group in self.param_groups])\r\n        if len(self.param_groups) > 0:\r\n            self.choice.SetSelection(0)\r\n        self.choice.Bind(wx.EVT_CHOICE, self.on_choice_updated)\r\n        self.sizer.Add(self.choice, 0, wx.EXPAND)\r\n\r\n        self.left_slider = wx.Slider(self, minValue=-1000, maxValue=1000, value=-1000, style=wx.HORIZONTAL)\r\n        self.sizer.Add(self.left_slider, 0, wx.EXPAND)\r\n\r\n        self.right_slider = wx.Slider(self, minValue=-1000, maxValue=1000, value=-1000, style=wx.HORIZONTAL)\r\n        self.sizer.Add(self.right_slider, 0, wx.EXPAND)\r\n\r\n        self.checkbox = wx.CheckBox(self, label=\"Show\")\r\n        self.checkbox.SetValue(True)\r\n        self.sizer.Add(self.checkbox, 0, wx.SHAPED | wx.ALIGN_CENTER)\r\n\r\n        self.update_ui()\r\n\r\n        self.sizer.Fit(self)\r\n\r\n    def update_ui(self):\r\n        param_group = self.param_groups[self.choice.GetSelection()]\r\n        if param_group.is_discrete():\r\n            self.left_slider.Enable(False)\r\n            self.right_slider.Enable(False)\r\n            self.checkbox.Enable(True)\r\n        elif param_group.get_arity() == 1:\r\n            self.left_slider.Enable(True)\r\n            self.right_slider.Enable(False)\r\n            self.checkbox.Enable(False)\r\n        else:\r\n            self.left_slider.Enable(True)\r\n            self.right_slider.Enable(True)\r\n            self.checkbox.Enable(False)\r\n\r\n    def on_choice_updated(self, event: wx.Event):\r\n        param_group = self.param_groups[self.choice.GetSelection()]\r\n        if param_group.is_discrete():\r\n            self.checkbox.SetValue(True)\r\n        self.update_ui()\r\n\r\n    def set_param_value(self, pose: List[float]):\r\n        if len(self.param_groups) == 0:\r\n            return\r\n        selected_morph_index = self.choice.GetSelection()\r\n        param_group = self.param_groups[selected_morph_index]\r\n        param_index = param_group.get_parameter_index()\r\n        if param_group.is_discrete():\r\n            if self.checkbox.GetValue():\r\n                for i in range(param_group.get_arity()):\r\n                    pose[param_index + i] = 1.0\r\n        else:\r\n            param_range = param_group.get_range()\r\n            alpha = (self.left_slider.GetValue() + 1000) / 2000.0\r\n            pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha\r\n            if param_group.get_arity() == 2:\r\n                alpha = (self.right_slider.GetValue() + 1000) / 2000.0\r\n                pose[param_index + 1] = param_range[0] + (param_range[1] - param_range[0]) * alpha\r\n\r\n\r\nclass SimpleParamGroupsControlPanel(wx.Panel):\r\n    def __init__(self, parent,\r\n                 pose_param_category: PoseParameterCategory,\r\n                 param_groups: List[PoseParameterGroup]):\r\n        super().__init__(parent, style=wx.SIMPLE_BORDER)\r\n        self.sizer = wx.BoxSizer(wx.VERTICAL)\r\n        self.SetSizer(self.sizer)\r\n        self.SetAutoLayout(1)\r\n\r\n        self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category]\r\n        for param_group in self.param_groups:\r\n            assert not param_group.is_discrete()\r\n            assert param_group.get_arity() == 1\r\n\r\n        self.sliders = []\r\n        for param_group in self.param_groups:\r\n            static_text = wx.StaticText(\r\n                self,\r\n                label=\"   ------------ %s ------------   \" % param_group.get_group_name(), style=wx.ALIGN_CENTER)\r\n            self.sizer.Add(static_text, 0, wx.EXPAND)\r\n            range = param_group.get_range()\r\n            min_value = int(range[0] * 1000)\r\n            max_value = int(range[1] * 1000)\r\n            slider = wx.Slider(self, minValue=min_value, maxValue=max_value, value=0, style=wx.HORIZONTAL)\r\n            self.sizer.Add(slider, 0, wx.EXPAND)\r\n            self.sliders.append(slider)\r\n\r\n        self.sizer.Fit(self)\r\n\r\n    def set_param_value(self, pose: List[float]):\r\n        if len(self.param_groups) == 0:\r\n            return\r\n        for param_group_index in range(len(self.param_groups)):\r\n            param_group = self.param_groups[param_group_index]\r\n            slider = self.sliders[param_group_index]\r\n            param_range = param_group.get_range()\r\n            param_index = param_group.get_parameter_index()\r\n            alpha = (slider.GetValue() - slider.GetMin()) * 1.0 / (slider.GetMax() - slider.GetMin())\r\n            pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha\r\n\r\n\r\nclass MainFrame(wx.Frame):\r\n    IMAGE_SIZE = 512\r\n    OUTPUT_LENGTH = 6\r\n    NUM_PARAMETERS = 45\r\n\r\n    def __init__(self, device: torch.device):\r\n        super().__init__(None, wx.ID_ANY, \"Poser\")\r\n        self.poser = None\r\n        self.device = device\r\n\r\n        self.wx_source_image = None\r\n        self.torch_source_image = None\r\n\r\n        self.main_sizer = wx.BoxSizer(wx.HORIZONTAL)\r\n        self.SetSizer(self.main_sizer)\r\n        self.SetAutoLayout(1)\r\n        self.init_left_panel()\r\n        self.init_control_panel()\r\n        self.init_right_panel()\r\n        self.main_sizer.Fit(self)\r\n\r\n        self.timer = wx.Timer(self, wx.ID_ANY)\r\n        self.Bind(wx.EVT_TIMER, self.update_images, self.timer)\r\n\r\n        save_image_id = wx.NewIdRef()\r\n        self.Bind(wx.EVT_MENU, self.on_save_image, id=save_image_id)\r\n        accelerator_table = wx.AcceleratorTable([\r\n            (wx.ACCEL_CTRL, ord('S'), save_image_id)\r\n        ])\r\n        self.SetAcceleratorTable(accelerator_table)\r\n\r\n        self.last_pose = None\r\n        self.last_output_index = self.output_index_choice.GetSelection()\r\n        self.last_output_numpy_image = None\r\n\r\n        self.wx_source_image = None\r\n        self.torch_source_image = None\r\n        self.source_image_bitmap = wx.Bitmap(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE)\r\n        self.result_image_bitmap = wx.Bitmap(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE)\r\n        self.source_image_dirty = True\r\n\r\n    def init_left_panel(self):\r\n        self.control_panel = wx.Panel(self, style=wx.SIMPLE_BORDER, size=(MainFrame.IMAGE_SIZE, -1))\r\n        self.left_panel = wx.Panel(self, style=wx.SIMPLE_BORDER)\r\n        left_panel_sizer = wx.BoxSizer(wx.VERTICAL)\r\n        self.left_panel.SetSizer(left_panel_sizer)\r\n        self.left_panel.SetAutoLayout(1)\r\n\r\n        self.source_image_panel = wx.Panel(self.left_panel, size=(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE),\r\n                                           style=wx.SIMPLE_BORDER)\r\n        self.source_image_panel.Bind(wx.EVT_PAINT, self.paint_source_image_panel)\r\n        self.source_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)\r\n        left_panel_sizer.Add(self.source_image_panel, 0, wx.FIXED_MINSIZE)\r\n\r\n        self.load_model_button = wx.Button(self.left_panel, wx.ID_ANY, \"\\nLoad Model\\n\\n\")\r\n        left_panel_sizer.Add(self.load_model_button, 1, wx.EXPAND)\r\n        self.load_model_button.Bind(wx.EVT_BUTTON, self.load_model)\r\n\r\n        left_panel_sizer.Fit(self.left_panel)\r\n        self.main_sizer.Add(self.left_panel, 0, wx.FIXED_MINSIZE)\r\n\r\n    def on_erase_background(self, event: wx.Event):\r\n        pass\r\n\r\n    def init_control_panel(self):\r\n        self.control_panel_sizer = wx.BoxSizer(wx.VERTICAL)\r\n        self.control_panel.SetSizer(self.control_panel_sizer)\r\n        self.control_panel.SetMinSize(wx.Size(256, 1))\r\n\r\n        morph_categories = [\r\n            PoseParameterCategory.EYEBROW,\r\n            PoseParameterCategory.EYE,\r\n            PoseParameterCategory.MOUTH,\r\n            PoseParameterCategory.IRIS_MORPH\r\n        ]\r\n        morph_category_titles = {\r\n            PoseParameterCategory.EYEBROW: \"   ------------ Eyebrow ------------   \",\r\n            PoseParameterCategory.EYE: \"   ------------ Eye ------------   \",\r\n            PoseParameterCategory.MOUTH: \"   ------------ Mouth ------------   \",\r\n            PoseParameterCategory.IRIS_MORPH: \"   ------------ Iris morphs ------------   \",\r\n        }\r\n        self.morph_control_panels = {}\r\n        param_groups = get_pose_parameters().get_pose_parameter_groups()\r\n        for category in morph_categories:\r\n            filtered_param_groups = [group for group in param_groups if group.get_category() == category]\r\n            if len(filtered_param_groups) == 0:\r\n                continue\r\n            control_panel = MorphCategoryControlPanel(\r\n                self.control_panel,\r\n                morph_category_titles[category],\r\n                category,\r\n                param_groups)\r\n            self.morph_control_panels[category] = control_panel\r\n            self.control_panel_sizer.Add(control_panel, 0, wx.EXPAND)\r\n\r\n        self.non_morph_control_panels = {}\r\n        non_morph_categories = [\r\n            PoseParameterCategory.IRIS_ROTATION,\r\n            PoseParameterCategory.FACE_ROTATION,\r\n            PoseParameterCategory.BODY_ROTATION,\r\n            PoseParameterCategory.BREATHING\r\n        ]\r\n        for category in non_morph_categories:\r\n            filtered_param_groups = [group for group in param_groups if group.get_category() == category]\r\n            if len(filtered_param_groups) == 0:\r\n                continue\r\n            control_panel = SimpleParamGroupsControlPanel(\r\n                self.control_panel,\r\n                category,\r\n                param_groups)\r\n            self.non_morph_control_panels[category] = control_panel\r\n            self.control_panel_sizer.Add(control_panel, 0, wx.EXPAND)\r\n\r\n        self.control_panel_sizer.Fit(self.control_panel)\r\n        self.main_sizer.Add(self.control_panel, 1, wx.FIXED_MINSIZE)\r\n\r\n    def init_right_panel(self):\r\n        self.right_panel = wx.Panel(self, style=wx.SIMPLE_BORDER)\r\n        right_panel_sizer = wx.BoxSizer(wx.VERTICAL)\r\n        self.right_panel.SetSizer(right_panel_sizer)\r\n        self.right_panel.SetAutoLayout(1)\r\n\r\n        self.result_image_panel = wx.Panel(self.right_panel,\r\n                                           size=(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE),\r\n                                           style=wx.SIMPLE_BORDER)\r\n        self.result_image_panel.Bind(wx.EVT_PAINT, self.paint_result_image_panel)\r\n        self.result_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)\r\n        self.output_index_choice = wx.Choice(\r\n            self.right_panel,\r\n            choices=[str(i) for i in range(MainFrame.OUTPUT_LENGTH)])\r\n        self.output_index_choice.SetSelection(0)\r\n        right_panel_sizer.Add(self.result_image_panel, 0, wx.FIXED_MINSIZE)\r\n        right_panel_sizer.Add(self.output_index_choice, 0, wx.EXPAND)\r\n\r\n        self.save_image_button = wx.Button(self.right_panel, wx.ID_ANY, \"\\nSave Image\\n\\n\")\r\n        right_panel_sizer.Add(self.save_image_button, 1, wx.EXPAND)\r\n        self.save_image_button.Bind(wx.EVT_BUTTON, self.on_save_image)\r\n\r\n        right_panel_sizer.Fit(self.right_panel)\r\n        self.main_sizer.Add(self.right_panel, 0, wx.FIXED_MINSIZE)\r\n\r\n    def create_param_category_choice(self, param_category: PoseParameterCategory):\r\n        params = []\r\n        for param_group in self.poser.get_pose_parameter_groups():\r\n            if param_group.get_category() == param_category:\r\n                params.append(param_group.get_group_name())\r\n        choice = wx.Choice(self.control_panel, choices=params)\r\n        if len(params) > 0:\r\n            choice.SetSelection(0)\r\n        return choice\r\n\r\n    def load_model(self, event: wx.Event):\r\n        dir_name = \"data/character_models\"\r\n        file_dialog = wx.FileDialog(self, \"Choose a model\", dir_name, \"\", \"*.yaml\", wx.FD_OPEN)\r\n        if file_dialog.ShowModal() == wx.ID_OK:\r\n            character_model_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename())\r\n            try:\r\n                self.character_model = CharacterModel.load(character_model_file_name)\r\n                self.torch_source_image = self.character_model.get_character_image(self.device)\r\n                pil_image = resize_PIL_image(\r\n                    PIL.Image.open(self.character_model.character_image_file_name),\r\n                    (MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE))\r\n                w, h = pil_image.size\r\n                self.wx_source_image = wx.Bitmap.FromBufferRGBA(w, h, pil_image.convert(\"RGBA\").tobytes())\r\n                self.poser = self.character_model.get_poser(self.device)\r\n                self.source_image_dirty = True\r\n                self.Refresh()\r\n                self.Update()\r\n            except RuntimeError as e:\r\n                message_dialog = wx.MessageDialog(\r\n                    self, \"Could not load character model \" + character_model_file_name, \"Poser\", wx.OK)\r\n                message_dialog.ShowModal()\r\n                message_dialog.Destroy()\r\n        file_dialog.Destroy()\r\n\r\n    def paint_source_image_panel(self, event: wx.Event):\r\n        wx.BufferedPaintDC(self.source_image_panel, self.source_image_bitmap)\r\n\r\n    def paint_result_image_panel(self, event: wx.Event):\r\n        wx.BufferedPaintDC(self.result_image_panel, self.result_image_bitmap)\r\n\r\n    def draw_nothing_yet_string_to_bitmap(self, bitmap):\r\n        dc = wx.MemoryDC()\r\n        dc.SelectObject(bitmap)\r\n\r\n        dc.Clear()\r\n        font = wx.Font(wx.FontInfo(14).Family(wx.FONTFAMILY_SWISS))\r\n        dc.SetFont(font)\r\n        w, h = dc.GetTextExtent(\"Nothing yet!\")\r\n        dc.DrawText(\"Nothing yet!\", (MainFrame.IMAGE_SIZE - w) // 2, (MainFrame.IMAGE_SIZE - - h) // 2)\r\n\r\n        del dc\r\n\r\n    def get_current_pose(self):\r\n        current_pose = [0.0 for i in range(MainFrame.NUM_PARAMETERS)]\r\n        for morph_control_panel in self.morph_control_panels.values():\r\n            morph_control_panel.set_param_value(current_pose)\r\n        for rotation_control_panel in self.non_morph_control_panels.values():\r\n            rotation_control_panel.set_param_value(current_pose)\r\n        return current_pose\r\n\r\n    def update_images(self, event: wx.Event):\r\n        current_pose = self.get_current_pose()\r\n        if not self.source_image_dirty \\\r\n                and self.last_pose is not None \\\r\n                and self.last_pose == current_pose \\\r\n                and self.last_output_index == self.output_index_choice.GetSelection():\r\n            return\r\n        self.last_pose = current_pose\r\n        self.last_output_index = self.output_index_choice.GetSelection()\r\n\r\n        if self.torch_source_image is None or self.poser is None:\r\n            self.draw_nothing_yet_string_to_bitmap(self.source_image_bitmap)\r\n            self.draw_nothing_yet_string_to_bitmap(self.result_image_bitmap)\r\n            self.source_image_dirty = False\r\n            self.Refresh()\r\n            self.Update()\r\n            return\r\n\r\n        if self.source_image_dirty:\r\n            dc = wx.MemoryDC()\r\n            dc.SelectObject(self.source_image_bitmap)\r\n            dc.Clear()\r\n            dc.DrawBitmap(self.wx_source_image, 0, 0)\r\n            self.source_image_dirty = False\r\n\r\n        pose = torch.tensor(current_pose, device=self.device)\r\n        output_index = self.output_index_choice.GetSelection()\r\n        with torch.no_grad():\r\n            start_cuda_event = torch.cuda.Event(enable_timing=True)\r\n            end_cuda_event = torch.cuda.Event(enable_timing=True)\r\n            start_cuda_event.record()\r\n            start_time = time.time()\r\n\r\n            output_image = self.poser.pose(self.torch_source_image, pose, output_index)[0].detach().cpu()\r\n\r\n            end_time = time.time()\r\n            end_cuda_event.record()\r\n            torch.cuda.synchronize()\r\n            print(\"cuda time (ms):\", start_cuda_event.elapsed_time(end_cuda_event))\r\n            print(\"elapsed time (ms):\", (end_time - start_time) * 1000.0)\r\n\r\n        numpy_image = convert_output_image_from_torch_to_numpy(output_image)\r\n        self.last_output_numpy_image = numpy_image\r\n        wx_image = wx.ImageFromBuffer(\r\n            numpy_image.shape[0],\r\n            numpy_image.shape[1],\r\n            numpy_image[:, :, 0:3].tobytes(),\r\n            numpy_image[:, :, 3].tobytes())\r\n        wx_bitmap = wx_image.ConvertToBitmap()\r\n\r\n        dc = wx.MemoryDC()\r\n        dc.SelectObject(self.result_image_bitmap)\r\n        dc.Clear()\r\n        dc.DrawBitmap(wx_bitmap,\r\n                      (MainFrame.IMAGE_SIZE - numpy_image.shape[0]) // 2,\r\n                      (MainFrame.IMAGE_SIZE - numpy_image.shape[1]) // 2,\r\n                      True)\r\n        del dc\r\n\r\n        self.Refresh()\r\n        self.Update()\r\n\r\n    def on_save_image(self, event: wx.Event):\r\n        if self.last_output_numpy_image is None:\r\n            logging.info(\"There is no output image to save!!!\")\r\n            return\r\n\r\n        dir_name = \"data/images\"\r\n        file_dialog = wx.FileDialog(self, \"Choose an image\", dir_name, \"\", \"*.png\", wx.FD_SAVE)\r\n        if file_dialog.ShowModal() == wx.ID_OK:\r\n            image_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename())\r\n            try:\r\n                if os.path.exists(image_file_name):\r\n                    message_dialog = wx.MessageDialog(self, f\"Override {image_file_name}\", \"Manual Poser\",\r\n                                                      wx.YES_NO | wx.ICON_QUESTION)\r\n                    result = message_dialog.ShowModal()\r\n                    if result == wx.ID_YES:\r\n                        self.save_last_numpy_image(image_file_name)\r\n                    message_dialog.Destroy()\r\n                else:\r\n                    self.save_last_numpy_image(image_file_name)\r\n            except:\r\n                message_dialog = wx.MessageDialog(self, f\"Could not save {image_file_name}\", \"Manual Poser\", wx.OK)\r\n                message_dialog.ShowModal()\r\n                message_dialog.Destroy()\r\n        file_dialog.Destroy()\r\n\r\n    def save_last_numpy_image(self, image_file_name):\r\n        numpy_image = self.last_output_numpy_image\r\n        pil_image = PIL.Image.fromarray(numpy_image, mode='RGBA')\r\n        os.makedirs(os.path.dirname(image_file_name), exist_ok=True)\r\n        pil_image.save(image_file_name)\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    device = torch.device('cuda:0')\r\n    app = wx.App()\r\n    main_frame = MainFrame(device)\r\n    main_frame.Show(True)\r\n    main_frame.timer.Start(16)\r\n    app.MainLoop()\r\n"
  },
  {
    "path": "src/tha4/app/character_model_mediapipe_puppeteer.py",
    "content": "import os\r\nimport sys\r\nimport threading\r\nimport time\r\nfrom typing import Optional\r\nimport PIL.Image\r\n\r\nimport cv2\r\nimport mediapipe\r\nfrom scipy.spatial.transform import Rotation\r\n\r\nfrom tha4.shion.base.image_util import resize_PIL_image\r\nfrom tha4.charmodel.character_model import CharacterModel\r\nfrom tha4.image_util import convert_linear_to_srgb\r\nfrom tha4.mocap.mediapipe_constants import HEAD_ROTATIONS, HEAD_X, HEAD_Y, HEAD_Z\r\nfrom tha4.mocap.mediapipe_face_pose import MediaPipeFacePose\r\nfrom tha4.mocap.mediapipe_face_pose_converter_00 import MediaPoseFacePoseConverter00\r\n\r\nsys.path.append(os.getcwd())\r\n\r\nimport torch\r\nimport wx\r\n\r\n\r\nclass FpsStatistics:\r\n    def __init__(self):\r\n        self.count = 100\r\n        self.fps = []\r\n\r\n    def add_fps(self, fps):\r\n        self.fps.append(fps)\r\n        while len(self.fps) > self.count:\r\n            del self.fps[0]\r\n\r\n    def get_average_fps(self):\r\n        if len(self.fps) == 0:\r\n            return 0.0\r\n        else:\r\n            return sum(self.fps) / len(self.fps)\r\n\r\n\r\nclass MainFrame(wx.Frame):\r\n    IMAGE_SIZE = 512\r\n\r\n    def __init__(self,\r\n                 pose_converter: MediaPoseFacePoseConverter00,\r\n                 video_capture,\r\n                 face_landmarker,\r\n                 device: torch.device):\r\n        super().__init__(None, wx.ID_ANY, \"THA4 Character Model MediaPipe Puppeteer\")\r\n        self.face_landmarker = face_landmarker\r\n        self.video_capture = video_capture\r\n        self.pose_converter = pose_converter\r\n        self.device = device\r\n\r\n        self.source_image_bitmap = wx.Bitmap(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE)\r\n        self.result_image_bitmap = wx.Bitmap(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE)\r\n        self.webcam_capture_bitmap = wx.Bitmap(256, 192)\r\n        self.wx_source_image = None\r\n        self.torch_source_image = None\r\n        self.last_pose = None\r\n        self.mediapipe_face_pose = None\r\n        self.fps_statistics = FpsStatistics()\r\n        self.last_update_time = None\r\n        self.character_model = None\r\n        self.poser = None\r\n\r\n        self.create_ui()\r\n        self.create_timers()\r\n        self.Bind(wx.EVT_CLOSE, self.on_close)\r\n\r\n        self.update_source_image_bitmap()\r\n        self.update_result_image_bitmap()\r\n\r\n    def create_timers(self):\r\n        self.capture_timer = wx.Timer(self, wx.ID_ANY)\r\n        self.Bind(wx.EVT_TIMER, self.update_capture_panel, id=self.capture_timer.GetId())\r\n        self.animation_timer = wx.Timer(self, wx.ID_ANY)\r\n        self.Bind(wx.EVT_TIMER, self.update_result_image_bitmap, id=self.animation_timer.GetId())\r\n\r\n    def on_close(self, event: wx.Event):\r\n        # Stop the timers\r\n        self.animation_timer.Stop()\r\n        self.capture_timer.Stop()\r\n\r\n        # Destroy the windows\r\n        self.Destroy()\r\n        event.Skip()\r\n\r\n    def on_erase_background(self, event: wx.Event):\r\n        pass\r\n\r\n    def create_animation_panel(self, parent):\r\n        self.animation_panel = wx.Panel(parent, style=wx.RAISED_BORDER)\r\n        self.animation_panel_sizer = wx.BoxSizer(wx.HORIZONTAL)\r\n        self.animation_panel.SetSizer(self.animation_panel_sizer)\r\n        self.animation_panel.SetAutoLayout(1)\r\n\r\n        image_size = MainFrame.IMAGE_SIZE\r\n\r\n        if True:\r\n            self.input_panel = wx.Panel(self.animation_panel, size=(image_size, image_size + 128),\r\n                                        style=wx.SIMPLE_BORDER)\r\n            self.input_panel_sizer = wx.BoxSizer(wx.VERTICAL)\r\n            self.input_panel.SetSizer(self.input_panel_sizer)\r\n            self.input_panel.SetAutoLayout(1)\r\n            self.animation_panel_sizer.Add(self.input_panel, 0, wx.FIXED_MINSIZE)\r\n\r\n            self.source_image_panel = wx.Panel(self.input_panel, size=(image_size, image_size), style=wx.SIMPLE_BORDER)\r\n            self.source_image_panel.Bind(wx.EVT_PAINT, self.paint_source_image_panel)\r\n            self.source_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)\r\n            self.input_panel_sizer.Add(self.source_image_panel, 0, wx.FIXED_MINSIZE)\r\n\r\n            self.load_model_button = wx.Button(self.input_panel, wx.ID_ANY, \"Load Model\")\r\n            self.input_panel_sizer.Add(self.load_model_button, 1, wx.EXPAND)\r\n            self.load_model_button.Bind(wx.EVT_BUTTON, self.load_model)\r\n\r\n            self.input_panel_sizer.Fit(self.input_panel)\r\n\r\n        if True:\r\n            def current_pose_supplier() -> Optional[MediaPipeFacePose]:\r\n                return self.mediapipe_face_pose\r\n\r\n            self.pose_converter.init_pose_converter_panel(self.animation_panel, current_pose_supplier)\r\n\r\n        if True:\r\n            self.animation_left_panel = wx.Panel(self.animation_panel, style=wx.SIMPLE_BORDER)\r\n            self.animation_left_panel_sizer = wx.BoxSizer(wx.VERTICAL)\r\n            self.animation_left_panel.SetSizer(self.animation_left_panel_sizer)\r\n            self.animation_left_panel.SetAutoLayout(1)\r\n            self.animation_panel_sizer.Add(self.animation_left_panel, 0, wx.EXPAND)\r\n\r\n            self.result_image_panel = wx.Panel(self.animation_left_panel, size=(image_size, image_size),\r\n                                               style=wx.SIMPLE_BORDER)\r\n            self.result_image_panel.Bind(wx.EVT_PAINT, self.paint_result_image_panel)\r\n            self.result_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)\r\n            self.animation_left_panel_sizer.Add(self.result_image_panel, 0, wx.FIXED_MINSIZE)\r\n\r\n            separator = wx.StaticLine(self.animation_left_panel, -1, size=(256, 5))\r\n            self.animation_left_panel_sizer.Add(separator, 0, wx.EXPAND)\r\n\r\n            background_text = wx.StaticText(self.animation_left_panel, label=\"--- Background ---\",\r\n                                            style=wx.ALIGN_CENTER)\r\n            self.animation_left_panel_sizer.Add(background_text, 0, wx.EXPAND)\r\n\r\n            self.output_background_choice = wx.Choice(\r\n                self.animation_left_panel,\r\n                choices=[\r\n                    \"TRANSPARENT\",\r\n                    \"GREEN\",\r\n                    \"BLUE\",\r\n                    \"BLACK\",\r\n                    \"WHITE\"\r\n                ])\r\n            self.output_background_choice.SetSelection(0)\r\n            self.animation_left_panel_sizer.Add(self.output_background_choice, 0, wx.EXPAND)\r\n\r\n            separator = wx.StaticLine(self.animation_left_panel, -1, size=(256, 5))\r\n            self.animation_left_panel_sizer.Add(separator, 0, wx.EXPAND)\r\n\r\n            self.fps_text = wx.StaticText(self.animation_left_panel, label=\"\")\r\n            self.animation_left_panel_sizer.Add(self.fps_text, wx.SizerFlags().Border())\r\n\r\n            self.animation_left_panel_sizer.Fit(self.animation_left_panel)\r\n\r\n        self.animation_panel_sizer.Fit(self.animation_panel)\r\n\r\n    def create_ui(self):\r\n        self.main_sizer = wx.BoxSizer(wx.VERTICAL)\r\n        self.SetSizer(self.main_sizer)\r\n        self.SetAutoLayout(1)\r\n\r\n        self.capture_pose_lock = threading.Lock()\r\n\r\n        self.create_animation_panel(self)\r\n        self.main_sizer.Add(self.animation_panel, wx.SizerFlags(0).Expand().Border(wx.ALL, 5))\r\n\r\n        self.create_capture_panel(self)\r\n        self.main_sizer.Add(self.capture_panel, wx.SizerFlags(0).Expand().Border(wx.ALL, 5))\r\n\r\n        self.main_sizer.Fit(self)\r\n\r\n    def create_capture_panel(self, parent):\r\n        self.capture_panel = wx.Panel(parent, style=wx.RAISED_BORDER)\r\n        self.capture_panel_sizer = wx.BoxSizer(wx.HORIZONTAL)\r\n        self.capture_panel.SetSizer(self.capture_panel_sizer)\r\n        self.capture_panel.SetAutoLayout(1)\r\n\r\n        self.webcam_capture_panel = wx.Panel(self.capture_panel, size=(256, 192), style=wx.SIMPLE_BORDER)\r\n        self.webcam_capture_panel.Bind(wx.EVT_PAINT, self.paint_webcam_capture_panel)\r\n        self.webcam_capture_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)\r\n        self.capture_panel_sizer.Add(self.webcam_capture_panel, wx.SizerFlags(0).FixedMinSize().Border(wx.ALL, 5))\r\n\r\n        self.rotation_labels = {}\r\n        self.rotation_value_labels = {}\r\n        rotation_column = self.create_rotation_column(self.capture_panel, HEAD_ROTATIONS)\r\n        self.capture_panel_sizer.Add(rotation_column, wx.SizerFlags(0).Expand().Border(wx.ALL, 3))\r\n\r\n    def paint_webcam_capture_panel(self, event: wx.Event):\r\n        wx.BufferedPaintDC(self.webcam_capture_panel, self.webcam_capture_bitmap)\r\n\r\n    def create_rotation_column(self, parent, rotation_names):\r\n        column_panel = wx.Panel(parent, style=wx.SIMPLE_BORDER)\r\n        column_panel_sizer = wx.FlexGridSizer(cols=2)\r\n        column_panel_sizer.AddGrowableCol(1)\r\n        column_panel.SetSizer(column_panel_sizer)\r\n        column_panel.SetAutoLayout(1)\r\n\r\n        for rotation_name in rotation_names:\r\n            self.rotation_labels[rotation_name] = wx.StaticText(\r\n                column_panel, label=rotation_name, style=wx.ALIGN_RIGHT)\r\n            column_panel_sizer.Add(self.rotation_labels[rotation_name],\r\n                                   wx.SizerFlags(1).Expand().Border(wx.ALL, 3))\r\n\r\n            self.rotation_value_labels[rotation_name] = wx.TextCtrl(\r\n                column_panel, style=wx.TE_RIGHT)\r\n            self.rotation_value_labels[rotation_name].SetValue(\"0.00\")\r\n            self.rotation_value_labels[rotation_name].Disable()\r\n            column_panel_sizer.Add(self.rotation_value_labels[rotation_name],\r\n                                   wx.SizerFlags(1).Expand().Border(wx.ALL, 3))\r\n\r\n        column_panel.GetSizer().Fit(column_panel)\r\n        return column_panel\r\n\r\n    def update_capture_panel(self, event: wx.Event):\r\n        there_is_frame, frame = self.video_capture.read()\r\n        if not there_is_frame:\r\n            dc = wx.MemoryDC()\r\n            dc.SelectObject(self.webcam_capture_bitmap)\r\n            self.draw_nothing_yet_string(dc)\r\n            del dc\r\n            return\r\n\r\n        rgb_frame = cv2.flip(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), 1)\r\n        resized_frame = cv2.resize(rgb_frame, (256, 192))\r\n        wx_image = wx.ImageFromBuffer(256, 192, resized_frame.tobytes())\r\n        wx_bitmap = wx_image.ConvertToBitmap()\r\n\r\n        dc = wx.MemoryDC()\r\n        dc.SelectObject(self.webcam_capture_bitmap)\r\n        dc.Clear()\r\n        dc.DrawBitmap(wx_bitmap, 0, 0, True)\r\n        del dc\r\n\r\n        self.webcam_capture_panel.Refresh()\r\n\r\n        time_ms = int(time.time() * 1000)\r\n        mediapipe_image = mediapipe.Image(image_format=mediapipe.ImageFormat.SRGB, data=rgb_frame)\r\n        detection_result = self.face_landmarker.detect_for_video(mediapipe_image, time_ms)\r\n        self.update_mediapipe_face_pose(detection_result)\r\n\r\n    def update_mediapipe_face_pose(self, detection_result):\r\n        if len(detection_result.facial_transformation_matrixes) == 0:\r\n            return\r\n\r\n        xform_matrix = detection_result.facial_transformation_matrixes[0]\r\n        blendshape_params = {}\r\n        for item in detection_result.face_blendshapes[0]:\r\n            blendshape_params[item.category_name] = item.score\r\n        M = xform_matrix[0:3, 0:3]\r\n        rot = Rotation.from_matrix(M)\r\n        euler_angles = rot.as_euler('xyz', degrees=True)\r\n\r\n        self.rotation_value_labels[HEAD_X].SetValue(\"%0.2f\" % euler_angles[0])\r\n        self.rotation_value_labels[HEAD_X].Refresh()\r\n        self.rotation_value_labels[HEAD_Y].SetValue(\"%0.2f\" % euler_angles[1])\r\n        self.rotation_value_labels[HEAD_Y].Refresh()\r\n        self.rotation_value_labels[HEAD_Z].SetValue(\"%0.2f\" % euler_angles[2])\r\n        self.rotation_value_labels[HEAD_Z].Refresh()\r\n\r\n        self.mediapipe_face_pose = MediaPipeFacePose(blendshape_params, xform_matrix)\r\n\r\n    @staticmethod\r\n    def convert_to_100(x):\r\n        return int(max(0.0, min(1.0, x)) * 100)\r\n\r\n    def paint_source_image_panel(self, event: wx.Event):\r\n        wx.BufferedPaintDC(self.source_image_panel, self.source_image_bitmap)\r\n\r\n    def update_source_image_bitmap(self):\r\n        dc = wx.MemoryDC()\r\n        dc.SelectObject(self.source_image_bitmap)\r\n        if self.wx_source_image is None:\r\n            self.draw_nothing_yet_string(dc)\r\n        else:\r\n            dc.Clear()\r\n            dc.DrawBitmap(self.wx_source_image, 0, 0, True)\r\n        del dc\r\n\r\n    def draw_nothing_yet_string(self, dc):\r\n        dc.Clear()\r\n        font = wx.Font(wx.FontInfo(14).Family(wx.FONTFAMILY_SWISS))\r\n        dc.SetFont(font)\r\n        w, h = dc.GetTextExtent(\"Nothing yet!\")\r\n        dc.DrawText(\"Nothing yet!\", (MainFrame.IMAGE_SIZE - w) // 2, (MainFrame.IMAGE_SIZE - h) // 2)\r\n\r\n    def paint_result_image_panel(self, event: wx.Event):\r\n        wx.BufferedPaintDC(self.result_image_panel, self.result_image_bitmap)\r\n\r\n    def update_result_image_bitmap(self, event: Optional[wx.Event] = None):\r\n        if self.mediapipe_face_pose is None or self.poser is None:\r\n            dc = wx.MemoryDC()\r\n            dc.SelectObject(self.result_image_bitmap)\r\n            self.draw_nothing_yet_string(dc)\r\n            del dc\r\n            return\r\n\r\n        current_pose = self.pose_converter.convert(self.mediapipe_face_pose)\r\n        if self.last_pose is not None and self.last_pose == current_pose:\r\n            return\r\n        self.last_pose = current_pose\r\n\r\n        if self.torch_source_image is None:\r\n            dc = wx.MemoryDC()\r\n            dc.SelectObject(self.result_image_bitmap)\r\n            self.draw_nothing_yet_string(dc)\r\n            del dc\r\n            return\r\n\r\n        pose = torch.tensor(current_pose, device=self.device, dtype=self.poser.get_dtype())\r\n\r\n        with torch.no_grad():\r\n            output_image = self.poser.pose(self.torch_source_image, pose)[0].float()\r\n            output_image = torch.clip((output_image + 1.0) / 2.0, 0.0, 1.0)\r\n            output_image = convert_linear_to_srgb(output_image)\r\n\r\n            background_choice = self.output_background_choice.GetSelection()\r\n            if background_choice == 0:\r\n                pass\r\n            else:\r\n                background = torch.zeros(4, output_image.shape[1], output_image.shape[2], device=self.device)\r\n                background[3, :, :] = 1.0\r\n                if background_choice == 1:\r\n                    background[1, :, :] = 1.0\r\n                    output_image = self.blend_with_background(output_image, background)\r\n                elif background_choice == 2:\r\n                    background[2, :, :] = 1.0\r\n                    output_image = self.blend_with_background(output_image, background)\r\n                elif background_choice == 3:\r\n                    output_image = self.blend_with_background(output_image, background)\r\n                else:\r\n                    background[0:3, :, :] = 1.0\r\n                    output_image = self.blend_with_background(output_image, background)\r\n\r\n            c, h, w = output_image.shape\r\n            output_image = 255.0 * torch.transpose(output_image.reshape(c, h * w), 0, 1).reshape(h, w, c)\r\n            output_image = output_image.byte()\r\n\r\n        numpy_image = output_image.detach().cpu().numpy()\r\n        wx_image = wx.ImageFromBuffer(numpy_image.shape[0],\r\n                                      numpy_image.shape[1],\r\n                                      numpy_image[:, :, 0:3].tobytes(),\r\n                                      numpy_image[:, :, 3].tobytes())\r\n        wx_bitmap = wx_image.ConvertToBitmap()\r\n\r\n        dc = wx.MemoryDC()\r\n        dc.SelectObject(self.result_image_bitmap)\r\n        dc.Clear()\r\n        dc.DrawBitmap(wx_bitmap,\r\n                      (MainFrame.IMAGE_SIZE - numpy_image.shape[0]) // 2,\r\n                      (MainFrame.IMAGE_SIZE - numpy_image.shape[1]) // 2, True)\r\n        del dc\r\n\r\n        time_now = time.time_ns()\r\n        if self.last_update_time is not None:\r\n            elapsed_time = time_now - self.last_update_time\r\n            fps = 1.0 / (elapsed_time / 10 ** 9)\r\n            if self.torch_source_image is not None:\r\n                self.fps_statistics.add_fps(fps)\r\n            self.fps_text.SetLabelText(\"FPS = %0.2f\" % self.fps_statistics.get_average_fps())\r\n        self.last_update_time = time_now\r\n\r\n        self.Refresh()\r\n\r\n    def blend_with_background(self, numpy_image, background):\r\n        alpha = numpy_image[3:4, :, :]\r\n        color = numpy_image[0:3, :, :]\r\n        new_color = color * alpha + (1.0 - alpha) * background[0:3, :, :]\r\n        return torch.cat([new_color, background[3:4, :, :]], dim=0)\r\n\r\n    def load_model(self, event: wx.Event):\r\n        dir_name = \"data/character_models\"\r\n        file_dialog = wx.FileDialog(self, \"Choose a model\", dir_name, \"\", \"*.yaml\", wx.FD_OPEN)\r\n        if file_dialog.ShowModal() == wx.ID_OK:\r\n            character_model_json_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename())\r\n            try:\r\n                self.character_model = CharacterModel.load(character_model_json_file_name)\r\n                self.torch_source_image = self.character_model.get_character_image(self.device)\r\n                pil_image = resize_PIL_image(\r\n                    PIL.Image.open(self.character_model.character_image_file_name),\r\n                    (MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE))\r\n                w, h = pil_image.size\r\n                self.wx_source_image = wx.Bitmap.FromBufferRGBA(w, h, pil_image.convert(\"RGBA\").tobytes())\r\n                self.update_source_image_bitmap()\r\n                self.poser = self.character_model.get_poser(self.device)\r\n            except Exception:\r\n                message_dialog = wx.MessageDialog(\r\n                    self, \"Could not load character model \" + character_model_json_file_name, \"Poser\", wx.OK)\r\n                message_dialog.ShowModal()\r\n                message_dialog.Destroy()\r\n        file_dialog.Destroy()\r\n        self.Refresh()\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    device = torch.device(\"cuda:0\")\r\n\r\n    pose_converter = MediaPoseFacePoseConverter00()\r\n\r\n    face_landmarker_base_options = mediapipe.tasks.BaseOptions(\r\n        model_asset_path='data/thirdparty/mediapipe/face_landmarker_v2_with_blendshapes.task')\r\n    options = mediapipe.tasks.vision.FaceLandmarkerOptions(\r\n        base_options=face_landmarker_base_options,\r\n        running_mode=mediapipe.tasks.vision.RunningMode.VIDEO,\r\n        output_face_blendshapes=True,\r\n        output_facial_transformation_matrixes=True,\r\n        num_faces=1)\r\n    face_landmarker = mediapipe.tasks.vision.FaceLandmarker.create_from_options(options)\r\n\r\n    video_capture = cv2.VideoCapture(0)\r\n\r\n    app = wx.App()\r\n    main_frame = MainFrame(pose_converter, video_capture, face_landmarker, device)\r\n    main_frame.Show(True)\r\n    main_frame.capture_timer.Start(30)\r\n    main_frame.animation_timer.Start(30)\r\n    app.MainLoop()\r\n"
  },
  {
    "path": "src/tha4/app/distill.py",
    "content": "import argparse\r\nimport logging\r\n\r\nfrom tha4.distiller.distiller_config import DistillerConfig\r\nfrom tha4.pytasuku.workspace import Workspace\r\n\r\n\r\ndef run_config(config_file_name: str):\r\n    config = DistillerConfig.load(config_file_name)\r\n\r\n    logging.basicConfig(level=logging.INFO, force=True)\r\n    workspace = Workspace()\r\n    config.define_tasks(workspace)\r\n\r\n    workspace.start_session()\r\n    workspace.run(f\"{config.prefix}/all\")\r\n    workspace.end_session()\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    parser = argparse.ArgumentParser(description='Training script.')\r\n    parser.add_argument(\"--config_file\", type=str, required=True,\r\n                        help=\"The name of the config file for the distillation process.\")\r\n    args = parser.parse_args()\r\n    run_config(args.config_file)\r\n"
  },
  {
    "path": "src/tha4/app/distiller_ui.py",
    "content": "import wx\r\n\r\nfrom tha4.app.distill import run_config\r\nfrom tha4.distiller.ui.distiller_ui_main_frame import DistillerUiMainFrame\r\n\r\nif __name__ == \"__main__\":\r\n    app = wx.App()\r\n    main_frame = DistillerUiMainFrame()\r\n    main_frame.Show(True)\r\n    app.MainLoop()\r\n\r\n    if main_frame.config_file_to_run is not None:\r\n        run_config(main_frame.config_file_to_run)\r\n"
  },
  {
    "path": "src/tha4/app/full_manual_poser.py",
    "content": "import logging\r\nimport os\r\nimport sys\r\nimport time\r\nfrom typing import List\r\n\r\nfrom tha4.shion.base.image_util import extract_pytorch_image_from_PIL_image, pytorch_rgba_to_numpy_image, \\\r\n    pytorch_rgb_to_numpy_image\r\nfrom tha4.image_util import grid_change_to_numpy_image, resize_PIL_image\r\n\r\nsys.path.append(os.getcwd())\r\n\r\nimport PIL.Image\r\nimport numpy\r\nimport torch\r\nimport wx\r\n\r\nfrom tha4.poser.poser import Poser, PoseParameterCategory, PoseParameterGroup\r\n\r\n\r\nclass MorphCategoryControlPanel(wx.Panel):\r\n    def __init__(self,\r\n                 parent,\r\n                 title: str,\r\n                 pose_param_category: PoseParameterCategory,\r\n                 param_groups: List[PoseParameterGroup]):\r\n        super().__init__(parent, style=wx.SIMPLE_BORDER)\r\n        self.pose_param_category = pose_param_category\r\n        self.sizer = wx.BoxSizer(wx.VERTICAL)\r\n        self.SetSizer(self.sizer)\r\n        self.SetAutoLayout(1)\r\n\r\n        title_text = wx.StaticText(self, label=title, style=wx.ALIGN_CENTER)\r\n        self.sizer.Add(title_text, 0, wx.EXPAND)\r\n\r\n        self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category]\r\n        self.choice = wx.Choice(self, choices=[group.get_group_name() for group in self.param_groups])\r\n        if len(self.param_groups) > 0:\r\n            self.choice.SetSelection(0)\r\n        self.choice.Bind(wx.EVT_CHOICE, self.on_choice_updated)\r\n        self.sizer.Add(self.choice, 0, wx.EXPAND)\r\n\r\n        self.left_slider = wx.Slider(self, minValue=-1000, maxValue=1000, value=-1000, style=wx.HORIZONTAL)\r\n        self.sizer.Add(self.left_slider, 0, wx.EXPAND)\r\n\r\n        self.right_slider = wx.Slider(self, minValue=-1000, maxValue=1000, value=-1000, style=wx.HORIZONTAL)\r\n        self.sizer.Add(self.right_slider, 0, wx.EXPAND)\r\n\r\n        self.checkbox = wx.CheckBox(self, label=\"Show\")\r\n        self.checkbox.SetValue(True)\r\n        self.sizer.Add(self.checkbox, 0, wx.SHAPED | wx.ALIGN_CENTER)\r\n\r\n        self.update_ui()\r\n\r\n        self.sizer.Fit(self)\r\n\r\n    def update_ui(self):\r\n        param_group = self.param_groups[self.choice.GetSelection()]\r\n        if param_group.is_discrete():\r\n            self.left_slider.Enable(False)\r\n            self.right_slider.Enable(False)\r\n            self.checkbox.Enable(True)\r\n        elif param_group.get_arity() == 1:\r\n            self.left_slider.Enable(True)\r\n            self.right_slider.Enable(False)\r\n            self.checkbox.Enable(False)\r\n        else:\r\n            self.left_slider.Enable(True)\r\n            self.right_slider.Enable(True)\r\n            self.checkbox.Enable(False)\r\n\r\n    def on_choice_updated(self, event: wx.Event):\r\n        param_group = self.param_groups[self.choice.GetSelection()]\r\n        if param_group.is_discrete():\r\n            self.checkbox.SetValue(True)\r\n        self.update_ui()\r\n\r\n    def set_param_value(self, pose: List[float]):\r\n        if len(self.param_groups) == 0:\r\n            return\r\n        selected_morph_index = self.choice.GetSelection()\r\n        param_group = self.param_groups[selected_morph_index]\r\n        param_index = param_group.get_parameter_index()\r\n        if param_group.is_discrete():\r\n            if self.checkbox.GetValue():\r\n                for i in range(param_group.get_arity()):\r\n                    pose[param_index + i] = 1.0\r\n        else:\r\n            param_range = param_group.get_range()\r\n            alpha = (self.left_slider.GetValue() + 1000) / 2000.0\r\n            pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha\r\n            if param_group.get_arity() == 2:\r\n                alpha = (self.right_slider.GetValue() + 1000) / 2000.0\r\n                pose[param_index + 1] = param_range[0] + (param_range[1] - param_range[0]) * alpha\r\n\r\n\r\nclass SimpleParamGroupsControlPanel(wx.Panel):\r\n    def __init__(self, parent,\r\n                 pose_param_category: PoseParameterCategory,\r\n                 param_groups: List[PoseParameterGroup]):\r\n        super().__init__(parent, style=wx.SIMPLE_BORDER)\r\n        self.sizer = wx.BoxSizer(wx.VERTICAL)\r\n        self.SetSizer(self.sizer)\r\n        self.SetAutoLayout(1)\r\n\r\n        self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category]\r\n        for param_group in self.param_groups:\r\n            assert not param_group.is_discrete()\r\n            assert param_group.get_arity() == 1\r\n\r\n        self.sliders = []\r\n        for param_group in self.param_groups:\r\n            static_text = wx.StaticText(\r\n                self,\r\n                label=\"   ------------ %s ------------   \" % param_group.get_group_name(), style=wx.ALIGN_CENTER)\r\n            self.sizer.Add(static_text, 0, wx.EXPAND)\r\n            range = param_group.get_range()\r\n            min_value = int(range[0] * 1000)\r\n            max_value = int(range[1] * 1000)\r\n            slider = wx.Slider(self, minValue=min_value, maxValue=max_value, value=0, style=wx.HORIZONTAL)\r\n            self.sizer.Add(slider, 0, wx.EXPAND)\r\n            self.sliders.append(slider)\r\n\r\n        self.sizer.Fit(self)\r\n\r\n    def set_param_value(self, pose: List[float]):\r\n        if len(self.param_groups) == 0:\r\n            return\r\n        for param_group_index in range(len(self.param_groups)):\r\n            param_group = self.param_groups[param_group_index]\r\n            slider = self.sliders[param_group_index]\r\n            param_range = param_group.get_range()\r\n            param_index = param_group.get_parameter_index()\r\n            alpha = (slider.GetValue() - slider.GetMin()) * 1.0 / (slider.GetMax() - slider.GetMin())\r\n            pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha\r\n\r\n\r\ndef convert_output_image_from_torch_to_numpy(output_image):\r\n    if output_image.shape[2] == 2:\r\n        h, w, c = output_image.shape\r\n        numpy_image = torch.transpose(output_image.reshape(h * w, c), 0, 1).reshape(c, h, w)\r\n    elif output_image.shape[0] == 4:\r\n        numpy_image = pytorch_rgba_to_numpy_image(output_image)\r\n    elif output_image.shape[0] == 3:\r\n        numpy_image = pytorch_rgb_to_numpy_image(output_image)\r\n    elif output_image.shape[0] == 1:\r\n        c, h, w = output_image.shape\r\n        alpha_image = torch.cat([output_image.repeat(3, 1, 1) * 2.0 - 1.0, torch.ones(1, h, w)], dim=0)\r\n        numpy_image = pytorch_rgba_to_numpy_image(alpha_image)\r\n    elif output_image.shape[0] == 2:\r\n        numpy_image = grid_change_to_numpy_image(output_image, num_channels=4)\r\n    else:\r\n        raise RuntimeError(\"Unsupported # image channels: %d\" % output_image.shape[0])\r\n    numpy_image = numpy.uint8(numpy.rint(numpy_image * 255.0))\r\n    return numpy_image\r\n\r\n\r\nclass MainFrame(wx.Frame):\r\n    def __init__(self, poser: Poser, device: torch.device):\r\n        super().__init__(None, wx.ID_ANY, \"Poser\")\r\n        self.poser = poser\r\n        self.dtype = self.poser.get_dtype()\r\n        self.device = device\r\n        self.image_size = self.poser.get_image_size()\r\n\r\n        self.wx_source_image = None\r\n        self.torch_source_image = None\r\n\r\n        self.main_sizer = wx.BoxSizer(wx.HORIZONTAL)\r\n        self.SetSizer(self.main_sizer)\r\n        self.SetAutoLayout(1)\r\n        self.init_left_panel()\r\n        self.init_control_panel()\r\n        self.init_right_panel()\r\n        self.main_sizer.Fit(self)\r\n\r\n        self.timer = wx.Timer(self, wx.ID_ANY)\r\n        self.Bind(wx.EVT_TIMER, self.update_images, self.timer)\r\n\r\n        save_image_id = wx.NewIdRef()\r\n        self.Bind(wx.EVT_MENU, self.on_save_image, id=save_image_id)\r\n        accelerator_table = wx.AcceleratorTable([\r\n            (wx.ACCEL_CTRL, ord('S'), save_image_id)\r\n        ])\r\n        self.SetAcceleratorTable(accelerator_table)\r\n\r\n        self.last_pose = None\r\n        self.last_output_index = self.output_index_choice.GetSelection()\r\n        self.last_output_numpy_image = None\r\n\r\n        self.wx_source_image = None\r\n        self.torch_source_image = None\r\n        self.source_image_bitmap = wx.Bitmap(self.image_size, self.image_size)\r\n        self.result_image_bitmap = wx.Bitmap(self.image_size, self.image_size)\r\n        self.source_image_dirty = True\r\n\r\n    def init_left_panel(self):\r\n        self.control_panel = wx.Panel(self, style=wx.SIMPLE_BORDER, size=(self.image_size, -1))\r\n        self.left_panel = wx.Panel(self, style=wx.SIMPLE_BORDER)\r\n        left_panel_sizer = wx.BoxSizer(wx.VERTICAL)\r\n        self.left_panel.SetSizer(left_panel_sizer)\r\n        self.left_panel.SetAutoLayout(1)\r\n\r\n        self.source_image_panel = wx.Panel(self.left_panel, size=(self.image_size, self.image_size),\r\n                                           style=wx.SIMPLE_BORDER)\r\n        self.source_image_panel.Bind(wx.EVT_PAINT, self.paint_source_image_panel)\r\n        self.source_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)\r\n        left_panel_sizer.Add(self.source_image_panel, 0, wx.FIXED_MINSIZE)\r\n\r\n        self.load_image_button = wx.Button(self.left_panel, wx.ID_ANY, \"\\nLoad Image\\n\\n\")\r\n        left_panel_sizer.Add(self.load_image_button, 1, wx.EXPAND)\r\n        self.load_image_button.Bind(wx.EVT_BUTTON, self.load_image)\r\n\r\n        left_panel_sizer.Fit(self.left_panel)\r\n        self.main_sizer.Add(self.left_panel, 0, wx.FIXED_MINSIZE)\r\n\r\n    def on_erase_background(self, event: wx.Event):\r\n        pass\r\n\r\n    def init_control_panel(self):\r\n        self.control_panel_sizer = wx.BoxSizer(wx.VERTICAL)\r\n        self.control_panel.SetSizer(self.control_panel_sizer)\r\n        self.control_panel.SetMinSize(wx.Size(256, 1))\r\n\r\n        morph_categories = [\r\n            PoseParameterCategory.EYEBROW,\r\n            PoseParameterCategory.EYE,\r\n            PoseParameterCategory.MOUTH,\r\n            PoseParameterCategory.IRIS_MORPH\r\n        ]\r\n        morph_category_titles = {\r\n            PoseParameterCategory.EYEBROW: \"   ------------ Eyebrow ------------   \",\r\n            PoseParameterCategory.EYE: \"   ------------ Eye ------------   \",\r\n            PoseParameterCategory.MOUTH: \"   ------------ Mouth ------------   \",\r\n            PoseParameterCategory.IRIS_MORPH: \"   ------------ Iris morphs ------------   \",\r\n        }\r\n        self.morph_control_panels = {}\r\n        for category in morph_categories:\r\n            param_groups = self.poser.get_pose_parameter_groups()\r\n            filtered_param_groups = [group for group in param_groups if group.get_category() == category]\r\n            if len(filtered_param_groups) == 0:\r\n                continue\r\n            control_panel = MorphCategoryControlPanel(\r\n                self.control_panel,\r\n                morph_category_titles[category],\r\n                category,\r\n                self.poser.get_pose_parameter_groups())\r\n            self.morph_control_panels[category] = control_panel\r\n            self.control_panel_sizer.Add(control_panel, 0, wx.EXPAND)\r\n\r\n        self.non_morph_control_panels = {}\r\n        non_morph_categories = [\r\n            PoseParameterCategory.IRIS_ROTATION,\r\n            PoseParameterCategory.FACE_ROTATION,\r\n            PoseParameterCategory.BODY_ROTATION,\r\n            PoseParameterCategory.BREATHING\r\n        ]\r\n        for category in non_morph_categories:\r\n            param_groups = self.poser.get_pose_parameter_groups()\r\n            filtered_param_groups = [group for group in param_groups if group.get_category() == category]\r\n            if len(filtered_param_groups) == 0:\r\n                continue\r\n            control_panel = SimpleParamGroupsControlPanel(\r\n                self.control_panel,\r\n                category,\r\n                self.poser.get_pose_parameter_groups())\r\n            self.non_morph_control_panels[category] = control_panel\r\n            self.control_panel_sizer.Add(control_panel, 0, wx.EXPAND)\r\n\r\n        self.control_panel_sizer.Fit(self.control_panel)\r\n        self.main_sizer.Add(self.control_panel, 1, wx.FIXED_MINSIZE)\r\n\r\n    def init_right_panel(self):\r\n        self.right_panel = wx.Panel(self, style=wx.SIMPLE_BORDER)\r\n        right_panel_sizer = wx.BoxSizer(wx.VERTICAL)\r\n        self.right_panel.SetSizer(right_panel_sizer)\r\n        self.right_panel.SetAutoLayout(1)\r\n\r\n        self.result_image_panel = wx.Panel(self.right_panel,\r\n                                           size=(self.image_size, self.image_size),\r\n                                           style=wx.SIMPLE_BORDER)\r\n        self.result_image_panel.Bind(wx.EVT_PAINT, self.paint_result_image_panel)\r\n        self.result_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)\r\n        self.output_index_choice = wx.Choice(\r\n            self.right_panel,\r\n            choices=[str(i) for i in range(self.poser.get_output_length())])\r\n        self.output_index_choice.SetSelection(0)\r\n        right_panel_sizer.Add(self.result_image_panel, 0, wx.FIXED_MINSIZE)\r\n        right_panel_sizer.Add(self.output_index_choice, 0, wx.EXPAND)\r\n\r\n        self.save_image_button = wx.Button(self.right_panel, wx.ID_ANY, \"\\nSave Image\\n\\n\")\r\n        right_panel_sizer.Add(self.save_image_button, 1, wx.EXPAND)\r\n        self.save_image_button.Bind(wx.EVT_BUTTON, self.on_save_image)\r\n\r\n        right_panel_sizer.Fit(self.right_panel)\r\n        self.main_sizer.Add(self.right_panel, 0, wx.FIXED_MINSIZE)\r\n\r\n    def create_param_category_choice(self, param_category: PoseParameterCategory):\r\n        params = []\r\n        for param_group in self.poser.get_pose_parameter_groups():\r\n            if param_group.get_category() == param_category:\r\n                params.append(param_group.get_group_name())\r\n        choice = wx.Choice(self.control_panel, choices=params)\r\n        if len(params) > 0:\r\n            choice.SetSelection(0)\r\n        return choice\r\n\r\n    def load_image(self, event: wx.Event):\r\n        dir_name = \"data/images\"\r\n        file_dialog = wx.FileDialog(self, \"Choose an image\", dir_name, \"\", \"*.png\", wx.FD_OPEN)\r\n        if file_dialog.ShowModal() == wx.ID_OK:\r\n            image_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename())\r\n            try:\r\n                pil_image = resize_PIL_image(PIL.Image.open(image_file_name),\r\n                                             (self.poser.get_image_size(), self.poser.get_image_size()))\r\n                w, h = pil_image.size\r\n                if pil_image.mode != 'RGBA':\r\n                    self.source_image_string = \"Image must have alpha channel!\"\r\n                    self.wx_source_image = None\r\n                    self.torch_source_image = None\r\n                else:\r\n                    self.wx_source_image = wx.Bitmap.FromBufferRGBA(w, h, pil_image.convert(\"RGBA\").tobytes())\r\n                    self.torch_source_image = extract_pytorch_image_from_PIL_image(pil_image) \\\r\n                        .to(self.device).to(self.dtype)\r\n                self.source_image_dirty = True\r\n                self.Refresh()\r\n                self.Update()\r\n            except:\r\n                message_dialog = wx.MessageDialog(self, \"Could not load image \" + image_file_name, \"Poser\", wx.OK)\r\n                message_dialog.ShowModal()\r\n                message_dialog.Destroy()\r\n        file_dialog.Destroy()\r\n\r\n    def paint_source_image_panel(self, event: wx.Event):\r\n        wx.BufferedPaintDC(self.source_image_panel, self.source_image_bitmap)\r\n\r\n    def paint_result_image_panel(self, event: wx.Event):\r\n        wx.BufferedPaintDC(self.result_image_panel, self.result_image_bitmap)\r\n\r\n    def draw_nothing_yet_string_to_bitmap(self, bitmap):\r\n        dc = wx.MemoryDC()\r\n        dc.SelectObject(bitmap)\r\n\r\n        dc.Clear()\r\n        font = wx.Font(wx.FontInfo(14).Family(wx.FONTFAMILY_SWISS))\r\n        dc.SetFont(font)\r\n        w, h = dc.GetTextExtent(\"Nothing yet!\")\r\n        dc.DrawText(\"Nothing yet!\", (self.image_size - w) // 2, (self.image_size - - h) // 2)\r\n\r\n        del dc\r\n\r\n    def get_current_pose(self):\r\n        current_pose = [0.0 for i in range(self.poser.get_num_parameters())]\r\n        for morph_control_panel in self.morph_control_panels.values():\r\n            morph_control_panel.set_param_value(current_pose)\r\n        for rotation_control_panel in self.non_morph_control_panels.values():\r\n            rotation_control_panel.set_param_value(current_pose)\r\n        return current_pose\r\n\r\n    def update_images(self, event: wx.Event):\r\n        current_pose = self.get_current_pose()\r\n        if not self.source_image_dirty \\\r\n                and self.last_pose is not None \\\r\n                and self.last_pose == current_pose \\\r\n                and self.last_output_index == self.output_index_choice.GetSelection():\r\n            return\r\n        self.last_pose = current_pose\r\n        self.last_output_index = self.output_index_choice.GetSelection()\r\n\r\n        if self.torch_source_image is None:\r\n            self.draw_nothing_yet_string_to_bitmap(self.source_image_bitmap)\r\n            self.draw_nothing_yet_string_to_bitmap(self.result_image_bitmap)\r\n            self.source_image_dirty = False\r\n            self.Refresh()\r\n            self.Update()\r\n            return\r\n\r\n        if self.source_image_dirty:\r\n            dc = wx.MemoryDC()\r\n            dc.SelectObject(self.source_image_bitmap)\r\n            dc.Clear()\r\n            dc.DrawBitmap(self.wx_source_image, 0, 0)\r\n            self.source_image_dirty = False\r\n\r\n        pose = torch.tensor(current_pose, device=self.device, dtype=self.dtype)\r\n        output_index = self.output_index_choice.GetSelection()\r\n        with torch.no_grad():\r\n            start_cuda_event = torch.cuda.Event(enable_timing=True)\r\n            end_cuda_event = torch.cuda.Event(enable_timing=True)\r\n            start_cuda_event.record()\r\n            start_time = time.time()\r\n\r\n            output_image = self.poser.pose(self.torch_source_image, pose, output_index)[0].detach().cpu()\r\n\r\n            end_time = time.time()\r\n            end_cuda_event.record()\r\n            torch.cuda.synchronize()\r\n            print(\"cuda time (ms):\", start_cuda_event.elapsed_time(end_cuda_event))\r\n            print(\"elapsed time (ms):\", (end_time - start_time) * 1000.0)\r\n\r\n        numpy_image = convert_output_image_from_torch_to_numpy(output_image)\r\n        self.last_output_numpy_image = numpy_image\r\n        wx_image = wx.ImageFromBuffer(\r\n            numpy_image.shape[0],\r\n            numpy_image.shape[1],\r\n            numpy_image[:, :, 0:3].tobytes(),\r\n            numpy_image[:, :, 3].tobytes())\r\n        wx_bitmap = wx_image.ConvertToBitmap()\r\n\r\n        dc = wx.MemoryDC()\r\n        dc.SelectObject(self.result_image_bitmap)\r\n        dc.Clear()\r\n        dc.DrawBitmap(wx_bitmap,\r\n                      (self.image_size - numpy_image.shape[0]) // 2,\r\n                      (self.image_size - numpy_image.shape[1]) // 2,\r\n                      True)\r\n        del dc\r\n\r\n        self.Refresh()\r\n        self.Update()\r\n\r\n    def on_save_image(self, event: wx.Event):\r\n        if self.last_output_numpy_image is None:\r\n            logging.info(\"There is no output image to save!!!\")\r\n            return\r\n\r\n        dir_name = \"data/images\"\r\n        file_dialog = wx.FileDialog(self, \"Choose an image\", dir_name, \"\", \"*.png\", wx.FD_SAVE)\r\n        if file_dialog.ShowModal() == wx.ID_OK:\r\n            image_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename())\r\n            try:\r\n                if os.path.exists(image_file_name):\r\n                    message_dialog = wx.MessageDialog(self, f\"Override {image_file_name}\", \"Manual Poser\",\r\n                                                      wx.YES_NO | wx.ICON_QUESTION)\r\n                    result = message_dialog.ShowModal()\r\n                    if result == wx.ID_YES:\r\n                        self.save_last_numpy_image(image_file_name)\r\n                    message_dialog.Destroy()\r\n                else:\r\n                    self.save_last_numpy_image(image_file_name)\r\n            except:\r\n                message_dialog = wx.MessageDialog(self, f\"Could not save {image_file_name}\", \"Manual Poser\", wx.OK)\r\n                message_dialog.ShowModal()\r\n                message_dialog.Destroy()\r\n        file_dialog.Destroy()\r\n\r\n    def save_last_numpy_image(self, image_file_name):\r\n        numpy_image = self.last_output_numpy_image\r\n        pil_image = PIL.Image.fromarray(numpy_image, mode='RGBA')\r\n        os.makedirs(os.path.dirname(image_file_name), exist_ok=True)\r\n        pil_image.save(image_file_name)\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    device = torch.device('cuda:0')\r\n    try:\r\n        import tha4.poser.modes.mode_07\r\n\r\n        poser = tha4.poser.modes.mode_07.create_poser(device)\r\n    except RuntimeError as e:\r\n        print(e)\r\n        sys.exit()\r\n\r\n    app = wx.App()\r\n    main_frame = MainFrame(poser, device)\r\n    main_frame.Show(True)\r\n    main_frame.timer.Start(16)\r\n    app.MainLoop()\r\n"
  },
  {
    "path": "src/tha4/charmodel/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/charmodel/character_model.py",
    "content": "import json\r\nimport os.path\r\n\r\nimport PIL.Image\r\nimport torch\r\nfrom omegaconf import OmegaConf\r\n\r\nfrom tha4.shion.base.image_util import extract_pytorch_image_from_PIL_image\r\nfrom tha4.poser.modes.mode_14 import create_poser, KEY_FACE_MORPHER, KEY_BODY_MORPHER\r\n\r\n\r\nclass CharacterModel:\r\n    def __init__(self,\r\n                 character_image_file_name: str,\r\n                 face_morpher_file_name: str,\r\n                 body_morpher_file_name: str):\r\n        self.body_morpher_file_name = body_morpher_file_name\r\n        self.face_morpher_file_name = face_morpher_file_name\r\n        self.character_image_file_name = character_image_file_name\r\n        self.poser = None\r\n        self.character_image = None\r\n\r\n    def get_poser(self, device: torch.device):\r\n        if self.poser is not None:\r\n            self.poser.to(device)\r\n        else:\r\n            self.poser = create_poser(\r\n                device,\r\n                module_file_names={\r\n                    KEY_FACE_MORPHER: self.face_morpher_file_name,\r\n                    KEY_BODY_MORPHER: self.body_morpher_file_name\r\n                })\r\n        return self.poser\r\n\r\n    def get_character_image(self, device: torch.device):\r\n        if self.character_image is None:\r\n            pil_image = PIL.Image.open(self.character_image_file_name)\r\n            if pil_image.mode != 'RGBA':\r\n                raise RuntimeError(\"Character image is not an RGBA image!\")\r\n            self.character_image = extract_pytorch_image_from_PIL_image(pil_image)\r\n        self.character_image = self.character_image.to(device)\r\n        return self.character_image\r\n\r\n    def save(self, file_name: str):\r\n        dir = os.path.dirname(file_name)\r\n        rel_char_image_file_name = os.path.relpath(self.character_image_file_name, dir)\r\n        rel_face_morpher_file_name = os.path.relpath(self.face_morpher_file_name, dir)\r\n        rel_body_morpher_file_name = os.path.relpath(self.body_morpher_file_name, dir)\r\n        data = {\r\n            \"character_image_file_name\": rel_char_image_file_name,\r\n            \"face_morpher_file_name\": rel_face_morpher_file_name,\r\n            \"body_morpher_file_name\": rel_body_morpher_file_name,\r\n        }\r\n        conf = OmegaConf.create(data)\r\n        os.makedirs(dir, exist_ok=True)\r\n        with open(file_name, \"wt\") as fout:\r\n            fout.write(OmegaConf.to_yaml(conf))\r\n\r\n    @staticmethod\r\n    def load(file_name: str):\r\n        conf = OmegaConf.to_container(OmegaConf.load(file_name))\r\n        dir = os.path.dirname(file_name)\r\n        character_image_file_name = os.path.join(dir, conf[\"character_image_file_name\"])\r\n        face_morpher_file_name = os.path.join(dir, conf[\"face_morpher_file_name\"])\r\n        body_morpher_file_name = os.path.join(dir, conf[\"body_morpher_file_name\"])\r\n        return CharacterModel(\r\n            character_image_file_name,\r\n            face_morpher_file_name,\r\n            body_morpher_file_name)\r\n"
  },
  {
    "path": "src/tha4/dataset/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/dataset/image_poses_and_aother_images_dataset.py",
    "content": "from typing import List, Callable\r\n\r\nfrom torch import Tensor\r\nfrom torch.utils.data import Dataset\r\n\r\n\r\nclass ImagePosesAndOtherImagesDataset(Dataset):\r\n    def __init__(self,\r\n                 main_image_func: Callable[[], Tensor],\r\n                 pose_dataset: Dataset,\r\n                 other_image_funcs: List[Callable[[], Tensor]]):\r\n        self.main_image_func = main_image_func\r\n        self.other_image_funcs = other_image_funcs\r\n        self.pose_dataset = pose_dataset\r\n        self.main_image = None\r\n        self.other_images = [None for i in range(len(self.other_image_funcs))]\r\n\r\n    def get_main_image(self):\r\n        if self.main_image is None:\r\n            self.main_image = self.main_image_func()\r\n        return self.main_image\r\n\r\n    def get_other_image(self, image_index: int):\r\n        if self.other_images[image_index] is None:\r\n            self.other_images[image_index] = self.other_image_funcs[image_index]()\r\n        return self.other_images[image_index]\r\n\r\n    def __len__(self):\r\n        return len(self.pose_dataset)\r\n\r\n    def __getitem__(self, index):\r\n        main_image = self.get_main_image()\r\n        pose = self.pose_dataset[index][0]\r\n        other_images = [self.get_other_image(i) for i in range(len(self.other_image_funcs))]\r\n        return [main_image, pose] + other_images\r\n"
  },
  {
    "path": "src/tha4/distiller/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/distiller/config_based_training_tasks.py",
    "content": "import logging\r\nimport os\r\nimport sys\r\nfrom typing import Callable, List, Optional\r\n\r\nfrom tha4.pytasuku.workspace import Workspace\r\nfrom tha4.shion.core.training.distrib.distributed_trainer import DistributedTrainer\r\nfrom tha4.shion.core.training.distrib.distributed_training_states import DistributedTrainingState\r\n\r\n\r\ndef get_torchrun_executable():\r\n    return os.path.dirname(sys.executable) + os.path.sep + \"torchrun\"\r\n\r\n\r\nclass RdzvConfig:\r\n    def __init__(self, id: int, port: int):\r\n        self.port = port\r\n        self.id = id\r\n\r\n\r\ndef run_standalone_config_based_training_script(\r\n        training_script_file_name: str,\r\n        config_file_name: str,\r\n        num_proc_per_node: int,\r\n        target_checkpoint_examples: Optional[int] = None,\r\n        rdzv_config: Optional[RdzvConfig] = None):\r\n    command = f\"{get_torchrun_executable()} \" \\\r\n              f\"--nnodes=1 \" \\\r\n              f\"--nproc_per_node={num_proc_per_node} \"\r\n    if rdzv_config is not None:\r\n        command += f\"--rdzv_endpoint=localhost:{rdzv_config.port} \"\r\n        command += \"--rdzv_backend=c10d \"\r\n        command += f\"--rdzv_id={rdzv_config.id} \"\r\n    else:\r\n        command += \"--standalone \"\r\n    command += f\"{training_script_file_name} \"\r\n    if target_checkpoint_examples is not None:\r\n        command += f\"--target_checkpoint_examples {target_checkpoint_examples} \"\r\n        command += f\"--config_file={config_file_name} \"\r\n    logging.info(f\"Executing -- {command}\")\r\n    os.system(command)\r\n\r\n\r\ndef define_standalone_config_based_training_tasks(\r\n        workspace: Workspace,\r\n        distributed_trainer_func: Callable[[], DistributedTrainer],\r\n        training_script_file_name: str,\r\n        config_file_name: str,\r\n        num_proc_per_node: int,\r\n        dependencies: Optional[List[str]] = None,\r\n        rdzv_config: Optional[RdzvConfig] = None):\r\n    trainer = distributed_trainer_func()\r\n    checkpoint_examples = trainer.training_protocol.get_checkpoint_examples()\r\n    assert len(checkpoint_examples) >= 1\r\n    assert checkpoint_examples[0] > 0\r\n    checkpoint_examples = [0] + checkpoint_examples\r\n\r\n    if dependencies is None:\r\n        dependencies = []\r\n    module_file_dependencies = dependencies[:]\r\n    for module_name in trainer.pretrained_module_file_names:\r\n        module_file_dependencies.append(trainer.pretrained_module_file_names[module_name])\r\n\r\n    def create_train_func(target_checkpoint_examples: int):\r\n        return lambda: run_standalone_config_based_training_script(\r\n            training_script_file_name,\r\n            config_file_name,\r\n            num_proc_per_node,\r\n            target_checkpoint_examples,\r\n            rdzv_config=rdzv_config)\r\n\r\n    train_tasks = []\r\n    for checkpoint_index in range(0, len(checkpoint_examples)):\r\n        for module_name in trainer.module_names:\r\n            module_file_name = DistributedTrainingState.get_module_file_name(\r\n                trainer.get_checkpoint_prefix(checkpoint_index),\r\n                module_name)\r\n            workspace.create_file_task(\r\n                module_file_name,\r\n                module_file_dependencies,\r\n                create_train_func(trainer.checkpoint_examples[checkpoint_index]))\r\n        for module_name in trainer.accumulators:\r\n            accumulated_module_file_name = DistributedTrainingState.get_accumulated_module_file_name(\r\n                trainer.get_checkpoint_prefix(checkpoint_index),\r\n                module_name)\r\n            workspace.create_file_task(\r\n                accumulated_module_file_name,\r\n                module_file_dependencies,\r\n                create_train_func(checkpoint_examples[checkpoint_index]))\r\n        workspace.create_command_task(\r\n            trainer.get_checkpoint_prefix(checkpoint_index) + \"/train_standalone\",\r\n            module_file_dependencies,\r\n            create_train_func(checkpoint_examples[checkpoint_index]))\r\n        train_tasks.append(trainer.get_checkpoint_prefix(checkpoint_index) + \"/train_standlone\")\r\n    workspace.create_file_task(\r\n        trainer.prefix + \"/train_standalone\",\r\n        module_file_dependencies,\r\n        create_train_func(checkpoint_examples[-1]))\r\n"
  },
  {
    "path": "src/tha4/distiller/distill_body_morpher.py",
    "content": "import logging\r\n\r\nfrom tha4.shion.core.training.distrib.distributed_trainer import DistributedTrainer\r\nfrom tha4.distiller.distiller_config import DistillerConfig\r\n\r\nif __name__ == \"__main__\":\r\n    logging.basicConfig(level=logging.INFO)\r\n\r\n    parser = DistributedTrainer.get_default_arg_parser()\r\n    parser.add_argument('--config_file', type=str)\r\n    args = parser.parse_args()\r\n\r\n    config_file_name = args.config_file\r\n    config = DistillerConfig.load(config_file_name)\r\n\r\n    DistributedTrainer.run_with_args(config.get_body_morpher_trainer, args)\r\n"
  },
  {
    "path": "src/tha4/distiller/distill_face_morpher.py",
    "content": "import logging\r\n\r\nfrom tha4.shion.core.training.distrib.distributed_trainer import DistributedTrainer\r\nfrom tha4.distiller.distiller_config import DistillerConfig\r\n\r\nif __name__ == \"__main__\":\r\n    logging.basicConfig(level=logging.INFO)\r\n\r\n    parser = DistributedTrainer.get_default_arg_parser()\r\n    parser.add_argument('--config_file', type=str)\r\n    args = parser.parse_args()\r\n\r\n    config_file_name = args.config_file\r\n    config = DistillerConfig.load(config_file_name)\r\n\r\n    DistributedTrainer.run_with_args(config.get_face_morpher_trainer, args)\r\n"
  },
  {
    "path": "src/tha4/distiller/distiller_config.py",
    "content": "import os.path\r\nimport shutil\r\nimport PIL.Image\r\nfrom dataclasses import dataclass\r\nfrom typing import Optional\r\n\r\nfrom omegaconf import OmegaConf\r\nfrom tha4.charmodel.character_model import CharacterModel\r\nfrom tha4.pytasuku.workspace import Workspace, file_task\r\nfrom tha4.distiller.config_based_training_tasks import define_standalone_config_based_training_tasks\r\nfrom tha4.nn.siren.face_morpher.siren_face_morpher_00_trainer import SirenFaceMorpher00TrainerArgs\r\nfrom tha4.nn.siren.morpher.siren_morpher_03_trainer import SirenMorpher03TrainerArgs, TrainingPhases, TrainingPhase, \\\r\n    LossWeights, LossTerm\r\nfrom tha4.shion.base.image_util import pil_image_has_transparency\r\n\r\nPOSE_DATASET_FILE_NAME = 'data/pose_dataset.pt'\r\n\r\n\r\ndef copy_file(source_file_name: str, dest_file_name):\r\n    os.makedirs(os.path.dirname(dest_file_name), exist_ok=True)\r\n    shutil.copyfile(source_file_name, dest_file_name)\r\n\r\n\r\n@dataclass\r\nclass DistillerConfig:\r\n    prefix: str\r\n    character_image_file_name: str\r\n    face_mask_image_file_name: str\r\n\r\n    face_morpher_random_seed_0: int = 12771885812175595441\r\n    face_morpher_random_seed_1: int = 14367217090963479175\r\n    face_morpher_num_training_examples_per_sample_output: Optional[int] = 10_000\r\n    face_morpher_batch_size: int = 8\r\n\r\n    body_morpher_random_seed_0: int = 2892221210020292507\r\n    body_morpher_random_seed_1: int = 9998918537095922080\r\n    body_morpher_num_training_examples_per_sample_output: Optional[int] = 10_000\r\n    body_morpher_batch_size: int = 8\r\n\r\n    num_cpu_workers: int = 1\r\n    num_gpus: int = 1\r\n\r\n    def check(self):\r\n        DistillerConfig.check_prefix(self.prefix)\r\n        DistillerConfig.check_character_image_file_name(self.character_image_file_name)\r\n        DistillerConfig.check_face_mask_image_file_name(self.face_mask_image_file_name)\r\n\r\n        DistillerConfig.check_num_cpu_workers(self.num_cpu_workers)\r\n        DistillerConfig.check_num_gpus(self.num_gpus)\r\n\r\n        DistillerConfig.check_random_seed(self.face_morpher_random_seed_0, \"face_morpher_random_seed_0\")\r\n        DistillerConfig.check_random_seed(self.face_morpher_random_seed_1, \"face_morpher_random_seed_1\")\r\n        DistillerConfig.check_batch_size(self.face_morpher_batch_size, \"face_morpher_batch_size\")\r\n        DistillerConfig.check_num_training_examples_per_sample_output(\r\n            self.face_morpher_num_training_examples_per_sample_output,\r\n            \"face_morpher_num_training_examples_per_sample_output\")\r\n\r\n        DistillerConfig.check_random_seed(self.body_morpher_random_seed_0, \"body_morpher_random_seed_0\")\r\n        DistillerConfig.check_random_seed(self.body_morpher_random_seed_1, \"body_morpher_random_seed_1\")\r\n        DistillerConfig.check_batch_size(self.body_morpher_batch_size, \"body_morpher_batch_size\")\r\n        DistillerConfig.check_num_training_examples_per_sample_output(\r\n            self.body_morpher_num_training_examples_per_sample_output,\r\n            \"body_morpher_num_training_examples_per_sample_output\")\r\n\r\n    @staticmethod\r\n    def check_prefix(prefix):\r\n        assert os.path.isdir(prefix), \"The 'prefix' must be a directory.\"\r\n        assert os.path.exists(prefix), f\"The {prefix} directory does not exist.\"\r\n\r\n    @staticmethod\r\n    def check_character_image_file_name(file_name):\r\n        _, ext = os.path.splitext(file_name)\r\n        assert os.path.isfile(file_name), \\\r\n            f\"The specified character image file name, {file_name}, does not point to a file.\"\r\n        assert ext.lower() == \".png\", \"The character image file name must have extension '.png'.\"\r\n\r\n        image = PIL.Image.open(file_name)\r\n        assert pil_image_has_transparency(image), \"The character image must have an alpha channel.\"\r\n        assert image.width == 512 and image.height == 512, \"The character image must be 512x512.\"\r\n        image.close()\r\n\r\n    @staticmethod\r\n    def check_face_mask_image_file_name(file_name):\r\n        _, ext = os.path.splitext(file_name)\r\n        assert os.path.isfile(file_name), \\\r\n            f\"The specified face mask image file name, {file_name}, does not point to a file.\"\r\n        assert ext.lower() == \".png\", \"The face mask image file name must have extension '.png'.\"\r\n\r\n        image = PIL.Image.open(file_name)\r\n        assert image.width == 512 and image.height == 512, \"The face mask image must be 512x512.\"\r\n        assert image.mode == \"RGB\", \"The face mask image must be an RGB image.\"\r\n        for x in range(512):\r\n            for y in range(512):\r\n                r, g, b = image.getpixel((x, y))\r\n                assert (r == 0) or (r == 255), \"The R channel of the face mask image must be 0 or 255\"\r\n                assert (g == 0) or (g == 255), \"The G channel of the face mask image must be 0 or 255\"\r\n                assert (b == 0) or (b == 255), \"The B channel of the face mask image must be 0 or 255\"\r\n        image.close()\r\n\r\n    @staticmethod\r\n    def check_batch_size(value, field_name: str):\r\n        assert isinstance(value, int), f\"The {field_name} must be an integer.\"\r\n        assert value >= 1, f\"The {field_name} must be at least 1.\"\r\n        assert value <= 8, f\"The {field_name} must be at most 8.\"\r\n\r\n    @staticmethod\r\n    def check_num_cpu_workers(value):\r\n        assert value >= 1, \"The value of 'num_cpu_workers must be at least 1.\"\r\n\r\n    @staticmethod\r\n    def check_num_gpus(value):\r\n        assert value >= 1, \"The value of 'num_gpus' must be at least 1.\"\r\n\r\n    @staticmethod\r\n    def check_random_seed(value, field_name: str):\r\n        assert isinstance(value, int), f\"The {field_name} must be an integer.\"\r\n        assert value >= 0 and value <= 0x_ffff_ffff_ffff_ffff, \"A random seed must be between 0 and 2**64-1.\"\r\n\r\n    @staticmethod\r\n    def check_num_training_examples_per_sample_output(value, field_name):\r\n        assert value in [10_000, 100_000, 1_000_000,\r\n                         None], f\"The {field_name} must be 10_000, 100_00, 1_000_000_000, or None.\"\r\n\r\n    def save(self, file_name: str):\r\n        conf = OmegaConf.structured(self)\r\n        os.makedirs(self.prefix, exist_ok=True)\r\n        with open(file_name, \"wt\") as fout:\r\n            fout.write(OmegaConf.to_yaml(conf))\r\n\r\n    def config_yaml_file_name(self):\r\n        return f\"{self.prefix}/config.yaml\"\r\n\r\n    def create_config_yaml_file(self):\r\n        if os.path.exists(self.config_yaml_file_name()):\r\n            return\r\n        self.save(self.config_yaml_file_name())\r\n\r\n    @staticmethod\r\n    def load(file_name: str) -> 'DistillerConfig':\r\n        conf = OmegaConf.to_container(OmegaConf.load(file_name))\r\n        args = DistillerConfig(**conf)\r\n        args.check()\r\n        return args\r\n\r\n    def face_morpher_prefix(self):\r\n        return f\"{self.prefix}/face_morpher\"\r\n\r\n    def get_face_morpher_trainer(self, world_size: Optional[int] = None, backend: str = 'gloo'):\r\n        if world_size is None:\r\n            world_size = self.num_gpus\r\n        args = SirenFaceMorpher00TrainerArgs(\r\n            character_file_name=self.character_image_file_name,\r\n            face_mask_file_name=self.face_mask_image_file_name,\r\n            pose_dataset_file_name=POSE_DATASET_FILE_NAME,\r\n            total_worker=self.num_cpu_workers,\r\n            num_training_examples_per_sample_output=self.face_morpher_num_training_examples_per_sample_output,\r\n            total_batch_size=self.face_morpher_batch_size,\r\n            training_random_seed=self.face_morpher_random_seed_0,\r\n            sample_output_random_seed=self.face_morpher_random_seed_1)\r\n        return args.create_trainer(self.face_morpher_prefix(), world_size, backend)\r\n\r\n    def body_morpher_prefix(self):\r\n        return f\"{self.prefix}/body_morpher\"\r\n\r\n    def get_body_morpher_trainer(self, world_size: Optional[int] = None, backend: str = 'gloo'):\r\n        if world_size is None:\r\n            world_size = self.num_gpus\r\n        args = SirenMorpher03TrainerArgs(\r\n            character_file_name=self.character_image_file_name,\r\n            pose_dataset_file_name=POSE_DATASET_FILE_NAME,\r\n            total_worker=self.num_cpu_workers,\r\n            num_training_examples_per_sample_output=self.body_morpher_num_training_examples_per_sample_output,\r\n            training_random_seed=self.body_morpher_random_seed_0,\r\n            sample_output_random_seed=self.body_morpher_random_seed_1,\r\n            total_batch_size=self.body_morpher_batch_size,\r\n            sample_output_batch_size=1,\r\n            training_phases=TrainingPhases([\r\n                TrainingPhase(\r\n                    num_examples_upper_bound=200_000,\r\n                    learning_rate=1e-4,\r\n                    loss_weights=LossWeights(weights={\r\n                        LossTerm.full_blended: 0.25,\r\n                        LossTerm.full_warped: 0.25,\r\n                        LossTerm.full_grid_change: 0.5,\r\n                        LossTerm.full_color_change: 2.0,\r\n                    })),\r\n                TrainingPhase(\r\n                    num_examples_upper_bound=400_000,\r\n                    learning_rate=3e-5,\r\n                    loss_weights=LossWeights(weights={\r\n                        LossTerm.full_blended: 0.25,\r\n                        LossTerm.full_warped: 0.25,\r\n                        LossTerm.full_grid_change: 0.5,\r\n                        LossTerm.full_color_change: 2.0,\r\n                    })),\r\n                TrainingPhase(\r\n                    num_examples_upper_bound=600_000,\r\n                    learning_rate=3e-5,\r\n                    loss_weights=LossWeights(weights={\r\n                        LossTerm.full_blended: 1.0,\r\n                        LossTerm.full_warped: 2.5,\r\n                        LossTerm.full_grid_change: 5.0,\r\n                        LossTerm.full_color_change: 1.0,\r\n                    })),\r\n                TrainingPhase(\r\n                    num_examples_upper_bound=800_000,\r\n                    learning_rate=1e-5,\r\n                    loss_weights=LossWeights(weights={\r\n                        LossTerm.full_blended: 1.0,\r\n                        LossTerm.full_warped: 2.5,\r\n                        LossTerm.full_grid_change: 5.0,\r\n                        LossTerm.full_color_change: 1.0,\r\n                    })),\r\n                TrainingPhase(\r\n                    num_examples_upper_bound=1_300_000,\r\n                    learning_rate=1e-5,\r\n                    loss_weights=LossWeights(weights={\r\n                        LossTerm.full_blended: 10.0,\r\n                        LossTerm.full_warped: 1.0,\r\n                        LossTerm.full_grid_change: 1.0,\r\n                        LossTerm.full_color_change: 1.0,\r\n                    })),\r\n                TrainingPhase(\r\n                    num_examples_upper_bound=1_500_000,\r\n                    learning_rate=3e-6,\r\n                    loss_weights=LossWeights(weights={\r\n                        LossTerm.full_blended: 10.0,\r\n                        LossTerm.full_warped: 1.0,\r\n                        LossTerm.full_grid_change: 1.0,\r\n                        LossTerm.full_color_change: 1.0,\r\n                    })),\r\n            ]))\r\n        return args.create_trainer(self.body_morpher_prefix(), world_size, backend)\r\n\r\n    def character_model_prefix(self):\r\n        return f\"{self.prefix}/character_model\"\r\n\r\n    def character_model_face_morpher_file_name(self):\r\n        return f\"{self.character_model_prefix()}/face_morpher.pt\"\r\n\r\n    def character_model_body_morpher_file_name(self):\r\n        return f\"{self.character_model_prefix()}/body_morpher.pt\"\r\n\r\n    def character_model_character_png_file_name(self):\r\n        return f\"{self.character_model_prefix()}/character.png\"\r\n\r\n    def character_model_yaml_file_name(self):\r\n        return f\"{self.character_model_prefix()}/character_model.yaml\"\r\n\r\n    def define_tasks(self, workspace: Workspace):\r\n        workspace.create_file_task(self.config_yaml_file_name(), [], self.create_config_yaml_file)\r\n\r\n        define_standalone_config_based_training_tasks(\r\n            workspace,\r\n            self.get_face_morpher_trainer,\r\n            \"src/tha4/distiller/distill_face_morpher.py\",\r\n            self.config_yaml_file_name(),\r\n            num_proc_per_node=self.num_gpus,\r\n            dependencies=[\r\n                self.config_yaml_file_name(),\r\n            ])\r\n\r\n        define_standalone_config_based_training_tasks(\r\n            workspace,\r\n            self.get_body_morpher_trainer,\r\n            \"src/tha4/distiller/distill_body_morpher.py\",\r\n            self.config_yaml_file_name(),\r\n            num_proc_per_node=self.num_gpus,\r\n            dependencies=[\r\n                self.config_yaml_file_name(),\r\n            ])\r\n\r\n        @file_task(workspace, self.character_model_character_png_file_name(), [self.character_image_file_name])\r\n        def copy_character_image_file_name():\r\n            copy_file(self.character_image_file_name, self.character_model_character_png_file_name())\r\n\r\n        @file_task(workspace, self.character_model_face_morpher_file_name(), [\r\n            f\"{self.face_morpher_prefix()}/checkpoint/0010/module_module.pt\",\r\n        ])\r\n        def copy_face_morpher():\r\n            copy_file(\r\n                f\"{self.face_morpher_prefix()}/checkpoint/0010/module_module.pt\",\r\n                self.character_model_face_morpher_file_name())\r\n\r\n        @file_task(workspace, self.character_model_body_morpher_file_name(), [\r\n            f\"{self.body_morpher_prefix()}/checkpoint/0015/module_module.pt\",\r\n        ])\r\n        def copy_face_morpher():\r\n            copy_file(\r\n                f\"{self.body_morpher_prefix()}/checkpoint/0015/module_module.pt\",\r\n                self.character_model_body_morpher_file_name())\r\n\r\n        @file_task(workspace, self.character_model_yaml_file_name(), [])\r\n        def create_character_model_yaml_file():\r\n            character_model = CharacterModel(\r\n                self.character_model_character_png_file_name(),\r\n                self.character_model_face_morpher_file_name(),\r\n                self.character_model_body_morpher_file_name())\r\n            character_model.save(self.character_model_yaml_file_name())\r\n\r\n        workspace.create_command_task(\r\n            f\"{self.prefix}/all\",\r\n            [\r\n                f\"{self.face_morpher_prefix()}/train_standalone\",\r\n                f\"{self.body_morpher_prefix()}/train_standalone\",\r\n                self.character_model_character_png_file_name(),\r\n                self.character_model_face_morpher_file_name(),\r\n                self.character_model_body_morpher_file_name(),\r\n                self.character_model_yaml_file_name(),\r\n            ])\r\n"
  },
  {
    "path": "src/tha4/distiller/ui/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/distiller/ui/distiller_config_state.py",
    "content": "import os.path\r\nfrom contextlib import contextmanager\r\nfrom pathlib import PurePath, Path\r\nfrom typing import Callable, Any, Optional\r\n\r\nfrom tha4.distiller.distiller_config import DistillerConfig\r\n\r\n\r\nclass DistillerConfigState:\r\n    def __init__(self):\r\n        self.config = DistillerConfig(prefix=\"\", character_image_file_name=\"\", face_mask_image_file_name=\"\")\r\n        self.last_saved_timestamp = None\r\n        self.dirty = False\r\n\r\n    def load(self, file_name):\r\n        self.config = DistillerConfig.load(file_name)\r\n        if os.path.exists(self.config.config_yaml_file_name()):\r\n            self.last_saved_timestamp = os.path.getmtime(self.config.config_yaml_file_name())\r\n        else:\r\n            self.last_saved_timestamp = None\r\n        self.dirty = False\r\n\r\n    def need_to_check_overwrite(self):\r\n        if self.last_saved_timestamp is None:\r\n            return True\r\n        if not os.path.exists(self.config.config_yaml_file_name()):\r\n            return False\r\n        if self.last_saved_timestamp < os.path.getmtime(self.config.config_yaml_file_name()):\r\n            return True\r\n        return False\r\n\r\n    def save(self):\r\n        self.config.save(self.config.config_yaml_file_name())\r\n        self.dirty = False\r\n        self.last_saved_timestamp = os.path.getmtime(self.config.config_yaml_file_name())\r\n\r\n    @contextmanager\r\n    def updating_value(self, value_func: Callable[[], Any]):\r\n        old_value = value_func()\r\n        yield\r\n        new_value = value_func()\r\n        if new_value != old_value:\r\n            self.dirty = True\r\n\r\n    def set_prefix(self, new_value):\r\n        with self.updating_value(lambda: self.config.prefix):\r\n            new_relative_path = self.get_relative_path_to_cwd(\r\n                new_value,\r\n                \"The prefix directory must be a subdirectory of the talking-head-anime-4-demo's source code directory.\")\r\n            DistillerConfig.check_prefix(new_relative_path)\r\n            self.config.prefix = new_relative_path\r\n\r\n    def set_character_image_file_name(self, new_value):\r\n        with self.updating_value(lambda: self.config.character_image_file_name):\r\n            new_relative_path = self.get_relative_path_to_cwd(\r\n                new_value,\r\n                \"The character image file must be under talking-head-anime-4-demo's source code directory.\")\r\n            DistillerConfig.check_character_image_file_name(new_relative_path)\r\n            self.config.character_image_file_name = new_relative_path\r\n\r\n    def set_face_mask_image_file_name(self, new_value):\r\n        with self.updating_value(lambda: self.config.face_mask_image_file_name):\r\n            new_relative_path = self.get_relative_path_to_cwd(\r\n                new_value,\r\n                \"The face mask image file must be under talking-head-anime-4-demo's source code directory.\")\r\n            DistillerConfig.check_face_mask_image_file_name(new_relative_path)\r\n            self.config.face_mask_image_file_name = new_relative_path\r\n\r\n    def set_num_cpu_workers(self, new_value: int):\r\n        with self.updating_value(lambda: self.config.num_cpu_workers):\r\n            DistillerConfig.check_num_cpu_workers(new_value)\r\n            self.config.num_cpu_workers = new_value\r\n\r\n    def set_num_gpus(self, new_value: int):\r\n        with self.updating_value(lambda: self.config.num_gpus):\r\n            DistillerConfig.check_num_cpu_workers(new_value)\r\n            self.config.num_gpus = new_value\r\n\r\n    def set_face_morpher_random_seed_0(self, new_value: int):\r\n        with self.updating_value(lambda: self.config.face_morpher_random_seed_0):\r\n            DistillerConfig.check_random_seed(new_value, \"face_morpher_random_seed_0\")\r\n            self.config.face_morpher_random_seed_0 = new_value\r\n\r\n    def set_face_morpher_random_seed_1(self, new_value: int):\r\n        with self.updating_value(lambda: self.config.face_morpher_random_seed_1):\r\n            DistillerConfig.check_random_seed(new_value, \"face_morpher_random_seed_1\")\r\n            self.config.face_morpher_random_seed_1 = new_value\r\n\r\n    def set_face_morpher_num_training_examples_per_sample_output(self, new_value: Optional[int]):\r\n        with self.updating_value(lambda: self.config.face_morpher_num_training_examples_per_sample_output):\r\n            DistillerConfig.check_num_training_examples_per_sample_output(\r\n                new_value, \"face_morpher_num_training_examples_per_sample_output\")\r\n            self.config.face_morpher_num_training_examples_per_sample_output = new_value\r\n\r\n    def set_face_morpher_batch_size(self, new_value: int):\r\n        with self.updating_value(lambda: self.config.face_morpher_batch_size):\r\n            DistillerConfig.check_batch_size(new_value, \"face_morpher_batch_size\")\r\n            self.config.face_morpher_batch_size = new_value\r\n\r\n    def set_body_morpher_random_seed_0(self, new_value: int):\r\n        with self.updating_value(lambda: self.config.body_morpher_random_seed_0):\r\n            DistillerConfig.check_random_seed(new_value, \"body_morpher_random_seed_0\")\r\n            self.config.body_morpher_random_seed_0 = new_value\r\n\r\n    def set_body_morpher_random_seed_1(self, new_value: int):\r\n        with self.updating_value(lambda: self.config.body_morpher_random_seed_1):\r\n            DistillerConfig.check_random_seed(new_value, \"body_morpher_random_seed_1\")\r\n            self.config.body_morpher_random_seed_1 = new_value\r\n\r\n    def set_body_morpher_num_training_examples_per_sample_output(self, new_value: Optional[int]):\r\n        with self.updating_value(lambda: self.config.body_morpher_num_training_examples_per_sample_output):\r\n            DistillerConfig.check_num_training_examples_per_sample_output(\r\n                new_value, \"body_morpher_num_training_examples_per_sample_output\")\r\n            self.config.body_morpher_num_training_examples_per_sample_output = new_value\r\n\r\n    def set_body_morpher_batch_size(self, new_value: int):\r\n        with self.updating_value(lambda: self.config.body_morpher_batch_size):\r\n            DistillerConfig.check_batch_size(new_value, \"body_morpher_batch_size\")\r\n            self.config.body_morpher_batch_size = new_value\r\n\r\n    def get_relative_path_to_cwd(self, file_name: str, message: str):\r\n        cwd = os.getcwd()\r\n        assert os.path.commonprefix([cwd, file_name]) == cwd, message\r\n        cwd_path = Path(cwd).as_posix()\r\n        new_path = Path(file_name).as_posix()\r\n        new_relative_path = os.path.relpath(str(new_path), cwd_path)\r\n        new_relative_path = str(Path(new_relative_path).as_posix())\r\n        return new_relative_path\r\n\r\n    def can_show_character_image(self):\r\n        return os.path.isfile(self.config.character_image_file_name)\r\n\r\n    def can_show_face_mask_image(self):\r\n        return os.path.isfile(self.config.face_mask_image_file_name)\r\n\r\n    def can_show_mask_on_face_image(self):\r\n        return self.can_show_character_image() and self.can_show_face_mask_image()\r\n\r\n    def can_save(self):\r\n        return os.path.isdir(self.config.prefix) \\\r\n            and os.path.isfile(self.config.character_image_file_name) \\\r\n            and os.path.isfile(self.config.face_mask_image_file_name)\r\n"
  },
  {
    "path": "src/tha4/distiller/ui/distiller_ui_main_frame.py",
    "content": "import multiprocessing\r\nimport random\r\nfrom contextlib import contextmanager\r\nfrom typing import Callable\r\nimport PIL.Image\r\n\r\nimport torch\r\nimport wx\r\nimport wx.html\r\nimport wx.lib.intctrl\r\nfrom tha4.distiller.ui.distiller_config_state import DistillerConfigState\r\nfrom tha4.image_util import convert_output_image_from_torch_to_numpy\r\nfrom tha4.shion.base.image_util import extract_pytorch_image_from_PIL_image\r\n\r\n\r\ndef wx_bind_event(widget, evt):\r\n    def f(handler):\r\n        widget.Bind(evt, handler)\r\n        return handler\r\n\r\n    return f\r\n\r\n\r\nclass DistillerUiMainFrame(wx.Frame):\r\n    PARAM_NAME_STATIC_TEXT_MIN_WIDTH = 400\r\n    NUM_TRAINING_EXAMPLES_PER_SAMPLE_OUTPUT_CHOICES = [\r\n        \"10_000\", \"100_000\", \"1_000_000\", \"Do not generate sample outputs\"]\r\n\r\n    def __init__(self):\r\n        super().__init__(None, wx.ID_ANY, \"Distiller UI\")\r\n\r\n        self.init_ui()\r\n        self.init_menus()\r\n        self.init_bitmaps()\r\n        self.Bind(wx.EVT_CLOSE, self.on_close)\r\n\r\n        self.state = DistillerConfigState()\r\n        self.update_ui()\r\n\r\n        self.config_file_to_run = None\r\n\r\n    def init_ui(self):\r\n        main_sizer = wx.BoxSizer(wx.HORIZONTAL)\r\n\r\n        self.SetSizer(main_sizer)\r\n        self.SetAutoLayout(1)\r\n\r\n        left_panel = self.init_left_panel(self)\r\n        main_sizer.Add(left_panel, 0, wx.FIXED_MINSIZE)\r\n\r\n        middle_panel = self.init_middle_panel(self)\r\n        main_sizer.Add(middle_panel, 0, wx.EXPAND)\r\n\r\n        right_panel = self.init_right_panel(self)\r\n        main_sizer.Add(right_panel, 1, wx.EXPAND)\r\n\r\n        main_sizer.Fit(self)\r\n\r\n    def init_menus(self):\r\n        self.file_menu = wx.Menu()\r\n\r\n        self.new_menu_id = wx.Window.NewControlId()\r\n        self.file_menu.Append(\r\n            self.new_menu_id, item=\"&New\\tCTRL+N\", helpString=\"Create a new distiller configuration.\")\r\n        self.Bind(wx.EVT_MENU, self.on_new, id=self.new_menu_id)\r\n\r\n        self.open_menu_id = wx.Window.NewControlId()\r\n        self.file_menu.Append(\r\n            self.open_menu_id, item=\"&Open\\tCTRL+O\", helpString=\"Open a distiller confuguration.\")\r\n        self.Bind(wx.EVT_MENU, self.on_open, id=self.open_menu_id)\r\n\r\n        self.save_menu_id = wx.Window.NewControlId()\r\n        self.save_menu_item = wx.MenuItem(\r\n            self.file_menu, id=self.save_menu_id, text=\"&Save\\tCTRL+S\",\r\n            helpString=\"Save the current distiller configuration. Error message will be shown it it is not well formed.\")\r\n        self.Bind(wx.EVT_MENU, self.on_save, id=self.save_menu_id)\r\n        self.file_menu.Append(self.save_menu_item)\r\n\r\n        self.file_menu.AppendSeparator()\r\n\r\n        self.exit_menu_id = wx.ID_EXIT\r\n        self.file_menu.Append(\r\n            self.exit_menu_id, item=\"E&xit\\tCTRL+Q\", helpString=\"Exit the application.\")\r\n        self.Bind(wx.EVT_MENU, self.on_close, id=self.exit_menu_id)\r\n\r\n        self.menu_bar = wx.MenuBar()\r\n        self.menu_bar.Append(self.file_menu, \"&File\")\r\n\r\n        self.SetMenuBar(self.menu_bar)\r\n\r\n    def init_bitmaps(self):\r\n        self.face_image_bitmap = wx.Bitmap(128, 128)\r\n        self.face_image_pytorch = None\r\n        self.face_mask_image_bitmap = wx.Bitmap(128, 128)\r\n        self.face_mask_image_pytorch = None\r\n        self.mask_on_face_image_bitmap = wx.Bitmap(128, 128)\r\n        self.draw_nothing_yet_string_to_bitmap(self.face_image_bitmap, 128, 128)\r\n        self.draw_nothing_yet_string_to_bitmap(self.face_mask_image_bitmap, 128, 128)\r\n        self.draw_nothing_yet_string_to_bitmap(self.mask_on_face_image_bitmap, 128, 128)\r\n\r\n    @contextmanager\r\n    def create_panel(self, parent, sizer, *args, **kwargs):\r\n        panel = wx.Panel(parent, *args, **kwargs)\r\n        panel.SetSizer(sizer)\r\n        panel.SetAutoLayout(1)\r\n\r\n        try:\r\n            yield panel, sizer\r\n        finally:\r\n            sizer.Fit(panel)\r\n\r\n    def init_left_panel(self, parent):\r\n        with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, sizer):\r\n            self.face_image_panel = wx.Panel(panel, size=(128, 128), style=wx.SIMPLE_BORDER)\r\n            self.face_image_panel.Bind(wx.EVT_PAINT, self.on_face_image_panel_paint)\r\n            sizer.Add(self.face_image_panel, 0, wx.EXPAND)\r\n\r\n            static_text = wx.StaticText(panel, label=\"Face\", style=wx.ALIGN_CENTER)\r\n            sizer.Add(static_text, 0, wx.EXPAND)\r\n\r\n            self.face_mask_image_panel = wx.Panel(panel, size=(128, 128), style=wx.SIMPLE_BORDER)\r\n            self.face_mask_image_panel.Bind(wx.EVT_PAINT, self.on_face_mask_image_panel_paint)\r\n            sizer.Add(self.face_mask_image_panel, 0, wx.EXPAND)\r\n\r\n            static_text = wx.StaticText(panel, label=\"Face mask\", style=wx.ALIGN_CENTER)\r\n            sizer.Add(static_text, 0, wx.EXPAND)\r\n\r\n            self.mask_on_face_image_panel = wx.Panel(panel, size=(128, 128), style=wx.SIMPLE_BORDER)\r\n            self.mask_on_face_image_panel.Bind(wx.EVT_PAINT, self.on_mask_on_face_image_panel_paint)\r\n            sizer.Add(self.mask_on_face_image_panel, 0, wx.EXPAND)\r\n\r\n            static_text = wx.StaticText(panel, label=\"Mask upon face\", style=wx.ALIGN_CENTER)\r\n            sizer.Add(static_text, 0, wx.EXPAND)\r\n\r\n        return panel\r\n\r\n    def on_erase_background(self, event):\r\n        pass\r\n\r\n    def on_face_image_panel_paint(self, event):\r\n        wx.BufferedPaintDC(self.face_image_panel, self.face_image_bitmap)\r\n\r\n    def on_face_mask_image_panel_paint(self, event):\r\n        wx.BufferedPaintDC(self.face_mask_image_panel, self.face_mask_image_bitmap)\r\n\r\n    def on_mask_on_face_image_panel_paint(self, event):\r\n        wx.BufferedPaintDC(self.mask_on_face_image_panel, self.mask_on_face_image_bitmap)\r\n\r\n    def init_middle_panel(self, parent):\r\n        with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, sizer):\r\n            sizer.Add(self.init_prefix_panel(panel), 0, wx.EXPAND)\r\n            sizer.Add(self.init_character_image_file_name_panel(panel), 0, wx.EXPAND)\r\n            sizer.Add(self.init_face_mask_image_file_name_panel(panel), 0, wx.EXPAND)\r\n            sizer.Add(self.init_num_cpu_workers_panel(panel), 0, wx.EXPAND)\r\n            sizer.Add(self.init_num_gpus_panel(panel), 0, wx.EXPAND)\r\n            sizer.Add(self.init_face_morpher_random_seed_0_panel(panel), 0, wx.EXPAND)\r\n            sizer.Add(self.init_face_morpher_random_seed_1_panel(panel), 0, wx.EXPAND)\r\n            sizer.Add(self.init_face_morpher_batch_size_panel(panel), 0, wx.EXPAND)\r\n            sizer.Add(self.init_body_morpher_random_seed_0_panel(panel), 0, wx.EXPAND)\r\n            sizer.Add(self.init_body_morpher_random_seed_1_panel(panel), 0, wx.EXPAND)\r\n            sizer.Add(self.init_body_morpher_batch_size_panel(panel), 0, wx.EXPAND)\r\n            sizer.Add(self.init_num_training_examples_per_sample_output_panel(panel), 0, wx.EXPAND)\r\n\r\n            self.run_button = wx.Button(panel, label=\"RUN\")\r\n            self.run_button.SetMinSize((-1, 64))\r\n            self.run_button.Bind(wx.EVT_BUTTON, self.on_run)\r\n            sizer.Add(self.run_button, 1, wx.EXPAND)\r\n\r\n        return panel\r\n\r\n    def init_prefix_panel(self, parent):\r\n        with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer):\r\n            prefix_param_name_panel = self.create_param_name_panel_with_help_button(\r\n                panel,\r\n                \"prefix (i.e. project directory)\",\r\n                self.create_help_button_func(\"distiller-ui-doc/params/prefix.html\"))\r\n            panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND)\r\n\r\n            with self.create_panel(panel, wx.BoxSizer(wx.HORIZONTAL), style=wx.BORDER_NONE) \\\r\n                    as (prefix_panel, prefix_sizer):\r\n                self.prefix_text_ctrl = wx.TextCtrl(prefix_panel, value=\"\")\r\n                self.prefix_text_ctrl.SetEditable(False)\r\n                prefix_sizer.Add(self.prefix_text_ctrl, 1, wx.EXPAND)\r\n\r\n                self.prefix_change_button = wx.Button(prefix_panel, label=\"Change...\")\r\n                self.prefix_change_button.Bind(wx.EVT_BUTTON, self.on_prefix_change_button)\r\n                prefix_sizer.Add(self.prefix_change_button, 0, wx.EXPAND)\r\n            panel_sizer.Add(prefix_panel, 1, wx.EXPAND)\r\n\r\n        return panel\r\n\r\n    def on_prefix_change_button(self, event):\r\n        dir_dialog = wx.DirDialog(self, \"Choose a directory.\", style=wx.DD_DEFAULT_STYLE | wx.DD_NEW_DIR_BUTTON)\r\n        if dir_dialog.ShowModal() != wx.ID_OK:\r\n            return\r\n        prefix_value = dir_dialog.GetPath()\r\n        try:\r\n            self.state.set_prefix(prefix_value)\r\n            self.update_ui()\r\n        except Exception as e:\r\n            message_dialog = wx.MessageDialog(self, str(e), \"Error\", wx.OK | wx.ICON_ERROR)\r\n            message_dialog.ShowModal()\r\n\r\n    def init_character_image_file_name_panel(self, parent):\r\n        with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer):\r\n            prefix_param_name_panel = self.create_param_name_panel_with_help_button(\r\n                panel,\r\n                \"character_image_file_name\",\r\n                self.create_help_button_func(\"distiller-ui-doc/params/character_image_file_name.html\"))\r\n            panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND)\r\n\r\n            with self.create_panel(panel, wx.BoxSizer(wx.HORIZONTAL), style=wx.BORDER_NONE) as (sub_panel, sub_sizer):\r\n                self.character_image_file_name_text_ctrl = wx.TextCtrl(sub_panel, value=\"\")\r\n                self.character_image_file_name_text_ctrl.SetEditable(False)\r\n                sub_sizer.Add(self.character_image_file_name_text_ctrl, 1, wx.EXPAND)\r\n\r\n                self.character_image_change_button = wx.Button(sub_panel, label=\"Change...\")\r\n                self.character_image_change_button.Bind(wx.EVT_BUTTON, self.on_character_image_change_button)\r\n                sub_sizer.Add(self.character_image_change_button, 0, wx.EXPAND)\r\n            panel_sizer.Add(sub_panel, 1, wx.EXPAND)\r\n\r\n        return panel\r\n\r\n    def on_character_image_change_button(self, event):\r\n        file_dialog = wx.FileDialog(self, \"Choose a PNG file\", wildcard=\"*.png\", style=wx.FD_OPEN)\r\n        if file_dialog.ShowModal() != wx.ID_OK:\r\n            return\r\n        file_name = file_dialog.GetPath()\r\n        try:\r\n            self.state.set_character_image_file_name(file_name)\r\n            self.update_face_image_bitmap(file_name)\r\n            self.update_ui()\r\n        except Exception as e:\r\n            message_dialog = wx.MessageDialog(self, str(e), \"Error\", wx.OK | wx.ICON_ERROR)\r\n            message_dialog.ShowModal()\r\n\r\n    def update_face_image_bitmap(self, new_file_name: str):\r\n        pil_image = PIL.Image.open(new_file_name)\r\n        subimage = pil_image.crop((256 - 64, 80, 256 + 64, 208))\r\n        self.face_image_bitmap = wx.Bitmap.FromBufferRGBA(128, 128, subimage.convert(\"RGBA\").tobytes())\r\n        self.face_image_pytorch = extract_pytorch_image_from_PIL_image(subimage).to(torch.float)\r\n        self.update_mask_on_face_image_bitmap()\r\n\r\n    def init_face_mask_image_file_name_panel(self, parent):\r\n        with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer):\r\n            prefix_param_name_panel = self.create_param_name_panel_with_help_button(\r\n                panel,\r\n                \"face_mask_image_file_name\",\r\n                self.create_help_button_func(\"distiller-ui-doc/params/face_mask_image_file_name.html\"))\r\n            panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND)\r\n\r\n            with self.create_panel(panel, wx.BoxSizer(wx.HORIZONTAL), style=wx.BORDER_NONE) as (sub_panel, sub_sizer):\r\n                self.face_mask_image_file_name_text_ctrl = wx.TextCtrl(sub_panel, value=\"\")\r\n                self.face_mask_image_file_name_text_ctrl.SetEditable(False)\r\n                sub_sizer.Add(self.face_mask_image_file_name_text_ctrl, 1, wx.EXPAND)\r\n\r\n                self.face_mask_image_file_name_change_button = wx.Button(sub_panel, label=\"Change...\")\r\n                self.face_mask_image_file_name_change_button.Bind(wx.EVT_BUTTON, self.on_face_mask_image_change_button)\r\n                sub_sizer.Add(self.face_mask_image_file_name_change_button, 0, wx.EXPAND)\r\n            panel_sizer.Add(sub_panel, 1, wx.EXPAND)\r\n\r\n        return panel\r\n\r\n    def on_face_mask_image_change_button(self, event):\r\n        file_dialog = wx.FileDialog(self, \"Choose a PNG file\", wildcard=\"*.png\", style=wx.FD_OPEN)\r\n        if file_dialog.ShowModal() != wx.ID_OK:\r\n            return\r\n        file_name = file_dialog.GetPath()\r\n        try:\r\n            self.state.set_face_mask_image_file_name(file_name)\r\n            self.update_face_mask_image_bitmap(file_name)\r\n            self.update_ui()\r\n        except Exception as e:\r\n            message_dialog = wx.MessageDialog(self, str(e), \"Error\", wx.OK | wx.ICON_ERROR)\r\n            message_dialog.ShowModal()\r\n\r\n    def update_face_mask_image_bitmap(self, new_file_name):\r\n        pil_image = PIL.Image.open(new_file_name)\r\n        subimage = pil_image.crop((256 - 64, 80, 256 + 64, 208))\r\n        self.face_mask_image_bitmap = wx.Bitmap.FromBufferRGBA(128, 128, subimage.convert(\"RGBA\").tobytes())\r\n        self.face_mask_image_pytorch = extract_pytorch_image_from_PIL_image(subimage).to(torch.float)\r\n        self.face_mask_image_pytorch = self.face_mask_image_pytorch[0:1, :, :]\r\n        self.update_mask_on_face_image_bitmap()\r\n\r\n    def update_mask_on_face_image_bitmap(self):\r\n        if self.face_image_pytorch is None:\r\n            return\r\n        if self.face_mask_image_pytorch is None:\r\n            return\r\n\r\n        mask_on_face_image = (0.5 * self.face_image_pytorch) + (0.5 * self.face_mask_image_pytorch)\r\n        numpy_image = convert_output_image_from_torch_to_numpy(mask_on_face_image)\r\n        wx_image = wx.ImageFromBuffer(\r\n            numpy_image.shape[0],\r\n            numpy_image.shape[1],\r\n            numpy_image[:, :, 0:3].tobytes(),\r\n            numpy_image[:, :, 3].tobytes())\r\n        self.mask_on_face_image_bitmap = wx_image.ConvertToBitmap()\r\n\r\n    def init_num_cpu_workers_panel(self, parent):\r\n        with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer):\r\n            prefix_param_name_panel = self.create_param_name_panel_with_help_button(\r\n                panel,\r\n                \"num_cpu_workers\",\r\n                self.create_help_button_func(\"distiller-ui-doc/params/num_cpu_workers.html\"))\r\n            panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND)\r\n\r\n            num_cpus = multiprocessing.cpu_count()\r\n            self.num_cpu_workers_spin_ctrl = wx.SpinCtrl(panel, initial=1, min=1, max=num_cpus)\r\n\r\n            @wx_bind_event(self.num_cpu_workers_spin_ctrl, wx.EVT_SPINCTRL)\r\n            def on_num_cpu_workers_spin_ctrl(event):\r\n                self.state.set_num_cpu_workers(self.num_cpu_workers_spin_ctrl.GetValue())\r\n                self.Refresh()\r\n\r\n            panel_sizer.Add(self.num_cpu_workers_spin_ctrl, 1, wx.EXPAND)\r\n\r\n        return panel\r\n\r\n    def init_num_gpus_panel(self, parent):\r\n        with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer):\r\n            prefix_param_name_panel = self.create_param_name_panel_with_help_button(\r\n                panel,\r\n                \"num_gpus\",\r\n                self.create_help_button_func(\"distiller-ui-doc/params/num_gpus.html\"))\r\n            panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND)\r\n\r\n            num_gpus = torch.cuda.device_count()\r\n            self.num_gpus_spin_ctrl = wx.SpinCtrl(panel, initial=1, min=1, max=max(1, num_gpus))\r\n\r\n            @wx_bind_event(self.num_gpus_spin_ctrl, wx.EVT_SPINCTRL)\r\n            def on_num_cpu_workers_spin_ctrl(event):\r\n                self.state.set_num_gpus(self.num_gpus_spin_ctrl.GetValue())\r\n                self.Refresh()\r\n\r\n            panel_sizer.Add(self.num_gpus_spin_ctrl, 1, wx.EXPAND)\r\n\r\n        return panel\r\n\r\n    def init_face_morpher_random_seed_0_panel(self, parent):\r\n        with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer):\r\n            prefix_param_name_panel = self.create_param_name_panel_with_help_button(\r\n                panel,\r\n                \"face_morpher_random_seed_0\",\r\n                self.create_help_button_func(\"distiller-ui-doc/params/face_morpher_random_seed_0.html\"))\r\n            panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND)\r\n\r\n            with self.create_panel(panel, wx.BoxSizer(wx.HORIZONTAL), style=wx.BORDER_NONE) as (sub_panel, sub_sizer):\r\n                initial_value = random.randint(0, 2 ** 64 - 1)\r\n                self.face_morpher_random_seed_0_int_ctrl = wx.lib.intctrl.IntCtrl(\r\n                    sub_panel, value=initial_value, min=0, max=0x_ffff_ffff_ffff_ffff)\r\n\r\n                @wx_bind_event(self.face_morpher_random_seed_0_int_ctrl, wx.EVT_TEXT)\r\n                def on_face_morpher_random_seed_0_int_ctrl_text(event):\r\n                    self.state.set_face_morpher_random_seed_0(self.face_morpher_random_seed_0_int_ctrl.GetValue())\r\n\r\n                sub_sizer.Add(self.face_morpher_random_seed_0_int_ctrl, 1, wx.EXPAND)\r\n\r\n                self.face_morpher_random_seed_0_randomize_button = wx.Button(sub_panel, label=\"Randomize\")\r\n\r\n                @wx_bind_event(self.face_morpher_random_seed_0_randomize_button, wx.EVT_BUTTON)\r\n                def on_face_morpher_random_seed_0_randomize_button(event):\r\n                    new_value = random.randint(0, 0x_ffff_ffff_ffff_ffff)\r\n                    self.face_morpher_random_seed_0_int_ctrl.SetValue(new_value)\r\n                    self.state.set_face_morpher_random_seed_0(new_value)\r\n\r\n                sub_sizer.Add(self.face_morpher_random_seed_0_randomize_button, 0, wx.EXPAND)\r\n            panel_sizer.Add(sub_panel, 1, wx.EXPAND)\r\n\r\n        return panel\r\n\r\n    def init_face_morpher_random_seed_1_panel(self, parent):\r\n        with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer):\r\n            prefix_param_name_panel = self.create_param_name_panel_with_help_button(\r\n                panel,\r\n                \"face_morpher_random_seed_1\",\r\n                self.create_help_button_func(\"distiller-ui-doc/params/face_morpher_random_seed_1.html\"))\r\n            panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND)\r\n\r\n            with self.create_panel(panel, wx.BoxSizer(wx.HORIZONTAL), style=wx.BORDER_NONE) as (sub_panel, sub_sizer):\r\n                initial_value = random.randint(0, 2 ** 64 - 1)\r\n                self.face_morpher_random_seed_1_int_ctrl = wx.lib.intctrl.IntCtrl(\r\n                    sub_panel, value=initial_value, min=0, max=0x_ffff_ffff_ffff_ffff)\r\n\r\n                @wx_bind_event(self.face_morpher_random_seed_1_int_ctrl, wx.EVT_TEXT)\r\n                def on_face_morpher_random_seed_1_int_ctrl_text(event):\r\n                    self.state.set_face_morpher_random_seed_1(self.face_morpher_random_seed_1_int_ctrl.GetValue())\r\n\r\n                sub_sizer.Add(self.face_morpher_random_seed_1_int_ctrl, 1, wx.EXPAND)\r\n\r\n                self.face_morpher_random_seed_1_randomize_button = wx.Button(sub_panel, label=\"Randomize\")\r\n\r\n                @wx_bind_event(self.face_morpher_random_seed_1_randomize_button, wx.EVT_BUTTON)\r\n                def on_face_morpher_random_seed_1_randomize_button(event):\r\n                    new_value = random.randint(0, 0x_ffff_ffff_ffff_ffff)\r\n                    self.face_morpher_random_seed_1_int_ctrl.SetValue(new_value)\r\n                    self.state.set_face_morpher_random_seed_1(new_value)\r\n\r\n                sub_sizer.Add(self.face_morpher_random_seed_1_randomize_button, 0, wx.EXPAND)\r\n            panel_sizer.Add(sub_panel, 1, wx.EXPAND)\r\n\r\n        return panel\r\n\r\n    def init_face_morpher_batch_size_panel(self, parent):\r\n        with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer):\r\n            prefix_param_name_panel = self.create_param_name_panel_with_help_button(\r\n                panel,\r\n                \"face_morpher_batch_size\",\r\n                self.create_help_button_func(\"distiller-ui-doc/params/face_morpher_batch_size.html\"))\r\n            panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND)\r\n\r\n            self.face_morpher_batch_size_spin_ctrl = wx.SpinCtrl(panel, initial=8, min=1, max=8)\r\n\r\n            @wx_bind_event(self.face_morpher_batch_size_spin_ctrl, wx.EVT_SPINCTRL)\r\n            def on_face_morpher_batch_size_spin_ctrl(event):\r\n                self.state.set_face_morpher_batch_size(self.face_morpher_batch_size_spin_ctrl.GetValue())\r\n\r\n            panel_sizer.Add(self.face_morpher_batch_size_spin_ctrl, 1, wx.EXPAND)\r\n\r\n        return panel\r\n\r\n    def init_body_morpher_random_seed_0_panel(self, parent):\r\n        with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer):\r\n            prefix_param_name_panel = self.create_param_name_panel_with_help_button(\r\n                panel,\r\n                \"body_morpher_random_seed_0\",\r\n                self.create_help_button_func(\"distiller-ui-doc/params/body_morpher_random_seed_0.html\"))\r\n            panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND)\r\n\r\n            with self.create_panel(panel, wx.BoxSizer(wx.HORIZONTAL), style=wx.BORDER_NONE) as (sub_panel, sub_sizer):\r\n                initial_value = random.randint(0, 2 ** 64 - 1)\r\n                self.body_morpher_random_seed_0_int_ctrl = wx.lib.intctrl.IntCtrl(\r\n                    sub_panel, value=initial_value, min=0, max=0x_ffff_ffff_ffff_ffff)\r\n\r\n                @wx_bind_event(self.body_morpher_random_seed_0_int_ctrl, wx.EVT_TEXT)\r\n                def on_body_morpher_random_seed_0_int_ctrl_text(event):\r\n                    self.state.set_body_morpher_random_seed_0(self.body_morpher_random_seed_0_int_ctrl.GetValue())\r\n\r\n                sub_sizer.Add(self.body_morpher_random_seed_0_int_ctrl, 1, wx.EXPAND)\r\n\r\n                self.body_morpher_random_seed_0_randomize_button = wx.Button(sub_panel, label=\"Randomize\")\r\n\r\n                @wx_bind_event(self.body_morpher_random_seed_0_randomize_button, wx.EVT_BUTTON)\r\n                def on_body_morpher_random_seed_0_randomize_button(event):\r\n                    new_value = random.randint(0, 0x_ffff_ffff_ffff_ffff)\r\n                    self.body_morpher_random_seed_0_int_ctrl.SetValue(new_value)\r\n                    self.state.set_body_morpher_random_seed_0(new_value)\r\n\r\n                sub_sizer.Add(self.body_morpher_random_seed_0_randomize_button, 0, wx.EXPAND)\r\n            panel_sizer.Add(sub_panel, 1, wx.EXPAND)\r\n\r\n        return panel\r\n\r\n    def init_body_morpher_random_seed_1_panel(self, parent):\r\n        with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer):\r\n            prefix_param_name_panel = self.create_param_name_panel_with_help_button(\r\n                panel,\r\n                \"body_morpher_random_seed_1\",\r\n                self.create_help_button_func(\"distiller-ui-doc/params/body_morpher_random_seed_1.html\"))\r\n            panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND)\r\n\r\n            with self.create_panel(panel, wx.BoxSizer(wx.HORIZONTAL), style=wx.BORDER_NONE) as (sub_panel, sub_sizer):\r\n                initial_value = random.randint(0, 2 ** 64 - 1)\r\n                self.body_morpher_random_seed_1_int_ctrl = wx.lib.intctrl.IntCtrl(\r\n                    sub_panel, value=initial_value, min=0, max=0x_ffff_ffff_ffff_ffff)\r\n\r\n                @wx_bind_event(self.body_morpher_random_seed_1_int_ctrl, wx.EVT_TEXT)\r\n                def on_body_morpher_random_seed_1_int_ctrl_text(event):\r\n                    self.state.set_body_morpher_random_seed_1(self.body_morpher_random_seed_1_int_ctrl.GetValue())\r\n\r\n                sub_sizer.Add(self.body_morpher_random_seed_1_int_ctrl, 1, wx.EXPAND)\r\n\r\n                self.body_morpher_random_seed_1_randomize_button = wx.Button(sub_panel, label=\"Randomize\")\r\n\r\n                @wx_bind_event(self.body_morpher_random_seed_1_randomize_button, wx.EVT_BUTTON)\r\n                def on_body_morpher_random_seed_1_randomize_button(event):\r\n                    new_value = random.randint(0, 0x_ffff_ffff_ffff_ffff)\r\n                    self.body_morpher_random_seed_1_int_ctrl.SetValue(new_value)\r\n                    self.state.set_body_morpher_random_seed_1(new_value)\r\n\r\n                sub_sizer.Add(self.body_morpher_random_seed_1_randomize_button, 0, wx.EXPAND)\r\n            panel_sizer.Add(sub_panel, 1, wx.EXPAND)\r\n\r\n        return panel\r\n\r\n    def init_body_morpher_batch_size_panel(self, parent):\r\n        with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer):\r\n            prefix_param_name_panel = self.create_param_name_panel_with_help_button(\r\n                panel,\r\n                \"body_morpher_batch_size\",\r\n                self.create_help_button_func(\"distiller-ui-doc/params/body_morpher_batch_size.html\"))\r\n            panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND)\r\n\r\n            self.body_morpher_batch_size_spin_ctrl = wx.SpinCtrl(panel, initial=8, min=1, max=8)\r\n\r\n            @wx_bind_event(self.body_morpher_batch_size_spin_ctrl, wx.EVT_SPINCTRL)\r\n            def on_body_morpher_batch_size_spin_ctrl(event):\r\n                self.state.set_body_morpher_batch_size(self.body_morpher_batch_size_spin_ctrl.GetValue())\r\n\r\n            panel_sizer.Add(self.body_morpher_batch_size_spin_ctrl, 1, wx.EXPAND)\r\n\r\n        return panel\r\n\r\n    def init_num_training_examples_per_sample_output_panel(self, parent):\r\n        with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer):\r\n            prefix_param_name_panel = self.create_param_name_panel_with_help_button(\r\n                panel,\r\n                \"num_training_examples_per_sample_output\",\r\n                self.create_help_button_func(\"distiller-ui-doc/params/num_training_examples_per_sample_output.html\"))\r\n            panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND)\r\n\r\n            self.num_training_examples_per_sample_output_combobox = \\\r\n                wx.ComboBox(panel,\r\n                            value=\"10_000\",\r\n                            choices=DistillerUiMainFrame.NUM_TRAINING_EXAMPLES_PER_SAMPLE_OUTPUT_CHOICES)\r\n\r\n            @wx_bind_event(self.num_training_examples_per_sample_output_combobox, wx.EVT_COMBOBOX)\r\n            def on_num_training_examples_per_sample_output_combobox(event):\r\n                index = self.num_training_examples_per_sample_output_combobox.GetSelection()\r\n                if index == 3:\r\n                    self.state.set_face_morpher_num_training_examples_per_sample_output(None)\r\n                    self.state.set_body_morpher_num_training_examples_per_sample_output(None)\r\n                else:\r\n                    selected = DistillerUiMainFrame.NUM_TRAINING_EXAMPLES_PER_SAMPLE_OUTPUT_CHOICES[index]\r\n                    new_value = int(selected)\r\n                    self.state.set_face_morpher_num_training_examples_per_sample_output(new_value)\r\n                    self.state.set_body_morpher_num_training_examples_per_sample_output(new_value)\r\n\r\n            panel_sizer.Add(self.num_training_examples_per_sample_output_combobox, 1, wx.EXPAND)\r\n\r\n        return panel\r\n\r\n    def on_close(self, event):\r\n        if self.state.dirty:\r\n            confirmation_dialog = wx.MessageDialog(\r\n                parent=self,\r\n                message=f\"You have not saved your work. Do you want to exit anyway?\",\r\n                caption=\"Confirmation\",\r\n                style=wx.YES_NO | wx.ICON_QUESTION)\r\n            result = confirmation_dialog.ShowModal()\r\n            if result == wx.ID_NO:\r\n                return\r\n\r\n        self.Destroy()\r\n\r\n    def create_help_button_func(self, html_file_name: str):\r\n        def init_help_button_func(parent):\r\n            button = wx.Button(parent, label=\"Help\")\r\n\r\n            @wx_bind_event(button, wx.EVT_BUTTON)\r\n            def on_prefix_button(event):\r\n                self.html_window.LoadPage(html_file_name)\r\n                self.Refresh()\r\n\r\n            return button\r\n\r\n        return init_help_button_func\r\n\r\n    def create_param_name_panel_with_help_button(\r\n            self, parent, param_name: str, help_button_func: Callable[[wx.Window], wx.Button]):\r\n        with self.create_panel(parent, wx.BoxSizer(wx.HORIZONTAL), style=wx.NO_BORDER) \\\r\n                as (panel, sizer):\r\n            title_text_panel = self.create_vertically_centered_text_panel(\r\n                panel, param_name, DistillerUiMainFrame.PARAM_NAME_STATIC_TEXT_MIN_WIDTH)\r\n            sizer.Add(title_text_panel, 1, wx.EXPAND)\r\n\r\n            help_button = help_button_func(panel)\r\n            sizer.Add(help_button, 0, wx.EXPAND)\r\n        return panel\r\n\r\n    def create_vertically_centered_text_panel(self, parent, text: str, min_width: int):\r\n        with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.NO_BORDER) as (panel, sizer):\r\n            sizer.AddStretchSpacer(1)\r\n            text = wx.StaticText(\r\n                panel,\r\n                label=text,\r\n                style=wx.ALIGN_CENTER)\r\n            text.SetMinSize((min_width, -1))\r\n            sizer.Add(text, 0, wx.EXPAND)\r\n            sizer.AddStretchSpacer(1)\r\n        return panel\r\n\r\n    def init_right_panel(self, parent):\r\n        with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, sizer):\r\n            self.html_window = wx.html.HtmlWindow(panel)\r\n            self.html_window.SetMinSize((600, 600))\r\n            self.html_window.SetFonts(\"Times New Roman\", \"Courier New\", sizes=[10, 12, 14, 16, 18, 20, 24])\r\n            self.html_window.LoadPage(\"distiller-ui-doc/index.html\")\r\n            sizer.Add(self.html_window, 1, wx.EXPAND)\r\n\r\n            go_to_main_documentation_button = wx.Button(panel, label=\"Go to Main Documentation\")\r\n            sizer.Add(go_to_main_documentation_button, 0, wx.EXPAND)\r\n\r\n            @wx_bind_event(go_to_main_documentation_button, wx.EVT_BUTTON)\r\n            def on_go_to_main_documentation_button(event):\r\n                self.html_window.LoadPage(\"distiller-ui-doc/index.html\")\r\n                self.Refresh()\r\n\r\n        return panel\r\n\r\n    def populate_distiller_config(self):\r\n        self.state.config.prefix = self.prefix_text_ctrl.GetValue()\r\n        self.state.config.character_image_file_name = self.character_image_file_name_text_ctrl.GetValue()\r\n        self.state.config.face_mask_image_file_name = self.face_mask_image_file_name_text_ctrl.GetValue()\r\n\r\n        self.state.config.num_cpu_workers = self.num_cpu_workers_spin_ctrl.GetValue()\r\n        self.state.config.num_gpus = self.num_gpus_spin_ctrl.GetValue()\r\n\r\n        self.state.config.face_morpher_random_seed_0 = self.face_morpher_random_seed_0_int_ctrl.GetValue()\r\n        self.state.config.face_morpher_random_seed_1 = self.face_morpher_random_seed_1_int_ctrl.GetValue()\r\n        self.state.config.face_morpher_batch_size = self.face_morpher_batch_size_spin_ctrl.GetValue()\r\n\r\n        self.state.config.body_morpher_random_seed_0 = self.body_morpher_random_seed_0_int_ctrl.GetValue()\r\n        self.state.config.body_morpher_random_seed_1 = self.body_morpher_random_seed_1_int_ctrl.GetValue()\r\n        self.state.config.body_morpher_batch_size = self.body_morpher_batch_size_spin_ctrl.GetValue()\r\n\r\n        if self.num_training_examples_per_sample_output_combobox.GetValue() == \\\r\n                DistillerUiMainFrame.NUM_TRAINING_EXAMPLES_PER_SAMPLE_OUTPUT_CHOICES[-1]:\r\n            self.state.config.face_morpher_num_training_examples_per_sample_output = None\r\n            self.state.config.body_morpher_num_training_examples_per_sample_output = None\r\n        else:\r\n            value = int(self.num_training_examples_per_sample_output_combobox.GetValue())\r\n            self.state.config.face_morpher_num_training_examples_per_sample_output = value\r\n            self.state.config.body_morpher_num_training_examples_per_sample_output = value\r\n\r\n    def update_ui(self):\r\n        self.prefix_text_ctrl.SetValue(self.state.config.prefix)\r\n        self.character_image_file_name_text_ctrl.SetValue(self.state.config.character_image_file_name)\r\n        self.face_mask_image_file_name_text_ctrl.SetValue(self.state.config.face_mask_image_file_name)\r\n\r\n        if not self.state.can_show_character_image():\r\n            self.draw_nothing_yet_string_to_bitmap(self.face_image_bitmap, 128, 128)\r\n        if not self.state.can_show_face_mask_image():\r\n            self.draw_nothing_yet_string_to_bitmap(self.face_mask_image_bitmap, 128, 128)\r\n        if not self.state.can_show_mask_on_face_image():\r\n            self.draw_nothing_yet_string_to_bitmap(self.mask_on_face_image_bitmap, 128, 128)\r\n\r\n        self.num_cpu_workers_spin_ctrl.SetValue(self.state.config.num_cpu_workers)\r\n        self.num_gpus_spin_ctrl.SetValue(self.state.config.num_gpus)\r\n\r\n        self.face_morpher_random_seed_0_int_ctrl.SetValue(self.state.config.face_morpher_random_seed_0)\r\n        self.face_morpher_random_seed_1_int_ctrl.SetValue(self.state.config.face_morpher_random_seed_1)\r\n        self.face_morpher_batch_size_spin_ctrl.SetValue(self.state.config.face_morpher_batch_size)\r\n\r\n        self.body_morpher_random_seed_0_int_ctrl.SetValue(self.state.config.body_morpher_random_seed_0)\r\n        self.body_morpher_random_seed_1_int_ctrl.SetValue(self.state.config.body_morpher_random_seed_1)\r\n        self.body_morpher_batch_size_spin_ctrl.SetValue(self.state.config.body_morpher_batch_size)\r\n\r\n        if self.state.config.body_morpher_num_training_examples_per_sample_output is None:\r\n            self.num_training_examples_per_sample_output_combobox.SetSelection(3)\r\n        else:\r\n            choices = [int(x) for x in DistillerUiMainFrame.NUM_TRAINING_EXAMPLES_PER_SAMPLE_OUTPUT_CHOICES[:-1]]\r\n            self.num_training_examples_per_sample_output_combobox.SetSelection(\r\n                choices.index(self.state.config.body_morpher_num_training_examples_per_sample_output))\r\n\r\n        self.save_menu_item.Enable(self.state.can_save())\r\n\r\n        self.Refresh()\r\n\r\n    def draw_nothing_yet_string_to_bitmap(self, bitmap, width: int, height: int):\r\n        dc = wx.MemoryDC()\r\n        dc.SelectObject(bitmap)\r\n\r\n        dc.Clear()\r\n        font = wx.Font(wx.FontInfo(14).Family(wx.FONTFAMILY_SWISS))\r\n        dc.SetFont(font)\r\n        w, h = dc.GetTextExtent(\"Nothing yet!\")\r\n        dc.DrawText(\"Nothing yet!\", (width - w) // 2, (height - h) // 2)\r\n\r\n        del dc\r\n\r\n    def try_saving(self):\r\n        if not self.state.can_save():\r\n            message_dialog = wx.MessageDialog(\r\n                self,\r\n                \"Cannot save yet! Please make sure you set the prefix, character_image_file_name, \"\r\n                \"and face_mask_image_file_name first.\",\r\n                \"Error\",\r\n                wx.OK | wx.ICON_ERROR)\r\n            message_dialog.ShowModal()\r\n            return False\r\n        else:\r\n            if self.state.need_to_check_overwrite():\r\n                confirmation_dialog = wx.MessageDialog(\r\n                    parent=self,\r\n                    message=f\"Overwriting {self.state.config.config_yaml_file_name()}?\",\r\n                    caption=\"Confirmation\",\r\n                    style=wx.YES_NO | wx.CANCEL | wx.ICON_QUESTION)\r\n                result = confirmation_dialog.ShowModal()\r\n                if result == wx.ID_YES:\r\n                    self.state.save()\r\n                    return True\r\n                elif result == wx.ID_NO:\r\n                    return False\r\n                else:\r\n                    return False\r\n            else:\r\n                self.state.save()\r\n                return True\r\n\r\n    def on_save(self, event):\r\n        return self.try_saving()\r\n\r\n    def on_new(self, event):\r\n        if self.state.dirty:\r\n            confirmation_dialog = wx.MessageDialog(\r\n                parent=self,\r\n                message=f\"You have not saved the current config. Do you want to proceed?\",\r\n                caption=\"Confirmation\",\r\n                style=wx.YES_NO | wx.ICON_QUESTION)\r\n            result = confirmation_dialog.ShowModal()\r\n            if result == wx.ID_NO:\r\n                return\r\n        self.state = DistillerConfigState()\r\n        self.update_ui()\r\n\r\n    def on_open(self, event):\r\n        if self.state.dirty:\r\n            confirmation_dialog = wx.MessageDialog(\r\n                parent=self,\r\n                message=f\"You have not saved the current config. Do you want to proceed?\",\r\n                caption=\"Confirmation\",\r\n                style=wx.YES_NO | wx.ICON_QUESTION)\r\n            result = confirmation_dialog.ShowModal()\r\n            if result == wx.ID_NO:\r\n                return\r\n\r\n        file_dialog = wx.FileDialog(self, \"Choose a YAML file\", wildcard=\"*.yaml\", style=wx.FD_OPEN)\r\n        if file_dialog.ShowModal() != wx.ID_OK:\r\n            return\r\n        file_name = file_dialog.GetPath()\r\n        try:\r\n            self.state.load(file_name)\r\n            self.face_image_pytorch = None\r\n            self.face_mask_image_pytorch = None\r\n            self.update_face_image_bitmap(self.state.config.character_image_file_name)\r\n            self.update_face_mask_image_bitmap(self.state.config.face_mask_image_file_name)\r\n            self.update_ui()\r\n        except Exception as e:\r\n            message_dialog = wx.MessageDialog(self, str(e), \"Error\", wx.OK | wx.ICON_ERROR)\r\n            message_dialog.ShowModal()\r\n\r\n    def on_run(self, event):\r\n        try:\r\n            self.state.config.check()\r\n        except Exception as e:\r\n            message_dialog = wx.MessageDialog(self, str(e), \"Error\", wx.OK | wx.ICON_ERROR)\r\n            message_dialog.ShowModal()\r\n            return\r\n\r\n        if self.state.dirty:\r\n            message_dialog = wx.MessageDialog(\r\n                self,\r\n                \"Please save the configuration first.\",\r\n                \"Error\",\r\n                wx.OK | wx.ICON_ERROR)\r\n            message_dialog.ShowModal()\r\n            return\r\n\r\n        self.config_file_to_run = self.state.config.config_yaml_file_name()\r\n        self.Destroy()\r\n"
  },
  {
    "path": "src/tha4/image_util.py",
    "content": "import math\r\n\r\nimport PIL.Image\r\nimport numpy\r\nimport torch\r\nfrom matplotlib import cm\r\nfrom tha4.shion.base.image_util import numpy_linear_to_srgb, pytorch_rgba_to_numpy_image, pytorch_rgb_to_numpy_image, \\\r\n    torch_linear_to_srgb\r\n\r\n\r\ndef grid_change_to_numpy_image(torch_image, num_channels=3):\r\n    height = torch_image.shape[1]\r\n    width = torch_image.shape[2]\r\n    size_image = (torch_image[0, :, :] ** 2 + torch_image[1, :, :] ** 2).sqrt().view(height, width, 1).numpy()\r\n    hsv = cm.get_cmap('hsv')\r\n    angle_image = hsv(((torch.atan2(\r\n        torch_image[0, :, :].view(height * width),\r\n        torch_image[1, :, :].view(height * width)).view(height, width) + math.pi) / (2 * math.pi)).numpy()) * 3\r\n    numpy_image = size_image * angle_image[:, :, 0:3]\r\n    rgb_image = numpy_linear_to_srgb(numpy_image)\r\n    if num_channels == 3:\r\n        return rgb_image\r\n    elif num_channels == 4:\r\n        return numpy.concatenate([rgb_image, numpy.ones_like(size_image)], axis=2)\r\n    else:\r\n        raise RuntimeError(\"Unsupported num_channels: \" + str(num_channels))\r\n\r\n\r\ndef resize_PIL_image(pil_image, size=(256, 256)):\r\n    w, h = pil_image.size\r\n    d = min(w, h)\r\n    r = ((w - d) // 2, (h - d) // 2, (w + d) // 2, (h + d) // 2)\r\n    return pil_image.resize(size, resample=PIL.Image.LANCZOS, box=r)\r\n\r\n\r\ndef convert_output_image_from_torch_to_numpy(output_image):\r\n    if output_image.shape[2] == 2:\r\n        h, w, c = output_image.shape\r\n        numpy_image = torch.transpose(output_image.reshape(h * w, c), 0, 1).reshape(c, h, w)\r\n    elif output_image.shape[0] == 4:\r\n        numpy_image = pytorch_rgba_to_numpy_image(output_image)\r\n    elif output_image.shape[0] == 3:\r\n        numpy_image = pytorch_rgb_to_numpy_image(output_image)\r\n    elif output_image.shape[0] == 1:\r\n        c, h, w = output_image.shape\r\n        alpha_image = torch.cat([output_image.repeat(3, 1, 1) * 2.0 - 1.0, torch.ones(1, h, w)], dim=0)\r\n        numpy_image = pytorch_rgba_to_numpy_image(alpha_image)\r\n    elif output_image.shape[0] == 2:\r\n        numpy_image = grid_change_to_numpy_image(output_image, num_channels=4)\r\n    else:\r\n        raise RuntimeError(\"Unsupported # image channels: %d\" % output_image.shape[0])\r\n    numpy_image = numpy.uint8(numpy.rint(numpy_image * 255.0))\r\n    return numpy_image\r\n\r\n\r\ndef convert_linear_to_srgb(image: torch.Tensor) -> torch.Tensor:\r\n    rgb_image = torch_linear_to_srgb(image[0:3, :, :])\r\n    return torch.cat([rgb_image, image[3:4, :, :]], dim=0)\r\n"
  },
  {
    "path": "src/tha4/mocap/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/mocap/ifacialmocap_constants.py",
    "content": "EYE_LOOK_IN_LEFT = \"eyeLookInLeft\"\r\nEYE_LOOK_OUT_LEFT = \"eyeLookOutLeft\"\r\nEYE_LOOK_DOWN_LEFT = \"eyeLookDownLeft\"\r\nEYE_LOOK_UP_LEFT = \"eyeLookUpLeft\"\r\nEYE_BLINK_LEFT = \"eyeBlinkLeft\"\r\nEYE_SQUINT_LEFT = \"eyeSquintLeft\"\r\nEYE_WIDE_LEFT = \"eyeWideLeft\"\r\nEYE_LOOK_IN_RIGHT = \"eyeLookInRight\"\r\nEYE_LOOK_OUT_RIGHT = \"eyeLookOutRight\"\r\nEYE_LOOK_DOWN_RIGHT = \"eyeLookDownRight\"\r\nEYE_LOOK_UP_RIGHT = \"eyeLookUpRight\"\r\nEYE_BLINK_RIGHT = \"eyeBlinkRight\"\r\nEYE_SQUINT_RIGHT = \"eyeSquintRight\"\r\nEYE_WIDE_RIGHT = \"eyeWideRight\"\r\nBROW_DOWN_LEFT = \"browDownLeft\"\r\nBROW_OUTER_UP_LEFT = \"browOuterUpLeft\"\r\nBROW_DOWN_RIGHT = \"browDownRight\"\r\nBROW_OUTER_UP_RIGHT = \"browOuterUpRight\"\r\nBROW_INNER_UP = \"browInnerUp\"\r\nNOSE_SNEER_LEFT = \"noseSneerLeft\"\r\nNOSE_SNEER_RIGHT = \"noseSneerRight\"\r\nCHEEK_SQUINT_LEFT = \"cheekSquintLeft\"\r\nCHEEK_SQUINT_RIGHT = \"cheekSquintRight\"\r\nCHEEK_PUFF = \"cheekPuff\"\r\nMOUTH_LEFT = \"mouthLeft\"\r\nMOUTH_DIMPLE_LEFT = \"mouthDimpleLeft\"\r\nMOUTH_FROWN_LEFT = \"mouthFrownLeft\"\r\nMOUTH_LOWER_DOWN_LEFT = \"mouthLowerDownLeft\"\r\nMOUTH_PRESS_LEFT = \"mouthPressLeft\"\r\nMOUTH_SMILE_LEFT = \"mouthSmileLeft\"\r\nMOUTH_STRETCH_LEFT = \"mouthStretchLeft\"\r\nMOUTH_UPPER_UP_LEFT = \"mouthUpperUpLeft\"\r\nMOUTH_RIGHT = \"mouthRight\"\r\nMOUTH_DIMPLE_RIGHT = \"mouthDimpleRight\"\r\nMOUTH_FROWN_RIGHT = \"mouthFrownRight\"\r\nMOUTH_LOWER_DOWN_RIGHT = \"mouthLowerDownRight\"\r\nMOUTH_PRESS_RIGHT = \"mouthPressRight\"\r\nMOUTH_SMILE_RIGHT = \"mouthSmileRight\"\r\nMOUTH_STRETCH_RIGHT = \"mouthStretchRight\"\r\nMOUTH_UPPER_UP_RIGHT = \"mouthUpperUpRight\"\r\nMOUTH_CLOSE = \"mouthClose\"\r\nMOUTH_FUNNEL = \"mouthFunnel\"\r\nMOUTH_PUCKER = \"mouthPucker\"\r\nMOUTH_ROLL_LOWER = \"mouthRollLower\"\r\nMOUTH_ROLL_UPPER = \"mouthRollUpper\"\r\nMOUTH_SHRUG_LOWER = \"mouthShrugLower\"\r\nMOUTH_SHRUG_UPPER = \"mouthShrugUpper\"\r\nJAW_LEFT = \"jawLeft\"\r\nJAW_RIGHT = \"jawRight\"\r\nJAW_FORWARD = \"jawForward\"\r\nJAW_OPEN = \"jawOpen\"\r\nTONGUE_OUT = \"tongueOut\"\r\n\r\nBLENDSHAPE_NAMES = [\r\n    EYE_LOOK_IN_LEFT,  # 0\r\n    EYE_LOOK_OUT_LEFT,  # 1\r\n    EYE_LOOK_DOWN_LEFT,  # 2\r\n    EYE_LOOK_UP_LEFT,  # 3\r\n    EYE_BLINK_LEFT,  # 4\r\n    EYE_SQUINT_LEFT,  # 5\r\n    EYE_WIDE_LEFT,  # 6\r\n    EYE_LOOK_IN_RIGHT,  # 7\r\n    EYE_LOOK_OUT_RIGHT,  # 8\r\n    EYE_LOOK_DOWN_RIGHT,  # 9\r\n    EYE_LOOK_UP_RIGHT,  # 10\r\n    EYE_BLINK_RIGHT,  # 11\r\n    EYE_SQUINT_RIGHT,  # 12\r\n    EYE_WIDE_RIGHT,  # 13\r\n    BROW_DOWN_LEFT,  # 14\r\n    BROW_OUTER_UP_LEFT,  # 15\r\n    BROW_DOWN_RIGHT,  # 16\r\n    BROW_OUTER_UP_RIGHT,  # 17\r\n    BROW_INNER_UP,  # 18\r\n    NOSE_SNEER_LEFT,  # 19\r\n    NOSE_SNEER_RIGHT,  # 20\r\n    CHEEK_SQUINT_LEFT,  # 21\r\n    CHEEK_SQUINT_RIGHT,  # 22\r\n    CHEEK_PUFF,  # 23\r\n    MOUTH_LEFT,  # 24\r\n    MOUTH_DIMPLE_LEFT,  # 25\r\n    MOUTH_FROWN_LEFT,  # 26\r\n    MOUTH_LOWER_DOWN_LEFT,  # 27\r\n    MOUTH_PRESS_LEFT,  # 28\r\n    MOUTH_SMILE_LEFT,  # 29\r\n    MOUTH_STRETCH_LEFT,  # 30\r\n    MOUTH_UPPER_UP_LEFT,  # 31\r\n    MOUTH_RIGHT,  # 32\r\n    MOUTH_DIMPLE_RIGHT,  # 33\r\n    MOUTH_FROWN_RIGHT,  # 34\r\n    MOUTH_LOWER_DOWN_RIGHT,  # 35\r\n    MOUTH_PRESS_RIGHT,  # 36\r\n    MOUTH_SMILE_RIGHT,  # 37\r\n    MOUTH_STRETCH_RIGHT,  # 38\r\n    MOUTH_UPPER_UP_RIGHT,  # 39\r\n    MOUTH_CLOSE,  # 40\r\n    MOUTH_FUNNEL,  # 41\r\n    MOUTH_PUCKER,  # 42\r\n    MOUTH_ROLL_LOWER,  # 43\r\n    MOUTH_ROLL_UPPER,  # 44\r\n    MOUTH_SHRUG_LOWER,  # 45\r\n    MOUTH_SHRUG_UPPER,  # 46\r\n    JAW_LEFT,  # 47\r\n    JAW_RIGHT,  # 48\r\n    JAW_FORWARD,  # 49\r\n    JAW_OPEN,  # 50\r\n    TONGUE_OUT,  # 51\r\n]\r\n\r\nEYE_LEFT_BLENDSHAPES = [\r\n    EYE_LOOK_IN_LEFT,  # 0\r\n    EYE_LOOK_OUT_LEFT,  # 1\r\n    EYE_LOOK_DOWN_LEFT,  # 2\r\n    EYE_LOOK_UP_LEFT,  # 3\r\n    EYE_BLINK_LEFT,  # 4\r\n    EYE_SQUINT_LEFT,  # 5\r\n    EYE_WIDE_LEFT,  # 6\r\n]\r\n\r\nEYE_RIGHT_BLENDSHAPES = [\r\n    EYE_LOOK_IN_RIGHT,  # 7\r\n    EYE_LOOK_OUT_RIGHT,  # 8\r\n    EYE_LOOK_DOWN_RIGHT,  # 9\r\n    EYE_LOOK_UP_RIGHT,  # 10\r\n    EYE_BLINK_RIGHT,  # 11\r\n    EYE_SQUINT_RIGHT,  # 12\r\n    EYE_WIDE_RIGHT,  # 13\r\n]\r\n\r\nBROW_LEFT_BLENDSHAPES = [\r\n    BROW_DOWN_LEFT,  # 14\r\n    BROW_OUTER_UP_LEFT,  # 15\r\n\r\n]\r\n\r\nBROW_RIGHT_BLENDSHAPES = [\r\n    BROW_DOWN_RIGHT,  # 16\r\n    BROW_OUTER_UP_RIGHT,  # 17\r\n\r\n]\r\n\r\nBROW_BOTH_BLENDSHAPES = [\r\n    BROW_INNER_UP,  # 18\r\n]\r\n\r\nNOSE_BLENDSHAPES = [\r\n    NOSE_SNEER_LEFT,  # 19\r\n    NOSE_SNEER_RIGHT,  # 20\r\n]\r\n\r\nCHECK_BLENDSHAPES = [\r\n    CHEEK_SQUINT_LEFT,  # 21\r\n    CHEEK_SQUINT_RIGHT,  # 22\r\n    CHEEK_PUFF,  # 23\r\n]\r\n\r\nMOUTH_LEFT_BLENDSHAPES = [\r\n    MOUTH_LEFT,  # 24\r\n    MOUTH_DIMPLE_LEFT,  # 25\r\n    MOUTH_FROWN_LEFT,  # 26\r\n    MOUTH_LOWER_DOWN_LEFT,  # 27\r\n    MOUTH_PRESS_LEFT,  # 28\r\n    MOUTH_SMILE_LEFT,  # 29\r\n    MOUTH_STRETCH_LEFT,  # 30\r\n    MOUTH_UPPER_UP_LEFT,  # 31\r\n]\r\n\r\nMOUTH_RIGHT_BLENDSHAPES = [\r\n    MOUTH_RIGHT,  # 32\r\n    MOUTH_DIMPLE_RIGHT,  # 33\r\n    MOUTH_FROWN_RIGHT,  # 34\r\n    MOUTH_LOWER_DOWN_RIGHT,  # 35\r\n    MOUTH_PRESS_RIGHT,  # 36\r\n    MOUTH_SMILE_RIGHT,  # 37\r\n    MOUTH_STRETCH_RIGHT,  # 38\r\n    MOUTH_UPPER_UP_RIGHT,  # 39\r\n]\r\n\r\nMOUTH_BOTH_BLENDSHAPES = [\r\n    MOUTH_CLOSE,  # 40\r\n    MOUTH_FUNNEL,  # 41\r\n    MOUTH_PUCKER,  # 42\r\n    MOUTH_ROLL_LOWER,  # 43\r\n    MOUTH_ROLL_UPPER,  # 44\r\n    MOUTH_SHRUG_LOWER,  # 45\r\n    MOUTH_SHRUG_UPPER,  # 46\r\n]\r\n\r\nJAW_BLENDSHAPES = [\r\n    JAW_LEFT,  # 47\r\n    JAW_RIGHT,  # 48\r\n    JAW_FORWARD,  # 49\r\n    JAW_OPEN,  # 50\r\n]\r\n\r\nTONGUE_BLENDSHAPES = [\r\n    TONGUE_OUT,  # 51\r\n]\r\n\r\nCOLUMN_0_BLENDSHAPES = EYE_RIGHT_BLENDSHAPES + BROW_RIGHT_BLENDSHAPES + [NOSE_SNEER_RIGHT, CHEEK_SQUINT_RIGHT]\r\nCOLUMN_1_BLENDSHAPES = EYE_LEFT_BLENDSHAPES + BROW_LEFT_BLENDSHAPES + [NOSE_SNEER_LEFT, CHEEK_SQUINT_LEFT]\r\nCOLUMN_2_BLENDSHAPES = MOUTH_RIGHT_BLENDSHAPES + [JAW_RIGHT]\r\nCOLUMN_3_BLENDSHAPES = MOUTH_LEFT_BLENDSHAPES + [JAW_LEFT]\r\nCOLUMN_4_BLENDSHAPES = [BROW_INNER_UP, CHEEK_PUFF] + MOUTH_BOTH_BLENDSHAPES + [JAW_FORWARD, JAW_OPEN, TONGUE_OUT]\r\n\r\nBLENDSHAPE_COLUMNS = [\r\n    COLUMN_0_BLENDSHAPES,\r\n    COLUMN_1_BLENDSHAPES,\r\n    COLUMN_2_BLENDSHAPES,\r\n    COLUMN_3_BLENDSHAPES,\r\n    COLUMN_4_BLENDSHAPES,\r\n]\r\n\r\nRIGHT_EYE_BONE_X = \"rightEyeBoneX\"\r\nRIGHT_EYE_BONE_Y = \"rightEyeBoneY\"\r\nRIGHT_EYE_BONE_Z = \"rightEyeBoneZ\"\r\nRIGHT_EYE_BONE_ROTATIONS = [RIGHT_EYE_BONE_X, RIGHT_EYE_BONE_Y, RIGHT_EYE_BONE_Z]\r\n\r\nLEFT_EYE_BONE_X = \"leftEyeBoneX\"\r\nLEFT_EYE_BONE_Y = \"leftEyeBoneY\"\r\nLEFT_EYE_BONE_Z = \"leftEyeBoneZ\"\r\nLEFT_EYE_BONE_ROTATIONS = [LEFT_EYE_BONE_X, LEFT_EYE_BONE_Y, LEFT_EYE_BONE_Z]\r\n\r\nHEAD_BONE_X = \"headBoneX\"\r\nHEAD_BONE_Y = \"headBoneY\"\r\nHEAD_BONE_Z = \"headBoneZ\"\r\nHEAD_BONE_ROTATIONS = [HEAD_BONE_X, HEAD_BONE_Y, HEAD_BONE_Z]\r\n\r\nROTATION_NAMES = RIGHT_EYE_BONE_ROTATIONS + LEFT_EYE_BONE_ROTATIONS + HEAD_BONE_ROTATIONS\r\n\r\nRIGHT_EYE_BONE_QUAT = \"rightEyeBoneQuat\"\r\nLEFT_EYE_BONE_QUAT = \"leftEyeBoneQuat\"\r\nHEAD_BONE_QUAT = \"headBoneQuat\"\r\nQUATERNION_NAMES = [\r\n    RIGHT_EYE_BONE_QUAT,\r\n    LEFT_EYE_BONE_QUAT,\r\n    HEAD_BONE_QUAT\r\n]\r\n\r\nIFACIALMOCAP_DATETIME_FORMAT = \"%Y/%m/%d-%H:%M:%S.%f\"\r\n"
  },
  {
    "path": "src/tha4/mocap/ifacialmocap_pose.py",
    "content": "from tha4.mocap.ifacialmocap_constants import BLENDSHAPE_NAMES, HEAD_BONE_X, HEAD_BONE_Y, HEAD_BONE_Z, \\\r\n    HEAD_BONE_QUAT, LEFT_EYE_BONE_X, LEFT_EYE_BONE_Y, LEFT_EYE_BONE_Z, LEFT_EYE_BONE_QUAT, RIGHT_EYE_BONE_X, \\\r\n    RIGHT_EYE_BONE_Y, RIGHT_EYE_BONE_Z, RIGHT_EYE_BONE_QUAT\r\n\r\n\r\ndef create_default_ifacialmocap_pose():\r\n    data = {}\r\n\r\n    for blendshape_name in BLENDSHAPE_NAMES:\r\n        data[blendshape_name] = 0.0\r\n\r\n    data[HEAD_BONE_X] = 0.0\r\n    data[HEAD_BONE_Y] = 0.0\r\n    data[HEAD_BONE_Z] = 0.0\r\n    data[HEAD_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]\r\n\r\n    data[LEFT_EYE_BONE_X] = 0.0\r\n    data[LEFT_EYE_BONE_Y] = 0.0\r\n    data[LEFT_EYE_BONE_Z] = 0.0\r\n    data[LEFT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]\r\n\r\n    data[RIGHT_EYE_BONE_X] = 0.0\r\n    data[RIGHT_EYE_BONE_Y] = 0.0\r\n    data[RIGHT_EYE_BONE_Z] = 0.0\r\n    data[RIGHT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]\r\n\r\n    return data"
  },
  {
    "path": "src/tha4/mocap/ifacialmocap_pose_converter.py",
    "content": "from abc import ABC, abstractmethod\r\nfrom typing import Dict, List\r\n\r\n\r\nclass IFacialMocapPoseConverter(ABC):\r\n    @abstractmethod\r\n    def convert(self, ifacialmocap_pose: Dict[str, float]) -> List[float]:\r\n        pass\r\n\r\n    @abstractmethod\r\n    def init_pose_converter_panel(self, parent):\r\n        pass"
  },
  {
    "path": "src/tha4/mocap/ifacialmocap_pose_converter_25.py",
    "content": "import math\r\nimport time\r\nfrom enum import Enum\r\nfrom typing import Optional, Dict, List, Callable\r\n\r\nimport numpy\r\nimport scipy.optimize\r\nimport wx\r\n\r\nfrom tha4.mocap.ifacialmocap_constants import MOUTH_SMILE_LEFT, MOUTH_SHRUG_UPPER, MOUTH_SMILE_RIGHT, \\\r\n    BROW_INNER_UP, BROW_OUTER_UP_RIGHT, BROW_OUTER_UP_LEFT, BROW_DOWN_LEFT, BROW_DOWN_RIGHT, EYE_WIDE_LEFT, \\\r\n    EYE_WIDE_RIGHT, EYE_BLINK_LEFT, EYE_BLINK_RIGHT, CHEEK_SQUINT_LEFT, CHEEK_SQUINT_RIGHT, EYE_LOOK_IN_LEFT, \\\r\n    EYE_LOOK_OUT_LEFT, EYE_LOOK_IN_RIGHT, EYE_LOOK_OUT_RIGHT, EYE_LOOK_UP_LEFT, EYE_LOOK_UP_RIGHT, EYE_LOOK_DOWN_RIGHT, \\\r\n    EYE_LOOK_DOWN_LEFT, HEAD_BONE_X, HEAD_BONE_Y, HEAD_BONE_Z, JAW_OPEN, MOUTH_FROWN_LEFT, MOUTH_FROWN_RIGHT, \\\r\n    MOUTH_LOWER_DOWN_LEFT, MOUTH_LOWER_DOWN_RIGHT, MOUTH_FUNNEL, MOUTH_PUCKER\r\nfrom tha4.mocap.ifacialmocap_pose_converter import IFacialMocapPoseConverter\r\nfrom tha4.poser.modes.pose_parameters import get_pose_parameters\r\n\r\n\r\nclass EyebrowDownMode(Enum):\r\n    TROUBLED = 1\r\n    ANGRY = 2\r\n    LOWERED = 3\r\n    SERIOUS = 4\r\n\r\n\r\nclass WinkMode(Enum):\r\n    NORMAL = 1\r\n    RELAXED = 2\r\n\r\n\r\ndef rad_to_deg(rad):\r\n    return rad * 180.0 / math.pi\r\n\r\n\r\ndef deg_to_rad(deg):\r\n    return deg * math.pi / 180.0\r\n\r\n\r\ndef clamp(x, min_value, max_value):\r\n    return max(min_value, min(max_value, x))\r\n\r\n\r\nclass IFacialMocapPoseConverter25Args:\r\n    def __init__(self,\r\n                 smile_threshold_min: float = 0.4,\r\n                 smile_threshold_max: float = 0.6,\r\n                 eyebrow_down_mode: EyebrowDownMode = EyebrowDownMode.ANGRY,\r\n                 wink_mode: WinkMode = WinkMode.NORMAL,\r\n                 eye_surprised_max: float = 0.5,\r\n                 eye_blink_max: float = 0.8,\r\n                 eyebrow_down_max: float = 0.4,\r\n                 cheek_squint_min: float = 0.1,\r\n                 cheek_squint_max: float = 0.7,\r\n                 eye_rotation_factor: float = 1.0 / 0.75,\r\n                 jaw_open_min: float = 0.1,\r\n                 jaw_open_max: float = 0.4,\r\n                 mouth_frown_max: float = 0.6,\r\n                 mouth_funnel_min: float = 0.25,\r\n                 mouth_funnel_max: float = 0.5,\r\n                 iris_small_left=0.0,\r\n                 iris_small_right=0.0):\r\n        self.iris_small_right = iris_small_left\r\n        self.iris_small_left = iris_small_right\r\n\r\n        self.wink_mode = wink_mode\r\n\r\n        self.mouth_funnel_max = mouth_funnel_max\r\n        self.mouth_funnel_min = mouth_funnel_min\r\n        self.mouth_frown_max = mouth_frown_max\r\n\r\n        self.jaw_open_max = jaw_open_max\r\n        self.jaw_open_min = jaw_open_min\r\n\r\n        self.eye_rotation_factor = eye_rotation_factor\r\n\r\n        self.cheek_squint_max = cheek_squint_max\r\n        self.cheek_squint_min = cheek_squint_min\r\n\r\n        self.eyebrow_down_max = eyebrow_down_max\r\n\r\n        self.eye_blink_max = eye_blink_max\r\n        self.eye_surprised_max = eye_surprised_max\r\n        self.eyebrow_down_mode = eyebrow_down_mode\r\n\r\n        self.smile_threshold_min = smile_threshold_min\r\n        self.smile_threshold_max = smile_threshold_max\r\n\r\n    def set_smile_threshold_min(self, new_value: float):\r\n        self.smile_threshold_min = new_value\r\n\r\n    def set_smile_threshold_max(self, new_value: float):\r\n        self.smile_threshold_max = new_value\r\n\r\n    def set_eye_surprised_max(self, new_value: float):\r\n        self.eye_surprised_max = new_value\r\n\r\n    def set_eye_blink_max(self, new_value: float):\r\n        self.eye_blink_max = new_value\r\n\r\n    def set_eyebrow_down_max(self, new_value: float):\r\n        self.eyebrow_down_max = new_value\r\n\r\n    def set_cheek_squint_min(self, new_value: float):\r\n        self.cheek_squint_min = new_value\r\n\r\n    def set_cheek_squint_max(self, new_value: float):\r\n        self.cheek_squint_max = new_value\r\n\r\n    def set_jaw_open_min(self, new_value: float):\r\n        self.jaw_open_min = new_value\r\n\r\n    def set_jaw_open_max(self, new_value: float):\r\n        self.jaw_open_max = new_value\r\n\r\n    def set_mouth_frown_max(self, new_value: float):\r\n        self.mouth_frown_max = new_value\r\n\r\n    def set_mouth_funnel_min(self, new_value: float):\r\n        self.mouth_funnel_min = new_value\r\n\r\n    def set_mouth_funnel_max(self, new_value: float):\r\n        self.mouth_funnel_min = new_value\r\n\r\n\r\nclass IFacialMocapPoseConverter25(IFacialMocapPoseConverter):\r\n    def __init__(self, args: Optional[IFacialMocapPoseConverter25Args] = None):\r\n        super().__init__()\r\n        if args is None:\r\n            args = IFacialMocapPoseConverter25Args()\r\n        self.args = args\r\n        pose_parameters = get_pose_parameters()\r\n        self.pose_size = 45\r\n\r\n        self.eyebrow_troubled_left_index = pose_parameters.get_parameter_index(\"eyebrow_troubled_left\")\r\n        self.eyebrow_troubled_right_index = pose_parameters.get_parameter_index(\"eyebrow_troubled_right\")\r\n        self.eyebrow_angry_left_index = pose_parameters.get_parameter_index(\"eyebrow_angry_left\")\r\n        self.eyebrow_angry_right_index = pose_parameters.get_parameter_index(\"eyebrow_angry_right\")\r\n        self.eyebrow_happy_left_index = pose_parameters.get_parameter_index(\"eyebrow_happy_left\")\r\n        self.eyebrow_happy_right_index = pose_parameters.get_parameter_index(\"eyebrow_happy_right\")\r\n        self.eyebrow_raised_left_index = pose_parameters.get_parameter_index(\"eyebrow_raised_left\")\r\n        self.eyebrow_raised_right_index = pose_parameters.get_parameter_index(\"eyebrow_raised_right\")\r\n        self.eyebrow_lowered_left_index = pose_parameters.get_parameter_index(\"eyebrow_lowered_left\")\r\n        self.eyebrow_lowered_right_index = pose_parameters.get_parameter_index(\"eyebrow_lowered_right\")\r\n        self.eyebrow_serious_left_index = pose_parameters.get_parameter_index(\"eyebrow_serious_left\")\r\n        self.eyebrow_serious_right_index = pose_parameters.get_parameter_index(\"eyebrow_serious_right\")\r\n\r\n        self.eye_surprised_left_index = pose_parameters.get_parameter_index(\"eye_surprised_left\")\r\n        self.eye_surprised_right_index = pose_parameters.get_parameter_index(\"eye_surprised_right\")\r\n        self.eye_wink_left_index = pose_parameters.get_parameter_index(\"eye_wink_left\")\r\n        self.eye_wink_right_index = pose_parameters.get_parameter_index(\"eye_wink_right\")\r\n        self.eye_happy_wink_left_index = pose_parameters.get_parameter_index(\"eye_happy_wink_left\")\r\n        self.eye_happy_wink_right_index = pose_parameters.get_parameter_index(\"eye_happy_wink_right\")\r\n        self.eye_relaxed_left_index = pose_parameters.get_parameter_index(\"eye_relaxed_left\")\r\n        self.eye_relaxed_right_index = pose_parameters.get_parameter_index(\"eye_relaxed_right\")\r\n        self.eye_raised_lower_eyelid_left_index = pose_parameters.get_parameter_index(\"eye_raised_lower_eyelid_left\")\r\n        self.eye_raised_lower_eyelid_right_index = pose_parameters.get_parameter_index(\"eye_raised_lower_eyelid_right\")\r\n\r\n        self.iris_small_left_index = pose_parameters.get_parameter_index(\"iris_small_left\")\r\n        self.iris_small_right_index = pose_parameters.get_parameter_index(\"iris_small_right\")\r\n\r\n        self.iris_rotation_x_index = pose_parameters.get_parameter_index(\"iris_rotation_x\")\r\n        self.iris_rotation_y_index = pose_parameters.get_parameter_index(\"iris_rotation_y\")\r\n\r\n        self.head_x_index = pose_parameters.get_parameter_index(\"head_x\")\r\n        self.head_y_index = pose_parameters.get_parameter_index(\"head_y\")\r\n        self.neck_z_index = pose_parameters.get_parameter_index(\"neck_z\")\r\n\r\n        self.mouth_aaa_index = pose_parameters.get_parameter_index(\"mouth_aaa\")\r\n        self.mouth_iii_index = pose_parameters.get_parameter_index(\"mouth_iii\")\r\n        self.mouth_uuu_index = pose_parameters.get_parameter_index(\"mouth_uuu\")\r\n        self.mouth_eee_index = pose_parameters.get_parameter_index(\"mouth_eee\")\r\n        self.mouth_ooo_index = pose_parameters.get_parameter_index(\"mouth_ooo\")\r\n\r\n        self.mouth_lowered_corner_left_index = pose_parameters.get_parameter_index(\"mouth_lowered_corner_left\")\r\n        self.mouth_lowered_corner_right_index = pose_parameters.get_parameter_index(\"mouth_lowered_corner_right\")\r\n        self.mouth_raised_corner_left_index = pose_parameters.get_parameter_index(\"mouth_raised_corner_left\")\r\n        self.mouth_raised_corner_right_index = pose_parameters.get_parameter_index(\"mouth_raised_corner_right\")\r\n\r\n        self.body_y_index = pose_parameters.get_parameter_index(\"body_y\")\r\n        self.body_z_index = pose_parameters.get_parameter_index(\"body_z\")\r\n        self.breathing_index = pose_parameters.get_parameter_index(\"breathing\")\r\n\r\n        self.breathing_start_time = time.time()\r\n\r\n        self.panel = None\r\n\r\n    def init_pose_converter_panel(self, parent):\r\n        self.panel = wx.Panel(parent, style=wx.SIMPLE_BORDER)\r\n        self.panel_sizer = wx.BoxSizer(wx.VERTICAL)\r\n        self.panel.SetSizer(self.panel_sizer)\r\n        self.panel.SetAutoLayout(1)\r\n        parent.GetSizer().Add(self.panel, 0, wx.EXPAND)\r\n\r\n        if True:\r\n            eyebrow_down_mode_text = wx.StaticText(self.panel, label=\" --- Eyebrow Down Mode --- \",\r\n                                                   style=wx.ALIGN_CENTER)\r\n            self.panel_sizer.Add(eyebrow_down_mode_text, 0, wx.EXPAND)\r\n\r\n            self.eyebrow_down_mode_choice = wx.Choice(\r\n                self.panel,\r\n                choices=[\r\n                    \"ANGRY\",\r\n                    \"TROUBLED\",\r\n                    \"SERIOUS\",\r\n                    \"LOWERED\",\r\n                ])\r\n            self.eyebrow_down_mode_choice.SetSelection(0)\r\n            self.panel_sizer.Add(self.eyebrow_down_mode_choice, 0, wx.EXPAND)\r\n            self.eyebrow_down_mode_choice.Bind(wx.EVT_CHOICE, self.change_eyebrow_down_mode)\r\n\r\n            separator = wx.StaticLine(self.panel, -1, size=(256, 5))\r\n            self.panel_sizer.Add(separator, 0, wx.EXPAND)\r\n\r\n        if True:\r\n            wink_mode_text = wx.StaticText(self.panel, label=\" --- Wink Mode --- \", style=wx.ALIGN_CENTER)\r\n            self.panel_sizer.Add(wink_mode_text, 0, wx.EXPAND)\r\n\r\n            self.wink_mode_choice = wx.Choice(\r\n                self.panel,\r\n                choices=[\r\n                    \"NORMAL\",\r\n                    \"RELAXED\",\r\n                ])\r\n            self.wink_mode_choice.SetSelection(0)\r\n            self.panel_sizer.Add(self.wink_mode_choice, 0, wx.EXPAND)\r\n            self.wink_mode_choice.Bind(wx.EVT_CHOICE, self.change_wink_mode)\r\n\r\n            separator = wx.StaticLine(self.panel, -1, size=(256, 5))\r\n            self.panel_sizer.Add(separator, 0, wx.EXPAND)\r\n\r\n        if True:\r\n            iris_size_text = wx.StaticText(self.panel, label=\" --- Iris Size --- \", style=wx.ALIGN_CENTER)\r\n            self.panel_sizer.Add(iris_size_text, 0, wx.EXPAND)\r\n\r\n            self.iris_left_slider = wx.Slider(self.panel, minValue=0, maxValue=1000, value=0, style=wx.HORIZONTAL)\r\n            self.panel_sizer.Add(self.iris_left_slider, 0, wx.EXPAND)\r\n            self.iris_left_slider.Bind(wx.EVT_SLIDER, self.change_iris_size)\r\n\r\n            self.iris_right_slider = wx.Slider(self.panel, minValue=0, maxValue=1000, value=0, style=wx.HORIZONTAL)\r\n            self.panel_sizer.Add(self.iris_right_slider, 0, wx.EXPAND)\r\n            self.iris_right_slider.Bind(wx.EVT_SLIDER, self.change_iris_size)\r\n            self.iris_right_slider.Enable(False)\r\n\r\n            self.link_left_right_irises = wx.CheckBox(\r\n                self.panel, label=\"Use same value for both sides\")\r\n            self.link_left_right_irises.SetValue(True)\r\n            self.panel_sizer.Add(self.link_left_right_irises, wx.SizerFlags().CenterHorizontal().Border())\r\n            self.link_left_right_irises.Bind(wx.EVT_CHECKBOX, self.link_left_right_irises_clicked)\r\n\r\n            separator = wx.StaticLine(self.panel, -1, size=(256, 5))\r\n            self.panel_sizer.Add(separator, 0, wx.EXPAND)\r\n\r\n        if True:\r\n            breathing_frequency_text = wx.StaticText(\r\n                self.panel, label=\" --- Breathing --- \", style=wx.ALIGN_CENTER)\r\n            self.panel_sizer.Add(breathing_frequency_text, 0, wx.EXPAND)\r\n\r\n            self.restart_breathing_cycle_button = wx.Button(self.panel, label=\"Restart Breathing Cycle\")\r\n            self.restart_breathing_cycle_button.Bind(wx.EVT_BUTTON, self.restart_breathing_cycle_clicked)\r\n            self.panel_sizer.Add(self.restart_breathing_cycle_button, 0, wx.EXPAND)\r\n\r\n            self.breathing_frequency_slider = wx.Slider(\r\n                self.panel, minValue=0, maxValue=60, value=20, style=wx.HORIZONTAL)\r\n            self.panel_sizer.Add(self.breathing_frequency_slider, 0, wx.EXPAND)\r\n\r\n            self.breathing_gauge = wx.Gauge(self.panel, style=wx.GA_HORIZONTAL, range=1000)\r\n            self.panel_sizer.Add(self.breathing_gauge, 0, wx.EXPAND)\r\n\r\n        if True:\r\n            separator = wx.StaticLine(self.panel, -1, size=(256, 5))\r\n            self.panel_sizer.Add(separator, 0, wx.EXPAND)\r\n\r\n            convertion_parameters_text = wx.StaticText(\r\n                self.panel, label=\"--- Conversion Parameters ---\", style=wx.ALIGN_CENTER)\r\n            self.panel_sizer.Add(convertion_parameters_text, 0, wx.EXPAND)\r\n\r\n            conversion_param_panel = wx.Panel(self.panel)\r\n            self.panel_sizer.Add(conversion_param_panel, 0, wx.EXPAND)\r\n            conversion_panel_sizer = wx.FlexGridSizer(cols=2)\r\n            conversion_panel_sizer.AddGrowableCol(1)\r\n            conversion_param_panel.SetSizer(conversion_panel_sizer)\r\n            conversion_param_panel.SetAutoLayout(1)\r\n\r\n            self.smile_thresold_min_spin = self.create_spin_control(\r\n                conversion_param_panel,\r\n                \"Smile Threshold Min:\", self.args.smile_threshold_min, self.args.set_smile_threshold_min)\r\n            self.smile_thresold_max_spin = self.create_spin_control(\r\n                conversion_param_panel,\r\n                \"Smile Threshold Max:\", self.args.smile_threshold_max, self.args.set_smile_threshold_max)\r\n            self.eye_surprised_max_spin = self.create_spin_control(\r\n                conversion_param_panel,\r\n                \"Eye Surprised Max:\", self.args.eye_surprised_max, self.args.set_eye_surprised_max)\r\n            self.eye_blink_max_spin = self.create_spin_control(\r\n                conversion_param_panel,\r\n                \"Eye Blink Max:\", self.args.eye_blink_max, self.args.set_eye_blink_max)\r\n            self.eyebrow_down_max_spin = self.create_spin_control(\r\n                conversion_param_panel,\r\n                \"Eyebrow Down Max:\", self.args.eyebrow_down_max, self.args.set_eyebrow_down_max)\r\n            self.cheek_squint_min_spin = self.create_spin_control(\r\n                conversion_param_panel,\r\n                \"Cheek Squint Min:\", self.args.cheek_squint_min, self.args.set_cheek_squint_min)\r\n            self.cheek_squint_max_spin = self.create_spin_control(\r\n                conversion_param_panel,\r\n                \"Cheek Squint Max:\", self.args.cheek_squint_max, self.args.set_cheek_squint_max)\r\n            self.jaw_open_min_spin = self.create_spin_control(\r\n                conversion_param_panel,\r\n                \"Jaw Open Min:\", self.args.jaw_open_min, self.args.set_jaw_open_min)\r\n            self.jaw_open_max_spin = self.create_spin_control(\r\n                conversion_param_panel,\r\n                \"Jaw Open Max:\", self.args.jaw_open_max, self.args.set_jaw_open_max)\r\n            self.mouth_frown_max_spin = self.create_spin_control(\r\n                conversion_param_panel,\r\n                \"Mouth Frown Max:\", self.args.mouth_frown_max, self.args.set_mouth_frown_max)\r\n            self.mouth_funnel_min_spin = self.create_spin_control(\r\n                conversion_param_panel,\r\n                \"Mouth Funnel Min:\", self.args.mouth_funnel_min, self.args.set_mouth_funnel_min)\r\n            self.mouth_funnel_max_spin = self.create_spin_control(\r\n                conversion_param_panel,\r\n                \"Mouth Funnel Max:\", self.args.mouth_funnel_max, self.args.set_mouth_funnel_max)\r\n\r\n        self.panel_sizer.Fit(self.panel)\r\n\r\n    def create_spin_control(self, parent, label: str, initial_value: float, set_func: Callable[[float], None]):\r\n        sizer = parent.GetSizer()\r\n\r\n        text = wx.StaticText(parent, label=label)\r\n        sizer.Add(text, wx.SizerFlags().Right().Border(wx.ALL, 2))\r\n\r\n        spin_ctrl = wx.SpinCtrlDouble(\r\n            parent,\r\n            wx.ID_ANY,\r\n            min=0.0,\r\n            max=1.0,\r\n            initial=initial_value,\r\n            inc=0.01)\r\n        sizer.Add(spin_ctrl, wx.SizerFlags().Border(wx.ALL, 2).Expand())\r\n\r\n        def handler(event: wx.Event):\r\n            new_value = spin_ctrl.GetValue()\r\n            set_func(new_value)\r\n\r\n        spin_ctrl.Bind(wx.EVT_SPINCTRLDOUBLE, handler)\r\n\r\n        return spin_ctrl\r\n\r\n    def restart_breathing_cycle_clicked(self, event: wx.Event):\r\n        self.breathing_start_time = time.time()\r\n\r\n    def change_eyebrow_down_mode(self, event: wx.Event):\r\n        selected_index = self.eyebrow_down_mode_choice.GetSelection()\r\n        if selected_index == 0:\r\n            self.args.eyebrow_down_mode = EyebrowDownMode.ANGRY\r\n        elif selected_index == 1:\r\n            self.args.eyebrow_down_mode = EyebrowDownMode.TROUBLED\r\n        elif selected_index == 2:\r\n            self.args.eyebrow_down_mode = EyebrowDownMode.SERIOUS\r\n        else:\r\n            self.args.eyebrow_down_mode = EyebrowDownMode.LOWERED\r\n\r\n    def change_wink_mode(self, event: wx.Event):\r\n        selected_index = self.wink_mode_choice.GetSelection()\r\n        if selected_index == 0:\r\n            self.args.wink_mode = WinkMode.NORMAL\r\n        else:\r\n            self.args.wink_mode = WinkMode.RELAXED\r\n\r\n    def change_iris_size(self, event: wx.Event):\r\n        if self.link_left_right_irises.GetValue():\r\n            left_value = self.iris_left_slider.GetValue()\r\n            right_value = self.iris_right_slider.GetValue()\r\n            if left_value != right_value:\r\n                self.iris_right_slider.SetValue(left_value)\r\n            self.args.iris_small_left = left_value / 1000.0\r\n            self.args.iris_small_right = left_value / 1000.0\r\n        else:\r\n            self.args.iris_small_left = self.iris_left_slider.GetValue() / 1000.0\r\n            self.args.iris_small_right = self.iris_right_slider.GetValue() / 1000.0\r\n\r\n    def link_left_right_irises_clicked(self, event: wx.Event):\r\n        if self.link_left_right_irises.GetValue():\r\n            self.iris_right_slider.Enable(False)\r\n        else:\r\n            self.iris_right_slider.Enable(True)\r\n        self.change_iris_size(event)\r\n\r\n    def decompose_head_body_param(self, param, threshold=2.0 / 3):\r\n        if abs(param) < threshold:\r\n            return (param, 0.0)\r\n        else:\r\n            if param < 0:\r\n                sign = -1.0\r\n            else:\r\n                sign = 1.0\r\n            return (threshold * sign, (abs(param) - threshold) * sign)\r\n\r\n    def convert(self, ifacialmocap_pose: Dict[str, float]) -> List[float]:\r\n        pose = [0.0 for i in range(self.pose_size)]\r\n\r\n        smile_value = \\\r\n            (ifacialmocap_pose[MOUTH_SMILE_LEFT] + ifacialmocap_pose[MOUTH_SMILE_RIGHT]) / 2.0 \\\r\n            + ifacialmocap_pose[MOUTH_SHRUG_UPPER]\r\n        if self.args.smile_threshold_min >= self.args.smile_threshold_max:\r\n            smile_degree = 0.0\r\n        else:\r\n            if smile_value < self.args.smile_threshold_min:\r\n                smile_degree = 0.0\r\n            elif smile_value > self.args.smile_threshold_max:\r\n                smile_degree = 1.0\r\n            else:\r\n                smile_degree = (smile_value - self.args.smile_threshold_min) / (\r\n                        self.args.smile_threshold_max - self.args.smile_threshold_min)\r\n\r\n        # Eyebrow\r\n        if True:\r\n            brow_inner_up = ifacialmocap_pose[BROW_INNER_UP]\r\n            brow_outer_up_right = ifacialmocap_pose[BROW_OUTER_UP_RIGHT]\r\n            brow_outer_up_left = ifacialmocap_pose[BROW_OUTER_UP_LEFT]\r\n\r\n            brow_up_left = clamp(brow_inner_up + brow_outer_up_left, 0.0, 1.0)\r\n            brow_up_right = clamp(brow_inner_up + brow_outer_up_right, 0.0, 1.0)\r\n            pose[self.eyebrow_raised_left_index] = brow_up_left\r\n            pose[self.eyebrow_raised_right_index] = brow_up_right\r\n\r\n            if self.args.eyebrow_down_max <= 0.0:\r\n                brow_down_left = 0.0\r\n                brow_down_right = 0.0\r\n            else:\r\n                brow_down_left = (1.0 - smile_degree) \\\r\n                                 * clamp(ifacialmocap_pose[BROW_DOWN_LEFT] / self.args.eyebrow_down_max, 0.0, 1.0)\r\n                brow_down_right = (1.0 - smile_degree) \\\r\n                                  * clamp(ifacialmocap_pose[BROW_DOWN_RIGHT] / self.args.eyebrow_down_max, 0.0, 1.0)\r\n\r\n            if self.args.eyebrow_down_mode == EyebrowDownMode.TROUBLED:\r\n                pose[self.eyebrow_troubled_left_index] = brow_down_left\r\n                pose[self.eyebrow_troubled_right_index] = brow_down_right\r\n            elif self.args.eyebrow_down_mode == EyebrowDownMode.ANGRY:\r\n                pose[self.eyebrow_angry_left_index] = brow_down_left\r\n                pose[self.eyebrow_angry_right_index] = brow_down_right\r\n            elif self.args.eyebrow_down_mode == EyebrowDownMode.LOWERED:\r\n                pose[self.eyebrow_lowered_left_index] = brow_down_left\r\n                pose[self.eyebrow_lowered_right_index] = brow_down_right\r\n            elif self.args.eyebrow_down_mode == EyebrowDownMode.SERIOUS:\r\n                pose[self.eyebrow_serious_left_index] = brow_down_left\r\n                pose[self.eyebrow_serious_right_index] = brow_down_right\r\n\r\n            brow_happy_value = clamp(smile_value, 0.0, 1.0) * smile_degree\r\n            pose[self.eyebrow_happy_left_index] = brow_happy_value\r\n            pose[self.eyebrow_happy_right_index] = brow_happy_value\r\n\r\n        # Eye\r\n        if True:\r\n            # Surprised\r\n            if self.args.eye_surprised_max <= 0.0:\r\n                pose[self.eye_surprised_left_index] = 0.0\r\n                pose[self.eye_surprised_right_index] = 0.0\r\n            else:\r\n                pose[self.eye_surprised_left_index] = clamp(\r\n                    ifacialmocap_pose[EYE_WIDE_LEFT] / self.args.eye_surprised_max, 0.0, 1.0)\r\n                pose[self.eye_surprised_right_index] = clamp(\r\n                    ifacialmocap_pose[EYE_WIDE_RIGHT] / self.args.eye_surprised_max, 0.0, 1.0)\r\n\r\n            # Wink\r\n            if self.args.wink_mode == WinkMode.NORMAL:\r\n                wink_left_index = self.eye_wink_left_index\r\n                wink_right_index = self.eye_wink_right_index\r\n            else:\r\n                wink_left_index = self.eye_relaxed_left_index\r\n                wink_right_index = self.eye_relaxed_right_index\r\n            if self.args.eye_blink_max <= 0:\r\n                pose[wink_left_index] = 0.0\r\n                pose[wink_right_index] = 0.0\r\n                pose[self.eye_happy_wink_left_index] = 0.0\r\n                pose[self.eye_happy_wink_right_index] = 0.0\r\n            else:\r\n                pose[wink_left_index] = (1.0 - smile_degree) * clamp(\r\n                    ifacialmocap_pose[EYE_BLINK_LEFT] / self.args.eye_blink_max, 0.0, 1.0)\r\n                pose[wink_right_index] = (1.0 - smile_degree) * clamp(\r\n                    ifacialmocap_pose[EYE_BLINK_RIGHT] / self.args.eye_blink_max, 0.0, 1.0)\r\n                pose[self.eye_happy_wink_left_index] = smile_degree * clamp(\r\n                    ifacialmocap_pose[EYE_BLINK_LEFT] / self.args.eye_blink_max, 0.0, 1.0)\r\n                pose[self.eye_happy_wink_right_index] = smile_degree * clamp(\r\n                    ifacialmocap_pose[EYE_BLINK_RIGHT] / self.args.eye_blink_max, 0.0, 1.0)\r\n\r\n            # Lower eyelid\r\n            cheek_squint_denom = self.args.cheek_squint_max - self.args.cheek_squint_min\r\n            if cheek_squint_denom <= 0.0:\r\n                pose[self.eye_raised_lower_eyelid_left_index] = 0.0\r\n                pose[self.eye_raised_lower_eyelid_right_index] = 0.0\r\n            else:\r\n                pose[self.eye_raised_lower_eyelid_left_index] = \\\r\n                    clamp(\r\n                        (ifacialmocap_pose[CHEEK_SQUINT_LEFT] - self.args.cheek_squint_min) / cheek_squint_denom,\r\n                        0.0, 1.0)\r\n                pose[self.eye_raised_lower_eyelid_right_index] = \\\r\n                    clamp(\r\n                        (ifacialmocap_pose[CHEEK_SQUINT_RIGHT] - self.args.cheek_squint_min) / cheek_squint_denom,\r\n                        0.0, 1.0)\r\n\r\n        # Iris rotation\r\n        if True:\r\n            eye_rotation_y = (ifacialmocap_pose[EYE_LOOK_IN_LEFT]\r\n                              - ifacialmocap_pose[EYE_LOOK_OUT_LEFT]\r\n                              - ifacialmocap_pose[EYE_LOOK_IN_RIGHT]\r\n                              + ifacialmocap_pose[EYE_LOOK_OUT_RIGHT]) / 2.0 * self.args.eye_rotation_factor\r\n            pose[self.iris_rotation_y_index] = clamp(eye_rotation_y, -1.0, 1.0)\r\n\r\n            eye_rotation_x = (ifacialmocap_pose[EYE_LOOK_UP_LEFT]\r\n                              + ifacialmocap_pose[EYE_LOOK_UP_RIGHT]\r\n                              - ifacialmocap_pose[EYE_LOOK_DOWN_LEFT]\r\n                              - ifacialmocap_pose[EYE_LOOK_DOWN_RIGHT]) / 2.0 * self.args.eye_rotation_factor\r\n            pose[self.iris_rotation_x_index] = clamp(eye_rotation_x, -1.0, 1.0)\r\n\r\n        # Iris size\r\n        if True:\r\n            pose[self.iris_small_left_index] = self.args.iris_small_left\r\n            pose[self.iris_small_right_index] = self.args.iris_small_right\r\n\r\n        # Head rotation\r\n        if True:\r\n            x_param = clamp(-ifacialmocap_pose[HEAD_BONE_X] * 180.0 / math.pi, -15.0, 15.0) / 15.0\r\n            pose[self.head_x_index] = x_param\r\n\r\n            y_param = clamp(-ifacialmocap_pose[HEAD_BONE_Y] * 180.0 / math.pi, -10.0, 10.0) / 10.0\r\n            pose[self.head_y_index] = y_param\r\n            pose[self.body_y_index] = y_param\r\n\r\n            z_param = clamp(ifacialmocap_pose[HEAD_BONE_Z] * 180.0 / math.pi, -15.0, 15.0) / 15.0\r\n            pose[self.neck_z_index] = z_param\r\n            pose[self.body_z_index] = z_param\r\n\r\n        # Mouth\r\n        if True:\r\n            jaw_open_denom = self.args.jaw_open_max - self.args.jaw_open_min\r\n            if jaw_open_denom <= 0:\r\n                mouth_open = 0.0\r\n            else:\r\n                mouth_open = clamp((ifacialmocap_pose[JAW_OPEN] - self.args.jaw_open_min) / jaw_open_denom, 0.0, 1.0)\r\n            pose[self.mouth_aaa_index] = mouth_open\r\n            pose[self.mouth_raised_corner_left_index] = clamp(smile_value, 0.0, 1.0)\r\n            pose[self.mouth_raised_corner_right_index] = clamp(smile_value, 0.0, 1.0)\r\n\r\n            is_mouth_open = mouth_open > 0.0\r\n            if not is_mouth_open:\r\n                if self.args.mouth_frown_max > 0:\r\n                    mouth_frown_value = 0.0\r\n                else:\r\n                    mouth_frown_value = clamp(\r\n                        (ifacialmocap_pose[MOUTH_FROWN_LEFT] + ifacialmocap_pose[\r\n                            MOUTH_FROWN_RIGHT]) / self.args.mouth_frown_max, 0.0, 1.0)\r\n                pose[self.mouth_lowered_corner_left_index] = mouth_frown_value\r\n                pose[self.mouth_lowered_corner_right_index] = mouth_frown_value\r\n            else:\r\n                mouth_lower_down = clamp(\r\n                    ifacialmocap_pose[MOUTH_LOWER_DOWN_LEFT] + ifacialmocap_pose[MOUTH_LOWER_DOWN_RIGHT], 0.0, 1.0)\r\n                mouth_funnel = ifacialmocap_pose[MOUTH_FUNNEL]\r\n                mouth_pucker = ifacialmocap_pose[MOUTH_PUCKER]\r\n\r\n                mouth_point = [mouth_open, mouth_lower_down, mouth_funnel, mouth_pucker]\r\n\r\n                aaa_point = [1.0, 1.0, 0.0, 0.0]\r\n                iii_point = [0.0, 1.0, 0.0, 0.0]\r\n                uuu_point = [0.5, 0.3, 0.25, 0.75]\r\n                ooo_point = [1.0, 0.5, 0.5, 0.4]\r\n\r\n                decomp = numpy.array([0, 0, 0, 0])\r\n                M = numpy.array([\r\n                    aaa_point,\r\n                    iii_point,\r\n                    uuu_point,\r\n                    ooo_point\r\n                ])\r\n\r\n                def loss(decomp):\r\n                    return numpy.linalg.norm(numpy.matmul(decomp, M) - mouth_point) \\\r\n                        + 0.01 * numpy.linalg.norm(decomp, ord=1)\r\n\r\n                opt_result = scipy.optimize.minimize(\r\n                    loss, decomp, bounds=[(0.0, 1.0), (0.0, 1.0), (0.0, 1.0), (0.0, 1.0)])\r\n                decomp = opt_result[\"x\"]\r\n                restricted_decomp = [decomp.item(0), decomp.item(1), decomp.item(2), decomp.item(3)]\r\n                pose[self.mouth_aaa_index] = restricted_decomp[0]\r\n                pose[self.mouth_iii_index] = restricted_decomp[1]\r\n                mouth_funnel_denom = self.args.mouth_funnel_max - self.args.mouth_funnel_min\r\n                if mouth_funnel_denom <= 0:\r\n                    ooo_alpha = 0.0\r\n                    uo_value = 0.0\r\n                else:\r\n                    ooo_alpha = clamp((mouth_funnel - self.args.mouth_funnel_min) / mouth_funnel_denom, 0.0, 1.0)\r\n                    uo_value = clamp(restricted_decomp[2] + restricted_decomp[3], 0.0, 1.0)\r\n                pose[self.mouth_uuu_index] = uo_value * (1.0 - ooo_alpha)\r\n                pose[self.mouth_ooo_index] = uo_value * ooo_alpha\r\n\r\n        if self.panel is not None:\r\n            frequency = self.breathing_frequency_slider.GetValue()\r\n            if frequency == 0:\r\n                value = 0.0\r\n                pose[self.breathing_index] = value\r\n                self.breathing_start_time = time.time()\r\n            else:\r\n                period = 60.0 / frequency\r\n                now = time.time()\r\n                diff = now - self.breathing_start_time\r\n                frac = (diff % period) / period\r\n                value = (-math.cos(2 * math.pi * frac) + 1.0) / 2.0\r\n                pose[self.breathing_index] = value\r\n            self.breathing_gauge.SetValue(int(1000 * value))\r\n\r\n        return pose\r\n\r\n\r\ndef create_ifacialmocap_pose_converter(\r\n        args: Optional[IFacialMocapPoseConverter25Args] = None) -> IFacialMocapPoseConverter:\r\n    return IFacialMocapPoseConverter25(args)\r\n"
  },
  {
    "path": "src/tha4/mocap/ifacialmocap_v2.py",
    "content": "import math\r\n\r\nfrom tha4.mocap.ifacialmocap_constants import BLENDSHAPE_NAMES, HEAD_BONE_X, HEAD_BONE_Y, HEAD_BONE_Z, \\\r\n    RIGHT_EYE_BONE_X, RIGHT_EYE_BONE_Y, RIGHT_EYE_BONE_Z, LEFT_EYE_BONE_X, LEFT_EYE_BONE_Y, LEFT_EYE_BONE_Z, \\\r\n    HEAD_BONE_QUAT, LEFT_EYE_BONE_QUAT, RIGHT_EYE_BONE_QUAT\r\n\r\nIFACIALMOCAP_PORT = 49983\r\nIFACIALMOCAP_START_STRING = \"iFacialMocap_sahuasouryya9218sauhuiayeta91555dy3719|sendDataVersion=v2\".encode('utf-8')\r\n\r\n\r\ndef parse_ifacialmocap_v2_pose(ifacialmocap_output):\r\n    output = {}\r\n    parts = ifacialmocap_output.split(\"|\")\r\n    for part in parts:\r\n        part = part.strip()\r\n        if len(part) == 0:\r\n            continue\r\n        if \"&\" in part:\r\n            components = part.split(\"&\")\r\n            assert len(components) == 2\r\n            key = components[0]\r\n            value = float(components[1]) / 100.0\r\n            if key.endswith(\"_L\"):\r\n                key = key[:-2] + \"Left\"\r\n            elif key.endswith(\"_R\"):\r\n                key = key[:-2] + \"Right\"\r\n            if key in BLENDSHAPE_NAMES:\r\n                output[key] = value\r\n        elif part.startswith(\"=head#\"):\r\n            components = part[len(\"=head#\"):].split(\",\")\r\n            assert len(components) == 6\r\n            output[HEAD_BONE_X] = float(components[0]) * math.pi / 180\r\n            output[HEAD_BONE_Y] = float(components[1]) * math.pi / 180\r\n            output[HEAD_BONE_Z] = float(components[2]) * math.pi / 180\r\n        elif part.startswith(\"rightEye#\"):\r\n            components = part[len(\"rightEye#\"):].split(\",\")\r\n            output[RIGHT_EYE_BONE_X] = float(components[0]) * math.pi / 180\r\n            output[RIGHT_EYE_BONE_Y] = float(components[1]) * math.pi / 180\r\n            output[RIGHT_EYE_BONE_Z] = float(components[2]) * math.pi / 180\r\n        elif part.startswith(\"leftEye#\"):\r\n            components = part[len(\"leftEye#\"):].split(\",\")\r\n            output[LEFT_EYE_BONE_X] = float(components[0]) * math.pi / 180\r\n            output[LEFT_EYE_BONE_Y] = float(components[1]) * math.pi / 180\r\n            output[LEFT_EYE_BONE_Z] = float(components[2]) * math.pi / 180\r\n    output[HEAD_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]\r\n    output[LEFT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]\r\n    output[RIGHT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]\r\n    return output\r\n\r\n\r\ndef parse_ifacialmocap_v1_pose(ifacialmocap_output):\r\n    output = {}\r\n    parts = ifacialmocap_output.split(\"|\")\r\n    for part in parts:\r\n        part = part.strip()\r\n        if len(part) == 0:\r\n            continue\r\n        if part.startswith(\"=head#\"):\r\n            components = part[len(\"=head#\"):].split(\",\")\r\n            assert len(components) == 6\r\n            output[HEAD_BONE_X] = float(components[0]) * math.pi / 180\r\n            output[HEAD_BONE_Y] = float(components[1]) * math.pi / 180\r\n            output[HEAD_BONE_Z] = float(components[2]) * math.pi / 180\r\n        elif part.startswith(\"rightEye#\"):\r\n            components = part[len(\"rightEye#\"):].split(\",\")\r\n            output[RIGHT_EYE_BONE_X] = float(components[0]) * math.pi / 180\r\n            output[RIGHT_EYE_BONE_Y] = float(components[1]) * math.pi / 180\r\n            output[RIGHT_EYE_BONE_Z] = float(components[2]) * math.pi / 180\r\n        elif part.startswith(\"leftEye#\"):\r\n            components = part[len(\"leftEye#\"):].split(\",\")\r\n            output[LEFT_EYE_BONE_X] = float(components[0]) * math.pi / 180\r\n            output[LEFT_EYE_BONE_Y] = float(components[1]) * math.pi / 180\r\n            output[LEFT_EYE_BONE_Z] = float(components[2]) * math.pi / 180\r\n        else:\r\n            components = part.split(\"-\")\r\n            assert len(components) == 2\r\n            key = components[0]\r\n            value = float(components[1]) / 100.0\r\n            if key.endswith(\"_L\"):\r\n                key = key[:-2] + \"Left\"\r\n            elif key.endswith(\"_R\"):\r\n                key = key[:-2] + \"Right\"\r\n            if key in BLENDSHAPE_NAMES:\r\n                output[key] = value\r\n    output[HEAD_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]\r\n    output[LEFT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]\r\n    output[RIGHT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]\r\n    return output\r\n\r\n"
  },
  {
    "path": "src/tha4/mocap/mediapipe_constants.py",
    "content": "EYE_LOOK_IN_LEFT = \"eyeLookInLeft\"\r\nEYE_LOOK_OUT_LEFT = \"eyeLookOutLeft\"\r\nEYE_LOOK_DOWN_LEFT = \"eyeLookDownLeft\"\r\nEYE_LOOK_UP_LEFT = \"eyeLookUpLeft\"\r\nEYE_BLINK_LEFT = \"eyeBlinkLeft\"\r\nEYE_SQUINT_LEFT = \"eyeSquintLeft\"\r\nEYE_WIDE_LEFT = \"eyeWideLeft\"\r\nEYE_LOOK_IN_RIGHT = \"eyeLookInRight\"\r\nEYE_LOOK_OUT_RIGHT = \"eyeLookOutRight\"\r\nEYE_LOOK_DOWN_RIGHT = \"eyeLookDownRight\"\r\nEYE_LOOK_UP_RIGHT = \"eyeLookUpRight\"\r\nEYE_BLINK_RIGHT = \"eyeBlinkRight\"\r\nEYE_SQUINT_RIGHT = \"eyeSquintRight\"\r\nEYE_WIDE_RIGHT = \"eyeWideRight\"\r\nBROW_DOWN_LEFT = \"browDownLeft\"\r\nBROW_OUTER_UP_LEFT = \"browOuterUpLeft\"\r\nBROW_DOWN_RIGHT = \"browDownRight\"\r\nBROW_OUTER_UP_RIGHT = \"browOuterUpRight\"\r\nBROW_INNER_UP = \"browInnerUp\"\r\nNOSE_SNEER_LEFT = \"noseSneerLeft\"\r\nNOSE_SNEER_RIGHT = \"noseSneerRight\"\r\nCHEEK_SQUINT_LEFT = \"cheekSquintLeft\"\r\nCHEEK_SQUINT_RIGHT = \"cheekSquintRight\"\r\nCHEEK_PUFF = \"cheekPuff\"\r\nMOUTH_LEFT = \"mouthLeft\"\r\nMOUTH_DIMPLE_LEFT = \"mouthDimpleLeft\"\r\nMOUTH_FROWN_LEFT = \"mouthFrownLeft\"\r\nMOUTH_LOWER_DOWN_LEFT = \"mouthLowerDownLeft\"\r\nMOUTH_PRESS_LEFT = \"mouthPressLeft\"\r\nMOUTH_SMILE_LEFT = \"mouthSmileLeft\"\r\nMOUTH_STRETCH_LEFT = \"mouthStretchLeft\"\r\nMOUTH_UPPER_UP_LEFT = \"mouthUpperUpLeft\"\r\nMOUTH_RIGHT = \"mouthRight\"\r\nMOUTH_DIMPLE_RIGHT = \"mouthDimpleRight\"\r\nMOUTH_FROWN_RIGHT = \"mouthFrownRight\"\r\nMOUTH_LOWER_DOWN_RIGHT = \"mouthLowerDownRight\"\r\nMOUTH_PRESS_RIGHT = \"mouthPressRight\"\r\nMOUTH_SMILE_RIGHT = \"mouthSmileRight\"\r\nMOUTH_STRETCH_RIGHT = \"mouthStretchRight\"\r\nMOUTH_UPPER_UP_RIGHT = \"mouthUpperUpRight\"\r\nMOUTH_CLOSE = \"mouthClose\"\r\nMOUTH_FUNNEL = \"mouthFunnel\"\r\nMOUTH_PUCKER = \"mouthPucker\"\r\nMOUTH_ROLL_LOWER = \"mouthRollLower\"\r\nMOUTH_ROLL_UPPER = \"mouthRollUpper\"\r\nMOUTH_SHRUG_LOWER = \"mouthShrugLower\"\r\nMOUTH_SHRUG_UPPER = \"mouthShrugUpper\"\r\nJAW_LEFT = \"jawLeft\"\r\nJAW_RIGHT = \"jawRight\"\r\nJAW_FORWARD = \"jawForward\"\r\nJAW_OPEN = \"jawOpen\"\r\nNEUTRAL = \"_neutral\"\r\n\r\nBLENDSHAPE_NAMES = [\r\n    EYE_LOOK_IN_LEFT,  # 0\r\n    EYE_LOOK_OUT_LEFT,  # 1\r\n    EYE_LOOK_DOWN_LEFT,  # 2\r\n    EYE_LOOK_UP_LEFT,  # 3\r\n    EYE_BLINK_LEFT,  # 4\r\n    EYE_SQUINT_LEFT,  # 5\r\n    EYE_WIDE_LEFT,  # 6\r\n    EYE_LOOK_IN_RIGHT,  # 7\r\n    EYE_LOOK_OUT_RIGHT,  # 8\r\n    EYE_LOOK_DOWN_RIGHT,  # 9\r\n    EYE_LOOK_UP_RIGHT,  # 10\r\n    EYE_BLINK_RIGHT,  # 11\r\n    EYE_SQUINT_RIGHT,  # 12\r\n    EYE_WIDE_RIGHT,  # 13\r\n    BROW_DOWN_LEFT,  # 14\r\n    BROW_OUTER_UP_LEFT,  # 15\r\n    BROW_DOWN_RIGHT,  # 16\r\n    BROW_OUTER_UP_RIGHT,  # 17\r\n    BROW_INNER_UP,  # 18\r\n    NOSE_SNEER_LEFT,  # 19\r\n    NOSE_SNEER_RIGHT,  # 20\r\n    CHEEK_SQUINT_LEFT,  # 21\r\n    CHEEK_SQUINT_RIGHT,  # 22\r\n    CHEEK_PUFF,  # 23\r\n    MOUTH_LEFT,  # 24\r\n    MOUTH_DIMPLE_LEFT,  # 25\r\n    MOUTH_FROWN_LEFT,  # 26\r\n    MOUTH_LOWER_DOWN_LEFT,  # 27\r\n    MOUTH_PRESS_LEFT,  # 28\r\n    MOUTH_SMILE_LEFT,  # 29\r\n    MOUTH_STRETCH_LEFT,  # 30\r\n    MOUTH_UPPER_UP_LEFT,  # 31\r\n    MOUTH_RIGHT,  # 32\r\n    MOUTH_DIMPLE_RIGHT,  # 33\r\n    MOUTH_FROWN_RIGHT,  # 34\r\n    MOUTH_LOWER_DOWN_RIGHT,  # 35\r\n    MOUTH_PRESS_RIGHT,  # 36\r\n    MOUTH_SMILE_RIGHT,  # 37\r\n    MOUTH_STRETCH_RIGHT,  # 38\r\n    MOUTH_UPPER_UP_RIGHT,  # 39\r\n    MOUTH_CLOSE,  # 40\r\n    MOUTH_FUNNEL,  # 41\r\n    MOUTH_PUCKER,  # 42\r\n    MOUTH_ROLL_LOWER,  # 43\r\n    MOUTH_ROLL_UPPER,  # 44\r\n    MOUTH_SHRUG_LOWER,  # 45\r\n    MOUTH_SHRUG_UPPER,  # 46\r\n    JAW_LEFT,  # 47\r\n    JAW_RIGHT,  # 48\r\n    JAW_FORWARD,  # 49\r\n    JAW_OPEN,  # 50\r\n    NEUTRAL,  # 51\r\n]\r\n\r\nEYE_LEFT_BLENDSHAPES = [\r\n    EYE_LOOK_IN_LEFT,  # 0\r\n    EYE_LOOK_OUT_LEFT,  # 1\r\n    EYE_LOOK_DOWN_LEFT,  # 2\r\n    EYE_LOOK_UP_LEFT,  # 3\r\n    EYE_BLINK_LEFT,  # 4\r\n    EYE_SQUINT_LEFT,  # 5\r\n    EYE_WIDE_LEFT,  # 6\r\n]\r\n\r\nEYE_RIGHT_BLENDSHAPES = [\r\n    EYE_LOOK_IN_RIGHT,  # 7\r\n    EYE_LOOK_OUT_RIGHT,  # 8\r\n    EYE_LOOK_DOWN_RIGHT,  # 9\r\n    EYE_LOOK_UP_RIGHT,  # 10\r\n    EYE_BLINK_RIGHT,  # 11\r\n    EYE_SQUINT_RIGHT,  # 12\r\n    EYE_WIDE_RIGHT,  # 13\r\n]\r\n\r\nBROW_LEFT_BLENDSHAPES = [\r\n    BROW_DOWN_LEFT,  # 14\r\n    BROW_OUTER_UP_LEFT,  # 15\r\n\r\n]\r\n\r\nBROW_RIGHT_BLENDSHAPES = [\r\n    BROW_DOWN_RIGHT,  # 16\r\n    BROW_OUTER_UP_RIGHT,  # 17\r\n\r\n]\r\n\r\nBROW_BOTH_BLENDSHAPES = [\r\n    BROW_INNER_UP,  # 18\r\n]\r\n\r\nNOSE_BLENDSHAPES = [\r\n    NOSE_SNEER_LEFT,  # 19\r\n    NOSE_SNEER_RIGHT,  # 20\r\n]\r\n\r\nCHECK_BLENDSHAPES = [\r\n    CHEEK_SQUINT_LEFT,  # 21\r\n    CHEEK_SQUINT_RIGHT,  # 22\r\n    CHEEK_PUFF,  # 23\r\n]\r\n\r\nMOUTH_LEFT_BLENDSHAPES = [\r\n    MOUTH_LEFT,  # 24\r\n    MOUTH_DIMPLE_LEFT,  # 25\r\n    MOUTH_FROWN_LEFT,  # 26\r\n    MOUTH_LOWER_DOWN_LEFT,  # 27\r\n    MOUTH_PRESS_LEFT,  # 28\r\n    MOUTH_SMILE_LEFT,  # 29\r\n    MOUTH_STRETCH_LEFT,  # 30\r\n    MOUTH_UPPER_UP_LEFT,  # 31\r\n]\r\n\r\nMOUTH_RIGHT_BLENDSHAPES = [\r\n    MOUTH_RIGHT,  # 32\r\n    MOUTH_DIMPLE_RIGHT,  # 33\r\n    MOUTH_FROWN_RIGHT,  # 34\r\n    MOUTH_LOWER_DOWN_RIGHT,  # 35\r\n    MOUTH_PRESS_RIGHT,  # 36\r\n    MOUTH_SMILE_RIGHT,  # 37\r\n    MOUTH_STRETCH_RIGHT,  # 38\r\n    MOUTH_UPPER_UP_RIGHT,  # 39\r\n]\r\n\r\nMOUTH_BOTH_BLENDSHAPES = [\r\n    MOUTH_CLOSE,  # 40\r\n    MOUTH_FUNNEL,  # 41\r\n    MOUTH_PUCKER,  # 42\r\n    MOUTH_ROLL_LOWER,  # 43\r\n    MOUTH_ROLL_UPPER,  # 44\r\n    MOUTH_SHRUG_LOWER,  # 45\r\n    MOUTH_SHRUG_UPPER,  # 46\r\n]\r\n\r\nJAW_BLENDSHAPES = [\r\n    JAW_LEFT,  # 47\r\n    JAW_RIGHT,  # 48\r\n    JAW_FORWARD,  # 49\r\n    JAW_OPEN,  # 50\r\n]\r\n\r\nNEUTRAL_BLENDSHAPES = [\r\n    NEUTRAL,  # 51\r\n]\r\n\r\nCOLUMN_0_BLENDSHAPES = EYE_RIGHT_BLENDSHAPES + BROW_RIGHT_BLENDSHAPES + [NOSE_SNEER_RIGHT, CHEEK_SQUINT_RIGHT]\r\nCOLUMN_1_BLENDSHAPES = EYE_LEFT_BLENDSHAPES + BROW_LEFT_BLENDSHAPES + [NOSE_SNEER_LEFT, CHEEK_SQUINT_LEFT]\r\nCOLUMN_2_BLENDSHAPES = MOUTH_RIGHT_BLENDSHAPES + [JAW_RIGHT]\r\nCOLUMN_3_BLENDSHAPES = MOUTH_LEFT_BLENDSHAPES + [JAW_LEFT]\r\nCOLUMN_4_BLENDSHAPES = [BROW_INNER_UP, CHEEK_PUFF] + MOUTH_BOTH_BLENDSHAPES + [JAW_FORWARD, JAW_OPEN, NEUTRAL]\r\n\r\nBLENDSHAPE_COLUMNS = [\r\n    COLUMN_0_BLENDSHAPES,\r\n    COLUMN_1_BLENDSHAPES,\r\n    COLUMN_2_BLENDSHAPES,\r\n    COLUMN_3_BLENDSHAPES,\r\n    COLUMN_4_BLENDSHAPES,\r\n]\r\n\r\nHEAD_X = \"headX\"\r\nHEAD_Y = \"headY\"\r\nHEAD_Z = \"headZ\"\r\nHEAD_ROTATIONS = [HEAD_X, HEAD_Y, HEAD_Z]"
  },
  {
    "path": "src/tha4/mocap/mediapipe_face_pose.py",
    "content": "import json\r\nimport os\r\nfrom typing import Optional, Dict\r\n\r\nimport numpy\r\n\r\n\r\nclass MediaPipeFacePose:\r\n    KEY_BLENDSHAPE_PARAMS = \"blendshape_params\"\r\n    KEY_XFORM_MATRIX = \"xform_matrix\"\r\n\r\n    def __init__(self, blendshape_params: Optional[Dict[str, float]], xform_matrix: Optional[numpy.ndarray]):\r\n        if blendshape_params is None:\r\n            blendshape_params = {}\r\n        if xform_matrix is None:\r\n            self.xform_matrix = numpy.zeros(4, 4)\r\n            for i in range(4):\r\n                self.xform_matrix[i, i] = 1.0\r\n\r\n        self.blendshape_params = blendshape_params\r\n        self.xform_matrix = xform_matrix\r\n\r\n    def get_json(self):\r\n        return {\r\n            MediaPipeFacePose.KEY_BLENDSHAPE_PARAMS: self.blendshape_params.copy(),\r\n            MediaPipeFacePose.KEY_XFORM_MATRIX: self.xform_matrix.tolist()\r\n        }\r\n\r\n    def save(self, file_name: str):\r\n        os.makedirs(os.path.dirname(file_name), exist_ok=True)\r\n        with open(file_name, \"wt\") as fout:\r\n            fout.write(json.dumps(self.get_json()))\r\n\r\n    @staticmethod\r\n    def load(file_name: str):\r\n        with open(file_name, \"rt\") as fin:\r\n            s = fin.read()\r\n            json_data = json.loads(s)\r\n            return MediaPipeFacePose(\r\n                json_data[MediaPipeFacePose.KEY_BLENDSHAPE_PARAMS],\r\n                xform_matrix = numpy.array(json_data[MediaPipeFacePose.KEY_XFORM_MATRIX]))\r\n"
  },
  {
    "path": "src/tha4/mocap/mediapipe_face_pose_converter.py",
    "content": "from abc import ABC, abstractmethod\r\nfrom typing import List, Callable, Optional\r\n\r\nfrom tha4.mocap.mediapipe_face_pose import MediaPipeFacePose\r\n\r\n\r\nclass MediaPipeFacePoseConverter(ABC):\r\n    @abstractmethod\r\n    def convert(self, mediapipe_face_pose: MediaPipeFacePose) -> List[float]:\r\n        pass\r\n\r\n    @abstractmethod\r\n    def init_pose_converter_panel(\r\n            self,\r\n            parent,\r\n            current_pose_supplier: Callable[[], Optional[MediaPipeFacePose]]):\r\n        pass"
  },
  {
    "path": "src/tha4/mocap/mediapipe_face_pose_converter_00.py",
    "content": "import math\r\nimport time\r\nfrom enum import Enum\r\nfrom typing import Optional, List, Callable\r\n\r\nimport numpy\r\nimport scipy.optimize\r\nimport wx\r\nfrom scipy.spatial.transform import Rotation\r\n\r\nfrom tha4.poser.modes.pose_parameters import get_pose_parameters\r\nfrom tha4.mocap.mediapipe_constants import MOUTH_SMILE_LEFT, MOUTH_SHRUG_UPPER, MOUTH_SMILE_RIGHT, \\\r\n    BROW_INNER_UP, BROW_OUTER_UP_RIGHT, BROW_OUTER_UP_LEFT, BROW_DOWN_LEFT, BROW_DOWN_RIGHT, EYE_WIDE_LEFT, \\\r\n    EYE_WIDE_RIGHT, EYE_BLINK_LEFT, EYE_BLINK_RIGHT, CHEEK_SQUINT_LEFT, CHEEK_SQUINT_RIGHT, EYE_LOOK_IN_LEFT, \\\r\n    EYE_LOOK_OUT_LEFT, EYE_LOOK_IN_RIGHT, EYE_LOOK_OUT_RIGHT, EYE_LOOK_UP_LEFT, EYE_LOOK_UP_RIGHT, EYE_LOOK_DOWN_RIGHT, \\\r\n    EYE_LOOK_DOWN_LEFT, JAW_OPEN, MOUTH_FROWN_LEFT, MOUTH_FROWN_RIGHT, \\\r\n    MOUTH_LOWER_DOWN_LEFT, MOUTH_LOWER_DOWN_RIGHT, MOUTH_FUNNEL, MOUTH_PUCKER\r\nfrom tha4.mocap.mediapipe_face_pose import MediaPipeFacePose\r\nfrom tha4.mocap.mediapipe_face_pose_converter import MediaPipeFacePoseConverter\r\n\r\n\r\nclass EyebrowDownMode(Enum):\r\n    TROUBLED = 1\r\n    ANGRY = 2\r\n    LOWERED = 3\r\n    SERIOUS = 4\r\n\r\n\r\nclass WinkMode(Enum):\r\n    NORMAL = 1\r\n    RELAXED = 2\r\n\r\n\r\ndef rad_to_deg(rad):\r\n    return rad * 180.0 / math.pi\r\n\r\n\r\ndef deg_to_rad(deg):\r\n    return deg * math.pi / 180.0\r\n\r\n\r\ndef clamp(x, min_value, max_value):\r\n    return max(min_value, min(max_value, x))\r\n\r\n\r\nclass MediaPipeFacePoseConverter00Args:\r\n    def __init__(self,\r\n                 smile_threshold_min: float = 0.4,\r\n                 smile_threshold_max: float = 0.6,\r\n                 eyebrow_down_mode: EyebrowDownMode = EyebrowDownMode.ANGRY,\r\n                 wink_mode: WinkMode = WinkMode.NORMAL,\r\n                 eye_surprised_max: float = 0.5,\r\n                 eye_blink_max: float = 0.8,\r\n                 eyebrow_down_max: float = 0.4,\r\n                 cheek_squint_min: float = 0.1,\r\n                 cheek_squint_max: float = 0.7,\r\n                 eye_rotation_factor: float = 1.0 / 0.75,\r\n                 jaw_open_min: float = 0.1,\r\n                 jaw_open_max: float = 0.4,\r\n                 mouth_frown_max: float = 0.6,\r\n                 mouth_funnel_min: float = 0.25,\r\n                 mouth_funnel_max: float = 0.5,\r\n                 iris_small_left=0.0,\r\n                 iris_small_right=0.0,\r\n                 head_x_offset=0.0,\r\n                 head_y_offset=0.0,\r\n                 head_z_offset=0.0):\r\n        self.iris_small_right = iris_small_left\r\n        self.iris_small_left = iris_small_right\r\n\r\n        self.wink_mode = wink_mode\r\n\r\n        self.mouth_funnel_max = mouth_funnel_max\r\n        self.mouth_funnel_min = mouth_funnel_min\r\n        self.mouth_frown_max = mouth_frown_max\r\n\r\n        self.jaw_open_max = jaw_open_max\r\n        self.jaw_open_min = jaw_open_min\r\n\r\n        self.eye_rotation_factor = eye_rotation_factor\r\n\r\n        self.cheek_squint_max = cheek_squint_max\r\n        self.cheek_squint_min = cheek_squint_min\r\n\r\n        self.eyebrow_down_max = eyebrow_down_max\r\n\r\n        self.eye_blink_max = eye_blink_max\r\n        self.eye_surprised_max = eye_surprised_max\r\n\r\n        self.smile_threshold_min = smile_threshold_min\r\n        self.smile_threshold_max = smile_threshold_max\r\n\r\n        self.head_z_offset = head_z_offset\r\n        self.head_y_offset = head_y_offset\r\n        self.head_x_offset = head_x_offset\r\n\r\n        self.eyebrow_down_mode = eyebrow_down_mode\r\n\r\n    def set_smile_threshold_min(self, new_value: float):\r\n        self.smile_threshold_min = new_value\r\n\r\n    def set_smile_threshold_max(self, new_value: float):\r\n        self.smile_threshold_max = new_value\r\n\r\n    def set_eye_surprised_max(self, new_value: float):\r\n        self.eye_surprised_max = new_value\r\n\r\n    def set_eye_blink_max(self, new_value: float):\r\n        self.eye_blink_max = new_value\r\n\r\n    def set_eyebrow_down_max(self, new_value: float):\r\n        self.eyebrow_down_max = new_value\r\n\r\n    def set_cheek_squint_min(self, new_value: float):\r\n        self.cheek_squint_min = new_value\r\n\r\n    def set_cheek_squint_max(self, new_value: float):\r\n        self.cheek_squint_max = new_value\r\n\r\n    def set_jaw_open_min(self, new_value: float):\r\n        self.jaw_open_min = new_value\r\n\r\n    def set_jaw_open_max(self, new_value: float):\r\n        self.jaw_open_max = new_value\r\n\r\n    def set_mouth_frown_max(self, new_value: float):\r\n        self.mouth_frown_max = new_value\r\n\r\n    def set_mouth_funnel_min(self, new_value: float):\r\n        self.mouth_funnel_min = new_value\r\n\r\n    def set_mouth_funnel_max(self, new_value: float):\r\n        self.mouth_funnel_min = new_value\r\n\r\n\r\nclass MediaPoseFacePoseConverter00(MediaPipeFacePoseConverter):\r\n    def __init__(self, args: Optional[MediaPipeFacePoseConverter00Args] = None):\r\n        super().__init__()\r\n        if args is None:\r\n            args = MediaPipeFacePoseConverter00Args()\r\n        self.args = args\r\n        pose_parameters = get_pose_parameters()\r\n        self.pose_size = 45\r\n\r\n        self.eyebrow_troubled_left_index = pose_parameters.get_parameter_index(\"eyebrow_troubled_left\")\r\n        self.eyebrow_troubled_right_index = pose_parameters.get_parameter_index(\"eyebrow_troubled_right\")\r\n        self.eyebrow_angry_left_index = pose_parameters.get_parameter_index(\"eyebrow_angry_left\")\r\n        self.eyebrow_angry_right_index = pose_parameters.get_parameter_index(\"eyebrow_angry_right\")\r\n        self.eyebrow_happy_left_index = pose_parameters.get_parameter_index(\"eyebrow_happy_left\")\r\n        self.eyebrow_happy_right_index = pose_parameters.get_parameter_index(\"eyebrow_happy_right\")\r\n        self.eyebrow_raised_left_index = pose_parameters.get_parameter_index(\"eyebrow_raised_left\")\r\n        self.eyebrow_raised_right_index = pose_parameters.get_parameter_index(\"eyebrow_raised_right\")\r\n        self.eyebrow_lowered_left_index = pose_parameters.get_parameter_index(\"eyebrow_lowered_left\")\r\n        self.eyebrow_lowered_right_index = pose_parameters.get_parameter_index(\"eyebrow_lowered_right\")\r\n        self.eyebrow_serious_left_index = pose_parameters.get_parameter_index(\"eyebrow_serious_left\")\r\n        self.eyebrow_serious_right_index = pose_parameters.get_parameter_index(\"eyebrow_serious_right\")\r\n\r\n        self.eye_surprised_left_index = pose_parameters.get_parameter_index(\"eye_surprised_left\")\r\n        self.eye_surprised_right_index = pose_parameters.get_parameter_index(\"eye_surprised_right\")\r\n        self.eye_wink_left_index = pose_parameters.get_parameter_index(\"eye_wink_left\")\r\n        self.eye_wink_right_index = pose_parameters.get_parameter_index(\"eye_wink_right\")\r\n        self.eye_happy_wink_left_index = pose_parameters.get_parameter_index(\"eye_happy_wink_left\")\r\n        self.eye_happy_wink_right_index = pose_parameters.get_parameter_index(\"eye_happy_wink_right\")\r\n        self.eye_relaxed_left_index = pose_parameters.get_parameter_index(\"eye_relaxed_left\")\r\n        self.eye_relaxed_right_index = pose_parameters.get_parameter_index(\"eye_relaxed_right\")\r\n        self.eye_raised_lower_eyelid_left_index = pose_parameters.get_parameter_index(\"eye_raised_lower_eyelid_left\")\r\n        self.eye_raised_lower_eyelid_right_index = pose_parameters.get_parameter_index(\"eye_raised_lower_eyelid_right\")\r\n\r\n        self.iris_small_left_index = pose_parameters.get_parameter_index(\"iris_small_left\")\r\n        self.iris_small_right_index = pose_parameters.get_parameter_index(\"iris_small_right\")\r\n\r\n        self.iris_rotation_x_index = pose_parameters.get_parameter_index(\"iris_rotation_x\")\r\n        self.iris_rotation_y_index = pose_parameters.get_parameter_index(\"iris_rotation_y\")\r\n\r\n        self.head_x_index = pose_parameters.get_parameter_index(\"head_x\")\r\n        self.head_y_index = pose_parameters.get_parameter_index(\"head_y\")\r\n        self.neck_z_index = pose_parameters.get_parameter_index(\"neck_z\")\r\n\r\n        self.mouth_aaa_index = pose_parameters.get_parameter_index(\"mouth_aaa\")\r\n        self.mouth_iii_index = pose_parameters.get_parameter_index(\"mouth_iii\")\r\n        self.mouth_uuu_index = pose_parameters.get_parameter_index(\"mouth_uuu\")\r\n        self.mouth_eee_index = pose_parameters.get_parameter_index(\"mouth_eee\")\r\n        self.mouth_ooo_index = pose_parameters.get_parameter_index(\"mouth_ooo\")\r\n\r\n        self.mouth_lowered_corner_left_index = pose_parameters.get_parameter_index(\"mouth_lowered_corner_left\")\r\n        self.mouth_lowered_corner_right_index = pose_parameters.get_parameter_index(\"mouth_lowered_corner_right\")\r\n        self.mouth_raised_corner_left_index = pose_parameters.get_parameter_index(\"mouth_raised_corner_left\")\r\n        self.mouth_raised_corner_right_index = pose_parameters.get_parameter_index(\"mouth_raised_corner_right\")\r\n\r\n        self.body_y_index = pose_parameters.get_parameter_index(\"body_y\")\r\n        self.body_z_index = pose_parameters.get_parameter_index(\"body_z\")\r\n        self.breathing_index = pose_parameters.get_parameter_index(\"breathing\")\r\n\r\n        self.breathing_start_time = time.time()\r\n\r\n        self.panel = None\r\n        self.current_pose_supplier = None\r\n\r\n    def init_pose_converter_panel(\r\n            self,\r\n            parent,\r\n            current_pose_supplier: Callable[[], Optional[MediaPipeFacePose]]):\r\n        self.panel = wx.Panel(parent, style=wx.SIMPLE_BORDER)\r\n        self.panel_sizer = wx.BoxSizer(wx.VERTICAL)\r\n        self.panel.SetSizer(self.panel_sizer)\r\n        self.panel.SetAutoLayout(1)\r\n        parent.GetSizer().Add(self.panel, 0, wx.EXPAND)\r\n\r\n        self.current_pose_supplier = current_pose_supplier\r\n\r\n        if True:\r\n            eyebrow_down_mode_text = wx.StaticText(self.panel, label=\" --- Eyebrow Down Mode --- \",\r\n                                                   style=wx.ALIGN_CENTER)\r\n            self.panel_sizer.Add(eyebrow_down_mode_text, 0, wx.EXPAND)\r\n\r\n            self.eyebrow_down_mode_choice = wx.Choice(\r\n                self.panel,\r\n                choices=[\r\n                    \"ANGRY\",\r\n                    \"TROUBLED\",\r\n                    \"SERIOUS\",\r\n                    \"LOWERED\",\r\n                ])\r\n            self.eyebrow_down_mode_choice.SetSelection(0)\r\n            self.panel_sizer.Add(self.eyebrow_down_mode_choice, 0, wx.EXPAND)\r\n            self.eyebrow_down_mode_choice.Bind(wx.EVT_CHOICE, self.change_eyebrow_down_mode)\r\n\r\n        if True:\r\n            separator = wx.StaticLine(self.panel, -1, size=(256, 5))\r\n            self.panel_sizer.Add(separator, 0, wx.EXPAND)\r\n\r\n            wink_mode_text = wx.StaticText(self.panel, label=\" --- Wink Mode --- \", style=wx.ALIGN_CENTER)\r\n            self.panel_sizer.Add(wink_mode_text, 0, wx.EXPAND)\r\n\r\n            self.wink_mode_choice = wx.Choice(\r\n                self.panel,\r\n                choices=[\r\n                    \"NORMAL\",\r\n                    \"RELAXED\",\r\n                ])\r\n            self.wink_mode_choice.SetSelection(0)\r\n            self.panel_sizer.Add(self.wink_mode_choice, 0, wx.EXPAND)\r\n            self.wink_mode_choice.Bind(wx.EVT_CHOICE, self.change_wink_mode)\r\n\r\n        if True:\r\n            separator = wx.StaticLine(self.panel, -1, size=(256, 5))\r\n            self.panel_sizer.Add(separator, 0, wx.EXPAND)\r\n\r\n            iris_size_text = wx.StaticText(self.panel, label=\" --- Iris Size --- \", style=wx.ALIGN_CENTER)\r\n            self.panel_sizer.Add(iris_size_text, 0, wx.EXPAND)\r\n\r\n            self.iris_left_slider = wx.Slider(self.panel, minValue=0, maxValue=1000, value=0, style=wx.HORIZONTAL)\r\n            self.panel_sizer.Add(self.iris_left_slider, 0, wx.EXPAND)\r\n            self.iris_left_slider.Bind(wx.EVT_SLIDER, self.change_iris_size)\r\n\r\n            self.iris_right_slider = wx.Slider(self.panel, minValue=0, maxValue=1000, value=0, style=wx.HORIZONTAL)\r\n            self.panel_sizer.Add(self.iris_right_slider, 0, wx.EXPAND)\r\n            self.iris_right_slider.Bind(wx.EVT_SLIDER, self.change_iris_size)\r\n            self.iris_right_slider.Enable(False)\r\n\r\n            self.link_left_right_irises = wx.CheckBox(\r\n                self.panel, label=\"Use same value for both sides\")\r\n            self.link_left_right_irises.SetValue(True)\r\n            self.panel_sizer.Add(self.link_left_right_irises, wx.SizerFlags().CenterHorizontal().Border())\r\n            self.link_left_right_irises.Bind(wx.EVT_CHECKBOX, self.link_left_right_irises_clicked)\r\n\r\n        if True:\r\n            separator = wx.StaticLine(self.panel, -1, size=(256, 5))\r\n            self.panel_sizer.Add(separator, 0, wx.EXPAND)\r\n\r\n            breathing_frequency_text = wx.StaticText(\r\n                self.panel, label=\" --- Breathing --- \", style=wx.ALIGN_CENTER)\r\n            self.panel_sizer.Add(breathing_frequency_text, 0, wx.EXPAND)\r\n\r\n            self.restart_breathing_cycle_button = wx.Button(self.panel, label=\"Restart Breathing Cycle\")\r\n            self.restart_breathing_cycle_button.Bind(wx.EVT_BUTTON, self.restart_breathing_cycle_clicked)\r\n            self.panel_sizer.Add(self.restart_breathing_cycle_button, 0, wx.EXPAND)\r\n\r\n            self.breathing_frequency_slider = wx.Slider(\r\n                self.panel, minValue=0, maxValue=60, value=20, style=wx.HORIZONTAL)\r\n            self.panel_sizer.Add(self.breathing_frequency_slider, 0, wx.EXPAND)\r\n\r\n            self.breathing_gauge = wx.Gauge(self.panel, style=wx.GA_HORIZONTAL, range=1000)\r\n            self.panel_sizer.Add(self.breathing_gauge, 0, wx.EXPAND)\r\n\r\n        if True:\r\n            separator = wx.StaticLine(self.panel, -1, size=(256, 5))\r\n            self.panel_sizer.Add(separator, 0, wx.EXPAND)\r\n\r\n            face_orientation_text = wx.StaticText(\r\n                self.panel, label=\"--- Face Orientation ---\", style=wx.ALIGN_CENTER)\r\n            self.panel_sizer.Add(face_orientation_text, 0, wx.EXPAND)\r\n\r\n            self.calibrate_face_orientation_button = wx.Button(self.panel, label=\"Calibrate (I'm looking forward)\")\r\n            self.calibrate_face_orientation_button.Bind(wx.EVT_BUTTON, self.calibrate_face_orientation_clicked)\r\n            self.panel_sizer.Add(self.calibrate_face_orientation_button, 0, wx.EXPAND)\r\n\r\n        if True:\r\n            separator = wx.StaticLine(self.panel, -1, size=(256, 5))\r\n            self.panel_sizer.Add(separator, 0, wx.EXPAND)\r\n\r\n            convertion_parameters_text = wx.StaticText(\r\n                self.panel, label=\"--- Conversion Parameters ---\", style=wx.ALIGN_CENTER)\r\n            self.panel_sizer.Add(convertion_parameters_text, 0, wx.EXPAND)\r\n\r\n            conversion_param_panel = wx.Panel(self.panel)\r\n            self.panel_sizer.Add(conversion_param_panel, 0, wx.EXPAND)\r\n            conversion_panel_sizer = wx.FlexGridSizer(cols=2)\r\n            conversion_panel_sizer.AddGrowableCol(1)\r\n            conversion_param_panel.SetSizer(conversion_panel_sizer)\r\n            conversion_param_panel.SetAutoLayout(1)\r\n\r\n            self.smile_thresold_min_spin = self.create_spin_control(\r\n                conversion_param_panel,\r\n                \"Smile Threshold Min:\", self.args.smile_threshold_min, self.args.set_smile_threshold_min)\r\n            self.smile_thresold_max_spin = self.create_spin_control(\r\n                conversion_param_panel,\r\n                \"Smile Threshold Max:\", self.args.smile_threshold_max, self.args.set_smile_threshold_max)\r\n            self.eye_surprised_max_spin = self.create_spin_control(\r\n                conversion_param_panel,\r\n                \"Eye Surprised Max:\", self.args.eye_surprised_max, self.args.set_eye_surprised_max)\r\n            self.eye_blink_max_spin = self.create_spin_control(\r\n                conversion_param_panel,\r\n                \"Eye Blink Max:\", self.args.eye_blink_max, self.args.set_eye_blink_max)\r\n            self.eyebrow_down_max_spin = self.create_spin_control(\r\n                conversion_param_panel,\r\n                \"Eyebrow Down Max:\", self.args.eyebrow_down_max, self.args.set_eyebrow_down_max)\r\n            self.cheek_squint_min_spin = self.create_spin_control(\r\n                conversion_param_panel,\r\n                \"Cheek Squint Min:\", self.args.cheek_squint_min, self.args.set_cheek_squint_min)\r\n            self.cheek_squint_max_spin = self.create_spin_control(\r\n                conversion_param_panel,\r\n                \"Cheek Squint Max:\", self.args.cheek_squint_max, self.args.set_cheek_squint_max)\r\n            self.jaw_open_min_spin = self.create_spin_control(\r\n                conversion_param_panel,\r\n                \"Jaw Open Min:\", self.args.jaw_open_min, self.args.set_jaw_open_min)\r\n            self.jaw_open_max_spin = self.create_spin_control(\r\n                conversion_param_panel,\r\n                \"Jaw Open Max:\", self.args.jaw_open_max, self.args.set_jaw_open_max)\r\n            self.mouth_frown_max_spin = self.create_spin_control(\r\n                conversion_param_panel,\r\n                \"Mouth Frown Max:\", self.args.mouth_frown_max, self.args.set_mouth_frown_max)\r\n            self.mouth_funnel_min_spin = self.create_spin_control(\r\n                conversion_param_panel,\r\n                \"Mouth Funnel Min:\", self.args.mouth_funnel_min, self.args.set_mouth_funnel_min)\r\n            self.mouth_funnel_max_spin = self.create_spin_control(\r\n                conversion_param_panel,\r\n                \"Mouth Funnel Max:\", self.args.mouth_funnel_max, self.args.set_mouth_funnel_max)\r\n\r\n        self.panel_sizer.Fit(self.panel)\r\n\r\n    def create_spin_control(self, parent, label: str, initial_value: float, set_func: Callable[[float], None]):\r\n        sizer = parent.GetSizer()\r\n\r\n        text = wx.StaticText(parent, label=label)\r\n        sizer.Add(text, wx.SizerFlags().Right().Border(wx.ALL, 2))\r\n\r\n        spin_ctrl = wx.SpinCtrlDouble(\r\n            parent,\r\n            wx.ID_ANY,\r\n            min=0.0,\r\n            max=1.0,\r\n            initial=initial_value,\r\n            inc=0.01)\r\n        sizer.Add(spin_ctrl, wx.SizerFlags().Border(wx.ALL, 2).Expand())\r\n\r\n        def handler(event: wx.Event):\r\n            new_value = spin_ctrl.GetValue()\r\n            set_func(new_value)\r\n\r\n        spin_ctrl.Bind(wx.EVT_SPINCTRLDOUBLE, handler)\r\n\r\n        return spin_ctrl\r\n\r\n    def extract_euler_angles(self, mediapipe_face_pose: MediaPipeFacePose):\r\n        M = mediapipe_face_pose.xform_matrix[0:3, 0:3]\r\n        rot = Rotation.from_matrix(M)\r\n        return rot.as_euler('xyz', degrees=False)\r\n\r\n    def calibrate_face_orientation_clicked(self, event: wx.Event):\r\n        if self.current_pose_supplier is None:\r\n            return\r\n\r\n        mediapipe_face_pose = self.current_pose_supplier()\r\n        if mediapipe_face_pose is None:\r\n            return\r\n\r\n        euler_angles = self.extract_euler_angles(mediapipe_face_pose)\r\n        self.args.head_x_offset = euler_angles[0]\r\n        self.args.head_y_offset = euler_angles[1]\r\n        self.args.head_z_offset = euler_angles[2]\r\n\r\n    def restart_breathing_cycle_clicked(self, event: wx.Event):\r\n        self.breathing_start_time = time.time()\r\n\r\n    def change_eyebrow_down_mode(self, event: wx.Event):\r\n        selected_index = self.eyebrow_down_mode_choice.GetSelection()\r\n        if selected_index == 0:\r\n            self.args.eyebrow_down_mode = EyebrowDownMode.ANGRY\r\n        elif selected_index == 1:\r\n            self.args.eyebrow_down_mode = EyebrowDownMode.TROUBLED\r\n        elif selected_index == 2:\r\n            self.args.eyebrow_down_mode = EyebrowDownMode.SERIOUS\r\n        else:\r\n            self.args.eyebrow_down_mode = EyebrowDownMode.LOWERED\r\n\r\n    def change_wink_mode(self, event: wx.Event):\r\n        selected_index = self.wink_mode_choice.GetSelection()\r\n        if selected_index == 0:\r\n            self.args.wink_mode = WinkMode.NORMAL\r\n        else:\r\n            self.args.wink_mode = WinkMode.RELAXED\r\n\r\n    def change_iris_size(self, event: wx.Event):\r\n        if self.link_left_right_irises.GetValue():\r\n            left_value = self.iris_left_slider.GetValue()\r\n            right_value = self.iris_right_slider.GetValue()\r\n            if left_value != right_value:\r\n                self.iris_right_slider.SetValue(left_value)\r\n            self.args.iris_small_left = left_value / 1000.0\r\n            self.args.iris_small_right = left_value / 1000.0\r\n        else:\r\n            self.args.iris_small_left = self.iris_left_slider.GetValue() / 1000.0\r\n            self.args.iris_small_right = self.iris_right_slider.GetValue() / 1000.0\r\n\r\n    def link_left_right_irises_clicked(self, event: wx.Event):\r\n        if self.link_left_right_irises.GetValue():\r\n            self.iris_right_slider.Enable(False)\r\n        else:\r\n            self.iris_right_slider.Enable(True)\r\n        self.change_iris_size(event)\r\n\r\n    def decompose_head_body_param(self, param, threshold=2.0 / 3):\r\n        if abs(param) < threshold:\r\n            return (param, 0.0)\r\n        else:\r\n            if param < 0:\r\n                sign = -1.0\r\n            else:\r\n                sign = 1.0\r\n            return (threshold * sign, (abs(param) - threshold) * sign)\r\n\r\n    def convert(self, mediapipe_face_pose: MediaPipeFacePose) -> List[float]:\r\n        pose = [0.0 for i in range(self.pose_size)]\r\n\r\n        blendshape_params = mediapipe_face_pose.blendshape_params\r\n\r\n        smile_value = \\\r\n            (blendshape_params[MOUTH_SMILE_LEFT] + blendshape_params[MOUTH_SMILE_RIGHT]) / 2.0 \\\r\n            + blendshape_params[MOUTH_SHRUG_UPPER]\r\n        if self.args.smile_threshold_min >= self.args.smile_threshold_max:\r\n            smile_degree = 0.0\r\n        else:\r\n            if smile_value < self.args.smile_threshold_min:\r\n                smile_degree = 0.0\r\n            elif smile_value > self.args.smile_threshold_max:\r\n                smile_degree = 1.0\r\n            else:\r\n                smile_degree = (smile_value - self.args.smile_threshold_min) / (\r\n                        self.args.smile_threshold_max - self.args.smile_threshold_min)\r\n\r\n        # Eyebrow\r\n        if True:\r\n            brow_inner_up = blendshape_params[BROW_INNER_UP]\r\n            brow_outer_up_right = blendshape_params[BROW_OUTER_UP_RIGHT]\r\n            brow_outer_up_left = blendshape_params[BROW_OUTER_UP_LEFT]\r\n\r\n            brow_up_left = clamp(brow_inner_up + brow_outer_up_left, 0.0, 1.0)\r\n            brow_up_right = clamp(brow_inner_up + brow_outer_up_right, 0.0, 1.0)\r\n            pose[self.eyebrow_raised_left_index] = brow_up_left\r\n            pose[self.eyebrow_raised_right_index] = brow_up_right\r\n\r\n            if self.args.eyebrow_down_max <= 0.0:\r\n                brow_down_left = 0.0\r\n                brow_down_right = 0.0\r\n            else:\r\n                brow_down_left = (1.0 - smile_degree) \\\r\n                                 * clamp(blendshape_params[BROW_DOWN_LEFT] / self.args.eyebrow_down_max, 0.0, 1.0)\r\n                brow_down_right = (1.0 - smile_degree) \\\r\n                                  * clamp(blendshape_params[BROW_DOWN_RIGHT] / self.args.eyebrow_down_max, 0.0, 1.0)\r\n\r\n            if self.args.eyebrow_down_mode == EyebrowDownMode.TROUBLED:\r\n                pose[self.eyebrow_troubled_left_index] = brow_down_left\r\n                pose[self.eyebrow_troubled_right_index] = brow_down_right\r\n            elif self.args.eyebrow_down_mode == EyebrowDownMode.ANGRY:\r\n                pose[self.eyebrow_angry_left_index] = brow_down_left\r\n                pose[self.eyebrow_angry_right_index] = brow_down_right\r\n            elif self.args.eyebrow_down_mode == EyebrowDownMode.LOWERED:\r\n                pose[self.eyebrow_lowered_left_index] = brow_down_left\r\n                pose[self.eyebrow_lowered_right_index] = brow_down_right\r\n            elif self.args.eyebrow_down_mode == EyebrowDownMode.SERIOUS:\r\n                pose[self.eyebrow_serious_left_index] = brow_down_left\r\n                pose[self.eyebrow_serious_right_index] = brow_down_right\r\n\r\n            brow_happy_value = clamp(smile_value, 0.0, 1.0) * smile_degree\r\n            pose[self.eyebrow_happy_left_index] = brow_happy_value\r\n            pose[self.eyebrow_happy_right_index] = brow_happy_value\r\n\r\n        # Eye\r\n        if True:\r\n            # Surprised\r\n            if self.args.eye_surprised_max <= 0.0:\r\n                pose[self.eye_surprised_left_index] = 0.0\r\n                pose[self.eye_surprised_right_index] = 0.0\r\n            else:\r\n                pose[self.eye_surprised_left_index] = clamp(\r\n                    blendshape_params[EYE_WIDE_LEFT] / self.args.eye_surprised_max, 0.0, 1.0)\r\n                pose[self.eye_surprised_right_index] = clamp(\r\n                    blendshape_params[EYE_WIDE_RIGHT] / self.args.eye_surprised_max, 0.0, 1.0)\r\n\r\n            # Wink\r\n            if self.args.wink_mode == WinkMode.NORMAL:\r\n                wink_left_index = self.eye_wink_left_index\r\n                wink_right_index = self.eye_wink_right_index\r\n            else:\r\n                wink_left_index = self.eye_relaxed_left_index\r\n                wink_right_index = self.eye_relaxed_right_index\r\n            if self.args.eye_blink_max <= 0:\r\n                pose[wink_left_index] = 0.0\r\n                pose[wink_right_index] = 0.0\r\n                pose[self.eye_happy_wink_left_index] = 0.0\r\n                pose[self.eye_happy_wink_right_index] = 0.0\r\n            else:\r\n                pose[wink_left_index] = (1.0 - smile_degree) * clamp(\r\n                    blendshape_params[EYE_BLINK_LEFT] / self.args.eye_blink_max, 0.0, 1.0)\r\n                pose[wink_right_index] = (1.0 - smile_degree) * clamp(\r\n                    blendshape_params[EYE_BLINK_RIGHT] / self.args.eye_blink_max, 0.0, 1.0)\r\n                pose[self.eye_happy_wink_left_index] = smile_degree * clamp(\r\n                    blendshape_params[EYE_BLINK_LEFT] / self.args.eye_blink_max, 0.0, 1.0)\r\n                pose[self.eye_happy_wink_right_index] = smile_degree * clamp(\r\n                    blendshape_params[EYE_BLINK_RIGHT] / self.args.eye_blink_max, 0.0, 1.0)\r\n\r\n            # Lower eyelid\r\n            cheek_squint_denom = self.args.cheek_squint_max - self.args.cheek_squint_min\r\n            if cheek_squint_denom <= 0.0:\r\n                pose[self.eye_raised_lower_eyelid_left_index] = 0.0\r\n                pose[self.eye_raised_lower_eyelid_right_index] = 0.0\r\n            else:\r\n                pose[self.eye_raised_lower_eyelid_left_index] = \\\r\n                    clamp(\r\n                        (blendshape_params[CHEEK_SQUINT_LEFT] - self.args.cheek_squint_min) / cheek_squint_denom,\r\n                        0.0, 1.0)\r\n                pose[self.eye_raised_lower_eyelid_right_index] = \\\r\n                    clamp(\r\n                        (blendshape_params[CHEEK_SQUINT_RIGHT] - self.args.cheek_squint_min) / cheek_squint_denom,\r\n                        0.0, 1.0)\r\n\r\n        # Iris rotation\r\n        if True:\r\n            eye_rotation_y = (blendshape_params[EYE_LOOK_IN_LEFT]\r\n                              - blendshape_params[EYE_LOOK_OUT_LEFT]\r\n                              - blendshape_params[EYE_LOOK_IN_RIGHT]\r\n                              + blendshape_params[EYE_LOOK_OUT_RIGHT]) / 2.0 * self.args.eye_rotation_factor\r\n            pose[self.iris_rotation_y_index] = clamp(eye_rotation_y, -1.0, 1.0)\r\n\r\n            eye_rotation_x = (blendshape_params[EYE_LOOK_UP_LEFT]\r\n                              + blendshape_params[EYE_LOOK_UP_RIGHT]\r\n                              - blendshape_params[EYE_LOOK_DOWN_LEFT]\r\n                              - blendshape_params[EYE_LOOK_DOWN_RIGHT]) / 2.0 * self.args.eye_rotation_factor\r\n            pose[self.iris_rotation_x_index] = clamp(eye_rotation_x, -1.0, 1.0)\r\n\r\n        # Iris size\r\n        if True:\r\n            pose[self.iris_small_left_index] = self.args.iris_small_left\r\n            pose[self.iris_small_right_index] = self.args.iris_small_right\r\n\r\n        # Head rotation\r\n        if True:\r\n            euler_angles = self.extract_euler_angles(mediapipe_face_pose)\r\n            euler_angles[0] -= self.args.head_x_offset\r\n            euler_angles[1] -= self.args.head_y_offset\r\n            euler_angles[2] -= self.args.head_z_offset\r\n\r\n            x_param = clamp(-euler_angles[0] * 180.0 / math.pi, -15.0, 15.0) / 15.0\r\n            pose[self.head_x_index] = x_param\r\n\r\n            y_param = clamp(-euler_angles[1] * 180.0 / math.pi, -10.0, 10.0) / 10.0\r\n            pose[self.head_y_index] = y_param\r\n            pose[self.body_y_index] = y_param\r\n\r\n            z_param = clamp(euler_angles[2] * 180.0 / math.pi, -15.0, 15.0) / 15.0\r\n            pose[self.neck_z_index] = z_param\r\n            pose[self.body_z_index] = z_param\r\n\r\n        # Mouth\r\n        if True:\r\n            jaw_open_denom = self.args.jaw_open_max - self.args.jaw_open_min\r\n            if jaw_open_denom <= 0:\r\n                mouth_open = 0.0\r\n            else:\r\n                mouth_open = clamp((blendshape_params[JAW_OPEN] - self.args.jaw_open_min) / jaw_open_denom, 0.0, 1.0)\r\n            pose[self.mouth_aaa_index] = mouth_open\r\n            pose[self.mouth_raised_corner_left_index] = clamp(smile_value, 0.0, 1.0)\r\n            pose[self.mouth_raised_corner_right_index] = clamp(smile_value, 0.0, 1.0)\r\n\r\n            is_mouth_open = mouth_open > 0.0\r\n            if not is_mouth_open:\r\n                if self.args.mouth_frown_max <= 0:\r\n                    mouth_frown_value = 0.0\r\n                else:\r\n                    mouth_frown_value = clamp(\r\n                        (blendshape_params[MOUTH_FROWN_LEFT] + blendshape_params[\r\n                            MOUTH_FROWN_RIGHT]) / self.args.mouth_frown_max, 0.0, 1.0)\r\n                pose[self.mouth_lowered_corner_left_index] = mouth_frown_value\r\n                pose[self.mouth_lowered_corner_right_index] = mouth_frown_value\r\n            else:\r\n                mouth_lower_down = clamp(\r\n                    blendshape_params[MOUTH_LOWER_DOWN_LEFT] + blendshape_params[MOUTH_LOWER_DOWN_RIGHT], 0.0, 1.0)\r\n                mouth_funnel = blendshape_params[MOUTH_FUNNEL]\r\n                mouth_pucker = blendshape_params[MOUTH_PUCKER]\r\n\r\n                mouth_point = [mouth_open, mouth_lower_down, mouth_funnel, mouth_pucker]\r\n\r\n                aaa_point = [1.0, 1.0, 0.0, 0.0]\r\n                iii_point = [0.0, 1.0, 0.0, 0.0]\r\n                uuu_point = [0.5, 0.3, 0.25, 0.75]\r\n                ooo_point = [1.0, 0.5, 0.5, 0.4]\r\n\r\n                decomp = numpy.array([0, 0, 0, 0])\r\n                M = numpy.array([\r\n                    aaa_point,\r\n                    iii_point,\r\n                    uuu_point,\r\n                    ooo_point\r\n                ])\r\n\r\n                def loss(decomp):\r\n                    return numpy.linalg.norm(numpy.matmul(decomp, M) - mouth_point) \\\r\n                        + 0.01 * numpy.linalg.norm(decomp, ord=1)\r\n\r\n                opt_result = scipy.optimize.minimize(\r\n                    loss, decomp, bounds=[(0.0, 1.0), (0.0, 1.0), (0.0, 1.0), (0.0, 1.0)])\r\n                decomp = opt_result[\"x\"]\r\n                restricted_decomp = [decomp.item(0), decomp.item(1), decomp.item(2), decomp.item(3)]\r\n                pose[self.mouth_aaa_index] = restricted_decomp[0]\r\n                pose[self.mouth_iii_index] = restricted_decomp[1]\r\n                mouth_funnel_denom = self.args.mouth_funnel_max - self.args.mouth_funnel_min\r\n                if mouth_funnel_denom <= 0:\r\n                    ooo_alpha = 0.0\r\n                    uo_value = 0.0\r\n                else:\r\n                    ooo_alpha = clamp((mouth_funnel - self.args.mouth_funnel_min) / mouth_funnel_denom, 0.0, 1.0)\r\n                    uo_value = clamp(restricted_decomp[2] + restricted_decomp[3], 0.0, 1.0)\r\n                pose[self.mouth_uuu_index] = uo_value * (1.0 - ooo_alpha)\r\n                pose[self.mouth_ooo_index] = uo_value * ooo_alpha\r\n\r\n        if self.panel is not None:\r\n            frequency = self.breathing_frequency_slider.GetValue()\r\n            if frequency == 0:\r\n                value = 0.0\r\n                pose[self.breathing_index] = value\r\n                self.breathing_start_time = time.time()\r\n            else:\r\n                period = 60.0 / frequency\r\n                now = time.time()\r\n                diff = now - self.breathing_start_time\r\n                frac = (diff % period) / period\r\n                value = (-math.cos(2 * math.pi * frac) + 1.0) / 2.0\r\n                pose[self.breathing_index] = value\r\n            self.breathing_gauge.SetValue(int(1000 * value))\r\n\r\n        return pose\r\n"
  },
  {
    "path": "src/tha4/nn/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/nn/common/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/nn/common/conv_block_factory.py",
    "content": "from typing import Optional\r\n\r\nfrom tha4.nn.conv import create_conv7_block_from_block_args, create_conv3_block_from_block_args, \\\r\n    create_downsample_block_from_block_args, create_conv3\r\nfrom tha4.nn.resnet_block import ResnetBlock\r\nfrom tha4.nn.resnet_block_seperable import ResnetBlockSeparable\r\nfrom tha4.nn.separable_conv import create_separable_conv7_block, create_separable_conv3_block, \\\r\n    create_separable_downsample_block, create_separable_conv3\r\nfrom tha4.nn.util import BlockArgs\r\n\r\n\r\nclass ConvBlockFactory:\r\n    def __init__(self,\r\n                 block_args: BlockArgs,\r\n                 use_separable_convolution: bool = False):\r\n        self.use_separable_convolution = use_separable_convolution\r\n        self.block_args = block_args\r\n\r\n    def create_conv3(self,\r\n                     in_channels: int,\r\n                     out_channels: int,\r\n                     bias: bool,\r\n                     initialization_method: Optional[str] = None):\r\n        if initialization_method is None:\r\n            initialization_method = self.block_args.initialization_method\r\n        if self.use_separable_convolution:\r\n            return create_separable_conv3(\r\n                in_channels, out_channels, bias, initialization_method, self.block_args.use_spectral_norm)\r\n        else:\r\n            return create_conv3(\r\n                in_channels, out_channels, bias, initialization_method, self.block_args.use_spectral_norm)\r\n\r\n    def create_conv7_block(self, in_channels: int, out_channels: int):\r\n        if self.use_separable_convolution:\r\n            return create_separable_conv7_block(in_channels, out_channels, self.block_args)\r\n        else:\r\n            return create_conv7_block_from_block_args(in_channels, out_channels, self.block_args)\r\n\r\n    def create_conv3_block(self, in_channels: int, out_channels: int):\r\n        if self.use_separable_convolution:\r\n            return create_separable_conv3_block(in_channels, out_channels, self.block_args)\r\n        else:\r\n            return create_conv3_block_from_block_args(in_channels, out_channels, self.block_args)\r\n\r\n    def create_downsample_block(self, in_channels: int, out_channels: int, is_output_1x1: bool):\r\n        if self.use_separable_convolution:\r\n            return create_separable_downsample_block(in_channels, out_channels, is_output_1x1, self.block_args)\r\n        else:\r\n            return create_downsample_block_from_block_args(in_channels, out_channels, is_output_1x1)\r\n\r\n    def create_resnet_block(self, num_channels: int, is_1x1: bool):\r\n        if self.use_separable_convolution:\r\n            return ResnetBlockSeparable.create(num_channels, is_1x1, block_args=self.block_args)\r\n        else:\r\n            return ResnetBlock.create(num_channels, is_1x1, block_args=self.block_args)"
  },
  {
    "path": "src/tha4/nn/common/poser_args.py",
    "content": "from typing import Optional\r\n\r\nfrom torch.nn import Sigmoid, Sequential, Tanh\r\n\r\nfrom tha4.nn.conv import create_conv3, create_conv3_from_block_args\r\nfrom tha4.nn.nonlinearity_factory import ReLUFactory\r\nfrom tha4.nn.normalization import InstanceNorm2dFactory\r\nfrom tha4.nn.util import BlockArgs\r\n\r\n\r\nclass PoserArgs00:\r\n    def __init__(self,\r\n                 image_size: int,\r\n                 input_image_channels: int,\r\n                 output_image_channels: int,\r\n                 start_channels: int,\r\n                 num_pose_params: int,\r\n                 block_args: Optional[BlockArgs] = None):\r\n        self.num_pose_params = num_pose_params\r\n        self.start_channels = start_channels\r\n        self.output_image_channels = output_image_channels\r\n        self.input_image_channels = input_image_channels\r\n        self.image_size = image_size\r\n        if block_args is None:\r\n            self.block_args = BlockArgs(\r\n                normalization_layer_factory=InstanceNorm2dFactory(),\r\n                nonlinearity_factory=ReLUFactory(inplace=True))\r\n        else:\r\n            self.block_args = block_args\r\n\r\n    def create_alpha_block(self):\r\n        from torch.nn import Sequential\r\n        return Sequential(\r\n            create_conv3(\r\n                in_channels=self.start_channels,\r\n                out_channels=1,\r\n                bias=True,\r\n                initialization_method=self.block_args.initialization_method,\r\n                use_spectral_norm=False),\r\n            Sigmoid())\r\n\r\n    def create_all_channel_alpha_block(self):\r\n        from torch.nn import Sequential\r\n        return Sequential(\r\n            create_conv3(\r\n                in_channels=self.start_channels,\r\n                out_channels=self.output_image_channels,\r\n                bias=True,\r\n                initialization_method=self.block_args.initialization_method,\r\n                use_spectral_norm=False),\r\n            Sigmoid())\r\n\r\n    def create_color_change_block(self):\r\n        return Sequential(\r\n            create_conv3_from_block_args(\r\n                in_channels=self.start_channels,\r\n                out_channels=self.output_image_channels,\r\n                bias=True,\r\n                block_args=self.block_args),\r\n            Tanh())\r\n\r\n    def create_grid_change_block(self):\r\n        return create_conv3(\r\n            in_channels=self.start_channels,\r\n            out_channels=2,\r\n            bias=False,\r\n            initialization_method='zero',\r\n            use_spectral_norm=False)"
  },
  {
    "path": "src/tha4/nn/common/poser_encoder_decoder_00.py",
    "content": "import math\r\nfrom typing import Optional, List\r\n\r\nimport torch\r\nfrom torch import Tensor\r\nfrom torch.nn import ModuleList, Module\r\n\r\nfrom tha4.nn.common.poser_args import PoserArgs00\r\nfrom tha4.nn.conv import create_conv3_block_from_block_args, create_downsample_block_from_block_args, \\\r\n    create_upsample_block_from_block_args\r\nfrom tha4.nn.nonlinearity_factory import ReLUFactory\r\nfrom tha4.nn.normalization import InstanceNorm2dFactory\r\nfrom tha4.nn.resnet_block import ResnetBlock\r\nfrom tha4.nn.util import BlockArgs\r\n\r\n\r\nclass PoserEncoderDecoder00Args(PoserArgs00):\r\n    def __init__(self,\r\n                 image_size: int,\r\n                 input_image_channels: int,\r\n                 output_image_channels: int,\r\n                 num_pose_params: int ,\r\n                 start_channels: int,\r\n                 bottleneck_image_size,\r\n                 num_bottleneck_blocks,\r\n                 max_channels: int,\r\n                 block_args: Optional[BlockArgs] = None):\r\n        super().__init__(\r\n            image_size, input_image_channels, output_image_channels, start_channels, num_pose_params, block_args)\r\n        self.max_channels = max_channels\r\n        self.num_bottleneck_blocks = num_bottleneck_blocks\r\n        self.bottleneck_image_size = bottleneck_image_size\r\n        assert bottleneck_image_size > 1\r\n\r\n        if block_args is None:\r\n            self.block_args = BlockArgs(\r\n                normalization_layer_factory=InstanceNorm2dFactory(),\r\n                nonlinearity_factory=ReLUFactory(inplace=True))\r\n        else:\r\n            self.block_args = block_args\r\n\r\n\r\nclass PoserEncoderDecoder00(Module):\r\n    def __init__(self, args: PoserEncoderDecoder00Args):\r\n        super().__init__()\r\n        self.args = args\r\n\r\n        self.num_levels = int(math.log2(args.image_size // args.bottleneck_image_size)) + 1\r\n\r\n        self.downsample_blocks = ModuleList()\r\n        self.downsample_blocks.append(\r\n            create_conv3_block_from_block_args(\r\n                args.input_image_channels,\r\n                args.start_channels,\r\n                args.block_args))\r\n        current_image_size = args.image_size\r\n        current_num_channels = args.start_channels\r\n        while current_image_size > args.bottleneck_image_size:\r\n            next_image_size = current_image_size // 2\r\n            next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)\r\n            self.downsample_blocks.append(create_downsample_block_from_block_args(\r\n                in_channels=current_num_channels,\r\n                out_channels=next_num_channels,\r\n                is_output_1x1=False,\r\n                block_args=args.block_args))\r\n            current_image_size = next_image_size\r\n            current_num_channels = next_num_channels\r\n        assert len(self.downsample_blocks) == self.num_levels\r\n\r\n        self.bottleneck_blocks = ModuleList()\r\n        self.bottleneck_blocks.append(create_conv3_block_from_block_args(\r\n            in_channels=current_num_channels + args.num_pose_params,\r\n            out_channels=current_num_channels,\r\n            block_args=args.block_args))\r\n        for i in range(1, args.num_bottleneck_blocks):\r\n            self.bottleneck_blocks.append(\r\n                ResnetBlock.create(\r\n                    num_channels=current_num_channels,\r\n                    is1x1=False,\r\n                    block_args=args.block_args))\r\n\r\n        self.upsample_blocks = ModuleList()\r\n        while current_image_size < args.image_size:\r\n            next_image_size = current_image_size * 2\r\n            next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)\r\n            self.upsample_blocks.append(create_upsample_block_from_block_args(\r\n                in_channels=current_num_channels,\r\n                out_channels=next_num_channels,\r\n                block_args=args.block_args))\r\n            current_image_size = next_image_size\r\n            current_num_channels = next_num_channels\r\n\r\n    def get_num_output_channels_from_level(self, level: int):\r\n        return self.get_num_output_channels_from_image_size(self.args.image_size // (2 ** level))\r\n\r\n    def get_num_output_channels_from_image_size(self, image_size: int):\r\n        return min(self.args.start_channels * (self.args.image_size // image_size), self.args.max_channels)\r\n\r\n    def forward(self, image: Tensor, pose: Optional[Tensor] = None) -> List[Tensor]:\r\n        if self.args.num_pose_params != 0:\r\n            assert pose is not None\r\n        else:\r\n            assert pose is None\r\n        outputs = []\r\n        feature = image\r\n        outputs.append(feature)\r\n        for block in self.downsample_blocks:\r\n            feature = block(feature)\r\n            outputs.append(feature)\r\n        if pose is not None:\r\n            n, c = pose.shape\r\n            pose = pose.view(n, c, 1, 1).repeat(1, 1, self.args.bottleneck_image_size, self.args.bottleneck_image_size)\r\n            feature = torch.cat([feature, pose], dim=1)\r\n        for block in self.bottleneck_blocks:\r\n            feature = block(feature)\r\n            outputs.append(feature)\r\n        for block in self.upsample_blocks:\r\n            feature = block(feature)\r\n            outputs.append(feature)\r\n        outputs.reverse()\r\n        return outputs\r\n"
  },
  {
    "path": "src/tha4/nn/common/poser_encoder_decoder_00_separable.py",
    "content": "import math\r\nfrom typing import Optional, List\r\n\r\nimport torch\r\nfrom torch import Tensor\r\nfrom torch.nn import ModuleList, Module\r\n\r\nfrom tha4.nn.common.poser_encoder_decoder_00 import PoserEncoderDecoder00Args\r\nfrom tha4.nn.resnet_block_seperable import ResnetBlockSeparable\r\nfrom tha4.nn.separable_conv import create_separable_conv3_block, create_separable_downsample_block, \\\r\n    create_separable_upsample_block\r\n\r\n\r\nclass PoserEncoderDecoder00Separable(Module):\r\n    def __init__(self, args: PoserEncoderDecoder00Args):\r\n        super().__init__()\r\n        self.args = args\r\n\r\n        self.num_levels = int(math.log2(args.image_size // args.bottleneck_image_size)) + 1\r\n\r\n        self.downsample_blocks = ModuleList()\r\n        self.downsample_blocks.append(\r\n            create_separable_conv3_block(\r\n                args.input_image_channels,\r\n                args.start_channels,\r\n                args.block_args))\r\n        current_image_size = args.image_size\r\n        current_num_channels = args.start_channels\r\n        while current_image_size > args.bottleneck_image_size:\r\n            next_image_size = current_image_size // 2\r\n            next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)\r\n            self.downsample_blocks.append(create_separable_downsample_block(\r\n                in_channels=current_num_channels,\r\n                out_channels=next_num_channels,\r\n                is_output_1x1=False,\r\n                block_args=args.block_args))\r\n            current_image_size = next_image_size\r\n            current_num_channels = next_num_channels\r\n        assert len(self.downsample_blocks) == self.num_levels\r\n\r\n        self.bottleneck_blocks = ModuleList()\r\n        self.bottleneck_blocks.append(create_separable_conv3_block(\r\n            in_channels=current_num_channels + args.num_pose_params,\r\n            out_channels=current_num_channels,\r\n            block_args=args.block_args))\r\n        for i in range(1, args.num_bottleneck_blocks):\r\n            self.bottleneck_blocks.append(\r\n                ResnetBlockSeparable.create(\r\n                    num_channels=current_num_channels,\r\n                    is1x1=False,\r\n                    block_args=args.block_args))\r\n\r\n        self.upsample_blocks = ModuleList()\r\n        while current_image_size < args.image_size:\r\n            next_image_size = current_image_size * 2\r\n            next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)\r\n            self.upsample_blocks.append(create_separable_upsample_block(\r\n                in_channels=current_num_channels,\r\n                out_channels=next_num_channels,\r\n                block_args=args.block_args))\r\n            current_image_size = next_image_size\r\n            current_num_channels = next_num_channels\r\n\r\n    def get_num_output_channels_from_level(self, level: int):\r\n        return self.get_num_output_channels_from_image_size(self.args.image_size // (2 ** level))\r\n\r\n    def get_num_output_channels_from_image_size(self, image_size: int):\r\n        return min(self.args.start_channels * (self.args.image_size // image_size), self.args.max_channels)\r\n\r\n    def forward(self, image: Tensor, pose: Optional[Tensor] = None) -> List[Tensor]:\r\n        if self.args.num_pose_params != 0:\r\n            assert pose is not None\r\n        else:\r\n            assert pose is None\r\n        outputs = []\r\n        feature = image\r\n        outputs.append(feature)\r\n        for block in self.downsample_blocks:\r\n            feature = block(feature)\r\n            outputs.append(feature)\r\n        if pose is not None:\r\n            n, c = pose.shape\r\n            pose = pose.view(n, c, 1, 1).repeat(1, 1, self.args.bottleneck_image_size, self.args.bottleneck_image_size)\r\n            feature = torch.cat([feature, pose], dim=1)\r\n        for block in self.bottleneck_blocks:\r\n            feature = block(feature)\r\n            outputs.append(feature)\r\n        for block in self.upsample_blocks:\r\n            feature = block(feature)\r\n            outputs.append(feature)\r\n        outputs.reverse()\r\n        return outputs\r\n"
  },
  {
    "path": "src/tha4/nn/common/resize_conv_encoder_decoder.py",
    "content": "import math\r\nfrom typing import Optional, List\r\n\r\nimport torch\r\nfrom torch import Tensor\r\nfrom torch.nn import Module, ModuleList, Sequential, Upsample\r\n\r\nfrom tha4.nn.common.conv_block_factory import ConvBlockFactory\r\nfrom tha4.nn.nonlinearity_factory import LeakyReLUFactory\r\nfrom tha4.nn.normalization import InstanceNorm2dFactory\r\nfrom tha4.nn.util import BlockArgs\r\n\r\n\r\nclass ResizeConvEncoderDecoderArgs:\r\n    def __init__(self,\r\n                 image_size: int,\r\n                 input_channels: int,\r\n                 start_channels: int,\r\n                 bottleneck_image_size,\r\n                 num_bottleneck_blocks,\r\n                 max_channels: int,\r\n                 block_args: Optional[BlockArgs] = None,\r\n                 upsample_mode: str = 'bilinear',\r\n                 use_separable_convolution=False):\r\n        self.use_separable_convolution = use_separable_convolution\r\n        self.upsample_mode = upsample_mode\r\n        self.block_args = block_args\r\n        self.max_channels = max_channels\r\n        self.num_bottleneck_blocks = num_bottleneck_blocks\r\n        self.bottleneck_image_size = bottleneck_image_size\r\n        self.start_channels = start_channels\r\n        self.image_size = image_size\r\n        self.input_channels = input_channels\r\n\r\n\r\nclass ResizeConvEncoderDecoder(Module):\r\n    def __init__(self, args: ResizeConvEncoderDecoderArgs):\r\n        super().__init__()\r\n        self.args = args\r\n\r\n        self.num_levels = int(math.log2(args.image_size // args.bottleneck_image_size)) + 1\r\n\r\n        conv_block_factory = ConvBlockFactory(args.block_args, args.use_separable_convolution)\r\n\r\n        self.downsample_blocks = ModuleList()\r\n        self.downsample_blocks.append(conv_block_factory.create_conv7_block(args.input_channels, args.start_channels))\r\n        current_image_size = args.image_size\r\n        current_num_channels = args.start_channels\r\n        while current_image_size > args.bottleneck_image_size:\r\n            next_image_size = current_image_size // 2\r\n            next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)\r\n            self.downsample_blocks.append(conv_block_factory.create_downsample_block(\r\n                in_channels=current_num_channels,\r\n                out_channels=next_num_channels,\r\n                is_output_1x1=False))\r\n            current_image_size = next_image_size\r\n            current_num_channels = next_num_channels\r\n        assert len(self.downsample_blocks) == self.num_levels\r\n\r\n        self.bottleneck_blocks = ModuleList()\r\n        for i in range(args.num_bottleneck_blocks):\r\n            self.bottleneck_blocks.append(conv_block_factory.create_resnet_block(current_num_channels, is_1x1=False))\r\n\r\n        self.output_image_sizes = [current_image_size]\r\n        self.output_num_channels = [current_num_channels]\r\n        self.upsample_blocks = ModuleList()\r\n        if args.upsample_mode == 'nearest':\r\n            align_corners = None\r\n        else:\r\n            align_corners = False\r\n        while current_image_size < args.image_size:\r\n            next_image_size = current_image_size * 2\r\n            next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)\r\n            self.upsample_blocks.append(\r\n                Sequential(\r\n                    Upsample(scale_factor=2, mode=args.upsample_mode, align_corners=align_corners),\r\n                    conv_block_factory.create_conv3_block(\r\n                        in_channels=current_num_channels, out_channels=next_num_channels)))\r\n            current_image_size = next_image_size\r\n            current_num_channels = next_num_channels\r\n            self.output_image_sizes.append(current_image_size)\r\n            self.output_num_channels.append(current_num_channels)\r\n\r\n    def get_num_output_channels_from_level(self, level: int):\r\n        return self.get_num_output_channels_from_image_size(self.args.image_size // (2 ** level))\r\n\r\n    def get_num_output_channels_from_image_size(self, image_size: int):\r\n        return min(self.args.start_channels * (self.args.image_size // image_size), self.args.max_channels)\r\n\r\n    def forward(self, feature: Tensor) -> List[Tensor]:\r\n        outputs = []\r\n        for block in self.downsample_blocks:\r\n            feature = block(feature)\r\n        for block in self.bottleneck_blocks:\r\n            feature = block(feature)\r\n        outputs.append(feature)\r\n        for block in self.upsample_blocks:\r\n            feature = block(feature)\r\n            outputs.append(feature)\r\n        return outputs"
  },
  {
    "path": "src/tha4/nn/common/resize_conv_unet.py",
    "content": "from typing import Optional, List\r\n\r\nimport torch\r\nfrom torch import Tensor\r\nfrom torch.nn import ModuleList, Module, Upsample\r\n\r\nfrom tha4.nn.common.conv_block_factory import ConvBlockFactory\r\nfrom tha4.nn.nonlinearity_factory import ReLUFactory\r\nfrom tha4.nn.normalization import InstanceNorm2dFactory\r\nfrom tha4.nn.util import BlockArgs\r\n\r\n\r\nclass ResizeConvUNetArgs:\r\n    def __init__(self,\r\n                 image_size: int,\r\n                 input_channels: int,\r\n                 start_channels: int,\r\n                 bottleneck_image_size: int,\r\n                 num_bottleneck_blocks: int,\r\n                 max_channels: int,\r\n                 upsample_mode: str = 'bilinear',\r\n                 block_args: Optional[BlockArgs] = None,\r\n                 use_separable_convolution: bool = False):\r\n        if block_args is None:\r\n            block_args = BlockArgs(\r\n                normalization_layer_factory=InstanceNorm2dFactory(),\r\n                nonlinearity_factory=ReLUFactory(inplace=False))\r\n\r\n        self.use_separable_convolution = use_separable_convolution\r\n        self.block_args = block_args\r\n        self.upsample_mode = upsample_mode\r\n        self.max_channels = max_channels\r\n        self.num_bottleneck_blocks = num_bottleneck_blocks\r\n        self.bottleneck_image_size = bottleneck_image_size\r\n        self.input_channels = input_channels\r\n        self.start_channels = start_channels\r\n        self.image_size = image_size\r\n\r\n\r\nclass ResizeConvUNet(Module):\r\n    def __init__(self, args: ResizeConvUNetArgs):\r\n        super().__init__()\r\n        self.args = args\r\n        conv_block_factory = ConvBlockFactory(args.block_args, args.use_separable_convolution)\r\n\r\n        self.downsample_blocks = ModuleList()\r\n        self.downsample_blocks.append(conv_block_factory.create_conv3_block(\r\n            self.args.input_channels,\r\n            self.args.start_channels))\r\n        current_channels = self.args.start_channels\r\n        current_size = self.args.image_size\r\n\r\n        size_to_channel = {\r\n            current_size: current_channels\r\n        }\r\n        while current_size > self.args.bottleneck_image_size:\r\n            next_size = current_size // 2\r\n            next_channels = min(self.args.max_channels, current_channels * 2)\r\n            self.downsample_blocks.append(conv_block_factory.create_downsample_block(\r\n                current_channels,\r\n                next_channels,\r\n                is_output_1x1=False))\r\n            current_size = next_size\r\n            current_channels = next_channels\r\n            size_to_channel[current_size] = current_channels\r\n\r\n        self.bottleneck_blocks = ModuleList()\r\n        for i in range(self.args.num_bottleneck_blocks):\r\n            self.bottleneck_blocks.append(conv_block_factory.create_resnet_block(current_channels, is_1x1=False))\r\n\r\n        self.output_image_sizes = [current_size]\r\n        self.output_num_channels = [current_channels]\r\n        self.upsample_blocks = ModuleList()\r\n        while current_size < self.args.image_size:\r\n            next_size = current_size * 2\r\n            next_channels = size_to_channel[next_size]\r\n            self.upsample_blocks.append(conv_block_factory.create_conv3_block(\r\n                current_channels + next_channels,\r\n                next_channels))\r\n            current_size = next_size\r\n            current_channels = next_channels\r\n            self.output_image_sizes.append(current_size)\r\n            self.output_num_channels.append(current_channels)\r\n\r\n        if args.upsample_mode == 'nearest':\r\n            align_corners = None\r\n        else:\r\n            align_corners = False\r\n        self.double_resolution = Upsample(scale_factor=2, mode=args.upsample_mode, align_corners=align_corners)\r\n\r\n    def forward(self, feature: Tensor) -> List[Tensor]:\r\n        downsampled_features = []\r\n        for block in self.downsample_blocks:\r\n            feature = block(feature)\r\n            downsampled_features.append(feature)\r\n\r\n        for block in self.bottleneck_blocks:\r\n            feature = block(feature)\r\n\r\n        outputs = [feature]\r\n        for i in range(0, len(self.upsample_blocks)):\r\n            feature = self.double_resolution(feature)\r\n            feature = torch.cat([feature, downsampled_features[-i - 2]], dim=1)\r\n            feature = self.upsample_blocks[i](feature)\r\n            outputs.append(feature)\r\n\r\n        return outputs"
  },
  {
    "path": "src/tha4/nn/common/unet.py",
    "content": "import math\r\nfrom enum import Enum\r\nfrom typing import Optional, List\r\n\r\nimport torch\r\nfrom torch import zero_, Tensor\r\nfrom torch.nn import Module, GroupNorm, Sequential, SiLU, Conv2d, AvgPool2d, Linear, Dropout, ModuleList\r\nfrom torch.nn.functional import interpolate\r\n\r\nfrom tha4.shion.core.module_factory import ModuleFactory\r\n\r\n\r\nclass Identity(Module):\r\n    def __init__(self):\r\n        super().__init__()\r\n\r\n    def forward(self, x):\r\n        return x\r\n\r\n\r\nclass IdentityFactory(ModuleFactory):\r\n    def create(self) -> Module:\r\n        return Identity()\r\n\r\n\r\ndef init_to_zero(module: Module):\r\n    with torch.no_grad():\r\n        zero_(module.weight)\r\n        zero_(module.bias)\r\n    return module\r\n\r\n\r\nclass Upsample(Module):\r\n    def __init__(self, in_channels: int, out_channels: Optional[int] = None, use_conv: bool = False):\r\n        super().__init__()\r\n        if out_channels is None:\r\n            out_channels = in_channels\r\n        self.in_channels = in_channels\r\n        if use_conv or in_channels != out_channels:\r\n            self.postprocess = Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)\r\n        else:\r\n            self.postprocess = Identity()\r\n\r\n    def forward(self, x):\r\n        assert x.shape[1] == self.in_channels\r\n        return self.postprocess(interpolate(x, scale_factor=2, mode=\"nearest\"))\r\n\r\n\r\nclass Downsample(Module):\r\n    def __init__(self, in_channels: int, out_channels: Optional[int] = None, use_conv: bool = False):\r\n        super().__init__()\r\n        if out_channels is None:\r\n            out_channels = in_channels\r\n        self.in_channels = in_channels\r\n        if use_conv or in_channels != out_channels:\r\n            self.op = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)\r\n        else:\r\n            self.op = AvgPool2d(kernel_size=2, stride=2)\r\n\r\n    def forward(self, x):\r\n        assert x.shape[1] == self.in_channels\r\n        return self.op(x)\r\n\r\n\r\ndef GroupNorm32(channels):\r\n    return GroupNorm(min(32, channels), channels)\r\n\r\n\r\nclass SamplingMode(Enum):\r\n    SAME_RESOLUTION = 0\r\n    UPSAMPLING = 1\r\n    DOWNSAMPING = 2\r\n\r\n\r\nclass ResBlockArgs:\r\n    def __init__(self,\r\n                 dropout_prob: float,\r\n                 use_cond0: bool = True,\r\n                 use_cond1: bool = False,\r\n                 init_conditioned_residual_to_zero: bool = False,\r\n                 use_conv_on_skip_connection: bool = False):\r\n        assert not use_cond1 or use_cond0\r\n        self.use_conv_on_skip_connection = use_conv_on_skip_connection\r\n        self.use_cond1 = use_cond1\r\n        self.use_cond0 = use_cond0\r\n        self.init_conditioned_residual_to_zero = init_conditioned_residual_to_zero\r\n        self.dropout_prob = dropout_prob\r\n\r\n\r\ndef apply_scaleshift(x: Tensor, scaleshift: Tensor, condition_bias: float = 1.0) -> Tensor:\r\n    assert len(scaleshift.shape) == 2\r\n    assert len(x.shape) == 4\r\n    assert x.shape[0] == scaleshift.shape[0]\r\n    assert 2 * x.shape[1] == scaleshift.shape[1]\r\n    scaleshift = scaleshift.reshape(scaleshift.shape[0], scaleshift.shape[1], 1, 1)\r\n    scale, shift = torch.chunk(scaleshift, 2, dim=1)\r\n    return x * (condition_bias + scale) + shift\r\n\r\n\r\nclass ResBlock(Module):\r\n    def __init__(self,\r\n                 in_channels: int,\r\n                 out_channels: int,\r\n                 cond0_channels: Optional[int] = None,\r\n                 cond1_channels: Optional[int] = None,\r\n                 sampling_mode: SamplingMode = SamplingMode.SAME_RESOLUTION,\r\n                 dropout_prob: float = 0.1,\r\n                 condition_bias: float = 1.0):\r\n        super().__init__()\r\n        assert cond0_channels is not None or cond1_channels is None\r\n\r\n        self.in_channels = in_channels\r\n        self.out_channels = out_channels\r\n        self.sampling_mode = sampling_mode\r\n        self.cond0_channels = cond0_channels\r\n        self.cond1_channels = cond1_channels\r\n        self.condition_bias = condition_bias\r\n\r\n        if sampling_mode == SamplingMode.UPSAMPLING:\r\n            self.x_resample = Upsample(in_channels)\r\n            self.h_resample = Upsample(in_channels)\r\n        elif sampling_mode == SamplingMode.DOWNSAMPING:\r\n            self.x_resample = Downsample(in_channels)\r\n            self.h_resample = Downsample(in_channels)\r\n        else:\r\n            self.x_resample = Identity()\r\n            self.h_resample = Identity()\r\n\r\n        self.nonlinear = SiLU()\r\n\r\n        # Layers before conditioning\r\n        self.norm0 = GroupNorm32(in_channels)\r\n        self.conv0 = Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)\r\n\r\n        # Conditioning layers\r\n        if cond0_channels is not None:\r\n            self.cond0_layers = Sequential(\r\n                SiLU(),\r\n                Linear(cond0_channels, 2 * out_channels))\r\n            self.norm1 = GroupNorm32(out_channels)\r\n            self.dropout = Dropout(dropout_prob)\r\n            self.conv1 = init_to_zero(Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1))\r\n        if cond1_channels is not None:\r\n            self.cond1_layers = Sequential(\r\n                SiLU(),\r\n                Linear(cond0_channels, 2 * out_channels))\r\n\r\n        # Skip layer\r\n        if in_channels == out_channels:\r\n            self.skip = Identity()\r\n        else:\r\n            self.skip = Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)\r\n\r\n    def forward(self, x: Tensor, cond0: Optional[Tensor] = None, cond1: Optional[Tensor] = None) -> Tensor:\r\n        assert self.cond0_channels is None or cond0 is not None\r\n        assert self.cond1_channels is None or cond1 is not None\r\n\r\n        h = self.conv0(self.h_resample(self.nonlinear(self.norm0(x))))\r\n        if self.cond0_channels is not None:\r\n            h = self.norm1(h)\r\n            h = apply_scaleshift(h, self.cond0_layers(cond0), self.condition_bias)\r\n            if self.cond1_channels is not None:\r\n                h = apply_scaleshift(h, self.cond1_layers(cond1), self.condition_bias)\r\n            h = self.conv1(self.dropout(self.nonlinear(h)))\r\n        return self.skip(self.x_resample(x)) + h\r\n\r\n\r\nclass AttentionBlockArgs:\r\n    def __init__(self,\r\n                 num_heads: Optional[int] = 1,\r\n                 num_head_channels: Optional[int] = None,\r\n                 use_new_attention_order: bool = False):\r\n        self.use_new_attention_order = use_new_attention_order\r\n        self.num_head_channels = num_head_channels\r\n        self.num_heads = num_heads\r\n\r\n\r\ndef qkv_attention_legacy(qkv: torch.Tensor, num_heads: int):\r\n    assert len(qkv.shape) == 3\r\n    B, W, L = qkv.shape\r\n    H = num_heads\r\n    assert W % (3 * H) == 0\r\n    C = W // (3 * H)\r\n    q, k, v = qkv.reshape(B * H, C * 3, L).split(C, dim=1)\r\n    scale = 1.0 / math.sqrt(math.sqrt(C))\r\n    weight = torch.einsum('bct,bcs->bts', q * scale, k * scale)\r\n    weight = torch.softmax(weight, dim=-1)\r\n    output = torch.einsum(\"bts,bcs->bct\", weight, v)\r\n    return output.reshape(B, H * C, L)\r\n\r\n\r\ndef qkv_attention(qkv: torch.Tensor, num_heads: int):\r\n    B, W, L = qkv.shape\r\n    H = num_heads\r\n    assert W % (3 * H) == 0\r\n    C = W // (3 * H)\r\n    q, k, v = qkv.chunk(3, dim=1)\r\n    scale = 1.0 / math.sqrt(math.sqrt(C))\r\n    weight = torch.einsum(\"bct,bcs->bts\", (q * scale).view(B * H, C, L), (k * scale).view(B * H, C, L))\r\n    weight = torch.softmax(weight, dim=-1)\r\n    output = torch.einsum(\"bts,bcs->bct\", weight, v.reshape(B * H, C, L))\r\n    return output.reshape(B, H * C, L)\r\n\r\n\r\nclass AttentionBlock(Module):\r\n    def __init__(self,\r\n                 num_channels: int,\r\n                 args: AttentionBlockArgs):\r\n        super().__init__()\r\n        self.use_new_attention_order = args.use_new_attention_order\r\n\r\n        if args.num_head_channels is None:\r\n            assert args.num_heads is not None\r\n            assert num_channels % args.num_heads == 0\r\n            self.num_heads = args.num_heads\r\n            self.num_head_channels = num_channels // self.num_heads\r\n        elif args.num_heads is None:\r\n            assert args.num_head_channels is not None\r\n            assert num_channels % args.num_head_channels == 0\r\n            self.num_heads = num_channels // args.num_head_channels\r\n            self.num_head_channels = args.num_head_channels\r\n\r\n        self.norm = GroupNorm32(num_channels)\r\n        self.qkv = Conv2d(num_channels, 3 * num_channels, kernel_size=1, stride=1, padding=0)\r\n        self.conv = Conv2d(num_channels, num_channels, kernel_size=1, stride=1, padding=0)\r\n        with torch.no_grad():\r\n            zero_(self.conv.weight)\r\n            zero_(self.conv.bias)\r\n\r\n    def forward(self, x: torch.Tensor):\r\n        assert len(x.shape) == 4\r\n        B, C, H, W = x.shape\r\n        qkv = self.qkv(self.norm(x)).reshape(B, 3 * C, H * W)\r\n        if self.use_new_attention_order:\r\n            h = qkv_attention(qkv, self.num_heads)\r\n        else:\r\n            h = qkv_attention_legacy(qkv, self.num_heads)\r\n        h = self.conv(h.reshape(B, C, H, W))\r\n        return x + h\r\n\r\n\r\nclass Arity3To1(Module):\r\n    def __init__(self, module: Module):\r\n        super().__init__()\r\n        self.module = module\r\n\r\n    def forward(self, x: Tensor, y: Optional[Tensor] = None, z: Optional[Tensor] = None):\r\n        return self.module(x)\r\n\r\n\r\nclass DownsamplingBlock(Module):\r\n    def __init__(self,\r\n                 in_channels: int,\r\n                 out_channels: int,\r\n                 cond0_channels: Optional[int],\r\n                 cond1_channels: Optional[int],\r\n                 num_res_blocks: int,\r\n                 dropout_prob: float,\r\n                 use_attention: bool,\r\n                 perform_downsampling: bool,\r\n                 resample_with_res_block: bool,\r\n                 use_conv_to_resample: bool,\r\n                 attention_block_args: AttentionBlockArgs,\r\n                 condition_bias: float = 1.0):\r\n        super().__init__()\r\n        self.use_attention = use_attention\r\n        self.res_blocks = ModuleList()\r\n        self.attention_blocks = ModuleList()\r\n        self.perform_downsampling = perform_downsampling\r\n        self.output_channels = []\r\n        for j in range(num_res_blocks):\r\n            self.res_blocks.append(ResBlock(\r\n                in_channels=in_channels if j == 0 else out_channels,\r\n                out_channels=out_channels,\r\n                cond0_channels=cond0_channels,\r\n                cond1_channels=cond1_channels,\r\n                dropout_prob=dropout_prob,\r\n                condition_bias=condition_bias))\r\n            if use_attention:\r\n                self.attention_blocks.append(AttentionBlock(out_channels, attention_block_args))\r\n            self.output_channels.append(out_channels)\r\n        if perform_downsampling:\r\n            if resample_with_res_block:\r\n                self.downsample = ResBlock(\r\n                    in_channels=out_channels,\r\n                    out_channels=out_channels,\r\n                    cond0_channels=cond0_channels,\r\n                    cond1_channels=cond1_channels,\r\n                    dropout_prob=dropout_prob,\r\n                    sampling_mode=SamplingMode.DOWNSAMPING,\r\n                    condition_bias=condition_bias)\r\n            else:\r\n                self.downsample = Arity3To1(Downsample(out_channels, use_conv_to_resample))\r\n            self.output_channels.append(out_channels)\r\n\r\n    def forward(self, h: Tensor, cond0: Optional[Tensor] = None, cond1: Optional[Tensor] = None) -> List[Tensor]:\r\n        hs = []\r\n        for i in range(len(self.res_blocks)):\r\n            h = self.res_blocks[i].forward(h, cond0, cond1)\r\n            if self.use_attention:\r\n                h = self.attention_blocks[i].forward(h)\r\n            hs.append(h)\r\n        if self.perform_downsampling:\r\n            hs.append(self.downsample(h, cond0, cond1))\r\n        return hs\r\n\r\n\r\nclass UpsamplingBlock(Module):\r\n    def __init__(self,\r\n                 in_channels: int,\r\n                 out_channels: int,\r\n                 cond0_channels: Optional[int],\r\n                 cond1_channels: Optional[int],\r\n                 num_resnet_blocks: int,\r\n                 skip_channels: List[int],\r\n                 dropout_prob: float,\r\n                 use_attention: bool,\r\n                 perform_upsampling: bool,\r\n                 resample_with_res_block: bool,\r\n                 use_conv_to_resample: bool,\r\n                 attention_block_args: AttentionBlockArgs,\r\n                 condition_bias: float = 1.0):\r\n        super().__init__()\r\n        self.use_attention = use_attention\r\n        self.resnet_blocks = ModuleList()\r\n        self.attention_blocks = ModuleList()\r\n        self.perform_upsampling = perform_upsampling\r\n        for i in range(num_resnet_blocks):\r\n            self.resnet_blocks.append(ResBlock(\r\n                in_channels=(in_channels if i == 0 else out_channels) + skip_channels[i],\r\n                out_channels=out_channels,\r\n                cond0_channels=cond0_channels,\r\n                cond1_channels=cond1_channels,\r\n                dropout_prob=dropout_prob,\r\n                condition_bias=condition_bias))\r\n            if use_attention:\r\n                self.attention_blocks.append(AttentionBlock(out_channels, attention_block_args))\r\n        if perform_upsampling:\r\n            if resample_with_res_block:\r\n                self.upsample = ResBlock(\r\n                    in_channels=out_channels,\r\n                    out_channels=out_channels,\r\n                    cond0_channels=cond0_channels,\r\n                    cond1_channels=cond1_channels,\r\n                    sampling_mode=SamplingMode.UPSAMPLING,\r\n                    dropout_prob=dropout_prob,\r\n                    condition_bias=condition_bias)\r\n            else:\r\n                self.upsample = Arity3To1(Upsample(out_channels, use_conv_to_resample))\r\n\r\n    def forward(self,\r\n                h: Tensor,\r\n                skips: List[Tensor],\r\n                cond0: Optional[Tensor] = None,\r\n                cond1: Optional[Tensor] = None) -> Tensor:\r\n        for i in range(len(self.resnet_blocks)):\r\n            h = self.resnet_blocks[i].forward(torch.concat([h, skips[i]], dim=1), cond0, cond1)\r\n            if self.use_attention:\r\n                h = self.attention_blocks[i].forward(h)\r\n        if self.perform_upsampling:\r\n            h = self.upsample.forward(h, cond0, cond1)\r\n        return h\r\n\r\n\r\ndef compute_timestep_embedding(t: Tensor, out_channels: int):\r\n    assert len(t.shape) == 2\r\n    b, c = t.shape\r\n    assert c == 1\r\n    half_channels = out_channels // 2\r\n    scale = -math.log(10000.0) / (half_channels - 1)\r\n    log_times = scale * torch.arange(0, half_channels, device=t.device)\r\n    times = torch.exp(log_times).reshape(1, half_channels) * t\r\n    t_emb = torch.cat([torch.cos(times), torch.sin(times)], dim=1)\r\n    if out_channels % 2 == 1:\r\n        t_emb = torch.nn.functional.pad(t_emb, (1, 1), mode='constant')\r\n    return t_emb\r\n\r\n\r\nclass TimeEmbedding(Module):\r\n    def __init__(self, out_channels: int):\r\n        super().__init__()\r\n        self.out_channels = out_channels\r\n\r\n    def forward(self, t: Tensor):\r\n        return compute_timestep_embedding(t, self.out_channels)\r\n\r\n\r\nclass UnetArgs:\r\n    def __init__(self,\r\n                 in_channels: int = 3,\r\n                 out_channels: int = 3,\r\n                 model_channels: int = 64,\r\n                 level_channel_multipliers: Optional[List[int]] = None,\r\n                 level_use_attention: Optional[List[bool]] = None,\r\n                 num_res_blocks_per_level: int = 2,\r\n                 num_middle_res_blocks: int = 2,\r\n                 time_embedding_channels: Optional[int] = None,\r\n                 cond_input_channels: int = 4,\r\n                 cond_internal_channels: int = 512,\r\n                 attention_block_args: Optional[AttentionBlockArgs] = None,\r\n                 dropout_prob: float = 0.1,\r\n                 resample_with_res_block: bool = True,\r\n                 use_conv_to_resample=False,\r\n                 condition_bias: float = 1.0):\r\n        assert len(level_channel_multipliers) == len(level_use_attention)\r\n        assert not use_conv_to_resample or not resample_with_res_block\r\n\r\n        if time_embedding_channels is None:\r\n            time_embedding_channels = model_channels\r\n        if level_channel_multipliers is None:\r\n            level_channel_multipliers = [1, 2, 4, 8]\r\n        if level_use_attention is None:\r\n            level_use_attention = [False for _ in level_channel_multipliers]\r\n        if attention_block_args is None:\r\n            attention_block_args = AttentionBlockArgs(\r\n                num_heads=1,\r\n                num_head_channels=None,\r\n                use_new_attention_order=False)\r\n\r\n        self.condition_bias = condition_bias\r\n        self.use_conv_to_resample = use_conv_to_resample\r\n        self.resample_with_res_block = resample_with_res_block\r\n        self.cond_internal_channels = cond_internal_channels\r\n        self.dropout_prob = dropout_prob\r\n        self.attention_block_args = attention_block_args\r\n        self.time_embedding_channels = time_embedding_channels\r\n        self.num_res_blocks_per_level = num_res_blocks_per_level\r\n        self.level_use_attention = level_use_attention\r\n        self.level_channel_multipliers = level_channel_multipliers\r\n        self.model_channels = model_channels\r\n        self.out_channels = out_channels\r\n        self.in_channels = in_channels\r\n        self.num_levels = len(level_channel_multipliers)\r\n        self.num_middle_res_blocks = num_middle_res_blocks\r\n        self.cond_input_channels = cond_input_channels\r\n\r\n\r\nclass Unet(Module):\r\n    def __init__(self, args: UnetArgs):\r\n        super().__init__()\r\n        self.args = args\r\n\r\n        self.time_embed = Sequential(\r\n            TimeEmbedding(self.args.time_embedding_channels),\r\n            Linear(self.args.time_embedding_channels, self.args.cond_internal_channels),\r\n            SiLU(),\r\n            Linear(self.args.cond_internal_channels, self.args.cond_internal_channels))\r\n\r\n        self.cond_embed = Sequential(\r\n            Linear(self.args.cond_input_channels, self.args.cond_internal_channels),\r\n            SiLU(),\r\n            Linear(self.args.cond_internal_channels, self.args.cond_internal_channels))\r\n\r\n        self.first_conv = Conv2d(args.in_channels, args.model_channels, kernel_size=3, stride=1, padding=1)\r\n        current_channels = args.model_channels\r\n        channels = [current_channels]\r\n\r\n        # Downsampling blocks\r\n        self.down_blocks = ModuleList()\r\n        for i in range(args.num_levels):\r\n            out_channels = args.model_channels * args.level_channel_multipliers[i]\r\n            perform_downsampling = i < args.num_levels - 1\r\n            down_block = DownsamplingBlock(\r\n                in_channels=current_channels,\r\n                out_channels=out_channels,\r\n                cond0_channels=args.cond_internal_channels,\r\n                cond1_channels=args.cond_internal_channels,\r\n                num_res_blocks=args.num_res_blocks_per_level,\r\n                dropout_prob=args.dropout_prob,\r\n                use_attention=args.level_use_attention[i],\r\n                perform_downsampling=perform_downsampling,\r\n                attention_block_args=args.attention_block_args,\r\n                resample_with_res_block=args.resample_with_res_block,\r\n                use_conv_to_resample=args.use_conv_to_resample,\r\n                condition_bias=args.condition_bias)\r\n            self.down_blocks.append(down_block)\r\n            current_channels = out_channels\r\n            channels += down_block.output_channels\r\n\r\n        # Middle blocks\r\n        self.middle_blocks = ModuleList()\r\n        for i in range(self.args.num_middle_res_blocks - 1):\r\n            self.middle_blocks.append(ResBlock(\r\n                in_channels=current_channels,\r\n                out_channels=current_channels,\r\n                cond0_channels=args.cond_internal_channels,\r\n                cond1_channels=args.cond_internal_channels,\r\n                dropout_prob=args.dropout_prob,\r\n                condition_bias=args.condition_bias))\r\n            self.middle_blocks.append(\r\n                Arity3To1(AttentionBlock(num_channels=current_channels, args=args.attention_block_args)))\r\n        self.middle_blocks.append(ResBlock(\r\n            in_channels=current_channels,\r\n            out_channels=current_channels,\r\n            cond0_channels=args.cond_internal_channels,\r\n            cond1_channels=args.cond_internal_channels,\r\n            dropout_prob=args.dropout_prob,\r\n            condition_bias=args.condition_bias))\r\n\r\n        # Upsampling blocks\r\n        self.up_blocks = ModuleList()\r\n        for i in reversed(range(args.num_levels)):\r\n            skip_channels = []\r\n            for j in range(args.num_res_blocks_per_level + 1):\r\n                skip_channels.append(channels.pop())\r\n            perform_upsampling = i > 0\r\n            out_channels = args.model_channels * args.level_channel_multipliers[i]\r\n            up_block = UpsamplingBlock(\r\n                in_channels=current_channels,\r\n                out_channels=out_channels,\r\n                cond0_channels=args.cond_internal_channels,\r\n                cond1_channels=args.cond_internal_channels,\r\n                num_resnet_blocks=args.num_res_blocks_per_level + 1,\r\n                skip_channels=skip_channels,\r\n                dropout_prob=args.dropout_prob,\r\n                use_attention=args.level_use_attention[i],\r\n                perform_upsampling=perform_upsampling,\r\n                attention_block_args=args.attention_block_args,\r\n                resample_with_res_block=args.resample_with_res_block,\r\n                use_conv_to_resample=args.use_conv_to_resample,\r\n                condition_bias=args.condition_bias)\r\n            self.up_blocks.append(up_block)\r\n            current_channels = out_channels\r\n        assert len(channels) == 0\r\n\r\n        self.last = Sequential(\r\n            GroupNorm32(current_channels),\r\n            SiLU(),\r\n            init_to_zero(Conv2d(current_channels, args.out_channels, kernel_size=3, stride=1, padding=1)))\r\n\r\n    def forward(self, x: Tensor, t: Tensor, cond: Tensor):\r\n        t_emb = self.time_embed(t)\r\n        cond_emb = self.cond_embed(cond)\r\n        hs = [self.first_conv(x)]\r\n        for block in self.down_blocks:\r\n            hs += block.forward(hs[-1], t_emb, cond_emb)\r\n        h = hs[-1]\r\n        for block in self.middle_blocks:\r\n            h = block(h, t_emb, cond_emb)\r\n        for block in self.up_blocks:\r\n            skips = []\r\n            for i in range(self.args.num_res_blocks_per_level + 1):\r\n                skips.append(hs.pop())\r\n            h = block.forward(h, skips, t_emb, cond_emb)\r\n        assert len(hs) == 0\r\n        return self.last(h)\r\n\r\n\r\nclass UnetWithFirstConvAddition(Module):\r\n    def __init__(self, args: UnetArgs):\r\n        super().__init__()\r\n        self.args = args\r\n\r\n        self.time_embed = Sequential(\r\n            TimeEmbedding(self.args.time_embedding_channels),\r\n            Linear(self.args.time_embedding_channels, self.args.cond_internal_channels),\r\n            SiLU(),\r\n            Linear(self.args.cond_internal_channels, self.args.cond_internal_channels))\r\n\r\n        self.cond_embed = Sequential(\r\n            Linear(self.args.cond_input_channels, self.args.cond_internal_channels),\r\n            SiLU(),\r\n            Linear(self.args.cond_internal_channels, self.args.cond_internal_channels))\r\n\r\n        self.first_conv = Conv2d(args.in_channels, args.model_channels, kernel_size=3, stride=1, padding=1)\r\n        current_channels = args.model_channels\r\n        channels = [current_channels]\r\n\r\n        # Downsampling blocks\r\n        self.down_blocks = ModuleList()\r\n        for i in range(args.num_levels):\r\n            out_channels = args.model_channels * args.level_channel_multipliers[i]\r\n            perform_downsampling = i < args.num_levels - 1\r\n            down_block = DownsamplingBlock(\r\n                in_channels=current_channels,\r\n                out_channels=out_channels,\r\n                cond0_channels=args.cond_internal_channels,\r\n                cond1_channels=args.cond_internal_channels,\r\n                num_res_blocks=args.num_res_blocks_per_level,\r\n                dropout_prob=args.dropout_prob,\r\n                use_attention=args.level_use_attention[i],\r\n                perform_downsampling=perform_downsampling,\r\n                attention_block_args=args.attention_block_args,\r\n                resample_with_res_block=args.resample_with_res_block,\r\n                use_conv_to_resample=args.use_conv_to_resample,\r\n                condition_bias=args.condition_bias)\r\n            self.down_blocks.append(down_block)\r\n            current_channels = out_channels\r\n            channels += down_block.output_channels\r\n\r\n        # Middle blocks\r\n        self.middle_blocks = ModuleList()\r\n        for i in range(self.args.num_middle_res_blocks - 1):\r\n            self.middle_blocks.append(ResBlock(\r\n                in_channels=current_channels,\r\n                out_channels=current_channels,\r\n                cond0_channels=args.cond_internal_channels,\r\n                cond1_channels=args.cond_internal_channels,\r\n                dropout_prob=args.dropout_prob,\r\n                condition_bias=args.condition_bias))\r\n            self.middle_blocks.append(\r\n                Arity3To1(AttentionBlock(num_channels=current_channels, args=args.attention_block_args)))\r\n        self.middle_blocks.append(ResBlock(\r\n            in_channels=current_channels,\r\n            out_channels=current_channels,\r\n            cond0_channels=args.cond_internal_channels,\r\n            cond1_channels=args.cond_internal_channels,\r\n            dropout_prob=args.dropout_prob,\r\n            condition_bias=args.condition_bias))\r\n\r\n        # Upsampling blocks\r\n        self.up_blocks = ModuleList()\r\n        for i in reversed(range(args.num_levels)):\r\n            skip_channels = []\r\n            for j in range(args.num_res_blocks_per_level + 1):\r\n                skip_channels.append(channels.pop())\r\n            perform_upsampling = i > 0\r\n            out_channels = args.model_channels * args.level_channel_multipliers[i]\r\n            up_block = UpsamplingBlock(\r\n                in_channels=current_channels,\r\n                out_channels=out_channels,\r\n                cond0_channels=args.cond_internal_channels,\r\n                cond1_channels=args.cond_internal_channels,\r\n                num_resnet_blocks=args.num_res_blocks_per_level + 1,\r\n                skip_channels=skip_channels,\r\n                dropout_prob=args.dropout_prob,\r\n                use_attention=args.level_use_attention[i],\r\n                perform_upsampling=perform_upsampling,\r\n                attention_block_args=args.attention_block_args,\r\n                resample_with_res_block=args.resample_with_res_block,\r\n                use_conv_to_resample=args.use_conv_to_resample,\r\n                condition_bias=args.condition_bias)\r\n            self.up_blocks.append(up_block)\r\n            current_channels = out_channels\r\n        assert len(channels) == 0\r\n\r\n        self.last = Sequential(\r\n            GroupNorm32(current_channels),\r\n            SiLU(),\r\n            init_to_zero(Conv2d(current_channels, args.out_channels, kernel_size=3, stride=1, padding=1)))\r\n\r\n    def forward(self, x: Tensor, t: Tensor, cond: Tensor, first_conv_addition: Tensor):\r\n        t_emb = self.time_embed(t)\r\n        cond_emb = self.cond_embed(cond)\r\n        first_conv = self.first_conv(x)\r\n        hs = [first_conv + first_conv_addition]\r\n        for block in self.down_blocks:\r\n            hs += block.forward(hs[-1], t_emb, cond_emb)\r\n        h = hs[-1]\r\n        for block in self.middle_blocks:\r\n            h = block(h, t_emb, cond_emb)\r\n        for block in self.up_blocks:\r\n            skips = []\r\n            for i in range(self.args.num_res_blocks_per_level + 1):\r\n                skips.append(hs.pop())\r\n            h = block.forward(h, skips, t_emb, cond_emb)\r\n        assert len(hs) == 0\r\n        return self.last(h)\r\n"
  },
  {
    "path": "src/tha4/nn/conv.py",
    "content": "from typing import Optional, Union, Callable\r\n\r\nfrom torch.nn import Conv2d, Module, Sequential, ConvTranspose2d\r\n\r\nfrom tha4.shion.core.module_factory import ModuleFactory\r\nfrom tha4.nn.nonlinearity_factory import resolve_nonlinearity_factory\r\nfrom tha4.nn.normalization import NormalizationLayerFactory\r\nfrom tha4.nn.util import wrap_conv_or_linear_module, BlockArgs\r\n\r\n\r\ndef create_conv7(in_channels: int, out_channels: int,\r\n                 bias: bool = False,\r\n                 initialization_method: Union[str, Callable[[Module], Module]] = 'he',\r\n                 use_spectral_norm: bool = False) -> Module:\r\n    return wrap_conv_or_linear_module(\r\n        Conv2d(in_channels, out_channels, kernel_size=7, stride=1, padding=3, bias=bias),\r\n        initialization_method,\r\n        use_spectral_norm)\r\n\r\n\r\ndef create_conv7_from_block_args(in_channels: int,\r\n                                 out_channels: int,\r\n                                 bias: bool = False,\r\n                                 block_args: Optional[BlockArgs] = None) -> Module:\r\n    if block_args is None:\r\n        block_args = BlockArgs()\r\n    return create_conv7(\r\n        in_channels, out_channels, bias,\r\n        block_args.initialization_method,\r\n        block_args.use_spectral_norm)\r\n\r\n\r\ndef create_conv3(in_channels: int,\r\n                 out_channels: int,\r\n                 bias: bool = False,\r\n                 initialization_method: Union[str, Callable[[Module], Module]] = 'he',\r\n                 use_spectral_norm: bool = False) -> Module:\r\n    return wrap_conv_or_linear_module(\r\n        Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=bias),\r\n        initialization_method,\r\n        use_spectral_norm)\r\n\r\n\r\ndef create_conv3_from_block_args(in_channels: int, out_channels: int,\r\n                                 bias: bool = False,\r\n                                 block_args: Optional[BlockArgs] = None):\r\n    if block_args is None:\r\n        block_args = BlockArgs()\r\n    return create_conv3(in_channels, out_channels, bias,\r\n                        block_args.initialization_method,\r\n                        block_args.use_spectral_norm)\r\n\r\n\r\ndef create_conv1(in_channels: int, out_channels: int,\r\n                 initialization_method: Union[str, Callable[[Module], Module]] = 'he',\r\n                 bias: bool = False,\r\n                 use_spectral_norm: bool = False) -> Module:\r\n    return wrap_conv_or_linear_module(\r\n        Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),\r\n        initialization_method,\r\n        use_spectral_norm)\r\n\r\n\r\ndef create_conv1_from_block_args(in_channels: int,\r\n                                 out_channels: int,\r\n                                 bias: bool = False,\r\n                                 block_args: Optional[BlockArgs] = None) -> Module:\r\n    if block_args is None:\r\n        block_args = BlockArgs()\r\n    return create_conv1(\r\n        in_channels=in_channels,\r\n        out_channels=out_channels,\r\n        initialization_method=block_args.initialization_method,\r\n        bias=bias,\r\n        use_spectral_norm=block_args.use_spectral_norm)\r\n\r\n\r\ndef create_conv7_block(in_channels: int, out_channels: int,\r\n                       initialization_method: Union[str, Callable[[Module], Module]] = 'he',\r\n                       nonlinearity_factory: Optional[ModuleFactory] = None,\r\n                       normalization_layer_factory: Optional[NormalizationLayerFactory] = None,\r\n                       use_spectral_norm: bool = False) -> Module:\r\n    nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory)\r\n    return Sequential(\r\n        create_conv7(in_channels, out_channels,\r\n                     bias=False, initialization_method=initialization_method, use_spectral_norm=use_spectral_norm),\r\n        NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(out_channels, affine=True),\r\n        resolve_nonlinearity_factory(nonlinearity_factory).create())\r\n\r\n\r\ndef create_conv7_block_from_block_args(\r\n        in_channels: int, out_channels: int,\r\n        block_args: Optional[BlockArgs] = None) -> Module:\r\n    if block_args is None:\r\n        block_args = BlockArgs()\r\n    return create_conv7_block(in_channels, out_channels,\r\n                              block_args.initialization_method,\r\n                              block_args.nonlinearity_factory,\r\n                              block_args.normalization_layer_factory,\r\n                              block_args.use_spectral_norm)\r\n\r\n\r\ndef create_conv3_block(in_channels: int, out_channels: int,\r\n                       initialization_method: Union[str, Callable[[Module], Module]] = 'he',\r\n                       nonlinearity_factory: Optional[ModuleFactory] = None,\r\n                       normalization_layer_factory: Optional[NormalizationLayerFactory] = None,\r\n                       use_spectral_norm: bool = False) -> Module:\r\n    nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory)\r\n    return Sequential(\r\n        create_conv3(in_channels, out_channels,\r\n                     bias=False, initialization_method=initialization_method, use_spectral_norm=use_spectral_norm),\r\n        NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(out_channels, affine=True),\r\n        resolve_nonlinearity_factory(nonlinearity_factory).create())\r\n\r\n\r\ndef create_conv3_block_from_block_args(\r\n        in_channels: int, out_channels: int, block_args: Optional[BlockArgs] = None):\r\n    if block_args is None:\r\n        block_args = BlockArgs()\r\n    return create_conv3_block(in_channels, out_channels,\r\n                              block_args.initialization_method,\r\n                              block_args.nonlinearity_factory,\r\n                              block_args.normalization_layer_factory,\r\n                              block_args.use_spectral_norm)\r\n\r\n\r\ndef create_downsample_block(in_channels: int, out_channels: int,\r\n                            is_output_1x1: bool = False,\r\n                            initialization_method: Union[str, Callable[[Module], Module]] = 'he',\r\n                            nonlinearity_factory: Optional[ModuleFactory] = None,\r\n                            normalization_layer_factory: Optional[NormalizationLayerFactory] = None,\r\n                            use_spectral_norm: bool = False) -> Module:\r\n    if is_output_1x1:\r\n        return Sequential(\r\n            wrap_conv_or_linear_module(\r\n                Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),\r\n                initialization_method,\r\n                use_spectral_norm),\r\n            resolve_nonlinearity_factory(nonlinearity_factory).create())\r\n    else:\r\n        return Sequential(\r\n            wrap_conv_or_linear_module(\r\n                Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),\r\n                initialization_method,\r\n                use_spectral_norm),\r\n            NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(out_channels, affine=True),\r\n            resolve_nonlinearity_factory(nonlinearity_factory).create())\r\n\r\n\r\ndef create_downsample_block_from_block_args(in_channels: int, out_channels: int,\r\n                                            is_output_1x1: bool = False,\r\n                                            block_args: Optional[BlockArgs] = None):\r\n    if block_args is None:\r\n        block_args = BlockArgs()\r\n    return create_downsample_block(\r\n        in_channels, out_channels,\r\n        is_output_1x1,\r\n        block_args.initialization_method,\r\n        block_args.nonlinearity_factory,\r\n        block_args.normalization_layer_factory,\r\n        block_args.use_spectral_norm)\r\n\r\n\r\ndef create_upsample_block(in_channels: int,\r\n                          out_channels: int,\r\n                          initialization_method: Union[str, Callable[[Module], Module]] = 'he',\r\n                          nonlinearity_factory: Optional[ModuleFactory] = None,\r\n                          normalization_layer_factory: Optional[NormalizationLayerFactory] = None,\r\n                          use_spectral_norm: bool = False) -> Module:\r\n    nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory)\r\n    return Sequential(\r\n        wrap_conv_or_linear_module(\r\n            ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),\r\n            initialization_method,\r\n            use_spectral_norm),\r\n        NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(out_channels, affine=True),\r\n        resolve_nonlinearity_factory(nonlinearity_factory).create())\r\n\r\n\r\ndef create_upsample_block_from_block_args(in_channels: int,\r\n                                          out_channels: int,\r\n                                          block_args: Optional[BlockArgs] = None) -> Module:\r\n    if block_args is None:\r\n        block_args = BlockArgs()\r\n    return create_upsample_block(in_channels, out_channels,\r\n                                 block_args.initialization_method,\r\n                                 block_args.nonlinearity_factory,\r\n                                 block_args.normalization_layer_factory,\r\n                                 block_args.use_spectral_norm)\r\n"
  },
  {
    "path": "src/tha4/nn/eyebrow_decomposer/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/nn/eyebrow_decomposer/eyebrow_decomposer_00.py",
    "content": "from typing import List, Optional\r\n\r\nimport torch\r\nfrom torch import Tensor\r\nfrom torch.nn import Module\r\n\r\nfrom tha4.nn.common.poser_encoder_decoder_00 import PoserEncoderDecoder00Args, PoserEncoderDecoder00\r\nfrom tha4.nn.image_processing_util import apply_color_change\r\nfrom tha4.shion.core.module_factory import ModuleFactory\r\nfrom tha4.nn.nonlinearity_factory import ReLUFactory\r\nfrom tha4.nn.normalization import InstanceNorm2dFactory\r\nfrom tha4.nn.util import BlockArgs\r\n\r\n\r\nclass EyebrowDecomposer00Args(PoserEncoderDecoder00Args):\r\n    def __init__(self,\r\n                 image_size: int = 128,\r\n                 image_channels: int = 4,\r\n                 start_channels: int = 64,\r\n                 bottleneck_image_size=16,\r\n                 num_bottleneck_blocks=6,\r\n                 max_channels: int = 512,\r\n                 block_args: Optional[BlockArgs] = None):\r\n        super().__init__(\r\n            image_size,\r\n            image_channels,\r\n            image_channels,\r\n            0,\r\n            start_channels,\r\n            bottleneck_image_size,\r\n            num_bottleneck_blocks,\r\n            max_channels,\r\n            block_args)\r\n\r\n\r\nclass EyebrowDecomposer00(Module):\r\n    def __init__(self, args: EyebrowDecomposer00Args):\r\n        super().__init__()\r\n        self.args = args\r\n        self.body = PoserEncoderDecoder00(args)\r\n        self.background_layer_alpha = self.args.create_alpha_block()\r\n        self.background_layer_color_change = self.args.create_color_change_block()\r\n        self.eyebrow_layer_alpha = self.args.create_alpha_block()\r\n        self.eyebrow_layer_color_change = self.args.create_color_change_block()\r\n\r\n    def forward(self, image: Tensor, *args) -> List[Tensor]:\r\n        feature = self.body(image)[0]\r\n\r\n        background_layer_alpha = self.background_layer_alpha(feature)\r\n        background_layer_color_change = self.background_layer_color_change(feature)\r\n        background_layer_1 = apply_color_change(background_layer_alpha, background_layer_color_change, image)\r\n\r\n        eyebrow_layer_alpha = self.eyebrow_layer_alpha(feature)\r\n        eyebrow_layer_color_change = self.eyebrow_layer_color_change(feature)\r\n        eyebrow_layer = apply_color_change(eyebrow_layer_alpha, image, eyebrow_layer_color_change)\r\n\r\n        return [\r\n            eyebrow_layer,  # 0\r\n            eyebrow_layer_alpha,  # 1\r\n            eyebrow_layer_color_change,  # 2\r\n            background_layer_1,  # 3\r\n            background_layer_alpha,  # 4\r\n            background_layer_color_change,  # 5\r\n        ]\r\n\r\n    EYEBROW_LAYER_INDEX = 0\r\n    EYEBROW_LAYER_ALPHA_INDEX = 1\r\n    EYEBROW_LAYER_COLOR_CHANGE_INDEX = 2\r\n    BACKGROUND_LAYER_INDEX = 3\r\n    BACKGROUND_LAYER_ALPHA_INDEX = 4\r\n    BACKGROUND_LAYER_COLOR_CHANGE_INDEX = 5\r\n    OUTPUT_LENGTH = 6\r\n\r\n\r\nclass EyebrowDecomposer00Factory(ModuleFactory):\r\n    def __init__(self, args: EyebrowDecomposer00Args):\r\n        super().__init__()\r\n        self.args = args\r\n\r\n    def create(self) -> Module:\r\n        return EyebrowDecomposer00(self.args)\r\n"
  },
  {
    "path": "src/tha4/nn/eyebrow_morphing_combiner/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/nn/eyebrow_morphing_combiner/eyebrow_morphing_combiner_00.py",
    "content": "from typing import List, Optional\r\n\r\nimport torch\r\nfrom torch import Tensor\r\nfrom torch.nn import Module\r\n\r\nfrom tha4.nn.common.poser_encoder_decoder_00 import PoserEncoderDecoder00Args, PoserEncoderDecoder00\r\nfrom tha4.nn.image_processing_util import apply_color_change, apply_grid_change, apply_rgb_change\r\nfrom tha4.shion.core.module_factory import ModuleFactory\r\nfrom tha4.nn.nonlinearity_factory import ReLUFactory\r\nfrom tha4.nn.normalization import InstanceNorm2dFactory\r\nfrom tha4.nn.util import BlockArgs\r\n\r\n\r\nclass EyebrowMorphingCombiner00Args(PoserEncoderDecoder00Args):\r\n    def __init__(self,\r\n                 image_size: int = 128,\r\n                 image_channels: int = 4,\r\n                 num_pose_params: int = 12,\r\n                 start_channels: int = 64,\r\n                 bottleneck_image_size=16,\r\n                 num_bottleneck_blocks=6,\r\n                 max_channels: int = 512,\r\n                 block_args: Optional[BlockArgs] = None):\r\n        super().__init__(\r\n            image_size,\r\n            2 * image_channels,\r\n            image_channels,\r\n            num_pose_params,\r\n            start_channels,\r\n            bottleneck_image_size,\r\n            num_bottleneck_blocks,\r\n            max_channels,\r\n            block_args)\r\n\r\n\r\nclass EyebrowMorphingCombiner00(Module):\r\n    def __init__(self, args: EyebrowMorphingCombiner00Args):\r\n        super().__init__()\r\n        self.args = args\r\n        self.body = PoserEncoderDecoder00(args)\r\n        self.morphed_eyebrow_layer_grid_change = self.args.create_grid_change_block()\r\n        self.morphed_eyebrow_layer_alpha = self.args.create_alpha_block()\r\n        self.morphed_eyebrow_layer_color_change = self.args.create_color_change_block()\r\n        self.combine_alpha = self.args.create_alpha_block()\r\n\r\n    def forward(self, background_layer: Tensor, eyebrow_layer: Tensor, pose: Tensor, *args) -> List[Tensor]:\r\n        combined_image = torch.cat([background_layer, eyebrow_layer], dim=1)\r\n        feature = self.body(combined_image, pose)[0]\r\n\r\n        morphed_eyebrow_layer_grid_change = self.morphed_eyebrow_layer_grid_change(feature)\r\n        morphed_eyebrow_layer_alpha = self.morphed_eyebrow_layer_alpha(feature)\r\n        morphed_eyebrow_layer_color_change = self.morphed_eyebrow_layer_color_change(feature)\r\n        warped_eyebrow_layer = apply_grid_change(morphed_eyebrow_layer_grid_change, eyebrow_layer)\r\n        morphed_eyebrow_layer = apply_color_change(\r\n            morphed_eyebrow_layer_alpha, morphed_eyebrow_layer_color_change, warped_eyebrow_layer)\r\n\r\n        combine_alpha = self.combine_alpha(feature)\r\n        eyebrow_image = apply_rgb_change(combine_alpha, morphed_eyebrow_layer, background_layer)\r\n        eyebrow_image_no_combine_alpha = apply_rgb_change(\r\n            (morphed_eyebrow_layer[:, 3:4, :, :] + 1.0) / 2.0, morphed_eyebrow_layer, background_layer)\r\n\r\n        return [\r\n            eyebrow_image,  # 0\r\n            combine_alpha,  # 1\r\n            eyebrow_image_no_combine_alpha,  # 2\r\n            morphed_eyebrow_layer,  # 3\r\n            morphed_eyebrow_layer_alpha,  # 4\r\n            morphed_eyebrow_layer_color_change,  # 5\r\n            warped_eyebrow_layer,  # 6\r\n            morphed_eyebrow_layer_grid_change,  # 7\r\n        ]\r\n\r\n    EYEBROW_IMAGE_INDEX = 0\r\n    COMBINE_ALPHA_INDEX = 1\r\n    EYEBROW_IMAGE_NO_COMBINE_ALPHA_INDEX = 2\r\n    MORPHED_EYEBROW_LAYER_INDEX = 3\r\n    MORPHED_EYEBROW_LAYER_ALPHA_INDEX = 4\r\n    MORPHED_EYEBROW_LAYER_COLOR_CHANGE_INDEX = 5\r\n    WARPED_EYEBROW_LAYER_INDEX = 6\r\n    MORPHED_EYEBROW_LAYER_GRID_CHANGE_INDEX = 7\r\n    OUTPUT_LENGTH = 8\r\n\r\n\r\nclass EyebrowMorphingCombiner00Factory(ModuleFactory):\r\n    def __init__(self, args: EyebrowMorphingCombiner00Args):\r\n        super().__init__()\r\n        self.args = args\r\n\r\n    def create(self) -> Module:\r\n        return EyebrowMorphingCombiner00(self.args)\r\n"
  },
  {
    "path": "src/tha4/nn/face_morpher/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/nn/face_morpher/face_morpher_08.py",
    "content": "import math\r\nfrom typing import List, Optional\r\n\r\nimport torch\r\nfrom torch import Tensor\r\nfrom torch.nn import ModuleList, Sequential, Sigmoid, Tanh, Module\r\nfrom torch.nn.functional import affine_grid, grid_sample\r\n\r\nfrom tha4.shion.core.module_factory import ModuleFactory\r\nfrom tha4.nn.conv import create_conv3_block_from_block_args, \\\r\n    create_downsample_block_from_block_args, create_upsample_block_from_block_args, create_conv3_from_block_args, \\\r\n    create_conv3\r\nfrom tha4.nn.nonlinearity_factory import LeakyReLUFactory\r\nfrom tha4.nn.normalization import InstanceNorm2dFactory\r\nfrom tha4.nn.resnet_block import ResnetBlock\r\nfrom tha4.nn.util import BlockArgs\r\n\r\n\r\nclass FaceMorpher08Args:\r\n    def __init__(self,\r\n                 image_size: int = 256,\r\n                 image_channels: int = 4,\r\n                 num_expression_params: int = 67,\r\n                 start_channels: int = 16,\r\n                 bottleneck_image_size=4,\r\n                 num_bottleneck_blocks=3,\r\n                 max_channels: int = 512,\r\n                 block_args: Optional[BlockArgs] = None,\r\n                 output_iris_mouth_grid_change: bool = False):\r\n        self.max_channels = max_channels\r\n        self.num_bottleneck_blocks = num_bottleneck_blocks\r\n        assert bottleneck_image_size > 1\r\n        self.bottleneck_image_size = bottleneck_image_size\r\n        self.start_channels = start_channels\r\n        self.image_channels = image_channels\r\n        self.num_expression_params = num_expression_params\r\n        self.image_size = image_size\r\n        self.output_iris_mouth_grid_change = output_iris_mouth_grid_change\r\n\r\n        if block_args is None:\r\n            self.block_args = BlockArgs(\r\n                normalization_layer_factory=InstanceNorm2dFactory(),\r\n                nonlinearity_factory=LeakyReLUFactory(negative_slope=0.2, inplace=True))\r\n        else:\r\n            self.block_args = block_args\r\n\r\n\r\nclass FaceMorpher08(Module):\r\n    def __init__(self, args: FaceMorpher08Args):\r\n        super().__init__()\r\n        self.args = args\r\n        self.num_levels = int(math.log2(args.image_size // args.bottleneck_image_size)) + 1\r\n\r\n        self.downsample_blocks = ModuleList()\r\n        self.downsample_blocks.append(\r\n            create_conv3_block_from_block_args(\r\n                args.image_channels,\r\n                args.start_channels,\r\n                args.block_args))\r\n        current_image_size = args.image_size\r\n        current_num_channels = args.start_channels\r\n        while current_image_size > args.bottleneck_image_size:\r\n            next_image_size = current_image_size // 2\r\n            next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)\r\n            self.downsample_blocks.append(create_downsample_block_from_block_args(\r\n                in_channels=current_num_channels,\r\n                out_channels=next_num_channels,\r\n                is_output_1x1=False,\r\n                block_args=args.block_args))\r\n            current_image_size = next_image_size\r\n            current_num_channels = next_num_channels\r\n        assert len(self.downsample_blocks) == self.num_levels\r\n\r\n        self.bottleneck_blocks = ModuleList()\r\n        self.bottleneck_blocks.append(create_conv3_block_from_block_args(\r\n            in_channels=current_num_channels + args.num_expression_params,\r\n            out_channels=current_num_channels,\r\n            block_args=args.block_args))\r\n        for i in range(1, args.num_bottleneck_blocks):\r\n            self.bottleneck_blocks.append(\r\n                ResnetBlock.create(\r\n                    num_channels=current_num_channels,\r\n                    is1x1=False,\r\n                    block_args=args.block_args))\r\n\r\n        self.upsample_blocks = ModuleList()\r\n        while current_image_size < args.image_size:\r\n            next_image_size = current_image_size * 2\r\n            next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)\r\n            self.upsample_blocks.append(create_upsample_block_from_block_args(\r\n                in_channels=current_num_channels,\r\n                out_channels=next_num_channels,\r\n                block_args=args.block_args))\r\n            current_image_size = next_image_size\r\n            current_num_channels = next_num_channels\r\n\r\n        self.iris_mouth_grid_change = self.create_grid_change_block()\r\n        self.iris_mouth_color_change = self.create_color_change_block()\r\n        self.iris_mouth_alpha = self.create_alpha_block()\r\n\r\n        self.eye_color_change = self.create_color_change_block()\r\n        self.eye_alpha = self.create_alpha_block()\r\n\r\n    def create_alpha_block(self):\r\n        return Sequential(\r\n            create_conv3(\r\n                in_channels=self.args.start_channels,\r\n                out_channels=1,\r\n                bias=True,\r\n                initialization_method=self.args.block_args.initialization_method,\r\n                use_spectral_norm=False),\r\n            Sigmoid())\r\n\r\n    def create_color_change_block(self):\r\n        return Sequential(\r\n            create_conv3_from_block_args(\r\n                in_channels=self.args.start_channels,\r\n                out_channels=self.args.image_channels,\r\n                bias=True,\r\n                block_args=self.args.block_args),\r\n            Tanh())\r\n\r\n    def create_grid_change_block(self):\r\n        return create_conv3(\r\n            in_channels=self.args.start_channels,\r\n            out_channels=2,\r\n            bias=False,\r\n            initialization_method='zero',\r\n            use_spectral_norm=False)\r\n\r\n    def get_num_output_channels_from_level(self, level: int):\r\n        return self.get_num_output_channels_from_image_size(self.args.image_size // (2 ** level))\r\n\r\n    def get_num_output_channels_from_image_size(self, image_size: int):\r\n        return min(self.args.start_channels * (self.args.image_size // image_size), self.args.max_channels)\r\n\r\n    def merge_down(self, top_layer: Tensor, bottom_layer: Tensor):\r\n        top_layer_rgb = top_layer[:, 0:3, :, :]\r\n        top_layer_a = top_layer[:, 3:4, :, :]\r\n        return bottom_layer * (1 - top_layer_a) + torch.cat([top_layer_rgb * top_layer_a, top_layer_a], dim=1)\r\n\r\n    def apply_grid_change(self, grid_change, image: Tensor) -> Tensor:\r\n        n, c, h, w = image.shape\r\n        device = grid_change.device\r\n        grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2)\r\n        identity = torch.tensor(\r\n            [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],\r\n            device=device,\r\n            dtype=grid_change.dtype).unsqueeze(0).repeat(n, 1, 1)\r\n        base_grid = affine_grid(identity, [n, c, h, w], align_corners=False)\r\n        grid = base_grid + grid_change\r\n        resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=False)\r\n        return resampled_image\r\n\r\n    def apply_color_change(self, alpha, color_change, image: Tensor) -> Tensor:\r\n        return color_change * alpha + image * (1 - alpha)\r\n\r\n    def forward(self, image: Tensor, pose: Tensor, *args) -> List[Tensor]:\r\n        feature = image\r\n        for block in self.downsample_blocks:\r\n            feature = block(feature)\r\n        n, c = pose.shape\r\n        pose = pose.view(n, c, 1, 1).repeat(1, 1, self.args.bottleneck_image_size, self.args.bottleneck_image_size)\r\n        feature = torch.cat([feature, pose], dim=1)\r\n        for block in self.bottleneck_blocks:\r\n            feature = block(feature)\r\n        for block in self.upsample_blocks:\r\n            feature = block(feature)\r\n\r\n        iris_mouth_grid_change = self.iris_mouth_grid_change(feature)\r\n        iris_mouth_image_0 = self.apply_grid_change(iris_mouth_grid_change, image)\r\n        iris_mouth_color_change = self.iris_mouth_color_change(feature)\r\n        iris_mouth_alpha = self.iris_mouth_alpha(feature)\r\n        iris_mouth_image_1 = self.apply_color_change(iris_mouth_alpha, iris_mouth_color_change, iris_mouth_image_0)\r\n\r\n        eye_color_change = self.eye_color_change(feature)\r\n        eye_alpha = self.eye_alpha(feature)\r\n        output_image = self.apply_color_change(eye_alpha, eye_color_change, iris_mouth_image_1.detach())\r\n\r\n        outputs = [\r\n            output_image,  # 0\r\n            eye_alpha,  # 1\r\n            eye_color_change,  # 2\r\n            iris_mouth_image_1,  # 3\r\n            iris_mouth_alpha,  # 4\r\n            iris_mouth_color_change,  # 5\r\n            iris_mouth_image_0,  # 6\r\n        ]\r\n\r\n        if self.args.output_iris_mouth_grid_change:\r\n            outputs.append(iris_mouth_grid_change)\r\n\r\n        return outputs\r\n\r\n    OUTPUT_IMAGE_INDEX = 0\r\n    EYE_ALPHA_INDEX = 1\r\n    EYE_COLOR_CHANGE_INDEX = 2\r\n    IRIS_MOUTH_IMAGE_1_INDEX = 3\r\n    IRIS_MOUTH_ALPHA_INDEX = 4\r\n    IRIS_MOUTH_COLOR_CHANGE_INDEX = 5\r\n    IRIS_MOUTH_IMAGE_0_INDEX = 6\r\n    IRIS_MOUTH_GRID_CHANGE_INDEX = 7\r\n\r\n\r\nclass FaceMorpher08Factory(ModuleFactory):\r\n    def __init__(self, args: FaceMorpher08Args):\r\n        super().__init__()\r\n        self.args = args\r\n\r\n    def create(self) -> Module:\r\n        return FaceMorpher08(self.args)\r\n"
  },
  {
    "path": "src/tha4/nn/image_processing_util.py",
    "content": "import torch\r\nfrom torch import Tensor\r\nfrom torch.nn.functional import affine_grid, grid_sample\r\n\r\n\r\ndef apply_rgb_change(alpha: Tensor, color_change: Tensor, image: Tensor):\r\n    image_rgb = image[:, 0:3, :, :]\r\n    color_change_rgb = color_change[:, 0:3, :, :]\r\n    output_rgb = color_change_rgb * alpha + image_rgb * (1 - alpha)\r\n    return torch.cat([output_rgb, image[:, 3:4, :, :]], dim=1)\r\n\r\n\r\ndef apply_grid_change(grid_change, image: Tensor) -> Tensor:\r\n    n, c, h, w = image.shape\r\n    device = grid_change.device\r\n    grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2)\r\n    identity = torch.tensor(\r\n        [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],\r\n        dtype=grid_change.dtype,\r\n        device=device).unsqueeze(0).repeat(n, 1, 1)\r\n    base_grid = affine_grid(identity, [n, c, h, w], align_corners=False)\r\n    grid = base_grid + grid_change\r\n    resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=False)\r\n    return resampled_image\r\n\r\n\r\nclass GridChangeApplier:\r\n    def __init__(self):\r\n        self.last_n = None\r\n        self.last_device = None\r\n        self.last_identity = None\r\n\r\n    def apply(self, grid_change: Tensor, image: Tensor, align_corners: bool = False) -> Tensor:\r\n        n, c, h, w = image.shape\r\n        device = grid_change.device\r\n        grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2)\r\n\r\n        if n == self.last_n and device == self.last_device:\r\n            identity = self.last_identity\r\n        else:\r\n            identity = torch.tensor(\r\n                [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],\r\n                dtype=grid_change.dtype,\r\n                device=device,\r\n                requires_grad=False) \\\r\n                .unsqueeze(0).repeat(n, 1, 1)\r\n            self.last_identity = identity\r\n            self.last_n = n\r\n            self.last_device = device\r\n        base_grid = affine_grid(identity, [n, c, h, w], align_corners=align_corners)\r\n\r\n        grid = base_grid + grid_change\r\n        resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=align_corners)\r\n        return resampled_image\r\n\r\n\r\ndef apply_color_change(alpha, color_change, image: Tensor) -> Tensor:\r\n    return color_change * alpha + image * (1 - alpha)\r\n"
  },
  {
    "path": "src/tha4/nn/init_function.py",
    "content": "from typing import Callable\r\n\r\nimport torch\r\nfrom torch import zero_\r\nfrom torch.nn import Module\r\nfrom torch.nn.init import kaiming_normal_, xavier_normal_, normal_\r\n\r\n\r\ndef create_init_function(method: str = 'none') -> Callable[[Module], Module]:\r\n    def init(module: Module):\r\n        if method == 'none':\r\n            return module\r\n        elif method == 'he':\r\n            kaiming_normal_(module.weight)\r\n            return module\r\n        elif method == 'xavier':\r\n            xavier_normal_(module.weight)\r\n            return module\r\n        elif method == 'dcgan':\r\n            normal_(module.weight, 0.0, 0.02)\r\n            return module\r\n        elif method == 'dcgan_001':\r\n            normal_(module.weight, 0.0, 0.01)\r\n            return module\r\n        elif method == \"zero\":\r\n            with torch.no_grad():\r\n                zero_(module.weight)\r\n            return module\r\n        else:\r\n            raise (\"Invalid initialization method %s\" % method)\r\n\r\n    return init\r\n\r\n\r\nclass HeInitialization:\r\n    def __init__(self, a: int = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'):\r\n        self.nonlinearity = nonlinearity\r\n        self.mode = mode\r\n        self.a = a\r\n\r\n    def __call__(self, module: Module) -> Module:\r\n        with torch.no_grad():\r\n            kaiming_normal_(module.weight, a=self.a, mode=self.mode, nonlinearity=self.nonlinearity)\r\n        return module\r\n\r\n\r\nclass NormalInitialization:\r\n    def __init__(self, mean: float = 0.0, std: float = 1.0):\r\n        self.std = std\r\n        self.mean = mean\r\n\r\n    def __call__(self, module: Module) -> Module:\r\n        with torch.no_grad():\r\n            normal_(module.weight, self.mean, self.std)\r\n        return module\r\n\r\n\r\nclass XavierInitialization:\r\n    def __init__(self, gain: float = 1.0):\r\n        self.gain = gain\r\n\r\n    def __call__(self, module: Module) -> Module:\r\n        with torch.no_grad():\r\n            xavier_normal_(module.weight, self.gain)\r\n        return module\r\n\r\n\r\nclass ZeroInitialization:\r\n    def __call__(self, module: Module) -> Module:\r\n        with torch.no_grad:\r\n            zero_(module.weight)\r\n        return module\r\n\r\nclass NoInitialization:\r\n    def __call__(self, module: Module) -> Module:\r\n        return module"
  },
  {
    "path": "src/tha4/nn/morpher/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/nn/morpher/morpher_00.py",
    "content": "from typing import List\r\n\r\nimport torch\r\nfrom torch import Tensor\r\nfrom torch.nn import Module\r\n\r\nfrom tha4.shion.core.module_factory import ModuleFactory\r\nfrom tha4.nn.image_processing_util import GridChangeApplier\r\nfrom tha4.nn.common.unet import UnetArgs, Unet, AttentionBlockArgs\r\n\r\n\r\ndef apply_color_change(alpha, color_change, image: Tensor) -> Tensor:\r\n    return color_change * alpha + image * (1 - alpha)\r\n\r\n\r\nclass Morpher00Args:\r\n    def __init__(self,\r\n                 image_size: int,\r\n                 image_channels: int,\r\n                 num_pose_parameters: int,\r\n                 unet_args: UnetArgs):\r\n        assert unet_args.in_channels == image_channels\r\n        assert unet_args.out_channels == (\r\n                image_channels +  # direct\r\n                2 +  # warp\r\n                1  # alpha\r\n        )\r\n        assert unet_args.cond_input_channels == num_pose_parameters\r\n        self.image_channels = image_channels\r\n        self.image_size = image_size\r\n        self.num_pose_parameters = num_pose_parameters\r\n        self.unet_args = unet_args\r\n\r\n\r\nclass Morpher00(Module):\r\n    def __init__(self, args: Morpher00Args):\r\n        super().__init__()\r\n        self.args = args\r\n        self.body = Unet(args.unet_args)\r\n        self.grid_change_applier = GridChangeApplier()\r\n\r\n    def forward(self, image: torch.Tensor, pose: torch.Tensor) -> List[Tensor]:\r\n        assert len(image.shape) == 4\r\n        assert image.shape[1] == self.args.image_channels\r\n        assert image.shape[2] == self.args.image_size\r\n        assert image.shape[3] == self.args.image_size\r\n        assert len(pose.shape) == 2\r\n        assert image.shape[0] == pose.shape[0]\r\n        assert pose.shape[1] == self.args.num_pose_parameters\r\n\r\n        t = torch.zeros(image.shape[0], 1, device=image.device)\r\n        body_output = self.body(image, t, pose)\r\n        direct = body_output[:, 0:self.args.image_channels, :, :]\r\n        grid_change = body_output[:, self.args.image_channels:self.args.image_channels + 2, :, :]\r\n        alpha = torch.sigmoid(body_output[:, self.args.image_channels + 2:self.args.image_channels + 3, :, :])\r\n\r\n        warped = self.grid_change_applier.apply(grid_change, image)\r\n        merged = apply_color_change(alpha, direct, warped)\r\n\r\n        return [\r\n            merged,\r\n            alpha,\r\n            warped,\r\n            grid_change,\r\n            direct\r\n        ]\r\n\r\n    INDEX_MERGED = 0\r\n    INDEX_ALPHA = 1\r\n    INDEX_WARPED = 2\r\n    INDEX_GRID_CHANGE = 3\r\n    INDEX_DIRECT = 4\r\n\r\n\r\nclass Morpher00Factory(ModuleFactory):\r\n    def __init__(self, args: Morpher00Args):\r\n        self.args = args\r\n\r\n    def create(self) -> Module:\r\n        return Morpher00(self.args)\r\n"
  },
  {
    "path": "src/tha4/nn/nonlinearity_factory.py",
    "content": "from typing import Optional\r\n\r\nfrom torch.nn import Module, ReLU, LeakyReLU, ELU, ReLU6, Hardswish, SiLU, Tanh, Sigmoid\r\n\r\nfrom tha4.shion.core.module_factory import ModuleFactory\r\n\r\n\r\nclass ReLUFactory(ModuleFactory):\r\n    def __init__(self, inplace: bool = False):\r\n        self.inplace = inplace\r\n\r\n    def create(self) -> Module:\r\n        return ReLU(self.inplace)\r\n\r\n\r\nclass LeakyReLUFactory(ModuleFactory):\r\n    def __init__(self, inplace: bool = False, negative_slope: float = 1e-2):\r\n        self.negative_slope = negative_slope\r\n        self.inplace = inplace\r\n\r\n    def create(self) -> Module:\r\n        return LeakyReLU(inplace=self.inplace, negative_slope=self.negative_slope)\r\n\r\n\r\nclass ELUFactory(ModuleFactory):\r\n    def __init__(self, inplace: bool = False, alpha: float = 1.0):\r\n        self.alpha = alpha\r\n        self.inplace = inplace\r\n\r\n    def create(self) -> Module:\r\n        return ELU(inplace=self.inplace, alpha=self.alpha)\r\n\r\n\r\nclass ReLU6Factory(ModuleFactory):\r\n    def __init__(self, inplace: bool = False):\r\n        self.inplace = inplace\r\n\r\n    def create(self) -> Module:\r\n        return ReLU6(inplace=self.inplace)\r\n\r\n\r\nclass SiLUFactory(ModuleFactory):\r\n    def __init__(self, inplace: bool = False):\r\n        self.inplace = inplace\r\n\r\n    def create(self) -> Module:\r\n        return SiLU(inplace=self.inplace)\r\n\r\n\r\nclass HardswishFactory(ModuleFactory):\r\n    def __init__(self, inplace: bool = False):\r\n        self.inplace = inplace\r\n\r\n    def create(self) -> Module:\r\n        return Hardswish(inplace=self.inplace)\r\n\r\n\r\nclass TanhFactory(ModuleFactory):\r\n    def create(self) -> Module:\r\n        return Tanh()\r\n\r\n\r\nclass SigmoidFactory(ModuleFactory):\r\n    def create(self) -> Module:\r\n        return Sigmoid()\r\n\r\n\r\ndef resolve_nonlinearity_factory(nonlinearity_fatory: Optional[ModuleFactory]) -> ModuleFactory:\r\n    if nonlinearity_fatory is None:\r\n        return ReLUFactory(inplace=False)\r\n    else:\r\n        return nonlinearity_fatory\r\n"
  },
  {
    "path": "src/tha4/nn/normalization.py",
    "content": "from abc import ABC, abstractmethod\r\nfrom typing import Optional\r\n\r\nimport torch\r\nfrom torch import layer_norm\r\nfrom torch.nn import Module, BatchNorm2d, InstanceNorm2d, Parameter\r\nfrom torch.nn.init import normal_, constant_\r\n\r\nfrom tha4.nn.pass_through import PassThrough\r\n\r\n\r\nclass PixelNormalization(Module):\r\n    def __init__(self, epsilon=1e-8):\r\n        super().__init__()\r\n        self.epsilon = epsilon\r\n\r\n    def forward(self, x):\r\n        return x / torch.sqrt((x ** 2).mean(dim=1, keepdim=True) + self.epsilon)\r\n\r\n\r\nclass NormalizationLayerFactory(ABC):\r\n    def __init__(self):\r\n        super().__init__()\r\n\r\n    @abstractmethod\r\n    def create(self, num_features: int, affine: bool = True) -> Module:\r\n        pass\r\n\r\n    @staticmethod\r\n    def resolve_2d(factory: Optional['NormalizationLayerFactory']) -> 'NormalizationLayerFactory':\r\n        if factory is None:\r\n            return InstanceNorm2dFactory()\r\n        else:\r\n            return factory\r\n\r\n\r\nclass Bias2d(Module):\r\n    def __init__(self, num_features: int):\r\n        super().__init__()\r\n        self.num_features = num_features\r\n        self.bias = Parameter(torch.zeros(1, num_features, 1, 1))\r\n\r\n    def forward(self, x):\r\n        return x + self.bias\r\n\r\n\r\nclass NoNorm2dFactory(NormalizationLayerFactory):\r\n    def __init__(self):\r\n        super().__init__()\r\n\r\n    def create(self, num_features: int, affine: bool = True) -> Module:\r\n        if affine:\r\n            return Bias2d(num_features)\r\n        else:\r\n            return PassThrough()\r\n\r\n\r\nclass BatchNorm2dFactory(NormalizationLayerFactory):\r\n    def __init__(self,\r\n                 weight_mean: Optional[float] = None,\r\n                 weight_std: Optional[float] = None,\r\n                 bias: Optional[float] = None):\r\n        super().__init__()\r\n        self.bias = bias\r\n        self.weight_std = weight_std\r\n        self.weight_mean = weight_mean\r\n\r\n    def get_weight_mean(self):\r\n        if self.weight_mean is None:\r\n            return 1.0\r\n        else:\r\n            return self.weight_mean\r\n\r\n    def get_weight_std(self):\r\n        if self.weight_std is None:\r\n            return 0.02\r\n        else:\r\n            return self.weight_std\r\n\r\n    def create(self, num_features: int, affine: bool = True) -> Module:\r\n        module = BatchNorm2d(num_features=num_features, affine=affine)\r\n        if affine:\r\n            if self.weight_mean is not None or self.weight_std is not None:\r\n                normal_(module.weight, self.get_weight_mean(), self.get_weight_std())\r\n            if self.bias is not None:\r\n                constant_(module.bias, self.bias)\r\n        return module\r\n\r\n\r\nclass InstanceNorm2dFactory(NormalizationLayerFactory):\r\n    def __init__(self):\r\n        super().__init__()\r\n\r\n    def create(self, num_features: int, affine: bool = True) -> Module:\r\n        return InstanceNorm2d(num_features=num_features, affine=affine)\r\n\r\n\r\nclass PixelNormFactory(NormalizationLayerFactory):\r\n    def __init__(self):\r\n        super().__init__()\r\n\r\n    def create(self, num_features: int, affine: bool = True) -> Module:\r\n        return PixelNormalization()\r\n\r\n\r\nclass LayerNorm2d(Module):\r\n    def __init__(self, channels: int, affine: bool = True):\r\n        super(LayerNorm2d, self).__init__()\r\n        self.channels = channels\r\n        self.affine = affine\r\n\r\n        if self.affine:\r\n            self.weight = Parameter(torch.ones(1, channels, 1, 1))\r\n            self.bias = Parameter(torch.zeros(1, channels, 1, 1))\r\n\r\n    def forward(self, x):\r\n        shape = x.size()[1:]\r\n        y = layer_norm(x, shape) * self.weight + self.bias\r\n        return y\r\n\r\nclass LayerNorm2dFactory(NormalizationLayerFactory):\r\n    def __init__(self):\r\n        super().__init__()\r\n\r\n    def create(self, num_features: int, affine: bool = True) -> Module:\r\n        return LayerNorm2d(channels=num_features, affine=affine)\r\n"
  },
  {
    "path": "src/tha4/nn/pass_through.py",
    "content": "from torch.nn import Module\r\n\r\n\r\nclass PassThrough(Module):\r\n    def __init__(self):\r\n        super().__init__()\r\n\r\n    def forward(self, x):\r\n        return x"
  },
  {
    "path": "src/tha4/nn/resnet_block.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch.nn import Module, Sequential, Parameter\n\nfrom tha4.shion.core.module_factory import ModuleFactory\nfrom tha4.nn.conv import create_conv1, create_conv3\nfrom tha4.nn.nonlinearity_factory import resolve_nonlinearity_factory\nfrom tha4.nn.normalization import NormalizationLayerFactory\nfrom tha4.nn.util import BlockArgs\n\n\nclass ResnetBlock(Module):\n    @staticmethod\n    def create(num_channels: int,\n               is1x1: bool = False,\n               use_scale_parameters: bool = False,\n               block_args: Optional[BlockArgs] = None):\n        if block_args is None:\n            block_args = BlockArgs()\n        return ResnetBlock(num_channels,\n                           is1x1,\n                           block_args.initialization_method,\n                           block_args.nonlinearity_factory,\n                           block_args.normalization_layer_factory,\n                           block_args.use_spectral_norm,\n                           use_scale_parameters)\n\n    def __init__(self,\n                 num_channels: int,\n                 is1x1: bool = False,\n                 initialization_method: str = 'he',\n                 nonlinearity_factory: ModuleFactory = None,\n                 normalization_layer_factory: Optional[NormalizationLayerFactory] = None,\n                 use_spectral_norm: bool = False,\n                 use_scale_parameter: bool = False):\n        super().__init__()\n        self.use_scale_parameter = use_scale_parameter\n        if self.use_scale_parameter:\n            self.scale = Parameter(torch.zeros(1))\n        nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory)\n        if is1x1:\n            self.resnet_path = Sequential(\n                create_conv1(num_channels, num_channels, initialization_method,\n                             bias=True,\n                             use_spectral_norm=use_spectral_norm),\n                nonlinearity_factory.create(),\n                create_conv1(num_channels, num_channels, initialization_method,\n                             bias=True,\n                             use_spectral_norm=use_spectral_norm))\n        else:\n            self.resnet_path = Sequential(\n                create_conv3(num_channels, num_channels,\n                             bias=False, initialization_method=initialization_method,\n                             use_spectral_norm=use_spectral_norm),\n                NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(num_channels, affine=True),\n                nonlinearity_factory.create(),\n                create_conv3(num_channels, num_channels,\n                             bias=False, initialization_method=initialization_method,\n                             use_spectral_norm=use_spectral_norm),\n                NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(num_channels, affine=True))\n\n    def forward(self, x):\n        if self.use_scale_parameter:\n            return x + self.scale * self.resnet_path(x)\n        else:\n            return x + self.resnet_path(x)\n"
  },
  {
    "path": "src/tha4/nn/resnet_block_seperable.py",
    "content": "from typing import Optional\r\n\r\nimport torch\r\nfrom torch.nn import Module, Sequential, Parameter\r\n\r\nfrom tha4.shion.core.module_factory import ModuleFactory\r\nfrom tha4.nn.conv import create_conv1\r\nfrom tha4.nn.nonlinearity_factory import resolve_nonlinearity_factory\r\nfrom tha4.nn.normalization import NormalizationLayerFactory\r\nfrom tha4.nn.separable_conv import create_separable_conv3\r\nfrom tha4.nn.util import BlockArgs\r\n\r\n\r\nclass ResnetBlockSeparable(Module):\r\n    @staticmethod\r\n    def create(num_channels: int,\r\n               is1x1: bool = False,\r\n               use_scale_parameters: bool = False,\r\n               block_args: Optional[BlockArgs] = None):\r\n        if block_args is None:\r\n            block_args = BlockArgs()\r\n        return ResnetBlockSeparable(\r\n            num_channels,\r\n            is1x1,\r\n            block_args.initialization_method,\r\n            block_args.nonlinearity_factory,\r\n            block_args.normalization_layer_factory,\r\n            block_args.use_spectral_norm,\r\n            use_scale_parameters)\r\n\r\n    def __init__(self,\r\n                 num_channels: int,\r\n                 is1x1: bool = False,\r\n                 initialization_method: str = 'he',\r\n                 nonlinearity_factory: ModuleFactory = None,\r\n                 normalization_layer_factory: Optional[NormalizationLayerFactory] = None,\r\n                 use_spectral_norm: bool = False,\r\n                 use_scale_parameter: bool = False):\r\n        super().__init__()\r\n        self.use_scale_parameter = use_scale_parameter\r\n        if self.use_scale_parameter:\r\n            self.scale = Parameter(torch.zeros(1))\r\n        nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory)\r\n        if is1x1:\r\n            self.resnet_path = Sequential(\r\n                create_conv1(num_channels, num_channels, initialization_method,\r\n                             bias=True,\r\n                             use_spectral_norm=use_spectral_norm),\r\n                nonlinearity_factory.create(),\r\n                create_conv1(num_channels, num_channels, initialization_method,\r\n                             bias=True,\r\n                             use_spectral_norm=use_spectral_norm))\r\n        else:\r\n            self.resnet_path = Sequential(\r\n                create_separable_conv3(\r\n                    num_channels, num_channels,\r\n                    bias=False, initialization_method=initialization_method,\r\n                    use_spectral_norm=use_spectral_norm),\r\n                NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(num_channels, affine=True),\r\n                nonlinearity_factory.create(),\r\n                create_separable_conv3(\r\n                    num_channels, num_channels,\r\n                    bias=False, initialization_method=initialization_method,\r\n                    use_spectral_norm=use_spectral_norm),\r\n                NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(num_channels, affine=True))\r\n\r\n    def forward(self, x):\r\n        if self.use_scale_parameter:\r\n            return x + self.scale * self.resnet_path(x)\r\n        else:\r\n            return x + self.resnet_path(x)\r\n"
  },
  {
    "path": "src/tha4/nn/separable_conv.py",
    "content": "from typing import Optional\r\n\r\nfrom torch.nn import Sequential, Conv2d, ConvTranspose2d, Module\r\n\r\nfrom tha4.nn.normalization import NormalizationLayerFactory\r\nfrom tha4.nn.util import BlockArgs, wrap_conv_or_linear_module\r\n\r\n\r\ndef create_separable_conv3(in_channels: int, out_channels: int,\r\n                           bias: bool = False,\r\n                           initialization_method='he',\r\n                           use_spectral_norm: bool = False) -> Module:\r\n    return Sequential(\r\n        wrap_conv_or_linear_module(\r\n            Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=False, groups=in_channels),\r\n            initialization_method,\r\n            use_spectral_norm),\r\n        wrap_conv_or_linear_module(\r\n            Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),\r\n            initialization_method,\r\n            use_spectral_norm))\r\n\r\n\r\ndef create_separable_conv7(in_channels: int, out_channels: int,\r\n                           bias: bool = False,\r\n                           initialization_method='he',\r\n                           use_spectral_norm: bool = False) -> Module:\r\n    return Sequential(\r\n        wrap_conv_or_linear_module(\r\n            Conv2d(in_channels, in_channels, kernel_size=7, stride=1, padding=3, bias=False, groups=in_channels),\r\n            initialization_method,\r\n            use_spectral_norm),\r\n        wrap_conv_or_linear_module(\r\n            Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),\r\n            initialization_method,\r\n            use_spectral_norm))\r\n\r\n\r\ndef create_separable_conv3_block(\r\n        in_channels: int, out_channels: int, block_args: Optional[BlockArgs] = None):\r\n    if block_args is None:\r\n        block_args = BlockArgs()\r\n    return Sequential(\r\n        wrap_conv_or_linear_module(\r\n            Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=False, groups=in_channels),\r\n            block_args.initialization_method,\r\n            block_args.use_spectral_norm),\r\n        wrap_conv_or_linear_module(\r\n            Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),\r\n            block_args.initialization_method,\r\n            block_args.use_spectral_norm),\r\n        NormalizationLayerFactory.resolve_2d(block_args.normalization_layer_factory).create(out_channels, affine=True),\r\n        block_args.nonlinearity_factory.create())\r\n\r\n\r\ndef create_separable_conv7_block(\r\n        in_channels: int, out_channels: int, block_args: Optional[BlockArgs] = None):\r\n    if block_args is None:\r\n        block_args = BlockArgs()\r\n    return Sequential(\r\n        wrap_conv_or_linear_module(\r\n            Conv2d(in_channels, in_channels, kernel_size=7, stride=1, padding=3, bias=False, groups=in_channels),\r\n            block_args.initialization_method,\r\n            block_args.use_spectral_norm),\r\n        wrap_conv_or_linear_module(\r\n            Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),\r\n            block_args.initialization_method,\r\n            block_args.use_spectral_norm),\r\n        NormalizationLayerFactory.resolve_2d(block_args.normalization_layer_factory).create(out_channels, affine=True),\r\n        block_args.nonlinearity_factory.create())\r\n\r\n\r\ndef create_separable_downsample_block(\r\n        in_channels: int, out_channels: int, is_output_1x1: bool, block_args: Optional[BlockArgs] = None):\r\n    if block_args is None:\r\n        block_args = BlockArgs()\r\n    if is_output_1x1:\r\n        return Sequential(\r\n            wrap_conv_or_linear_module(\r\n                Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1, bias=False, groups=in_channels),\r\n                block_args.initialization_method,\r\n                block_args.use_spectral_norm),\r\n            wrap_conv_or_linear_module(\r\n                Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),\r\n                block_args.initialization_method,\r\n                block_args.use_spectral_norm),\r\n            block_args.nonlinearity_factory.create())\r\n    else:\r\n        return Sequential(\r\n            wrap_conv_or_linear_module(\r\n                Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1, bias=False, groups=in_channels),\r\n                block_args.initialization_method,\r\n                block_args.use_spectral_norm),\r\n            wrap_conv_or_linear_module(\r\n                Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),\r\n                block_args.initialization_method,\r\n                block_args.use_spectral_norm),\r\n            NormalizationLayerFactory.resolve_2d(block_args.normalization_layer_factory)\r\n                .create(out_channels, affine=True),\r\n            block_args.nonlinearity_factory.create())\r\n\r\n\r\ndef create_separable_upsample_block(\r\n        in_channels: int, out_channels: int, block_args: Optional[BlockArgs] = None):\r\n    if block_args is None:\r\n        block_args = BlockArgs()\r\n    return Sequential(\r\n        wrap_conv_or_linear_module(\r\n            ConvTranspose2d(\r\n                in_channels, in_channels, kernel_size=4, stride=2, padding=1, bias=False, groups=in_channels),\r\n            block_args.initialization_method,\r\n            block_args.use_spectral_norm),\r\n        wrap_conv_or_linear_module(\r\n            Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),\r\n            block_args.initialization_method,\r\n            block_args.use_spectral_norm),\r\n        NormalizationLayerFactory.resolve_2d(block_args.normalization_layer_factory)\r\n            .create(out_channels, affine=True),\r\n        block_args.nonlinearity_factory.create())\r\n"
  },
  {
    "path": "src/tha4/nn/siren/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/nn/siren/face_morpher/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/nn/siren/face_morpher/siren_face_morpher_00.py",
    "content": "from typing import Optional, List\r\n\r\nimport torch\r\nfrom torch import Tensor\r\nfrom torch.nn import Module\r\nfrom torch.nn.functional import affine_grid\r\n\r\nfrom tha4.shion.core.module_factory import ModuleFactory\r\nfrom tha4.nn.siren.vanilla.siren import SirenArgs, Siren\r\n\r\n\r\nclass SirenFaceMorpher00Args:\r\n    def __init__(self,\r\n                 image_size: int,\r\n                 image_channels: int,\r\n                 pose_size: int,\r\n                 siren_args: SirenArgs):\r\n        assert siren_args.in_channels == pose_size + 2\r\n        assert siren_args.out_channels == image_channels\r\n        assert not siren_args.use_tanh\r\n\r\n        self.siren_args = siren_args\r\n        self.pose_size = pose_size\r\n        self.image_size = image_size\r\n        self.image_channels = image_channels\r\n\r\n\r\nclass SirenFaceMorpher00(Module):\r\n    def __init__(self, args: SirenFaceMorpher00Args):\r\n        super().__init__()\r\n        self.args = args\r\n        self.siren = Siren(self.args.siren_args)\r\n\r\n    def forward(self, pose: Tensor, position: Optional[Tensor] = None) -> Tensor:\r\n        n, p = pose.shape[0], pose.shape[1]\r\n        device = pose.device\r\n\r\n        if position is None:\r\n            h, w = self.args.image_size, self.args.image_size\r\n            identity = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device).unsqueeze(0)\r\n            position = affine_grid(identity, [1, 1, h, w], align_corners=False) \\\r\n                .view(1, h * w, 2)\r\n            position = torch.transpose(position, dim0=1, dim1=2).view(1, 2, h, w) \\\r\n                .repeat(n, 1, 1, 1)\r\n\r\n        h, w = position.shape[2], position.shape[3]\r\n        pose_image = pose.view(n, p, 1, 1).repeat(1, 1, h, w)\r\n\r\n        siren_input = torch.cat([position, pose_image], dim=1)\r\n\r\n        return self.siren.forward(siren_input)\r\n\r\n\r\nclass SirenFaceMorpher00Factory(ModuleFactory):\r\n    def __init__(self, args: SirenFaceMorpher00Args):\r\n        self.args = args\r\n\r\n    def create(self) -> Module:\r\n        return SirenFaceMorpher00(self.args)\r\n"
  },
  {
    "path": "src/tha4/nn/siren/face_morpher/siren_face_morpher_00_trainer.py",
    "content": "from typing import Dict, List, Optional, Callable\r\n\r\nimport torch\r\nfrom tha4.shion.base.dataset.lazy_tensor_dataset import LazyTensorDataset\r\nfrom tha4.shion.base.image_util import extract_pytorch_image_from_filelike\r\nfrom tha4.shion.base.loss.l1_loss import L1Loss, MaskedL1Loss\r\nfrom tha4.shion.base.loss.sum_loss import SumLoss\r\nfrom tha4.shion.base.optimizer_factories import AdamOptimizerFactory\r\nfrom tha4.shion.core.training.distrib.distributed_trainer import DistributedTrainer\r\nfrom tha4.dataset.image_poses_and_aother_images_dataset import ImagePosesAndOtherImagesDataset\r\nfrom tha4.nn.siren.face_morpher.siren_face_morpher_00 import SirenFaceMorpher00Factory, SirenFaceMorpher00Args\r\nfrom tha4.nn.siren.face_morpher.siren_face_morpher_protocols_00 import SirenFaceMorpherComputationProtocol00, \\\r\n    SirenFaceMorpherSampleOutputProtocol00\r\nfrom tha4.nn.siren.morpher.siren_morpher_protocols_03 import SirenMorpherTrainingProtocol03\r\nfrom tha4.nn.siren.vanilla.siren import SirenArgs\r\nfrom tha4.poser.poser import Poser\r\nfrom torch import Tensor\r\n\r\nKEY_MODULE = \"module\"\r\nKEY_POSER = \"poser\"\r\n\r\n\r\ndef get_poser():\r\n    import tha4.poser.modes.mode_12\r\n    poser = tha4.poser.modes.mode_12.create_poser(torch.device('cpu'))\r\n    return poser\r\n\r\n\r\nclass SirenFaceMorpher00TrainerArgs:\r\n    def __init__(self,\r\n                 character_file_name: str,\r\n                 face_mask_file_name: str,\r\n                 pose_dataset_file_name: str,\r\n                 num_training_total_examples: int = 1_000_000,\r\n                 num_training_examples_per_checkpoint: int = 100_000,\r\n                 num_training_examples_lr_boundaries: Optional[List[int]] = None,\r\n                 num_training_examples_per_sample_output: Optional[int] = 5_000,\r\n                 num_training_examples_per_snapshot: int = 10_000,\r\n                 total_batch_size: int = 8,\r\n                 training_random_seed: int = 2965603729,\r\n                 sample_output_random_seed: int = 3522651501,\r\n                 total_worker: int = 16,\r\n                 poser_func: Optional[Callable[[], Poser]] = None,\r\n                 base_learning_rate: float = 1e-4):\r\n        assert num_training_total_examples % num_training_examples_per_checkpoint == 0\r\n\r\n        if num_training_examples_lr_boundaries is None:\r\n            num_training_examples_lr_boundaries = [\r\n                int(num_training_examples_per_checkpoint * 2),\r\n                int(num_training_examples_per_checkpoint * 5),\r\n                int(num_training_examples_per_checkpoint * 8),\r\n            ]\r\n\r\n        for x in num_training_examples_lr_boundaries:\r\n            assert x % num_training_examples_per_snapshot == 0\r\n\r\n        if poser_func is None:\r\n            poser_func = get_poser\r\n\r\n        self.face_mask_file_name = face_mask_file_name\r\n        self.base_learning_rate = base_learning_rate\r\n        self.poser_func = poser_func\r\n        self.total_worker = total_worker\r\n        self.num_training_examples_per_snapshot = num_training_examples_per_snapshot\r\n        self.num_training_examples_per_sample_output = num_training_examples_per_sample_output\r\n        self.sample_output_random_seed = sample_output_random_seed\r\n        self.training_random_seed = training_random_seed\r\n        self.total_batch_size = total_batch_size\r\n        self.num_training_total_examples = num_training_total_examples\r\n        self.num_training_examples_per_checkpoint = num_training_examples_per_checkpoint\r\n        self.num_training_examples_lr_boundaries = num_training_examples_lr_boundaries\r\n        self.pose_dataset_file_name = pose_dataset_file_name\r\n        self.character_file_name = character_file_name\r\n\r\n    def get_character_image(self):\r\n        return extract_pytorch_image_from_filelike(\r\n            self.character_file_name,\r\n            scale=2.0,\r\n            offset=-1.0,\r\n            premultiply_alpha=True,\r\n            perform_srgb_to_linear=True)\r\n\r\n    def get_face_mask_image(self):\r\n        loaded_image = extract_pytorch_image_from_filelike(\r\n            self.face_mask_file_name,\r\n            scale=1.0,\r\n            offset=0.0,\r\n            premultiply_alpha=True,\r\n            perform_srgb_to_linear=True)\r\n        output_image = torch.zeros(4, 128, 128)\r\n        center_x = 256\r\n        center_y = 128 + 16\r\n        for i in range(4):\r\n            output_image[i, :, :] = loaded_image[0, center_y - 64:center_y + 64, center_x - 64:center_x + 64]\r\n        return output_image\r\n\r\n    def get_training_dataset(self):\r\n        return ImagePosesAndOtherImagesDataset(\r\n            main_image_func=self.get_character_image,\r\n            other_image_funcs=[self.get_face_mask_image],\r\n            pose_dataset=LazyTensorDataset(self.pose_dataset_file_name))\r\n\r\n    def get_module_factory(self):\r\n        return SirenFaceMorpher00Factory(\r\n            SirenFaceMorpher00Args(\r\n                image_size=128,\r\n                image_channels=4,\r\n                pose_size=39,\r\n                siren_args=SirenArgs(\r\n                    in_channels=39 + 2,\r\n                    out_channels=4,\r\n                    intermediate_channels=128,\r\n                    num_sine_layers=8)))\r\n\r\n    def transform_pose_to_module_input(self, pose: Tensor):\r\n        return pose[:, 0:39]\r\n\r\n    def transform_original_image_to_module_input(self, image: Tensor):\r\n        center_x = 256\r\n        center_y = 128 + 16\r\n        return image[:, :, center_y - 64:center_y + 64, center_x - 64:center_x + 64]\r\n\r\n    def transform_poser_posed_image_to_groundtruth(self, image: Tensor):\r\n        center_x = 96\r\n        center_y = 96 + 16\r\n        return image[:, :, center_y - 64:center_y + 64, center_x - 64:center_x + 64]\r\n\r\n    def get_training_computation_protocol(self):\r\n        return SirenFaceMorpherComputationProtocol00(\r\n            transform_pose_to_module_input_func=self.transform_pose_to_module_input,\r\n            transform_original_image_to_module_input_func=self.transform_original_image_to_module_input,\r\n            transform_poser_posed_image_to_groundtruth_func=self.transform_poser_posed_image_to_groundtruth)\r\n\r\n    def get_learning_rate(self, examples_seen_so_far) -> Dict[str, float]:\r\n        if examples_seen_so_far < self.num_training_examples_lr_boundaries[0]:\r\n            return {\r\n                KEY_MODULE: self.base_learning_rate,\r\n            }\r\n        elif examples_seen_so_far < self.num_training_examples_lr_boundaries[1]:\r\n            return {\r\n                KEY_MODULE: self.base_learning_rate / 3.0,\r\n            }\r\n        elif examples_seen_so_far < self.num_training_examples_lr_boundaries[2]:\r\n            return {\r\n                KEY_MODULE: self.base_learning_rate / 10.0,\r\n            }\r\n        else:\r\n            return {\r\n                KEY_MODULE: self.base_learning_rate / 30.0,\r\n            }\r\n\r\n    def get_optimizer_factories(self):\r\n        return {\r\n            KEY_MODULE: AdamOptimizerFactory(betas=(0.9, 0.999)),\r\n        }\r\n\r\n    def get_poser(self):\r\n        return self.poser_func()\r\n\r\n    def get_training_protocol(self, world_size: int):\r\n        total_examples = self.num_training_total_examples\r\n        per_checkpoint_examples = self.num_training_examples_per_checkpoint\r\n        num_checkpoints = total_examples // per_checkpoint_examples\r\n        batch_size = self.total_batch_size // world_size\r\n        return SirenMorpherTrainingProtocol03(\r\n            check_point_examples=[per_checkpoint_examples * (i + 1) for i in range(num_checkpoints)],\r\n            batch_size=batch_size,\r\n            learning_rate=self.get_learning_rate,\r\n            optimizer_factories=self.get_optimizer_factories(),\r\n            random_seed=self.training_random_seed,\r\n            poser_func=self.get_poser,\r\n            key_module=KEY_MODULE,\r\n            key_poser=KEY_POSER)\r\n\r\n    def get_sample_output_protocol(self):\r\n        return SirenFaceMorpherSampleOutputProtocol00(\r\n            num_images=8,\r\n            image_size=128,\r\n            images_per_row=2,\r\n            examples_per_sample_output=self.num_training_examples_per_sample_output,\r\n            computation_protocol=self.get_training_computation_protocol(),\r\n            poser_func=self.get_poser,\r\n            random_seed=self.sample_output_random_seed)\r\n\r\n    def get_loss(self):\r\n        protocol = self.get_training_computation_protocol()\r\n        return SumLoss([\r\n            (\r\n                'full',\r\n                L1Loss(\r\n                    expected_func=protocol.get_output_func(protocol.keys.groundtruth_posed_image),\r\n                    actual_func=protocol.get_output_func(protocol.keys.predicted_posed_image),\r\n                    weight=1.0)\r\n            ),\r\n            (\r\n                'eye_mouth',\r\n                MaskedL1Loss(\r\n                    expected_func=protocol.get_output_func(protocol.keys.groundtruth_posed_image),\r\n                    actual_func=protocol.get_output_func(protocol.keys.predicted_posed_image),\r\n                    mask_func=protocol.get_output_func(protocol.keys.eye_mouth_mask),\r\n                    weight=20.0)\r\n            ),\r\n        ])\r\n\r\n    def create_trainer(self, prefix: str, world_size: int, distrib_backend: str = 'gloo'):\r\n        if self.num_training_examples_per_sample_output is not None:\r\n            sample_output_protocol = self.get_sample_output_protocol()\r\n        else:\r\n            sample_output_protocol = None\r\n\r\n        return DistributedTrainer(\r\n            prefix=prefix,\r\n            module_factories={\r\n                KEY_MODULE: self.get_module_factory(),\r\n            },\r\n            accumulators={},\r\n            losses={\r\n                KEY_MODULE: self.get_loss(),\r\n            },\r\n            training_dataset=self.get_training_dataset(),\r\n            validation_dataset=self.get_training_dataset(),\r\n            training_protocol=self.get_training_protocol(world_size),\r\n            validation_protocol=None,\r\n            sample_output_protocol=sample_output_protocol,\r\n            pretrained_module_file_names={},\r\n            example_per_snapshot=self.num_training_examples_per_snapshot,\r\n            num_data_loader_workers=max(1, self.total_worker // world_size),\r\n            distrib_backend=distrib_backend)\r\n"
  },
  {
    "path": "src/tha4/nn/siren/face_morpher/siren_face_morpher_protocols_00.py",
    "content": "import os\r\nfrom dataclasses import dataclass\r\nfrom typing import Dict, Any, Optional, Callable\r\n\r\nimport PIL.Image\r\nimport numpy\r\nimport torch\r\nfrom tha4.shion.base.dataset.util import get_indexed_batch\r\nfrom tha4.shion.base.image_util import pytorch_rgba_to_numpy_image\r\nfrom tha4.shion.core.cached_computation import CachedComputationProtocol, ComputationState, \\\r\n    ComposableCachedComputationProtocol, batch_indexing_func, add_step\r\nfrom tha4.shion.core.training.sample_output_protocol import SampleOutputProtocol\r\nfrom tha4.poser.general_poser_02 import GeneralPoser02\r\nfrom torch import Tensor\r\nfrom torch.nn import Module\r\nfrom torch.utils.data import Dataset\r\n\r\nKEY_MODULE = \"module\"\r\nKEY_POSER = \"poser\"\r\n\r\n\r\n@dataclass\r\nclass SirenMorpherProtocol00Keys:\r\n    module: str = KEY_MODULE\r\n    module_output: str = \"module_output\"\r\n\r\n    poser: str = KEY_POSER\r\n    poser_output: str = \"poser_output\"\r\n\r\n    original_image: str = \"original_image\"\r\n    original_pose: str = \"original_pose\"\r\n\r\n    module_input_image: str = \"module_input_image\"\r\n    module_input_pose: str = \"module_input_pose\"\r\n\r\n    groundtruth_posed_image: str = 'groundtruth_posed_image'\r\n    predicted_posed_image: str = 'predicted_posed_image'\r\n\r\n    eye_mouth_mask: str = 'eye_mouth_mask'\r\n\r\n\r\n@dataclass\r\nclass SirenMorpherProtocol00Indices:\r\n    batch_original_image: int = 0\r\n    batch_pose: int = 1\r\n    batch_eye_mouth_mask: int = 2\r\n    poser_posed_image: int = 0\r\n\r\n\r\nclass SirenFaceMorpherComputationProtocol00(ComposableCachedComputationProtocol):\r\n    def __init__(self,\r\n                 transform_pose_to_module_input_func: Callable[[Tensor], Tensor],\r\n                 transform_original_image_to_module_input_func: Callable[[Tensor], Tensor],\r\n                 transform_poser_posed_image_to_groundtruth_func: Callable[[Tensor], Tensor],\r\n                 keys: Optional[SirenMorpherProtocol00Keys] = None,\r\n                 indices: Optional[SirenMorpherProtocol00Indices] = None):\r\n        super().__init__()\r\n\r\n        if keys is None:\r\n            keys = SirenMorpherProtocol00Keys()\r\n        if indices is None:\r\n            indices = SirenMorpherProtocol00Indices()\r\n\r\n        self.keys = keys\r\n        self.indices = indices\r\n        self.transform_image_to_module_input_func = transform_original_image_to_module_input_func\r\n        self.transform_pose_to_module_input_func = transform_pose_to_module_input_func\r\n        self.transform_poser_posed_image_to_groundtruth_func = transform_poser_posed_image_to_groundtruth_func\r\n\r\n        self.computation_steps[keys.original_image] = batch_indexing_func(indices.batch_original_image)\r\n        self.computation_steps[keys.original_pose] = batch_indexing_func(indices.batch_pose)\r\n\r\n        @add_step(self.computation_steps, keys.module_input_pose)\r\n        def get_module_input_pose(protocol: CachedComputationProtocol, state: ComputationState):\r\n            original_pose = protocol.get_output(keys.original_pose, state)\r\n            return transform_pose_to_module_input_func(original_pose)\r\n\r\n        @add_step(self.computation_steps, keys.module_input_image)\r\n        def get_module_input_image(protocol: CachedComputationProtocol, state: ComputationState):\r\n            original_image = protocol.get_output(keys.original_image, state)\r\n            return transform_original_image_to_module_input_func(original_image)\r\n\r\n        @add_step(self.computation_steps, keys.poser_output)\r\n        def get_poser_output(protocol: CachedComputationProtocol, state: ComputationState):\r\n            with torch.no_grad():\r\n                poser = state.modules[keys.poser]\r\n                pose = protocol.get_output(keys.original_pose, state)\r\n                image = protocol.get_output(keys.original_image, state)\r\n                return poser.get_posing_outputs(image, pose)\r\n\r\n        @add_step(self.computation_steps, keys.groundtruth_posed_image)\r\n        def get_groundtruth_posed_image(protocol: CachedComputationProtocol, state: ComputationState):\r\n            poser_output = protocol.get_output(keys.poser_output, state)\r\n            poser_posed_image = poser_output[indices.poser_posed_image]\r\n            return transform_poser_posed_image_to_groundtruth_func(poser_posed_image)\r\n\r\n        @add_step(self.computation_steps, keys.module_output)\r\n        def get_module_output(protocol: CachedComputationProtocol, state: ComputationState):\r\n            module_input_pose = protocol.get_output(keys.module_input_pose, state)\r\n            module = state.modules[keys.module]\r\n            return module.forward(module_input_pose)\r\n\r\n        @add_step(self.computation_steps, keys.predicted_posed_image)\r\n        def get_predicted_image(protocol: CachedComputationProtocol, state: ComputationState):\r\n            return protocol.get_output(keys.module_output, state)\r\n\r\n        self.computation_steps[keys.eye_mouth_mask] = batch_indexing_func(indices.batch_eye_mouth_mask)\r\n\r\n\r\nclass SirenFaceMorpherSampleOutputProtocol00(SampleOutputProtocol):\r\n    def __init__(self,\r\n                 num_images: int,\r\n                 image_size: int,\r\n                 images_per_row: int,\r\n                 examples_per_sample_output: int,\r\n                 computation_protocol: SirenFaceMorpherComputationProtocol00,\r\n                 poser_func: Callable[[], GeneralPoser02],\r\n                 random_seed: int = 54859395058,\r\n                 batch_size: Optional[int] = None):\r\n        if batch_size is None:\r\n            batch_size = num_images\r\n\r\n        self.batch_size = batch_size\r\n        self.poser_func = poser_func\r\n        self.random_seed = random_seed\r\n        self.examples_per_sample_output = examples_per_sample_output\r\n        self.images_per_row = images_per_row\r\n        self.image_size = image_size\r\n        self.num_images = num_images\r\n        self.computation_protocol = computation_protocol\r\n\r\n    def get_examples_per_sample_output(self) -> int:\r\n        return self.examples_per_sample_output\r\n\r\n    def get_random_seed(self) -> int:\r\n        return self.random_seed\r\n\r\n    def get_sample_output_data(self, validation_dataset: Dataset, device: torch.device) -> dict:\r\n        example_indices = torch.randint(0, len(validation_dataset), (self.num_images,))\r\n        example_indices = [example_indices[i].item() for i in range(self.num_images)]\r\n        batch = get_indexed_batch(validation_dataset, example_indices, device)\r\n        poser = self.poser_func()\r\n        poser.to(device)\r\n        with torch.no_grad():\r\n            ground_truth = poser.pose(\r\n                batch[self.computation_protocol.indices.batch_original_image],\r\n                batch[self.computation_protocol.indices.batch_pose])\r\n        return {\r\n            'batch': batch,\r\n            'ground_truth': ground_truth\r\n        }\r\n\r\n    def save_sample_output_data(self,\r\n                                modules: Dict[str, Module],\r\n                                accumulated_modules: Dict[str, Module],\r\n                                sample_output_data: Any,\r\n                                prefix: str,\r\n                                examples_seen_so_far: int,\r\n                                device: torch.device):\r\n        batch = sample_output_data['batch']\r\n        ground_truth = sample_output_data['ground_truth']\r\n        ground_truth = self.computation_protocol.transform_poser_posed_image_to_groundtruth_func(ground_truth)\r\n\r\n        module = modules[self.computation_protocol.keys.module]\r\n        module.train(False)\r\n\r\n        if self.batch_size == self.num_images:\r\n            with torch.no_grad():\r\n                state = ComputationState(\r\n                    modules=modules,\r\n                    accumulated_modules=accumulated_modules,\r\n                    batch=batch,\r\n                    outputs={})\r\n                poser_output_images = self.computation_protocol.get_output(\r\n                    self.computation_protocol.keys.predicted_posed_image, state)\r\n        else:\r\n            poser_output_images_list = []\r\n            start = 0\r\n            while start < self.num_images:\r\n                end = start + self.batch_size\r\n                end = min(self.num_images, end)\r\n                minibatch = [batch[i][start:end] for i in range(len(batch))]\r\n                state = ComputationState(\r\n                    modules=modules,\r\n                    accumulated_modules=accumulated_modules,\r\n                    batch=minibatch,\r\n                    outputs={})\r\n                with torch.no_grad():\r\n                    poser_output_images = self.computation_protocol.get_output(\r\n                        self.computation_protocol.keys.predicted_posed_image, state)\r\n                poser_output_images_list.append(poser_output_images)\r\n                start = end\r\n            poser_output_images = torch.cat(poser_output_images_list, dim=0)\r\n\r\n        num_rows = self.num_images // self.images_per_row\r\n        if self.num_images % self.images_per_row > 0:\r\n            num_rows += 1\r\n        num_cols = 2 * self.images_per_row\r\n\r\n        image_channels = 4\r\n        output_image = numpy.zeros([self.image_size * num_rows, self.image_size * num_cols, image_channels])\r\n\r\n        for image_index in range(self.num_images):\r\n            row = image_index // self.images_per_row\r\n            start_row = row * self.image_size\r\n\r\n            col = 2 * (image_index % self.images_per_row)\r\n            start_col = col * self.image_size\r\n            output_image[start_row:start_row + self.image_size, start_col:start_col + self.image_size, :] \\\r\n                = pytorch_rgba_to_numpy_image(ground_truth[image_index].detach().cpu())\r\n\r\n            start_col += self.image_size\r\n            output_image[start_row:start_row + self.image_size, start_col:start_col + self.image_size, :] \\\r\n                = pytorch_rgba_to_numpy_image(poser_output_images[image_index].detach().cpu())\r\n\r\n        file_name = \"%s/sample_output_%010d.png\" % (prefix, examples_seen_so_far)\r\n        os.makedirs(os.path.dirname(file_name), exist_ok=True)\r\n        pil_image = PIL.Image.fromarray(numpy.uint8(numpy.rint(output_image * 255.0)), mode='RGBA')\r\n        pil_image.save(file_name)\r\n        print(\"Saved %s\" % file_name)\r\n"
  },
  {
    "path": "src/tha4/nn/siren/morpher/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/nn/siren/morpher/siren_morpher_03.py",
    "content": "from typing import List, Optional, Callable\r\n\r\nimport torch\r\nfrom torch import Tensor\r\nfrom torch.nn import Module, ModuleList, Sequential, Conv2d\r\nfrom torch.nn.functional import affine_grid, interpolate\r\n\r\nfrom tha4.shion.core.module_factory import ModuleFactory\r\nfrom tha4.shion.nn00.initialization_funcs import HeInitialization\r\nfrom tha4.nn.image_processing_util import GridChangeApplier\r\nfrom tha4.nn.siren.vanilla.siren import SineLinearLayer\r\n\r\n\r\nclass SirenMorpherLevelArgs:\r\n    def __init__(self,\r\n                 image_size: int,\r\n                 intermediate_channels: int,\r\n                 num_sine_layers: int):\r\n        assert num_sine_layers >= 2\r\n        self.image_size = image_size\r\n        self.num_sine_layers = num_sine_layers\r\n        self.intermediate_channels = intermediate_channels\r\n\r\n\r\nclass SirenMorpher03Args:\r\n    def __init__(self,\r\n                 image_size: int,\r\n                 image_channels: int,\r\n                 pose_size: int,\r\n                 level_args: List[SirenMorpherLevelArgs],\r\n                 init_func: Optional[Callable[[Module], Module]] = None):\r\n        assert len(level_args) >= 2\r\n        if init_func is None:\r\n            init_func = HeInitialization()\r\n        self.image_size = image_size\r\n        self.init_func = init_func\r\n        self.level_args = level_args\r\n        self.pose_size = pose_size\r\n        self.image_channels = image_channels\r\n\r\n\r\nclass SirenMorpher03(Module):\r\n    def __init__(self, args: SirenMorpher03Args):\r\n        super().__init__()\r\n        self.args = args\r\n\r\n        self.siren_layers = ModuleList()\r\n\r\n        for i in range(len(args.level_args)):\r\n            level_args = args.level_args[i]\r\n\r\n            layers = []\r\n\r\n            if i == 0:\r\n                layers.append(SineLinearLayer(\r\n                    in_channels=args.pose_size + 2,\r\n                    out_channels=level_args.intermediate_channels,\r\n                    is_first=True))\r\n            else:\r\n                layers.append(SineLinearLayer(\r\n                    in_channels=level_args.intermediate_channels + args.pose_size + 2,\r\n                    out_channels=level_args.intermediate_channels,\r\n                    is_first=False))\r\n\r\n            for j in range(1, level_args.num_sine_layers - 1):\r\n                layers.append(SineLinearLayer(\r\n                    in_channels=level_args.intermediate_channels,\r\n                    out_channels=level_args.intermediate_channels,\r\n                    is_first=False))\r\n\r\n            if i == len(args.level_args) - 1:\r\n                out_channels = level_args.intermediate_channels\r\n            else:\r\n                out_channels = args.level_args[i + 1].intermediate_channels\r\n            layers.append(SineLinearLayer(\r\n                in_channels=level_args.intermediate_channels,\r\n                out_channels=out_channels,\r\n                is_first=False))\r\n\r\n            self.siren_layers.append(Sequential(*layers))\r\n\r\n        self.last_linear = args.init_func(Conv2d(\r\n            args.level_args[-1].intermediate_channels,\r\n            args.image_channels + 2 + 1,\r\n            kernel_size=1,\r\n            stride=1,\r\n            padding=0,\r\n            bias=True))\r\n\r\n        self.grid_change_applier = GridChangeApplier()\r\n\r\n    def get_position_grid(self, n: int, image_size: int, device: torch.device):\r\n        h, w = image_size, image_size\r\n        identity = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device).unsqueeze(0)\r\n        position = affine_grid(identity, [1, 1, h, w], align_corners=False) \\\r\n            .view(1, h * w, 2)\r\n        position = torch.transpose(position, dim0=1, dim1=2).view(1, 2, h, w) \\\r\n            .repeat(n, 1, 1, 1)\r\n        return position\r\n\r\n    def get_pose_image(self, pose: Tensor, image_size: int):\r\n        n, p = pose.shape[0], pose.shape[1]\r\n        h, w = image_size, image_size\r\n        pose_image = pose.view(n, p, 1, 1).repeat(1, 1, h, w)\r\n        return pose_image\r\n\r\n    def forward(self, image: Tensor, pose: Tensor) -> List[Tensor]:\r\n        n = pose.shape[0]\r\n        device = pose.device\r\n\r\n        x = None\r\n        for i in range(len(self.args.level_args)):\r\n            args = self.args.level_args[i]\r\n            position_and_pose = torch.cat([\r\n                self.get_position_grid(n, args.image_size, device),\r\n                self.get_pose_image(pose, args.image_size)\r\n            ], dim=1)\r\n            if i == 0:\r\n                x = self.siren_layers[i].forward(position_and_pose)\r\n            else:\r\n                x = interpolate(x, size=(args.image_size, args.image_size), mode='bilinear')\r\n                x = torch.cat([x, position_and_pose], dim=1)\r\n                x = self.siren_layers[i].forward(x)\r\n\r\n        siren_output = self.last_linear(x)\r\n\r\n        grid_change = siren_output[:, 0:2, :, :]\r\n        alpha = siren_output[:, 2:3, :, :]\r\n        color_change = siren_output[:, 3:, :, :]\r\n        warped_image = self.grid_change_applier.apply(grid_change, image, align_corners=False)\r\n        blended_image = (1 - alpha) * warped_image + alpha * color_change\r\n\r\n        return [\r\n            blended_image,\r\n            alpha,\r\n            color_change,\r\n            warped_image,\r\n            grid_change\r\n        ]\r\n\r\n    INDEX_BLENDED_IMAGE = 0\r\n    INDEX_ALPHA = 1\r\n    INDEX_COLOR_CHANGE = 2\r\n    INDEX_WARPED_IMAGE = 3\r\n    INDEX_GRID_CHANGE = 4\r\n\r\n\r\nclass SirenMorpher03Factory(ModuleFactory):\r\n    def __init__(self, args: SirenMorpher03Args):\r\n        self.args = args\r\n\r\n    def create(self):\r\n        return SirenMorpher03(self.args)\r\n"
  },
  {
    "path": "src/tha4/nn/siren/morpher/siren_morpher_03_trainer.py",
    "content": "from enum import Enum\r\nfrom typing import Dict, List, Optional, Callable\r\n\r\nimport torch\r\nfrom tha4.shion.base.dataset.lazy_tensor_dataset import LazyTensorDataset\r\nfrom tha4.shion.base.image_util import extract_pytorch_image_from_filelike\r\nfrom tha4.shion.base.loss.l1_loss import L1Loss\r\nfrom tha4.shion.base.loss.sum_loss import SumLoss\r\nfrom tha4.shion.base.loss.time_dependently_weighted_loss import TimeDependentlyWeightedLoss\r\nfrom tha4.shion.base.optimizer_factories import AdamOptimizerFactory\r\nfrom tha4.shion.core.training.distrib.distributed_trainer import DistributedTrainer\r\nfrom tha4.dataset.image_poses_and_aother_images_dataset import ImagePosesAndOtherImagesDataset\r\nfrom tha4.nn.siren.morpher.siren_morpher_03 import SirenMorpherLevelArgs, SirenMorpher03Factory, SirenMorpher03Args\r\nfrom tha4.nn.siren.morpher.siren_morpher_protocols_03 import SirenMorpherComputationProtocol03, \\\r\n    SirenMorpherProtocol03Indices, KEY_MODULE, KEY_POSER, KEY_EXAMPLES_SEEN_SO_FAR, SirenMorpherTrainingProtocol03, \\\r\n    SirenMorpherSampleOutputProtocol\r\nfrom tha4.poser.poser import Poser\r\n\r\n\r\ndef get_poser():\r\n    import tha4.poser.modes.mode_07\r\n    poser = tha4.poser.modes.mode_07.create_poser(torch.device('cpu'))\r\n    return poser\r\n\r\n\r\nclass LossTerm(Enum):\r\n    full_blended = 1\r\n    full_warped = 2\r\n    full_grid_change = 3\r\n    full_color_change = 4\r\n\r\n    def get_loss(self, protocol: SirenMorpherComputationProtocol03):\r\n        if self == LossTerm.full_blended:\r\n            return L1Loss(\r\n                expected_func=protocol.get_output_func(protocol.keys.groundtruth_posed_image),\r\n                actual_func=protocol.get_output_func(protocol.keys.predicted_posed_image))\r\n        elif self == LossTerm.full_warped:\r\n            return L1Loss(\r\n                expected_func=protocol.get_output_func(protocol.keys.groundtruth_warped_image),\r\n                actual_func=protocol.get_output_func(protocol.keys.predicted_warped_image))\r\n        elif self == LossTerm.full_grid_change:\r\n            return L1Loss(\r\n                expected_func=protocol.get_output_func(protocol.keys.groundtruth_grid_change),\r\n                actual_func=protocol.get_output_func(protocol.keys.predicted_grid_change))\r\n        elif self == LossTerm.full_color_change:\r\n            return L1Loss(\r\n                expected_func=protocol.get_output_func(protocol.keys.groundtruth_posed_image),\r\n                actual_func=protocol.get_output_func(protocol.keys.predicted_color_change))\r\n        else:\r\n            raise RuntimeError(f\"Unsupported loss term {self}\")\r\n\r\n\r\nclass LossWeights:\r\n    def __init__(self, weights: Optional[Dict[LossTerm, float]] = None):\r\n        self.weights = {}\r\n        for term in LossTerm:\r\n            self.weights[term] = 0.0\r\n        if weights is not None:\r\n            for term in LossTerm:\r\n                if term in weights:\r\n                    self.weights[term] = weights[term]\r\n\r\n\r\nclass TrainingPhase:\r\n    def __init__(self,\r\n                 num_examples_upper_bound: int,\r\n                 learning_rate: float,\r\n                 loss_weights: LossWeights):\r\n        self.loss_weights = loss_weights\r\n        self.learning_rate = learning_rate\r\n        self.num_examples_upper_bound = num_examples_upper_bound\r\n\r\n\r\nclass LearningRateFunc:\r\n    def __init__(self, phases: List[TrainingPhase], keys: List[str]):\r\n        self.phases = phases\r\n        self.keys = keys\r\n\r\n    def make_learning_rate_dict(self, keys: List[str], value: float):\r\n        output = {}\r\n        for key in keys:\r\n            output[key] = value\r\n        return output\r\n\r\n    def __call__(self, examples_seen_so_far: int) -> Dict[str, float]:\r\n        for i in range(len(self.phases) - 1):\r\n            if examples_seen_so_far < self.phases[i].num_examples_upper_bound:\r\n                return self.make_learning_rate_dict(self.keys, self.phases[i].learning_rate)\r\n        return self.make_learning_rate_dict(self.keys, self.phases[-1].learning_rate)\r\n\r\n\r\nclass LossWeightFunc:\r\n    def __init__(self, phases: List[TrainingPhase], term: LossTerm):\r\n        self.term = term\r\n        self.phases = phases\r\n\r\n    def __call__(self, examples_seen_so_far: int) -> float:\r\n        for i in range(len(self.phases) - 1):\r\n            if examples_seen_so_far < self.phases[i].num_examples_upper_bound:\r\n                return self.phases[i].loss_weights.weights[self.term]\r\n        return self.phases[-1].loss_weights.weights[self.term]\r\n\r\n\r\nclass TrainingPhases:\r\n    def __init__(self, phases: List[TrainingPhase]):\r\n        assert len(phases) > 0\r\n        for i in range(1, len(phases)):\r\n            assert phases[i - 1].num_examples_upper_bound < phases[i].num_examples_upper_bound\r\n\r\n        self.phases = phases\r\n\r\n    def make_learning_rate_dict(self, keys: List[str], value: float):\r\n        output = {}\r\n        for key in keys:\r\n            output[key] = value\r\n        return output\r\n\r\n    def get_learning_rate_func(self, keys: List[str]):\r\n        return LearningRateFunc(self.phases, keys)\r\n\r\n    def get_loss_weight_func(self, term: LossTerm) -> Callable[[int], float]:\r\n        return LossWeightFunc(self.phases, term)\r\n\r\n\r\nclass SirenMorpher03TrainerArgs:\r\n    def __init__(self,\r\n                 character_file_name: str,\r\n                 pose_dataset_file_name: str,\r\n                 training_phases: TrainingPhases,\r\n                 num_training_examples_per_checkpoint: int = 100_000,\r\n                 num_training_examples_per_sample_output: Optional[int] = 10_000,\r\n                 num_training_examples_per_snapshot: int = 10_000,\r\n                 total_batch_size: int = 8,\r\n                 training_random_seed: int = 2965603729,\r\n                 sample_output_random_seed: int = 3522651501,\r\n                 total_worker: int = 8,\r\n                 poser_func: Optional[Callable[[], Poser]] = None,\r\n                 sample_output_batch_size: Optional[int] = None,\r\n                 pretrained_module_file_name: Optional[str] = None):\r\n        for phase in training_phases.phases:\r\n            assert phase.num_examples_upper_bound % num_training_examples_per_checkpoint == 0\r\n\r\n        if poser_func is None:\r\n            poser_func = get_poser\r\n\r\n        self.training_phases = training_phases\r\n        self.pretrained_module_file_name = pretrained_module_file_name\r\n        self.sample_output_batch_size = sample_output_batch_size\r\n        self.poser_func = poser_func\r\n        self.total_worker = total_worker\r\n        self.num_training_examples_per_snapshot = num_training_examples_per_snapshot\r\n        self.num_training_examples_per_sample_output = num_training_examples_per_sample_output\r\n        self.sample_output_random_seed = sample_output_random_seed\r\n        self.training_random_seed = training_random_seed\r\n        self.total_batch_size = total_batch_size\r\n        self.num_training_examples_per_checkpoint = num_training_examples_per_checkpoint\r\n        self.pose_dataset_file_name = pose_dataset_file_name\r\n        self.character_file_name = character_file_name\r\n\r\n    def get_character_image(self):\r\n        return extract_pytorch_image_from_filelike(\r\n            self.character_file_name,\r\n            scale=2.0,\r\n            offset=-1.0,\r\n            premultiply_alpha=True,\r\n            perform_srgb_to_linear=True)\r\n\r\n    def get_training_dataset(self):\r\n        return ImagePosesAndOtherImagesDataset(\r\n            main_image_func=self.get_character_image,\r\n            pose_dataset=LazyTensorDataset(self.pose_dataset_file_name),\r\n            other_image_funcs=[])\r\n\r\n    def get_module_factory(self):\r\n        return SirenMorpher03Factory(\r\n            SirenMorpher03Args(\r\n                image_size=512,\r\n                image_channels=4,\r\n                pose_size=45,\r\n                level_args=[\r\n                    SirenMorpherLevelArgs(\r\n                        image_size=128,\r\n                        intermediate_channels=360,\r\n                        num_sine_layers=3),\r\n                    SirenMorpherLevelArgs(\r\n                        image_size=256,\r\n                        intermediate_channels=180,\r\n                        num_sine_layers=3),\r\n                    SirenMorpherLevelArgs(\r\n                        image_size=512,\r\n                        intermediate_channels=90,\r\n                        num_sine_layers=3),\r\n                ]))\r\n\r\n    def get_training_computation_protocol(self):\r\n        return SirenMorpherComputationProtocol03(\r\n            indices=SirenMorpherProtocol03Indices(\r\n                batch_image=0,\r\n                batch_pose=1,\r\n                batch_face_mask=2))\r\n\r\n    def get_optimizer_factories(self):\r\n        return {\r\n            KEY_MODULE: AdamOptimizerFactory(betas=(0.9, 0.999)),\r\n        }\r\n\r\n    def get_poser(self):\r\n        return self.poser_func()\r\n\r\n    def get_training_protocol(self, world_size: int):\r\n        total_examples = self.training_phases.phases[-1].num_examples_upper_bound\r\n        per_checkpoint_examples = self.num_training_examples_per_checkpoint\r\n        num_checkpoints = total_examples // per_checkpoint_examples\r\n        batch_size = self.total_batch_size // world_size\r\n        return SirenMorpherTrainingProtocol03(\r\n            check_point_examples=[per_checkpoint_examples * (i + 1) for i in range(num_checkpoints)],\r\n            batch_size=batch_size,\r\n            learning_rate=self.training_phases.get_learning_rate_func([KEY_MODULE]),\r\n            optimizer_factories=self.get_optimizer_factories(),\r\n            random_seed=self.training_random_seed,\r\n            poser_func=self.get_poser,\r\n            key_module=KEY_MODULE,\r\n            key_poser=KEY_POSER)\r\n\r\n    def get_sample_output_protocol(self):\r\n        return SirenMorpherSampleOutputProtocol(\r\n            num_images=4,\r\n            image_size=512,\r\n            examples_per_sample_output=self.num_training_examples_per_sample_output,\r\n            computation_protocol=self.get_training_computation_protocol(),\r\n            poser_func=self.get_poser,\r\n            random_seed=self.sample_output_random_seed,\r\n            batch_size=self.sample_output_batch_size,\r\n            batch_pose_index=1,\r\n            batch_image_index=0)\r\n\r\n    def get_loss(self):\r\n        protocol = self.get_training_computation_protocol()\r\n        losses = []\r\n        for term in LossTerm:\r\n            base_loss = term.get_loss(protocol)\r\n            loss = TimeDependentlyWeightedLoss(\r\n                base_loss,\r\n                examples_seen_so_far_func=lambda state: state.outputs[KEY_EXAMPLES_SEEN_SO_FAR],\r\n                weight_func=self.training_phases.get_loss_weight_func(term))\r\n            losses.append((term.name, loss))\r\n        return SumLoss(losses)\r\n\r\n    def create_trainer(self, prefix: str, world_size: int, distrib_backend: str = 'gloo'):\r\n        if self.num_training_examples_per_sample_output is not None:\r\n            sample_output_protocol = self.get_sample_output_protocol()\r\n        else:\r\n            sample_output_protocol = None\r\n\r\n        pretrained_module_file_names = {}\r\n        if self.pretrained_module_file_name is not None:\r\n            pretrained_module_file_names[KEY_MODULE] = self.pretrained_module_file_name\r\n\r\n        return DistributedTrainer(\r\n            prefix=prefix,\r\n            module_factories={\r\n                KEY_MODULE: self.get_module_factory(),\r\n            },\r\n            accumulators={},\r\n            losses={\r\n                KEY_MODULE: self.get_loss(),\r\n            },\r\n            training_dataset=self.get_training_dataset(),\r\n            validation_dataset=self.get_training_dataset(),\r\n            training_protocol=self.get_training_protocol(world_size),\r\n            validation_protocol=None,\r\n            sample_output_protocol=sample_output_protocol,\r\n            pretrained_module_file_names=pretrained_module_file_names,\r\n            example_per_snapshot=self.num_training_examples_per_snapshot,\r\n            num_data_loader_workers=max(1, self.total_worker // world_size),\r\n            distrib_backend=distrib_backend)\r\n"
  },
  {
    "path": "src/tha4/nn/siren/morpher/siren_morpher_protocols_03.py",
    "content": "from dataclasses import dataclass\r\nfrom typing import Optional, List, Callable, Dict, Any\r\n\r\nimport torch\r\nfrom tha4.shion.base.dataset.util import get_indexed_batch\r\n\r\nfrom tha4.shion.core.cached_computation import output_array_indexing_func, add_step, ComputationState, \\\r\n    CachedComputationProtocol, ComposableCachedComputationProtocol, batch_indexing_func, proxy_func\r\nfrom tha4.shion.core.loss import Loss\r\nfrom tha4.shion.core.optimizer_factory import OptimizerFactory\r\nfrom tha4.shion.core.training.sample_output_protocol import SampleOutputProtocol\r\nfrom tha4.shion.core.training.training_protocol import AbstractTrainingProtocol\r\nfrom tha4.nn.image_processing_util import GridChangeApplier\r\nfrom tha4.nn.siren.morpher.siren_morpher_03 import SirenMorpher03\r\nfrom tha4.poser.general_poser_02 import GeneralPoser02\r\nfrom tha4.sampleoutput.sample_image_creator import SampleImageSpec, ImageSource, ImageType, SampleImageSaver\r\nfrom torch.nn import Module\r\nfrom torch.optim import Optimizer\r\nfrom torch.utils.data import Dataset\r\n\r\nKEY_MODULE = \"module\"\r\nKEY_POSER = \"poser\"\r\nKEY_EXAMPLES_SEEN_SO_FAR = \"examples_seen_so_far\"\r\n\r\n\r\n@dataclass\r\nclass SirenMorpherProtocol03Keys:\r\n    module: str = KEY_MODULE\r\n    module_output: str = \"module_output\"\r\n\r\n    poser: str = KEY_POSER\r\n    poser_output: str = \"poser_output\"\r\n\r\n    image: str = \"image\"\r\n    pose: str = \"pose\"\r\n    face_mask: str = 'face_mask'\r\n\r\n    groundtruth_posed_image: str = 'groundtruth_posed_image'\r\n    groundtruth_grid_change: str = 'groundtruth_grid_change'\r\n    groundtruth_posed_face_mask: str = 'groundtruth_posed_face_mask'\r\n    predicted_posed_image: str = 'predicted_posed_image'\r\n\r\n    module_input_image: str = \"module_input_image\"\r\n\r\n    predicted_grid_change: str = \"predicted_grid_change\"\r\n    predicted_color_change: str = \"predicted_color_change\"\r\n    predicted_warped_image: str = \"predicted_warped_image\"\r\n    predicted_alpha: str = \"predicted_alpha\"\r\n\r\n    groundtruth_alpha: str = \"groundtruth_alpha\"\r\n    groundtruth_warped_image: str = \"groundtruth_warped_image\"\r\n\r\n    zero: str = \"zero\"\r\n\r\n\r\n@dataclass\r\nclass SirenMorpherProtocol03Indices:\r\n    batch_image: int = 0\r\n    batch_face_mask: int = 1\r\n    batch_pose: int = 2\r\n    poser_posed_image: int = 0\r\n    poser_grid_change: int = 3\r\n\r\n    poser_output_module_input_image_index: int = 5\r\n    poser_alpha: int = 1\r\n    poser_warped_image: int = 2\r\n\r\n    module_blended_image: int = SirenMorpher03.INDEX_BLENDED_IMAGE\r\n    module_grid_change: int = SirenMorpher03.INDEX_GRID_CHANGE\r\n    module_color_change: int = SirenMorpher03.INDEX_COLOR_CHANGE\r\n    module_warped_image: int = SirenMorpher03.INDEX_WARPED_IMAGE\r\n    module_alpha: int = SirenMorpher03.INDEX_ALPHA\r\n\r\n\r\nclass SirenMorpherComputationProtocol03(ComposableCachedComputationProtocol):\r\n    def __init__(self,\r\n                 keys: Optional[SirenMorpherProtocol03Keys] = None,\r\n                 indices: Optional[SirenMorpherProtocol03Indices] = None):\r\n        super().__init__()\r\n\r\n        if keys is None:\r\n            keys = SirenMorpherProtocol03Keys()\r\n        if indices is None:\r\n            indices = SirenMorpherProtocol03Indices()\r\n\r\n        self.keys = keys\r\n        self.indices = indices\r\n\r\n        self.computation_steps[keys.image] = batch_indexing_func(indices.batch_image)\r\n        self.computation_steps[keys.pose] = batch_indexing_func(indices.batch_pose)\r\n        self.computation_steps[keys.face_mask] = batch_indexing_func(indices.batch_face_mask)\r\n        self.grid_change_applier = GridChangeApplier()\r\n\r\n        @add_step(self.computation_steps, keys.module_output)\r\n        def get_module_output(protocol: CachedComputationProtocol, state: ComputationState):\r\n            pose = protocol.get_output(keys.pose, state)\r\n            module = state.modules[self.keys.module]\r\n            return module.forward(pose)\r\n\r\n        self.computation_steps[keys.predicted_posed_image] = proxy_func(keys.module_output)\r\n\r\n        @add_step(self.computation_steps, keys.poser_output)\r\n        def get_poser_output(protocol: CachedComputationProtocol, state: ComputationState):\r\n            with torch.no_grad():\r\n                poser = state.modules[keys.poser]\r\n                pose = protocol.get_output(keys.pose, state)\r\n                image = protocol.get_output(keys.image, state)\r\n                return poser.get_posing_outputs(image, pose)\r\n\r\n        @add_step(self.computation_steps, keys.groundtruth_posed_image)\r\n        def get_groundtruth_posed_image(protocol: CachedComputationProtocol, state: ComputationState):\r\n            return protocol.get_output(keys.poser_output, state)[indices.poser_posed_image]\r\n\r\n        @add_step(self.computation_steps, keys.groundtruth_grid_change)\r\n        def get_groundtruth_posed_image(protocol: CachedComputationProtocol, state: ComputationState):\r\n            return protocol.get_output(keys.poser_output, state)[indices.poser_grid_change]\r\n\r\n        @add_step(self.computation_steps, keys.groundtruth_posed_face_mask)\r\n        def get_groundtruth_posed_face_mask(protocol: CachedComputationProtocol, state: ComputationState):\r\n            face_mask = protocol.get_output(keys.face_mask, state)\r\n            groundtruth_grid_change = protocol.get_output(keys.groundtruth_grid_change, state)\r\n            with torch.no_grad():\r\n                return self.grid_change_applier.apply(groundtruth_grid_change, face_mask)\r\n\r\n        @add_step(self.computation_steps, keys.module_input_image)\r\n        def get_module_input_image(protocol: CachedComputationProtocol, state: ComputationState):\r\n            poser_output = protocol.get_output(keys.poser_output, state)\r\n            return poser_output[indices.poser_output_module_input_image_index]\r\n\r\n        @add_step(self.computation_steps, keys.module_output)\r\n        def get_module_output(protocol: CachedComputationProtocol, state: ComputationState):\r\n            image = protocol.get_output(keys.module_input_image, state)\r\n            pose = protocol.get_output(keys.pose, state)\r\n            module = state.modules[self.keys.module]\r\n            return module.forward(image, pose)\r\n\r\n        self.computation_steps[keys.predicted_posed_image] = output_array_indexing_func(\r\n            keys.module_output, indices.module_blended_image)\r\n        self.computation_steps[keys.predicted_grid_change] = output_array_indexing_func(\r\n            keys.module_output, indices.module_grid_change)\r\n        self.computation_steps[keys.predicted_color_change] = output_array_indexing_func(\r\n            keys.module_output, indices.module_color_change)\r\n        self.computation_steps[keys.predicted_warped_image] = output_array_indexing_func(\r\n            keys.module_output, indices.module_warped_image)\r\n        self.computation_steps[keys.predicted_alpha] = output_array_indexing_func(\r\n            keys.module_output, indices.module_alpha)\r\n\r\n        self.computation_steps[keys.groundtruth_alpha] = output_array_indexing_func(\r\n            keys.poser_output, indices.poser_alpha)\r\n        self.computation_steps[keys.groundtruth_warped_image] = output_array_indexing_func(\r\n            keys.poser_output, indices.poser_warped_image)\r\n\r\n        @add_step(self.computation_steps, keys.zero)\r\n        def get_zero(protocol: CachedComputationProtocol, state: ComputationState):\r\n            pose = protocol.get_output(keys.pose, state)\r\n            device = pose.device\r\n            return torch.zeros(1, device=device)\r\n\r\n\r\nclass SirenMorpherTrainingProtocol03(AbstractTrainingProtocol):\r\n    def __init__(self,\r\n                 check_point_examples: List[int],\r\n                 batch_size: int,\r\n                 learning_rate: Callable[[int], Dict[str, float]],\r\n                 optimizer_factories: Dict[str, OptimizerFactory],\r\n                 random_seed: int,\r\n                 poser_func: Callable[[], GeneralPoser02],\r\n                 key_module: str,\r\n                 key_poser: str = KEY_POSER,\r\n                 key_examples_seen_so_far: str = KEY_EXAMPLES_SEEN_SO_FAR):\r\n        super().__init__(check_point_examples, batch_size, learning_rate, optimizer_factories, random_seed)\r\n        self.key_examples_seen_so_far = key_examples_seen_so_far\r\n        self.key_poser = key_poser\r\n        self.key_module = key_module\r\n        self.poser_func = poser_func\r\n        self.poser = None\r\n\r\n    def run_training_iteration(\r\n            self,\r\n            batch: Any,\r\n            examples_seen_so_far: int,\r\n            modules: Dict[str, Module],\r\n            accumulated_modules: Dict[str, Module],\r\n            optimizers: Dict[str, Optimizer],\r\n            losses: Dict[str, Loss],\r\n            create_log_func: Optional[Callable[[str, int], Callable[[str, float], None]]],\r\n            device: torch.device):\r\n        if self.poser is None:\r\n            self.poser = self.poser_func()\r\n            self.poser.to(device)\r\n\r\n        module = modules[self.key_module]\r\n        module.train(True)\r\n        module_optimizer = optimizers[self.key_module]\r\n        module_optimizer.zero_grad(set_to_none=True)\r\n\r\n        loss = losses[self.key_module]\r\n        if create_log_func is not None:\r\n            log_func = create_log_func(f\"training_{self.key_module}\", examples_seen_so_far)\r\n        else:\r\n            log_func = None\r\n        state = ComputationState(\r\n            modules={\r\n                **modules,\r\n                self.key_poser: self.poser,\r\n            },\r\n            accumulated_modules=accumulated_modules,\r\n            batch=batch,\r\n            outputs={\r\n                self.key_examples_seen_so_far: examples_seen_so_far,\r\n            })\r\n        loss_value = loss.compute(state, log_func)\r\n        loss_value.backward()\r\n        module_optimizer.step()\r\n\r\n\r\nclass SirenMorpherSampleOutputProtocol(SampleOutputProtocol):\r\n    def __init__(self,\r\n                 num_images: int,\r\n                 image_size: int,\r\n                 examples_per_sample_output: int,\r\n                 computation_protocol,\r\n                 poser_func: Callable[[], GeneralPoser02],\r\n                 random_seed: int = 54859395058,\r\n                 batch_image_index: int = 0,\r\n                 batch_pose_index: int = 2,\r\n                 batch_size: Optional[int] = None,\r\n                 sample_image_specs: Optional[List[SampleImageSpec]] = None,\r\n                 cell_size: Optional[int] = None):\r\n        if batch_size is None:\r\n            batch_size = num_images\r\n\r\n        if sample_image_specs is None:\r\n            sample_image_specs = [\r\n                SampleImageSpec(\r\n                    ImageSource.BATCH,\r\n                    computation_protocol.indices.poser_posed_image,\r\n                    ImageType.COLOR),\r\n                SampleImageSpec(\r\n                    ImageSource.OUTPUT,\r\n                    computation_protocol.indices.module_blended_image,\r\n                    ImageType.COLOR),\r\n                SampleImageSpec(\r\n                    ImageSource.OUTPUT,\r\n                    computation_protocol.indices.module_alpha,\r\n                    ImageType.ALPHA),\r\n                SampleImageSpec(\r\n                    ImageSource.OUTPUT,\r\n                    computation_protocol.indices.module_color_change,\r\n                    ImageType.COLOR),\r\n                SampleImageSpec(\r\n                    ImageSource.OUTPUT,\r\n                    computation_protocol.indices.module_warped_image,\r\n                    ImageType.COLOR),\r\n                SampleImageSpec(\r\n                    ImageSource.BATCH,\r\n                    computation_protocol.indices.poser_grid_change,\r\n                    ImageType.GRID_CHANGE),\r\n                SampleImageSpec(\r\n                    ImageSource.OUTPUT,\r\n                    computation_protocol.indices.module_grid_change,\r\n                    ImageType.GRID_CHANGE),\r\n            ]\r\n\r\n        if cell_size is None:\r\n            cell_size = image_size\r\n\r\n        self.batch_size = batch_size\r\n        self.batch_pose_index = batch_pose_index\r\n        self.batch_image_index = batch_image_index\r\n        self.poser_func = poser_func\r\n        self.random_seed = random_seed\r\n        self.examples_per_sample_output = examples_per_sample_output\r\n        self.image_size = image_size\r\n        self.num_images = num_images\r\n        self.computation_protocol = computation_protocol\r\n        self.cell_size = cell_size\r\n\r\n        self.sample_image_saver = SampleImageSaver(image_size, cell_size, 4, sample_image_specs)\r\n\r\n    def get_examples_per_sample_output(self) -> int:\r\n        return self.examples_per_sample_output\r\n\r\n    def get_random_seed(self) -> int:\r\n        return self.random_seed\r\n\r\n    def get_sample_output_data(self, validation_dataset: Dataset, device: torch.device) -> dict:\r\n        example_indices = torch.randint(0, len(validation_dataset), (self.num_images,))\r\n        example_indices = [example_indices[i].item() for i in range(self.num_images)]\r\n        batch = get_indexed_batch(validation_dataset, example_indices, device)\r\n        poser = self.poser_func()\r\n        poser.to(device)\r\n        with torch.no_grad():\r\n            ground_truth = poser.get_posing_outputs(batch[self.batch_image_index], batch[self.batch_pose_index])\r\n        return {\r\n            'batch': batch,\r\n            'ground_truth': ground_truth\r\n        }\r\n\r\n    def save_sample_output_data(self,\r\n                                modules: Dict[str, Module],\r\n                                accumulated_modules: Dict[str, Module],\r\n                                sample_output_data: Any,\r\n                                prefix: str,\r\n                                examples_seen_so_far: int,\r\n                                device: torch.device):\r\n        batch = sample_output_data['batch']\r\n        ground_truth = sample_output_data['ground_truth']\r\n\r\n        module = modules[self.computation_protocol.keys.module]\r\n        module.train(False)\r\n        if self.batch_size == self.num_images:\r\n            with torch.no_grad():\r\n                state = ComputationState(\r\n                    modules=modules,\r\n                    accumulated_modules=accumulated_modules,\r\n                    batch=batch,\r\n                    outputs={\r\n                        self.computation_protocol.keys.poser_output: ground_truth,\r\n                    })\r\n                module_outputs = self.computation_protocol.get_output(\r\n                    self.computation_protocol.keys.module_output, state)\r\n        else:\r\n            module_outputs_list = []\r\n            start = 0\r\n            while start < self.num_images:\r\n                end = start + self.batch_size\r\n                end = min(self.num_images, end)\r\n                minibatch = [batch[i][start:end] for i in range(len(batch))]\r\n                ground_truth_batch = [ground_truth[i][start:end] for i in range(len(ground_truth))]\r\n                state = ComputationState(\r\n                    modules=modules,\r\n                    accumulated_modules=accumulated_modules,\r\n                    batch=minibatch,\r\n                    outputs={\r\n                        self.computation_protocol.keys.poser_output: ground_truth_batch\r\n                    })\r\n                with torch.no_grad():\r\n                    module_outputs = self.computation_protocol.get_output(\r\n                        self.computation_protocol.keys.module_output, state)\r\n                module_outputs_list.append(module_outputs)\r\n                start = end\r\n\r\n            module_outputs = []\r\n            for i in range(len(module_outputs_list[0])):\r\n                tensor_list = []\r\n                for j in range(len(module_outputs_list)):\r\n                    tensor_list.append(module_outputs_list[j][i])\r\n                module_output = torch.cat(tensor_list, dim=0)\r\n                module_outputs.append(module_output)\r\n\r\n        self.sample_image_saver.save_sample_output_data(ground_truth, module_outputs, prefix, examples_seen_so_far)\r\n"
  },
  {
    "path": "src/tha4/nn/siren/vanilla/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/nn/siren/vanilla/siren.py",
    "content": "import math\r\nfrom typing import Callable, Optional, List\r\n\r\nimport torch\r\nfrom torch import Tensor\r\nfrom torch.nn import Module, Conv2d, ModuleList\r\n\r\nfrom tha4.shion.core.module_factory import ModuleFactory\r\nfrom tha4.shion.nn00.initialization_funcs import HeInitialization\r\n\r\n\r\nclass SineLinearLayer(Module):\r\n    def __init__(self,\r\n            in_channels: int,\r\n            out_channels: int,\r\n            is_first=False,\r\n            omega_0=30.0):\r\n        super().__init__()\r\n        self.out_channels = out_channels\r\n        self.in_channels = in_channels\r\n        self.omega_0 = omega_0\r\n        self.is_first = is_first\r\n        self.linear = Conv2d(\r\n            in_channels=in_channels,\r\n            out_channels=out_channels,\r\n            kernel_size=1,\r\n            stride=1,\r\n            padding=0,\r\n            bias=True)\r\n        with torch.no_grad():\r\n            if self.is_first:\r\n                self.linear.weight.uniform_(-1 / in_channels, 1.0 / in_channels)\r\n            else:\r\n                self.linear.weight.uniform_(\r\n                    -math.sqrt(6.0 / in_channels) / self.omega_0,\r\n                    math.sqrt(6.0 / in_channels) / self.omega_0)\r\n\r\n    def forward(self, x: Tensor):\r\n        return torch.sin(self.omega_0 * self.linear(x))\r\n\r\n\r\nclass SirenArgs:\r\n    def __init__(\r\n            self,\r\n            in_channels: int,\r\n            out_channels: int,\r\n            intermediate_channels: int,\r\n            num_sine_layers: int,\r\n            use_tanh: bool = False,\r\n            init_func: Optional[Callable[[Module], Module]] = None):\r\n        if init_func is None:\r\n            init_func = HeInitialization()\r\n        self.init_func = init_func\r\n        self.use_tanh = use_tanh\r\n        assert num_sine_layers >= 1\r\n        self.intermediate_channels = intermediate_channels\r\n        self.num_sine_layers = num_sine_layers\r\n        self.out_channels = out_channels\r\n        self.in_channels = in_channels\r\n\r\n\r\nclass Siren(Module):\r\n    def __init__(self, args: SirenArgs):\r\n        super().__init__()\r\n        self.args = args\r\n        self.sine_layers = ModuleList()\r\n        self.sine_layers.append(\r\n            SineLinearLayer(\r\n                in_channels=args.in_channels, out_channels=args.intermediate_channels, is_first=True))\r\n        for i in range(args.num_sine_layers - 1):\r\n            self.sine_layers.append(\r\n                SineLinearLayer(\r\n                    in_channels=args.intermediate_channels,\r\n                    out_channels=args.intermediate_channels,\r\n                    is_first=False))\r\n        self.last_linear = args.init_func(Conv2d(\r\n            args.intermediate_channels,\r\n            args.out_channels,\r\n            kernel_size=1,\r\n            stride=1,\r\n            padding=0,\r\n            bias=True))\r\n\r\n    def forward(self, x: Tensor) -> Tensor:\r\n        for i in range(self.args.num_sine_layers):\r\n            x = self.sine_layers[i].forward(x)\r\n        x = self.last_linear(x)\r\n        if self.args.use_tanh:\r\n            return torch.tanh(x)\r\n        else:\r\n            return x\r\n\r\n\r\nclass SirenFactory(ModuleFactory):\r\n    def __init__(self, args: SirenArgs):\r\n        super().__init__()\r\n        self.args = args\r\n\r\n    def create(self) -> Module:\r\n        return Siren(self.args)\r\n"
  },
  {
    "path": "src/tha4/nn/spectral_norm.py",
    "content": "from torch.nn import Module\r\nfrom torch.nn.utils import spectral_norm\r\n\r\n\r\ndef apply_spectral_norm(module: Module, use_spectrial_norm: bool = False) -> Module:\r\n    if use_spectrial_norm:\r\n        return spectral_norm(module)\r\n    else:\r\n        return module\r\n"
  },
  {
    "path": "src/tha4/nn/upscaler/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/nn/upscaler/upscaler_02.py",
    "content": "from typing import List\r\n\r\nimport torch\r\nfrom torch import Tensor, zero_\r\nfrom torch.nn import Module, Conv2d\r\n\r\nfrom tha4.shion.core.module_factory import ModuleFactory\r\nfrom tha4.nn.image_processing_util import GridChangeApplier\r\nfrom tha4.nn.common.unet import UnetArgs, AttentionBlockArgs, UnetWithFirstConvAddition\r\n\r\n\r\nclass Upscaler02Args:\r\n    def __init__(self,\r\n                 image_size: int,\r\n                 image_channels: int,\r\n                 num_pose_parameters: int,\r\n                 unet_args: UnetArgs):\r\n        assert unet_args.in_channels == (\r\n            image_channels\r\n        )\r\n        assert unet_args.out_channels == (\r\n                image_channels +  # direct\r\n                2 +  # warp\r\n                1  # alpha\r\n        )\r\n        assert unet_args.cond_input_channels == num_pose_parameters\r\n        self.image_channels = image_channels\r\n        self.image_size = image_size\r\n        self.num_pose_parameters = num_pose_parameters\r\n        self.unet_args = unet_args\r\n\r\n\r\ndef apply_color_change(alpha, color_change, image: Tensor) -> Tensor:\r\n    return color_change * alpha + image * (1 - alpha)\r\n\r\n\r\nclass Upscaler02(Module):\r\n    def __init__(self, args: Upscaler02Args):\r\n        super().__init__()\r\n        self.args = args\r\n        self.body = UnetWithFirstConvAddition(args.unet_args)\r\n        self.grid_change_applier = GridChangeApplier()\r\n        self.coarse_image_conv = Conv2d(\r\n            args.image_channels + args.image_channels + 2,\r\n            args.unet_args.model_channels,\r\n            kernel_size=3,\r\n            stride=1,\r\n            padding=1)\r\n        with torch.no_grad():\r\n            zero_(self.coarse_image_conv.weight)\r\n            zero_(self.coarse_image_conv.bias)\r\n\r\n    def check_image(self, image: torch.Tensor):\r\n        assert len(image.shape) == 4\r\n        assert image.shape[1] == self.args.image_channels\r\n        assert image.shape[2] == self.args.image_size\r\n        assert image.shape[3] == self.args.image_size\r\n\r\n    def forward(self,\r\n                rest_image: torch.Tensor,\r\n                coarse_posed_image: torch.Tensor,\r\n                coarse_grid_change: torch.Tensor,\r\n                pose: torch.Tensor) -> List[Tensor]:\r\n        self.check_image(rest_image)\r\n        self.check_image(coarse_posed_image)\r\n\r\n        assert len(pose.shape) == 2\r\n        assert rest_image.shape[0] == pose.shape[0]\r\n        assert coarse_posed_image.shape[0] == pose.shape[0]\r\n        assert coarse_grid_change.shape[0] == pose.shape[0]\r\n        assert coarse_grid_change.shape[1] == 2\r\n        assert coarse_grid_change.shape[2] == self.args.image_size\r\n        assert coarse_grid_change.shape[3] == self.args.image_size\r\n        assert pose.shape[1] == self.args.num_pose_parameters\r\n\r\n        warped_image = self.grid_change_applier.apply(coarse_grid_change, rest_image)\r\n\r\n        t = torch.zeros(rest_image.shape[0], 1, device=rest_image.device)\r\n        feature = torch.cat([coarse_posed_image, warped_image, coarse_grid_change], dim=1)\r\n        first_conv_addition = self.coarse_image_conv(feature)\r\n\r\n        body_output = self.body(rest_image, t, pose, first_conv_addition)\r\n\r\n        direct = body_output[:, 0:self.args.image_channels, :, :]\r\n        grid_change = body_output[:, self.args.image_channels:self.args.image_channels + 2, :, :]\r\n        alpha = torch.sigmoid(body_output[:, self.args.image_channels + 2:self.args.image_channels + 3, :, :])\r\n        warped = self.grid_change_applier.apply(grid_change, rest_image)\r\n        merged = apply_color_change(alpha, direct, warped)\r\n\r\n        return [\r\n            merged,\r\n            alpha,\r\n            warped,\r\n            grid_change,\r\n            direct\r\n        ]\r\n\r\n    INDEX_MERGED = 0\r\n    INDEX_ALPHA = 1\r\n    INDEX_WARPED = 2\r\n    INDEX_GRID_CHANGE = 3\r\n    INDEX_DIRECT = 4\r\n\r\n\r\nclass Upscaler02Factory(ModuleFactory):\r\n    def __init__(self, args: Upscaler02Args):\r\n        self.args = args\r\n\r\n    def create(self) -> Module:\r\n        return Upscaler02(self.args)\r\n"
  },
  {
    "path": "src/tha4/nn/util.py",
    "content": "from typing import Optional, Callable, Union\r\n\r\nfrom torch.nn import Module\r\n\r\nfrom tha4.shion.core.module_factory import ModuleFactory\r\nfrom tha4.nn.init_function import create_init_function\r\nfrom tha4.nn.nonlinearity_factory import resolve_nonlinearity_factory\r\nfrom tha4.nn.normalization import NormalizationLayerFactory\r\nfrom tha4.nn.spectral_norm import apply_spectral_norm\r\n\r\n\r\ndef wrap_conv_or_linear_module(module: Module,\r\n                               initialization_method: Union[str, Callable[[Module], Module]],\r\n                               use_spectral_norm: bool):\r\n    if isinstance(initialization_method, str):\r\n        init = create_init_function(initialization_method)\r\n    else:\r\n        init = initialization_method\r\n    return apply_spectral_norm(init(module), use_spectral_norm)\r\n\r\n\r\nclass BlockArgs:\r\n    def __init__(self,\r\n                 initialization_method: Union[str, Callable[[Module], Module]] = 'he',\r\n                 use_spectral_norm: bool = False,\r\n                 normalization_layer_factory: Optional[NormalizationLayerFactory] = None,\r\n                 nonlinearity_factory: Optional[ModuleFactory] = None):\r\n        self.nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory)\r\n        self.normalization_layer_factory = normalization_layer_factory\r\n        self.use_spectral_norm = use_spectral_norm\r\n        self.initialization_method = initialization_method\r\n\r\n    def wrap_module(self, module: Module) -> Module:\r\n        return wrap_conv_or_linear_module(module, self.get_init_func(), self.use_spectral_norm)\r\n\r\n    def get_init_func(self) -> Callable[[Module], Module]:\r\n        if isinstance(self.initialization_method, str):\r\n            return create_init_function(self.initialization_method)\r\n        else:\r\n            return self.initialization_method\r\n"
  },
  {
    "path": "src/tha4/poser/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/poser/general_poser_02.py",
    "content": "from typing import List, Optional, Tuple, Dict, Callable\r\n\r\nimport torch\r\nfrom tha4.shion.core.cached_computation import ComputationState\r\nfrom tha4.poser.poser import PoseParameterGroup, Poser\r\nfrom torch import Tensor\r\nfrom torch.nn import Module\r\n\r\n\r\nclass GeneralPoser02(Poser):\r\n    def __init__(self,\r\n                 module_loaders: Dict[str, Callable[[], Module]],\r\n                 device: torch.device,\r\n                 output_length: int,\r\n                 pose_parameters: List[PoseParameterGroup],\r\n                 output_list_func: Callable[[ComputationState], List[Tensor]],\r\n                 subrect: Optional[Tuple[Tuple[int, int], Tuple[int, int]]] = None,\r\n                 default_output_index: int = 0,\r\n                 image_size: int = 256,\r\n                 dtype: torch.dtype = torch.float):\r\n        self.dtype = dtype\r\n        self.image_size = image_size\r\n        self.default_output_index = default_output_index\r\n        self.output_list_func = output_list_func\r\n        self.subrect = subrect\r\n        self.pose_parameters = pose_parameters\r\n        self.device = device\r\n        self.module_loaders = module_loaders\r\n\r\n        self.modules = None\r\n\r\n        self.num_parameters = 0\r\n        for pose_parameter in self.pose_parameters:\r\n            self.num_parameters += pose_parameter.get_arity()\r\n\r\n        self.output_length = output_length\r\n\r\n    def get_image_size(self) -> int:\r\n        return self.image_size\r\n\r\n    def get_modules(self):\r\n        if self.modules is None:\r\n            self.modules = {}\r\n            for key in self.module_loaders:\r\n                module = self.module_loaders[key]()\r\n                self.modules[key] = module\r\n                module.to(self.device)\r\n                module.train(False)\r\n        return self.modules\r\n\r\n    def get_pose_parameter_groups(self) -> List[PoseParameterGroup]:\r\n        return self.pose_parameters\r\n\r\n    def get_num_parameters(self) -> int:\r\n        return self.num_parameters\r\n\r\n    def pose(self, image: Tensor, pose: Tensor, output_index: Optional[int] = None) -> Tensor:\r\n        if output_index is None:\r\n            output_index = self.default_output_index\r\n        output_list = self.get_posing_outputs(image, pose)\r\n        return output_list[output_index]\r\n\r\n    def get_posing_outputs(self, image: Tensor, pose: Tensor) -> List[Tensor]:\r\n        modules = self.get_modules()\r\n\r\n        if len(image.shape) == 3:\r\n            image = image.unsqueeze(0)\r\n        if len(pose.shape) == 1:\r\n            pose = pose.unsqueeze(0)\r\n        if self.subrect is not None:\r\n            image = image[:, :, self.subrect[0][0]:self.subrect[0][1], self.subrect[1][0]:self.subrect[1][1]]\r\n        batch = [image, pose]\r\n\r\n        state = ComputationState(\r\n            modules=modules,\r\n            accumulated_modules={},\r\n            batch=batch,\r\n            outputs={})\r\n        return self.output_list_func(state)\r\n\r\n    def get_output_length(self) -> int:\r\n        return self.output_length\r\n\r\n    def free(self):\r\n        self.modules = None\r\n\r\n    def get_dtype(self) -> torch.dtype:\r\n        return self.dtype\r\n\r\n    def to(self, device: torch.device) -> 'GeneralPoser02':\r\n        if device == self.device:\r\n            return self\r\n        modules = self.get_modules()\r\n        self.device = device\r\n        for key in modules:\r\n            module = modules[key]\r\n            module.to(self.device)\r\n        return self\r\n"
  },
  {
    "path": "src/tha4/poser/modes/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/poser/modes/mode_07.py",
    "content": "from enum import Enum\r\nfrom typing import List, Dict, Optional\r\n\r\nimport torch\r\nfrom tha4.shion.core.cached_computation import CachedComputationProtocol, ComputationState\r\nfrom tha4.shion.core.load_save import torch_load\r\nfrom tha4.nn.eyebrow_decomposer.eyebrow_decomposer_00 import EyebrowDecomposer00, \\\r\n    EyebrowDecomposer00Factory, EyebrowDecomposer00Args\r\nfrom tha4.nn.eyebrow_morphing_combiner.eyebrow_morphing_combiner_00 import \\\r\n    EyebrowMorphingCombiner00Factory, EyebrowMorphingCombiner00Args, EyebrowMorphingCombiner00\r\nfrom tha4.nn.face_morpher.face_morpher_08 import FaceMorpher08Args, FaceMorpher08Factory\r\nfrom tha4.nn.nonlinearity_factory import ReLUFactory\r\nfrom tha4.nn.normalization import InstanceNorm2dFactory\r\nfrom tha4.nn.util import BlockArgs\r\nfrom tha4.nn.common.unet import UnetArgs, AttentionBlockArgs\r\nfrom tha4.nn.morpher.morpher_00 import Morpher00Args, Morpher00\r\nfrom tha4.nn.upscaler.upscaler_02 import Upscaler02Args, Upscaler02\r\nfrom tha4.poser.general_poser_02 import GeneralPoser02\r\nfrom tha4.poser.modes.pose_parameters import get_pose_parameters\r\nfrom torch import Tensor\r\nfrom torch.nn.functional import interpolate\r\n\r\n\r\nclass Network(Enum):\r\n    eyebrow_decomposer = 1\r\n    eyebrow_morphing_combiner = 2\r\n    face_morpher = 3\r\n    body_morpher = 4\r\n    upscaler = 5\r\n\r\n    @property\r\n    def outputs_key(self):\r\n        return f\"{self.name}_outputs\"\r\n\r\n\r\nclass Branch(Enum):\r\n    face_morphed_half = 1\r\n    face_morphed_full = 2\r\n    all_outputs = 3\r\n\r\n\r\nNUM_EYEBROW_PARAMS = 12\r\nNUM_FACE_PARAMS = 27\r\nNUM_ROTATION_PARAMS = 6\r\n\r\n\r\nclass FiveStepPoserComputationProtocol(CachedComputationProtocol):\r\n    def __init__(self, eyebrow_morphed_image_index: int):\r\n        super().__init__()\r\n        self.eyebrow_morphed_image_index = eyebrow_morphed_image_index\r\n        self.cached_batch_0 = None\r\n        self.cached_eyebrow_decomposer_output = None\r\n\r\n    def compute_func(self):\r\n        def func(state: ComputationState) -> List[Tensor]:\r\n            if self.cached_batch_0 is None:\r\n                new_batch_0 = True\r\n            elif state.batch[0].shape[0] != self.cached_batch_0.shape[0]:\r\n                new_batch_0 = True\r\n            else:\r\n                new_batch_0 = torch.max((state.batch[0] - self.cached_batch_0).abs()).item() > 0\r\n            if not new_batch_0:\r\n                state.outputs[Network.eyebrow_decomposer.outputs_key] = self.cached_eyebrow_decomposer_output\r\n            output = self.get_output(Branch.all_outputs.name, state)\r\n            if new_batch_0:\r\n                self.cached_batch_0 = state.batch[0]\r\n                self.cached_eyebrow_decomposer_output = state.outputs[Network.eyebrow_decomposer.outputs_key]\r\n            return output\r\n\r\n        return func\r\n\r\n    def compute_output(self, key: str, state: ComputationState) -> List[Tensor]:\r\n        if key == Network.eyebrow_decomposer.outputs_key:\r\n            input_image = state.batch[0][:, :, 64:192, 64 + 128:192 + 128]\r\n            return state.modules[Network.eyebrow_decomposer.name].forward(input_image)\r\n        elif key == Network.eyebrow_morphing_combiner.outputs_key:\r\n            eyebrow_decomposer_output = self.get_output(Network.eyebrow_decomposer.outputs_key, state)\r\n            background_layer = eyebrow_decomposer_output[EyebrowDecomposer00.BACKGROUND_LAYER_INDEX]\r\n            eyebrow_layer = eyebrow_decomposer_output[EyebrowDecomposer00.EYEBROW_LAYER_INDEX]\r\n            eyebrow_pose = state.batch[1][:, :NUM_EYEBROW_PARAMS]\r\n            return state.modules[Network.eyebrow_morphing_combiner.name].forward(\r\n                background_layer,\r\n                eyebrow_layer,\r\n                eyebrow_pose)\r\n        elif key == Network.face_morpher.outputs_key:\r\n            eyebrow_morphing_combiner_output = self.get_output(\r\n                Network.eyebrow_morphing_combiner.outputs_key, state)\r\n            eyebrow_morphed_image = eyebrow_morphing_combiner_output[self.eyebrow_morphed_image_index]\r\n            input_image = state.batch[0][:, :, 32:32 + 192, (32 + 128):(32 + 192 + 128)].clone()\r\n            input_image[:, :, 32:32 + 128, 32:32 + 128] = eyebrow_morphed_image\r\n            face_pose = state.batch[1][:, NUM_EYEBROW_PARAMS:NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS]\r\n            return state.modules[Network.face_morpher.name].forward(input_image, face_pose)\r\n        elif key == Branch.face_morphed_full.name:\r\n            face_morpher_output = self.get_output(Network.face_morpher.outputs_key, state)\r\n            face_morphed_image = face_morpher_output[0]\r\n            input_image = state.batch[0].clone()\r\n            input_image[:, :, 32:32 + 192, 32 + 128:32 + 192 + 128] = face_morphed_image\r\n            return [input_image]\r\n        elif key == Branch.face_morphed_half.name:\r\n            face_morphed_full = self.get_output(Branch.face_morphed_full.name, state)[0]\r\n            return [\r\n                interpolate(face_morphed_full, size=(256, 256), mode='bilinear', align_corners=False)\r\n            ]\r\n        elif key == Network.body_morpher.outputs_key:\r\n            face_morphed_half = self.get_output(Branch.face_morphed_half.name, state)[0]\r\n            rotation_pose = state.batch[1][:, NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS:]\r\n            return state.modules[Network.body_morpher.name].forward(face_morphed_half, rotation_pose)\r\n        elif key == Network.upscaler.outputs_key:\r\n            rest_image = self.get_output(Branch.face_morphed_full.name, state)[0]\r\n            body_morpher_outputs = self.get_output(\r\n                Network.body_morpher.outputs_key, state)\r\n            half_res_posed_image = body_morpher_outputs[Morpher00.INDEX_MERGED]\r\n            half_res_grid_change = body_morpher_outputs[Morpher00.INDEX_GRID_CHANGE]\r\n            coarse_posed_image = interpolate(half_res_posed_image, size=(512, 512), mode='bilinear')\r\n            coarse_grid_change = interpolate(half_res_grid_change, size=(512, 512), mode='bilinear')\r\n            rotation_pose = state.batch[1][:, NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS:]\r\n            return state.modules[Network.upscaler.name].forward(\r\n                rest_image, coarse_posed_image, coarse_grid_change, rotation_pose)\r\n        elif key == Branch.all_outputs.name:\r\n            upscaler_output = self.get_output(Network.upscaler.outputs_key, state)\r\n            face_morphed_full = self.get_output(Branch.face_morphed_full.name, state)\r\n            body_morpher_output = self.get_output(Network.body_morpher.outputs_key, state)\r\n            face_morpher_output = self.get_output(Network.face_morpher.outputs_key, state)\r\n            eyebrow_morphing_combiner_output = self.get_output(Network.eyebrow_morphing_combiner.outputs_key, state)\r\n            eyebrow_decomposer_output = self.get_output(Network.eyebrow_decomposer.outputs_key, state)\r\n            output = upscaler_output \\\r\n                     + face_morphed_full \\\r\n                     + body_morpher_output \\\r\n                     + face_morpher_output \\\r\n                     + eyebrow_morphing_combiner_output \\\r\n                     + eyebrow_decomposer_output\r\n            return output\r\n        else:\r\n            raise RuntimeError(\"Unsupported key: \" + key)\r\n\r\n\r\ndef load_eyebrow_decomposer(file_name: str):\r\n    factory = EyebrowDecomposer00Factory(\r\n        EyebrowDecomposer00Args(\r\n            image_size=128,\r\n            image_channels=4,\r\n            start_channels=64,\r\n            bottleneck_image_size=16,\r\n            num_bottleneck_blocks=6,\r\n            max_channels=512,\r\n            block_args=BlockArgs(\r\n                initialization_method='he',\r\n                use_spectral_norm=False,\r\n                normalization_layer_factory=InstanceNorm2dFactory(),\r\n                nonlinearity_factory=ReLUFactory(inplace=True))))\r\n    print(\"Loading the eyebrow decomposer ... \", end=\"\")\r\n    module = factory.create()\r\n    module.load_state_dict(torch_load(file_name))\r\n    print(\"DONE!!!\")\r\n    return module\r\n\r\n\r\ndef load_eyebrow_morphing_combiner(file_name: str):\r\n    factory = EyebrowMorphingCombiner00Factory(\r\n        EyebrowMorphingCombiner00Args(\r\n            image_size=128,\r\n            image_channels=4,\r\n            start_channels=64,\r\n            num_pose_params=12,\r\n            bottleneck_image_size=16,\r\n            num_bottleneck_blocks=6,\r\n            max_channels=512,\r\n            block_args=BlockArgs(\r\n                initialization_method='he',\r\n                use_spectral_norm=False,\r\n                normalization_layer_factory=InstanceNorm2dFactory(),\r\n                nonlinearity_factory=ReLUFactory(inplace=True))))\r\n    print(\"Loading the eyebrow morphing conbiner ... \", end=\"\")\r\n    module = factory.create()\r\n    module.load_state_dict(torch_load(file_name))\r\n    print(\"DONE!!!\")\r\n    return module\r\n\r\n\r\ndef load_face_morpher(file_name: str):\r\n    factory = FaceMorpher08Factory(\r\n        FaceMorpher08Args(\r\n            image_size=192,\r\n            image_channels=4,\r\n            num_expression_params=27,\r\n            start_channels=64,\r\n            bottleneck_image_size=24,\r\n            num_bottleneck_blocks=6,\r\n            max_channels=512,\r\n            block_args=BlockArgs(\r\n                initialization_method='he',\r\n                use_spectral_norm=False,\r\n                normalization_layer_factory=InstanceNorm2dFactory(),\r\n                nonlinearity_factory=ReLUFactory(inplace=False)\r\n            ),\r\n            output_iris_mouth_grid_change=True,\r\n        )\r\n    )\r\n    print(\"Loading the face morpher ... \", end=\"\")\r\n    module = factory.create()\r\n    module.load_state_dict(torch_load(file_name))\r\n    print(\"DONE!!!\")\r\n    return module\r\n\r\n\r\ndef apply_color_change(alpha, color_change, image: Tensor) -> Tensor:\r\n    return color_change * alpha + image * (1 - alpha)\r\n\r\n\r\ndef load_morpher_00(file_name: str):\r\n    unet_args = UnetArgs(\r\n        in_channels=4,\r\n        out_channels=7,\r\n        model_channels=64,\r\n        level_channel_multipliers=[1, 2, 4, 4, 4],\r\n        level_use_attention=[False, False, False, False, True],\r\n        num_res_blocks_per_level=1,\r\n        num_middle_res_blocks=4,\r\n        time_embedding_channels=None,\r\n        cond_input_channels=6,\r\n        cond_internal_channels=256,\r\n        attention_block_args=AttentionBlockArgs(\r\n            num_heads=8,\r\n            use_new_attention_order=True),\r\n        dropout_prob=0.0)\r\n    morpher_00_args = Morpher00Args(\r\n        image_size=256,\r\n        image_channels=4,\r\n        num_pose_parameters=6,\r\n        unet_args=unet_args)\r\n    morpher_00 = Morpher00(morpher_00_args)\r\n\r\n    print(\"Loading the body morpher ... \", end=\"\")\r\n    morpher_00.load_state_dict(torch_load(file_name))\r\n    print(\"DONE\")\r\n\r\n    morpher_00.train(False)\r\n    return morpher_00\r\n\r\n\r\ndef load_upscaler_02(file_name: str):\r\n    unet_args = UnetArgs(\r\n        in_channels=4,\r\n        out_channels=7,\r\n        model_channels=32,\r\n        level_channel_multipliers=[1, 2, 4, 8, 8, 8],\r\n        level_use_attention=[False, False, False, False, False, True],\r\n        num_res_blocks_per_level=1,\r\n        num_middle_res_blocks=4,\r\n        time_embedding_channels=None,\r\n        cond_input_channels=6,\r\n        cond_internal_channels=256,\r\n        attention_block_args=AttentionBlockArgs(\r\n            num_heads=8,\r\n            use_new_attention_order=True),\r\n        dropout_prob=0.0)\r\n    upscaler_02_args = Upscaler02Args(\r\n        image_size=512,\r\n        image_channels=4,\r\n        num_pose_parameters=6,\r\n        unet_args=unet_args)\r\n    upscaler_02 = Upscaler02(upscaler_02_args)\r\n\r\n    print(\"Loading the upscaler ... \", end=\"\")\r\n    upscaler_02.load_state_dict(torch_load(file_name))\r\n    print(\"DONE\")\r\n\r\n    upscaler_02.train(False)\r\n    return upscaler_02\r\n\r\n\r\ndef create_poser(\r\n        device: torch.device,\r\n        module_file_names: Optional[Dict[str, str]] = None,\r\n        eyebrow_morphed_image_index: int = EyebrowMorphingCombiner00.EYEBROW_IMAGE_NO_COMBINE_ALPHA_INDEX,\r\n        default_output_index: int = 0) -> GeneralPoser02:\r\n    if module_file_names is None:\r\n        module_file_names = {}\r\n    if Network.eyebrow_decomposer.name not in module_file_names:\r\n        file_name = \"data/tha4/eyebrow_decomposer.pt\"\r\n        module_file_names[Network.eyebrow_decomposer.name] = file_name\r\n    if Network.eyebrow_morphing_combiner.name not in module_file_names:\r\n        file_name = \"data/tha4/eyebrow_morphing_combiner.pt\"\r\n        module_file_names[Network.eyebrow_morphing_combiner.name] = file_name\r\n    if Network.face_morpher.name not in module_file_names:\r\n        file_name = \"data/tha4/face_morpher.pt\"\r\n        module_file_names[Network.face_morpher.name] = file_name\r\n    if Network.body_morpher.name not in module_file_names:\r\n        file_name = \"data/tha4/body_morpher.pt\"\r\n        module_file_names[Network.body_morpher.name] = file_name\r\n    if Network.upscaler.name not in module_file_names:\r\n        file_name = \"data/tha4/upscaler.pt\"\r\n        module_file_names[Network.upscaler.name] = file_name\r\n\r\n    loaders = {\r\n        Network.eyebrow_decomposer.name:\r\n            lambda: load_eyebrow_decomposer(module_file_names[Network.eyebrow_decomposer.name]),\r\n        Network.eyebrow_morphing_combiner.name:\r\n            lambda: load_eyebrow_morphing_combiner(module_file_names[Network.eyebrow_morphing_combiner.name]),\r\n        Network.face_morpher.name:\r\n            lambda: load_face_morpher(module_file_names[Network.face_morpher.name]),\r\n        Network.body_morpher.name:\r\n            lambda: load_morpher_00(module_file_names[Network.body_morpher.name]),\r\n        Network.upscaler.name:\r\n            lambda: load_upscaler_02(module_file_names[Network.upscaler.name]),\r\n    }\r\n    return GeneralPoser02(\r\n        image_size=512,\r\n        module_loaders=loaders,\r\n        pose_parameters=get_pose_parameters().get_pose_parameter_groups(),\r\n        output_list_func=FiveStepPoserComputationProtocol(eyebrow_morphed_image_index).compute_func(),\r\n        subrect=None,\r\n        device=device,\r\n        output_length=5 + 1 + 5 + 8 + 8 + 6,\r\n        default_output_index=default_output_index)\r\n"
  },
  {
    "path": "src/tha4/poser/modes/mode_12.py",
    "content": "from enum import Enum\r\nfrom typing import List, Dict, Optional, Any\r\n\r\nimport torch\r\nfrom tha4.shion.core.cached_computation import CachedComputationProtocol, ComputationState\r\nfrom tha4.shion.core.load_save import torch_load\r\nfrom tha4.nn.eyebrow_decomposer.eyebrow_decomposer_00 import EyebrowDecomposer00, \\\r\n    EyebrowDecomposer00Factory, EyebrowDecomposer00Args\r\nfrom tha4.nn.eyebrow_morphing_combiner.eyebrow_morphing_combiner_00 import \\\r\n    EyebrowMorphingCombiner00Factory, EyebrowMorphingCombiner00Args, EyebrowMorphingCombiner00\r\nfrom tha4.nn.face_morpher.face_morpher_08 import FaceMorpher08Args, FaceMorpher08Factory\r\nfrom tha4.nn.nonlinearity_factory import ReLUFactory\r\nfrom tha4.nn.normalization import InstanceNorm2dFactory\r\nfrom tha4.nn.util import BlockArgs\r\nfrom tha4.poser.general_poser_02 import GeneralPoser02\r\nfrom tha4.poser.modes.pose_parameters import get_pose_parameters\r\nfrom torch import Tensor\r\n\r\n\r\nclass Network(Enum):\r\n    eyebrow_decomposer = 1\r\n    eyebrow_morphing_combiner = 2\r\n    face_morpher = 3\r\n\r\n    @property\r\n    def outputs_key(self):\r\n        return f\"{self.name}_outputs\"\r\n\r\n\r\nclass Branch(Enum):\r\n    face_morphed_half = 1\r\n    face_morphed_full = 2\r\n    all_outputs = 3\r\n\r\n\r\nNUM_EYEBROW_PARAMS = 12\r\nNUM_FACE_PARAMS = 27\r\nNUM_ROTATION_PARAMS = 6\r\n\r\n\r\nclass FiveStepPoserComputationProtocol(CachedComputationProtocol):\r\n    def __init__(self, eyebrow_morphed_image_index: int):\r\n        super().__init__()\r\n        self.eyebrow_morphed_image_index = eyebrow_morphed_image_index\r\n        self.cached_batch_0 = None\r\n        self.cached_eyebrow_decomposer_output = None\r\n\r\n    def compute_func(self):\r\n        def func(state: ComputationState) -> List[Tensor]:\r\n            if self.cached_batch_0 is None:\r\n                new_batch_0 = True\r\n            elif state.batch[0].shape[0] != self.cached_batch_0.shape[0]:\r\n                new_batch_0 = True\r\n            else:\r\n                new_batch_0 = torch.max((state.batch[0] - self.cached_batch_0).abs()).item() > 0\r\n            if not new_batch_0:\r\n                state.outputs[Network.eyebrow_decomposer.outputs_key] = self.cached_eyebrow_decomposer_output\r\n            output = self.get_output(Branch.all_outputs.name, state)\r\n            if new_batch_0:\r\n                self.cached_batch_0 = state.batch[0]\r\n                self.cached_eyebrow_decomposer_output = state.outputs[Network.eyebrow_decomposer.outputs_key]\r\n            return output\r\n\r\n        return func\r\n\r\n    def compute_output(self, key: str, state: ComputationState) -> Any:\r\n        if key == Network.eyebrow_decomposer.outputs_key:\r\n            input_image = state.batch[0][:, :, 64:192, 64 + 128:192 + 128]\r\n            return state.modules[Network.eyebrow_decomposer.name].forward(input_image)\r\n        elif key == Network.eyebrow_morphing_combiner.outputs_key:\r\n            eyebrow_decomposer_output = self.get_output(Network.eyebrow_decomposer.outputs_key, state)\r\n            background_layer = eyebrow_decomposer_output[EyebrowDecomposer00.BACKGROUND_LAYER_INDEX]\r\n            eyebrow_layer = eyebrow_decomposer_output[EyebrowDecomposer00.EYEBROW_LAYER_INDEX]\r\n            eyebrow_pose = state.batch[1][:, :NUM_EYEBROW_PARAMS]\r\n            return state.modules[Network.eyebrow_morphing_combiner.name].forward(\r\n                background_layer,\r\n                eyebrow_layer,\r\n                eyebrow_pose)\r\n        elif key == Network.face_morpher.outputs_key:\r\n            eyebrow_morphing_combiner_output = self.get_output(\r\n                Network.eyebrow_morphing_combiner.outputs_key, state)\r\n            eyebrow_morphed_image = eyebrow_morphing_combiner_output[self.eyebrow_morphed_image_index]\r\n            input_image = state.batch[0][:, :, 32:32 + 192, (32 + 128):(32 + 192 + 128)].clone()\r\n            input_image[:, :, 32:32 + 128, 32:32 + 128] = eyebrow_morphed_image\r\n            face_pose = state.batch[1][:, NUM_EYEBROW_PARAMS:NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS]\r\n            return state.modules[Network.face_morpher.name].forward(input_image, face_pose)\r\n        elif key == Branch.all_outputs.name:\r\n            face_morpher_output = self.get_output(Network.face_morpher.outputs_key, state)\r\n            eyebrow_morphing_combiner_output = self.get_output(Network.eyebrow_morphing_combiner.outputs_key, state)\r\n            eyebrow_decomposer_output = self.get_output(Network.eyebrow_decomposer.outputs_key, state)\r\n            output = face_morpher_output \\\r\n                     + eyebrow_morphing_combiner_output \\\r\n                     + eyebrow_decomposer_output\r\n            return output\r\n        else:\r\n            raise RuntimeError(\"Unsupported key: \" + key)\r\n\r\n\r\ndef load_eyebrow_decomposer(file_name: str):\r\n    factory = EyebrowDecomposer00Factory(\r\n        EyebrowDecomposer00Args(\r\n            image_size=128,\r\n            image_channels=4,\r\n            start_channels=64,\r\n            bottleneck_image_size=16,\r\n            num_bottleneck_blocks=6,\r\n            max_channels=512,\r\n            block_args=BlockArgs(\r\n                initialization_method='he',\r\n                use_spectral_norm=False,\r\n                normalization_layer_factory=InstanceNorm2dFactory(),\r\n                nonlinearity_factory=ReLUFactory(inplace=True))))\r\n    print(\"Loading the eyebrow decomposer ... \", end=\"\")\r\n    module = factory.create()\r\n    module.load_state_dict(torch_load(file_name))\r\n    print(\"DONE!!!\")\r\n    return module\r\n\r\n\r\ndef load_eyebrow_morphing_combiner(file_name: str):\r\n    factory = EyebrowMorphingCombiner00Factory(\r\n        EyebrowMorphingCombiner00Args(\r\n            image_size=128,\r\n            image_channels=4,\r\n            start_channels=64,\r\n            num_pose_params=12,\r\n            bottleneck_image_size=16,\r\n            num_bottleneck_blocks=6,\r\n            max_channels=512,\r\n            block_args=BlockArgs(\r\n                initialization_method='he',\r\n                use_spectral_norm=False,\r\n                normalization_layer_factory=InstanceNorm2dFactory(),\r\n                nonlinearity_factory=ReLUFactory(inplace=True))))\r\n    print(\"Loading the eyebrow morphing conbiner ... \", end=\"\")\r\n    module = factory.create()\r\n    module.load_state_dict(torch_load(file_name))\r\n    print(\"DONE!!!\")\r\n    return module\r\n\r\n\r\ndef load_face_morpher(file_name: str):\r\n    factory = FaceMorpher08Factory(\r\n        FaceMorpher08Args(\r\n            image_size=192,\r\n            image_channels=4,\r\n            num_expression_params=27,\r\n            start_channels=64,\r\n            bottleneck_image_size=24,\r\n            num_bottleneck_blocks=6,\r\n            max_channels=512,\r\n            block_args=BlockArgs(\r\n                initialization_method='he',\r\n                use_spectral_norm=False,\r\n                normalization_layer_factory=InstanceNorm2dFactory(),\r\n                nonlinearity_factory=ReLUFactory(inplace=False)),\r\n            output_iris_mouth_grid_change=True))\r\n    print(\"Loading the face morpher ... \", end=\"\")\r\n    module = factory.create()\r\n    module.load_state_dict(torch_load(file_name))\r\n    print(\"DONE!!!\")\r\n    return module\r\n\r\n\r\ndef apply_color_change(alpha, color_change, image: Tensor) -> Tensor:\r\n    return color_change * alpha + image * (1 - alpha)\r\n\r\n\r\ndef create_poser(\r\n        device: torch.device,\r\n        module_file_names: Optional[Dict[str, str]] = None,\r\n        eyebrow_morphed_image_index: int = EyebrowMorphingCombiner00.EYEBROW_IMAGE_NO_COMBINE_ALPHA_INDEX,\r\n        default_output_index: int = 0) -> GeneralPoser02:\r\n    if module_file_names is None:\r\n        module_file_names = {}\r\n    if Network.eyebrow_decomposer.name not in module_file_names:\r\n        file_name = \"data/tha4/eyebrow_decomposer.pt\"\r\n        module_file_names[Network.eyebrow_decomposer.name] = file_name\r\n    if Network.eyebrow_morphing_combiner.name not in module_file_names:\r\n        file_name = \"data/tha4/eyebrow_morphing_combiner.pt\"\r\n        module_file_names[Network.eyebrow_morphing_combiner.name] = file_name\r\n    if Network.face_morpher.name not in module_file_names:\r\n        file_name = \"data/tha4/face_morpher.pt\"\r\n        module_file_names[Network.face_morpher.name] = file_name\r\n\r\n    loaders = {\r\n        Network.eyebrow_decomposer.name:\r\n            lambda: load_eyebrow_decomposer(module_file_names[Network.eyebrow_decomposer.name]),\r\n        Network.eyebrow_morphing_combiner.name:\r\n            lambda: load_eyebrow_morphing_combiner(module_file_names[Network.eyebrow_morphing_combiner.name]),\r\n        Network.face_morpher.name:\r\n            lambda: load_face_morpher(module_file_names[Network.face_morpher.name]),\r\n    }\r\n    return GeneralPoser02(\r\n        image_size=512,\r\n        module_loaders=loaders,\r\n        pose_parameters=get_pose_parameters().get_pose_parameter_groups(),\r\n        output_list_func=FiveStepPoserComputationProtocol(eyebrow_morphed_image_index).compute_func(),\r\n        subrect=None,\r\n        device=device,\r\n        output_length=5 + 5 + 8,\r\n        default_output_index=default_output_index)\r\n"
  },
  {
    "path": "src/tha4/poser/modes/mode_14.py",
    "content": "from dataclasses import dataclass\r\nfrom typing import List, Optional, Dict, Any\r\n\r\nimport torch\r\nfrom tha4.shion.core.cached_computation import CachedComputationProtocol, ComputationState\r\nfrom tha4.shion.core.load_save import torch_load\r\nfrom tha4.nn.siren.face_morpher.siren_face_morpher_00 import SirenFaceMorpher00Args, SirenFaceMorpher00\r\nfrom tha4.nn.siren.morpher.siren_morpher_03 import SirenMorpher03, SirenMorpher03Args, SirenMorpherLevelArgs\r\nfrom tha4.nn.siren.vanilla.siren import SirenArgs\r\nfrom tha4.poser.general_poser_02 import GeneralPoser02\r\nfrom tha4.poser.modes.pose_parameters import get_pose_parameters\r\nfrom torch import Tensor\r\n\r\nKEY_FACE_MORPHER = \"face_morpher\"\r\nKEY_BODY_MORPHER = \"body_morpher\"\r\n\r\n\r\n@dataclass\r\nclass Keys:\r\n    face_morpher: str = KEY_FACE_MORPHER\r\n    face_morpher_output: str = \"face_morpher_output\"\r\n\r\n    face_morpher_input_image: str = \"face_morpher_input_image\"\r\n    face_morpher_input_pose: str = \"face_morpher_input_pose\"\r\n\r\n    body_morpher_input_image: str = \"body_morpher_input_image\"\r\n\r\n    body_morpher: str = KEY_BODY_MORPHER\r\n    body_morpher_output: str = \"body_morpher_output\"\r\n\r\n    all_outputs: str = \"all_outputs\"\r\n\r\n\r\n@dataclass\r\nclass Indices:\r\n    original_image: int = 0\r\n    original_pose: int = 1\r\n\r\n\r\nclass TwoStepPoserComputationProtocol(CachedComputationProtocol):\r\n    def __init__(self, keys: Optional[Keys] = None, indices: Optional[Indices] = None):\r\n        super().__init__()\r\n\r\n        if keys is None:\r\n            keys = Keys()\r\n        if indices is None:\r\n            indices = Indices()\r\n\r\n        self.keys = keys\r\n        self.indices = indices\r\n\r\n    def compute_func(self):\r\n        def func(state: ComputationState) -> List[Tensor]:\r\n            return self.get_output(self.keys.all_outputs, state)\r\n\r\n        return func\r\n\r\n    def compute_output(self, key: str, state: ComputationState) -> Any:\r\n        if key == self.keys.face_morpher_input_image:\r\n            image = state.batch[self.indices.original_image]\r\n            center_x = 256\r\n            center_y = 128 + 16\r\n            return image[:, :, center_y - 64:center_y + 64, center_x - 64:center_x + 64]\r\n        elif key == self.keys.face_morpher_input_pose:\r\n            pose = state.batch[self.indices.original_pose]\r\n            return pose[:, 0:39]\r\n        elif key == self.keys.face_morpher_output:\r\n            module = state.modules[self.keys.face_morpher]\r\n            pose = self.get_output(self.keys.face_morpher_input_pose, state)\r\n            with torch.no_grad():\r\n                return module.forward(pose)\r\n        elif key == self.keys.body_morpher_input_image:\r\n            image = state.batch[self.indices.original_image].clone()\r\n            center_x = 256\r\n            center_y = 128 + 16\r\n            face_morphed_image = self.get_output(self.keys.face_morpher_output, state)\r\n            image[:, :, center_y - 64:center_y + 64, center_x - 64:center_x + 64] = face_morphed_image\r\n            return image\r\n        elif key == self.keys.body_morpher_output:\r\n            image = self.get_output(self.keys.body_morpher_input_image, state)\r\n            pose = state.batch[self.indices.original_pose]\r\n            body_morpher = state.modules[self.keys.body_morpher]\r\n            with torch.no_grad():\r\n                return body_morpher.forward(image, pose)\r\n        elif key == self.keys.all_outputs:\r\n            body_morpher_output = self.get_output(self.keys.body_morpher_output, state)\r\n            face_morpher_output = self.get_output(self.keys.face_morpher_output, state)\r\n            return body_morpher_output + [face_morpher_output]\r\n        else:\r\n            raise RuntimeError(\"Unsupported key: \" + key)\r\n\r\n\r\ndef load_face_morpher(file_name: Optional[str] = None):\r\n    module = SirenFaceMorpher00(\r\n        SirenFaceMorpher00Args(\r\n            image_size=128,\r\n            image_channels=4,\r\n            pose_size=39,\r\n            siren_args=SirenArgs(\r\n                in_channels=39 + 2,\r\n                out_channels=4,\r\n                intermediate_channels=128,\r\n                num_sine_layers=8)))\r\n    if file_name is not None:\r\n        module.load_state_dict(torch_load(file_name))\r\n    return module\r\n\r\n\r\ndef load_body_morpher(file_name: Optional[str] = None):\r\n    module = SirenMorpher03(\r\n        SirenMorpher03Args(\r\n            image_size=512,\r\n            image_channels=4,\r\n            pose_size=45,\r\n            level_args=[\r\n                SirenMorpherLevelArgs(\r\n                    image_size=128,\r\n                    intermediate_channels=360,\r\n                    num_sine_layers=3),\r\n                SirenMorpherLevelArgs(\r\n                    image_size=256,\r\n                    intermediate_channels=180,\r\n                    num_sine_layers=3),\r\n                SirenMorpherLevelArgs(\r\n                    image_size=512,\r\n                    intermediate_channels=90,\r\n                    num_sine_layers=3),\r\n            ]))\r\n    if file_name is not None:\r\n        module.load_state_dict(torch_load(file_name))\r\n    return module\r\n\r\n\r\ndef create_poser(\r\n        device: torch.device,\r\n        module_file_names: Optional[Dict[str, str]] = None,\r\n        default_output_index: int = 0) -> GeneralPoser02:\r\n    if module_file_names is None:\r\n        module_file_names = {}\r\n    if KEY_FACE_MORPHER not in module_file_names:\r\n        file_name = \"data/character_models/lambda_00/face_morpher.pt\"\r\n        module_file_names[KEY_FACE_MORPHER] = file_name\r\n    if KEY_BODY_MORPHER not in module_file_names:\r\n        file_name = \"data/character_models/lambda_00/body_morpher.pt\"\r\n        module_file_names[KEY_BODY_MORPHER] = file_name\r\n\r\n    loaders = {\r\n        KEY_FACE_MORPHER:\r\n            lambda: load_face_morpher(module_file_names[KEY_FACE_MORPHER]),\r\n        KEY_BODY_MORPHER:\r\n            lambda: load_body_morpher(module_file_names[KEY_BODY_MORPHER]),\r\n    }\r\n\r\n    return GeneralPoser02(\r\n        image_size=512,\r\n        module_loaders=loaders,\r\n        pose_parameters=get_pose_parameters().get_pose_parameter_groups(),\r\n        output_list_func=TwoStepPoserComputationProtocol().compute_func(),\r\n        subrect=None,\r\n        device=device,\r\n        output_length=5 + 1,\r\n        default_output_index=default_output_index)\r\n"
  },
  {
    "path": "src/tha4/poser/modes/pose_parameters.py",
    "content": "from tha4.poser.poser import PoseParameters, PoseParameterCategory\r\n\r\n\r\ndef get_pose_parameters():\r\n    return PoseParameters.Builder() \\\r\n        .add_parameter_group(\"eyebrow_troubled\", PoseParameterCategory.EYEBROW, arity=2) \\\r\n        .add_parameter_group(\"eyebrow_angry\", PoseParameterCategory.EYEBROW, arity=2) \\\r\n        .add_parameter_group(\"eyebrow_lowered\", PoseParameterCategory.EYEBROW, arity=2) \\\r\n        .add_parameter_group(\"eyebrow_raised\", PoseParameterCategory.EYEBROW, arity=2) \\\r\n        .add_parameter_group(\"eyebrow_happy\", PoseParameterCategory.EYEBROW, arity=2) \\\r\n        .add_parameter_group(\"eyebrow_serious\", PoseParameterCategory.EYEBROW, arity=2) \\\r\n        .add_parameter_group(\"eye_wink\", PoseParameterCategory.EYE, arity=2) \\\r\n        .add_parameter_group(\"eye_happy_wink\", PoseParameterCategory.EYE, arity=2) \\\r\n        .add_parameter_group(\"eye_surprised\", PoseParameterCategory.EYE, arity=2) \\\r\n        .add_parameter_group(\"eye_relaxed\", PoseParameterCategory.EYE, arity=2) \\\r\n        .add_parameter_group(\"eye_unimpressed\", PoseParameterCategory.EYE, arity=2) \\\r\n        .add_parameter_group(\"eye_raised_lower_eyelid\", PoseParameterCategory.EYE, arity=2) \\\r\n        .add_parameter_group(\"iris_small\", PoseParameterCategory.IRIS_MORPH, arity=2) \\\r\n        .add_parameter_group(\"mouth_aaa\", PoseParameterCategory.MOUTH, arity=1, default_value=1.0) \\\r\n        .add_parameter_group(\"mouth_iii\", PoseParameterCategory.MOUTH, arity=1) \\\r\n        .add_parameter_group(\"mouth_uuu\", PoseParameterCategory.MOUTH, arity=1) \\\r\n        .add_parameter_group(\"mouth_eee\", PoseParameterCategory.MOUTH, arity=1) \\\r\n        .add_parameter_group(\"mouth_ooo\", PoseParameterCategory.MOUTH, arity=1) \\\r\n        .add_parameter_group(\"mouth_delta\", PoseParameterCategory.MOUTH, arity=1) \\\r\n        .add_parameter_group(\"mouth_lowered_corner\", PoseParameterCategory.MOUTH, arity=2) \\\r\n        .add_parameter_group(\"mouth_raised_corner\", PoseParameterCategory.MOUTH, arity=2) \\\r\n        .add_parameter_group(\"mouth_smirk\", PoseParameterCategory.MOUTH, arity=1) \\\r\n        .add_parameter_group(\"iris_rotation_x\", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \\\r\n        .add_parameter_group(\"iris_rotation_y\", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \\\r\n        .add_parameter_group(\"head_x\", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \\\r\n        .add_parameter_group(\"head_y\", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \\\r\n        .add_parameter_group(\"neck_z\", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \\\r\n        .add_parameter_group(\"body_y\", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \\\r\n        .add_parameter_group(\"body_z\", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \\\r\n        .add_parameter_group(\"breathing\", PoseParameterCategory.BREATHING, arity=1, range=(0.0, 1.0)) \\\r\n        .build()"
  },
  {
    "path": "src/tha4/poser/poser.py",
    "content": "from abc import ABC, abstractmethod\r\nfrom enum import Enum\r\nfrom typing import Tuple, List, Optional\r\n\r\nimport torch\r\nfrom torch import Tensor\r\n\r\n\r\nclass PoseParameterCategory(Enum):\r\n    EYEBROW = 1\r\n    EYE = 2\r\n    IRIS_MORPH = 3\r\n    IRIS_ROTATION = 4\r\n    MOUTH = 5\r\n    FACE_ROTATION = 6\r\n    BODY_ROTATION = 7\r\n    BREATHING = 8\r\n\r\n\r\nclass PoseParameterGroup:\r\n    def __init__(self,\r\n                 group_name: str,\r\n                 parameter_index: int,\r\n                 category: PoseParameterCategory,\r\n                 arity: int = 1,\r\n                 discrete: bool = False,\r\n                 default_value: float = 0.0,\r\n                 range: Optional[Tuple[float, float]] = None):\r\n        assert arity == 1 or arity == 2\r\n        if range is None:\r\n            range = (0.0, 1.0)\r\n        if arity == 1:\r\n            parameter_names = [group_name]\r\n        else:\r\n            parameter_names = [group_name + \"_left\", group_name + \"_right\"]\r\n        assert len(parameter_names) == arity\r\n\r\n        self.parameter_names = parameter_names\r\n        self.range = range\r\n        self.default_value = default_value\r\n        self.discrete = discrete\r\n        self.arity = arity\r\n        self.category = category\r\n        self.parameter_index = parameter_index\r\n        self.group_name = group_name\r\n\r\n    def get_arity(self) -> int:\r\n        return self.arity\r\n\r\n    def get_group_name(self) -> str:\r\n        return self.group_name\r\n\r\n    def get_parameter_names(self) -> List[str]:\r\n        return self.parameter_names\r\n\r\n    def is_discrete(self) -> bool:\r\n        return self.discrete\r\n\r\n    def get_range(self) -> Tuple[float, float]:\r\n        return self.range\r\n\r\n    def get_default_value(self):\r\n        return self.default_value\r\n\r\n    def get_parameter_index(self):\r\n        return self.parameter_index\r\n\r\n    def get_category(self) -> PoseParameterCategory:\r\n        return self.category\r\n\r\n\r\nclass PoseParameters:\r\n    def __init__(self, pose_parameter_groups: List[PoseParameterGroup]):\r\n        self.pose_parameter_groups = pose_parameter_groups\r\n\r\n    def get_parameter_index(self, name: str) -> int:\r\n        index = 0\r\n        for parameter_group in self.pose_parameter_groups:\r\n            for param_name in parameter_group.parameter_names:\r\n                if name == param_name:\r\n                    return index\r\n                index += 1\r\n        raise RuntimeError(\"Cannot find parameter with name %s\" % name)\r\n\r\n    def get_parameter_name(self, index: int) -> str:\r\n        assert index >= 0 and index < self.get_parameter_count()\r\n\r\n        for group in self.pose_parameter_groups:\r\n            if index < group.get_arity():\r\n                return group.get_parameter_names()[index]\r\n            index -= group.arity\r\n\r\n        raise RuntimeError(\"Something is wrong here!!!\")\r\n\r\n    def get_pose_parameter_groups(self):\r\n        return self.pose_parameter_groups\r\n\r\n    def get_parameter_count(self):\r\n        count = 0\r\n        for group in self.pose_parameter_groups:\r\n            count += group.arity\r\n        return count\r\n\r\n    class Builder:\r\n        def __init__(self):\r\n            self.index = 0\r\n            self.pose_parameter_groups = []\r\n\r\n        def add_parameter_group(self,\r\n                                group_name: str,\r\n                                category: PoseParameterCategory,\r\n                                arity: int = 1,\r\n                                discrete: bool = False,\r\n                                default_value: float = 0.0,\r\n                                range: Optional[Tuple[float, float]] = None):\r\n            self.pose_parameter_groups.append(\r\n                PoseParameterGroup(\r\n                    group_name,\r\n                    self.index,\r\n                    category,\r\n                    arity,\r\n                    discrete,\r\n                    default_value,\r\n                    range))\r\n            self.index += arity\r\n            return self\r\n\r\n        def build(self) -> 'PoseParameters':\r\n            return PoseParameters(self.pose_parameter_groups)\r\n\r\n\r\nclass Poser(ABC):\r\n    @abstractmethod\r\n    def get_image_size(self) -> int:\r\n        pass\r\n\r\n    @abstractmethod\r\n    def get_output_length(self) -> int:\r\n        pass\r\n\r\n    @abstractmethod\r\n    def get_pose_parameter_groups(self) -> List[PoseParameterGroup]:\r\n        pass\r\n\r\n    @abstractmethod\r\n    def get_num_parameters(self) -> int:\r\n        pass\r\n\r\n    @abstractmethod\r\n    def pose(self, image: Tensor, pose: Tensor, output_index: int = 0) -> Tensor:\r\n        pass\r\n\r\n    @abstractmethod\r\n    def get_posing_outputs(self, image: Tensor, pose: Tensor) -> List[Tensor]:\r\n        pass\r\n\r\n    def get_dtype(self) -> torch.dtype:\r\n        return torch.float\r\n\r\n    @abstractmethod\r\n    def to(self, device: torch.device):\r\n        pass"
  },
  {
    "path": "src/tha4/pytasuku/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/pytasuku/indexed/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/pytasuku/indexed/all_tasks.py",
    "content": "from typing import Iterable\n\nfrom tha4.pytasuku.workspace import Workspace\nfrom tha4.pytasuku.indexed.indexed_tasks import IndexedTasks\nfrom tha4.pytasuku.indexed.no_index_command_tasks import NoIndexCommandTasks\n\n\nclass AllTasks(NoIndexCommandTasks):\n    def __init__(\n            self,\n            workspace: Workspace, prefix: str,\n            tasks: Iterable[IndexedTasks],\n            command_name: str = \"all\",\n            define_tasks_immediately: bool = True):\n        super().__init__(workspace, prefix, command_name, define_tasks_immediately)\n        self.tasks = [t for t in tasks]\n        if define_tasks_immediately:\n            self.define_tasks()\n\n    def execute_run_command(self):\n        for task in self.tasks:\n            self.workspace.run(task.run_command)\n\n    def execute_clean_command(self):\n        for task in self.tasks:\n            self.workspace.run(task.clean_command)"
  },
  {
    "path": "src/tha4/pytasuku/indexed/bundled_indexed_file_tasks.py",
    "content": "import abc\nfrom typing import Iterable, List\n\nfrom tha4.pytasuku.workspace import Workspace\nfrom tha4.pytasuku.indexed.indexed_tasks import IndexedTasks\nfrom tha4.pytasuku.workspace import do_nothing\n\n\nclass BundledIndexedTasks:\n    __metaclass__ = abc.ABCMeta\n\n    @property\n    @abc.abstractmethod\n    def indexed_tasks_command_names(self) -> Iterable[str]:\n        pass\n\n    @abc.abstractmethod\n    def get_indexed_tasks(self, command_name) -> IndexedTasks:\n        pass\n\n\ndef define_all_tasks_from_list(workspace: Workspace, prefix: str, tasks: List[BundledIndexedTasks]):\n    for command_name in tasks[0].indexed_tasks_command_names:\n        workspace.create_command_task(\n            prefix + \"/\" + command_name,\n            [x.get_indexed_tasks(command_name).run_command for x in tasks],\n            do_nothing)\n        workspace.create_command_task(\n            prefix + \"/\" + command_name + \"_clean\",\n            [x.get_indexed_tasks(command_name).clean_command for x in tasks],\n            do_nothing)\n"
  },
  {
    "path": "src/tha4/pytasuku/indexed/indexed_file_tasks.py",
    "content": "import abc\nfrom typing import List\n\nfrom tha4.pytasuku.workspace import Workspace\nfrom tha4.pytasuku.indexed.indexed_tasks import IndexedTasks\n\n\nclass IndexedFileTasks(IndexedTasks, abc.ABC):\n    def __init__(self, workspace: Workspace, prefix: str):\n        super().__init__(workspace, prefix)\n\n    @property\n    @abc.abstractmethod\n    def file_list(self) -> List[str]:\n        pass\n\n    @abc.abstractmethod\n    def get_file_name(self, *indices: int) -> str:\n        pass\n"
  },
  {
    "path": "src/tha4/pytasuku/indexed/indexed_tasks.py",
    "content": "import abc\nfrom typing import List\n\nfrom tha4.pytasuku.workspace import Workspace\n\n\nclass IndexedTasks(abc.ABC):\n    def __init__(self, workspace: Workspace, prefix: str):\n        self.prefix = prefix\n        self.workspace = workspace\n\n    @property\n    @abc.abstractmethod\n    def run_command(self) -> str:\n        pass\n\n    @property\n    @abc.abstractmethod\n    def clean_command(self) -> str:\n        pass\n\n    @property\n    @abc.abstractmethod\n    def shape(self) -> List[int]:\n        pass\n\n    @property\n    @abc.abstractmethod\n    def arity(self) -> int:\n        pass\n\n    @abc.abstractmethod\n    def define_tasks(self):\n        pass\n"
  },
  {
    "path": "src/tha4/pytasuku/indexed/no_index_command_tasks.py",
    "content": "import abc\nfrom typing import List\n\nfrom tha4.pytasuku.workspace import Workspace\nfrom tha4.pytasuku.indexed.indexed_tasks import IndexedTasks\n\n\nclass NoIndexCommandTasks(IndexedTasks, abc.ABC):\n    def __init__(self, workspace: Workspace, prefix: str, command_name: str, define_tasks_immediately: bool = True):\n        super().__init__(workspace, prefix)\n        self.command_name = command_name\n        if define_tasks_immediately:\n            self.define_tasks()\n\n    @property\n    def run_command(self):\n        return self.prefix + \"/\" + self.command_name\n\n    @property\n    def clean_command(self):\n        return self.prefix + \"/\" + self.command_name + \"_clean\"\n\n    @property\n    def arity(self) -> int:\n        return 0\n\n    @property\n    def shape(self) -> List[int]:\n        return []\n\n    @abc.abstractmethod\n    def execute_run_command(self):\n        pass\n\n    @abc.abstractmethod\n    def execute_clean_command(self):\n        pass\n\n    def define_tasks(self):\n        self.workspace.create_command_task(self.run_command, [], self.execute_run_command)\n        self.workspace.create_command_task(self.clean_command, [], self.execute_clean_command)\n"
  },
  {
    "path": "src/tha4/pytasuku/indexed/no_index_file_tasks.py",
    "content": "import abc\nfrom typing import List\n\nfrom tha4.pytasuku.workspace import Workspace\nfrom tha4.pytasuku.indexed.indexed_file_tasks import IndexedFileTasks\nfrom tha4.pytasuku.indexed.util import delete_file\n\n\nclass NoIndexFileTasks(IndexedFileTasks, abc.ABC):\n    def __init__(self, workspace: Workspace, prefix: str, command_name: str, define_tasks_immediately: bool = True):\n        super().__init__(workspace, prefix)\n        self.command_name = command_name\n        if define_tasks_immediately:\n            self.define_tasks()\n\n    @property\n    @abc.abstractmethod\n    def file_name(self):\n        pass\n\n    @abc.abstractmethod\n    def create_file_task(self):\n        pass\n\n    def get_file_name(self, *indices: int) -> str:\n        if len(indices) > 0:\n            raise IndexError(\"NoIndexFileTasks has arity 0, but get_file_name is called with an index.\")\n        return self.file_name\n\n    @property\n    def run_command(self):\n        return self.prefix + \"/\" + self.command_name\n\n    @property\n    def clean_command(self):\n        return self.prefix + \"/\" + self.command_name + \"_clean\"\n\n    @property\n    def arity(self) -> int:\n        return 0\n\n    @property\n    def shape(self) -> List[int]:\n        return []\n\n    @property\n    def file_list(self) -> List[str]:\n        return [self.file_name]\n\n    def clean(self):\n        delete_file(self.file_name)\n\n    def define_tasks(self):\n        self.create_file_task()\n        self.workspace.create_command_task(self.run_command, [self.file_name])\n        self.workspace.create_command_task(self.clean_command, [], lambda: self.clean())"
  },
  {
    "path": "src/tha4/pytasuku/indexed/one_index_file_tasks.py",
    "content": "import abc\n\nfrom typing import List\n\nfrom tha4.pytasuku.workspace import Workspace\nfrom tha4.pytasuku.indexed.indexed_file_tasks import IndexedFileTasks\nfrom tha4.pytasuku.indexed.util import delete_file\n\n\nclass OneIndexFileTasks(IndexedFileTasks, abc.ABC):\n    def __init__(self, workspace: Workspace, prefix: str, command_name: str, count: int,\n                 define_tasks_immediately: bool = True):\n        super().__init__(workspace, prefix)\n        self.command_name = command_name\n        self.count = count\n        self.file_list_ = []\n        if define_tasks_immediately:\n            self.define_tasks()\n\n    @property\n    def run_command(self) -> str:\n        return self.prefix + \"/\" + self.command_name\n\n    @property\n    def clean_command(self) -> str:\n        return self.prefix + \"/\" + self.command_name + \"_clean\"\n\n    @property\n    def shape(self) -> List[int]:\n        return [self.count]\n\n    @property\n    def arity(self) -> int:\n        return 1\n\n    @abc.abstractmethod\n    def file_name(self, index):\n        pass\n\n    @abc.abstractmethod\n    def create_file_tasks(self, index):\n        pass\n\n    def get_file_name(self, *indices: int) -> str:\n        if len(indices) != 1:\n            raise IndexError(\"OneIndexFileTasks has arity 1, but \"\n                             \"get_file_name does not get the appropriate number of arguments.\")\n        return self.file_name(indices[0])\n\n    @property\n    def file_list(self):\n        if len(self.file_list_) == 0:\n            for i in range(self.count):\n                self.file_list_.append(self.file_name(i))\n        return self.file_list_\n\n    def clean(self):\n        for file in self.file_list:\n            delete_file(file)\n\n    def define_tasks(self):\n        for index in range(self.count):\n            self.create_file_tasks(index)\n        dependencies = self.file_list\n        self.workspace.create_command_task(self.run_command, dependencies)\n        self.workspace.create_command_task(self.clean_command, [], lambda: self.clean())\n"
  },
  {
    "path": "src/tha4/pytasuku/indexed/simple_no_index_file_tasks.py",
    "content": "from typing import Callable, List, Optional\r\n\r\nfrom tha4.pytasuku.workspace import Workspace\r\nfrom tha4.pytasuku.indexed.no_index_file_tasks import NoIndexFileTasks\r\n\r\n\r\nclass SimpleNoIndexFileTasks(NoIndexFileTasks):\r\n    def __init__(self,\r\n                 workspace: Workspace,\r\n                 prefix: str,\r\n                 command_name: str,\r\n                 file_name: str,\r\n                 run_func: Callable[[], None],\r\n                 dependencies: Optional[List[str]] = None):\r\n        super().__init__(workspace, prefix, command_name, define_tasks_immediately=False)\r\n        if dependencies is None:\r\n            dependencies = []\r\n        self.run_func = run_func\r\n        self._file_name = file_name\r\n        self.dependencies = dependencies\r\n        self.define_tasks()\r\n\r\n    @property\r\n    def file_name(self):\r\n        return self._file_name\r\n\r\n    def create_file_task(self):\r\n        self.workspace.create_file_task(self.file_name, self.dependencies, self.run_func)\r\n"
  },
  {
    "path": "src/tha4/pytasuku/indexed/two_indices_file_tasks.py",
    "content": "import abc\nfrom typing import List\n\nfrom tha4.pytasuku.workspace import Workspace\nfrom tha4.pytasuku.indexed.indexed_file_tasks import IndexedFileTasks\nfrom tha4.pytasuku.indexed.util import delete_file\n\n\nclass TwoIndicesFileTasks(IndexedFileTasks, abc.ABC):\n    def __init__(self, workspace: Workspace, prefix: str, command_name: str,\n                 count0: int, count1: int, define_tasks_immediately: bool = True):\n        super().__init__(workspace, prefix)\n        self.count1 = count1\n        self.count0 = count0\n        self.command_name = command_name\n        self.file_list_ = []\n        if define_tasks_immediately:\n            self.define_tasks()\n\n    @property\n    def run_command(self) -> str:\n        return self.prefix + \"/\" + self.command_name\n\n    @property\n    def clean_command(self) -> str:\n        return self.prefix + \"/\" + self.command_name + \"_clean\"\n\n    @property\n    def shape(self) -> List[int]:\n        return [self.count0, self.count1]\n\n    @property\n    def arity(self) -> int:\n        return 2\n\n    @abc.abstractmethod\n    def file_name(self, index0: int, index1: int) -> str:\n        pass\n\n    @property\n    def file_list(self) -> List[str]:\n        if len(self.file_list_) == 0:\n            for i in range(self.count0):\n                for j in range(self.count1):\n                    self.file_list_.append(self.file_name(i, j))\n        return self.file_list_\n\n    @abc.abstractmethod\n    def create_file_tasks(self, index0: int, index1: int):\n        pass\n\n    def get_file_name(self, *indices: int) -> str:\n        if len(indices) != 2:\n            raise IndexError(\"TwoIndicesFileTasks.get_file_name require two indices, \" +\n                             \"but not exactly 2 indices were provide\")\n        return self.file_name(indices[0], indices[1])\n\n    def clean(self):\n        for file in self.file_list:\n            delete_file(file)\n\n    def define_tasks(self):\n        for index0 in range(self.count0):\n            for index1 in range(self.count1):\n                self.create_file_tasks(index0, index1)\n        self.workspace.create_command_task(self.run_command, self.file_list)\n        self.workspace.create_command_task(self.clean_command, [], lambda: self.clean())\n"
  },
  {
    "path": "src/tha4/pytasuku/indexed/util.py",
    "content": "import os\nfrom typing import Iterable, Dict, Callable, List\n\nfrom tha4.pytasuku.workspace import Workspace\nfrom tha4.pytasuku.indexed.all_tasks import AllTasks\nfrom tha4.pytasuku.indexed.indexed_tasks import IndexedTasks\n\n\ndef delete_file(file_name):\n    if os.path.exists(file_name):\n        os.remove(file_name)\n        print(\"[delete] \" + file_name)\n    else:\n        print(\"[not exist] \" + file_name)\n\n\ndef all_tasks_from_named_tasks_map(\n        workspace: Workspace,\n        prefix: str,\n        tasks: Iterable[Dict[str, IndexedTasks]],\n        define_all_tasks: bool = True) \\\n        -> Dict[str, IndexedTasks]:\n    subtasks = [x for x in tasks]\n    name_to_subtask_list = {}\n    for a_subtasks in subtasks:\n        for name in a_subtasks:\n            if not define_all_tasks and name == \"all\":\n                continue\n            if name not in name_to_subtask_list:\n                name_to_subtask_list[name] = []\n            name_to_subtask_list[name].append(a_subtasks[name])\n    output = {}\n    for name in name_to_subtask_list:\n        output[name] = AllTasks(workspace, prefix, name_to_subtask_list[name], name)\n    return output\n\n\ndef create_tasks_hierarchy_helper(\n        workspace: Workspace,\n        prefix: str,\n        tasks_func: Callable[[Workspace, str, List[str]], Dict[str, IndexedTasks]],\n        branches: List[List[str]],\n        path: List[str]):\n    if len(branches) == 0:\n        return tasks_func(workspace, prefix, path)\n    else:\n        tasks = {}\n        for branch in branches[0]:\n            output_tasks = create_tasks_hierarchy_helper(\n                workspace,\n                f\"{prefix}/{branch}\",\n                tasks_func,\n                branches[1:],\n                path + [branch])\n            if output_tasks is not None:\n                tasks[branch] = output_tasks\n        return all_tasks_from_named_tasks_map(workspace, prefix, tasks.values())\n\n\ndef create_task_hierarchy(\n        workspace: Workspace,\n        prefix: str,\n        tasks_func: Callable[[Workspace, str, List[str]], Dict[str, IndexedTasks]],\n        branches: List[List[str]]) -> Dict[str, IndexedTasks]:\n    return create_tasks_hierarchy_helper(workspace, prefix, tasks_func, branches, [])\n\n\ndef write_done_file(file_name: str):\n    os.makedirs(os.path.dirname(file_name), exist_ok=True)\n    with open(file_name, \"wt\") as fout:\n        fout.write(\"DONE!!!\")"
  },
  {
    "path": "src/tha4/pytasuku/task.py",
    "content": "import os\nimport logging\nfrom typing import List\n\n\nclass Task:\n    def __init__(self, workspace: 'Workspace', name: str, dependencies: List[str]):\n        self._workspace = workspace\n        self._name = name\n        self._dependencies = dependencies\n        self._workspace.add_task(self)\n\n    def run(self):\n        pass\n\n    @property\n    def can_run(self) -> bool:\n        return True\n\n    @property\n    def needs_to_be_run(self) -> bool:\n        return False\n\n    @property\n    def name(self) -> str:\n        return self._name\n\n    @property\n    def dependencies(self) -> List[str]:\n        return self._dependencies\n\n    @property\n    def workspace(self) -> 'Workspace':\n        return self._workspace\n\n    @property\n    def timestamp(self) -> float:\n        return float(\"inf\")\n\n\nclass CommandTask(Task):\n    def __init__(self, workspace, name, dependencies):\n        super().__init__(workspace, name, dependencies)\n\n    @property\n    def needs_to_be_run(self):\n        return True\n\n\nclass PlaceholderTask(Task):\n    def __init__(self, workspace, name):\n        super().__init__(workspace, name, [])\n\n    @property\n    def can_run(self):\n        return False\n\n    def run(self):\n        raise Exception(\"A  placeholder task cannot be run! (%s)\" % self.name)\n\n    @property\n    def needs_to_be_run(self):\n        return not os.path.isfile(self.name)\n\n    @property\n    def timestamp(self) -> float:\n        if not os.path.isfile(self.name):\n            return float(\"inf\")\n        else:\n            return os.path.getmtime(self.name)\n\n\nclass FileTask(Task):\n    def __init__(self, workspace, name, dependencies):\n        super().__init__(workspace, name, dependencies)\n\n    @property\n    def timestamp(self):\n        return os.path.getmtime(self.name)\n\n    @property\n    def needs_to_be_run(self):\n        if not os.path.isfile(self.name):\n            logging.info(\"Task %s will be run because the corresponding file does not exist.\" % self.name)\n            return True\n        for dep in self.dependencies:\n            if self.workspace.needs_to_run(dep):\n                logging.info(\"Task %s will be run because dependency %s also needs to be run.\" % (self.name, dep))\n                return True\n            else:\n                self_timestamp = self.timestamp\n                dep_task = self.workspace.get_task(dep)\n                if dep_task.timestamp > self_timestamp:\n                    if isinstance(dep_task, FileTask) or isinstance(dep_task, PlaceholderTask):\n                        logging.info(\"Task %s needs to be run because task %s has later timestamp.\" %\n                                     (self.name, dep))\n                    elif isinstance(dep_task, CommandTask):\n                        logging.info(\"Task %s needs to be run because task %s is a command.\" % (self.name, dep))\n                    return True\n        return False\n"
  },
  {
    "path": "src/tha4/pytasuku/task_selector_ui.py",
    "content": "from tkinter import Tk, BOTH, Button, RIGHT, Scrollbar\nfrom tkinter.ttk import Frame, Treeview\n\nfrom tha4.pytasuku.workspace import Workspace, PlaceholderTask\n\n\nclass TaskSelectorUi(Frame):\n    def __init__(self, root, workspace: Workspace):\n        super().__init__()\n        self.root = root\n        self.workspace = workspace\n        self.master.title(\"Tasks\")\n        self.master.geometry(\"256x512\")\n\n        treeview_frame = Frame(self)\n        treeview_frame.pack(fill=BOTH, expand=True)\n\n        self.treeview = Treeview(treeview_frame)\n        self.treeview[\"columns\"] = (\"task_name\")\n        self.treeview.column(\"#0\", width=256, minwidth=256)\n        self.treeview.heading(\"#0\", text=\"Tree\")\n        self.treeview.heading(\"task_name\", text=\"Task Name\")\n\n        treeview_vertical_scroll = Scrollbar(treeview_frame,\n                                             orient='vertical',\n                                             command=self.treeview.yview)\n        self.treeview.configure(yscrollcommand=treeview_vertical_scroll.set)\n        treeview_vertical_scroll.pack(side=RIGHT, fill='y')\n        self.treeview.pack(fill=BOTH, expand=True)\n\n        treeview_horizontal_scroll = Scrollbar(treeview_frame,\n                                               orient='horizontal',\n                                               command=self.treeview.xview)\n        self.treeview.configure(xscrollcommand=treeview_horizontal_scroll.set)\n        treeview_horizontal_scroll.pack(fill='x')\n\n        self.add_tree_nodes()\n\n        self.execute_button = Button(self, text=\"Execute!\", command=self.run_selected_task)\n        self.execute_button.pack(side=RIGHT, padx=5, pady=5)\n\n        self.pack(fill=BOTH, expand=True)\n\n        self.selected_task_name = None\n\n    def add_tree_nodes(self):\n        nodes = {}\n\n        for task in self.workspace._tasks.values():\n            if isinstance(task, PlaceholderTask):\n                continue\n            comps = task.name.split('/')\n            for i in range(1, len(comps)):\n                assert len(comps) > 0\n            prefix = \"\"\n            index = 0\n            for comp in comps:\n                index = index + 1\n                parent = prefix\n                if prefix == \"\" and comp == \"\":\n                    prefix = \"/\"\n                elif prefix == \"\":\n                    prefix = prefix + comp\n                elif prefix == \"/\":\n                    prefix = prefix + comp\n                else:\n                    prefix = prefix + \"/\" + comp\n                if prefix in nodes:\n                    continue\n                if index == len(comps):\n                    data = prefix\n                else:\n                    data = \"\"\n                if prefix == \"/\":\n                    comp = \"/\"\n                nodes[prefix] = {\n                    \"name\": str(prefix),\n                    \"display_name\": comp,\n                    \"parent\": parent,\n                    \"data\": data\n                }\n\n        sorted_node_names = sorted(nodes.keys())\n        node_index = {}\n        for name in sorted_node_names:\n            node = nodes[name]\n            if node[\"parent\"] == \"\":\n                id = self.treeview.insert(\"\", \"end\", text=node[\"display_name\"], values=node[\"data\"], )\n            else:\n                parent = node_index[node[\"parent\"]]\n                id = self.treeview.insert(parent, \"end\", text=node[\"display_name\"], values=node[\"data\"], )\n            node_index[node[\"name\"]] = id\n\n    def run_selected_task(self):\n        selection = self.treeview.selection()\n        item = self.treeview.item(selection)\n        if item['values'] == \"\":\n            return\n        task_name = item[\"values\"][0]\n        self.selected_task_name = task_name\n        self.root.destroy()\n\n\ndef run_task_selector_ui(workspace: Workspace):\n    root = Tk()\n    task_selector_ui = TaskSelectorUi(root, workspace=workspace)\n    root.mainloop()\n\n    task_name = task_selector_ui.selected_task_name\n    if task_name is not None:\n        print(\"Running\", task_name, \"...\")\n        with workspace.session():\n            workspace.run(task_name)\n"
  },
  {
    "path": "src/tha4/pytasuku/util.py",
    "content": "import os.path\nfrom typing import List\nimport logging\n\nfrom tha4.pytasuku.workspace import Workspace\n\n\ndef create_delete_all_task(workspace: Workspace, name: str, files: List[str]):\n    def delete_all():\n        for file in files:\n            if os.path.exists(file):\n                logging.info(\"Removing %s ...\" % file)\n                os.remove(file)\n\n    workspace.create_command_task(name, [], delete_all)\n"
  },
  {
    "path": "src/tha4/pytasuku/workspace.py",
    "content": "from contextlib import contextmanager\nfrom enum import Enum\nfrom typing import List\n\nfrom tha4.pytasuku.task import Task, CommandTask, FileTask, PlaceholderTask\n\n\nclass WorkspaceState(Enum):\n    OUT_OF_SESSION = 1\n    IN_SESSION = 2\n\n\nclass NodeState(Enum):\n    IN_STACK = 1\n    VISITED = 2\n\n\nclass FuncCommandTask(CommandTask):\n    def __init__(self, workspace, name, dependencies, func):\n        super().__init__(workspace, name, dependencies)\n        self._func = func\n\n    def run(self):\n        self._func()\n\n\nclass FuncFileTask(FileTask):\n    def __init__(self, workspace, name, dependencies, func):\n        super().__init__(workspace, name, dependencies)\n        self._func = func\n\n    def run(self):\n        self._func()\n\n\ndef do_nothing():\n    pass\n\n\nclass Workspace:\n    def __init__(self):\n        self._tasks = dict()\n        self._name_to_done = None\n        self._state = WorkspaceState.OUT_OF_SESSION\n        self._modified = False\n\n    @property\n    def modified(self) -> bool:\n        return self._modified\n\n    @property\n    def state(self) -> WorkspaceState:\n        return self._state\n\n    @property\n    def in_session(self) -> bool:\n        return self._state == WorkspaceState.IN_SESSION\n\n    def task_exists(self, name: str) -> bool:\n        return name in self._tasks\n\n    def task_exists_and_not_placeholder(self, name: str) -> bool:\n        return self.task_exists(name) and not isinstance(self.get_task(name), PlaceholderTask)\n\n    def get_task(self, name: str) -> Task:\n        return self._tasks[name]\n\n    def add_task(self, task):\n        if self.in_session:\n            raise RuntimeError(\"New tasks can only be created when the workspace is out of session.\")\n        if isinstance(task, PlaceholderTask):\n            if not self.task_exists(task.name):\n                self._tasks[task.name] = task\n                self._modified = True\n        else:\n            self._tasks[task.name] = task\n            for dep in task.dependencies:\n                PlaceholderTask(self, dep)\n            self._modified = True\n\n    def start_session(self):\n        if self.in_session:\n            raise RuntimeError(\"A session can only be started when the workspace is out of session.\")\n        if self.modified:\n            self.check_cycle()\n        self._state = WorkspaceState.IN_SESSION\n        self._name_to_done = dict()\n        self._modified = False\n\n    def end_session(self):\n        if not self.in_session:\n            raise RuntimeError(\"A session can only be ended when the workspace is in session.\")\n        self._state = WorkspaceState.OUT_OF_SESSION\n        self._name_to_done = None\n\n    @contextmanager\n    def session(self):\n        try:\n            self.start_session()\n            yield\n        finally:\n            self.end_session()\n\n    def check_cycle(self):\n        node_states = dict()\n        for name in self._tasks:\n            if name not in node_states:\n                self.dfs(name, node_states)\n\n    def dfs(self, name, node_states):\n        node_states[name] = NodeState.IN_STACK\n        task = self.get_task(name)\n        for dep in task.dependencies:\n            if dep not in node_states:\n                self.dfs(dep, node_states)\n            else:\n                state = node_states[dep]\n                if state == NodeState.IN_STACK:\n                    raise RuntimeError(\"Dicovered cyclic dependency!\")\n        node_states[name] = NodeState.VISITED\n\n    def run(self, name):\n        if not self.in_session:\n            raise RuntimeError(\"A task can only be run when the workspace is in session.\")\n        if not self.task_exists(name):\n            raise RuntimeError(\"Task %s does not exists\" % name)\n        self.run_helper(name)\n\n    def run_helper(self, name):\n        task = self.get_task(name)\n        for dep in task.dependencies:\n            if self.needs_to_run(dep):\n                self.run_helper(dep)\n        if self.needs_to_run(name):\n            task.run()\n            self._name_to_done[name] = True\n\n    def needs_to_run(self, name):\n        if not self.in_session:\n            raise RuntimeError(\"You can only check whether a task needs to run when the workspace is in session.\")\n        if name in self._name_to_done:\n            return not self._name_to_done[name]\n        task = self.get_task(name)\n        need_to_run_value = task.needs_to_be_run\n        self._name_to_done[name] = not need_to_run_value\n        return need_to_run_value\n\n    def create_command_task(self, name, dependencies, func=do_nothing):\n        return FuncCommandTask(self, name, dependencies, func)\n\n    def create_file_task(self, name, dependencies, func):\n        return FuncFileTask(self, name, dependencies, func)\n\n\ndef command_task(workspace: Workspace, name: str, dependencies: List[str]):\n    def func(f):\n        workspace.create_command_task(name, dependencies, f)\n        return f\n\n    return func\n\n\ndef file_task(workspace: Workspace, name: str, dependencies: List[str]):\n    def func(f):\n        workspace.create_file_task(name, dependencies, f)\n        return f\n\n    return func\n"
  },
  {
    "path": "src/tha4/sampleoutput/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/sampleoutput/general_sample_output_protocol.py",
    "content": "import os\r\nfrom enum import Enum\r\nfrom typing import List, Dict\r\n\r\nimport PIL.Image\r\nimport numpy\r\nimport torch\r\nfrom tha4.shion.base.dataset.util import get_indexed_batch\r\nfrom tha4.shion.base.image_util import pytorch_rgb_to_numpy_image, pytorch_rgba_to_numpy_image\r\nfrom tha4.shion.core.cached_computation import TensorCachedComputationFunc, ComputationState\r\nfrom tha4.shion.core.training.sample_output_protocol import SampleOutputProtocol\r\nfrom tha4.image_util import grid_change_to_numpy_image\r\nfrom torch.nn import Module\r\nfrom torch.nn.functional import interpolate\r\nfrom torch.utils.data import Dataset\r\n\r\n\r\nclass ImageType(Enum):\r\n    COLOR = 0\r\n    ALPHA = 1\r\n    GRID_CHANGE = 2\r\n    SIGMOID_LOGIT = 3\r\n\r\n\r\nclass SampleImageSpec:\r\n    def __init__(self, value_func: TensorCachedComputationFunc, image_type: ImageType):\r\n        self.value_func = value_func\r\n        self.image_type = image_type\r\n\r\n\r\nclass SampleImageSaver:\r\n    def __init__(self,\r\n                 cell_size: int,\r\n                 image_channels: int,\r\n                 sample_image_specs: List[SampleImageSpec]):\r\n        super().__init__()\r\n        self.sample_image_specs = sample_image_specs\r\n        self.cell_size = cell_size\r\n        self.image_channels = image_channels\r\n\r\n    def save_sample_output_data(self,\r\n                                state: ComputationState,\r\n                                prefix: str,\r\n                                examples_seen_so_far: int):\r\n        num_cols = len(self.sample_image_specs)\r\n        num_rows = state.batch[0].shape[0]\r\n        output_image = numpy.zeros([self.cell_size * num_rows, self.cell_size * num_cols, self.image_channels])\r\n\r\n        for col in range(num_cols):\r\n            spec = self.sample_image_specs[col]\r\n            images = spec.value_func(state)\r\n            start_col = col * self.cell_size\r\n\r\n            for image_index in range(num_rows):\r\n                image = images[image_index].clone().detach()\r\n                row = image_index\r\n                start_row = row * self.cell_size\r\n                if spec.image_type == ImageType.COLOR:\r\n                    c, h, w = image.shape\r\n                    green_screen = torch.ones(3, h, w, device=image.device) * -1.0\r\n                    green_screen[1, :, :] = 1.0\r\n                    alpha = (image[3:4, :, :] + 1.0) * 0.5\r\n                    image[0:3, :, :] = image[0:3, :, :] * alpha + green_screen * (1 - alpha)\r\n                    image[3:4, :, :] = 1.0\r\n                    image = image.cpu()\r\n                elif spec.image_type == ImageType.GRID_CHANGE:\r\n                    image = image.cpu()\r\n                elif spec.image_type == ImageType.SIGMOID_LOGIT:\r\n                    image = torch.sigmoid(image)\r\n                    image = image.repeat(self.image_channels, 1, 1)\r\n                    image = image * 2.0 - 1.0\r\n                    image = image.cpu()\r\n                elif spec.image_type == ImageType.ALPHA:\r\n                    if image.shape[0] == 1:\r\n                        image = image.repeat(self.image_channels, 1, 1)\r\n                    image = image * 2.0 - 1.0\r\n                    image = image.cpu()\r\n                else:\r\n                    raise RuntimeError(f\"Unsupported image type: {spec.image_type}\")\r\n\r\n                output_image[start_row:start_row + self.cell_size, start_col:start_col + self.cell_size, :] \\\r\n                    = self.convert_to_numpy_image(image)\r\n\r\n        file_name = \"%s/sample_output_%010d.png\" % (prefix, examples_seen_so_far)\r\n        os.makedirs(os.path.dirname(file_name), exist_ok=True)\r\n        if self.image_channels == 3:\r\n            mode = 'RGB'\r\n        else:\r\n            mode = 'RGBA'\r\n        pil_image = PIL.Image.fromarray(numpy.uint8(numpy.rint(output_image * 255.0)), mode=mode)\r\n        pil_image.save(file_name)\r\n        print(\"Saved %s\" % file_name)\r\n\r\n    def convert_to_numpy_image(self, image: torch.Tensor):\r\n        image_size = image.shape[-1]\r\n        if self.cell_size != image_size:\r\n            image = interpolate(image.unsqueeze(0), size=self.cell_size).squeeze(0)\r\n        if image.shape[0] == 2:\r\n            return grid_change_to_numpy_image(image, num_channels=self.image_channels)\r\n        elif self.image_channels == 3:\r\n            return pytorch_rgb_to_numpy_image(image)\r\n        else:\r\n            return pytorch_rgba_to_numpy_image(image)\r\n\r\n\r\nclass GeneralSampleOutputProtocol(SampleOutputProtocol):\r\n    def __init__(self,\r\n                 sample_image_specs: List[SampleImageSpec],\r\n                 num_images: int = 8,\r\n                 cell_size: int = 256,\r\n                 image_channels: int = 4,\r\n                 examples_per_sample_output: int = 5000,\r\n                 random_seed: int = 1203040687):\r\n        super().__init__()\r\n        self.num_images = num_images\r\n        self.random_seed = random_seed\r\n        self.examples_per_sample_output = examples_per_sample_output\r\n        self.sample_image_saver = SampleImageSaver(cell_size, image_channels, sample_image_specs)\r\n\r\n    def get_examples_per_sample_output(self) -> int:\r\n        return self.examples_per_sample_output\r\n\r\n    def get_random_seed(self) -> int:\r\n        return self.random_seed\r\n\r\n    def get_sample_output_data(self, validation_dataset: Dataset, device: torch.device) -> dict:\r\n        example_indices = torch.randint(0, len(validation_dataset), (self.num_images,))\r\n        example_indices = [example_indices[i].item() for i in range(self.num_images)]\r\n        batch = get_indexed_batch(validation_dataset, example_indices, device)\r\n        return {'batch': batch}\r\n\r\n    def save_sample_output_data(self,\r\n                                modules: Dict[str, Module],\r\n                                accumulated_modules: Dict[str, Module],\r\n                                sample_output_data: dict, prefix: str,\r\n                                examples_seen_so_far: int,\r\n                                device: torch.device):\r\n        for key in modules:\r\n            modules[key].train(False)\r\n        batch = sample_output_data['batch']\r\n        state = ComputationState(modules, accumulated_modules, batch, {})\r\n        self.sample_image_saver.save_sample_output_data(state, prefix, examples_seen_so_far)\r\n"
  },
  {
    "path": "src/tha4/sampleoutput/poser_sampler_output_protocol.py",
    "content": "from typing import Optional, List, Dict\r\n\r\nimport torch\r\nfrom torch.nn import Module\r\nfrom torch.utils.data import Dataset\r\n\r\nfrom tha4.shion.base.dataset.util import get_indexed_batch\r\nfrom tha4.shion.core.cached_computation import CachedComputationFunc, ComputationState\r\nfrom tha4.shion.core.training.sample_output_protocol import SampleOutputProtocol\r\nfrom tha4.sampleoutput.sample_image_creator import SampleImageSpec, ImageSource, ImageType, SampleImageSaver\r\n\r\n\r\nclass PoserSampleOutputProtocol(SampleOutputProtocol):\r\n    def __init__(self,\r\n                 output_list_func: Optional[CachedComputationFunc] = None,\r\n                 num_images: int = 8,\r\n                 image_size: int = 256,\r\n                 cell_size: int = 256,\r\n                 image_channels: int = 4,\r\n                 examples_per_sample_output: int = 5000,\r\n                 sample_image_specs: Optional[List[SampleImageSpec]] = None,\r\n                 random_seed: int = 1203040687):\r\n        super().__init__()\r\n        self.num_images = num_images\r\n        self.random_seed = random_seed\r\n        self.examples_per_sample_output = examples_per_sample_output\r\n        self.output_list_func = output_list_func\r\n\r\n        if sample_image_specs is None:\r\n            sample_image_specs = [\r\n                SampleImageSpec(ImageSource.BATCH, 0, ImageType.COLOR),\r\n                SampleImageSpec(ImageSource.BATCH, 2, ImageType.COLOR),\r\n                SampleImageSpec(ImageSource.OUTPUT, 0, ImageType.COLOR),\r\n                SampleImageSpec(ImageSource.OUTPUT, 1, ImageType.COLOR),\r\n                SampleImageSpec(ImageSource.OUTPUT, 2, ImageType.ALPHA),\r\n                SampleImageSpec(ImageSource.BATCH, 3, ImageType.COLOR),\r\n                SampleImageSpec(ImageSource.OUTPUT, 3, ImageType.COLOR),\r\n                SampleImageSpec(ImageSource.OUTPUT, 4, ImageType.COLOR),\r\n                SampleImageSpec(ImageSource.OUTPUT, 5, ImageType.ALPHA),\r\n                SampleImageSpec(ImageSource.OUTPUT, 6, ImageType.COLOR),\r\n                SampleImageSpec(ImageSource.BATCH, 4, ImageType.COLOR),\r\n                SampleImageSpec(ImageSource.OUTPUT, 7, ImageType.COLOR),\r\n                SampleImageSpec(ImageSource.OUTPUT, 8, ImageType.COLOR),\r\n                SampleImageSpec(ImageSource.OUTPUT, 9, ImageType.ALPHA),\r\n                SampleImageSpec(ImageSource.OUTPUT, 10, ImageType.COLOR),\r\n            ]\r\n\r\n        self.sample_image_saver = SampleImageSaver(image_size, cell_size, image_channels, sample_image_specs)\r\n\r\n    def get_examples_per_sample_output(self) -> int:\r\n        return self.examples_per_sample_output\r\n\r\n    def get_random_seed(self) -> int:\r\n        return self.random_seed\r\n\r\n    def get_sample_output_data(self, validation_dataset: Dataset, device: torch.device) -> dict:\r\n        example_indices = torch.randint(0, len(validation_dataset), (self.num_images,))\r\n        example_indices = [example_indices[i].item() for i in range(self.num_images)]\r\n        batch = get_indexed_batch(validation_dataset, example_indices, device)\r\n        return {'batch': batch}\r\n\r\n    def save_sample_output_data(self,\r\n                                modules: Dict[str, Module],\r\n                                accumulated_modules: Dict[str, Module],\r\n                                sample_output_data: dict, prefix: str,\r\n                                examples_seen_so_far: int,\r\n                                device: torch.device):\r\n        for key in modules:\r\n            modules[key].train(False)\r\n        batch = sample_output_data['batch']\r\n        with torch.no_grad():\r\n            outputs = self.output_list_func(ComputationState(modules, accumulated_modules, batch))\r\n        self.sample_image_saver.save_sample_output_data(batch, outputs, prefix, examples_seen_so_far)\r\n"
  },
  {
    "path": "src/tha4/sampleoutput/sample_image_creator.py",
    "content": "import math\r\nimport os\r\nfrom enum import Enum\r\nfrom typing import List\r\n\r\nimport numpy\r\nimport torch\r\nfrom matplotlib import cm\r\nfrom torch import Tensor\r\nfrom torch.nn.functional import interpolate\r\n\r\nfrom tha4.shion.base.image_util import save_numpy_image\r\n\r\n\r\nclass ImageSource(Enum):\r\n    BATCH = 0\r\n    OUTPUT = 1\r\n\r\n\r\nclass ImageType(Enum):\r\n    COLOR = 0\r\n    ALPHA = 1\r\n    GRID_CHANGE = 2\r\n    SIGMOID_LOGIT = 3\r\n\r\n\r\nclass SampleImageSpec:\r\n    def __init__(self, image_source: ImageSource, index: int, image_type: ImageType):\r\n        self.image_type = image_type\r\n        self.index = index\r\n        self.image_source = image_source\r\n\r\n\r\ndef torch_rgb_to_numpy_image(torch_image: Tensor, min_pixel_value=-1.0, max_pixel_value=1.0):\r\n    assert torch_image.dim() == 3\r\n    assert torch_image.shape[0] == 3\r\n    height = torch_image.shape[1]\r\n    width = torch_image.shape[2]\r\n\r\n    reshaped_image = torch_image.numpy().reshape(3, height * width).transpose().reshape(height, width, 3)\r\n    numpy_image = (reshaped_image - min_pixel_value) / (max_pixel_value - min_pixel_value)\r\n    return numpy_image\r\n\r\n\r\ndef torch_rgba_to_numpy_image(torch_image: Tensor, min_pixel_value=-1.0, max_pixel_value=1.0):\r\n    assert torch_image.dim() == 3\r\n    assert torch_image.shape[0] == 4\r\n    height = torch_image.shape[1]\r\n    width = torch_image.shape[2]\r\n\r\n    reshaped_image = torch_image.numpy().reshape(4, height * width).transpose().reshape(height, width, 4)\r\n    numpy_image = (reshaped_image - min_pixel_value) / (max_pixel_value - min_pixel_value)\r\n    numpy_image = numpy.clip(numpy_image, 0.0, 1.0)\r\n    return numpy_image\r\n\r\n\r\ndef torch_grid_change_to_numpy_image(torch_image, num_channels=3):\r\n    height = torch_image.shape[1]\r\n    width = torch_image.shape[2]\r\n    size_image = (torch_image[0, :, :] ** 2 + torch_image[1, :, :] ** 2).sqrt().view(height, width, 1).numpy()\r\n    hsv = cm.get_cmap('hsv')\r\n    angle_image = hsv(((torch.atan2(\r\n        torch_image[0, :, :].view(height * width),\r\n        torch_image[1, :, :].view(height * width)).view(height, width) + math.pi) / (2 * math.pi)).numpy()) * 3\r\n    numpy_image = size_image * angle_image[:, :, 0:3]\r\n    if num_channels == 3:\r\n        return numpy_image\r\n    elif num_channels == 4:\r\n        return numpy.concatenate([numpy_image, numpy.ones_like(size_image)], axis=2)\r\n    else:\r\n        raise RuntimeError(\"Unsupported num_channels: \" + str(num_channels))\r\n\r\n\r\nclass SampleImageSaver:\r\n    def __init__(self,\r\n                 image_size: int,\r\n                 cell_size: int,\r\n                 image_channels: int,\r\n                 sample_image_specs: List[SampleImageSpec]):\r\n        super().__init__()\r\n        self.sample_image_specs = sample_image_specs\r\n        self.cell_size = cell_size\r\n        self.image_channels = image_channels\r\n        self.image_size = image_size\r\n\r\n    def save_sample_output_image(self, batch: List[Tensor], outputs: List[Tensor], file_name: str):\r\n        num_cols = len(self.sample_image_specs)\r\n\r\n        num_rows = batch[0].shape[0]\r\n        output_image = numpy.zeros([self.cell_size * num_rows, self.cell_size * num_cols, self.image_channels])\r\n\r\n        for image_index in range(num_rows):\r\n            row = image_index\r\n            start_row = row * self.cell_size\r\n\r\n            for col in range(num_cols):\r\n                spec = self.sample_image_specs[col]\r\n                start_col = col * self.cell_size\r\n\r\n                if spec.image_source == ImageSource.BATCH:\r\n                    image = batch[spec.index][image_index].clone()\r\n                else:\r\n                    image = outputs[spec.index][image_index].clone()\r\n\r\n                if spec.image_type == ImageType.COLOR:\r\n                    c, h, w = image.shape\r\n                    green_screen = torch.ones(3, h, w, device=image.device) * -1.0\r\n                    green_screen[1, :, :] = 1.0\r\n                    alpha = (image[3:4, :, :] + 1.0) * 0.5\r\n                    image[0:3, :, :] = image[0:3, :, :] * alpha + green_screen * (1 - alpha)\r\n                    image[3:4, :, :] = 1.0\r\n                    image = image.detach().cpu()\r\n                elif spec.image_type == ImageType.GRID_CHANGE:\r\n                    image = image.detach().cpu()\r\n                elif spec.image_type == ImageType.SIGMOID_LOGIT:\r\n                    image = torch.sigmoid(image)\r\n                    image = image.repeat(self.image_channels, 1, 1)\r\n                    image = image * 2.0 - 1.0\r\n                    image = image.detach().cpu()\r\n                else:\r\n                    if image.shape[0] == 1:\r\n                        image = image.repeat(self.image_channels, 1, 1)\r\n                    image = image * 2.0 - 1.0\r\n                    image = image.detach().cpu()\r\n\r\n                output_image[start_row:start_row + self.cell_size, start_col:start_col + self.cell_size, :] \\\r\n                    = self.convert_to_numpy_image(image)\r\n\r\n        os.makedirs(os.path.dirname(file_name), exist_ok=True)\r\n        save_numpy_image(output_image, file_name, save_straight_alpha=True)\r\n\r\n    def save_sample_output_data(self,\r\n                                batch: List[Tensor],\r\n                                outputs: List[Tensor],\r\n                                prefix: str,\r\n                                examples_seen_so_far: int):\r\n        file_name = \"%s/sample_output_%010d.png\" % (prefix, examples_seen_so_far)\r\n        self.save_sample_output_image(batch, outputs, file_name)\r\n\r\n    def convert_to_numpy_image(self, image: torch.Tensor):\r\n        if self.cell_size != self.image_size:\r\n            image = interpolate(image.unsqueeze(0), size=self.cell_size).squeeze(0)\r\n        if image.shape[0] == 2:\r\n            return torch_grid_change_to_numpy_image(image, num_channels=self.image_channels)\r\n        elif self.image_channels == 3:\r\n            return torch_rgb_to_numpy_image(image)\r\n        else:\r\n            return torch_rgba_to_numpy_image(image)\r\n"
  },
  {
    "path": "src/tha4/shion/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/shion/base/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/shion/base/dataset/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/shion/base/dataset/lazy_dataset.py",
    "content": "from typing import Callable\r\n\r\nfrom torch.utils.data import Dataset\r\n\r\n\r\nclass LazyDataset(Dataset):\r\n    def __init__(self, source_func: Callable[[], Dataset]):\r\n        self.source_func = source_func\r\n        self.source = None\r\n\r\n    def get_source(self):\r\n        if self.source is None:\r\n            self.source = self.source_func()\r\n        return self.source\r\n\r\n    def __len__(self):\r\n        return len(self.get_source())\r\n\r\n    def __getitem__(self, item):\r\n        return self.get_source()[item]\r\n"
  },
  {
    "path": "src/tha4/shion/base/dataset/lazy_tensor_dataset.py",
    "content": "import torch\r\nfrom torch.utils.data import Dataset, TensorDataset\r\n\r\nfrom tha4.shion.core.load_save import torch_load\r\n\r\n\r\nclass LazyTensorDataset(Dataset):\r\n    def __init__(self, file_name: str):\r\n        self.file_name = file_name\r\n        self.dataset = None\r\n\r\n    def get_dataset(self):\r\n        if self.dataset is None:\r\n            data = torch_load(self.file_name)\r\n            if isinstance(data, torch.Tensor):\r\n                self.dataset = TensorDataset(data)\r\n            elif isinstance(data, tuple):\r\n                self.dataset = TensorDataset(*data)\r\n            elif isinstance(data, list):\r\n                self.dataset = TensorDataset(*data)\r\n            else:\r\n                raise RuntimeError(\"Unsupported data type: \" + type(data))\r\n        return self.dataset\r\n\r\n    def __len__(self):\r\n        dataset = self.get_dataset()\r\n        return len(dataset)\r\n\r\n    def __getitem__(self, item):\r\n        dataset = self.get_dataset()\r\n        return dataset.__getitem__(item)\r\n\r\n\r\n"
  },
  {
    "path": "src/tha4/shion/base/dataset/png_in_dir_dataset.py",
    "content": "import os\r\n\r\nfrom torch.nn import functional\r\nfrom torch.utils.data import Dataset\r\nfrom os import listdir\r\nfrom os.path import isfile\r\n\r\nfrom tha4.shion.base.image_util import extract_pytorch_image_from_filelike\r\n\r\n\r\nclass PngInDirDataset(Dataset):\r\n    def __init__(self, dir: str,\r\n                 downscale_kernel_size: int = 1,\r\n                 has_alpha=False,\r\n                 scale=2.0,\r\n                 offset=-1.0,\r\n                 premultiply_alpha=True,\r\n                 perfrom_srb_to_linear=True):\r\n        super().__init__()\r\n        self.perfrom_srb_to_linear = perfrom_srb_to_linear\r\n        self.premultiply_alpha = premultiply_alpha\r\n        self.offset = offset\r\n        self.scale = scale\r\n        self.has_alpha = has_alpha\r\n        self.downscale_kernel_size = downscale_kernel_size\r\n        self.dir = dir\r\n        self.file_names = None\r\n\r\n    def get_file_names(self):\r\n        if self.file_names is None:\r\n            self.file_names = [os.path.join(self.dir, x) for x in listdir(self.dir)]\r\n            self.file_names = [x for x in self.file_names if isfile(x) and x[-4:].lower() == \".png\"]\r\n            self.file_names = sorted(self.file_names)\r\n        return self.file_names\r\n\r\n    def __len__(self):\r\n        file_names = self.get_file_names()\r\n        return len(file_names)\r\n\r\n    def __getitem__(self, item):\r\n        file_names = self.get_file_names()\r\n        file_name = file_names[item]\r\n        image = extract_pytorch_image_from_filelike(\r\n            file_name,\r\n            scale=self.scale,\r\n            offset=self.offset,\r\n            premultiply_alpha=self.has_alpha and self.premultiply_alpha,\r\n            perform_srgb_to_linear=self.perfrom_srb_to_linear)\r\n        if self.downscale_kernel_size == 1:\r\n            return [image]\r\n        else:\r\n            image = functional.avg_pool2d(image.unsqueeze(0), kernel_size=self.downscale_kernel_size).squeeze(0)\r\n            return [image]\r\n"
  },
  {
    "path": "src/tha4/shion/base/dataset/util.py",
    "content": "from typing import List\r\n\r\nimport torch\r\nfrom torch.utils.data import Dataset\r\n\r\n\r\ndef get_indexed_batch(dataset: Dataset, example_indices: List[int], device: torch.device):\r\n    if len(example_indices) == 0:\r\n        return []\r\n    examples = []\r\n    for index in range(len(example_indices)):\r\n        example_index = example_indices[index]\r\n        raw_example = dataset[example_index]\r\n        example = []\r\n        for x in raw_example:\r\n            if isinstance(x, torch.Tensor):\r\n                y = x.to(device).unsqueeze(0)\r\n            elif isinstance(x, float) or isinstance(x, int):\r\n                y = torch.tensor([[x]], device=device)\r\n            else:\r\n                raise RuntimeError(f\"get_indexed_batch: Data of type {type(x)} is not supported.\")\r\n            example.append(y)\r\n        examples.append(example)\r\n    k = len(examples[0])\r\n    transposed = [[] for i in range(k)]\r\n    for example in examples:\r\n        for i in range(k):\r\n            transposed[i].append(example[i])\r\n    return [torch.cat(x, dim=0) for x in transposed]\r\n"
  },
  {
    "path": "src/tha4/shion/base/dataset/xformed_dataset.py",
    "content": "from typing import Any, Callable\r\n\r\nfrom torch.utils.data import Dataset\r\n\r\n\r\nclass XformedDataset(Dataset):\r\n    def __init__(self, source: Dataset, xform_func: Callable[[Any], Any]):\r\n        self.xform_func = xform_func\r\n        self.source = source\r\n\r\n    def __len__(self):\r\n        return len(self.source)\r\n\r\n    def __getitem__(self, item):\r\n        return self.xform_func(self.source[item])\r\n"
  },
  {
    "path": "src/tha4/shion/base/image_util.py",
    "content": "import os\r\n\r\nimport PIL.Image\r\nimport numpy\r\nimport torch\r\nfrom matplotlib import pyplot\r\nfrom torch import Tensor\r\n\r\n\r\ndef numpy_srgb_to_linear(x):\r\n    x = numpy.clip(x, 0.0, 1.0)\r\n    return numpy.where(x <= 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4)\r\n\r\n\r\ndef numpy_linear_to_srgb(x):\r\n    x = numpy.clip(x, 0.0, 1.0)\r\n    return numpy.where(x <= 0.003130804953560372, x * 12.92, 1.055 * (x ** (1.0 / 2.4)) - 0.055)\r\n\r\n\r\ndef numpy_alpha_devide(rgb, a, epsilon=1e-5):\r\n    aaa = numpy.repeat(a, 3, axis=2)\r\n    aaa_prime = aaa + numpy.where(numpy.abs(aaa) < epsilon, epsilon, 0.0)\r\n    return numpy.where(numpy.abs(aaa) < epsilon, 0.0, rgb / aaa_prime)\r\n\r\n\r\ndef torch_srgb_to_linear(x: torch.Tensor):\r\n    x = torch.clip(x, 0.0, 1.0)\r\n    return torch.where(torch.le(x, 0.04045), x / 12.92, ((x + 0.055) / 1.055) ** 2.4)\r\n\r\n\r\ndef torch_linear_to_srgb(x):\r\n    x = torch.clip(x, 0.0, 1.0)\r\n    return torch.where(torch.le(x, 0.003130804953560372), x * 12.92, 1.055 * (x ** (1.0 / 2.4)) - 0.055)\r\n\r\n\r\ndef numpy_image_linear_to_srgb(image):\r\n    assert image.shape[2] == 3 or image.shape[2] == 4\r\n    if image.shape[2] == 3:\r\n        return numpy_linear_to_srgb(image)\r\n    else:\r\n        height, width, _ = image.shape\r\n        rgb_image = numpy_linear_to_srgb(image[:, :, 0:3])\r\n        a_image = image[:, :, 3:4]\r\n        return numpy.concatenate((rgb_image, a_image), axis=2)\r\n\r\n\r\ndef numpy_image_srgb_to_linear(image):\r\n    assert image.shape[2] == 3 or image.shape[2] == 4\r\n    if image.shape[2] == 3:\r\n        return numpy_srgb_to_linear(image)\r\n    else:\r\n        height, width, _ = image.shape\r\n        rgb_image = numpy_srgb_to_linear(image[:, :, 0:3])\r\n        a_image = image[:, :, 3:4]\r\n        return numpy.concatenate((rgb_image, a_image), axis=2)\r\n\r\n\r\ndef pytorch_rgb_to_numpy_image(torch_image: Tensor, min_pixel_value=-1.0, max_pixel_value=1.0):\r\n    assert torch_image.dim() == 3\r\n    assert torch_image.shape[0] == 3\r\n    height = torch_image.shape[1]\r\n    width = torch_image.shape[2]\r\n\r\n    reshaped_image = torch_image.numpy().reshape(3, height * width).transpose().reshape(height, width, 3)\r\n    numpy_image = (reshaped_image - min_pixel_value) / (max_pixel_value - min_pixel_value)\r\n    return numpy_linear_to_srgb(numpy_image)\r\n\r\n\r\ndef pytorch_rgba_to_numpy_image_greenscreen(torch_image: Tensor,\r\n                                            min_pixel_value=-1.0,\r\n                                            max_pixel_value=1.0,\r\n                                            include_alpha=False):\r\n    height = torch_image.shape[1]\r\n    width = torch_image.shape[2]\r\n\r\n    numpy_image = (torch_image.numpy().reshape(4, height * width).transpose().reshape(height, width,\r\n                                                                                      4) - min_pixel_value) \\\r\n                  / (max_pixel_value - min_pixel_value)\r\n    rgb_image = numpy_linear_to_srgb(numpy_image[:, :, 0:3])\r\n    a_image = numpy_image[:, :, 3]\r\n    rgb_image[:, :, 0:3] = rgb_image[:, :, 0:3] * a_image.reshape(a_image.shape[0], a_image.shape[1], 1)\r\n    rgb_image[:, :, 1] = rgb_image[:, :, 1] + (1 - a_image)\r\n\r\n    if not include_alpha:\r\n        return rgb_image\r\n    else:\r\n        return numpy.concatenate((rgb_image, numpy.ones_like(numpy_image[:, :, 3:4])), axis=2)\r\n\r\n\r\ndef pytorch_rgba_to_numpy_image(\r\n        torch_image: Tensor,\r\n        min_pixel_value=-1.0,\r\n        max_pixel_value=1.0,\r\n        perform_linear_to_srb: bool = True):\r\n    assert torch_image.dim() == 3\r\n    assert torch_image.shape[0] == 4\r\n    height = torch_image.shape[1]\r\n    width = torch_image.shape[2]\r\n\r\n    reshaped_image = torch_image.numpy().reshape(4, height * width).transpose().reshape(height, width, 4)\r\n    numpy_image = (reshaped_image - min_pixel_value) / (max_pixel_value - min_pixel_value)\r\n    if perform_linear_to_srb:\r\n        rgb_image = numpy_linear_to_srgb(numpy_image[:, :, 0:3])\r\n    else:\r\n        rgb_image = numpy.clip(numpy_image[:, :, 0:3], 0.0, 1.0)\r\n    a_image = numpy.clip(numpy_image[:, :, 3], 0.0, 1.0)\r\n    rgba_image = numpy.concatenate((rgb_image, a_image.reshape(height, width, 1)), axis=2)\r\n    return rgba_image\r\n\r\n\r\ndef pil_image_has_transparency(pil_image):\r\n    if pil_image.info.get(\"transparency\", None) is not None:\r\n        return True\r\n    if pil_image.mode == \"P\":\r\n        transparent = pil_image.info.get(\"transparency\", -1)\r\n        for _, index in pil_image.getcolors():\r\n            if index == transparent:\r\n                return True\r\n    elif pil_image.mode == \"RGBA\":\r\n        extrema = pil_image.getextrema()\r\n        if extrema[3][0] < 255:\r\n            return True\r\n\r\n    return False\r\n\r\n\r\ndef extract_numpy_image_from_PIL_image(pil_image, scale=2.0, offset=-1.0,\r\n                                       premultiply_alpha=True,\r\n                                       perform_srgb_to_linear=True):\r\n    has_alpha = pil_image_has_transparency(pil_image)\r\n    if has_alpha and pil_image.mode != 'RGBA':\r\n        pil_image = pil_image.convert(\"RGBA\")\r\n    if not has_alpha and pil_image.mode != 'RGB':\r\n        pil_image = pil_image.convert(\"RGB\")\r\n    if has_alpha:\r\n        num_channel = 4\r\n    else:\r\n        num_channel = 3\r\n    image_width = pil_image.width\r\n    image_height = pil_image.height\r\n\r\n    raw_image = numpy.asarray(pil_image, dtype=numpy.float32)\r\n    image = (raw_image / 255.0).reshape(image_height, image_width, num_channel)\r\n    if perform_srgb_to_linear:\r\n        image[:, :, 0:3] = numpy_srgb_to_linear(image[:, :, 0:3])\r\n        # Premultiply alpha\r\n    if has_alpha and premultiply_alpha:\r\n        image[:, :, 0:3] = image[:, :, 0:3] * image[:, :, 3:4]\r\n    return image * scale + offset\r\n\r\n\r\ndef extract_numpy_image_from_PIL_image_with_pytorch_layout(pil_image, scale=2.0, offset=-1.0,\r\n                                                           premultiply_alpha=True,\r\n                                                           perform_srgb_to_linear=True):\r\n    numpy_image = extract_numpy_image_from_PIL_image(\r\n        pil_image, scale, offset, premultiply_alpha, perform_srgb_to_linear)\r\n    image_height, image_width, num_channel = numpy_image.shape\r\n    image = numpy_image \\\r\n        .reshape(image_height * image_width, num_channel) \\\r\n        .transpose() \\\r\n        .reshape(num_channel, image_height, image_width)\r\n    return image\r\n\r\n\r\ndef extract_numpy_image_from_filelike_with_pytorch_layout(file, scale=2.0, offset=-1.0, premultiply_alpha=True):\r\n    try:\r\n        pil_image = PIL.Image.open(file)\r\n    except Exception as e:\r\n        raise RuntimeError(file)\r\n    return extract_numpy_image_from_PIL_image_with_pytorch_layout(pil_image, scale, offset, premultiply_alpha)\r\n\r\n\r\ndef extract_numpy_image_from_filelike(file, scale=1.0, offset=0.0,\r\n                                      premultiply_alpha=True,\r\n                                      perform_srgb_to_linear: bool = True):\r\n    try:\r\n        pil_image = PIL.Image.open(file)\r\n    except Exception as e:\r\n        raise RuntimeError(file)\r\n    return extract_numpy_image_from_PIL_image(pil_image, scale, offset, premultiply_alpha, perform_srgb_to_linear)\r\n\r\n\r\ndef extract_pytorch_image_from_filelike(file, scale=2.0, offset=-1.0, premultiply_alpha=True,\r\n                                        perform_srgb_to_linear=True):\r\n    try:\r\n        pil_image = PIL.Image.open(file)\r\n    except Exception as e:\r\n        raise RuntimeError(file)\r\n    image = extract_numpy_image_from_PIL_image_with_pytorch_layout(pil_image, scale, offset, premultiply_alpha,\r\n                                                                   perform_srgb_to_linear)\r\n    return torch.from_numpy(image).float()\r\n\r\n\r\ndef extract_pytorch_image_from_PIL_image(pil_image, scale=2.0, offset=-1.0, premultiply_alpha=True,\r\n                                         perform_srgb_to_linear=True):\r\n    image = extract_numpy_image_from_PIL_image_with_pytorch_layout(\r\n        pil_image, scale, offset, premultiply_alpha, perform_srgb_to_linear)\r\n    return torch.from_numpy(image).float()\r\n\r\n\r\ndef convert_pytorch_image_to_zero_to_one_numpy_image(\r\n        torch_image: torch.Tensor,\r\n        scale: float = 2.0,\r\n        offset: float = -1.0):\r\n    torch_image = (torch_image - offset) / scale\r\n    torch_image = torch.permute(torch_image, (1, 2, 0))\r\n    numpy_image = torch_image.cpu().numpy()\r\n    return numpy_image\r\n\r\n\r\ndef convert_zero_to_one_numpy_image_to_PIL_image(\r\n        numpy_image, use_straight_alpha=True, perform_linear_to_srgb=True):\r\n    if numpy_image.shape[2] == 4:\r\n        rgb_image = numpy_image[:, :, 0:3]\r\n        a_image = numpy.clip(numpy_image[:, :, 3:4], 0.0, 1.0)\r\n        if use_straight_alpha:\r\n            rgb_image = numpy_alpha_devide(rgb_image, a_image)\r\n        if perform_linear_to_srgb:\r\n            rgb_image = numpy_linear_to_srgb(rgb_image)\r\n        else:\r\n            rgb_image = numpy.clip(rgb_image, 0.0, 1.0)\r\n        new_numpy_image = numpy.concatenate((rgb_image, a_image), axis=2)\r\n        pil_image = PIL.Image.fromarray(numpy.uint8(numpy.rint(new_numpy_image * 255.0)), mode='RGBA')\r\n    else:\r\n        if perform_linear_to_srgb:\r\n            numpy_image = numpy_linear_to_srgb(numpy_image)\r\n        else:\r\n            numpy_image = numpy.clip(numpy_image, 0.0, 1.0)\r\n        pil_image = PIL.Image.fromarray(numpy.uint8(numpy.rint(numpy_image * 255.0)), mode='RGB')\r\n    return pil_image\r\n\r\n\r\ndef save_numpy_image(numpy_image, file_name: str, save_straight_alpha=True, perform_linear_to_srgb=True):\r\n    pil_image = convert_zero_to_one_numpy_image_to_PIL_image(numpy_image, save_straight_alpha, perform_linear_to_srgb)\r\n    os.makedirs(os.path.dirname(file_name), exist_ok=True)\r\n    pil_image.save(file_name)\r\n\r\n\r\ndef resize_PIL_image(pil_image, size=(256, 256)):\r\n    w, h = pil_image.size\r\n    d = min(w, h)\r\n    r = ((w - d) // 2, (h - d) // 2, (w + d) // 2, (h + d) // 2)\r\n    return pil_image.resize(size, resample=PIL.Image.LANCZOS, box=r)"
  },
  {
    "path": "src/tha4/shion/base/loss/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/shion/base/loss/computed_scale_loss.py",
    "content": "from typing import Optional, Callable\r\n\r\nfrom tha4.shion.core.cached_computation import TensorCachedComputationFunc, ComputationState\r\nfrom tha4.shion.core.loss import Loss\r\n\r\n\r\nclass ComputedScaleLoss(Loss):\r\n    def __init__(self,\r\n                 scale_func: TensorCachedComputationFunc,\r\n                 loss: Loss,\r\n                 weight: float = 1.0):\r\n        self.weight = weight\r\n        self.loss = loss\r\n        self.scale_func = scale_func\r\n\r\n    def compute(self, state: ComputationState, log_func: Optional[Callable[[str, float], None]] = None):\r\n        loss = self.loss.compute(state)\r\n        scale = self.scale_func(state)\r\n        loss = self.weight * scale * loss\r\n        if log_func is not None:\r\n            log_func(\"loss\", loss.item())\r\n        return loss\r\n"
  },
  {
    "path": "src/tha4/shion/base/loss/computed_scaled_l2_loss.py",
    "content": "from typing import Callable, Optional\r\n\r\nfrom tha4.shion.core.cached_computation import TensorCachedComputationFunc, ComputationState\r\nfrom tha4.shion.core.loss import Loss\r\n\r\n\r\nclass ComputedScaledL2Loss(Loss):\r\n    def __init__(self,\r\n                 expected_func: TensorCachedComputationFunc,\r\n                 actual_func: TensorCachedComputationFunc,\r\n                 element_scale_func: TensorCachedComputationFunc,\r\n                 weight: float = 1.0):\r\n        self.element_scale_func = element_scale_func\r\n        self.actual_func = actual_func\r\n        self.expected_func = expected_func\r\n        self.weight = weight\r\n\r\n    def compute(\r\n            self,\r\n            state: ComputationState,\r\n            log_func: Optional[Callable[[str, float], None]] = None):\r\n        element_scale = self.element_scale_func(state)\r\n        expected = self.expected_func(state)\r\n        actual = self.actual_func(state)\r\n        diff = (expected - actual) * element_scale\r\n        loss = self.weight * (diff ** 2).mean()\r\n        if log_func is not None:\r\n            log_func(\"loss\", loss.item())\r\n        return loss\r\n"
  },
  {
    "path": "src/tha4/shion/base/loss/l1_loss.py",
    "content": "from typing import Callable, Optional\r\n\r\nimport torch\r\n\r\nfrom tha4.shion.core.cached_computation import TensorCachedComputationFunc, ComputationState\r\nfrom tha4.shion.core.loss import Loss\r\n\r\n\r\nclass L1Loss(Loss):\r\n    def __init__(self,\r\n                 expected_func: TensorCachedComputationFunc,\r\n                 actual_func: TensorCachedComputationFunc,\r\n                 weight: float = 1.0):\r\n        self.actual_func = actual_func\r\n        self.expected_func = expected_func\r\n        self.weight = weight\r\n\r\n    def compute(self, state: ComputationState, log_func: Optional[Callable[[str, float], None]] = None):\r\n        expected = self.expected_func(state)\r\n        actual = self.actual_func(state)\r\n        loss = self.weight * (expected - actual).abs().mean()\r\n        if log_func is not None:\r\n            log_func(\"loss\", loss.item())\r\n        return loss\r\n\r\n\r\nclass ListL1Loss(Loss):\r\n    def __init__(self,\r\n                 expected_func: TensorCachedComputationFunc,\r\n                 actual_func: TensorCachedComputationFunc,\r\n                 weight: float = 1.0):\r\n        self.actual_func = actual_func\r\n        self.expected_func = expected_func\r\n        self.weight = weight\r\n\r\n    def compute(self, state: ComputationState, log_func: Optional[Callable[[str, float], None]] = None):\r\n        expected = self.expected_func(state)\r\n        actual = self.actual_func(state)\r\n        assert len(expected) == len(actual)\r\n        loss = torch.zeros(1, device=expected[0].device)\r\n        for i in range(len(expected)):\r\n            loss += (expected[i] - actual[i]).abs().mean()\r\n        loss = self.weight * loss\r\n        if log_func is not None:\r\n            log_func(\"loss\", loss.item())\r\n        return loss\r\n\r\n\r\nclass MaskedL1Loss(Loss):\r\n    def __init__(self,\r\n                 expected_func: TensorCachedComputationFunc,\r\n                 actual_func: TensorCachedComputationFunc,\r\n                 mask_func: TensorCachedComputationFunc,\r\n                 weight: float = 1.0):\r\n        self.mask_func = mask_func\r\n        self.actual_func = actual_func\r\n        self.expected_func = expected_func\r\n        self.weight = weight\r\n\r\n    def compute(self, state: ComputationState, log_func: Optional[Callable[[str, float], None]] = None):\r\n        mask = self.mask_func(state)\r\n        expected = self.expected_func(state)\r\n        actual = self.actual_func(state)\r\n        loss = self.weight * ((expected - actual) * mask).abs().mean()\r\n        if log_func is not None:\r\n            log_func(\"loss\", loss.item())\r\n        return loss"
  },
  {
    "path": "src/tha4/shion/base/loss/l2_loss.py",
    "content": "from typing import Callable, Optional\r\n\r\nfrom tha4.shion.core.cached_computation import TensorCachedComputationFunc, ComputationState\r\nfrom tha4.shion.core.loss import Loss\r\n\r\n\r\nclass L2Loss(Loss):\r\n    def __init__(self,\r\n                 expected_func: TensorCachedComputationFunc,\r\n                 actual_func: TensorCachedComputationFunc,\r\n                 weight: float = 1.0):\r\n        self.actual_func = actual_func\r\n        self.expected_func = expected_func\r\n        self.weight = weight\r\n\r\n    def compute(\r\n            self,\r\n            state: ComputationState,\r\n            log_func: Optional[Callable[[str, float], None]] = None):\r\n        expected = self.expected_func(state)\r\n        actual = self.actual_func(state)\r\n        loss = self.weight * ((expected - actual) ** 2).mean()\r\n        if log_func is not None:\r\n            log_func(\"loss\", loss.item())\r\n        return loss\r\n"
  },
  {
    "path": "src/tha4/shion/base/loss/sum_loss.py",
    "content": "from typing import List, Tuple, Callable, Optional\r\n\r\nimport torch\r\nfrom torch import Tensor\r\n\r\nfrom tha4.shion.core.cached_computation import ComputationState\r\nfrom tha4.shion.core.loss import Loss\r\n\r\n\r\nclass SumLoss(Loss):\r\n    def __init__(self, losses: List[Tuple[str, Loss]]):\r\n        self.losses = losses\r\n\r\n    def compute(self,\r\n                state: ComputationState,\r\n                log_func: Optional[Callable[[str, float], None]] = None) -> Tensor:\r\n        device = state.batch[0].device\r\n        loss_value = torch.zeros(1, device=device)\r\n        for loss_spec in self.losses:\r\n            loss_name = loss_spec[0]\r\n            loss = loss_spec[1]\r\n            if log_func is not None:\r\n                def loss_log_func(name, value):\r\n                    log_func(loss_name + \"_\" + name, value)\r\n            else:\r\n                loss_log_func = None\r\n            loss_value = loss_value + loss.compute(state, loss_log_func)\r\n\r\n        if log_func is not None:\r\n            log_func(\"loss\", loss_value.item())\r\n\r\n        return loss_value\r\n"
  },
  {
    "path": "src/tha4/shion/base/loss/time_dependently_weighted_loss.py",
    "content": "from typing import Callable, Optional\r\n\r\nfrom torch import Tensor\r\n\r\nfrom tha4.shion.core.cached_computation import ComputationState, CachedComputationFunc\r\nfrom tha4.shion.core.loss import Loss\r\n\r\n\r\nclass TimeDependentlyWeightedLoss(Loss):\r\n    def __init__(self,\r\n                 base_loss: Loss,\r\n                 examples_seen_so_far_func: CachedComputationFunc,\r\n                 weight_func: Callable[[int], float]):\r\n        self.weight_func = weight_func\r\n        self.examples_seen_so_far_func = examples_seen_so_far_func\r\n        self.base_loss = base_loss\r\n\r\n    def compute(self,\r\n                state: ComputationState,\r\n                log_func: Optional[Callable[[str, float], None]] = None) -> Tensor:\r\n        base_value = self.base_loss.compute(state)\r\n        examples_seen_so_far = self.examples_seen_so_far_func(state)\r\n        weight = self.weight_func(examples_seen_so_far)\r\n        loss_value = base_value * weight\r\n\r\n        if log_func is not None:\r\n            log_func(\"loss\", loss_value.item())\r\n\r\n        return loss_value\r\n"
  },
  {
    "path": "src/tha4/shion/base/module_accumulators.py",
    "content": "from typing import Optional\r\n\r\nimport torch\r\nfrom torch.nn import Module\r\n\r\nfrom tha4.shion.core.module_accumulator import ModuleAccumulator\r\n\r\n\r\n# Code from https://github.com/rosinality/style-based-gan-pytorch/blob/8437a8bbd106ad4a4691b798ce35d30b5111990b/train.py\r\ndef accumulate_modules(new_module: Module, accumulated_module: Module, beta=0.99):\r\n    with torch.no_grad():\r\n        new_module_params = dict(new_module.named_parameters())\r\n        accumulated_module_params = dict(accumulated_module.named_parameters())\r\n        for key in new_module_params.keys():\r\n            accumulated_module_params[key].mul_(beta).add_(new_module_params[key] * (1 - beta))\r\n\r\n        new_module_buffers = dict(new_module.named_buffers())\r\n        accumulated_module_buffers = dict(accumulated_module.named_buffers())\r\n        for key in new_module_buffers.keys():\r\n            accumulated_module_buffers[key].copy_(new_module_buffers[key])\r\n\r\n\r\nclass DecayAccumulator(ModuleAccumulator):\r\n    def __init__(self, decay: float = 0.999):\r\n        self.decay = decay\r\n\r\n    def accumulate(self, module: Module, output: Module, examples_seen_so_far: Optional[int] = None) -> Module:\r\n        accumulate_modules(module, output, self.decay)\r\n        return output\r\n"
  },
  {
    "path": "src/tha4/shion/base/optimizer_factories.py",
    "content": "from typing import Tuple, Iterable\r\n\r\nfrom torch.nn import Parameter\r\nfrom torch.optim import Optimizer, Adam, AdamW, SparseAdam, RMSprop\r\n\r\nfrom tha4.shion.core.optimizer_factory import OptimizerFactory\r\n\r\n\r\nclass AdamOptimizerFactory(OptimizerFactory):\r\n    def __init__(self, betas: Tuple[float, float] = (0.9, 0.999), epsilon: float = 1e-8, weight_decay: float = 0.0):\r\n        super().__init__()\r\n        self.weight_decay = weight_decay\r\n        self.betas = betas\r\n        self.epsilon = epsilon\r\n\r\n    def create(self, parameters: Iterable[Parameter]) -> Optimizer:\r\n        return Adam(parameters, betas=self.betas, eps=self.epsilon, weight_decay=self.weight_decay)\r\n\r\n\r\nclass AdamWOptimizerFactory(OptimizerFactory):\r\n    def __init__(self, betas: Tuple[float, float] = (0.9, 0.999), epsilon: float = 1e-8, weight_decay: float = 0.01):\r\n        super().__init__()\r\n        self.weight_decay = weight_decay\r\n        self.betas = betas\r\n        self.epsilon = epsilon\r\n\r\n    def create(self, parameters: Iterable[Parameter]) -> Optimizer:\r\n        return AdamW(parameters, betas=self.betas, eps=self.epsilon, weight_decay=self.weight_decay)\r\n\r\n\r\nclass SparseAdamOptimizerFactory(OptimizerFactory):\r\n    def __init__(self, betas: Tuple[float, float] = (0.9, 0.999), epsilon: float = 1e-8):\r\n        super().__init__()\r\n        self.betas = betas\r\n        self.epsilon = epsilon\r\n\r\n    def create(self, parameters: Iterable[Parameter]) -> Optimizer:\r\n        return SparseAdam(list(parameters), betas=self.betas, eps=self.epsilon)\r\n\r\n\r\nclass RMSpropOptimizerFactory(OptimizerFactory):\r\n    def __init__(self):\r\n        super().__init__()\r\n\r\n    def create(self, parameters: Iterable[Parameter]) -> Optimizer:\r\n        return RMSprop(parameters)\r\n"
  },
  {
    "path": "src/tha4/shion/base/protocol/single_network_from_batch_input_computation_protocol.py",
    "content": "from typing import Optional, Any, List\r\n\r\nfrom tha4.shion.core.cached_computation import CachedComputationProtocol, ComputationState\r\n\r\nKEY_NETWORK = \"network\"\r\nKEY_NETWORK_OUTPUT = \"network_output\"\r\n\r\n\r\nclass SingleNetworkBatchInputComputationProtocol(CachedComputationProtocol):\r\n    def __init__(self,\r\n                 key_network: str = KEY_NETWORK,\r\n                 key_network_output: str = KEY_NETWORK_OUTPUT,\r\n                 input_index_to_batch_index: Optional[List[int]] = None):\r\n        if input_index_to_batch_index is None:\r\n            input_index_to_batch_index = [0]\r\n\r\n        self.input_index_to_batch_index = input_index_to_batch_index\r\n        self.key_network_output = key_network_output\r\n        self.key_network = key_network\r\n\r\n    def compute_output(self, key: str, state: ComputationState) -> Any:\r\n        if key == self.key_network_output:\r\n            inputs = []\r\n            for batch_index in self.input_index_to_batch_index:\r\n                inputs.append(state.batch[batch_index])\r\n            network = state.modules[self.key_network]\r\n            return network.forward(*inputs)\r\n        else:\r\n            raise RuntimeError(\"Computing output for key \" + key + \" is not supported!\")\r\n"
  },
  {
    "path": "src/tha4/shion/base/training/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/shion/base/training/single_network.py",
    "content": "import time\r\nfrom typing import List, Dict, Callable, Any, Optional\r\n\r\nimport torch\r\nfrom torch.nn import Module\r\nfrom torch.nn.utils import clip_grad_norm_\r\nfrom torch.optim.optimizer import Optimizer\r\n\r\nfrom tha4.shion.core.cached_computation import ComputationState\r\nfrom tha4.shion.core.loss import Loss\r\nfrom tha4.shion.core.optimizer_factory import OptimizerFactory\r\nfrom tha4.shion.core.training.training_protocol import TrainingProtocol\r\nfrom tha4.shion.core.training.validation_protocol import ValidationProtocol\r\n\r\nKEY_NETWORK = \"network\"\r\n\r\n\r\nclass SingleNetworkTrainingProtocol(TrainingProtocol):\r\n    def __init__(self,\r\n                 check_point_examples: List[int],\r\n                 batch_size: int,\r\n                 learning_rate: Callable[[int], Dict[str, float]],\r\n                 optimizer_factories: Dict[str, OptimizerFactory],\r\n                 module_key: str = KEY_NETWORK,\r\n                 random_seed: int = 39549059840,\r\n                 max_grad_norm: Optional[float] = None):\r\n        super().__init__()\r\n        self.max_grad_norm = max_grad_norm\r\n        self.module_key = module_key\r\n        self.optimizer_factories = optimizer_factories\r\n        self.learning_rate = learning_rate\r\n        self.batch_size = batch_size\r\n        self.random_seed = random_seed\r\n        self.check_point_examples = check_point_examples\r\n\r\n    def get_optimizer_factories(self) -> Dict[str, OptimizerFactory]:\r\n        return self.optimizer_factories\r\n\r\n    def get_checkpoint_examples(self) -> List[int]:\r\n        return self.check_point_examples\r\n\r\n    def get_random_seed(self) -> int:\r\n        return self.random_seed\r\n\r\n    def get_batch_size(self) -> int:\r\n        return self.batch_size\r\n\r\n    def get_learning_rate(self, examples_seen_so_far: int) -> Dict[str, float]:\r\n        return self.learning_rate(examples_seen_so_far)\r\n\r\n    def run_training_iteration(\r\n            self,\r\n            batch: Any,\r\n            examples_seen_so_far: int,\r\n            modules: Dict[str, Module],\r\n            accumulated_modules: Dict[str, Module],\r\n            optimizers: Dict[str, Optimizer],\r\n            losses: Dict[str, Loss],\r\n            create_log_func: Optional[Callable[[str, int], Callable[[str, float], None]]],\r\n            device: torch.device):\r\n        module = modules[self.module_key]\r\n        module.train(True)\r\n        optimizers[self.module_key].zero_grad(set_to_none=True)\r\n        if create_log_func is not None:\r\n            log_func = create_log_func(\"training_\" + self.module_key, examples_seen_so_far)\r\n        else:\r\n            log_func = None\r\n        losses[self.module_key].compute(\r\n            ComputationState(modules, accumulated_modules, batch),\r\n            log_func).backward()\r\n        if self.max_grad_norm is not None:\r\n            clip_grad_norm_(module.parameters(), self.max_grad_norm)\r\n        optimizers[self.module_key].step()\r\n\r\n\r\nclass SingleNetworkValidationProtocol(ValidationProtocol):\r\n    def __init__(\r\n            self,\r\n            example_per_validation_iteration: int,\r\n            batch_size: int,\r\n            module_key: str = KEY_NETWORK):\r\n        super().__init__()\r\n        self.module_key = module_key\r\n        self.batch_size = batch_size\r\n        self.example_per_validation_iteration = example_per_validation_iteration\r\n\r\n    def get_batch_size(self, ) -> int:\r\n        return self.batch_size\r\n\r\n    def get_examples_per_validation_iteration(self) -> int:\r\n        return self.example_per_validation_iteration\r\n\r\n    def run_validation_iteration(\r\n            self,\r\n            batch: Any,\r\n            examples_seen_so_far: int,\r\n            modules: Dict[str, Module],\r\n            accumulated_modules: Dict[str, Module],\r\n            losses: Dict[str, Loss],\r\n            create_log_func: Callable[[str, int], Callable[[str, float], None]],\r\n            device: torch.device):\r\n        module = modules[self.module_key]\r\n        module.train(False)\r\n        with torch.no_grad():\r\n            log_func = create_log_func(\"validation_\" + self.module_key, examples_seen_so_far)\r\n            losses[self.module_key].compute(\r\n                ComputationState(modules, accumulated_modules, batch),\r\n                log_func)\r\n"
  },
  {
    "path": "src/tha4/shion/base/training/single_network_with_minibatch.py",
    "content": "import time\r\nfrom typing import List, Dict, Callable, Any, Optional\r\n\r\nimport torch\r\nfrom torch.nn import Module\r\nfrom torch.nn.utils import clip_grad_norm_\r\nfrom torch.optim.optimizer import Optimizer\r\n\r\nfrom tha4.shion.core.cached_computation import ComputationState\r\nfrom tha4.shion.core.loss import Loss\r\nfrom tha4.shion.core.optimizer_factory import OptimizerFactory\r\nfrom tha4.shion.core.training.training_protocol import TrainingProtocol\r\nfrom tha4.shion.core.training.validation_protocol import ValidationProtocol\r\n\r\nKEY_NETWORK = \"network\"\r\n\r\n\r\nclass SingleNetworkWithMinibatchTrainingProtocol(TrainingProtocol):\r\n    def __init__(self,\r\n                 check_point_examples: List[int],\r\n                 batch_size: int,\r\n                 minibatch_size: int,\r\n                 learning_rate: Callable[[int], Dict[str, float]],\r\n                 optimizer_factories: Dict[str, OptimizerFactory],\r\n                 module_key: str = KEY_NETWORK,\r\n                 random_seed: int = 39549059840,\r\n                 max_grad_norm: Optional[float] = None):\r\n        super().__init__()\r\n        assert batch_size % minibatch_size == 0\r\n        self.minibatch_size = minibatch_size\r\n        self.max_grad_norm = max_grad_norm\r\n        self.module_key = module_key\r\n        self.optimizer_factories = optimizer_factories\r\n        self.learning_rate = learning_rate\r\n        self.batch_size = batch_size\r\n        self.random_seed = random_seed\r\n        self.check_point_examples = check_point_examples\r\n\r\n    def get_optimizer_factories(self) -> Dict[str, OptimizerFactory]:\r\n        return self.optimizer_factories\r\n\r\n    def get_checkpoint_examples(self) -> List[int]:\r\n        return self.check_point_examples\r\n\r\n    def get_random_seed(self) -> int:\r\n        return self.random_seed\r\n\r\n    def get_batch_size(self) -> int:\r\n        return self.batch_size\r\n\r\n    def get_learning_rate(self, examples_seen_so_far: int) -> Dict[str, float]:\r\n        return self.learning_rate(examples_seen_so_far)\r\n\r\n    def run_training_iteration(\r\n            self,\r\n            batch: Any,\r\n            examples_seen_so_far: int,\r\n            modules: Dict[str, Module],\r\n            accumulated_modules: Dict[str, Module],\r\n            optimizers: Dict[str, Optimizer],\r\n            losses: Dict[str, Loss],\r\n            create_log_func: Optional[Callable[[str, int], Callable[[str, float], None]]],\r\n            device: torch.device):\r\n        module = modules[self.module_key]\r\n        module.train(True)\r\n        optimizers[self.module_key].zero_grad(set_to_none=True)\r\n        if create_log_func is not None:\r\n            log_func = create_log_func(\"training_\" + self.module_key, examples_seen_so_far)\r\n        else:\r\n            log_func = None\r\n\r\n        num_minibatch = self.batch_size // self.minibatch_size\r\n        for minibatch_index in range(num_minibatch):\r\n            minibatch = []\r\n            for item in batch:\r\n                minibatch.append(\r\n                    item[minibatch_index * self.minibatch_size:(minibatch_index + 1) * self.minibatch_size])\r\n            loss = losses[self.module_key].compute(\r\n                ComputationState(modules, accumulated_modules, minibatch),\r\n                log_func if minibatch_index == 0 else None)\r\n            loss = loss / num_minibatch\r\n            loss.backward()\r\n\r\n        if self.max_grad_norm is not None:\r\n            clip_grad_norm_(module.parameters(), self.max_grad_norm)\r\n\r\n        optimizers[self.module_key].step()\r\n"
  },
  {
    "path": "src/tha4/shion/base/training/two_networks_training_protocol.py",
    "content": "from typing import List, Dict, Callable, Any, Optional\r\n\r\nimport torch\r\nfrom torch.nn import Module\r\nfrom torch.nn.utils import clip_grad_norm_\r\nfrom torch.optim.optimizer import Optimizer\r\n\r\nfrom tha4.shion.core.cached_computation import ComputationState\r\nfrom tha4.shion.core.loss import Loss\r\nfrom tha4.shion.core.optimizer_factory import OptimizerFactory\r\nfrom tha4.shion.core.training.training_protocol import TrainingProtocol\r\n\r\n\r\nclass TwoNetworksWithMinibatchTrainingProtocol(TrainingProtocol):\r\n    def __init__(self,\r\n                 check_point_examples: List[int],\r\n                 batch_size: int,\r\n                 learning_rate: Callable[[int], Dict[str, float]],\r\n                 optimizer_factories: Dict[str, OptimizerFactory],\r\n                 key_network_0: str,\r\n                 key_network_1: str,\r\n                 train_network_0: bool = False,\r\n                 random_seed: int = 39549059840,\r\n                 max_grad_norm: Optional[float] = None,\r\n                 minibatch_size: Optional[int] = None):\r\n        super().__init__()\r\n        if minibatch_size is None:\r\n            minibatch_size = batch_size\r\n        assert batch_size % minibatch_size == 0\r\n        self.train_network_0 = train_network_0\r\n        self.key_network_1 = key_network_1\r\n        self.key_network_0 = key_network_0\r\n        self.minibatch_size = minibatch_size\r\n        self.max_grad_norm = max_grad_norm\r\n        self.optimizer_factories = optimizer_factories\r\n        self.learning_rate = learning_rate\r\n        self.batch_size = batch_size\r\n        self.random_seed = random_seed\r\n        self.check_point_examples = check_point_examples\r\n\r\n    def get_optimizer_factories(self) -> Dict[str, OptimizerFactory]:\r\n        return self.optimizer_factories\r\n\r\n    def get_checkpoint_examples(self) -> List[int]:\r\n        return self.check_point_examples\r\n\r\n    def get_random_seed(self) -> int:\r\n        return self.random_seed\r\n\r\n    def get_batch_size(self) -> int:\r\n        return self.batch_size\r\n\r\n    def get_learning_rate(self, examples_seen_so_far: int) -> Dict[str, float]:\r\n        return self.learning_rate(examples_seen_so_far)\r\n\r\n    def run_training_iteration(\r\n            self,\r\n            batch: Any,\r\n            examples_seen_so_far: int,\r\n            modules: Dict[str, Module],\r\n            accumulated_modules: Dict[str, Module],\r\n            optimizers: Dict[str, Optimizer],\r\n            losses: Dict[str, Loss],\r\n            create_log_func: Optional[Callable[[str, int], Callable[[str, float], None]]],\r\n            device: torch.device):\r\n        network_0 = modules[self.key_network_0]\r\n        network_0.train(self.train_network_0)\r\n\r\n        network_1 = modules[self.key_network_1]\r\n        network_1.train(True)\r\n\r\n        if self.train_network_0:\r\n            optimizers[self.key_network_0].zero_grad(set_to_none=True)\r\n        optimizers[self.key_network_1].zero_grad(set_to_none=True)\r\n\r\n        if create_log_func is not None:\r\n            network_0_log_func = create_log_func(\"training_\" + self.key_network_0, examples_seen_so_far)\r\n            network_1_log_func = create_log_func(\"training_\" + self.key_network_1, examples_seen_so_far)\r\n        else:\r\n            network_0_log_func = None\r\n            network_1_log_func = None\r\n\r\n        num_minibatch = self.batch_size // self.minibatch_size\r\n        for minibatch_index in range(num_minibatch):\r\n            minibatch = []\r\n            for item in batch:\r\n                minibatch.append(\r\n                    item[minibatch_index * self.minibatch_size:(minibatch_index + 1) * self.minibatch_size])\r\n            loss = losses[self.key_network_1].compute(\r\n                ComputationState(modules, accumulated_modules, minibatch),\r\n                network_1_log_func if minibatch_index == 0 else None)\r\n            if self.train_network_0 and self.key_network_0 in losses:\r\n                loss = loss + losses[self.key_network_0].compute(\r\n                    ComputationState(modules, accumulated_modules, minibatch),\r\n                    network_0_log_func if minibatch_index == 0 else None)\r\n            loss = loss / num_minibatch\r\n            loss.backward()\r\n\r\n        if self.max_grad_norm is not None:\r\n            clip_grad_norm_(network_1.parameters(), self.max_grad_norm)\r\n            if self.train_network_0:\r\n                clip_grad_norm_(network_0.parameters(), self.max_grad_norm)\r\n\r\n        optimizers[self.key_network_1].step()\r\n        if self.train_network_0:\r\n            optimizers[self.key_network_0].step()\r\n"
  },
  {
    "path": "src/tha4/shion/core/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/shion/core/cached_computation.py",
    "content": "from abc import ABC, abstractmethod\r\nfrom typing import Callable, Dict, Any, Optional\r\n\r\nimport torch\r\nfrom torch import Tensor\r\nfrom torch.nn import Module\r\n\r\n\r\nclass ComputationState:\r\n    def __init__(self,\r\n                 modules: Dict[str, Module],\r\n                 accumulated_modules: Dict[str, Module],\r\n                 batch: Any,\r\n                 outputs: Optional[Dict[str, Any]] = None):\r\n        if outputs is None:\r\n            outputs = {}\r\n        self.outputs = outputs\r\n        self.batch = batch\r\n        self.accumulated_modules = accumulated_modules\r\n        self.modules = modules\r\n\r\n\r\nCachedComputationFunc = Callable[[ComputationState], Any]\r\nTensorCachedComputationFunc = Callable[[ComputationState], Tensor]\r\n\r\n\r\ndef create_get_item_func(func: CachedComputationFunc, index):\r\n    def _f(state: ComputationState):\r\n        output = func(state)\r\n        return output[index]\r\n\r\n    return _f\r\n\r\n\r\ndef create_batch_element_func(index: int) -> TensorCachedComputationFunc:\r\n    def _f(state: ComputationState) -> Tensor:\r\n        return state.batch[index]\r\n\r\n    return _f\r\n\r\n\r\nclass CachedComputationProtocol(ABC):\r\n    def get_output(self, key: str, state: ComputationState) -> Any:\r\n        if key in state.outputs:\r\n            return state.outputs[key]\r\n        else:\r\n            output = self.compute_output(key, state)\r\n            state.outputs[key] = output\r\n            return state.outputs[key]\r\n\r\n    @abstractmethod\r\n    def compute_output(self, key: str, state: ComputationState) -> Any:\r\n        pass\r\n\r\n    def get_output_func(self, key: str) -> CachedComputationFunc:\r\n        def func(state: ComputationState):\r\n            return self.get_output(key, state)\r\n\r\n        return func\r\n\r\n\r\nComposableCachedComputationStep = Callable[[CachedComputationProtocol, ComputationState], Any]\r\n\r\n\r\nclass ComposableCachedComputationProtocol(CachedComputationProtocol):\r\n    def __init__(self, computation_steps: Optional[Dict[str, ComposableCachedComputationStep]] = None):\r\n        if computation_steps is None:\r\n            computation_steps = {}\r\n        self.computation_steps = computation_steps\r\n\r\n    def compute_output(self, key: str, state: ComputationState) -> Any:\r\n        if key in self.computation_steps:\r\n            return self.computation_steps[key](self, state)\r\n        else:\r\n            raise RuntimeError(\"Computing output for key \" + key + \" is not supported!\")\r\n\r\n\r\ndef batch_indexing_func(index: int):\r\n    def _f(protocol: CachedComputationProtocol, state: ComputationState):\r\n        return state.batch[index]\r\n\r\n    return _f\r\n\r\n\r\ndef proxy_func(key: str):\r\n    def _f(protocol: CachedComputationProtocol, state: ComputationState):\r\n        return protocol.get_output(key, state)\r\n\r\n    return _f\r\n\r\n\r\ndef output_array_indexing_func(key: str, index: int):\r\n    def _f(protocol: CachedComputationProtocol, state: ComputationState):\r\n        return protocol.get_output(key, state)[index]\r\n\r\n    return _f\r\n\r\n\r\ndef add_step(step_dict: Dict[str, ComposableCachedComputationStep], name: str):\r\n    def _f(func):\r\n        step_dict[name] = func\r\n        return func\r\n\r\n    return _f\r\n\r\n\r\ndef zeros_like_func(key: str):\r\n    def _f(protocol: CachedComputationProtocol, state: ComputationState):\r\n        prototype = protocol.get_output(key, state)\r\n        return torch.zeros_like(prototype)\r\n\r\n    return _f\r\n"
  },
  {
    "path": "src/tha4/shion/core/load_save.py",
    "content": "import os\r\n\r\nimport torch\r\n\r\n\r\ndef torch_save(content, file_name):\r\n    os.makedirs(os.path.dirname(file_name), exist_ok=True)\r\n    with open(file_name, 'wb') as f:\r\n        torch.save(content, f)\r\n\r\n\r\ndef torch_load(file_name):\r\n    with open(file_name, 'rb') as f:\r\n        return torch.load(f, map_location=lambda storage, loc: storage)\r\n"
  },
  {
    "path": "src/tha4/shion/core/loss.py",
    "content": "from abc import ABC, abstractmethod\r\nfrom typing import Callable, Optional\r\n\r\nfrom torch import Tensor\r\n\r\nfrom tha4.shion.core.cached_computation import ComputationState\r\n\r\n\r\nclass Loss(ABC):\r\n    @abstractmethod\r\n    def compute(\r\n            self,\r\n            state: ComputationState,\r\n            log_func: Optional[Callable[[str, float], None]] = None) -> Tensor:\r\n        pass\r\n"
  },
  {
    "path": "src/tha4/shion/core/module_accumulator.py",
    "content": "from abc import ABC, abstractmethod\r\nfrom typing import Optional\r\n\r\nfrom torch.nn import Module\r\n\r\n\r\nclass ModuleAccumulator(ABC):\r\n    @abstractmethod\r\n    def accumulate(self, module: Module, output: Module, examples_seen_so_far: Optional[int] = None) -> Module:\r\n        pass\r\n"
  },
  {
    "path": "src/tha4/shion/core/module_factory.py",
    "content": "from abc import ABC, abstractmethod\r\n\r\nfrom torch.nn import Module\r\n\r\n\r\nclass ModuleFactory(ABC):\r\n    @abstractmethod\r\n    def create(self) -> Module:\r\n        pass"
  },
  {
    "path": "src/tha4/shion/core/optimizer_factory.py",
    "content": "from abc import ABC, abstractmethod\r\nfrom typing import Iterable\r\n\r\nfrom torch.nn import Parameter\r\n\r\n\r\nclass OptimizerFactory(ABC):\r\n    @abstractmethod\r\n    def create(self, parameters: Iterable[Parameter]):\r\n        pass\r\n"
  },
  {
    "path": "src/tha4/shion/core/training/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/shion/core/training/distrib/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/shion/core/training/distrib/device_mapper.py",
    "content": "from typing import Dict\r\n\r\nimport torch\r\n\r\n\r\nclass SimpleCudaDeviceMapper:\r\n    def __call__(self, rank, local_rank):\r\n        return torch.device(\"cuda\", local_rank)\r\n\r\n\r\nclass UserSpecifiedLocalRankToDeviceMapper:\r\n    def __init__(self, device_map: Dict[int, torch.device]):\r\n        self.device_map = device_map\r\n\r\n    def __call__(self, rank, local_rank):\r\n        assert local_rank in self.device_map\r\n        return self.device_map[local_rank]\r\n"
  },
  {
    "path": "src/tha4/shion/core/training/distrib/distributed_trainer.py",
    "content": "import argparse\r\nimport logging\r\nimport os.path\r\nimport time\r\nfrom datetime import datetime\r\nfrom typing import Dict, Optional, Callable, Any\r\n\r\nimport torch\r\nimport torch.distributed\r\nfrom torch.nn.parallel import DistributedDataParallel\r\nfrom torch.utils.data import Dataset, DataLoader, DistributedSampler\r\nfrom torch.utils.tensorboard import SummaryWriter\r\n\r\nfrom tha4.shion.core.load_save import torch_save, torch_load\r\nfrom tha4.shion.core.loss import Loss\r\nfrom tha4.shion.core.module_accumulator import ModuleAccumulator\r\nfrom tha4.shion.core.module_factory import ModuleFactory\r\nfrom tha4.shion.core.training.distrib.device_mapper import SimpleCudaDeviceMapper\r\nfrom tha4.shion.core.training.distrib.distributed_training_states import DistributedTrainingState\r\nfrom tha4.shion.core.training.sample_output_protocol import SampleOutputProtocol\r\nfrom tha4.shion.core.training.training_protocol import TrainingProtocol\r\nfrom tha4.shion.core.training.util import set_learning_rate, create_log_func, get_least_greater_multiple\r\nfrom tha4.shion.core.training.validation_protocol import ValidationProtocol\r\n\r\nKEY_CHECKPOINT = 'checkpoint'\r\nKEY_SNAPSHOT = 'snapshot'\r\nKEY_VALIDATION = 'validation'\r\nKEY_SAMPLE_OUTPUT = 'sample_output'\r\n\r\n\r\nclass DistributedTrainer:\r\n    def __init__(self,\r\n                 prefix: str,\r\n                 module_factories: Dict[str, ModuleFactory],\r\n                 accumulators: Dict[str, ModuleAccumulator],\r\n                 losses: Dict[str, Loss],\r\n                 training_dataset: Dataset,\r\n                 validation_dataset: Optional[Dataset],\r\n                 training_protocol: TrainingProtocol,\r\n                 validation_protocol: Optional[ValidationProtocol],\r\n                 sample_output_protocol: Optional[SampleOutputProtocol],\r\n                 pretrained_module_file_names: Dict[str, str],\r\n                 example_per_snapshot: int,\r\n                 num_data_loader_workers: int = 8,\r\n                 distrib_backend: str = 'gloo'):\r\n        self.distrib_backend = distrib_backend\r\n        self.num_data_loader_workers = num_data_loader_workers\r\n        self.accumulators = accumulators\r\n        self.sample_output_protocol = sample_output_protocol\r\n        self.example_per_snapshot = example_per_snapshot\r\n        self.pretrained_module_file_names = pretrained_module_file_names\r\n        self.losses = losses\r\n        self.validation_protocol = validation_protocol\r\n        self.training_protocol = training_protocol\r\n        self.module_factories = module_factories\r\n        self.prefix = prefix\r\n        self.training_dataset = training_dataset\r\n        self.validation_dataset = validation_dataset\r\n\r\n        self.checkpoint_examples = self.training_protocol.get_checkpoint_examples()\r\n        assert len(self.checkpoint_examples) >= 1\r\n        assert self.checkpoint_examples[0] > 0\r\n        self.checkpoint_examples = [0] + self.checkpoint_examples\r\n\r\n        self.module_names = self.module_factories.keys()\r\n        assert len(self.module_names) > 0\r\n\r\n        self.training_data_loader = None\r\n        self.training_data_loader_iter = None\r\n        self.training_data_loader_batch_size = None\r\n        self.training_data_sampler = None\r\n\r\n        self.validation_data_loader = None\r\n        self.validation_data_loader_iter = None\r\n        self.validation_data_loader_batch_size = None\r\n\r\n        self.sample_output_data = None\r\n        self.summary_writer = None\r\n        self.log_dir = None\r\n        self.training_state = None\r\n\r\n    def get_sample_output_data_file_name(self):\r\n        return self.prefix + \"/sample_output_data.pt\"\r\n\r\n    def save_sample_output_data(self, rank: int, device: torch.device):\r\n        if rank != 0:\r\n            return\r\n        if os.path.exists(self.get_sample_output_data_file_name()):\r\n            return\r\n        if self.sample_output_protocol is not None:\r\n            torch.manual_seed(self.sample_output_protocol.get_random_seed())\r\n            sample_output_data = self.sample_output_protocol.get_sample_output_data(self.validation_dataset, device)\r\n            torch_save(sample_output_data, self.get_sample_output_data_file_name())\r\n        else:\r\n            torch_save({}, self.get_sample_output_data_file_name())\r\n\r\n    def load_sample_output_data(self, rank: int, device: torch.device):\r\n        if rank != 0:\r\n            return None\r\n        else:\r\n            self.save_sample_output_data(rank, device)\r\n            return torch_load(self.get_sample_output_data_file_name())\r\n\r\n    def get_snapshot_prefix(self) -> str:\r\n        return self.prefix + \"/snapshot\"\r\n\r\n    def can_load_training_state(self, prefix: str, world_size: int) -> bool:\r\n        return DistributedTrainingState.can_load(\r\n            prefix,\r\n            self.module_factories,\r\n            self.accumulators,\r\n            self.training_protocol.get_optimizer_factories(),\r\n            world_size)\r\n\r\n    def load_training_state(self, prefix, rank: int, local_rank: int, device: torch.device) -> DistributedTrainingState:\r\n        return DistributedTrainingState.load(\r\n            prefix,\r\n            self.module_factories,\r\n            self.accumulators,\r\n            self.training_protocol.get_optimizer_factories(),\r\n            rank,\r\n            local_rank,\r\n            device)\r\n\r\n    @staticmethod\r\n    def checkpoint_prefix(prefix: str, checkpoint_index: int) -> str:\r\n        return \"%s/checkpoint/%04d\" % (prefix, checkpoint_index)\r\n\r\n    def get_checkpoint_prefix(self, checkpoint_index) -> str:\r\n        return DistributedTrainer.checkpoint_prefix(self.prefix, checkpoint_index)\r\n\r\n    def get_initial_training_state(self, rank: int, local_rank: int, device: torch.device) -> DistributedTrainingState:\r\n        training_state = DistributedTrainingState.new(\r\n            self.module_factories,\r\n            self.accumulators,\r\n            self.training_protocol.get_optimizer_factories(),\r\n            self.training_protocol.get_random_seed(),\r\n            rank,\r\n            local_rank,\r\n            device,\r\n            self.pretrained_module_file_names)\r\n        logging.info(\"Created a new initial training state.\")\r\n        return training_state\r\n\r\n    def load_previous_training_state(self,\r\n                                     target_checkpoint_examples: int,\r\n                                     world_size: int,\r\n                                     rank: int,\r\n                                     local_rank: int,\r\n                                     device: torch.device) -> DistributedTrainingState:\r\n        if self.can_load_training_state(self.get_snapshot_prefix(), world_size):\r\n            examples_seen_so_far = DistributedTrainingState.get_examples_seen_so_far(self.get_snapshot_prefix())\r\n            diff = examples_seen_so_far - target_checkpoint_examples\r\n            if diff < self.training_protocol.get_batch_size():\r\n                return self.load_training_state(self.get_snapshot_prefix(), rank, local_rank, device)\r\n        num_checkpoints = len(self.checkpoint_examples)\r\n        for checkpoint_index in range(num_checkpoints - 1, -1, -1):\r\n            if self.can_load_training_state(self.get_checkpoint_prefix(checkpoint_index), world_size):\r\n                examples_seen_so_far = DistributedTrainingState.get_examples_seen_so_far(\r\n                    self.get_checkpoint_prefix(checkpoint_index))\r\n                diff = examples_seen_so_far - target_checkpoint_examples\r\n                if diff < self.training_protocol.get_batch_size():\r\n                    return self.load_training_state(\r\n                        self.get_checkpoint_prefix(checkpoint_index), rank, local_rank, device)\r\n\r\n        training_state = self.get_initial_training_state(rank, local_rank, device)\r\n        training_state.save(self.get_checkpoint_prefix(0), rank, lambda: self.barrier(local_rank))\r\n        training_state = self.load_training_state(self.get_checkpoint_prefix(0), rank, local_rank, device)\r\n        return training_state\r\n\r\n    def get_log_dir(self):\r\n        if self.log_dir is None:\r\n            now = datetime.now()\r\n            self.log_dir = self.prefix + \"/log/\" + now.strftime(\"%Y_%m_%d__%H_%M_%S\")\r\n        return self.log_dir\r\n\r\n    def get_summary_writer(self, rank: int) -> Optional[SummaryWriter]:\r\n        if rank != 0:\r\n            return None\r\n        if self.summary_writer is None:\r\n            self.summary_writer = SummaryWriter(log_dir=self.get_log_dir())\r\n        return self.summary_writer\r\n\r\n    def get_effective_training_epoch_size(self, world_size: int):\r\n        batch_size = self.training_protocol.get_batch_size()\r\n        N = len(self.training_dataset)\r\n        N = (N // world_size) * world_size\r\n        N = (N // batch_size) * batch_size\r\n        return N\r\n\r\n    def get_training_epoch_index(self, examples_seen_so_far: int, world_size: int):\r\n        epoch_size = self.get_effective_training_epoch_size(world_size)\r\n        batch_size = self.training_protocol.get_batch_size()\r\n        return (examples_seen_so_far + batch_size * world_size) // epoch_size\r\n\r\n    def get_next_training_batch(self, examples_seen_so_far: int, world_size: int, device: torch.device):\r\n        batch_size = self.training_protocol.get_batch_size()\r\n        dataset = self.training_dataset\r\n        if self.training_data_loader is None:\r\n            self.training_data_sampler = DistributedSampler(\r\n                dataset,\r\n                shuffle=True,\r\n                drop_last=True)\r\n            self.training_data_loader = DataLoader(\r\n                dataset,\r\n                batch_size=batch_size,\r\n                sampler=self.training_data_sampler,\r\n                shuffle=False,\r\n                num_workers=self.num_data_loader_workers,\r\n                drop_last=True)\r\n        if self.training_data_loader_iter is None:\r\n            epoch_index = self.get_training_epoch_index(examples_seen_so_far, world_size)\r\n            logging.info(f\"Started a new epoch: index = {epoch_index}, examples_seen_so_far = {examples_seen_so_far}\")\r\n            self.training_data_sampler.set_epoch(epoch_index)\r\n            self.training_data_loader_iter = iter(self.training_data_loader)\r\n        try:\r\n            batch = next(self.training_data_loader_iter)\r\n        except StopIteration:\r\n            epoch_index = self.get_training_epoch_index(examples_seen_so_far, world_size)\r\n            logging.info(f\"Started a new epoch: index = {epoch_index}, examples_seen_so_far = {examples_seen_so_far}\")\r\n            self.training_data_sampler.set_epoch(epoch_index)\r\n            self.training_data_loader_iter = iter(self.training_data_loader)\r\n            batch = next(self.training_data_loader_iter)\r\n        return [x.to(device) for x in batch]\r\n\r\n    def get_next_checkpoint_num_examples(self, examples_seen_so_far) -> int:\r\n        next_index = next(\r\n            (i for i in range(len(self.checkpoint_examples)) if self.checkpoint_examples[i] > examples_seen_so_far),\r\n            -1)\r\n        return self.checkpoint_examples[next_index]\r\n\r\n    def get_next_snapshot_num_examples(self, examples_seen_so_far) -> int:\r\n        return get_least_greater_multiple(examples_seen_so_far, self.example_per_snapshot)\r\n\r\n    def get_next_validation_num_examples(self, examples_seen_so_far) -> int:\r\n        if self.validation_protocol is None:\r\n            return -1\r\n        return get_least_greater_multiple(examples_seen_so_far,\r\n                                          self.validation_protocol.get_examples_per_validation_iteration())\r\n\r\n    def get_next_sample_output_num_examples(self, examples_seen_so_far) -> int:\r\n        if self.sample_output_protocol is None:\r\n            return -1\r\n        return get_least_greater_multiple(examples_seen_so_far,\r\n                                          self.sample_output_protocol.get_examples_per_sample_output())\r\n\r\n    def get_next_num_examples(self, examples_seen_so_far) -> Dict[str, int]:\r\n        return {\r\n            KEY_CHECKPOINT: self.get_next_checkpoint_num_examples(examples_seen_so_far),\r\n            KEY_SNAPSHOT: self.get_next_snapshot_num_examples(examples_seen_so_far),\r\n            KEY_VALIDATION: self.get_next_validation_num_examples(examples_seen_so_far),\r\n            KEY_SAMPLE_OUTPUT: self.get_next_sample_output_num_examples(examples_seen_so_far)\r\n        }\r\n\r\n    def get_next_validation_batch(self, device: torch.device):\r\n        if self.validation_dataset is None:\r\n            return None\r\n        if self.validation_data_loader is None:\r\n            self.validation_data_loader = DataLoader(\r\n                self.validation_dataset,\r\n                batch_size=self.validation_protocol.get_batch_size(),\r\n                shuffle=True,\r\n                num_workers=1,\r\n                drop_last=True)\r\n        if self.validation_data_loader_iter is None:\r\n            self.validation_data_loader_iter = iter(self.validation_data_loader)\r\n        try:\r\n            batch = next(self.validation_data_loader_iter)\r\n        except StopIteration:\r\n            self.validation_data_loader_iter = iter(self.validation_data_loader)\r\n            batch = next(self.validation_data_loader_iter)\r\n        return [x.to(device) for x in batch]\r\n\r\n    def get_checkpoint_index_to_save(self, examples_seen_so_far: int) -> int:\r\n        checkpoint_index = 0\r\n        for i in range(len(self.checkpoint_examples)):\r\n            if self.checkpoint_examples[i] <= examples_seen_so_far:\r\n                checkpoint_index = i\r\n        return checkpoint_index\r\n\r\n    def barrier(self, local_rank: int):\r\n        if self.distrib_backend == 'nccl':\r\n            torch.distributed.barrier(device_ids=[local_rank])\r\n        else:\r\n            torch.distributed.barrier()\r\n\r\n    def train(self,\r\n              world_size: int,\r\n              rank: int,\r\n              local_rank: int,\r\n              target_checkpoint_examples: Optional[int] = None,\r\n              device_mapper: Optional[Callable[[int, int], torch.device]] = None):\r\n        if target_checkpoint_examples is None:\r\n            target_checkpoint_examples = self.checkpoint_examples[-1]\r\n\r\n        if device_mapper is None:\r\n            device_mapper = SimpleCudaDeviceMapper()\r\n        device = device_mapper(rank, local_rank)\r\n\r\n        sample_output_data = self.load_sample_output_data(rank, device)\r\n        training_state = self.load_previous_training_state(\r\n            target_checkpoint_examples, world_size, rank, local_rank, device)\r\n        summary_writer = self.get_summary_writer(rank)\r\n        if summary_writer is not None:\r\n            log_func_factory = lambda name, num: create_log_func(summary_writer, name, num)\r\n        else:\r\n            log_func_factory = None\r\n        last_time = time.time()\r\n\r\n        while training_state.examples_seen_so_far < target_checkpoint_examples:\r\n            # Set the learning rate\r\n            learning_rate_by_module_name = self.training_protocol.get_learning_rate(training_state.examples_seen_so_far)\r\n            for module_name in self.module_factories.keys():\r\n                if module_name not in learning_rate_by_module_name or module_name not in training_state.optimizers:\r\n                    continue\r\n                lr = learning_rate_by_module_name[module_name]\r\n                set_learning_rate(training_state.optimizers[module_name], lr)\r\n                if summary_writer is not None:\r\n                    summary_writer.add_scalar(\r\n                        module_name + \"_learning_rate\", lr, training_state.examples_seen_so_far)\r\n\r\n            # One training iteration\r\n            training_batch = self.get_next_training_batch(training_state.examples_seen_so_far, world_size, device)\r\n            self.training_protocol.run_training_iteration(\r\n                training_batch,\r\n                training_state.examples_seen_so_far,\r\n                training_state.modules,\r\n                training_state.accumulated_modules,\r\n                training_state.optimizers,\r\n                self.losses,\r\n                log_func_factory,\r\n                device)\r\n\r\n            # Accumulate model data\r\n            for module_name in self.accumulators:\r\n                new_module = training_state.modules[module_name]\r\n                if isinstance(new_module, DistributedDataParallel):\r\n                    new_module = new_module.module\r\n                buffer_module = training_state.accumulated_modules[module_name]\r\n                self.accumulators[module_name].accumulate(\r\n                    new_module, buffer_module, examples_seen_so_far=training_state.examples_seen_so_far)\r\n\r\n            # Advance the number of examples seen so far\r\n            next_num_examples = self.get_next_num_examples(training_state.examples_seen_so_far)\r\n            training_state.examples_seen_so_far += self.training_protocol.get_batch_size() * world_size\r\n\r\n            # Validation iteration\r\n            if self.validation_protocol is not None \\\r\n                    and training_state.examples_seen_so_far >= next_num_examples[KEY_VALIDATION] \\\r\n                    and rank == 0:\r\n                validation_batch = self.get_next_validation_batch(device)\r\n                self.validation_protocol.run_validation_iteration(\r\n                    validation_batch,\r\n                    training_state.examples_seen_so_far,\r\n                    training_state.modules,\r\n                    training_state.accumulated_modules,\r\n                    self.losses,\r\n                    log_func_factory,\r\n                    device)\r\n\r\n            # Save sample output\r\n            if self.sample_output_protocol is not None \\\r\n                    and training_state.examples_seen_so_far >= next_num_examples[KEY_SAMPLE_OUTPUT]:\r\n                if rank == 0:\r\n                    self.sample_output_protocol.save_sample_output_data(\r\n                        training_state.modules,\r\n                        training_state.accumulated_modules,\r\n                        sample_output_data,\r\n                        self.prefix + \"/sample_outputs\",\r\n                        training_state.examples_seen_so_far,\r\n                        device)\r\n                self.barrier(local_rank)\r\n\r\n            # Save checkpoint\r\n            if training_state.examples_seen_so_far >= next_num_examples[KEY_CHECKPOINT]:\r\n                checkpoint_index = self.get_checkpoint_index_to_save(training_state.examples_seen_so_far)\r\n                training_state.save(\r\n                    self.get_checkpoint_prefix(checkpoint_index), rank, lambda: self.barrier(local_rank))\r\n                if next_num_examples[KEY_CHECKPOINT] != next_num_examples[KEY_SNAPSHOT]:\r\n                    training_state.save(self.get_snapshot_prefix(), rank, lambda: self.barrier(local_rank))\r\n\r\n            # Save snapshot\r\n            if training_state.examples_seen_so_far >= next_num_examples[KEY_SNAPSHOT]:\r\n                training_state.save(self.get_snapshot_prefix(), rank, lambda: self.barrier(local_rank))\r\n\r\n            now = time.time()\r\n            if now - last_time > 10:\r\n                logging.info(\"Showed %d training examples.\" % training_state.examples_seen_so_far)\r\n                last_time = now\r\n\r\n    @staticmethod\r\n    def get_default_arg_parser() -> argparse.ArgumentParser:\r\n        parser = argparse.ArgumentParser(description='Training script.')\r\n        parser.add_argument(\"--target_checkpoint_examples\", type=int)\r\n        return parser\r\n\r\n    @staticmethod\r\n    def run_with_args(trainer_factory: Callable[[int, str], 'DistributedTrainer'],\r\n                      args,\r\n                      backend: str = 'gloo',\r\n                      device_mapper: Optional[Callable[[int, int], torch.device]] = None):\r\n        world_size = int(os.environ['WORLD_SIZE'])\r\n        rank = int(os.environ['RANK'])\r\n        local_rank = int(os.environ['LOCAL_RANK'])\r\n\r\n        torch.distributed.init_process_group(backend)\r\n        trainer = trainer_factory(world_size, backend)\r\n        trainer.train(world_size, rank, local_rank, args.target_checkpoint_examples, device_mapper)\r\n\r\n    @staticmethod\r\n    def run(trainer_factory: Callable[[int, str], 'DistributedTrainer'],\r\n            backend: str = 'gloo',\r\n            device_mapper: Optional[Callable[[int, int], torch.device]] = None,\r\n            args: Optional[Any] = None):\r\n        if args is None:\r\n            parser = DistributedTrainer.get_default_arg_parser()\r\n            args = parser.parse_args()\r\n\r\n        DistributedTrainer.run_with_args(trainer_factory, args, backend, device_mapper)\r\n"
  },
  {
    "path": "src/tha4/shion/core/training/distrib/distributed_training_states.py",
    "content": "import copy\r\nimport logging\r\nimport os\r\nfrom typing import Dict, Optional, Callable\r\n\r\nimport torch\r\nfrom torch.nn import Module\r\nfrom torch.nn.parallel import DistributedDataParallel\r\nfrom torch.optim.optimizer import Optimizer\r\n\r\nfrom tha4.shion.core.load_save import torch_save, torch_load\r\nfrom tha4.shion.core.module_accumulator import ModuleAccumulator\r\nfrom tha4.shion.core.module_factory import ModuleFactory\r\nfrom tha4.shion.core.optimizer_factory import OptimizerFactory\r\nfrom tha4.shion.core.training.util import optimizer_to_device\r\n\r\n\r\nclass DistributedTrainingState:\r\n    def __init__(self,\r\n                 examples_seen_so_far: int,\r\n                 modules: Dict[str, Module],\r\n                 accumulated_modules: Dict[str, Module],\r\n                 optimizers: Dict[str, Optimizer]):\r\n        self.accumulated_modules = accumulated_modules\r\n        self.optimizers = optimizers\r\n        self.modules = modules\r\n        self.examples_seen_so_far = examples_seen_so_far\r\n\r\n    @staticmethod\r\n    def get_examples_seen_so_far_file_name(prefix) -> str:\r\n        return prefix + \"/examples_seen_so_far.txt\"\r\n\r\n    @staticmethod\r\n    def get_module_file_name(prefix, module_name) -> str:\r\n        return \"%s/module_%s.pt\" % (prefix, module_name)\r\n\r\n    @staticmethod\r\n    def get_accumulated_module_file_name(prefix, module_name) -> str:\r\n        return \"%s/accumulated_%s.pt\" % (prefix, module_name)\r\n\r\n    @staticmethod\r\n    def get_optimizer_file_name(prefix, module_name) -> str:\r\n        return \"%s/optimizer_%s.pt\" % (prefix, module_name)\r\n\r\n    @staticmethod\r\n    def get_rng_state_file_name(prefix, rank: int):\r\n        return \"%s/rng_state_%08d.pt\" % (prefix, rank)\r\n\r\n    def mkdir(self, prefix: str):\r\n        os.makedirs(prefix, exist_ok=True)\r\n\r\n    def save_data(self, prefix: str, rank: int):\r\n        assert os.path.exists(prefix)\r\n\r\n        torch_save(torch.get_rng_state(), DistributedTrainingState.get_rng_state_file_name(prefix, rank))\r\n        logging.info(\"Saved %s\" % DistributedTrainingState.get_rng_state_file_name(prefix, rank))\r\n\r\n        if rank == 0:\r\n            logging.info(\"Saving training state to %s\" % prefix)\r\n            with open(DistributedTrainingState.get_examples_seen_so_far_file_name(prefix), \"wt\") as fout:\r\n                fout.write(\"%d\\n\" % self.examples_seen_so_far)\r\n                logging.info(\"Saved %s\" % DistributedTrainingState.get_examples_seen_so_far_file_name(prefix))\r\n            for module_name in self.modules:\r\n                file_name = DistributedTrainingState.get_module_file_name(prefix, module_name)\r\n                module = self.modules[module_name]\r\n                if isinstance(module, DistributedDataParallel):\r\n                    state_dict = module.module.state_dict()\r\n                else:\r\n                    state_dict = module.state_dict()\r\n                torch_save(state_dict, file_name)\r\n                logging.info(\"Saved %s\" % file_name)\r\n            for module_name in self.accumulated_modules:\r\n                file_name = DistributedTrainingState.get_accumulated_module_file_name(prefix, module_name)\r\n                torch_save(self.accumulated_modules[module_name].state_dict(), file_name)\r\n                logging.info(\"Saved %s\" % file_name)\r\n            for module_name in self.optimizers:\r\n                file_name = DistributedTrainingState.get_optimizer_file_name(prefix, module_name)\r\n                torch_save(self.optimizers[module_name].state_dict(), file_name)\r\n                logging.info(\"Saved %s\" % file_name)\r\n\r\n        logging.info(\"Done saving training state to %s\" % prefix)\r\n\r\n    def save(self, prefix: str, rank: int, barrier_func: Callable[[], None]):\r\n        if rank == 0:\r\n            self.mkdir(prefix)\r\n        barrier_func()\r\n        self.save_data(prefix, rank)\r\n        barrier_func()\r\n\r\n    @staticmethod\r\n    def get_examples_seen_so_far(prefix: str) -> int:\r\n        with open(DistributedTrainingState.get_examples_seen_so_far_file_name(prefix)) as fin:\r\n            lines = fin.readlines()\r\n            return int(lines[0])\r\n\r\n    @staticmethod\r\n    def load(\r\n            prefix: str,\r\n            module_factories: Dict[str, ModuleFactory],\r\n            accumulators: Dict[str, ModuleAccumulator],\r\n            optimizer_factories: Dict[str, OptimizerFactory],\r\n            rank: int,\r\n            local_rank: int,\r\n            device: torch.device) -> 'DistributedTrainingState':\r\n        logging.info(\"Loading training state from %s\" % prefix)\r\n\r\n        with open(DistributedTrainingState.get_examples_seen_so_far_file_name(prefix)) as fin:\r\n            lines = fin.readlines()\r\n            examples_seen_so_far = int(lines[0])\r\n            logging.info(\"Loaded %s\" % DistributedTrainingState.get_examples_seen_so_far_file_name(prefix))\r\n\r\n        modules = {\r\n            module_name: factory.create()\r\n            for (module_name, factory) in module_factories.items()\r\n        }\r\n        for module_name in modules:\r\n            file_name = DistributedTrainingState.get_module_file_name(prefix, module_name)\r\n            module = modules[module_name]\r\n            state_dict = torch_load(file_name)\r\n            module.load_state_dict(state_dict)\r\n            module.to(device)\r\n            modules[module_name] = DistributedDataParallel(\r\n                module,\r\n                device_ids=[device.index],\r\n                output_device=device.index)\r\n            logging.info(\"Loaded %s\" % file_name)\r\n\r\n        accumulated_modules = {}\r\n        for module_name in accumulators:\r\n            module_factory = module_factories[module_name]\r\n            module = module_factory.create()\r\n            file_name = DistributedTrainingState.get_accumulated_module_file_name(prefix, module_name)\r\n            module.load_state_dict(torch_load(file_name))\r\n            module.to(device)\r\n            accumulated_modules[module_name] = module\r\n            logging.info(\"Loaded %s\" % file_name)\r\n\r\n        optimizers = {}\r\n        for module_name in optimizer_factories:\r\n            optimizer = optimizer_factories[module_name].create(modules[module_name].parameters())\r\n            file_name = DistributedTrainingState.get_optimizer_file_name(prefix, module_name)\r\n            optimizer.load_state_dict(torch_load(file_name))\r\n            optimizer_to_device(optimizer, device)\r\n            optimizers[module_name] = optimizer\r\n            logging.info(\"Loaded %s\" % file_name)\r\n\r\n        torch.set_rng_state(torch_load(DistributedTrainingState.get_rng_state_file_name(prefix, rank)))\r\n        logging.info(\"Loaded %s\" % DistributedTrainingState.get_rng_state_file_name(prefix, rank))\r\n\r\n        logging.info(\"Done loading training state from %s\" % prefix)\r\n\r\n        return DistributedTrainingState(examples_seen_so_far, modules, accumulated_modules, optimizers)\r\n\r\n    @staticmethod\r\n    def new(module_factories: Dict[str, ModuleFactory],\r\n            accumulators: Dict[str, ModuleAccumulator],\r\n            optimizer_factories: Dict[str, OptimizerFactory],\r\n            random_seed: int,\r\n            rank: int,\r\n            local_rank: int,\r\n            device: torch.device,\r\n            pretrained_module_file_names: Optional[Dict[str, str]] = None) -> 'DistributedTrainingState':\r\n        examples_seen_so_far = 0\r\n\r\n        modules = {\r\n            module_name: factory.create()\r\n            for (module_name, factory) in module_factories.items()\r\n        }\r\n        for module_name in modules:\r\n            modules[module_name].to(device)\r\n        if pretrained_module_file_names is not None:\r\n            for module_name in modules:\r\n                if module_name in pretrained_module_file_names:\r\n                    file_name = pretrained_module_file_names[module_name]\r\n                    modules[module_name].load_state_dict(torch_load(file_name))\r\n                    logging.info(\"Loaded initial state from %s ...\" % file_name)\r\n\r\n        accumulated_modules = {}\r\n        for module_name in accumulators:\r\n            accumulated_modules[module_name] = copy.deepcopy(modules[module_name])\r\n\r\n        for module_name in modules:\r\n            module = modules[module_name]\r\n            modules[module_name] = DistributedDataParallel(\r\n                module,\r\n                device_ids=[device.index],\r\n                output_device=device.index)\r\n\r\n        optimizers = {}\r\n        for module_name in optimizer_factories:\r\n            module = modules[module_name]\r\n            optimizer = optimizer_factories[module_name].create(module.parameters())\r\n            optimizer_to_device(optimizer, device)\r\n            optimizers[module_name] = optimizer\r\n\r\n        torch.manual_seed(random_seed + rank)\r\n\r\n        return DistributedTrainingState(examples_seen_so_far, modules, accumulated_modules, optimizers)\r\n\r\n    @staticmethod\r\n    def can_load(prefix: str,\r\n                 module_factories: Dict[str, ModuleFactory],\r\n                 accumulators: Dict[str, ModuleAccumulator],\r\n                 optimizer_factories: Dict[str, OptimizerFactory],\r\n                 world_size: int) -> bool:\r\n        logging.info(f\"Checking directory {prefix}\")\r\n        if not os.path.isdir(prefix):\r\n            logging.info(f\"Cannot load files in {prefix} because it is not a directory\")\r\n            return False\r\n        examples_seen_so_far_file_name = DistributedTrainingState.get_examples_seen_so_far_file_name(prefix)\r\n        if not os.path.isfile(examples_seen_so_far_file_name):\r\n            logging.info(f\"Cannot load files in {prefix} because {examples_seen_so_far_file_name} is not a file.\")\r\n            return False\r\n        for module_name in module_factories.keys():\r\n            file_name = DistributedTrainingState.get_module_file_name(prefix, module_name)\r\n            if not os.path.isfile(file_name):\r\n                logging.info(f\"Cannot load files in {prefix} because {file_name} is not a file.\")\r\n                return False\r\n        for module_name in accumulators:\r\n            file_name = DistributedTrainingState.get_accumulated_module_file_name(prefix, module_name)\r\n            if not os.path.isfile(file_name):\r\n                logging.info(f\"Cannot load files in {prefix} because {file_name} is not a file.\")\r\n                return False\r\n        for module_name in optimizer_factories:\r\n            file_name = DistributedTrainingState.get_optimizer_file_name(prefix, module_name)\r\n            if not os.path.isfile(file_name):\r\n                logging.info(f\"Cannot load files in {prefix} because {file_name} is not a file.\")\r\n                return False\r\n        for rank in range(world_size):\r\n            file_name = DistributedTrainingState.get_rng_state_file_name(prefix, rank)\r\n            if not os.path.isfile(file_name):\r\n                logging.info(f\"Cannot load files in {prefix} because {file_name} is not a file.\")\r\n                return False\r\n        return True\r\n"
  },
  {
    "path": "src/tha4/shion/core/training/distrib/distributed_training_tasks.py",
    "content": "import logging\r\nimport os\r\nimport sys\r\nfrom typing import Callable, List, Optional\r\n\r\nfrom tha4.pytasuku.workspace import Workspace\r\nfrom tha4.shion.core.training.distrib.distributed_trainer import DistributedTrainer\r\nfrom tha4.shion.core.training.distrib.distributed_training_states import DistributedTrainingState\r\n\r\n\r\ndef get_torchrun_executable():\r\n    return os.path.dirname(sys.executable) + os.path.sep + \"torchrun\"\r\n\r\n\r\ndef run_distributed_training_script(\r\n        training_script_file_name: str,\r\n        num_nodes: int,\r\n        node_rank: int,\r\n        num_proc_per_node: int,\r\n        master_addr: int = \"127.0.0.1\",\r\n        master_port: int = 8888):\r\n    command = f\"{get_torchrun_executable()} \" \\\r\n              f\"--nproc_per_node={num_proc_per_node} \" \\\r\n              f\"--nnodes={num_nodes} \" \\\r\n              f\"--node_rank={node_rank} \" \\\r\n              f\"--master_addr={master_addr} \" \\\r\n              f\"--master_port={master_port} \" \\\r\n              f\"{training_script_file_name}\"\r\n    logging.info(f\"Executing -- {command}\")\r\n    os.system(command)\r\n\r\n\r\nclass RdzvConfig:\r\n    def __init__(self, id: int, port: int):\r\n        self.port = port\r\n        self.id = id\r\n\r\n\r\ndef run_standalone_distributed_training_script(\r\n        training_script_file_name: str,\r\n        num_proc_per_node: int,\r\n        target_checkpoint_examples: Optional[int] = None,\r\n        rdzv_config: Optional[RdzvConfig] = None):\r\n    command = f\"{get_torchrun_executable()} \" \\\r\n              f\"--nnodes=1 \" \\\r\n              f\"--nproc_per_node={num_proc_per_node} \"\r\n    if rdzv_config is not None:\r\n        command += f\"--rdzv_endpoint=localhost:{rdzv_config.port} \"\r\n        command += \"--rdzv_backend=c10d \"\r\n        command += f\"--rdzv_id={rdzv_config.id} \"\r\n    else:\r\n        command += \"--standalone \"\r\n    command += f\"{training_script_file_name} \"\r\n    if target_checkpoint_examples is not None:\r\n        command += f\"--target_checkpoint_examples {target_checkpoint_examples} \"\r\n    logging.info(f\"Executing -- {command}\")\r\n    os.system(command)\r\n\r\n\r\ndef define_distributed_training_tasks(\r\n        workspace: Workspace,\r\n        prefix: str,\r\n        training_script_file_name: str,\r\n        num_nodes: int,\r\n        num_proc_per_node: int,\r\n        master_addr: int = \"127.0.0.1\",\r\n        master_port: int = 8888):\r\n    def run_training_script_func(rank: int):\r\n        def _f():\r\n            run_distributed_training_script(\r\n                training_script_file_name,\r\n                num_nodes,\r\n                rank,\r\n                num_proc_per_node, master_addr,\r\n                master_port)\r\n\r\n        return _f\r\n\r\n    for i in range(num_nodes):\r\n        workspace.create_command_task(f\"{prefix}/train_node_%06d\" % i, [], run_training_script_func(i))\r\n\r\n\r\ndef define_standalone_distributed_training_tasks(\r\n        workspace: Workspace,\r\n        distributed_trainer_func: Callable[[int], DistributedTrainer],\r\n        training_script_file_name: str,\r\n        num_proc_per_node: int,\r\n        dependencies: Optional[List[str]] = None,\r\n        rdzv_config: Optional[RdzvConfig] = None):\r\n    trainer = distributed_trainer_func(1)\r\n    checkpoint_examples = trainer.training_protocol.get_checkpoint_examples()\r\n    assert len(checkpoint_examples) >= 1\r\n    assert checkpoint_examples[0] > 0\r\n    checkpoint_examples = [0] + checkpoint_examples\r\n\r\n    if dependencies is None:\r\n        dependencies = []\r\n    module_file_dependencies = dependencies[:]\r\n    for module_name in trainer.pretrained_module_file_names:\r\n        module_file_dependencies.append(trainer.pretrained_module_file_names[module_name])\r\n\r\n    def create_train_func(target_checkpoint_examples: int):\r\n        return lambda: run_standalone_distributed_training_script(\r\n            training_script_file_name,\r\n            num_proc_per_node,\r\n            target_checkpoint_examples,\r\n            rdzv_config=rdzv_config)\r\n\r\n    train_tasks = []\r\n    for checkpoint_index in range(0, len(checkpoint_examples)):\r\n        for module_name in trainer.module_names:\r\n            module_file_name = DistributedTrainingState.get_module_file_name(\r\n                trainer.get_checkpoint_prefix(checkpoint_index),\r\n                module_name)\r\n            workspace.create_file_task(\r\n                module_file_name,\r\n                module_file_dependencies,\r\n                create_train_func(trainer.checkpoint_examples[checkpoint_index]))\r\n        for module_name in trainer.accumulators:\r\n            accumulated_module_file_name = DistributedTrainingState.get_accumulated_module_file_name(\r\n                trainer.get_checkpoint_prefix(checkpoint_index),\r\n                module_name)\r\n            workspace.create_file_task(\r\n                accumulated_module_file_name,\r\n                module_file_dependencies,\r\n                create_train_func(checkpoint_examples[checkpoint_index]))\r\n        workspace.create_command_task(\r\n            trainer.get_checkpoint_prefix(checkpoint_index) + \"/train_standalone\",\r\n            module_file_dependencies,\r\n            create_train_func(checkpoint_examples[checkpoint_index]))\r\n        train_tasks.append(trainer.get_checkpoint_prefix(checkpoint_index) + \"/train_standlone\")\r\n    workspace.create_file_task(\r\n        trainer.prefix + \"/train_standalone\",\r\n        module_file_dependencies,\r\n        create_train_func(checkpoint_examples[-1]))\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    print(os.path.dirname(sys.executable) + os.path.sep + \"torchrun\")\r\n"
  },
  {
    "path": "src/tha4/shion/core/training/sample_output_protocol.py",
    "content": "from abc import ABC, abstractmethod\r\nfrom typing import Dict, Any\r\n\r\nimport torch\r\nfrom torch.nn import Module\r\nfrom torch.utils.data import Dataset\r\n\r\n\r\nclass SampleOutputProtocol(ABC):\r\n    @abstractmethod\r\n    def get_examples_per_sample_output(self) -> int:\r\n        pass\r\n\r\n    @abstractmethod\r\n    def get_random_seed(self) -> int:\r\n        pass\r\n\r\n    @abstractmethod\r\n    def get_sample_output_data(self, validation_dataset: Dataset, device: torch.device) -> Any:\r\n        pass\r\n\r\n    @abstractmethod\r\n    def save_sample_output_data(\r\n            self,\r\n            modules: Dict[str, Module],\r\n            accumulated_modules: Dict[str, Module],\r\n            sample_output_data: Any,\r\n            prefix: str,\r\n            examples_seen_so_far: int,\r\n            device: torch.device):\r\n        pass\r\n\r\n\r\nclass AbstractSampleOutputProtocol(SampleOutputProtocol, ABC):\r\n    def __init__(self, examples_per_sample_output: int, random_seed: int):\r\n        self.random_seed = random_seed\r\n        self.examples_per_sample_output = examples_per_sample_output\r\n\r\n    def get_examples_per_sample_output(self) -> int:\r\n        return self.examples_per_sample_output\r\n\r\n    def get_random_seed(self) -> int:\r\n        return self.random_seed\r\n"
  },
  {
    "path": "src/tha4/shion/core/training/single/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/shion/core/training/single/training_states.py",
    "content": "import copy\r\nimport logging\r\nimport os\r\nfrom typing import Dict, Optional\r\n\r\nimport torch\r\nfrom torch.nn import Module\r\nfrom torch.optim import Optimizer\r\n\r\nfrom tha4.shion.core.load_save import torch_save, torch_load\r\nfrom tha4.shion.core.module_accumulator import ModuleAccumulator\r\nfrom tha4.shion.core.module_factory import ModuleFactory\r\nfrom tha4.shion.core.optimizer_factory import OptimizerFactory\r\nfrom tha4.shion.core.training.util import optimizer_to_device\r\n\r\n\r\nclass TrainingState:\r\n    def __init__(self,\r\n                 examples_seen_so_far: int,\r\n                 modules: Dict[str, Module],\r\n                 accumulated_modules: Dict[str, Module],\r\n                 optimizers: Dict[str, Optimizer]):\r\n        self.accumulated_modules = accumulated_modules\r\n        self.optimizers = optimizers\r\n        self.modules = modules\r\n        self.examples_seen_so_far = examples_seen_so_far\r\n\r\n    @staticmethod\r\n    def get_examples_seen_so_far_file_name(prefix) -> str:\r\n        return prefix + \"/examples_seen_so_far.txt\"\r\n\r\n    @staticmethod\r\n    def get_module_file_name(prefix, module_name) -> str:\r\n        return \"%s/module_%s.pt\" % (prefix, module_name)\r\n\r\n    @staticmethod\r\n    def get_accumulated_module_file_name(prefix, module_name) -> str:\r\n        return \"%s/accumulated_%s.pt\" % (prefix, module_name)\r\n\r\n    @staticmethod\r\n    def get_optimizer_file_name(prefix, module_name) -> str:\r\n        return \"%s/optimizer_%s.pt\" % (prefix, module_name)\r\n\r\n    @staticmethod\r\n    def get_rng_state_file_name(prefix):\r\n        return \"%s/rng_state.pt\" % prefix\r\n\r\n    def save(self, prefix):\r\n        logging.info(\"Saving training state to %s\" % prefix)\r\n        os.makedirs(prefix, exist_ok=True)\r\n        with open(TrainingState.get_examples_seen_so_far_file_name(prefix), \"wt\") as fout:\r\n            fout.write(\"%d\\n\" % self.examples_seen_so_far)\r\n            logging.info(\"Saved %s\" % TrainingState.get_examples_seen_so_far_file_name(prefix))\r\n        for module_name in self.modules:\r\n            file_name = TrainingState.get_module_file_name(prefix, module_name)\r\n            torch_save(self.modules[module_name].state_dict(), file_name)\r\n            logging.info(\"Saved %s\" % file_name)\r\n        for module_name in self.accumulated_modules:\r\n            file_name = TrainingState.get_accumulated_module_file_name(prefix, module_name)\r\n            torch_save(self.accumulated_modules[module_name].state_dict(), file_name)\r\n            logging.info(\"Saved %s\" % file_name)\r\n        for module_name in self.optimizers:\r\n            file_name = TrainingState.get_optimizer_file_name(prefix, module_name)\r\n            torch_save(self.optimizers[module_name].state_dict(), file_name)\r\n            logging.info(\"Saved %s\" % file_name)\r\n        torch_save(torch.get_rng_state(), TrainingState.get_rng_state_file_name(prefix))\r\n        logging.info(\"Saved %s\" % TrainingState.get_rng_state_file_name(prefix))\r\n        logging.info(\"Done saving training state to %s\" % prefix)\r\n\r\n    @staticmethod\r\n    def get_examples_seen_so_far(prefix: str) -> int:\r\n        with open(TrainingState.get_examples_seen_so_far_file_name(prefix)) as fin:\r\n            lines = fin.readlines()\r\n            return int(lines[0])\r\n\r\n    @staticmethod\r\n    def load(prefix: str,\r\n             module_factories: Dict[str, ModuleFactory],\r\n             accumulators: Dict[str, ModuleAccumulator],\r\n             optimizer_factories: Dict[str, OptimizerFactory],\r\n             device: torch.device) -> 'TrainingState':\r\n        logging.info(\"Loading training state from %s\" % prefix)\r\n\r\n        with open(TrainingState.get_examples_seen_so_far_file_name(prefix)) as fin:\r\n            lines = fin.readlines()\r\n            examples_seen_so_far = int(lines[0])\r\n            logging.info(\"Loaded %s\" % TrainingState.get_examples_seen_so_far_file_name(prefix))\r\n\r\n        modules = {\r\n            module_name: factory.create()\r\n            for (module_name, factory) in module_factories.items()\r\n        }\r\n        for module_name in modules:\r\n            file_name = TrainingState.get_module_file_name(prefix, module_name)\r\n            modules[module_name].load_state_dict(torch_load(file_name))\r\n            modules[module_name].to(device)\r\n            logging.info(\"Loaded %s\" % file_name)\r\n\r\n        accumulated_modules = {}\r\n        for module_name in accumulators:\r\n            module_factory = module_factories[module_name]\r\n            module = module_factory.create()\r\n            file_name = TrainingState.get_accumulated_module_file_name(prefix, module_name)\r\n            module.load_state_dict(torch_load(file_name))\r\n            module.to(device)\r\n            accumulated_modules[module_name] = module\r\n            logging.info(\"Loaded %s\" % file_name)\r\n\r\n        optimizers = {}\r\n        for module_name in optimizer_factories:\r\n            optimizer = optimizer_factories[module_name].create(modules[module_name].parameters())\r\n            file_name = TrainingState.get_optimizer_file_name(prefix, module_name)\r\n            optimizer.load_state_dict(torch_load(file_name))\r\n            optimizer_to_device(optimizer, device)\r\n            optimizers[module_name] = optimizer\r\n            logging.info(\"Loaded %s\" % file_name)\r\n\r\n        torch.set_rng_state(torch_load(TrainingState.get_rng_state_file_name(prefix)))\r\n        logging.info(\"Loaded %s\" % TrainingState.get_rng_state_file_name(prefix))\r\n\r\n        logging.info(\"Done loading training state from %s\" % prefix)\r\n\r\n        return TrainingState(examples_seen_so_far, modules, accumulated_modules, optimizers)\r\n\r\n    @staticmethod\r\n    def new(module_factories: Dict[str, ModuleFactory],\r\n            accumulators: Dict[str, ModuleAccumulator],\r\n            optimizer_factories: Dict[str, OptimizerFactory],\r\n            random_seed: int,\r\n            device: torch.device,\r\n            pretrained_module_file_names: Optional[Dict[str, str]] = None) -> 'TrainingState':\r\n        examples_seen_so_far = 0\r\n\r\n        modules = {\r\n            module_name: factory.create()\r\n            for (module_name, factory) in module_factories.items()\r\n        }\r\n        for module_name in modules:\r\n            modules[module_name].to(device)\r\n        if pretrained_module_file_names is not None:\r\n            for module_name in modules:\r\n                if module_name in pretrained_module_file_names:\r\n                    file_name = pretrained_module_file_names[module_name]\r\n                    modules[module_name].load_state_dict(torch_load(file_name))\r\n                    logging.info(\"Loaded initial state from %s ...\" % file_name)\r\n\r\n        accumulated_modules = {}\r\n        for module_name in accumulators:\r\n            accumulated_modules[module_name] = copy.deepcopy(modules[module_name])\r\n\r\n        optimizers = {}\r\n        for module_name in optimizer_factories:\r\n            module = modules[module_name]\r\n            optimizer = optimizer_factories[module_name].create(module.parameters())\r\n            optimizer_to_device(optimizer, device)\r\n            optimizers[module_name] = optimizer\r\n\r\n        torch.manual_seed(random_seed)\r\n\r\n        return TrainingState(examples_seen_so_far, modules, accumulated_modules, optimizers)\r\n\r\n    @staticmethod\r\n    def can_load(prefix: str,\r\n                 module_factories: Dict[str, ModuleFactory],\r\n                 accumulators: Dict[str, ModuleAccumulator],\r\n                 optimizer_factories: Dict[str, OptimizerFactory]) -> bool:\r\n        if not os.path.isdir(prefix):\r\n            return False\r\n        if not os.path.isfile(TrainingState.get_examples_seen_so_far_file_name(prefix)):\r\n            return False\r\n        for module_name in module_factories.keys():\r\n            if not os.path.isfile(TrainingState.get_module_file_name(prefix, module_name)):\r\n                return False\r\n        for module_name in accumulators:\r\n            if not os.path.isfile(TrainingState.get_accumulated_module_file_name(prefix, module_name)):\r\n                return False\r\n        for module_name in optimizer_factories:\r\n            if not os.path.isfile(TrainingState.get_optimizer_file_name(prefix, module_name)):\r\n                return False\r\n        if not os.path.isfile(TrainingState.get_rng_state_file_name(prefix)):\r\n            return False\r\n        return True\r\n"
  },
  {
    "path": "src/tha4/shion/core/training/single/training_tasks.py",
    "content": "import logging\r\nimport time\r\nfrom datetime import datetime\r\nfrom typing import Optional, Dict, List\r\n\r\nimport torch\r\nfrom torch.utils.data import Dataset, DataLoader\r\nfrom torch.utils.tensorboard import SummaryWriter\r\n\r\nfrom tha4.pytasuku.workspace import Workspace\r\nfrom tha4.shion.core.load_save import torch_save, torch_load\r\nfrom tha4.shion.core.loss import Loss\r\nfrom tha4.shion.core.module_accumulator import ModuleAccumulator\r\nfrom tha4.shion.core.module_factory import ModuleFactory\r\nfrom tha4.shion.core.training.sample_output_protocol import SampleOutputProtocol\r\nfrom tha4.shion.core.training.single.training_states import TrainingState\r\nfrom tha4.shion.core.training.training_protocol import TrainingProtocol\r\nfrom tha4.shion.core.training.util import get_least_greater_multiple, create_log_func, set_learning_rate\r\nfrom tha4.shion.core.training.validation_protocol import ValidationProtocol\r\n\r\nKEY_CHECKPOINT = 'checkpoint'\r\nKEY_SNAPSHOT = 'snapshot'\r\nKEY_VALIDATION = 'validation'\r\nKEY_SAMPLE_OUTPUT = 'sample_output'\r\n\r\n\r\nclass TrainingTasks:\r\n    def __init__(\r\n            self,\r\n            workspace: Workspace,\r\n            prefix: str,\r\n            module_factories: Dict[str, ModuleFactory],\r\n            accumulators: Dict[str, ModuleAccumulator],\r\n            losses: Dict[str, Loss],\r\n            training_dataset: Dataset,\r\n            validation_dataset: Optional[Dataset],\r\n            training_protocol: TrainingProtocol,\r\n            validation_protocol: Optional[ValidationProtocol],\r\n            sample_output_protocol: Optional[SampleOutputProtocol],\r\n            pretrained_module_file_names: Dict[str, str],\r\n            example_per_snapshot: int,\r\n            device: torch.device,\r\n            num_data_loader_workers: int = 8,\r\n            dependencies: Optional[List[str]] = None):\r\n        super().__init__()\r\n        self.num_data_loader_workers = num_data_loader_workers\r\n        self.accumulators = accumulators\r\n        self.device = device\r\n        self.sample_output_protocol = sample_output_protocol\r\n        self.example_per_snapshot = example_per_snapshot\r\n        self.pretrained_module_file_names = pretrained_module_file_names\r\n        self.losses = losses\r\n        self.validation_protocol = validation_protocol\r\n        self.training_protocol = training_protocol\r\n        self.module_factories = module_factories\r\n        self.prefix = prefix\r\n        self.training_dataset = training_dataset\r\n        self.validation_dataset = validation_dataset\r\n\r\n        self.checkpoint_examples = self.training_protocol.get_checkpoint_examples()\r\n        assert len(self.checkpoint_examples) >= 1\r\n        assert self.checkpoint_examples[0] > 0\r\n        self.checkpoint_examples = [0] + self.checkpoint_examples\r\n\r\n        self.module_names = self.module_factories.keys()\r\n        assert len(self.module_names) > 0\r\n\r\n        self.training_data_loader = None\r\n        self.training_data_loader_iter = None\r\n        self.training_data_loader_batch_size = None\r\n        self.validation_data_loader = None\r\n        self.validation_data_loader_iter = None\r\n        self.validation_data_loader_batch_size = None\r\n        self.sample_output_data = None\r\n        self.summary_writer = None\r\n        self.log_dir = None\r\n        self.training_state = None\r\n\r\n        if dependencies is None:\r\n            dependencies = []\r\n        self.sample_output_data_task = workspace.create_file_task(\r\n            self.get_sample_output_data_file_name(), dependencies, self.save_sample_output_data)\r\n\r\n        module_file_dependencies = [self.sample_output_data_task.name]\r\n        for module_name in pretrained_module_file_names:\r\n            module_file_dependencies.append(self.pretrained_module_file_names[module_name])\r\n\r\n        def create_train_func(target_examples: int):\r\n            return lambda: self.train(target_examples)\r\n\r\n        train_tasks = []\r\n        for checkpoint_index in range(1, len(self.checkpoint_examples)):\r\n            for module_name in self.module_names:\r\n                module_file_name = TrainingState.get_module_file_name(\r\n                    self.get_checkpoint_prefix(checkpoint_index),\r\n                    module_name)\r\n                workspace.create_file_task(\r\n                    module_file_name,\r\n                    module_file_dependencies,\r\n                    create_train_func(self.checkpoint_examples[checkpoint_index]))\r\n            for module_name in self.accumulators:\r\n                accumulated_module_file_name = TrainingState.get_accumulated_module_file_name(\r\n                    self.get_checkpoint_prefix(checkpoint_index),\r\n                    module_name)\r\n                workspace.create_file_task(\r\n                    accumulated_module_file_name,\r\n                    module_file_dependencies,\r\n                    create_train_func(self.checkpoint_examples[checkpoint_index]))\r\n            workspace.create_command_task(\r\n                self.get_checkpoint_prefix(checkpoint_index) + \"/train\",\r\n                module_file_dependencies,\r\n                create_train_func(self.checkpoint_examples[checkpoint_index]))\r\n            train_tasks.append(self.get_checkpoint_prefix(checkpoint_index) + \"/train\")\r\n\r\n        self.train_task = workspace.create_file_task(\r\n            self.get_train_command_name(),\r\n            module_file_dependencies,\r\n            create_train_func(self.checkpoint_examples[-1]))\r\n\r\n    def get_sample_output_data_file_name(self):\r\n        return self.prefix + \"/sample_output_data.pt\"\r\n\r\n    def save_sample_output_data(self):\r\n        if self.sample_output_protocol is not None:\r\n            torch.manual_seed(self.sample_output_protocol.get_random_seed())\r\n            sample_output_data = self.sample_output_protocol.get_sample_output_data(self.validation_dataset,\r\n                                                                                    self.device)\r\n            torch_save(sample_output_data, self.get_sample_output_data_file_name())\r\n        else:\r\n            torch_save({}, self.get_sample_output_data_file_name())\r\n\r\n    def get_module_file_name(self, checkpoint_index, module_name):\r\n        return TrainingState.get_module_file_name(self.get_checkpoint_prefix(checkpoint_index), module_name)\r\n\r\n    def get_last_module_file_name(self, module_name):\r\n        return self.get_module_file_name(len(self.checkpoint_examples) - 1, module_name)\r\n\r\n    def get_log_dir(self):\r\n        if self.log_dir is None:\r\n            now = datetime.now()\r\n            self.log_dir = self.prefix + \"/log/\" + now.strftime(\"%Y_%m_%d__%H_%M_%S\")\r\n        return self.log_dir\r\n\r\n    def get_summary_writer(self) -> SummaryWriter:\r\n        if self.summary_writer is None:\r\n            self.summary_writer = SummaryWriter(log_dir=self.get_log_dir())\r\n        return self.summary_writer\r\n\r\n    def get_train_command_name(self) -> str:\r\n        return self.prefix + \"/train\"\r\n\r\n    def get_snapshot_prefix(self) -> str:\r\n        return self.prefix + \"/snapshot\"\r\n\r\n    def get_checkpoint_prefix(self, checkpoint_index) -> str:\r\n        return \"%s/checkpoint/%04d\" % (self.prefix, checkpoint_index)\r\n\r\n    def can_load_training_state(self, prefix) -> bool:\r\n        return TrainingState.can_load(\r\n            prefix,\r\n            self.module_factories,\r\n            self.accumulators,\r\n            self.training_protocol.get_optimizer_factories())\r\n\r\n    def load_training_state(self, prefix) -> TrainingState:\r\n        return TrainingState.load(\r\n            prefix,\r\n            self.module_factories,\r\n            self.accumulators,\r\n            self.training_protocol.get_optimizer_factories(),\r\n            self.device)\r\n\r\n    def get_initial_training_state(self) -> TrainingState:\r\n        training_state = TrainingState.new(\r\n            self.module_factories,\r\n            self.accumulators,\r\n            self.training_protocol.get_optimizer_factories(),\r\n            self.training_protocol.get_random_seed(),\r\n            self.device,\r\n            self.pretrained_module_file_names)\r\n        logging.info(\"Created a new initial training state.\")\r\n        return training_state\r\n\r\n    def load_previous_training_state(self, target_checkpoint_examples: int) -> TrainingState:\r\n        if self.can_load_training_state(self.get_snapshot_prefix()):\r\n            examples_seen_so_far = TrainingState.get_examples_seen_so_far(self.get_snapshot_prefix())\r\n            diff = examples_seen_so_far - target_checkpoint_examples\r\n            if diff < self.training_protocol.get_batch_size():\r\n                return self.load_training_state(self.get_snapshot_prefix())\r\n        num_checkpoints = len(self.checkpoint_examples)\r\n        for checkpoint_index in range(num_checkpoints - 1, -1, -1):\r\n            if self.can_load_training_state(self.get_checkpoint_prefix(checkpoint_index)):\r\n                examples_seen_so_far = TrainingState.get_examples_seen_so_far(\r\n                    self.get_checkpoint_prefix(checkpoint_index))\r\n                diff = examples_seen_so_far - target_checkpoint_examples\r\n                if diff < self.training_protocol.get_batch_size():\r\n                    return self.load_training_state(self.get_checkpoint_prefix(checkpoint_index))\r\n        return self.get_initial_training_state()\r\n\r\n    def get_next_checkpoint_num_examples(self, examples_seen_so_far) -> int:\r\n        next_index = next(\r\n            (i for i in range(len(self.checkpoint_examples)) if self.checkpoint_examples[i] > examples_seen_so_far),\r\n            -1)\r\n        return self.checkpoint_examples[next_index]\r\n\r\n    def get_next_snapshot_num_examples(self, examples_seen_so_far) -> int:\r\n        return get_least_greater_multiple(examples_seen_so_far, self.example_per_snapshot)\r\n\r\n    def get_next_validation_num_examples(self, examples_seen_so_far) -> int:\r\n        if self.validation_protocol is None:\r\n            return -1\r\n        return get_least_greater_multiple(examples_seen_so_far,\r\n                                          self.validation_protocol.get_examples_per_validation_iteration())\r\n\r\n    def get_next_sample_output_num_examples(self, examples_seen_so_far) -> int:\r\n        if self.sample_output_protocol is None:\r\n            return -1\r\n        return get_least_greater_multiple(examples_seen_so_far,\r\n                                          self.sample_output_protocol.get_examples_per_sample_output())\r\n\r\n    def get_next_num_examples(self, examples_seen_so_far) -> Dict[str, int]:\r\n        return {\r\n            KEY_CHECKPOINT: self.get_next_checkpoint_num_examples(examples_seen_so_far),\r\n            KEY_SNAPSHOT: self.get_next_snapshot_num_examples(examples_seen_so_far),\r\n            KEY_VALIDATION: self.get_next_validation_num_examples(examples_seen_so_far),\r\n            KEY_SAMPLE_OUTPUT: self.get_next_sample_output_num_examples(examples_seen_so_far)\r\n        }\r\n\r\n    def get_checkpoint_index_to_save(self, examples_seen_so_far: int) -> int:\r\n        checkpoint_index = 0\r\n        for i in range(len(self.checkpoint_examples)):\r\n            if self.checkpoint_examples[i] <= examples_seen_so_far:\r\n                checkpoint_index = i\r\n        return checkpoint_index\r\n\r\n    def get_next_training_batch(self):\r\n        if self.training_data_loader is None:\r\n            self.training_data_loader = DataLoader(\r\n                self.training_dataset,\r\n                batch_size=self.training_protocol.get_batch_size(),\r\n                shuffle=True,\r\n                num_workers=self.num_data_loader_workers,\r\n                drop_last=True)\r\n        if self.training_data_loader_iter is None:\r\n            self.training_data_loader_iter = iter(self.training_data_loader)\r\n        try:\r\n            batch = next(self.training_data_loader_iter)\r\n        except StopIteration:\r\n            self.training_data_loader_iter = iter(self.training_data_loader)\r\n            batch = next(self.training_data_loader_iter)\r\n        return [x.to(self.device) for x in batch]\r\n\r\n    def get_next_validation_batch(self):\r\n        if self.validation_dataset is None:\r\n            return None\r\n        if self.validation_data_loader is None:\r\n            self.validation_data_loader = DataLoader(\r\n                self.validation_dataset,\r\n                batch_size=self.validation_protocol.get_batch_size(),\r\n                shuffle=True,\r\n                num_workers=self.num_data_loader_workers,\r\n                drop_last=True)\r\n        if self.validation_data_loader_iter is None:\r\n            self.validation_data_loader_iter = iter(self.validation_data_loader)\r\n        try:\r\n            batch = next(self.validation_data_loader_iter)\r\n        except StopIteration:\r\n            self.validation_data_loader_iter = iter(self.validation_data_loader)\r\n            batch = next(self.validation_data_loader_iter)\r\n        return [x.to(self.device) for x in batch]\r\n\r\n    def get_checkpoint_index(self, target_checkpoint_examples: int):\r\n        return self.checkpoint_examples.index(target_checkpoint_examples)\r\n\r\n    def train(self, target_checkpoint_examples: Optional[int] = None):\r\n        if target_checkpoint_examples is None:\r\n            target_checkpoint_examples = self.checkpoint_examples[-1]\r\n\r\n        sample_output_data = torch_load(self.get_sample_output_data_file_name())\r\n        logging.info(\"Loaded sampled output data from %s\", self.get_sample_output_data_file_name())\r\n        training_state = self.load_previous_training_state(target_checkpoint_examples)\r\n        summary_writer = self.get_summary_writer()\r\n        last_time = time.time()\r\n\r\n        while training_state.examples_seen_so_far < target_checkpoint_examples:\r\n            # One training iteration\r\n            learning_rate = self.training_protocol.get_learning_rate(training_state.examples_seen_so_far)\r\n            for module_name in self.module_factories.keys():\r\n                if module_name not in learning_rate or module_name not in training_state.optimizers:\r\n                    continue\r\n                lr = learning_rate[module_name]\r\n                set_learning_rate(training_state.optimizers[module_name], lr)\r\n                self.get_summary_writer().add_scalar(\r\n                    module_name + \"_learning_rate\", lr, training_state.examples_seen_so_far)\r\n            training_batch = self.get_next_training_batch()\r\n            self.training_protocol.run_training_iteration(\r\n                training_batch,\r\n                training_state.examples_seen_so_far,\r\n                training_state.modules,\r\n                training_state.accumulated_modules,\r\n                training_state.optimizers,\r\n                self.losses,\r\n                lambda name, num: create_log_func(summary_writer, name, num),\r\n                self.device)\r\n\r\n            # Accumulate model data\r\n            for module_name in self.accumulators:\r\n                new_module = training_state.modules[module_name]\r\n                buffer_module = training_state.accumulated_modules[module_name]\r\n                self.accumulators[module_name].accumulate(\r\n                    new_module,\r\n                    buffer_module,\r\n                    training_state.examples_seen_so_far)\r\n\r\n            # Advance the number of examples seen so far\r\n            next_num_examples = self.get_next_num_examples(training_state.examples_seen_so_far)\r\n            training_state.examples_seen_so_far += self.training_protocol.get_batch_size()\r\n\r\n            # Validation iteration\r\n            if self.validation_protocol is not None \\\r\n                    and training_state.examples_seen_so_far >= next_num_examples[KEY_VALIDATION]:\r\n                validation_batch = self.get_next_validation_batch()\r\n                self.validation_protocol.run_validation_iteration(\r\n                    validation_batch,\r\n                    training_state.examples_seen_so_far,\r\n                    training_state.modules,\r\n                    training_state.accumulated_modules,\r\n                    self.losses,\r\n                    lambda name, num: create_log_func(summary_writer, name, num),\r\n                    self.device)\r\n\r\n            # Save sample output\r\n            if self.sample_output_protocol is not None \\\r\n                    and training_state.examples_seen_so_far >= next_num_examples[KEY_SAMPLE_OUTPUT]:\r\n                self.sample_output_protocol.save_sample_output_data(\r\n                    training_state.modules,\r\n                    training_state.accumulated_modules,\r\n                    sample_output_data,\r\n                    self.prefix + \"/sample_outputs\",\r\n                    training_state.examples_seen_so_far,\r\n                    self.device)\r\n\r\n            # Save checkpoint\r\n            if training_state.examples_seen_so_far >= next_num_examples[KEY_CHECKPOINT]:\r\n                checkpoint_index = self.get_checkpoint_index_to_save(training_state.examples_seen_so_far)\r\n                training_state.save(self.get_checkpoint_prefix(checkpoint_index))\r\n                if next_num_examples[KEY_CHECKPOINT] != next_num_examples[KEY_SNAPSHOT]:\r\n                    training_state.save(self.get_snapshot_prefix())\r\n\r\n            # Save snapshot\r\n            if training_state.examples_seen_so_far >= next_num_examples[KEY_SNAPSHOT]:\r\n                training_state.save(self.get_snapshot_prefix())\r\n\r\n            now = time.time()\r\n            if now - last_time > 10:\r\n                logging.info(\"Showed %d training examples.\" % training_state.examples_seen_so_far)\r\n                last_time = now\r\n"
  },
  {
    "path": "src/tha4/shion/core/training/swarm/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/shion/core/training/swarm/swarm_training_tasks.py",
    "content": "from typing import Callable, Optional, List\r\n\r\nfrom tha4.pytasuku.workspace import Workspace\r\nfrom tha4.shion.core.training.distrib.distributed_training_tasks import RdzvConfig, \\\r\n    run_standalone_distributed_training_script\r\nfrom tha4.shion.core.training.single.training_states import TrainingState\r\nfrom tha4.shion.core.training.swarm.swarm_unit_trainer import SwarmUnitTrainer\r\n\r\n\r\ndef define_standalone_swarm_training_tasks(\r\n        workspace: Workspace,\r\n        swarm_unit_trainer_func: Callable[[], SwarmUnitTrainer],\r\n        training_script_file_name: str,\r\n        num_proc_per_node: int,\r\n        dependencies: Optional[List[str]] = None,\r\n        rdzv_config: Optional[RdzvConfig] = None):\r\n    trainer = swarm_unit_trainer_func()\r\n    checkpoint_examples = trainer.training_protocol.get_checkpoint_examples()\r\n    assert len(checkpoint_examples) >= 1\r\n    assert checkpoint_examples[0] > 0\r\n    checkpoint_examples = [0] + checkpoint_examples\r\n\r\n    if dependencies is None:\r\n        dependencies = []\r\n    module_file_dependencies = dependencies[:]\r\n    for module_name in trainer.pretrained_module_file_names:\r\n        module_file_dependencies.append(trainer.pretrained_module_file_names[module_name])\r\n\r\n    def create_train_func(target_checkpoint_examples: int):\r\n        return lambda: run_standalone_distributed_training_script(\r\n            training_script_file_name,\r\n            num_proc_per_node,\r\n            target_checkpoint_examples,\r\n            rdzv_config=rdzv_config)\r\n\r\n    train_tasks = []\r\n    for checkpoint_index in range(0, len(checkpoint_examples)):\r\n        for module_name in trainer.module_names:\r\n            module_file_name = TrainingState.get_module_file_name(\r\n                trainer.get_checkpoint_prefix(checkpoint_index),\r\n                module_name)\r\n            workspace.create_file_task(\r\n                module_file_name,\r\n                module_file_dependencies,\r\n                create_train_func(trainer.checkpoint_examples[checkpoint_index]))\r\n        for module_name in trainer.accumulators:\r\n            accumulated_module_file_name = TrainingState.get_accumulated_module_file_name(\r\n                trainer.get_checkpoint_prefix(checkpoint_index),\r\n                module_name)\r\n            workspace.create_file_task(\r\n                accumulated_module_file_name,\r\n                module_file_dependencies,\r\n                create_train_func(checkpoint_examples[checkpoint_index]))\r\n        workspace.create_command_task(\r\n            trainer.get_checkpoint_prefix(checkpoint_index) + \"/train_standalone\",\r\n            module_file_dependencies,\r\n            create_train_func(checkpoint_examples[checkpoint_index]))\r\n        train_tasks.append(trainer.get_checkpoint_prefix(checkpoint_index) + \"/train_standlone\")\r\n    workspace.create_file_task(\r\n        trainer.prefix + \"/train_standalone\",\r\n        module_file_dependencies,\r\n        create_train_func(checkpoint_examples[-1]))\r\n"
  },
  {
    "path": "src/tha4/shion/core/training/swarm/swarm_unit_trainer.py",
    "content": "import argparse\r\nimport logging\r\nimport os\r\nimport time\r\nfrom datetime import datetime\r\nfrom typing import Dict, Optional, Callable\r\nimport torch.distributed\r\n\r\nimport torch\r\nfrom torch.utils.data import Dataset, DataLoader\r\nfrom torch.utils.tensorboard import SummaryWriter\r\n\r\nfrom tha4.shion.core.load_save import torch_save, torch_load\r\nfrom tha4.shion.core.loss import Loss\r\nfrom tha4.shion.core.module_accumulator import ModuleAccumulator\r\nfrom tha4.shion.core.module_factory import ModuleFactory\r\nfrom tha4.shion.core.training.distrib.device_mapper import SimpleCudaDeviceMapper\r\nfrom tha4.shion.core.training.sample_output_protocol import SampleOutputProtocol\r\nfrom tha4.shion.core.training.single.training_states import TrainingState\r\nfrom tha4.shion.core.training.single.training_tasks import KEY_CHECKPOINT, KEY_SNAPSHOT, KEY_VALIDATION, KEY_SAMPLE_OUTPUT\r\nfrom tha4.shion.core.training.training_protocol import TrainingProtocol\r\nfrom tha4.shion.core.training.util import get_least_greater_multiple, create_log_func, set_learning_rate\r\nfrom tha4.shion.core.training.validation_protocol import ValidationProtocol\r\n\r\n\r\nclass SwarmUnitTrainer:\r\n    def __init__(self,\r\n                 prefix: str,\r\n                 module_factories: Dict[str, ModuleFactory],\r\n                 accumulators: Dict[str, ModuleAccumulator],\r\n                 losses: Dict[str, Loss],\r\n                 training_dataset: Dataset,\r\n                 validation_dataset: Optional[Dataset],\r\n                 training_protocol: TrainingProtocol,\r\n                 validation_protocol: Optional[ValidationProtocol],\r\n                 sample_output_protocol: Optional[SampleOutputProtocol],\r\n                 pretrained_module_file_names: Dict[str, str],\r\n                 example_per_snapshot: int,\r\n                 num_data_loader_workers: int = 8):\r\n        self.num_data_loader_workers = num_data_loader_workers\r\n        self.accumulators = accumulators\r\n        self.sample_output_protocol = sample_output_protocol\r\n        self.example_per_snapshot = example_per_snapshot\r\n        self.pretrained_module_file_names = pretrained_module_file_names\r\n        self.losses = losses\r\n        self.validation_protocol = validation_protocol\r\n        self.training_protocol = training_protocol\r\n        self.module_factories = module_factories\r\n        self.prefix = prefix\r\n        self.training_dataset = training_dataset\r\n        self.validation_dataset = validation_dataset\r\n\r\n        self.checkpoint_examples = self.training_protocol.get_checkpoint_examples()\r\n        assert len(self.checkpoint_examples) >= 1\r\n        assert self.checkpoint_examples[0] > 0\r\n        self.checkpoint_examples = [0] + self.checkpoint_examples\r\n\r\n        self.module_names = self.module_factories.keys()\r\n        assert len(self.module_names) > 0\r\n\r\n        self.training_data_loader = None\r\n        self.training_data_loader_iter = None\r\n        self.training_data_loader_batch_size = None\r\n        self.training_data_sampler = None\r\n\r\n        self.validation_data_loader = None\r\n        self.validation_data_loader_iter = None\r\n        self.validation_data_loader_batch_size = None\r\n\r\n        self.sample_output_data = None\r\n        self.summary_writer = None\r\n        self.log_dir = None\r\n        self.training_state = None\r\n\r\n    def get_sample_output_data_file_name(self):\r\n        return self.prefix + \"/sample_output_data.pt\"\r\n\r\n    def save_sample_output_data(self, device: torch.device):\r\n        if os.path.exists(self.get_sample_output_data_file_name()):\r\n            return\r\n        if self.sample_output_protocol is not None:\r\n            torch.manual_seed(self.sample_output_protocol.get_random_seed())\r\n            sample_output_data = self.sample_output_protocol.get_sample_output_data(self.validation_dataset, device)\r\n            torch_save(sample_output_data, self.get_sample_output_data_file_name())\r\n        else:\r\n            torch_save({}, self.get_sample_output_data_file_name())\r\n\r\n    def load_sample_output_data(self, device: torch.device):\r\n        self.save_sample_output_data(device)\r\n        return torch_load(self.get_sample_output_data_file_name())\r\n\r\n    def get_snapshot_prefix(self) -> str:\r\n        return self.prefix + \"/snapshot\"\r\n\r\n    def can_load_training_state(self, prefix: str) -> bool:\r\n        return TrainingState.can_load(\r\n            prefix,\r\n            self.module_factories,\r\n            self.accumulators,\r\n            self.training_protocol.get_optimizer_factories())\r\n\r\n    def load_training_state(self, prefix, device: torch.device) -> TrainingState:\r\n        return TrainingState.load(\r\n            prefix,\r\n            self.module_factories,\r\n            self.accumulators,\r\n            self.training_protocol.get_optimizer_factories(),\r\n            device)\r\n\r\n    @staticmethod\r\n    def checkpoint_prefix(prefix: str, checkpoint_index: int) -> str:\r\n        return \"%s/checkpoint/%04d\" % (prefix, checkpoint_index)\r\n\r\n    def get_checkpoint_prefix(self, checkpoint_index) -> str:\r\n        return SwarmUnitTrainer.checkpoint_prefix(self.prefix, checkpoint_index)\r\n\r\n    def get_initial_training_state(self, device: torch.device) -> TrainingState:\r\n        training_state = TrainingState.new(\r\n            self.module_factories,\r\n            self.accumulators,\r\n            self.training_protocol.get_optimizer_factories(),\r\n            self.training_protocol.get_random_seed(),\r\n            device,\r\n            self.pretrained_module_file_names)\r\n        logging.info(\"Created a new initial training state.\")\r\n        return training_state\r\n\r\n    def load_previous_training_state(self,\r\n                                     target_checkpoint_examples: int,\r\n                                     device: torch.device) -> TrainingState:\r\n        if self.can_load_training_state(self.get_snapshot_prefix()):\r\n            examples_seen_so_far = TrainingState.get_examples_seen_so_far(self.get_snapshot_prefix())\r\n            diff = examples_seen_so_far - target_checkpoint_examples\r\n            if diff < self.training_protocol.get_batch_size():\r\n                return self.load_training_state(self.get_snapshot_prefix(), device)\r\n        num_checkpoints = len(self.checkpoint_examples)\r\n        for checkpoint_index in range(num_checkpoints - 1, -1, -1):\r\n            if self.can_load_training_state(self.get_checkpoint_prefix(checkpoint_index)):\r\n                examples_seen_so_far = TrainingState.get_examples_seen_so_far(\r\n                    self.get_checkpoint_prefix(checkpoint_index))\r\n                diff = examples_seen_so_far - target_checkpoint_examples\r\n                if diff < self.training_protocol.get_batch_size():\r\n                    return self.load_training_state(\r\n                        self.get_checkpoint_prefix(checkpoint_index), device)\r\n\r\n        training_state = self.get_initial_training_state(device)\r\n        training_state.save(self.get_checkpoint_prefix(0))\r\n        training_state = self.load_training_state(self.get_checkpoint_prefix(0), device)\r\n        return training_state\r\n\r\n    def get_log_dir(self):\r\n        if self.log_dir is None:\r\n            now = datetime.now()\r\n            self.log_dir = self.prefix + \"/log/\" + now.strftime(\"%Y_%m_%d__%H_%M_%S\")\r\n        return self.log_dir\r\n\r\n    def get_summary_writer(self) -> Optional[SummaryWriter]:\r\n        if self.summary_writer is None:\r\n            self.summary_writer = SummaryWriter(log_dir=self.get_log_dir())\r\n        return self.summary_writer\r\n\r\n    def get_next_training_batch(self, device: torch.device):\r\n        if self.training_data_loader is None:\r\n            self.training_data_loader = DataLoader(\r\n                self.training_dataset,\r\n                batch_size=self.training_protocol.get_batch_size(),\r\n                shuffle=True,\r\n                num_workers=self.num_data_loader_workers,\r\n                drop_last=True)\r\n        if self.training_data_loader_iter is None:\r\n            self.training_data_loader_iter = iter(self.training_data_loader)\r\n        try:\r\n            batch = next(self.training_data_loader_iter)\r\n        except StopIteration:\r\n            self.training_data_loader_iter = iter(self.training_data_loader)\r\n            batch = next(self.training_data_loader_iter)\r\n        return [x.to(device) for x in batch]\r\n\r\n    def get_next_checkpoint_num_examples(self, examples_seen_so_far) -> int:\r\n        next_index = next(\r\n            (i for i in range(len(self.checkpoint_examples)) if self.checkpoint_examples[i] > examples_seen_so_far),\r\n            -1)\r\n        return self.checkpoint_examples[next_index]\r\n\r\n    def get_next_snapshot_num_examples(self, examples_seen_so_far) -> int:\r\n        return get_least_greater_multiple(examples_seen_so_far, self.example_per_snapshot)\r\n\r\n    def get_next_validation_num_examples(self, examples_seen_so_far) -> int:\r\n        if self.validation_protocol is None:\r\n            return -1\r\n        return get_least_greater_multiple(examples_seen_so_far,\r\n                                          self.validation_protocol.get_examples_per_validation_iteration())\r\n\r\n    def get_next_sample_output_num_examples(self, examples_seen_so_far) -> int:\r\n        if self.sample_output_protocol is None:\r\n            return -1\r\n        return get_least_greater_multiple(examples_seen_so_far,\r\n                                          self.sample_output_protocol.get_examples_per_sample_output())\r\n\r\n    def get_next_num_examples(self, examples_seen_so_far) -> Dict[str, int]:\r\n        return {\r\n            KEY_CHECKPOINT: self.get_next_checkpoint_num_examples(examples_seen_so_far),\r\n            KEY_SNAPSHOT: self.get_next_snapshot_num_examples(examples_seen_so_far),\r\n            KEY_VALIDATION: self.get_next_validation_num_examples(examples_seen_so_far),\r\n            KEY_SAMPLE_OUTPUT: self.get_next_sample_output_num_examples(examples_seen_so_far)\r\n        }\r\n\r\n    def get_next_validation_batch(self, device: torch.device):\r\n        if self.validation_dataset is None:\r\n            return None\r\n        if self.validation_data_loader is None:\r\n            self.validation_data_loader = DataLoader(\r\n                self.validation_dataset,\r\n                batch_size=self.validation_protocol.get_batch_size(),\r\n                shuffle=True,\r\n                num_workers=1,\r\n                drop_last=True)\r\n        if self.validation_data_loader_iter is None:\r\n            self.validation_data_loader_iter = iter(self.validation_data_loader)\r\n        try:\r\n            batch = next(self.validation_data_loader_iter)\r\n        except StopIteration:\r\n            self.validation_data_loader_iter = iter(self.validation_data_loader)\r\n            batch = next(self.validation_data_loader_iter)\r\n        return [x.to(device) for x in batch]\r\n\r\n    def get_checkpoint_index_to_save(self, examples_seen_so_far: int) -> int:\r\n        checkpoint_index = 0\r\n        for i in range(len(self.checkpoint_examples)):\r\n            if self.checkpoint_examples[i] <= examples_seen_so_far:\r\n                checkpoint_index = i\r\n        return checkpoint_index\r\n\r\n    def train(self,\r\n              rank: int,\r\n              local_rank: int,\r\n              target_checkpoint_examples: Optional[int] = None,\r\n              device_mapper: Optional[Callable[[int, int], torch.device]] = None):\r\n        if target_checkpoint_examples is None:\r\n            target_checkpoint_examples = self.checkpoint_examples[-1]\r\n\r\n        if device_mapper is None:\r\n            device_mapper = SimpleCudaDeviceMapper()\r\n        device = device_mapper(rank, local_rank)\r\n\r\n        sample_output_data = self.load_sample_output_data(device)\r\n        training_state = self.load_previous_training_state(\r\n            target_checkpoint_examples, device)\r\n        summary_writer = self.get_summary_writer()\r\n        if summary_writer is not None:\r\n            log_func_factory = lambda name, num: create_log_func(summary_writer, name, num)\r\n        else:\r\n            log_func_factory = None\r\n        last_time = time.time()\r\n\r\n        while training_state.examples_seen_so_far < target_checkpoint_examples:\r\n            # Set the learning rate\r\n            learning_rate_by_module_name = self.training_protocol.get_learning_rate(training_state.examples_seen_so_far)\r\n            for module_name in self.module_factories.keys():\r\n                if module_name not in learning_rate_by_module_name or module_name not in training_state.optimizers:\r\n                    continue\r\n                lr = learning_rate_by_module_name[module_name]\r\n                set_learning_rate(training_state.optimizers[module_name], lr)\r\n                if summary_writer is not None:\r\n                    summary_writer.add_scalar(\r\n                        module_name + \"_learning_rate\", lr, training_state.examples_seen_so_far)\r\n\r\n            # One training iteration\r\n            training_batch = self.get_next_training_batch(device)\r\n            self.training_protocol.run_training_iteration(\r\n                training_batch,\r\n                training_state.examples_seen_so_far,\r\n                training_state.modules,\r\n                training_state.accumulated_modules,\r\n                training_state.optimizers,\r\n                self.losses,\r\n                log_func_factory,\r\n                device)\r\n\r\n            # Accumulate model data\r\n            for module_name in self.accumulators:\r\n                new_module = training_state.modules[module_name]\r\n                buffer_module = training_state.accumulated_modules[module_name]\r\n                self.accumulators[module_name].accumulate(\r\n                    new_module, buffer_module, examples_seen_so_far=training_state.examples_seen_so_far)\r\n\r\n            # Advance the number of examples seen so far\r\n            next_num_examples = self.get_next_num_examples(training_state.examples_seen_so_far)\r\n            training_state.examples_seen_so_far += self.training_protocol.get_batch_size()\r\n\r\n            # Validation iteration\r\n            if self.validation_protocol is not None \\\r\n                    and training_state.examples_seen_so_far >= next_num_examples[KEY_VALIDATION]:\r\n                validation_batch = self.get_next_validation_batch(device)\r\n                self.validation_protocol.run_validation_iteration(\r\n                    validation_batch,\r\n                    training_state.examples_seen_so_far,\r\n                    training_state.modules,\r\n                    training_state.accumulated_modules,\r\n                    self.losses,\r\n                    log_func_factory,\r\n                    device)\r\n\r\n            # Save sample output\r\n            if self.sample_output_protocol is not None \\\r\n                    and training_state.examples_seen_so_far >= next_num_examples[KEY_SAMPLE_OUTPUT]:\r\n                self.sample_output_protocol.save_sample_output_data(\r\n                    training_state.modules,\r\n                    training_state.accumulated_modules,\r\n                    sample_output_data,\r\n                    self.prefix + \"/sample_outputs\",\r\n                    training_state.examples_seen_so_far,\r\n                    device)\r\n\r\n            # Save checkpoint\r\n            if training_state.examples_seen_so_far >= next_num_examples[KEY_CHECKPOINT]:\r\n                checkpoint_index = self.get_checkpoint_index_to_save(training_state.examples_seen_so_far)\r\n                training_state.save(self.get_checkpoint_prefix(checkpoint_index))\r\n                if next_num_examples[KEY_CHECKPOINT] != next_num_examples[KEY_SNAPSHOT]:\r\n                    training_state.save(self.get_snapshot_prefix())\r\n\r\n            # Save snapshot\r\n            if training_state.examples_seen_so_far >= next_num_examples[KEY_SNAPSHOT]:\r\n                training_state.save(self.get_snapshot_prefix())\r\n\r\n            now = time.time()\r\n            if now - last_time > 10:\r\n                logging.info(\"[Rank %d] Showed %d training examples.\" % (rank, training_state.examples_seen_so_far))\r\n                last_time = now\r\n\r\n    @staticmethod\r\n    def run(trainer_factory: Dict[int, Callable[[], 'SwarmUnitTrainer']],\r\n            backend: str = 'gloo',\r\n            device_mapper: Optional[Callable[[int, int], torch.device]] = None):\r\n        parser = argparse.ArgumentParser(description='Training script.')\r\n        parser.add_argument(\"--target_checkpoint_examples\", type=int)\r\n        args = parser.parse_args()\r\n\r\n        rank = int(os.environ['RANK'])\r\n        local_rank = int(os.environ['LOCAL_RANK'])\r\n\r\n        torch.distributed.init_process_group(backend)\r\n        if rank in trainer_factory:\r\n            trainer = trainer_factory[rank]()\r\n            trainer.train(rank, local_rank, args.target_checkpoint_examples, device_mapper)\r\n"
  },
  {
    "path": "src/tha4/shion/core/training/training_protocol.py",
    "content": "from abc import ABC, abstractmethod\r\nfrom typing import Dict, List, Callable, Any, Optional\r\n\r\nimport torch\r\nfrom torch.nn import Module\r\nfrom torch.optim.optimizer import Optimizer\r\n\r\nfrom tha4.shion.core.loss import Loss\r\nfrom tha4.shion.core.optimizer_factory import OptimizerFactory\r\n\r\n\r\nclass TrainingProtocol(ABC):\r\n    @abstractmethod\r\n    def get_optimizer_factories(self) -> Dict[str, OptimizerFactory]:\r\n        pass\r\n\r\n    @abstractmethod\r\n    def get_checkpoint_examples(self) -> List[int]:\r\n        pass\r\n\r\n    @abstractmethod\r\n    def get_random_seed(self) -> int:\r\n        pass\r\n\r\n    @abstractmethod\r\n    def get_batch_size(self) -> int:\r\n        pass\r\n\r\n    @abstractmethod\r\n    def get_learning_rate(self, examples_seen_so_far: int) -> Dict[str, float]:\r\n        pass\r\n\r\n    @abstractmethod\r\n    def run_training_iteration(\r\n            self,\r\n            batch: Any,\r\n            examples_seen_so_far: int,\r\n            modules: Dict[str, Module],\r\n            accumulated_modules: Dict[str, Module],\r\n            optimizers: Dict[str, Optimizer],\r\n            losses: Dict[str, Loss],\r\n            create_log_func: Optional[Callable[[str, int], Callable[[str, float], None]]],\r\n            device: torch.device):\r\n        pass\r\n\r\n\r\nclass AbstractTrainingProtocol(TrainingProtocol, ABC):\r\n    def __init__(self,\r\n                 check_point_examples: List[int],\r\n                 batch_size: int,\r\n                 learning_rate: Callable[[int], Dict[str, float]],\r\n                 optimizer_factories: Dict[str, OptimizerFactory],\r\n                 random_seed: int):\r\n        self.random_seed = random_seed\r\n        self.optimizer_factories = optimizer_factories\r\n        self.learning_rate = learning_rate\r\n        self.batch_size = batch_size\r\n        self.check_point_examples = check_point_examples\r\n\r\n    def get_optimizer_factories(self) -> Dict[str, OptimizerFactory]:\r\n        return self.optimizer_factories\r\n\r\n    def get_checkpoint_examples(self) -> List[int]:\r\n        return self.check_point_examples\r\n\r\n    def get_random_seed(self) -> int:\r\n        return self.random_seed\r\n\r\n    def get_batch_size(self) -> int:\r\n        return self.batch_size\r\n\r\n    def get_learning_rate(self, examples_seen_so_far: int) -> Dict[str, float]:\r\n        return self.learning_rate(examples_seen_so_far)\r\n"
  },
  {
    "path": "src/tha4/shion/core/training/util.py",
    "content": "from typing import Callable\r\n\r\nimport torch\r\nfrom torch.nn import Module\r\nfrom torch.optim import Optimizer\r\n\r\n\r\ndef optimizer_to_device(optim: Optimizer, device: torch.device):\r\n    for state in optim.state.values():\r\n        for k, v in state.items():\r\n            if isinstance(v, torch.Tensor):\r\n                state[k] = v.to(device)\r\n\r\n\r\ndef zero_module(module: Module):\r\n    parameters = dict(module.named_parameters())\r\n    for k in parameters.keys():\r\n        parameters[k].data.zero_()\r\n\r\n\r\ndef get_least_greater_multiple(x: int, m: int) -> int:\r\n    \"\"\"\r\n    :param x: a non-negative integer\r\n    :param m: a positive integer\r\n    :return: the next multiple of m that is greater than x\r\n    \"\"\"\r\n    assert x >= 0\r\n    assert m > 0\r\n    return (x // m + 1) * m\r\n\r\n\r\ndef create_log_func(summary_writer, prefix: str, examples_seen_so_far: int) -> Callable[[str, float], None]:\r\n    def log_func(tag: str, value: float):\r\n        summary_writer.add_scalar(prefix + \"_\" + tag, value, examples_seen_so_far)\r\n\r\n    return log_func\r\n\r\n\r\ndef set_learning_rate(module, lr):\r\n    for param_group in module.param_groups:\r\n        param_group['lr'] = lr\r\n"
  },
  {
    "path": "src/tha4/shion/core/training/validation_protocol.py",
    "content": "from abc import ABC, abstractmethod\r\nfrom typing import Dict, Callable, Any\r\n\r\nimport torch\r\nfrom torch.nn import Module\r\n\r\nfrom tha4.shion.core.loss import Loss\r\n\r\n\r\nclass ValidationProtocol(ABC):\r\n    @abstractmethod\r\n    def get_batch_size(self) -> int:\r\n        pass\r\n\r\n    @abstractmethod\r\n    def get_examples_per_validation_iteration(self) -> int:\r\n        pass\r\n\r\n    @abstractmethod\r\n    def run_validation_iteration(\r\n            self,\r\n            batch: Any,\r\n            examples_seen_so_far: int,\r\n            modules: Dict[str, Module],\r\n            accumulated_modules: Dict[str, Module],\r\n            losses: Dict[str, Loss],\r\n            create_log_func: Callable[[str, int], Callable[[str, float], None]],\r\n            device: torch.device):\r\n        pass\r\n\r\n\r\nclass AbstractValidationProtocol(ValidationProtocol, ABC):\r\n    def __init__(self,\r\n                 example_per_validation_iteration: int,\r\n                 batch_size: int):\r\n        self.batch_size = batch_size\r\n        self.example_per_validation_iteration = example_per_validation_iteration\r\n\r\n    def get_batch_size(self) -> int:\r\n        return self.batch_size\r\n\r\n    def get_examples_per_validation_iteration(self) -> int:\r\n        return self.example_per_validation_iteration\r\n"
  },
  {
    "path": "src/tha4/shion/nn00/__init__.py",
    "content": ""
  },
  {
    "path": "src/tha4/shion/nn00/block_args.py",
    "content": "from typing import Optional\r\n\r\nfrom torch.nn import Module, Sequential\r\n\r\nfrom tha4.shion.core.module_factory import ModuleFactory\r\nfrom tha4.shion.nn00.linear_module_args import LinearModuleArgs\r\nfrom tha4.shion.nn00.nonlinearity_factories import resolve_nonlinearity_factory\r\nfrom tha4.shion.nn00.normalization_layer_factories import resolve_normalization_layer_factory\r\nfrom tha4.shion.nn00.normalization_layer_factory import NormalizationLayerFactory\r\n\r\n\r\nclass BlockArgs:\r\n    def __init__(\r\n            self,\r\n            linear_module_args: Optional[LinearModuleArgs] = None,\r\n            normalization_layer_factory: Optional[NormalizationLayerFactory] = None,\r\n            nonlinearity_factory: Optional[ModuleFactory] = None):\r\n        if linear_module_args is None:\r\n            linear_module_args = LinearModuleArgs()\r\n        self.nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory)\r\n        self.normalization_layer_factory = resolve_normalization_layer_factory(normalization_layer_factory)\r\n        self.linear_module_args = linear_module_args"
  },
  {
    "path": "src/tha4/shion/nn00/conv.py",
    "content": "from typing import Optional, Union, Callable\r\n\r\nfrom torch.nn import Conv2d, Module, Sequential, ConvTranspose2d\r\n\r\nfrom tha4.shion.nn00.block_args import BlockArgs\r\nfrom tha4.shion.nn00.linear_module_args import LinearModuleArgs, wrap_linear_module\r\n\r\n\r\ndef create_conv7(\r\n        in_channels: int,\r\n        out_channels: int,\r\n        bias: bool = False,\r\n        linear_module_args: Optional[LinearModuleArgs] = None) -> Module:\r\n    return wrap_linear_module(\r\n        Conv2d(in_channels, out_channels, kernel_size=7, stride=1, padding=3, bias=bias),\r\n        linear_module_args)\r\n\r\n\r\ndef create_conv3(in_channels: int,\r\n                 out_channels: int,\r\n                 bias: bool = False,\r\n                 linear_module_args: Optional[LinearModuleArgs] = None) -> Module:\r\n    return wrap_linear_module(\r\n        Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=bias),\r\n        linear_module_args)\r\n\r\n\r\ndef create_conv1(\r\n        in_channels: int, out_channels: int,\r\n        bias: bool = False,\r\n        linear_module_args: Optional[LinearModuleArgs] = None) -> Module:\r\n    return wrap_linear_module(\r\n        Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),\r\n        linear_module_args)\r\n\r\n\r\ndef create_conv7_block(\r\n        in_channels: int,\r\n        out_channels: int,\r\n        block_args: Optional[BlockArgs] = None) -> Module:\r\n    if block_args is None:\r\n        block_args = BlockArgs()\r\n    return Sequential(\r\n        create_conv7(\r\n            in_channels,\r\n            out_channels,\r\n            bias=False,\r\n            linear_module_args=block_args.linear_module_args),\r\n        block_args.normalization_layer_factory.create(out_channels, affine=True),\r\n        block_args.nonlinearity_factory.create())\r\n\r\n\r\ndef create_conv3_block(\r\n        in_channels: int,\r\n        out_channels: int,\r\n        block_args: Optional[BlockArgs] = None) -> Module:\r\n    if block_args is None:\r\n        block_args = BlockArgs()\r\n    return Sequential(\r\n        create_conv7(\r\n            in_channels,\r\n            out_channels,\r\n            bias=False,\r\n            linear_module_args=block_args.linear_module_args),\r\n        block_args.normalization_layer_factory.create(out_channels, affine=True),\r\n        block_args.nonlinearity_factory.create())\r\n\r\n\r\ndef create_downsample_block(\r\n        in_channels: int,\r\n        out_channels: int,\r\n        is_output_1x1: bool = False,\r\n        block_args: Optional[BlockArgs] = None) -> Module:\r\n    if block_args is None:\r\n        block_args = BlockArgs()\r\n    if is_output_1x1:\r\n        return Sequential(\r\n            wrap_linear_module(\r\n                Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),\r\n                block_args.linear_module_args),\r\n            block_args.nonlinearity_factory.create())\r\n    else:\r\n        return Sequential(\r\n            wrap_linear_module(\r\n                Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),\r\n                block_args.linear_module_args),\r\n            block_args.normalization_layer_factory.create(out_channels, affine=True),\r\n            block_args.nonlinearity_factory.create())\r\n\r\n\r\ndef create_upsample_block(\r\n        in_channels: int,\r\n        out_channels: int,\r\n        block_args: Optional[BlockArgs] = None) -> Module:\r\n    if block_args is None:\r\n        block_args = BlockArgs()\r\n    return Sequential(\r\n        wrap_linear_module(\r\n            ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),\r\n            linear_module_args=block_args.linear_module_args),\r\n        block_args.normalization_layer_factory.create(out_channels, affine=True),\r\n        block_args.nonlinearity_factory.create())\r\n"
  },
  {
    "path": "src/tha4/shion/nn00/initialization_funcs.py",
    "content": "from typing import Callable, Optional\r\n\r\nimport torch\r\nfrom torch import zero_\r\nfrom torch.nn import Module\r\nfrom torch.nn.init import kaiming_normal_, xavier_normal_, normal_\r\n\r\n\r\nclass HeInitialization:\r\n    def __init__(self, a: int = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'):\r\n        self.nonlinearity = nonlinearity\r\n        self.mode = mode\r\n        self.a = a\r\n\r\n    def __call__(self, module: Module) -> Module:\r\n        with torch.no_grad():\r\n            kaiming_normal_(module.weight, a=self.a, mode=self.mode, nonlinearity=self.nonlinearity)\r\n        return module\r\n\r\n\r\nclass NormalInitialization:\r\n    def __init__(self, mean: float = 0.0, std: float = 1.0):\r\n        self.std = std\r\n        self.mean = mean\r\n\r\n    def __call__(self, module: Module) -> Module:\r\n        with torch.no_grad():\r\n            normal_(module.weight, self.mean, self.std)\r\n        return module\r\n\r\n\r\nclass XavierInitialization:\r\n    def __init__(self, gain: float = 1.0):\r\n        self.gain = gain\r\n\r\n    def __call__(self, module: Module) -> Module:\r\n        with torch.no_grad():\r\n            xavier_normal_(module.weight, self.gain)\r\n        return module\r\n\r\n\r\nclass ZeroInitialization:\r\n    def __call__(self, module: Module) -> Module:\r\n        with torch.no_grad:\r\n            zero_(module.weight)\r\n        return module\r\n\r\n\r\nclass NoInitialization:\r\n    def __call__(self, module: Module) -> Module:\r\n        return module\r\n\r\n\r\ndef resolve_initialization_func(initialization: Optional[Callable[[Module], Module]]):\r\n    if initialization is None:\r\n        return NoInitialization()\r\n    else:\r\n        return initialization\r\n"
  },
  {
    "path": "src/tha4/shion/nn00/linear_module_args.py",
    "content": "from typing import Optional, Callable\r\n\r\nfrom torch.nn import Module\r\nfrom torch.nn.utils import spectral_norm\r\n\r\nfrom tha4.shion.nn00.initialization_funcs import resolve_initialization_func\r\n\r\n\r\nclass LinearModuleArgs:\r\n    def __init__(\r\n            self,\r\n            initialization_func: Optional[Callable[[Module], Module]] = None,\r\n            use_spectral_norm: bool = False):\r\n        self.use_spectral_norm = use_spectral_norm\r\n        self.initialization_func = resolve_initialization_func(initialization_func)\r\n\r\n    def wrap_linear_module(self, module: Module) -> Module:\r\n        module = self.initialization_func(module)\r\n        if self.use_spectral_norm:\r\n            module = spectral_norm(module)\r\n        return module\r\n\r\n\r\ndef wrap_linear_module(module: Module, linear_module_args: Optional[LinearModuleArgs] = None):\r\n    if linear_module_args is None:\r\n        linear_module_args = LinearModuleArgs()\r\n    module = linear_module_args.initialization_func(module)\r\n    if linear_module_args.use_spectral_norm:\r\n        module = spectral_norm(module)\r\n    return module\r\n"
  },
  {
    "path": "src/tha4/shion/nn00/nonlinearity_factories.py",
    "content": "from typing import Optional\r\n\r\nimport torch\r\nfrom torch import Tensor\r\nfrom torch.nn import Module, ReLU, LeakyReLU, ELU, ReLU6, Hardswish, SiLU, Tanh, Sigmoid\r\n\r\nfrom tha4.shion.core.module_factory import ModuleFactory\r\n\r\n\r\nclass ReLUFactory(ModuleFactory):\r\n    def __init__(self, inplace: bool = False):\r\n        self.inplace = inplace\r\n\r\n    def create(self) -> Module:\r\n        return ReLU(self.inplace)\r\n\r\n\r\nclass LeakyReLUFactory(ModuleFactory):\r\n    def __init__(self, inplace: bool = False, negative_slope: float = 1e-2):\r\n        self.negative_slope = negative_slope\r\n        self.inplace = inplace\r\n\r\n    def create(self) -> Module:\r\n        return LeakyReLU(inplace=self.inplace, negative_slope=self.negative_slope)\r\n\r\n\r\nclass ELUFactory(ModuleFactory):\r\n    def __init__(self, inplace: bool = False, alpha: float = 1.0):\r\n        self.alpha = alpha\r\n        self.inplace = inplace\r\n\r\n    def create(self) -> Module:\r\n        return ELU(inplace=self.inplace, alpha=self.alpha)\r\n\r\n\r\nclass ReLU6Factory(ModuleFactory):\r\n    def __init__(self, inplace: bool = False):\r\n        self.inplace = inplace\r\n\r\n    def create(self) -> Module:\r\n        return ReLU6(inplace=self.inplace)\r\n\r\n\r\nclass SiLUFactory(ModuleFactory):\r\n    def __init__(self, inplace: bool = False):\r\n        self.inplace = inplace\r\n\r\n    def create(self) -> Module:\r\n        return SiLU(inplace=self.inplace)\r\n\r\n\r\nclass HardswishFactory(ModuleFactory):\r\n    def __init__(self, inplace: bool = False):\r\n        self.inplace = inplace\r\n\r\n    def create(self) -> Module:\r\n        return Hardswish(inplace=self.inplace)\r\n\r\n\r\nclass TanhFactory(ModuleFactory):\r\n    def create(self) -> Module:\r\n        return Tanh()\r\n\r\n\r\nclass SigmoidFactory(ModuleFactory):\r\n    def create(self) -> Module:\r\n        return Sigmoid()\r\n\r\n\r\nclass Swish(Module):\r\n    def __init__(self):\r\n        super().__init__()\r\n\r\n    def forward(self, x: Tensor):\r\n        return x * torch.sigmoid(x)\r\n\r\n\r\nclass SwishFactory(ModuleFactory):\r\n    def create(self) -> Module:\r\n        return Swish()\r\n\r\n\r\ndef resolve_nonlinearity_factory(nonlinearity_factory: Optional[ModuleFactory]) -> ModuleFactory:\r\n    if nonlinearity_factory is None:\r\n        return ReLUFactory(inplace=False)\r\n    else:\r\n        return nonlinearity_factory"
  },
  {
    "path": "src/tha4/shion/nn00/normalization_layer_factories.py",
    "content": "from typing import Optional\r\n\r\nimport torch\r\nfrom torch.nn import Module, Parameter, BatchNorm2d, InstanceNorm2d, GroupNorm\r\nfrom torch.nn.functional import layer_norm\r\nfrom torch.nn.init import normal_, constant_\r\n\r\nfrom tha4.shion.nn00.normalization_layer_factory import NormalizationLayerFactory\r\nfrom tha4.shion.nn00.pass_through import PassThrough\r\n\r\n\r\nclass Bias2d(Module):\r\n    def __init__(self, num_features: int):\r\n        super().__init__()\r\n        self.num_features = num_features\r\n        self.bias = Parameter(torch.zeros(1, num_features, 1, 1))\r\n\r\n    def forward(self, x):\r\n        return x + self.bias\r\n\r\n\r\nclass NoNorm2dFactory(NormalizationLayerFactory):\r\n    def __init__(self):\r\n        super().__init__()\r\n\r\n    def create(self, num_features: int, affine: bool = True) -> Module:\r\n        if affine:\r\n            return Bias2d(num_features)\r\n        else:\r\n            return PassThrough()\r\n\r\n\r\nclass BatchNorm2dFactory(NormalizationLayerFactory):\r\n    def __init__(self,\r\n                 weight_mean: Optional[float] = None,\r\n                 weight_std: Optional[float] = None,\r\n                 bias: Optional[float] = None):\r\n        super().__init__()\r\n        self.bias = bias\r\n        self.weight_std = weight_std\r\n        self.weight_mean = weight_mean\r\n\r\n    def get_weight_mean(self):\r\n        if self.weight_mean is None:\r\n            return 1.0\r\n        else:\r\n            return self.weight_mean\r\n\r\n    def get_weight_std(self):\r\n        if self.weight_std is None:\r\n            return 0.02\r\n        else:\r\n            return self.weight_std\r\n\r\n    def create(self, num_features: int, affine: bool = True) -> Module:\r\n        module = BatchNorm2d(num_features=num_features, affine=affine)\r\n        if affine:\r\n            if self.weight_mean is not None or self.weight_std is not None:\r\n                normal_(module.weight, self.get_weight_mean(), self.get_weight_std())\r\n            if self.bias is not None:\r\n                constant_(module.bias, self.bias)\r\n        return module\r\n\r\n\r\nclass InstanceNorm2dFactory(NormalizationLayerFactory):\r\n    def __init__(self):\r\n        super().__init__()\r\n\r\n    def create(self, num_features: int, affine: bool = True) -> Module:\r\n        return InstanceNorm2d(num_features=num_features, affine=affine)\r\n\r\n\r\nclass LayerNorm2d(Module):\r\n    def __init__(self, channels: int, affine: bool = True):\r\n        super(LayerNorm2d, self).__init__()\r\n        self.channels = channels\r\n        self.affine = affine\r\n\r\n        if self.affine:\r\n            self.weight = Parameter(torch.ones(1, channels, 1, 1))\r\n            self.bias = Parameter(torch.zeros(1, channels, 1, 1))\r\n\r\n    def forward(self, x):\r\n        shape = x.size()[1:]\r\n        y = layer_norm(x, shape) * self.weight + self.bias\r\n        return y\r\n\r\n\r\nclass LayerNorm2dFactory(NormalizationLayerFactory):\r\n    def __init__(self):\r\n        super().__init__()\r\n\r\n    def create(self, num_features: int, affine: bool = True) -> Module:\r\n        return LayerNorm2d(channels=num_features, affine=affine)\r\n\r\n\r\nclass GroupNormFactory(NormalizationLayerFactory):\r\n    def __init__(self, num_groups: int, eps=1e-6):\r\n        super().__init__()\r\n        self.eps = eps\r\n        self.num_groups = num_groups\r\n\r\n    def create(self, num_features: int, affine: bool = True) -> Module:\r\n        return GroupNorm(num_channels=num_features, num_groups=self.num_groups, eps=self.eps, affine=affine)\r\n\r\n\r\ndef resolve_normalization_layer_factory(factory: Optional['NormalizationLayerFactory']) -> 'NormalizationLayerFactory':\r\n    if factory is None:\r\n        return InstanceNorm2dFactory()\r\n    else:\r\n        return factory\r\n"
  },
  {
    "path": "src/tha4/shion/nn00/normalization_layer_factory.py",
    "content": "from abc import ABC, abstractmethod\r\n\r\nfrom torch.nn import Module\r\n\r\n\r\nclass NormalizationLayerFactory(ABC):\r\n    def __init__(self):\r\n        super().__init__()\r\n\r\n    @abstractmethod\r\n    def create(self, num_features: int, affine: bool = True) -> Module:\r\n        pass\r\n"
  },
  {
    "path": "src/tha4/shion/nn00/pass_through.py",
    "content": "from torch.nn import Module\r\n\r\n\r\nclass PassThrough(Module):\r\n    def __init__(self):\r\n        super().__init__()\r\n\r\n    def forward(self, x):\r\n        return x"
  },
  {
    "path": "src/tha4/shion/nn00/resnet_block.py",
    "content": "from typing import Optional\r\n\r\nimport torch\r\nfrom torch.nn import Module, Sequential, Parameter\r\n\r\nfrom tha4.shion.nn00.block_args import BlockArgs\r\nfrom tha4.shion.nn00.conv import create_conv1, create_conv3\r\n\r\n\r\nclass ResnetBlock(Module):\r\n    def __init__(self,\r\n                 num_channels: int,\r\n                 is1x1: bool = False,\r\n                 use_scale_parameter: bool = False,\r\n                 block_args: Optional[BlockArgs] = None):\r\n        super().__init__()\r\n        if block_args is None:\r\n            block_args = BlockArgs()\r\n        self.use_scale_parameter = use_scale_parameter\r\n        if self.use_scale_parameter:\r\n            self.scale = Parameter(torch.zeros(1))\r\n        if is1x1:\r\n            self.resnet_path = Sequential(\r\n                create_conv1(\r\n                    num_channels,\r\n                    num_channels,\r\n                    bias=True,\r\n                    linear_module_args=block_args.linear_module_args),\r\n                block_args.nonlinearity_factory.create(),\r\n                create_conv1(\r\n                    num_channels,\r\n                    num_channels,\r\n                    bias=True,\r\n                    linear_module_args=block_args.linear_module_args))\r\n        else:\r\n            self.resnet_path = Sequential(\r\n                create_conv3(\r\n                    num_channels,\r\n                    num_channels,\r\n                    bias=False,\r\n                    linear_module_args=block_args.linear_module_args),\r\n                block_args.normalization_layer_factory.create(num_channels, affine=True),\r\n                block_args.nonlinearity_factory.create(),\r\n                create_conv3(\r\n                    num_channels,\r\n                    num_channels,\r\n                    bias=False,\r\n                    linear_module_args=block_args.linear_module_args),\r\n                block_args.normalization_layer_factory.create(num_channels, affine=True))\r\n\r\n    def forward(self, x):\r\n        if self.use_scale_parameter:\r\n            return x + self.scale * self.resnet_path(x)\r\n        else:\r\n            return x + self.resnet_path(x)\r\n"
  }
]