Repository: pkhungurn/talking-head-anime-4-demo Branch: main Commit: 320640116abd Files: 185 Total size: 655.5 KB Directory structure: gitextract_k8uli292/ ├── .gitignore ├── .python-version ├── LICENSE ├── README.md ├── bin/ │ ├── activate-venv.bat │ ├── activate-venv.sh │ ├── run │ └── run.bat ├── distiller-ui-doc/ │ ├── index.html │ └── params/ │ ├── body_morpher_batch_size.html │ ├── body_morpher_random_seed_0.html │ ├── body_morpher_random_seed_1.html │ ├── character_image_file_name.html │ ├── face_mask_image_file_name.html │ ├── face_morpher_batch_size.html │ ├── face_morpher_random_seed_0.html │ ├── face_morpher_random_seed_1.html │ ├── num_cpu_workers.html │ ├── num_gpus.html │ ├── num_training_examples_per_sample_output.html │ └── prefix.html ├── docs/ │ ├── character_model_ifacialmocap_puppeteer.md │ ├── character_model_manual_poser.md │ ├── character_model_mediapipe_puppeteer.md │ ├── distill.md │ ├── distiller_ui.md │ └── full_manual_poser.md ├── poetry/ │ ├── README.md │ └── pyproject.toml └── src/ └── tha4/ ├── __init__.py ├── app/ │ ├── __init__.py │ ├── character_model_ifacialmocap_puppeteer.py │ ├── character_model_manual_poser.py │ ├── character_model_mediapipe_puppeteer.py │ ├── distill.py │ ├── distiller_ui.py │ └── full_manual_poser.py ├── charmodel/ │ ├── __init__.py │ └── character_model.py ├── dataset/ │ ├── __init__.py │ └── image_poses_and_aother_images_dataset.py ├── distiller/ │ ├── __init__.py │ ├── config_based_training_tasks.py │ ├── distill_body_morpher.py │ ├── distill_face_morpher.py │ ├── distiller_config.py │ └── ui/ │ ├── __init__.py │ ├── distiller_config_state.py │ └── distiller_ui_main_frame.py ├── image_util.py ├── mocap/ │ ├── __init__.py │ ├── ifacialmocap_constants.py │ ├── ifacialmocap_pose.py │ ├── ifacialmocap_pose_converter.py │ ├── ifacialmocap_pose_converter_25.py │ ├── ifacialmocap_v2.py │ ├── mediapipe_constants.py │ ├── mediapipe_face_pose.py │ ├── mediapipe_face_pose_converter.py │ └── mediapipe_face_pose_converter_00.py ├── nn/ │ ├── __init__.py │ ├── common/ │ │ ├── __init__.py │ │ ├── conv_block_factory.py │ │ ├── poser_args.py │ │ ├── poser_encoder_decoder_00.py │ │ ├── poser_encoder_decoder_00_separable.py │ │ ├── resize_conv_encoder_decoder.py │ │ ├── resize_conv_unet.py │ │ └── unet.py │ ├── conv.py │ ├── eyebrow_decomposer/ │ │ ├── __init__.py │ │ └── eyebrow_decomposer_00.py │ ├── eyebrow_morphing_combiner/ │ │ ├── __init__.py │ │ └── eyebrow_morphing_combiner_00.py │ ├── face_morpher/ │ │ ├── __init__.py │ │ └── face_morpher_08.py │ ├── image_processing_util.py │ ├── init_function.py │ ├── morpher/ │ │ ├── __init__.py │ │ └── morpher_00.py │ ├── nonlinearity_factory.py │ ├── normalization.py │ ├── pass_through.py │ ├── resnet_block.py │ ├── resnet_block_seperable.py │ ├── separable_conv.py │ ├── siren/ │ │ ├── __init__.py │ │ ├── face_morpher/ │ │ │ ├── __init__.py │ │ │ ├── siren_face_morpher_00.py │ │ │ ├── siren_face_morpher_00_trainer.py │ │ │ └── siren_face_morpher_protocols_00.py │ │ ├── morpher/ │ │ │ ├── __init__.py │ │ │ ├── siren_morpher_03.py │ │ │ ├── siren_morpher_03_trainer.py │ │ │ └── siren_morpher_protocols_03.py │ │ └── vanilla/ │ │ ├── __init__.py │ │ └── siren.py │ ├── spectral_norm.py │ ├── upscaler/ │ │ ├── __init__.py │ │ └── upscaler_02.py │ └── util.py ├── poser/ │ ├── __init__.py │ ├── general_poser_02.py │ ├── modes/ │ │ ├── __init__.py │ │ ├── mode_07.py │ │ ├── mode_12.py │ │ ├── mode_14.py │ │ └── pose_parameters.py │ └── poser.py ├── pytasuku/ │ ├── __init__.py │ ├── indexed/ │ │ ├── __init__.py │ │ ├── all_tasks.py │ │ ├── bundled_indexed_file_tasks.py │ │ ├── indexed_file_tasks.py │ │ ├── indexed_tasks.py │ │ ├── no_index_command_tasks.py │ │ ├── no_index_file_tasks.py │ │ ├── one_index_file_tasks.py │ │ ├── simple_no_index_file_tasks.py │ │ ├── two_indices_file_tasks.py │ │ └── util.py │ ├── task.py │ ├── task_selector_ui.py │ ├── util.py │ └── workspace.py ├── sampleoutput/ │ ├── __init__.py │ ├── general_sample_output_protocol.py │ ├── poser_sampler_output_protocol.py │ └── sample_image_creator.py └── shion/ ├── __init__.py ├── base/ │ ├── __init__.py │ ├── dataset/ │ │ ├── __init__.py │ │ ├── lazy_dataset.py │ │ ├── lazy_tensor_dataset.py │ │ ├── png_in_dir_dataset.py │ │ ├── util.py │ │ └── xformed_dataset.py │ ├── image_util.py │ ├── loss/ │ │ ├── __init__.py │ │ ├── computed_scale_loss.py │ │ ├── computed_scaled_l2_loss.py │ │ ├── l1_loss.py │ │ ├── l2_loss.py │ │ ├── sum_loss.py │ │ └── time_dependently_weighted_loss.py │ ├── module_accumulators.py │ ├── optimizer_factories.py │ ├── protocol/ │ │ └── single_network_from_batch_input_computation_protocol.py │ └── training/ │ ├── __init__.py │ ├── single_network.py │ ├── single_network_with_minibatch.py │ └── two_networks_training_protocol.py ├── core/ │ ├── __init__.py │ ├── cached_computation.py │ ├── load_save.py │ ├── loss.py │ ├── module_accumulator.py │ ├── module_factory.py │ ├── optimizer_factory.py │ └── training/ │ ├── __init__.py │ ├── distrib/ │ │ ├── __init__.py │ │ ├── device_mapper.py │ │ ├── distributed_trainer.py │ │ ├── distributed_training_states.py │ │ └── distributed_training_tasks.py │ ├── sample_output_protocol.py │ ├── single/ │ │ ├── __init__.py │ │ ├── training_states.py │ │ └── training_tasks.py │ ├── swarm/ │ │ ├── __init__.py │ │ ├── swarm_training_tasks.py │ │ └── swarm_unit_trainer.py │ ├── training_protocol.py │ ├── util.py │ └── validation_protocol.py └── nn00/ ├── __init__.py ├── block_args.py ├── conv.py ├── initialization_funcs.py ├── linear_module_args.py ├── nonlinearity_factories.py ├── normalization_layer_factories.py ├── normalization_layer_factory.py ├── pass_through.py └── resnet_block.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Compiled class file *.class # Log file *.log # BlueJ files *.ctxt # Mobile Tools for Java (J2ME) .mtj.tmp/ # Package Files # *.jar *.war *.nar *.ear *.zip *.tar.gz *.rar # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml hs_err_pid* .gradle .vscode out/ .idea/ .gradle/ build/ /data/ .idea/* *.iml */.idea/* */build */.gradle/* */out/* *.pyc *.pyd **/.cache/* */bin __pycache__/ ./tools/ temp/ */*/bin/ venv/* # io dump to ABCI tasks *.o* ================================================ FILE: .python-version ================================================ 3.10.11 ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2024 pixiv Inc. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # Demo Code for "Talking Head(?) Anime from a Single Image 4: Improved Model and Its Distillation" This repository contains demo programs for the "Talking Head(?) Anime from a Single Image 4: Improved Model and Its Distillation" project. Roughly, the project is about a machine learning model that can animate an anime character given only one image. However, the model is too slow to run in real-time. So, it also proposes an algorithm to use the model to train a small machine learning model that is specialized to a character image that can anime the character in real time. This demo code has two parts. * **Improved model.** This part gives a model similar to [Version 3](https://github.com/pkhungurn/talking-head-anime-3-demo) of the porject. It has one demo program: * The `full_manual_poser` allows the user to manipulate a character's facial expression and body rotation through a graphical user interface. There are no real-time demos because the new model is too slow for that. * **Distillation.** This part allows the user to train small models (which we will refer to as **student models**) to mimic that behavior of the full system with regards to a specific character image. It also allows the user to run these models under various interfaces. The demo programs are: * `distill` trains a student model given a configuration file, a $512 \times 512$ RGBA character image, and a mask of facial organs. * `distiller_ui` provides a user-friendly interface to `distill`, allowing you to create training configurations and providing useful documentation. * `character_model_manual_poser` allows the user to control trained student models with a graphical user interface. * `character_model_ifacialmocap_puppeteer` allows the user to control trained student models with their facial movement, which is captured by the [iFacialMocap](https://www.ifacialmocap.com/) software. To run this software, you must have an iOS device and, of course, iFacialMocap. * `character_model_mediapipe_puppeteer` allows the user to control trained student models with their facial movement, which is captured a web camera and processed by the [Mediapipe FaceLandmarker](https://developers.google.com/mediapipe/solutions/vision/face_landmarker) model. To run this software, you need a web camera. ## Preemptive FAQs ### What is the program to control character images with my facial movement? There is no such program in this release. If you want one, try the `ifacialmocap_puppeteer` of [Version 3](https://github.com/pkhungurn/talking-head-anime-3-demo). ### OK. I'm confused. Isn't your work about easy VTubing? Are you saying this release cannot do it? NO. This release does it in a more complicated way. In order to control an image, you need to create a "student model." It is a small (< 2MB) and fast machine learning model that knows how to animate that particular image. Then, the student model can be controlled with facial movement. You can find two student models in the `data/character_models` directory. The [two](https://pkhungurn.github.io/talking-head-anime-4/supplementary/webcam-demo/index.html) [demos](https://pkhungurn.github.io/talking-head-anime-4/supplementary/manual-poser-demo/index.html) on the project website feature 13 students models. ### So, for this release, you can control only these few characters in real time? No. You can create your own student models. ### How do I create this student model then? 1. You prepare your characater image according to the "Constraint on Input Images" section below. 2. You prepare a black-and-white mask image that covers the eyes and the mouth of the character, like [this image](data/images/lambda_00_face_mask.png). You can see how I made it with [GIMP](https://www.gimp.org/) by inspecting this [GIMP file](data/images/lambda_00_face_mask.xcf). 3. You use `distiller_ui` to create a configuration file that specifies how the student model should be trained. 4. You use `distiller_ui` or `distill` to start the training process. 5. You wait several ten hours for the student model to finish training. Last time I tried, it was about 30 hours on a computer with an Nvidia RTX A6000 GPU. 6. After that, you can control the student model with `character_model_ifacialmocap_puppeteer` and `character_model_mediapipe_puppeteer`. ### Why is this release so hard to use? [Version 3](https://github.com/pkhungurn/talking-head-anime-3-demo) is arguably easier to use because you can give it an animate and you can control it with your facial movment immediately. However, I was not satisfied with its image quality and speed. In this release, I explore a new way of doing things. I added a new preprocessing stage (i.e., training the student models) that has to be done one time per character image. It allows the image to be animated much faster at a higher image quality level. In other words, it makes the user's life difficult but the engineer/researcher happy. Patient users who are willing to go through the steps, though, would be rewarded with faster animation. ### Can I use a student model from a web browser? No. A student model created by `distill` is a [PyTorch](https://pytorch.org/) model, which cannot run directly in the browser. It needs to be converted to the appropriate format ([TensorFlow.js](https://www.tensorflow.org/js)) first, and the [web](https://pkhungurn.github.io/talking-head-anime-4/supplementary/webcam-demo/index.html) [demos](https://pkhungurn.github.io/talking-head-anime-4/supplementary/manual-poser-demo/index.html) use the converted models. However, The conversion code is not included in this repository. I will not release it unless I change my mind. ## Hardware Requirements All programs require a recent and powerful Nvidia GPU to run. I developed the programs on a machine with an Nvidia RTX A6000. However, anything after the GeForce RTX 2080 should be fine. The `character_model_ifacialmocap_puppeteer` program requires an iOS device that is capable of computing [blend shape parameters](https://developer.apple.com/documentation/arkit/arfaceanchor/2928251-blendshapes) from a video feed. This means that the device must be able to run iOS 11.0 or higher and must have a TrueDepth front-facing camera. (See [this page](https://developer.apple.com/documentation/arkit/content_anchors/tracking_and_visualizing_faces) for more info.) In other words, if you have the iPhone X or something better, you should be all set. Personally, I have used an iPhone 12 mini. The `character_model_mediapipe_puppeteer` program requires a web camera. ## Software Requirements ### GPU Driver and CUDA Toolkit Please update your GPU's device driver and install the [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit) that is compatible with your GPU and is newer than the version you will be installing in the next subsection. ### Python and Python Libraries All programs are written in the [Python](https://www.python.org/) programming languages. The following libraries are required: * `python` 3.10.11 * `torch` 1.13.1 with CUDA support * `torchvision` 0.14.1 * `tensorboard` 2.15.1 * `opencv-python` 4.8.1.78 * `wxpython` 4.2.1 * `numpy-quaternion` 2022.4.2 * `pillow` 9.4.0 * `matplotlib` 3.6.3 * `einops` 0.6.0 * `mediapipe` 0.10.3 * `numpy` 1.26.3 * `scipy` 1.12.0 * `omegaconf` 2.3.0 Instead of installing these libraries yourself, you should follow the recommended method to set up a Python environment in the next section. ### iFacialMocap If you want to use ``ifacialmocap_puppeteer``, you will also need to an iOS software called [iFacialMocap](https://www.ifacialmocap.com/) (a 980 yen purchase in the App Store). Your iOS and your computer must use the same network. For example, you may connect them to the same wireless router. ## Creating Python Environment ### Installing Python Please install [Python 3.10.11](https://www.python.org/downloads/release/python-31011/). I recommend using [`pyenv`](https://github.com/pyenv/pyenv) (or [`pyenv-win`](https://github.com/pyenv-win/pyenv-win) for Windows users) to manage multiple Python versions on your system. If you use `pyenv`, this repository has a `.python-version` file that indicates it would use Python 3.10.11. So, you will be using Python 3.10.11 automatically once you `cd` into the repository's directory. Make sure that you can run Python from the command line. ### Installing Poetry Please install [Poetry](https://python-poetry.org/) 1.7 or later. We will use it to automatically install the required libraries. Again, make sure that you can run it from the command line. ### Cloning the Repository Please clone the repository to an arbitrary directory in your machine. ### Instruction for Linux/OSX Users 1. Open a shell. 2. `cd` to the directory you just cloned the repository too ``` cd SOMEWHERE/talking-head-anime-4-demo ``` 3. Use Python to create a virtual environment under the `venv` directory. ``` python -m venv venv --prompt talking-head-anime-4-demo ``` 4. Activate the newly created virtual environment. You can either use the script I provide: ``` source bin/activate-venv.sh ``` or do it yourself: ``` source venv/bin/activate ``` 5. Use Poetry to install libraries. ``` cd poetry poetry install ``` ### Instruction for Windows Users 1. Open a shell. 2. `cd` to the directory you just cloned the repository too ``` cd SOMEWHERE\talking-head-anime-4-demo ``` 3. Use Python to create a virtual environment under the `venv` directory. ``` python -m venv venv --prompt talking-head-anime-4-demo ``` 4. Activate the newly created virtual environment. You can either use the script I provide: ``` bin\activate-venv.bat ``` or do it yourself: ``` venv\Scripts\activate ``` 5. Use Poetry to install libraries. ``` cd poetry poetry install ``` ## Download the Models/Dataset Files ### THA4 Models Please download [this ZIP file](https://www.dropbox.com/scl/fi/7wec0sur7449iqgtlpi3n/tha4-models.zip?rlkey=0f9d1djmbvjjjn09469s1adx8&dl=0) hosted on Dropbox, and unzip it to the `data/tha4` directory the under the repository's directory. In the end, the directory tree should look like the following diagram: ``` + talking-head-anime-4-demo + data - character_models - distill_examples + tha4 - body_morpher.pt - eyebrow_decomposer.pt - eyebrow_morphing_combiner.pt - face_morpher.pt - upscaler.pt - images - third_party ``` ### Pose Dataset If you want to create your own student models, you also need to download a dataset of poses that are needed for the training process. Download [this `pose_dataset.pt` file](https://www.dropbox.com/scl/fi/du10e6buzr5bslbe025qu/pose_dataset.pt?rlkey=y052g4n3xb14nu2elctzouc5x&dl=0) and save it to the `data` folder. The directory tree should then look like the following diagram: ``` + talking-head-anime-4-demo + data - character_models - distill_examples - tha4 - images - third_party - pose_dataset.pt ``` ## Running the Programs The programs are located in the `src/tha4/app` directory. You need to run them from a shell with the provided scripts. ### Instruction for Linux/OSX Users 1. Open a shell. 2. `cd` to the repository's directory. ``` cd SOMEWHERE/talking-head-anime-4-demo ``` 3. Run a program. ``` bin/run src/tha4/app/ ``` where `` can be replaced with: * `character_model_ifacialmocap_puppeteer.py` * `character_model_manual_poser.py` * `character_model_mediapipe_puppeteer.py` * `distill.py` * `disllerer_ui.py` * `full_manual_poser.py` ### Instruction for Windows Users 1. Open a shell. 2. `cd` to the repository's directory. ``` cd SOMEWHERE\talking-head-anime-4-demo ``` 3. Run a program. ``` bin\run.bat src\tha4\app\ ``` where `` can be replaced with: * `character_model_ifacialmocap_puppeteer.py` * `character_model_manual_poser.py` * `character_model_mediapipe_puppeteer.py` * `distill.py` * `disllerer_ui.py` * `full_manual_poser.py` ## Contraints on Input Images In order for the system to work well, the input image must obey the following constraints: * It should be of resolution 512 x 512. (If the demo programs receives an input image of any other size, they will resize the image to this resolution and also output at this resolution.) * It must have an alpha channel. * It must contain only one humanoid character. * The character should be standing upright and facing forward. * The character's hands should be below and far from the head. * The head of the character should roughly be contained in the 128 x 128 box in the middle of the top half of the image. * The alpha channels of all pixels that do not belong to the character (i.e., background pixels) must be 0. ![An example of an image that conforms to the above criteria](docs/images/input_spec.png "An example of an image that conforms to the above criteria") ## Documentation for the Tools * [`character_model_ifacial_model_puppeteer`](docs/character_model_ifacialmocap_puppeteer.md) * [`character_model_manual_poser`](docs/character_model_manual_poser.md) * [`character_model_mediapipe_puppeteer`](docs/character_model_mediapipe_puppeteer.md) * [`distill`](docs/distill.md) * [`distiller_ui`](docs/distiller_ui.md) * [`full_manual_poser`](docs/full_manual_poser.md) ## Disclaimer The author is an employee of [pixiv Inc.](https://www.pixiv.co.jp/) This project is a part of his work as a researcher. However, this project is NOT a pixiv product. The company will NOT provide any support for this project. The author will try to support the project, but there are no Service Level Agreements (SLAs) that he will maintain. The code is released under the [MIT license](https://github.com/pkhungurn/talking-head-anime-2-demo/blob/master/LICENSE). The THA4 models and the images under the `data/images` directory are released under the [Creative Commons Attribution-NonCommercial 4.0 International](https://creativecommons.org/licenses/by-nc/4.0/deed.en). This repository redistributes a version of the [Face landmark detection model](https://developers.google.com/mediapipe/solutions/vision/face_landmarker) from the [MediaPipe](https://developers.google.com/mediapipe) project. The model has been released under the [Apache License, Version 2.0](https://www.apache.org/licenses/LICENSE-2.0.html). ================================================ FILE: bin/activate-venv.bat ================================================ venv\Scripts\activate ================================================ FILE: bin/activate-venv.sh ================================================ #! /bin/bash source venv/bin/activate ================================================ FILE: bin/run ================================================ #! /bin/bash export PYTHONPATH=$(pwd)/src venv/bin/python $@ ================================================ FILE: bin/run.bat ================================================ set PYTHONPATH=%cd%\src venv\Scripts\python.exe %* ================================================ FILE: distiller-ui-doc/index.html ================================================ Distiller UI Documentation

How to use Distiller UI

This program is called distiller_ui. It allows you to create and modify configurations for the process of distilling the full, but slow THA4 system to a student model that can be run in real time on computers with moderately powerful GPUs.

Basic Usage

This program manipulates YAML files that are used as configurations for the distillation process. The menus

  • File → New
  • File → Open
  • File → Save
do what they are supposed to do in typical application programs.

You can use the UI in the middle panel to change various parameters of the configuration. If you do not understand what the meaning of a parameter, click the "Help" button for that parameter to learn more.

Once you have modified the parameters to your liking, click the "RUN" button at the bottom of the middle panel to carry out the distillation. This will take several ten hours, so sit back and relax.

The distillation process can be interrupted and resumed at any time. As a result, you do not have to worry that you may lose data if there's a blackout or if you need to free your GPU(s) to do something else. Resuming can be done through this program or through the distill script.

Explanation of Configuration Parameters

================================================ FILE: distiller-ui-doc/params/body_morpher_batch_size.html ================================================ Distiller UI Documentation: body_morpher_batch_size

body_morpher_batch_size

The "batch size" is the number of training examples shown to a machine learning model in one round of parameter update. This parameter is the batch size for training the student body morpher. We recommend you set it to 8. However, if your computer does not have enough GPU RAM, you can reduce the number to any smaller positive integer.


Back to main documentation ================================================ FILE: distiller-ui-doc/params/body_morpher_random_seed_0.html ================================================ Distiller UI Documentation: body_morpher_random_seed_0

body_morpher_random_seed_0

This parameter will be used as a random seed in the process of training the student body morpher. It can be any non-negative integer from 0 to 264-1. You can specify the number directly, or use the "Randomize" button to specify a random one.


Back to main documentation ================================================ FILE: distiller-ui-doc/params/body_morpher_random_seed_1.html ================================================ Distiller UI Documentation: body_morpher_random_seed_1

body_morpher_random_seed_1

This parameter will be used as a random seed in the process of training the student body morpher. It can be any non-negative integer from 0 to 264-1. You can specify the number directly, or use the "Randomize" button to specify a random one.


Back to main documentation ================================================ FILE: distiller-ui-doc/params/character_image_file_name.html ================================================ Distiller UI Documentation: character_image_file_name

character_image_file_name

This is the name of the file of an image of a humanoid character. The image must conform to the following specifications.

  • It MUST in the PNG format.
  • It MUST have an alpha channel.
  • It MUST be 512 x 512.
  • It MUST contain only one humanoid character.
  • The character should be standing upright and facing forward.
  • The character's hands should be below and far from the head.
  • The head of the character should roughly be contained in the 128 x 128 box in the middle of the top half of the image.
  • The alpha channels of all pixels that do not belong to the character (i.e., background pixels) must be 0.

Once you have chosen the image, a crop of the character face will be shown on the right side of the window. In order for the distillation process works correctly, make sure that all the movable parts of the face— eyes, eyebrows, mouth, jaw line — can all be seen in this crop.


This image is GOOD because we can see all of the eyes, eyebrows, mouth, and jaw line in the image.

This image is NOT GOOD because we cannot see the whole of the jaw line in the image

This image is NOT GOOD because we cannot see the whole of the right eye and eyebrow in the image.

This image is NOT GOOD because we cannot see the whole of the eyebrows in the image.

The data/images directory contains two example images that conform to all the above specifications: data/images/lambda_00.png and data/images/lambda_01.png. Please use them as references.


Back to main documentation ================================================ FILE: distiller-ui-doc/params/face_mask_image_file_name.html ================================================ Distiller UI Documentation: face_mask_image_file_name

face_mask_image_file_name

This is the name of the file containing binary masks of movable facial organs of the character. It is probably the best to see an example.

A "face mask image" conforms to the following specification.

  • It must be in the PNG format.
  • It must be 512 x 512.
  • It must be an RGB image (i.e., no alpha channel).
  • All pixels must be either block (0,0,0) or white (255,255,255).
  • The white pixels should cover movable parts of the face.

We recommend creating three rectangles.

  • One covers the right eye and eyebrow.
  • One covers the left eye and eyebrow.
  • One covers the mouth and the jaw line.

The rectangles for the eyes and the eyebrows should extend above the eyes to some extent because the eyebrows can move upward.

Once you have specified the face mask image with the "Change..." button, a crop of the face area will show up on the left side of the window. If the character image has also been specified, an image of the face mask laid over the character's face will also show up. Use this image to check whether the masks are covering everything.


Back to main documentation ================================================ FILE: distiller-ui-doc/params/face_morpher_batch_size.html ================================================ Distiller UI Documentation: face_morpher_batch_size

face_morpher_batch_size

The "batch size" is the number of training examples shown to a machine learning model in one round of parameter update. This parameter is the batch size for training the student face morpher. We recommend you set it to 8. However, if your computer does not have enough GPU RAM, you can reduce the number to any smaller positive integer.


Back to main documentation ================================================ FILE: distiller-ui-doc/params/face_morpher_random_seed_0.html ================================================ Distiller UI Documentation: face_morpher_random_seed_0

face_morpher_random_seed_0

This parameter will be used as a random seed in the process of training the student face morpher. It can be any non-negative integer from 0 to 264-1. You can specify the number directly, or use the "Randomize" button to specify a random one.


Back to main documentation ================================================ FILE: distiller-ui-doc/params/face_morpher_random_seed_1.html ================================================ Distiller UI Documentation: face_morpher_random_seed_1

face_morpher_random_seed_1

This parameter will be used as a random seed in the process of training the student face morpher. It can be any non-negative integer from 0 to 264-1. You can specify the number directly, or use the "Randomize" button to specify a random one.


Back to main documentation ================================================ FILE: distiller-ui-doc/params/num_cpu_workers.html ================================================ Distiller UI Documentation: face_mask_image_file_name

num_cpu_workers

This is the number of worker threads that are used to process pose data during training of the student models. Typically, 1 would be enough, but you can specify up to the number of CPUs your computer has.


Back to main documentation ================================================ FILE: distiller-ui-doc/params/num_gpus.html ================================================ Distiller UI Documentation: num_gpus

num_gpus

This is the number of GPUs that are used to to train the student models. Typically, 1 would be enough. However, you can specify up to the number of Nvidia GPUs that your PC has.


Back to main documentation ================================================ FILE: distiller-ui-doc/params/num_training_examples_per_sample_output.html ================================================ Distiller UI Documentation: num_training_example_per_sample_output

num_training_example_per_sample_output

During training of a student model, the training process would periodically create "sample output" produced by the model being trained in order to allow the user to see training progress and observe whether there is any anomalies.

This parameter specifies how frequent the sample outputs are generated. You can indicate whether you want a sample output to be generated every time the trained model has beeen shown 10,000, 100,000 or 1,000,000 training examples. If you do not care about sample outputs, you can also make the process not generate any sample outputs at all.


Back to main documentation ================================================ FILE: distiller-ui-doc/params/prefix.html ================================================ Distiller UI Documentation: prefix

prefix

prefix is the name of the directory under which the distillation process will store the trained models and other intermediate data. Please choose a directory that is a subdirectory of the directory that stores the talking-head-anime-4-demo's repository.


Back to main documentation ================================================ FILE: docs/character_model_ifacialmocap_puppeteer.md ================================================ # `character_model_ifacialmocap_puppeteer` This program allows the user to control trained student models with their facial movement, which is captured by the [iFacialMocap](https://www.ifacialmocap.com/) software. You can purchase the software from the App Store for 980 Japanese Yen. ## Invoking the Program Make sure you have (1) created a Python environment and (2) downloaded model files as instruction in the [main README file](../README.md). ### Instruction for Linux/OSX Users 1. Open a shell. 2. `cd` to the repository's directory. ``` cd SOMEWHERE/talking-head-anime-4-demo ``` 3. Run the program. ``` bin/run src/tha4/app/character_model_ifacialmocap_puppeteer.py ``` ### Instruction for Windows Users 1. Open a shell. 2. `cd` to the repository's directory. ``` cd SOMEWHERE\talking-head-anime-4-demo ``` 3. Run the program. ``` bin\run.bat src\tha4\app\character_model_ifacialmocap_puppeteer.py ``` ## Usage 1. Run iFacialMocap on your iOS device. It should show you the device's IP address. Jot it down. Keep the app open. ![IP address in iFacialMocap screen](images/ifacialmocap_ip.jpg "IP address in iFacialMocap screen") 2. Invoke the `character_model_ifacialmocap_puppeteer` application. 3. You will see a text box with label "Capture Device IP." Write the iOS device's IP address that you jotted down there. ![Write IP address of your iOS device in the 'Capture Device IP' text box.](images/ifacialmocap-puppeteer-device-ip.png "Write IP address of your iOS device in the 'Capture Device IP' text box.") 4. Click the "START CAPTURE!" button to the right. ![Click the 'START CAPTURE!' button.](images/ifacialmocap-puppeteer-start-capture.png "Click the 'START CAPTURE!' button.") If the programs are connected properly, you should see the numbers in the bottom part of the window change when you move your head. ![The numbers in the bottom part of the window should change when you move your head.](images/ifacialmocap-puppeteer-moving-numbers.png "The numbers in the bottom part of the window should change when you move your head.") 5. Now, you can load a student model, and the character should follow your facial movement. ================================================ FILE: docs/character_model_manual_poser.md ================================================ # `character_model_manual_poser` This program allows the user to control trained student models with a graphical user interface, mostly sliders. ## Invoking the Program Make sure you have (1) created a Python environment and (2) downloaded model files as instruction in the [main README file](../README.md). ### Instruction for Linux/OSX Users 1. Open a shell. 2. `cd` to the repository's directory. ``` cd SOMEWHERE/talking-head-anime-4-demo ``` 3. Run the program. ``` bin/run src/tha4/app/character_model_manual_poser.py ``` ### Instruction for Windows Users 1. Open a shell. 2. `cd` to the repository's directory. ``` cd SOMEWHERE\talking-head-anime-4-demo ``` 3. Run the program. ``` bin\run.bat src\tha4\app\character_model_manual_poser.py ``` ================================================ FILE: docs/character_model_mediapipe_puppeteer.md ================================================ # `character_model_mediapipe_puppeteer` allows the user to control trained student models with their facial movement, which is captured by a web camera and processed by the [Mediapipe FaceLandmarker](https://developers.google.com/mediapipe/solutions/vision/face_landmarker) model. ## Web Camera Please make sure that, before you invoke the program, your computer has a web camera plugged in. The program will use a web camera, but it does not allow you to specify which. In case your machine has more than one web camera, you can turn off all camera except the one that you want to use. You can also inspect the [source code](../src/tha4/app/character_model_mediapipe_puppeteer.py) and change the ``` video_capture = cv2.VideoCapture(0) ``` line to choose a particular camera that you want to use. ## Invoking the Program Make sure you have (1) created a Python environment and (2) downloaded model files as instruction in the [main README file](../README.md). ### Instruction for Linux/OSX Users 1. Open a shell. 2. `cd` to the repository's directory. ``` cd SOMEWHERE/talking-head-anime-4-demo ``` 3. Run the program. ``` bin/run src/tha4/app/character_model_mediapipe_puppeteer.py ``` ### Instruction for Windows Users 1. Open a shell. 2. `cd` to the repository's directory. ``` cd SOMEWHERE\talking-head-anime-4-demo ``` 3. Run the program. ``` bin\run.bat src\tha4\app\character_model_mediapipe_puppeteer.py ``` ================================================ FILE: docs/distill.md ================================================ # `distill` This program trains a student model given a configuration file, a $512 \times 512$ RGBA character image, and a mask of facial organs. ## Invoking the Program Make sure you have (1) created a Python environment and (2) downloaded model files as instruction in the [main README file](../README.md). ### Instruction for Linux/OSX Users 1. Open a shell. 2. `cd` to the repository's directory. ``` cd SOMEWHERE/talking-head-anime-4-demo ``` 3. Run the program. ``` bin/run src/tha4/app/distill.py ``` where `` is a configuration file for creating a student model. More on this later. ### Instruction for Windows Users 1. Open a shell. 2. `cd` to the repository's directory. ``` cd SOMEWHERE\talking-head-anime-4-demo ``` 3. Run the program. ``` bin\run.bat src\tha4\app\full_manual_poser.py ``` where `` is a configuration file for creating a student model. More on this later. ## Configuration File A configuration file is a [YAML](https://yaml.org/) file that specify how to create a student model. This repository comes with two valid configuration files that you can peruse: * [data/distill_examples/lambda_00/config.yaml](../data/distill_examples/lambda_00/config.yaml) * [data/distill_examples/lambda_01/config.yaml](../data/distill_examples/lambda_01/config.yaml) I recommend that you use the `distiller_ui` program to create configuration files rather than writing them yourself. Inside the program, you can see what the fields are and what they mean. ## What `distill` Outputs Inside the configuration file, you specify a directory where the student models should be saved to in the `prefix` field. After `distill` is done with its job, the output directory will look like this: ``` + + body_morpher + face_morpher + character_model - config.yaml ``` Here: * `config.yaml` is a copy of the configuration file that you wrote. * The `character_model` directory contains a trained student model that can be used with `character_model_manual_poser.md`, `character_model_ifacialmocap_puppeteer.md`, and `character_model_mediapipe_puppeteer.md`. * `body_morpher` is a scratch directory that was used to save intermediate results during the training of a part of the student model. * `face_morpher` is a scratch directory that was used to save intermediate results during the training of another part of the student model. You only need what is inside the `character_model` directory. As a resulit, you can delete other files after the `character_model` directory has been filled. You can move the directory out to somewhere and rename it as long as the contents inside are not modified. ## The Training Process Is Interruptible Invoking `distill` on a configuration will start a rather long process of training a student model. On a machine with an A6000 GPU, it takes about 30 hours to complete. As a result, it might take several days on machines with less powerful GPUs. The training process is robust and interruptible. You can stop it any time by closing the shell window or by typing `Ctrl+C`. Intermediate results are periodically saved in the scratch directories, ready to be picked up at a later time when you are ready to train the student model again. To resume the process, just invoke `distill` again with the same configuration file that you started with, and the process will take care of itself. ================================================ FILE: docs/distiller_ui.md ================================================ # `distiller_ui` This program provides a user-friendly interface to the [`distill`](distill.md) program, allowing you to create training configurations and providing useful documentation. ## Invoking the Program Make sure you have (1) created a Python environment and (2) downloaded model files as instruction in the [main README file](../README.md). ### Instruction for Linux/OSX Users 1. Open a shell. 2. `cd` to the repository's directory. ``` cd SOMEWHERE/talking-head-anime-4-demo ``` 3. Run the program. ``` bin/run src/tha4/app/distill_ui.py ``` ### Instruction for Windows Users 1. Open a shell. 2. `cd` to the repository's directory. ``` cd SOMEWHERE\talking-head-anime-4-demo ``` 3. Run the program. ``` bin\run.bat src\tha4\app\distill_ui.py ``` ## Usage Please consult the documentation inside the program itself. It is available on the rightmost panel. ================================================ FILE: docs/full_manual_poser.md ================================================ # `full_manual_poser` This program uses the full version of the Talking Head(?) Anime 4 system to animate character images. ## Invoking the Program Make sure you have (1) created a Python environment and (2) downloaded model files as instruction in the [main README file](../README.md). ### Instruction for Linux/OSX Users 1. Open a shell. 2. `cd` to the repository's directory. ``` cd SOMEWHERE/talking-head-anime-4-demo ``` 3. Run the program. ``` bin/run src/tha4/app/full_manual_poser.py ``` ### Instruction for Windows Users 1. Open a shell. 2. `cd` to the repository's directory. ``` cd SOMEWHERE\talking-head-anime-4-demo ``` 3. Run the program. ``` bin\run.bat src\tha4\app\full_manual_poser.py ``` ================================================ FILE: poetry/README.md ================================================ ================================================ FILE: poetry/pyproject.toml ================================================ [tool.poetry] name = "talking-head-anime-4-demo" version = "0.1.0" description = "Demo code for Talking Head(?) Anime 4" authors = ["Pramook Khungurn "] readme = "README.md" packages = [ {include = "tha4", from = "../src"}, ] [tool.poetry.dependencies] python = ">=3.10, <3.11" torch = {version = "1.13.1", source = "torch_cu117"} torchvision = {version = "0.14.1", source = "torch_cu117"} tensorboard = "^2.15.1" opencv-python = "^4.8.1.78" wxpython = "^4.2.1" numpy-quaternion = "^2022.4.2" pillow = "^9.4.0" matplotlib = "^3.6.3" einops = "^0.6.0" mediapipe = "^0.10.3" numpy = "^1.26.3" scipy = "^1.12.0" omegaconf = "^2.3.0" [[tool.poetry.source]] name = "torch_cu117" url = "https://download.pytorch.org/whl/cu117" priority = "explicit" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" ================================================ FILE: src/tha4/__init__.py ================================================ ================================================ FILE: src/tha4/app/__init__.py ================================================ ================================================ FILE: src/tha4/app/character_model_ifacialmocap_puppeteer.py ================================================ import os import socket import sys import threading import time from typing import Optional import PIL.Image from tha4.shion.base.image_util import torch_linear_to_srgb from tha4.image_util import convert_linear_to_srgb from tha4.mocap.ifacialmocap_pose_converter_25 import create_ifacialmocap_pose_converter from tha4.app.full_manual_poser import resize_PIL_image from tha4.charmodel.character_model import CharacterModel sys.path.append(os.getcwd()) from tha4.mocap.ifacialmocap_pose import create_default_ifacialmocap_pose from tha4.mocap.ifacialmocap_v2 import IFACIALMOCAP_PORT, IFACIALMOCAP_START_STRING, parse_ifacialmocap_v2_pose import torch import wx from tha4.mocap.ifacialmocap_constants import * from tha4.mocap.ifacialmocap_pose_converter import IFacialMocapPoseConverter class FpsStatistics: def __init__(self): self.count = 100 self.fps = [] def add_fps(self, fps): self.fps.append(fps) while len(self.fps) > self.count: del self.fps[0] def get_average_fps(self): if len(self.fps) == 0: return 0.0 else: return sum(self.fps) / len(self.fps) class MainFrame(wx.Frame): IMAGE_SIZE = 512 def __init__(self, pose_converter: IFacialMocapPoseConverter, device: torch.device): super().__init__(None, wx.ID_ANY, "iFacialMocap Puppeteer (Fuji)") self.poser = None self.pose_converter = pose_converter self.device = device self.ifacialmocap_pose = create_default_ifacialmocap_pose() self.source_image_bitmap = wx.Bitmap(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE) self.result_image_bitmap = wx.Bitmap(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE) self.wx_source_image = None self.torch_source_image = None self.last_pose = None self.fps_statistics = FpsStatistics() self.last_update_time = None self.create_receiving_socket() self.create_ui() self.create_timers() self.Bind(wx.EVT_CLOSE, self.on_close) self.update_source_image_bitmap() self.update_result_image_bitmap() def create_receiving_socket(self): self.receiving_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.receiving_socket.bind(("", IFACIALMOCAP_PORT)) self.receiving_socket.setblocking(False) def create_timers(self): self.capture_timer = wx.Timer(self, wx.ID_ANY) self.Bind(wx.EVT_TIMER, self.update_capture_panel, id=self.capture_timer.GetId()) self.animation_timer = wx.Timer(self, wx.ID_ANY) self.Bind(wx.EVT_TIMER, self.update_result_image_bitmap, id=self.animation_timer.GetId()) def on_close(self, event: wx.Event): # Stop the timers self.animation_timer.Stop() self.capture_timer.Stop() # Close receiving socket self.receiving_socket.close() # Destroy the windows self.Destroy() event.Skip() def on_start_capture(self, event: wx.Event): capture_device_ip_address = self.capture_device_ip_text_ctrl.GetValue() out_socket = None try: address = (capture_device_ip_address, IFACIALMOCAP_PORT) out_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) out_socket.sendto(IFACIALMOCAP_START_STRING, address) except Exception as e: message_dialog = wx.MessageDialog(self, str(e), "Error!", wx.OK) message_dialog.ShowModal() message_dialog.Destroy() finally: if out_socket is not None: out_socket.close() def read_ifacialmocap_pose(self): if not self.animation_timer.IsRunning(): return self.ifacialmocap_pose socket_bytes = None while True: try: socket_bytes = self.receiving_socket.recv(8192) except socket.error as e: break if socket_bytes is not None: socket_string = socket_bytes.decode("utf-8") self.ifacialmocap_pose = parse_ifacialmocap_v2_pose(socket_string) return self.ifacialmocap_pose def on_erase_background(self, event: wx.Event): pass def create_animation_panel(self, parent): self.animation_panel = wx.Panel(parent, style=wx.RAISED_BORDER) self.animation_panel_sizer = wx.BoxSizer(wx.HORIZONTAL) self.animation_panel.SetSizer(self.animation_panel_sizer) self.animation_panel.SetAutoLayout(1) image_size = MainFrame.IMAGE_SIZE if True: self.input_panel = wx.Panel(self.animation_panel, size=(image_size, image_size + 128), style=wx.SIMPLE_BORDER) self.input_panel_sizer = wx.BoxSizer(wx.VERTICAL) self.input_panel.SetSizer(self.input_panel_sizer) self.input_panel.SetAutoLayout(1) self.animation_panel_sizer.Add(self.input_panel, 0, wx.FIXED_MINSIZE) self.source_image_panel = wx.Panel(self.input_panel, size=(image_size, image_size), style=wx.SIMPLE_BORDER) self.source_image_panel.Bind(wx.EVT_PAINT, self.paint_source_image_panel) self.source_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background) self.input_panel_sizer.Add(self.source_image_panel, 0, wx.FIXED_MINSIZE) self.load_model_button = wx.Button(self.input_panel, wx.ID_ANY, "Load Model") self.input_panel_sizer.Add(self.load_model_button, 1, wx.EXPAND) self.load_model_button.Bind(wx.EVT_BUTTON, self.load_model) self.input_panel_sizer.Fit(self.input_panel) if True: self.pose_converter.init_pose_converter_panel(self.animation_panel) if True: self.animation_left_panel = wx.Panel(self.animation_panel, style=wx.SIMPLE_BORDER) self.animation_left_panel_sizer = wx.BoxSizer(wx.VERTICAL) self.animation_left_panel.SetSizer(self.animation_left_panel_sizer) self.animation_left_panel.SetAutoLayout(1) self.animation_panel_sizer.Add(self.animation_left_panel, 0, wx.EXPAND) self.result_image_panel = wx.Panel(self.animation_left_panel, size=(image_size, image_size), style=wx.SIMPLE_BORDER) self.result_image_panel.Bind(wx.EVT_PAINT, self.paint_result_image_panel) self.result_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background) self.animation_left_panel_sizer.Add(self.result_image_panel, 0, wx.FIXED_MINSIZE) separator = wx.StaticLine(self.animation_left_panel, -1, size=(256, 5)) self.animation_left_panel_sizer.Add(separator, 0, wx.EXPAND) background_text = wx.StaticText(self.animation_left_panel, label="--- Background ---", style=wx.ALIGN_CENTER) self.animation_left_panel_sizer.Add(background_text, 0, wx.EXPAND) self.output_background_choice = wx.Choice( self.animation_left_panel, choices=[ "TRANSPARENT", "GREEN", "BLUE", "BLACK", "WHITE" ]) self.output_background_choice.SetSelection(0) self.animation_left_panel_sizer.Add(self.output_background_choice, 0, wx.EXPAND) separator = wx.StaticLine(self.animation_left_panel, -1, size=(256, 5)) self.animation_left_panel_sizer.Add(separator, 0, wx.EXPAND) self.fps_text = wx.StaticText(self.animation_left_panel, label="") self.animation_left_panel_sizer.Add(self.fps_text, wx.SizerFlags().Border()) self.animation_left_panel_sizer.Fit(self.animation_left_panel) self.animation_panel_sizer.Fit(self.animation_panel) def create_ui(self): self.main_sizer = wx.BoxSizer(wx.VERTICAL) self.SetSizer(self.main_sizer) self.SetAutoLayout(1) self.capture_pose_lock = threading.Lock() self.create_connection_panel(self) self.main_sizer.Add(self.connection_panel, wx.SizerFlags(0).Expand().Border(wx.ALL, 5)) self.create_animation_panel(self) self.main_sizer.Add(self.animation_panel, wx.SizerFlags(0).Expand().Border(wx.ALL, 5)) self.create_capture_panel(self) self.main_sizer.Add(self.capture_panel, wx.SizerFlags(0).Expand().Border(wx.ALL, 5)) self.main_sizer.Fit(self) def create_connection_panel(self, parent): self.connection_panel = wx.Panel(parent, style=wx.RAISED_BORDER) self.connection_panel_sizer = wx.BoxSizer(wx.HORIZONTAL) self.connection_panel.SetSizer(self.connection_panel_sizer) self.connection_panel.SetAutoLayout(1) capture_device_ip_text = wx.StaticText(self.connection_panel, label="Capture Device IP:", style=wx.ALIGN_RIGHT) self.connection_panel_sizer.Add(capture_device_ip_text, wx.SizerFlags(0).FixedMinSize().Border(wx.ALL, 3)) self.capture_device_ip_text_ctrl = wx.TextCtrl(self.connection_panel, value="192.168.0.1") self.connection_panel_sizer.Add(self.capture_device_ip_text_ctrl, wx.SizerFlags(1).Expand().Border(wx.ALL, 3)) self.start_capture_button = wx.Button(self.connection_panel, label="START CAPTURE!") self.connection_panel_sizer.Add(self.start_capture_button, wx.SizerFlags(0).FixedMinSize().Border(wx.ALL, 3)) self.start_capture_button.Bind(wx.EVT_BUTTON, self.on_start_capture) def create_capture_panel(self, parent): self.capture_panel = wx.Panel(parent, style=wx.RAISED_BORDER) self.capture_panel_sizer = wx.FlexGridSizer(cols=5) for i in range(5): self.capture_panel_sizer.AddGrowableCol(i) self.capture_panel.SetSizer(self.capture_panel_sizer) self.capture_panel.SetAutoLayout(1) self.rotation_labels = {} self.rotation_value_labels = {} rotation_column_0 = self.create_rotation_column(self.capture_panel, RIGHT_EYE_BONE_ROTATIONS) self.capture_panel_sizer.Add(rotation_column_0, wx.SizerFlags(0).Expand().Border(wx.ALL, 3)) rotation_column_1 = self.create_rotation_column(self.capture_panel, LEFT_EYE_BONE_ROTATIONS) self.capture_panel_sizer.Add(rotation_column_1, wx.SizerFlags(0).Expand().Border(wx.ALL, 3)) rotation_column_2 = self.create_rotation_column(self.capture_panel, HEAD_BONE_ROTATIONS) self.capture_panel_sizer.Add(rotation_column_2, wx.SizerFlags(0).Expand().Border(wx.ALL, 3)) def create_rotation_column(self, parent, rotation_names): column_panel = wx.Panel(parent, style=wx.SIMPLE_BORDER) column_panel_sizer = wx.FlexGridSizer(cols=2) column_panel_sizer.AddGrowableCol(1) column_panel.SetSizer(column_panel_sizer) column_panel.SetAutoLayout(1) for rotation_name in rotation_names: self.rotation_labels[rotation_name] = wx.StaticText( column_panel, label=rotation_name, style=wx.ALIGN_RIGHT) column_panel_sizer.Add(self.rotation_labels[rotation_name], wx.SizerFlags(1).Expand().Border(wx.ALL, 3)) self.rotation_value_labels[rotation_name] = wx.TextCtrl( column_panel, style=wx.TE_RIGHT) self.rotation_value_labels[rotation_name].SetValue("0.00") self.rotation_value_labels[rotation_name].Disable() column_panel_sizer.Add(self.rotation_value_labels[rotation_name], wx.SizerFlags(1).Expand().Border(wx.ALL, 3)) column_panel.GetSizer().Fit(column_panel) return column_panel def paint_capture_panel(self, event: wx.Event): self.update_capture_panel(event) def update_capture_panel(self, event: wx.Event): data = self.ifacialmocap_pose for rotation_name in ROTATION_NAMES: value = data[rotation_name] self.rotation_value_labels[rotation_name].SetValue("%0.2f" % value) @staticmethod def convert_to_100(x): return int(max(0.0, min(1.0, x)) * 100) def paint_source_image_panel(self, event: wx.Event): wx.BufferedPaintDC(self.source_image_panel, self.source_image_bitmap) def update_source_image_bitmap(self): dc = wx.MemoryDC() dc.SelectObject(self.source_image_bitmap) if self.wx_source_image is None: self.draw_nothing_yet_string(dc) else: dc.Clear() dc.DrawBitmap(self.wx_source_image, 0, 0, True) del dc def draw_nothing_yet_string(self, dc): dc.Clear() font = wx.Font(wx.FontInfo(14).Family(wx.FONTFAMILY_SWISS)) dc.SetFont(font) w, h = dc.GetTextExtent("Nothing yet!") dc.DrawText("Nothing yet!", (MainFrame.IMAGE_SIZE - w) // 2, (MainFrame.IMAGE_SIZE - h) // 2) def paint_result_image_panel(self, event: wx.Event): wx.BufferedPaintDC(self.result_image_panel, self.result_image_bitmap) def update_result_image_bitmap(self, event: Optional[wx.Event] = None): ifacialmocap_pose = self.read_ifacialmocap_pose() current_pose = self.pose_converter.convert(ifacialmocap_pose) if self.last_pose is not None and self.last_pose == current_pose: return self.last_pose = current_pose if self.torch_source_image is None or self.poser is None: dc = wx.MemoryDC() dc.SelectObject(self.result_image_bitmap) self.draw_nothing_yet_string(dc) del dc return pose = torch.tensor(current_pose, device=self.device, dtype=self.poser.get_dtype()) with torch.no_grad(): output_image = self.poser.pose(self.torch_source_image, pose)[0].float() output_image = torch.clip((output_image + 1.0) / 2.0, 0.0, 1.0) output_image = convert_linear_to_srgb(output_image) background_choice = self.output_background_choice.GetSelection() if background_choice == 0: pass else: background = torch.zeros(4, output_image.shape[1], output_image.shape[2], device=self.device) background[3, :, :] = 1.0 if background_choice == 1: background[1, :, :] = 1.0 output_image = self.blend_with_background(output_image, background) elif background_choice == 2: background[2, :, :] = 1.0 output_image = self.blend_with_background(output_image, background) elif background_choice == 3: output_image = self.blend_with_background(output_image, background) else: background[0:3, :, :] = 1.0 output_image = self.blend_with_background(output_image, background) c, h, w = output_image.shape output_image = 255.0 * torch.transpose(output_image.reshape(c, h * w), 0, 1).reshape(h, w, c) output_image = output_image.byte() numpy_image = output_image.detach().cpu().numpy() wx_image = wx.ImageFromBuffer(numpy_image.shape[0], numpy_image.shape[1], numpy_image[:, :, 0:3].tobytes(), numpy_image[:, :, 3].tobytes()) wx_bitmap = wx_image.ConvertToBitmap() dc = wx.MemoryDC() dc.SelectObject(self.result_image_bitmap) dc.Clear() dc.DrawBitmap(wx_bitmap, (MainFrame.IMAGE_SIZE - numpy_image.shape[0]) // 2, (MainFrame.IMAGE_SIZE - numpy_image.shape[1]) // 2, True) del dc time_now = time.time_ns() if self.last_update_time is not None: elapsed_time = time_now - self.last_update_time fps = 1.0 / (elapsed_time / 10 ** 9) if self.torch_source_image is not None: self.fps_statistics.add_fps(fps) self.fps_text.SetLabelText("FPS = %0.2f" % self.fps_statistics.get_average_fps()) self.last_update_time = time_now self.Refresh() def blend_with_background(self, numpy_image, background): alpha = numpy_image[3:4, :, :] color = numpy_image[0:3, :, :] new_color = color * alpha + (1.0 - alpha) * background[0:3, :, :] return torch.cat([new_color, background[3:4, :, :]], dim=0) def load_model(self, event: wx.Event): dir_name = "data/character_models" file_dialog = wx.FileDialog(self, "Choose a model", dir_name, "", "*.yaml", wx.FD_OPEN) if file_dialog.ShowModal() == wx.ID_OK: character_model_json_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename()) try: self.character_model = CharacterModel.load(character_model_json_file_name) self.torch_source_image = self.character_model.get_character_image(self.device) pil_image = resize_PIL_image( PIL.Image.open(self.character_model.character_image_file_name), (MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE)) w, h = pil_image.size self.wx_source_image = wx.Bitmap.FromBufferRGBA(w, h, pil_image.convert("RGBA").tobytes()) self.update_source_image_bitmap() self.poser = self.character_model.get_poser(self.device) except Exception: message_dialog = wx.MessageDialog( self, "Could not load character model " + character_model_json_file_name, "Poser", wx.OK) message_dialog.ShowModal() message_dialog.Destroy() file_dialog.Destroy() self.Refresh() if __name__ == "__main__": device = torch.device('cuda:0') pose_converter = create_ifacialmocap_pose_converter() app = wx.App() main_frame = MainFrame(pose_converter, device) main_frame.Show(True) main_frame.capture_timer.Start(10) main_frame.animation_timer.Start(10) app.MainLoop() ================================================ FILE: src/tha4/app/character_model_manual_poser.py ================================================ import logging import os import sys import time from typing import List from tha4.charmodel.character_model import CharacterModel from tha4.image_util import resize_PIL_image, convert_output_image_from_torch_to_numpy from tha4.poser.modes.mode_14 import get_pose_parameters sys.path.append(os.getcwd()) import PIL.Image import torch import wx from tha4.poser.poser import PoseParameterCategory, PoseParameterGroup class MorphCategoryControlPanel(wx.Panel): def __init__(self, parent, title: str, pose_param_category: PoseParameterCategory, param_groups: List[PoseParameterGroup]): super().__init__(parent, style=wx.SIMPLE_BORDER) self.pose_param_category = pose_param_category self.sizer = wx.BoxSizer(wx.VERTICAL) self.SetSizer(self.sizer) self.SetAutoLayout(1) title_text = wx.StaticText(self, label=title, style=wx.ALIGN_CENTER) self.sizer.Add(title_text, 0, wx.EXPAND) self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category] self.choice = wx.Choice(self, choices=[group.get_group_name() for group in self.param_groups]) if len(self.param_groups) > 0: self.choice.SetSelection(0) self.choice.Bind(wx.EVT_CHOICE, self.on_choice_updated) self.sizer.Add(self.choice, 0, wx.EXPAND) self.left_slider = wx.Slider(self, minValue=-1000, maxValue=1000, value=-1000, style=wx.HORIZONTAL) self.sizer.Add(self.left_slider, 0, wx.EXPAND) self.right_slider = wx.Slider(self, minValue=-1000, maxValue=1000, value=-1000, style=wx.HORIZONTAL) self.sizer.Add(self.right_slider, 0, wx.EXPAND) self.checkbox = wx.CheckBox(self, label="Show") self.checkbox.SetValue(True) self.sizer.Add(self.checkbox, 0, wx.SHAPED | wx.ALIGN_CENTER) self.update_ui() self.sizer.Fit(self) def update_ui(self): param_group = self.param_groups[self.choice.GetSelection()] if param_group.is_discrete(): self.left_slider.Enable(False) self.right_slider.Enable(False) self.checkbox.Enable(True) elif param_group.get_arity() == 1: self.left_slider.Enable(True) self.right_slider.Enable(False) self.checkbox.Enable(False) else: self.left_slider.Enable(True) self.right_slider.Enable(True) self.checkbox.Enable(False) def on_choice_updated(self, event: wx.Event): param_group = self.param_groups[self.choice.GetSelection()] if param_group.is_discrete(): self.checkbox.SetValue(True) self.update_ui() def set_param_value(self, pose: List[float]): if len(self.param_groups) == 0: return selected_morph_index = self.choice.GetSelection() param_group = self.param_groups[selected_morph_index] param_index = param_group.get_parameter_index() if param_group.is_discrete(): if self.checkbox.GetValue(): for i in range(param_group.get_arity()): pose[param_index + i] = 1.0 else: param_range = param_group.get_range() alpha = (self.left_slider.GetValue() + 1000) / 2000.0 pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha if param_group.get_arity() == 2: alpha = (self.right_slider.GetValue() + 1000) / 2000.0 pose[param_index + 1] = param_range[0] + (param_range[1] - param_range[0]) * alpha class SimpleParamGroupsControlPanel(wx.Panel): def __init__(self, parent, pose_param_category: PoseParameterCategory, param_groups: List[PoseParameterGroup]): super().__init__(parent, style=wx.SIMPLE_BORDER) self.sizer = wx.BoxSizer(wx.VERTICAL) self.SetSizer(self.sizer) self.SetAutoLayout(1) self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category] for param_group in self.param_groups: assert not param_group.is_discrete() assert param_group.get_arity() == 1 self.sliders = [] for param_group in self.param_groups: static_text = wx.StaticText( self, label=" ------------ %s ------------ " % param_group.get_group_name(), style=wx.ALIGN_CENTER) self.sizer.Add(static_text, 0, wx.EXPAND) range = param_group.get_range() min_value = int(range[0] * 1000) max_value = int(range[1] * 1000) slider = wx.Slider(self, minValue=min_value, maxValue=max_value, value=0, style=wx.HORIZONTAL) self.sizer.Add(slider, 0, wx.EXPAND) self.sliders.append(slider) self.sizer.Fit(self) def set_param_value(self, pose: List[float]): if len(self.param_groups) == 0: return for param_group_index in range(len(self.param_groups)): param_group = self.param_groups[param_group_index] slider = self.sliders[param_group_index] param_range = param_group.get_range() param_index = param_group.get_parameter_index() alpha = (slider.GetValue() - slider.GetMin()) * 1.0 / (slider.GetMax() - slider.GetMin()) pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha class MainFrame(wx.Frame): IMAGE_SIZE = 512 OUTPUT_LENGTH = 6 NUM_PARAMETERS = 45 def __init__(self, device: torch.device): super().__init__(None, wx.ID_ANY, "Poser") self.poser = None self.device = device self.wx_source_image = None self.torch_source_image = None self.main_sizer = wx.BoxSizer(wx.HORIZONTAL) self.SetSizer(self.main_sizer) self.SetAutoLayout(1) self.init_left_panel() self.init_control_panel() self.init_right_panel() self.main_sizer.Fit(self) self.timer = wx.Timer(self, wx.ID_ANY) self.Bind(wx.EVT_TIMER, self.update_images, self.timer) save_image_id = wx.NewIdRef() self.Bind(wx.EVT_MENU, self.on_save_image, id=save_image_id) accelerator_table = wx.AcceleratorTable([ (wx.ACCEL_CTRL, ord('S'), save_image_id) ]) self.SetAcceleratorTable(accelerator_table) self.last_pose = None self.last_output_index = self.output_index_choice.GetSelection() self.last_output_numpy_image = None self.wx_source_image = None self.torch_source_image = None self.source_image_bitmap = wx.Bitmap(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE) self.result_image_bitmap = wx.Bitmap(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE) self.source_image_dirty = True def init_left_panel(self): self.control_panel = wx.Panel(self, style=wx.SIMPLE_BORDER, size=(MainFrame.IMAGE_SIZE, -1)) self.left_panel = wx.Panel(self, style=wx.SIMPLE_BORDER) left_panel_sizer = wx.BoxSizer(wx.VERTICAL) self.left_panel.SetSizer(left_panel_sizer) self.left_panel.SetAutoLayout(1) self.source_image_panel = wx.Panel(self.left_panel, size=(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE), style=wx.SIMPLE_BORDER) self.source_image_panel.Bind(wx.EVT_PAINT, self.paint_source_image_panel) self.source_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background) left_panel_sizer.Add(self.source_image_panel, 0, wx.FIXED_MINSIZE) self.load_model_button = wx.Button(self.left_panel, wx.ID_ANY, "\nLoad Model\n\n") left_panel_sizer.Add(self.load_model_button, 1, wx.EXPAND) self.load_model_button.Bind(wx.EVT_BUTTON, self.load_model) left_panel_sizer.Fit(self.left_panel) self.main_sizer.Add(self.left_panel, 0, wx.FIXED_MINSIZE) def on_erase_background(self, event: wx.Event): pass def init_control_panel(self): self.control_panel_sizer = wx.BoxSizer(wx.VERTICAL) self.control_panel.SetSizer(self.control_panel_sizer) self.control_panel.SetMinSize(wx.Size(256, 1)) morph_categories = [ PoseParameterCategory.EYEBROW, PoseParameterCategory.EYE, PoseParameterCategory.MOUTH, PoseParameterCategory.IRIS_MORPH ] morph_category_titles = { PoseParameterCategory.EYEBROW: " ------------ Eyebrow ------------ ", PoseParameterCategory.EYE: " ------------ Eye ------------ ", PoseParameterCategory.MOUTH: " ------------ Mouth ------------ ", PoseParameterCategory.IRIS_MORPH: " ------------ Iris morphs ------------ ", } self.morph_control_panels = {} param_groups = get_pose_parameters().get_pose_parameter_groups() for category in morph_categories: filtered_param_groups = [group for group in param_groups if group.get_category() == category] if len(filtered_param_groups) == 0: continue control_panel = MorphCategoryControlPanel( self.control_panel, morph_category_titles[category], category, param_groups) self.morph_control_panels[category] = control_panel self.control_panel_sizer.Add(control_panel, 0, wx.EXPAND) self.non_morph_control_panels = {} non_morph_categories = [ PoseParameterCategory.IRIS_ROTATION, PoseParameterCategory.FACE_ROTATION, PoseParameterCategory.BODY_ROTATION, PoseParameterCategory.BREATHING ] for category in non_morph_categories: filtered_param_groups = [group for group in param_groups if group.get_category() == category] if len(filtered_param_groups) == 0: continue control_panel = SimpleParamGroupsControlPanel( self.control_panel, category, param_groups) self.non_morph_control_panels[category] = control_panel self.control_panel_sizer.Add(control_panel, 0, wx.EXPAND) self.control_panel_sizer.Fit(self.control_panel) self.main_sizer.Add(self.control_panel, 1, wx.FIXED_MINSIZE) def init_right_panel(self): self.right_panel = wx.Panel(self, style=wx.SIMPLE_BORDER) right_panel_sizer = wx.BoxSizer(wx.VERTICAL) self.right_panel.SetSizer(right_panel_sizer) self.right_panel.SetAutoLayout(1) self.result_image_panel = wx.Panel(self.right_panel, size=(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE), style=wx.SIMPLE_BORDER) self.result_image_panel.Bind(wx.EVT_PAINT, self.paint_result_image_panel) self.result_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background) self.output_index_choice = wx.Choice( self.right_panel, choices=[str(i) for i in range(MainFrame.OUTPUT_LENGTH)]) self.output_index_choice.SetSelection(0) right_panel_sizer.Add(self.result_image_panel, 0, wx.FIXED_MINSIZE) right_panel_sizer.Add(self.output_index_choice, 0, wx.EXPAND) self.save_image_button = wx.Button(self.right_panel, wx.ID_ANY, "\nSave Image\n\n") right_panel_sizer.Add(self.save_image_button, 1, wx.EXPAND) self.save_image_button.Bind(wx.EVT_BUTTON, self.on_save_image) right_panel_sizer.Fit(self.right_panel) self.main_sizer.Add(self.right_panel, 0, wx.FIXED_MINSIZE) def create_param_category_choice(self, param_category: PoseParameterCategory): params = [] for param_group in self.poser.get_pose_parameter_groups(): if param_group.get_category() == param_category: params.append(param_group.get_group_name()) choice = wx.Choice(self.control_panel, choices=params) if len(params) > 0: choice.SetSelection(0) return choice def load_model(self, event: wx.Event): dir_name = "data/character_models" file_dialog = wx.FileDialog(self, "Choose a model", dir_name, "", "*.yaml", wx.FD_OPEN) if file_dialog.ShowModal() == wx.ID_OK: character_model_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename()) try: self.character_model = CharacterModel.load(character_model_file_name) self.torch_source_image = self.character_model.get_character_image(self.device) pil_image = resize_PIL_image( PIL.Image.open(self.character_model.character_image_file_name), (MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE)) w, h = pil_image.size self.wx_source_image = wx.Bitmap.FromBufferRGBA(w, h, pil_image.convert("RGBA").tobytes()) self.poser = self.character_model.get_poser(self.device) self.source_image_dirty = True self.Refresh() self.Update() except RuntimeError as e: message_dialog = wx.MessageDialog( self, "Could not load character model " + character_model_file_name, "Poser", wx.OK) message_dialog.ShowModal() message_dialog.Destroy() file_dialog.Destroy() def paint_source_image_panel(self, event: wx.Event): wx.BufferedPaintDC(self.source_image_panel, self.source_image_bitmap) def paint_result_image_panel(self, event: wx.Event): wx.BufferedPaintDC(self.result_image_panel, self.result_image_bitmap) def draw_nothing_yet_string_to_bitmap(self, bitmap): dc = wx.MemoryDC() dc.SelectObject(bitmap) dc.Clear() font = wx.Font(wx.FontInfo(14).Family(wx.FONTFAMILY_SWISS)) dc.SetFont(font) w, h = dc.GetTextExtent("Nothing yet!") dc.DrawText("Nothing yet!", (MainFrame.IMAGE_SIZE - w) // 2, (MainFrame.IMAGE_SIZE - - h) // 2) del dc def get_current_pose(self): current_pose = [0.0 for i in range(MainFrame.NUM_PARAMETERS)] for morph_control_panel in self.morph_control_panels.values(): morph_control_panel.set_param_value(current_pose) for rotation_control_panel in self.non_morph_control_panels.values(): rotation_control_panel.set_param_value(current_pose) return current_pose def update_images(self, event: wx.Event): current_pose = self.get_current_pose() if not self.source_image_dirty \ and self.last_pose is not None \ and self.last_pose == current_pose \ and self.last_output_index == self.output_index_choice.GetSelection(): return self.last_pose = current_pose self.last_output_index = self.output_index_choice.GetSelection() if self.torch_source_image is None or self.poser is None: self.draw_nothing_yet_string_to_bitmap(self.source_image_bitmap) self.draw_nothing_yet_string_to_bitmap(self.result_image_bitmap) self.source_image_dirty = False self.Refresh() self.Update() return if self.source_image_dirty: dc = wx.MemoryDC() dc.SelectObject(self.source_image_bitmap) dc.Clear() dc.DrawBitmap(self.wx_source_image, 0, 0) self.source_image_dirty = False pose = torch.tensor(current_pose, device=self.device) output_index = self.output_index_choice.GetSelection() with torch.no_grad(): start_cuda_event = torch.cuda.Event(enable_timing=True) end_cuda_event = torch.cuda.Event(enable_timing=True) start_cuda_event.record() start_time = time.time() output_image = self.poser.pose(self.torch_source_image, pose, output_index)[0].detach().cpu() end_time = time.time() end_cuda_event.record() torch.cuda.synchronize() print("cuda time (ms):", start_cuda_event.elapsed_time(end_cuda_event)) print("elapsed time (ms):", (end_time - start_time) * 1000.0) numpy_image = convert_output_image_from_torch_to_numpy(output_image) self.last_output_numpy_image = numpy_image wx_image = wx.ImageFromBuffer( numpy_image.shape[0], numpy_image.shape[1], numpy_image[:, :, 0:3].tobytes(), numpy_image[:, :, 3].tobytes()) wx_bitmap = wx_image.ConvertToBitmap() dc = wx.MemoryDC() dc.SelectObject(self.result_image_bitmap) dc.Clear() dc.DrawBitmap(wx_bitmap, (MainFrame.IMAGE_SIZE - numpy_image.shape[0]) // 2, (MainFrame.IMAGE_SIZE - numpy_image.shape[1]) // 2, True) del dc self.Refresh() self.Update() def on_save_image(self, event: wx.Event): if self.last_output_numpy_image is None: logging.info("There is no output image to save!!!") return dir_name = "data/images" file_dialog = wx.FileDialog(self, "Choose an image", dir_name, "", "*.png", wx.FD_SAVE) if file_dialog.ShowModal() == wx.ID_OK: image_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename()) try: if os.path.exists(image_file_name): message_dialog = wx.MessageDialog(self, f"Override {image_file_name}", "Manual Poser", wx.YES_NO | wx.ICON_QUESTION) result = message_dialog.ShowModal() if result == wx.ID_YES: self.save_last_numpy_image(image_file_name) message_dialog.Destroy() else: self.save_last_numpy_image(image_file_name) except: message_dialog = wx.MessageDialog(self, f"Could not save {image_file_name}", "Manual Poser", wx.OK) message_dialog.ShowModal() message_dialog.Destroy() file_dialog.Destroy() def save_last_numpy_image(self, image_file_name): numpy_image = self.last_output_numpy_image pil_image = PIL.Image.fromarray(numpy_image, mode='RGBA') os.makedirs(os.path.dirname(image_file_name), exist_ok=True) pil_image.save(image_file_name) if __name__ == "__main__": device = torch.device('cuda:0') app = wx.App() main_frame = MainFrame(device) main_frame.Show(True) main_frame.timer.Start(16) app.MainLoop() ================================================ FILE: src/tha4/app/character_model_mediapipe_puppeteer.py ================================================ import os import sys import threading import time from typing import Optional import PIL.Image import cv2 import mediapipe from scipy.spatial.transform import Rotation from tha4.shion.base.image_util import resize_PIL_image from tha4.charmodel.character_model import CharacterModel from tha4.image_util import convert_linear_to_srgb from tha4.mocap.mediapipe_constants import HEAD_ROTATIONS, HEAD_X, HEAD_Y, HEAD_Z from tha4.mocap.mediapipe_face_pose import MediaPipeFacePose from tha4.mocap.mediapipe_face_pose_converter_00 import MediaPoseFacePoseConverter00 sys.path.append(os.getcwd()) import torch import wx class FpsStatistics: def __init__(self): self.count = 100 self.fps = [] def add_fps(self, fps): self.fps.append(fps) while len(self.fps) > self.count: del self.fps[0] def get_average_fps(self): if len(self.fps) == 0: return 0.0 else: return sum(self.fps) / len(self.fps) class MainFrame(wx.Frame): IMAGE_SIZE = 512 def __init__(self, pose_converter: MediaPoseFacePoseConverter00, video_capture, face_landmarker, device: torch.device): super().__init__(None, wx.ID_ANY, "THA4 Character Model MediaPipe Puppeteer") self.face_landmarker = face_landmarker self.video_capture = video_capture self.pose_converter = pose_converter self.device = device self.source_image_bitmap = wx.Bitmap(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE) self.result_image_bitmap = wx.Bitmap(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE) self.webcam_capture_bitmap = wx.Bitmap(256, 192) self.wx_source_image = None self.torch_source_image = None self.last_pose = None self.mediapipe_face_pose = None self.fps_statistics = FpsStatistics() self.last_update_time = None self.character_model = None self.poser = None self.create_ui() self.create_timers() self.Bind(wx.EVT_CLOSE, self.on_close) self.update_source_image_bitmap() self.update_result_image_bitmap() def create_timers(self): self.capture_timer = wx.Timer(self, wx.ID_ANY) self.Bind(wx.EVT_TIMER, self.update_capture_panel, id=self.capture_timer.GetId()) self.animation_timer = wx.Timer(self, wx.ID_ANY) self.Bind(wx.EVT_TIMER, self.update_result_image_bitmap, id=self.animation_timer.GetId()) def on_close(self, event: wx.Event): # Stop the timers self.animation_timer.Stop() self.capture_timer.Stop() # Destroy the windows self.Destroy() event.Skip() def on_erase_background(self, event: wx.Event): pass def create_animation_panel(self, parent): self.animation_panel = wx.Panel(parent, style=wx.RAISED_BORDER) self.animation_panel_sizer = wx.BoxSizer(wx.HORIZONTAL) self.animation_panel.SetSizer(self.animation_panel_sizer) self.animation_panel.SetAutoLayout(1) image_size = MainFrame.IMAGE_SIZE if True: self.input_panel = wx.Panel(self.animation_panel, size=(image_size, image_size + 128), style=wx.SIMPLE_BORDER) self.input_panel_sizer = wx.BoxSizer(wx.VERTICAL) self.input_panel.SetSizer(self.input_panel_sizer) self.input_panel.SetAutoLayout(1) self.animation_panel_sizer.Add(self.input_panel, 0, wx.FIXED_MINSIZE) self.source_image_panel = wx.Panel(self.input_panel, size=(image_size, image_size), style=wx.SIMPLE_BORDER) self.source_image_panel.Bind(wx.EVT_PAINT, self.paint_source_image_panel) self.source_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background) self.input_panel_sizer.Add(self.source_image_panel, 0, wx.FIXED_MINSIZE) self.load_model_button = wx.Button(self.input_panel, wx.ID_ANY, "Load Model") self.input_panel_sizer.Add(self.load_model_button, 1, wx.EXPAND) self.load_model_button.Bind(wx.EVT_BUTTON, self.load_model) self.input_panel_sizer.Fit(self.input_panel) if True: def current_pose_supplier() -> Optional[MediaPipeFacePose]: return self.mediapipe_face_pose self.pose_converter.init_pose_converter_panel(self.animation_panel, current_pose_supplier) if True: self.animation_left_panel = wx.Panel(self.animation_panel, style=wx.SIMPLE_BORDER) self.animation_left_panel_sizer = wx.BoxSizer(wx.VERTICAL) self.animation_left_panel.SetSizer(self.animation_left_panel_sizer) self.animation_left_panel.SetAutoLayout(1) self.animation_panel_sizer.Add(self.animation_left_panel, 0, wx.EXPAND) self.result_image_panel = wx.Panel(self.animation_left_panel, size=(image_size, image_size), style=wx.SIMPLE_BORDER) self.result_image_panel.Bind(wx.EVT_PAINT, self.paint_result_image_panel) self.result_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background) self.animation_left_panel_sizer.Add(self.result_image_panel, 0, wx.FIXED_MINSIZE) separator = wx.StaticLine(self.animation_left_panel, -1, size=(256, 5)) self.animation_left_panel_sizer.Add(separator, 0, wx.EXPAND) background_text = wx.StaticText(self.animation_left_panel, label="--- Background ---", style=wx.ALIGN_CENTER) self.animation_left_panel_sizer.Add(background_text, 0, wx.EXPAND) self.output_background_choice = wx.Choice( self.animation_left_panel, choices=[ "TRANSPARENT", "GREEN", "BLUE", "BLACK", "WHITE" ]) self.output_background_choice.SetSelection(0) self.animation_left_panel_sizer.Add(self.output_background_choice, 0, wx.EXPAND) separator = wx.StaticLine(self.animation_left_panel, -1, size=(256, 5)) self.animation_left_panel_sizer.Add(separator, 0, wx.EXPAND) self.fps_text = wx.StaticText(self.animation_left_panel, label="") self.animation_left_panel_sizer.Add(self.fps_text, wx.SizerFlags().Border()) self.animation_left_panel_sizer.Fit(self.animation_left_panel) self.animation_panel_sizer.Fit(self.animation_panel) def create_ui(self): self.main_sizer = wx.BoxSizer(wx.VERTICAL) self.SetSizer(self.main_sizer) self.SetAutoLayout(1) self.capture_pose_lock = threading.Lock() self.create_animation_panel(self) self.main_sizer.Add(self.animation_panel, wx.SizerFlags(0).Expand().Border(wx.ALL, 5)) self.create_capture_panel(self) self.main_sizer.Add(self.capture_panel, wx.SizerFlags(0).Expand().Border(wx.ALL, 5)) self.main_sizer.Fit(self) def create_capture_panel(self, parent): self.capture_panel = wx.Panel(parent, style=wx.RAISED_BORDER) self.capture_panel_sizer = wx.BoxSizer(wx.HORIZONTAL) self.capture_panel.SetSizer(self.capture_panel_sizer) self.capture_panel.SetAutoLayout(1) self.webcam_capture_panel = wx.Panel(self.capture_panel, size=(256, 192), style=wx.SIMPLE_BORDER) self.webcam_capture_panel.Bind(wx.EVT_PAINT, self.paint_webcam_capture_panel) self.webcam_capture_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background) self.capture_panel_sizer.Add(self.webcam_capture_panel, wx.SizerFlags(0).FixedMinSize().Border(wx.ALL, 5)) self.rotation_labels = {} self.rotation_value_labels = {} rotation_column = self.create_rotation_column(self.capture_panel, HEAD_ROTATIONS) self.capture_panel_sizer.Add(rotation_column, wx.SizerFlags(0).Expand().Border(wx.ALL, 3)) def paint_webcam_capture_panel(self, event: wx.Event): wx.BufferedPaintDC(self.webcam_capture_panel, self.webcam_capture_bitmap) def create_rotation_column(self, parent, rotation_names): column_panel = wx.Panel(parent, style=wx.SIMPLE_BORDER) column_panel_sizer = wx.FlexGridSizer(cols=2) column_panel_sizer.AddGrowableCol(1) column_panel.SetSizer(column_panel_sizer) column_panel.SetAutoLayout(1) for rotation_name in rotation_names: self.rotation_labels[rotation_name] = wx.StaticText( column_panel, label=rotation_name, style=wx.ALIGN_RIGHT) column_panel_sizer.Add(self.rotation_labels[rotation_name], wx.SizerFlags(1).Expand().Border(wx.ALL, 3)) self.rotation_value_labels[rotation_name] = wx.TextCtrl( column_panel, style=wx.TE_RIGHT) self.rotation_value_labels[rotation_name].SetValue("0.00") self.rotation_value_labels[rotation_name].Disable() column_panel_sizer.Add(self.rotation_value_labels[rotation_name], wx.SizerFlags(1).Expand().Border(wx.ALL, 3)) column_panel.GetSizer().Fit(column_panel) return column_panel def update_capture_panel(self, event: wx.Event): there_is_frame, frame = self.video_capture.read() if not there_is_frame: dc = wx.MemoryDC() dc.SelectObject(self.webcam_capture_bitmap) self.draw_nothing_yet_string(dc) del dc return rgb_frame = cv2.flip(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), 1) resized_frame = cv2.resize(rgb_frame, (256, 192)) wx_image = wx.ImageFromBuffer(256, 192, resized_frame.tobytes()) wx_bitmap = wx_image.ConvertToBitmap() dc = wx.MemoryDC() dc.SelectObject(self.webcam_capture_bitmap) dc.Clear() dc.DrawBitmap(wx_bitmap, 0, 0, True) del dc self.webcam_capture_panel.Refresh() time_ms = int(time.time() * 1000) mediapipe_image = mediapipe.Image(image_format=mediapipe.ImageFormat.SRGB, data=rgb_frame) detection_result = self.face_landmarker.detect_for_video(mediapipe_image, time_ms) self.update_mediapipe_face_pose(detection_result) def update_mediapipe_face_pose(self, detection_result): if len(detection_result.facial_transformation_matrixes) == 0: return xform_matrix = detection_result.facial_transformation_matrixes[0] blendshape_params = {} for item in detection_result.face_blendshapes[0]: blendshape_params[item.category_name] = item.score M = xform_matrix[0:3, 0:3] rot = Rotation.from_matrix(M) euler_angles = rot.as_euler('xyz', degrees=True) self.rotation_value_labels[HEAD_X].SetValue("%0.2f" % euler_angles[0]) self.rotation_value_labels[HEAD_X].Refresh() self.rotation_value_labels[HEAD_Y].SetValue("%0.2f" % euler_angles[1]) self.rotation_value_labels[HEAD_Y].Refresh() self.rotation_value_labels[HEAD_Z].SetValue("%0.2f" % euler_angles[2]) self.rotation_value_labels[HEAD_Z].Refresh() self.mediapipe_face_pose = MediaPipeFacePose(blendshape_params, xform_matrix) @staticmethod def convert_to_100(x): return int(max(0.0, min(1.0, x)) * 100) def paint_source_image_panel(self, event: wx.Event): wx.BufferedPaintDC(self.source_image_panel, self.source_image_bitmap) def update_source_image_bitmap(self): dc = wx.MemoryDC() dc.SelectObject(self.source_image_bitmap) if self.wx_source_image is None: self.draw_nothing_yet_string(dc) else: dc.Clear() dc.DrawBitmap(self.wx_source_image, 0, 0, True) del dc def draw_nothing_yet_string(self, dc): dc.Clear() font = wx.Font(wx.FontInfo(14).Family(wx.FONTFAMILY_SWISS)) dc.SetFont(font) w, h = dc.GetTextExtent("Nothing yet!") dc.DrawText("Nothing yet!", (MainFrame.IMAGE_SIZE - w) // 2, (MainFrame.IMAGE_SIZE - h) // 2) def paint_result_image_panel(self, event: wx.Event): wx.BufferedPaintDC(self.result_image_panel, self.result_image_bitmap) def update_result_image_bitmap(self, event: Optional[wx.Event] = None): if self.mediapipe_face_pose is None or self.poser is None: dc = wx.MemoryDC() dc.SelectObject(self.result_image_bitmap) self.draw_nothing_yet_string(dc) del dc return current_pose = self.pose_converter.convert(self.mediapipe_face_pose) if self.last_pose is not None and self.last_pose == current_pose: return self.last_pose = current_pose if self.torch_source_image is None: dc = wx.MemoryDC() dc.SelectObject(self.result_image_bitmap) self.draw_nothing_yet_string(dc) del dc return pose = torch.tensor(current_pose, device=self.device, dtype=self.poser.get_dtype()) with torch.no_grad(): output_image = self.poser.pose(self.torch_source_image, pose)[0].float() output_image = torch.clip((output_image + 1.0) / 2.0, 0.0, 1.0) output_image = convert_linear_to_srgb(output_image) background_choice = self.output_background_choice.GetSelection() if background_choice == 0: pass else: background = torch.zeros(4, output_image.shape[1], output_image.shape[2], device=self.device) background[3, :, :] = 1.0 if background_choice == 1: background[1, :, :] = 1.0 output_image = self.blend_with_background(output_image, background) elif background_choice == 2: background[2, :, :] = 1.0 output_image = self.blend_with_background(output_image, background) elif background_choice == 3: output_image = self.blend_with_background(output_image, background) else: background[0:3, :, :] = 1.0 output_image = self.blend_with_background(output_image, background) c, h, w = output_image.shape output_image = 255.0 * torch.transpose(output_image.reshape(c, h * w), 0, 1).reshape(h, w, c) output_image = output_image.byte() numpy_image = output_image.detach().cpu().numpy() wx_image = wx.ImageFromBuffer(numpy_image.shape[0], numpy_image.shape[1], numpy_image[:, :, 0:3].tobytes(), numpy_image[:, :, 3].tobytes()) wx_bitmap = wx_image.ConvertToBitmap() dc = wx.MemoryDC() dc.SelectObject(self.result_image_bitmap) dc.Clear() dc.DrawBitmap(wx_bitmap, (MainFrame.IMAGE_SIZE - numpy_image.shape[0]) // 2, (MainFrame.IMAGE_SIZE - numpy_image.shape[1]) // 2, True) del dc time_now = time.time_ns() if self.last_update_time is not None: elapsed_time = time_now - self.last_update_time fps = 1.0 / (elapsed_time / 10 ** 9) if self.torch_source_image is not None: self.fps_statistics.add_fps(fps) self.fps_text.SetLabelText("FPS = %0.2f" % self.fps_statistics.get_average_fps()) self.last_update_time = time_now self.Refresh() def blend_with_background(self, numpy_image, background): alpha = numpy_image[3:4, :, :] color = numpy_image[0:3, :, :] new_color = color * alpha + (1.0 - alpha) * background[0:3, :, :] return torch.cat([new_color, background[3:4, :, :]], dim=0) def load_model(self, event: wx.Event): dir_name = "data/character_models" file_dialog = wx.FileDialog(self, "Choose a model", dir_name, "", "*.yaml", wx.FD_OPEN) if file_dialog.ShowModal() == wx.ID_OK: character_model_json_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename()) try: self.character_model = CharacterModel.load(character_model_json_file_name) self.torch_source_image = self.character_model.get_character_image(self.device) pil_image = resize_PIL_image( PIL.Image.open(self.character_model.character_image_file_name), (MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE)) w, h = pil_image.size self.wx_source_image = wx.Bitmap.FromBufferRGBA(w, h, pil_image.convert("RGBA").tobytes()) self.update_source_image_bitmap() self.poser = self.character_model.get_poser(self.device) except Exception: message_dialog = wx.MessageDialog( self, "Could not load character model " + character_model_json_file_name, "Poser", wx.OK) message_dialog.ShowModal() message_dialog.Destroy() file_dialog.Destroy() self.Refresh() if __name__ == "__main__": device = torch.device("cuda:0") pose_converter = MediaPoseFacePoseConverter00() face_landmarker_base_options = mediapipe.tasks.BaseOptions( model_asset_path='data/thirdparty/mediapipe/face_landmarker_v2_with_blendshapes.task') options = mediapipe.tasks.vision.FaceLandmarkerOptions( base_options=face_landmarker_base_options, running_mode=mediapipe.tasks.vision.RunningMode.VIDEO, output_face_blendshapes=True, output_facial_transformation_matrixes=True, num_faces=1) face_landmarker = mediapipe.tasks.vision.FaceLandmarker.create_from_options(options) video_capture = cv2.VideoCapture(0) app = wx.App() main_frame = MainFrame(pose_converter, video_capture, face_landmarker, device) main_frame.Show(True) main_frame.capture_timer.Start(30) main_frame.animation_timer.Start(30) app.MainLoop() ================================================ FILE: src/tha4/app/distill.py ================================================ import argparse import logging from tha4.distiller.distiller_config import DistillerConfig from tha4.pytasuku.workspace import Workspace def run_config(config_file_name: str): config = DistillerConfig.load(config_file_name) logging.basicConfig(level=logging.INFO, force=True) workspace = Workspace() config.define_tasks(workspace) workspace.start_session() workspace.run(f"{config.prefix}/all") workspace.end_session() if __name__ == "__main__": parser = argparse.ArgumentParser(description='Training script.') parser.add_argument("--config_file", type=str, required=True, help="The name of the config file for the distillation process.") args = parser.parse_args() run_config(args.config_file) ================================================ FILE: src/tha4/app/distiller_ui.py ================================================ import wx from tha4.app.distill import run_config from tha4.distiller.ui.distiller_ui_main_frame import DistillerUiMainFrame if __name__ == "__main__": app = wx.App() main_frame = DistillerUiMainFrame() main_frame.Show(True) app.MainLoop() if main_frame.config_file_to_run is not None: run_config(main_frame.config_file_to_run) ================================================ FILE: src/tha4/app/full_manual_poser.py ================================================ import logging import os import sys import time from typing import List from tha4.shion.base.image_util import extract_pytorch_image_from_PIL_image, pytorch_rgba_to_numpy_image, \ pytorch_rgb_to_numpy_image from tha4.image_util import grid_change_to_numpy_image, resize_PIL_image sys.path.append(os.getcwd()) import PIL.Image import numpy import torch import wx from tha4.poser.poser import Poser, PoseParameterCategory, PoseParameterGroup class MorphCategoryControlPanel(wx.Panel): def __init__(self, parent, title: str, pose_param_category: PoseParameterCategory, param_groups: List[PoseParameterGroup]): super().__init__(parent, style=wx.SIMPLE_BORDER) self.pose_param_category = pose_param_category self.sizer = wx.BoxSizer(wx.VERTICAL) self.SetSizer(self.sizer) self.SetAutoLayout(1) title_text = wx.StaticText(self, label=title, style=wx.ALIGN_CENTER) self.sizer.Add(title_text, 0, wx.EXPAND) self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category] self.choice = wx.Choice(self, choices=[group.get_group_name() for group in self.param_groups]) if len(self.param_groups) > 0: self.choice.SetSelection(0) self.choice.Bind(wx.EVT_CHOICE, self.on_choice_updated) self.sizer.Add(self.choice, 0, wx.EXPAND) self.left_slider = wx.Slider(self, minValue=-1000, maxValue=1000, value=-1000, style=wx.HORIZONTAL) self.sizer.Add(self.left_slider, 0, wx.EXPAND) self.right_slider = wx.Slider(self, minValue=-1000, maxValue=1000, value=-1000, style=wx.HORIZONTAL) self.sizer.Add(self.right_slider, 0, wx.EXPAND) self.checkbox = wx.CheckBox(self, label="Show") self.checkbox.SetValue(True) self.sizer.Add(self.checkbox, 0, wx.SHAPED | wx.ALIGN_CENTER) self.update_ui() self.sizer.Fit(self) def update_ui(self): param_group = self.param_groups[self.choice.GetSelection()] if param_group.is_discrete(): self.left_slider.Enable(False) self.right_slider.Enable(False) self.checkbox.Enable(True) elif param_group.get_arity() == 1: self.left_slider.Enable(True) self.right_slider.Enable(False) self.checkbox.Enable(False) else: self.left_slider.Enable(True) self.right_slider.Enable(True) self.checkbox.Enable(False) def on_choice_updated(self, event: wx.Event): param_group = self.param_groups[self.choice.GetSelection()] if param_group.is_discrete(): self.checkbox.SetValue(True) self.update_ui() def set_param_value(self, pose: List[float]): if len(self.param_groups) == 0: return selected_morph_index = self.choice.GetSelection() param_group = self.param_groups[selected_morph_index] param_index = param_group.get_parameter_index() if param_group.is_discrete(): if self.checkbox.GetValue(): for i in range(param_group.get_arity()): pose[param_index + i] = 1.0 else: param_range = param_group.get_range() alpha = (self.left_slider.GetValue() + 1000) / 2000.0 pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha if param_group.get_arity() == 2: alpha = (self.right_slider.GetValue() + 1000) / 2000.0 pose[param_index + 1] = param_range[0] + (param_range[1] - param_range[0]) * alpha class SimpleParamGroupsControlPanel(wx.Panel): def __init__(self, parent, pose_param_category: PoseParameterCategory, param_groups: List[PoseParameterGroup]): super().__init__(parent, style=wx.SIMPLE_BORDER) self.sizer = wx.BoxSizer(wx.VERTICAL) self.SetSizer(self.sizer) self.SetAutoLayout(1) self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category] for param_group in self.param_groups: assert not param_group.is_discrete() assert param_group.get_arity() == 1 self.sliders = [] for param_group in self.param_groups: static_text = wx.StaticText( self, label=" ------------ %s ------------ " % param_group.get_group_name(), style=wx.ALIGN_CENTER) self.sizer.Add(static_text, 0, wx.EXPAND) range = param_group.get_range() min_value = int(range[0] * 1000) max_value = int(range[1] * 1000) slider = wx.Slider(self, minValue=min_value, maxValue=max_value, value=0, style=wx.HORIZONTAL) self.sizer.Add(slider, 0, wx.EXPAND) self.sliders.append(slider) self.sizer.Fit(self) def set_param_value(self, pose: List[float]): if len(self.param_groups) == 0: return for param_group_index in range(len(self.param_groups)): param_group = self.param_groups[param_group_index] slider = self.sliders[param_group_index] param_range = param_group.get_range() param_index = param_group.get_parameter_index() alpha = (slider.GetValue() - slider.GetMin()) * 1.0 / (slider.GetMax() - slider.GetMin()) pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha def convert_output_image_from_torch_to_numpy(output_image): if output_image.shape[2] == 2: h, w, c = output_image.shape numpy_image = torch.transpose(output_image.reshape(h * w, c), 0, 1).reshape(c, h, w) elif output_image.shape[0] == 4: numpy_image = pytorch_rgba_to_numpy_image(output_image) elif output_image.shape[0] == 3: numpy_image = pytorch_rgb_to_numpy_image(output_image) elif output_image.shape[0] == 1: c, h, w = output_image.shape alpha_image = torch.cat([output_image.repeat(3, 1, 1) * 2.0 - 1.0, torch.ones(1, h, w)], dim=0) numpy_image = pytorch_rgba_to_numpy_image(alpha_image) elif output_image.shape[0] == 2: numpy_image = grid_change_to_numpy_image(output_image, num_channels=4) else: raise RuntimeError("Unsupported # image channels: %d" % output_image.shape[0]) numpy_image = numpy.uint8(numpy.rint(numpy_image * 255.0)) return numpy_image class MainFrame(wx.Frame): def __init__(self, poser: Poser, device: torch.device): super().__init__(None, wx.ID_ANY, "Poser") self.poser = poser self.dtype = self.poser.get_dtype() self.device = device self.image_size = self.poser.get_image_size() self.wx_source_image = None self.torch_source_image = None self.main_sizer = wx.BoxSizer(wx.HORIZONTAL) self.SetSizer(self.main_sizer) self.SetAutoLayout(1) self.init_left_panel() self.init_control_panel() self.init_right_panel() self.main_sizer.Fit(self) self.timer = wx.Timer(self, wx.ID_ANY) self.Bind(wx.EVT_TIMER, self.update_images, self.timer) save_image_id = wx.NewIdRef() self.Bind(wx.EVT_MENU, self.on_save_image, id=save_image_id) accelerator_table = wx.AcceleratorTable([ (wx.ACCEL_CTRL, ord('S'), save_image_id) ]) self.SetAcceleratorTable(accelerator_table) self.last_pose = None self.last_output_index = self.output_index_choice.GetSelection() self.last_output_numpy_image = None self.wx_source_image = None self.torch_source_image = None self.source_image_bitmap = wx.Bitmap(self.image_size, self.image_size) self.result_image_bitmap = wx.Bitmap(self.image_size, self.image_size) self.source_image_dirty = True def init_left_panel(self): self.control_panel = wx.Panel(self, style=wx.SIMPLE_BORDER, size=(self.image_size, -1)) self.left_panel = wx.Panel(self, style=wx.SIMPLE_BORDER) left_panel_sizer = wx.BoxSizer(wx.VERTICAL) self.left_panel.SetSizer(left_panel_sizer) self.left_panel.SetAutoLayout(1) self.source_image_panel = wx.Panel(self.left_panel, size=(self.image_size, self.image_size), style=wx.SIMPLE_BORDER) self.source_image_panel.Bind(wx.EVT_PAINT, self.paint_source_image_panel) self.source_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background) left_panel_sizer.Add(self.source_image_panel, 0, wx.FIXED_MINSIZE) self.load_image_button = wx.Button(self.left_panel, wx.ID_ANY, "\nLoad Image\n\n") left_panel_sizer.Add(self.load_image_button, 1, wx.EXPAND) self.load_image_button.Bind(wx.EVT_BUTTON, self.load_image) left_panel_sizer.Fit(self.left_panel) self.main_sizer.Add(self.left_panel, 0, wx.FIXED_MINSIZE) def on_erase_background(self, event: wx.Event): pass def init_control_panel(self): self.control_panel_sizer = wx.BoxSizer(wx.VERTICAL) self.control_panel.SetSizer(self.control_panel_sizer) self.control_panel.SetMinSize(wx.Size(256, 1)) morph_categories = [ PoseParameterCategory.EYEBROW, PoseParameterCategory.EYE, PoseParameterCategory.MOUTH, PoseParameterCategory.IRIS_MORPH ] morph_category_titles = { PoseParameterCategory.EYEBROW: " ------------ Eyebrow ------------ ", PoseParameterCategory.EYE: " ------------ Eye ------------ ", PoseParameterCategory.MOUTH: " ------------ Mouth ------------ ", PoseParameterCategory.IRIS_MORPH: " ------------ Iris morphs ------------ ", } self.morph_control_panels = {} for category in morph_categories: param_groups = self.poser.get_pose_parameter_groups() filtered_param_groups = [group for group in param_groups if group.get_category() == category] if len(filtered_param_groups) == 0: continue control_panel = MorphCategoryControlPanel( self.control_panel, morph_category_titles[category], category, self.poser.get_pose_parameter_groups()) self.morph_control_panels[category] = control_panel self.control_panel_sizer.Add(control_panel, 0, wx.EXPAND) self.non_morph_control_panels = {} non_morph_categories = [ PoseParameterCategory.IRIS_ROTATION, PoseParameterCategory.FACE_ROTATION, PoseParameterCategory.BODY_ROTATION, PoseParameterCategory.BREATHING ] for category in non_morph_categories: param_groups = self.poser.get_pose_parameter_groups() filtered_param_groups = [group for group in param_groups if group.get_category() == category] if len(filtered_param_groups) == 0: continue control_panel = SimpleParamGroupsControlPanel( self.control_panel, category, self.poser.get_pose_parameter_groups()) self.non_morph_control_panels[category] = control_panel self.control_panel_sizer.Add(control_panel, 0, wx.EXPAND) self.control_panel_sizer.Fit(self.control_panel) self.main_sizer.Add(self.control_panel, 1, wx.FIXED_MINSIZE) def init_right_panel(self): self.right_panel = wx.Panel(self, style=wx.SIMPLE_BORDER) right_panel_sizer = wx.BoxSizer(wx.VERTICAL) self.right_panel.SetSizer(right_panel_sizer) self.right_panel.SetAutoLayout(1) self.result_image_panel = wx.Panel(self.right_panel, size=(self.image_size, self.image_size), style=wx.SIMPLE_BORDER) self.result_image_panel.Bind(wx.EVT_PAINT, self.paint_result_image_panel) self.result_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background) self.output_index_choice = wx.Choice( self.right_panel, choices=[str(i) for i in range(self.poser.get_output_length())]) self.output_index_choice.SetSelection(0) right_panel_sizer.Add(self.result_image_panel, 0, wx.FIXED_MINSIZE) right_panel_sizer.Add(self.output_index_choice, 0, wx.EXPAND) self.save_image_button = wx.Button(self.right_panel, wx.ID_ANY, "\nSave Image\n\n") right_panel_sizer.Add(self.save_image_button, 1, wx.EXPAND) self.save_image_button.Bind(wx.EVT_BUTTON, self.on_save_image) right_panel_sizer.Fit(self.right_panel) self.main_sizer.Add(self.right_panel, 0, wx.FIXED_MINSIZE) def create_param_category_choice(self, param_category: PoseParameterCategory): params = [] for param_group in self.poser.get_pose_parameter_groups(): if param_group.get_category() == param_category: params.append(param_group.get_group_name()) choice = wx.Choice(self.control_panel, choices=params) if len(params) > 0: choice.SetSelection(0) return choice def load_image(self, event: wx.Event): dir_name = "data/images" file_dialog = wx.FileDialog(self, "Choose an image", dir_name, "", "*.png", wx.FD_OPEN) if file_dialog.ShowModal() == wx.ID_OK: image_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename()) try: pil_image = resize_PIL_image(PIL.Image.open(image_file_name), (self.poser.get_image_size(), self.poser.get_image_size())) w, h = pil_image.size if pil_image.mode != 'RGBA': self.source_image_string = "Image must have alpha channel!" self.wx_source_image = None self.torch_source_image = None else: self.wx_source_image = wx.Bitmap.FromBufferRGBA(w, h, pil_image.convert("RGBA").tobytes()) self.torch_source_image = extract_pytorch_image_from_PIL_image(pil_image) \ .to(self.device).to(self.dtype) self.source_image_dirty = True self.Refresh() self.Update() except: message_dialog = wx.MessageDialog(self, "Could not load image " + image_file_name, "Poser", wx.OK) message_dialog.ShowModal() message_dialog.Destroy() file_dialog.Destroy() def paint_source_image_panel(self, event: wx.Event): wx.BufferedPaintDC(self.source_image_panel, self.source_image_bitmap) def paint_result_image_panel(self, event: wx.Event): wx.BufferedPaintDC(self.result_image_panel, self.result_image_bitmap) def draw_nothing_yet_string_to_bitmap(self, bitmap): dc = wx.MemoryDC() dc.SelectObject(bitmap) dc.Clear() font = wx.Font(wx.FontInfo(14).Family(wx.FONTFAMILY_SWISS)) dc.SetFont(font) w, h = dc.GetTextExtent("Nothing yet!") dc.DrawText("Nothing yet!", (self.image_size - w) // 2, (self.image_size - - h) // 2) del dc def get_current_pose(self): current_pose = [0.0 for i in range(self.poser.get_num_parameters())] for morph_control_panel in self.morph_control_panels.values(): morph_control_panel.set_param_value(current_pose) for rotation_control_panel in self.non_morph_control_panels.values(): rotation_control_panel.set_param_value(current_pose) return current_pose def update_images(self, event: wx.Event): current_pose = self.get_current_pose() if not self.source_image_dirty \ and self.last_pose is not None \ and self.last_pose == current_pose \ and self.last_output_index == self.output_index_choice.GetSelection(): return self.last_pose = current_pose self.last_output_index = self.output_index_choice.GetSelection() if self.torch_source_image is None: self.draw_nothing_yet_string_to_bitmap(self.source_image_bitmap) self.draw_nothing_yet_string_to_bitmap(self.result_image_bitmap) self.source_image_dirty = False self.Refresh() self.Update() return if self.source_image_dirty: dc = wx.MemoryDC() dc.SelectObject(self.source_image_bitmap) dc.Clear() dc.DrawBitmap(self.wx_source_image, 0, 0) self.source_image_dirty = False pose = torch.tensor(current_pose, device=self.device, dtype=self.dtype) output_index = self.output_index_choice.GetSelection() with torch.no_grad(): start_cuda_event = torch.cuda.Event(enable_timing=True) end_cuda_event = torch.cuda.Event(enable_timing=True) start_cuda_event.record() start_time = time.time() output_image = self.poser.pose(self.torch_source_image, pose, output_index)[0].detach().cpu() end_time = time.time() end_cuda_event.record() torch.cuda.synchronize() print("cuda time (ms):", start_cuda_event.elapsed_time(end_cuda_event)) print("elapsed time (ms):", (end_time - start_time) * 1000.0) numpy_image = convert_output_image_from_torch_to_numpy(output_image) self.last_output_numpy_image = numpy_image wx_image = wx.ImageFromBuffer( numpy_image.shape[0], numpy_image.shape[1], numpy_image[:, :, 0:3].tobytes(), numpy_image[:, :, 3].tobytes()) wx_bitmap = wx_image.ConvertToBitmap() dc = wx.MemoryDC() dc.SelectObject(self.result_image_bitmap) dc.Clear() dc.DrawBitmap(wx_bitmap, (self.image_size - numpy_image.shape[0]) // 2, (self.image_size - numpy_image.shape[1]) // 2, True) del dc self.Refresh() self.Update() def on_save_image(self, event: wx.Event): if self.last_output_numpy_image is None: logging.info("There is no output image to save!!!") return dir_name = "data/images" file_dialog = wx.FileDialog(self, "Choose an image", dir_name, "", "*.png", wx.FD_SAVE) if file_dialog.ShowModal() == wx.ID_OK: image_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename()) try: if os.path.exists(image_file_name): message_dialog = wx.MessageDialog(self, f"Override {image_file_name}", "Manual Poser", wx.YES_NO | wx.ICON_QUESTION) result = message_dialog.ShowModal() if result == wx.ID_YES: self.save_last_numpy_image(image_file_name) message_dialog.Destroy() else: self.save_last_numpy_image(image_file_name) except: message_dialog = wx.MessageDialog(self, f"Could not save {image_file_name}", "Manual Poser", wx.OK) message_dialog.ShowModal() message_dialog.Destroy() file_dialog.Destroy() def save_last_numpy_image(self, image_file_name): numpy_image = self.last_output_numpy_image pil_image = PIL.Image.fromarray(numpy_image, mode='RGBA') os.makedirs(os.path.dirname(image_file_name), exist_ok=True) pil_image.save(image_file_name) if __name__ == "__main__": device = torch.device('cuda:0') try: import tha4.poser.modes.mode_07 poser = tha4.poser.modes.mode_07.create_poser(device) except RuntimeError as e: print(e) sys.exit() app = wx.App() main_frame = MainFrame(poser, device) main_frame.Show(True) main_frame.timer.Start(16) app.MainLoop() ================================================ FILE: src/tha4/charmodel/__init__.py ================================================ ================================================ FILE: src/tha4/charmodel/character_model.py ================================================ import json import os.path import PIL.Image import torch from omegaconf import OmegaConf from tha4.shion.base.image_util import extract_pytorch_image_from_PIL_image from tha4.poser.modes.mode_14 import create_poser, KEY_FACE_MORPHER, KEY_BODY_MORPHER class CharacterModel: def __init__(self, character_image_file_name: str, face_morpher_file_name: str, body_morpher_file_name: str): self.body_morpher_file_name = body_morpher_file_name self.face_morpher_file_name = face_morpher_file_name self.character_image_file_name = character_image_file_name self.poser = None self.character_image = None def get_poser(self, device: torch.device): if self.poser is not None: self.poser.to(device) else: self.poser = create_poser( device, module_file_names={ KEY_FACE_MORPHER: self.face_morpher_file_name, KEY_BODY_MORPHER: self.body_morpher_file_name }) return self.poser def get_character_image(self, device: torch.device): if self.character_image is None: pil_image = PIL.Image.open(self.character_image_file_name) if pil_image.mode != 'RGBA': raise RuntimeError("Character image is not an RGBA image!") self.character_image = extract_pytorch_image_from_PIL_image(pil_image) self.character_image = self.character_image.to(device) return self.character_image def save(self, file_name: str): dir = os.path.dirname(file_name) rel_char_image_file_name = os.path.relpath(self.character_image_file_name, dir) rel_face_morpher_file_name = os.path.relpath(self.face_morpher_file_name, dir) rel_body_morpher_file_name = os.path.relpath(self.body_morpher_file_name, dir) data = { "character_image_file_name": rel_char_image_file_name, "face_morpher_file_name": rel_face_morpher_file_name, "body_morpher_file_name": rel_body_morpher_file_name, } conf = OmegaConf.create(data) os.makedirs(dir, exist_ok=True) with open(file_name, "wt") as fout: fout.write(OmegaConf.to_yaml(conf)) @staticmethod def load(file_name: str): conf = OmegaConf.to_container(OmegaConf.load(file_name)) dir = os.path.dirname(file_name) character_image_file_name = os.path.join(dir, conf["character_image_file_name"]) face_morpher_file_name = os.path.join(dir, conf["face_morpher_file_name"]) body_morpher_file_name = os.path.join(dir, conf["body_morpher_file_name"]) return CharacterModel( character_image_file_name, face_morpher_file_name, body_morpher_file_name) ================================================ FILE: src/tha4/dataset/__init__.py ================================================ ================================================ FILE: src/tha4/dataset/image_poses_and_aother_images_dataset.py ================================================ from typing import List, Callable from torch import Tensor from torch.utils.data import Dataset class ImagePosesAndOtherImagesDataset(Dataset): def __init__(self, main_image_func: Callable[[], Tensor], pose_dataset: Dataset, other_image_funcs: List[Callable[[], Tensor]]): self.main_image_func = main_image_func self.other_image_funcs = other_image_funcs self.pose_dataset = pose_dataset self.main_image = None self.other_images = [None for i in range(len(self.other_image_funcs))] def get_main_image(self): if self.main_image is None: self.main_image = self.main_image_func() return self.main_image def get_other_image(self, image_index: int): if self.other_images[image_index] is None: self.other_images[image_index] = self.other_image_funcs[image_index]() return self.other_images[image_index] def __len__(self): return len(self.pose_dataset) def __getitem__(self, index): main_image = self.get_main_image() pose = self.pose_dataset[index][0] other_images = [self.get_other_image(i) for i in range(len(self.other_image_funcs))] return [main_image, pose] + other_images ================================================ FILE: src/tha4/distiller/__init__.py ================================================ ================================================ FILE: src/tha4/distiller/config_based_training_tasks.py ================================================ import logging import os import sys from typing import Callable, List, Optional from tha4.pytasuku.workspace import Workspace from tha4.shion.core.training.distrib.distributed_trainer import DistributedTrainer from tha4.shion.core.training.distrib.distributed_training_states import DistributedTrainingState def get_torchrun_executable(): return os.path.dirname(sys.executable) + os.path.sep + "torchrun" class RdzvConfig: def __init__(self, id: int, port: int): self.port = port self.id = id def run_standalone_config_based_training_script( training_script_file_name: str, config_file_name: str, num_proc_per_node: int, target_checkpoint_examples: Optional[int] = None, rdzv_config: Optional[RdzvConfig] = None): command = f"{get_torchrun_executable()} " \ f"--nnodes=1 " \ f"--nproc_per_node={num_proc_per_node} " if rdzv_config is not None: command += f"--rdzv_endpoint=localhost:{rdzv_config.port} " command += "--rdzv_backend=c10d " command += f"--rdzv_id={rdzv_config.id} " else: command += "--standalone " command += f"{training_script_file_name} " if target_checkpoint_examples is not None: command += f"--target_checkpoint_examples {target_checkpoint_examples} " command += f"--config_file={config_file_name} " logging.info(f"Executing -- {command}") os.system(command) def define_standalone_config_based_training_tasks( workspace: Workspace, distributed_trainer_func: Callable[[], DistributedTrainer], training_script_file_name: str, config_file_name: str, num_proc_per_node: int, dependencies: Optional[List[str]] = None, rdzv_config: Optional[RdzvConfig] = None): trainer = distributed_trainer_func() checkpoint_examples = trainer.training_protocol.get_checkpoint_examples() assert len(checkpoint_examples) >= 1 assert checkpoint_examples[0] > 0 checkpoint_examples = [0] + checkpoint_examples if dependencies is None: dependencies = [] module_file_dependencies = dependencies[:] for module_name in trainer.pretrained_module_file_names: module_file_dependencies.append(trainer.pretrained_module_file_names[module_name]) def create_train_func(target_checkpoint_examples: int): return lambda: run_standalone_config_based_training_script( training_script_file_name, config_file_name, num_proc_per_node, target_checkpoint_examples, rdzv_config=rdzv_config) train_tasks = [] for checkpoint_index in range(0, len(checkpoint_examples)): for module_name in trainer.module_names: module_file_name = DistributedTrainingState.get_module_file_name( trainer.get_checkpoint_prefix(checkpoint_index), module_name) workspace.create_file_task( module_file_name, module_file_dependencies, create_train_func(trainer.checkpoint_examples[checkpoint_index])) for module_name in trainer.accumulators: accumulated_module_file_name = DistributedTrainingState.get_accumulated_module_file_name( trainer.get_checkpoint_prefix(checkpoint_index), module_name) workspace.create_file_task( accumulated_module_file_name, module_file_dependencies, create_train_func(checkpoint_examples[checkpoint_index])) workspace.create_command_task( trainer.get_checkpoint_prefix(checkpoint_index) + "/train_standalone", module_file_dependencies, create_train_func(checkpoint_examples[checkpoint_index])) train_tasks.append(trainer.get_checkpoint_prefix(checkpoint_index) + "/train_standlone") workspace.create_file_task( trainer.prefix + "/train_standalone", module_file_dependencies, create_train_func(checkpoint_examples[-1])) ================================================ FILE: src/tha4/distiller/distill_body_morpher.py ================================================ import logging from tha4.shion.core.training.distrib.distributed_trainer import DistributedTrainer from tha4.distiller.distiller_config import DistillerConfig if __name__ == "__main__": logging.basicConfig(level=logging.INFO) parser = DistributedTrainer.get_default_arg_parser() parser.add_argument('--config_file', type=str) args = parser.parse_args() config_file_name = args.config_file config = DistillerConfig.load(config_file_name) DistributedTrainer.run_with_args(config.get_body_morpher_trainer, args) ================================================ FILE: src/tha4/distiller/distill_face_morpher.py ================================================ import logging from tha4.shion.core.training.distrib.distributed_trainer import DistributedTrainer from tha4.distiller.distiller_config import DistillerConfig if __name__ == "__main__": logging.basicConfig(level=logging.INFO) parser = DistributedTrainer.get_default_arg_parser() parser.add_argument('--config_file', type=str) args = parser.parse_args() config_file_name = args.config_file config = DistillerConfig.load(config_file_name) DistributedTrainer.run_with_args(config.get_face_morpher_trainer, args) ================================================ FILE: src/tha4/distiller/distiller_config.py ================================================ import os.path import shutil import PIL.Image from dataclasses import dataclass from typing import Optional from omegaconf import OmegaConf from tha4.charmodel.character_model import CharacterModel from tha4.pytasuku.workspace import Workspace, file_task from tha4.distiller.config_based_training_tasks import define_standalone_config_based_training_tasks from tha4.nn.siren.face_morpher.siren_face_morpher_00_trainer import SirenFaceMorpher00TrainerArgs from tha4.nn.siren.morpher.siren_morpher_03_trainer import SirenMorpher03TrainerArgs, TrainingPhases, TrainingPhase, \ LossWeights, LossTerm from tha4.shion.base.image_util import pil_image_has_transparency POSE_DATASET_FILE_NAME = 'data/pose_dataset.pt' def copy_file(source_file_name: str, dest_file_name): os.makedirs(os.path.dirname(dest_file_name), exist_ok=True) shutil.copyfile(source_file_name, dest_file_name) @dataclass class DistillerConfig: prefix: str character_image_file_name: str face_mask_image_file_name: str face_morpher_random_seed_0: int = 12771885812175595441 face_morpher_random_seed_1: int = 14367217090963479175 face_morpher_num_training_examples_per_sample_output: Optional[int] = 10_000 face_morpher_batch_size: int = 8 body_morpher_random_seed_0: int = 2892221210020292507 body_morpher_random_seed_1: int = 9998918537095922080 body_morpher_num_training_examples_per_sample_output: Optional[int] = 10_000 body_morpher_batch_size: int = 8 num_cpu_workers: int = 1 num_gpus: int = 1 def check(self): DistillerConfig.check_prefix(self.prefix) DistillerConfig.check_character_image_file_name(self.character_image_file_name) DistillerConfig.check_face_mask_image_file_name(self.face_mask_image_file_name) DistillerConfig.check_num_cpu_workers(self.num_cpu_workers) DistillerConfig.check_num_gpus(self.num_gpus) DistillerConfig.check_random_seed(self.face_morpher_random_seed_0, "face_morpher_random_seed_0") DistillerConfig.check_random_seed(self.face_morpher_random_seed_1, "face_morpher_random_seed_1") DistillerConfig.check_batch_size(self.face_morpher_batch_size, "face_morpher_batch_size") DistillerConfig.check_num_training_examples_per_sample_output( self.face_morpher_num_training_examples_per_sample_output, "face_morpher_num_training_examples_per_sample_output") DistillerConfig.check_random_seed(self.body_morpher_random_seed_0, "body_morpher_random_seed_0") DistillerConfig.check_random_seed(self.body_morpher_random_seed_1, "body_morpher_random_seed_1") DistillerConfig.check_batch_size(self.body_morpher_batch_size, "body_morpher_batch_size") DistillerConfig.check_num_training_examples_per_sample_output( self.body_morpher_num_training_examples_per_sample_output, "body_morpher_num_training_examples_per_sample_output") @staticmethod def check_prefix(prefix): assert os.path.isdir(prefix), "The 'prefix' must be a directory." assert os.path.exists(prefix), f"The {prefix} directory does not exist." @staticmethod def check_character_image_file_name(file_name): _, ext = os.path.splitext(file_name) assert os.path.isfile(file_name), \ f"The specified character image file name, {file_name}, does not point to a file." assert ext.lower() == ".png", "The character image file name must have extension '.png'." image = PIL.Image.open(file_name) assert pil_image_has_transparency(image), "The character image must have an alpha channel." assert image.width == 512 and image.height == 512, "The character image must be 512x512." image.close() @staticmethod def check_face_mask_image_file_name(file_name): _, ext = os.path.splitext(file_name) assert os.path.isfile(file_name), \ f"The specified face mask image file name, {file_name}, does not point to a file." assert ext.lower() == ".png", "The face mask image file name must have extension '.png'." image = PIL.Image.open(file_name) assert image.width == 512 and image.height == 512, "The face mask image must be 512x512." assert image.mode == "RGB", "The face mask image must be an RGB image." for x in range(512): for y in range(512): r, g, b = image.getpixel((x, y)) assert (r == 0) or (r == 255), "The R channel of the face mask image must be 0 or 255" assert (g == 0) or (g == 255), "The G channel of the face mask image must be 0 or 255" assert (b == 0) or (b == 255), "The B channel of the face mask image must be 0 or 255" image.close() @staticmethod def check_batch_size(value, field_name: str): assert isinstance(value, int), f"The {field_name} must be an integer." assert value >= 1, f"The {field_name} must be at least 1." assert value <= 8, f"The {field_name} must be at most 8." @staticmethod def check_num_cpu_workers(value): assert value >= 1, "The value of 'num_cpu_workers must be at least 1." @staticmethod def check_num_gpus(value): assert value >= 1, "The value of 'num_gpus' must be at least 1." @staticmethod def check_random_seed(value, field_name: str): assert isinstance(value, int), f"The {field_name} must be an integer." assert value >= 0 and value <= 0x_ffff_ffff_ffff_ffff, "A random seed must be between 0 and 2**64-1." @staticmethod def check_num_training_examples_per_sample_output(value, field_name): assert value in [10_000, 100_000, 1_000_000, None], f"The {field_name} must be 10_000, 100_00, 1_000_000_000, or None." def save(self, file_name: str): conf = OmegaConf.structured(self) os.makedirs(self.prefix, exist_ok=True) with open(file_name, "wt") as fout: fout.write(OmegaConf.to_yaml(conf)) def config_yaml_file_name(self): return f"{self.prefix}/config.yaml" def create_config_yaml_file(self): if os.path.exists(self.config_yaml_file_name()): return self.save(self.config_yaml_file_name()) @staticmethod def load(file_name: str) -> 'DistillerConfig': conf = OmegaConf.to_container(OmegaConf.load(file_name)) args = DistillerConfig(**conf) args.check() return args def face_morpher_prefix(self): return f"{self.prefix}/face_morpher" def get_face_morpher_trainer(self, world_size: Optional[int] = None, backend: str = 'gloo'): if world_size is None: world_size = self.num_gpus args = SirenFaceMorpher00TrainerArgs( character_file_name=self.character_image_file_name, face_mask_file_name=self.face_mask_image_file_name, pose_dataset_file_name=POSE_DATASET_FILE_NAME, total_worker=self.num_cpu_workers, num_training_examples_per_sample_output=self.face_morpher_num_training_examples_per_sample_output, total_batch_size=self.face_morpher_batch_size, training_random_seed=self.face_morpher_random_seed_0, sample_output_random_seed=self.face_morpher_random_seed_1) return args.create_trainer(self.face_morpher_prefix(), world_size, backend) def body_morpher_prefix(self): return f"{self.prefix}/body_morpher" def get_body_morpher_trainer(self, world_size: Optional[int] = None, backend: str = 'gloo'): if world_size is None: world_size = self.num_gpus args = SirenMorpher03TrainerArgs( character_file_name=self.character_image_file_name, pose_dataset_file_name=POSE_DATASET_FILE_NAME, total_worker=self.num_cpu_workers, num_training_examples_per_sample_output=self.body_morpher_num_training_examples_per_sample_output, training_random_seed=self.body_morpher_random_seed_0, sample_output_random_seed=self.body_morpher_random_seed_1, total_batch_size=self.body_morpher_batch_size, sample_output_batch_size=1, training_phases=TrainingPhases([ TrainingPhase( num_examples_upper_bound=200_000, learning_rate=1e-4, loss_weights=LossWeights(weights={ LossTerm.full_blended: 0.25, LossTerm.full_warped: 0.25, LossTerm.full_grid_change: 0.5, LossTerm.full_color_change: 2.0, })), TrainingPhase( num_examples_upper_bound=400_000, learning_rate=3e-5, loss_weights=LossWeights(weights={ LossTerm.full_blended: 0.25, LossTerm.full_warped: 0.25, LossTerm.full_grid_change: 0.5, LossTerm.full_color_change: 2.0, })), TrainingPhase( num_examples_upper_bound=600_000, learning_rate=3e-5, loss_weights=LossWeights(weights={ LossTerm.full_blended: 1.0, LossTerm.full_warped: 2.5, LossTerm.full_grid_change: 5.0, LossTerm.full_color_change: 1.0, })), TrainingPhase( num_examples_upper_bound=800_000, learning_rate=1e-5, loss_weights=LossWeights(weights={ LossTerm.full_blended: 1.0, LossTerm.full_warped: 2.5, LossTerm.full_grid_change: 5.0, LossTerm.full_color_change: 1.0, })), TrainingPhase( num_examples_upper_bound=1_300_000, learning_rate=1e-5, loss_weights=LossWeights(weights={ LossTerm.full_blended: 10.0, LossTerm.full_warped: 1.0, LossTerm.full_grid_change: 1.0, LossTerm.full_color_change: 1.0, })), TrainingPhase( num_examples_upper_bound=1_500_000, learning_rate=3e-6, loss_weights=LossWeights(weights={ LossTerm.full_blended: 10.0, LossTerm.full_warped: 1.0, LossTerm.full_grid_change: 1.0, LossTerm.full_color_change: 1.0, })), ])) return args.create_trainer(self.body_morpher_prefix(), world_size, backend) def character_model_prefix(self): return f"{self.prefix}/character_model" def character_model_face_morpher_file_name(self): return f"{self.character_model_prefix()}/face_morpher.pt" def character_model_body_morpher_file_name(self): return f"{self.character_model_prefix()}/body_morpher.pt" def character_model_character_png_file_name(self): return f"{self.character_model_prefix()}/character.png" def character_model_yaml_file_name(self): return f"{self.character_model_prefix()}/character_model.yaml" def define_tasks(self, workspace: Workspace): workspace.create_file_task(self.config_yaml_file_name(), [], self.create_config_yaml_file) define_standalone_config_based_training_tasks( workspace, self.get_face_morpher_trainer, "src/tha4/distiller/distill_face_morpher.py", self.config_yaml_file_name(), num_proc_per_node=self.num_gpus, dependencies=[ self.config_yaml_file_name(), ]) define_standalone_config_based_training_tasks( workspace, self.get_body_morpher_trainer, "src/tha4/distiller/distill_body_morpher.py", self.config_yaml_file_name(), num_proc_per_node=self.num_gpus, dependencies=[ self.config_yaml_file_name(), ]) @file_task(workspace, self.character_model_character_png_file_name(), [self.character_image_file_name]) def copy_character_image_file_name(): copy_file(self.character_image_file_name, self.character_model_character_png_file_name()) @file_task(workspace, self.character_model_face_morpher_file_name(), [ f"{self.face_morpher_prefix()}/checkpoint/0010/module_module.pt", ]) def copy_face_morpher(): copy_file( f"{self.face_morpher_prefix()}/checkpoint/0010/module_module.pt", self.character_model_face_morpher_file_name()) @file_task(workspace, self.character_model_body_morpher_file_name(), [ f"{self.body_morpher_prefix()}/checkpoint/0015/module_module.pt", ]) def copy_face_morpher(): copy_file( f"{self.body_morpher_prefix()}/checkpoint/0015/module_module.pt", self.character_model_body_morpher_file_name()) @file_task(workspace, self.character_model_yaml_file_name(), []) def create_character_model_yaml_file(): character_model = CharacterModel( self.character_model_character_png_file_name(), self.character_model_face_morpher_file_name(), self.character_model_body_morpher_file_name()) character_model.save(self.character_model_yaml_file_name()) workspace.create_command_task( f"{self.prefix}/all", [ f"{self.face_morpher_prefix()}/train_standalone", f"{self.body_morpher_prefix()}/train_standalone", self.character_model_character_png_file_name(), self.character_model_face_morpher_file_name(), self.character_model_body_morpher_file_name(), self.character_model_yaml_file_name(), ]) ================================================ FILE: src/tha4/distiller/ui/__init__.py ================================================ ================================================ FILE: src/tha4/distiller/ui/distiller_config_state.py ================================================ import os.path from contextlib import contextmanager from pathlib import PurePath, Path from typing import Callable, Any, Optional from tha4.distiller.distiller_config import DistillerConfig class DistillerConfigState: def __init__(self): self.config = DistillerConfig(prefix="", character_image_file_name="", face_mask_image_file_name="") self.last_saved_timestamp = None self.dirty = False def load(self, file_name): self.config = DistillerConfig.load(file_name) if os.path.exists(self.config.config_yaml_file_name()): self.last_saved_timestamp = os.path.getmtime(self.config.config_yaml_file_name()) else: self.last_saved_timestamp = None self.dirty = False def need_to_check_overwrite(self): if self.last_saved_timestamp is None: return True if not os.path.exists(self.config.config_yaml_file_name()): return False if self.last_saved_timestamp < os.path.getmtime(self.config.config_yaml_file_name()): return True return False def save(self): self.config.save(self.config.config_yaml_file_name()) self.dirty = False self.last_saved_timestamp = os.path.getmtime(self.config.config_yaml_file_name()) @contextmanager def updating_value(self, value_func: Callable[[], Any]): old_value = value_func() yield new_value = value_func() if new_value != old_value: self.dirty = True def set_prefix(self, new_value): with self.updating_value(lambda: self.config.prefix): new_relative_path = self.get_relative_path_to_cwd( new_value, "The prefix directory must be a subdirectory of the talking-head-anime-4-demo's source code directory.") DistillerConfig.check_prefix(new_relative_path) self.config.prefix = new_relative_path def set_character_image_file_name(self, new_value): with self.updating_value(lambda: self.config.character_image_file_name): new_relative_path = self.get_relative_path_to_cwd( new_value, "The character image file must be under talking-head-anime-4-demo's source code directory.") DistillerConfig.check_character_image_file_name(new_relative_path) self.config.character_image_file_name = new_relative_path def set_face_mask_image_file_name(self, new_value): with self.updating_value(lambda: self.config.face_mask_image_file_name): new_relative_path = self.get_relative_path_to_cwd( new_value, "The face mask image file must be under talking-head-anime-4-demo's source code directory.") DistillerConfig.check_face_mask_image_file_name(new_relative_path) self.config.face_mask_image_file_name = new_relative_path def set_num_cpu_workers(self, new_value: int): with self.updating_value(lambda: self.config.num_cpu_workers): DistillerConfig.check_num_cpu_workers(new_value) self.config.num_cpu_workers = new_value def set_num_gpus(self, new_value: int): with self.updating_value(lambda: self.config.num_gpus): DistillerConfig.check_num_cpu_workers(new_value) self.config.num_gpus = new_value def set_face_morpher_random_seed_0(self, new_value: int): with self.updating_value(lambda: self.config.face_morpher_random_seed_0): DistillerConfig.check_random_seed(new_value, "face_morpher_random_seed_0") self.config.face_morpher_random_seed_0 = new_value def set_face_morpher_random_seed_1(self, new_value: int): with self.updating_value(lambda: self.config.face_morpher_random_seed_1): DistillerConfig.check_random_seed(new_value, "face_morpher_random_seed_1") self.config.face_morpher_random_seed_1 = new_value def set_face_morpher_num_training_examples_per_sample_output(self, new_value: Optional[int]): with self.updating_value(lambda: self.config.face_morpher_num_training_examples_per_sample_output): DistillerConfig.check_num_training_examples_per_sample_output( new_value, "face_morpher_num_training_examples_per_sample_output") self.config.face_morpher_num_training_examples_per_sample_output = new_value def set_face_morpher_batch_size(self, new_value: int): with self.updating_value(lambda: self.config.face_morpher_batch_size): DistillerConfig.check_batch_size(new_value, "face_morpher_batch_size") self.config.face_morpher_batch_size = new_value def set_body_morpher_random_seed_0(self, new_value: int): with self.updating_value(lambda: self.config.body_morpher_random_seed_0): DistillerConfig.check_random_seed(new_value, "body_morpher_random_seed_0") self.config.body_morpher_random_seed_0 = new_value def set_body_morpher_random_seed_1(self, new_value: int): with self.updating_value(lambda: self.config.body_morpher_random_seed_1): DistillerConfig.check_random_seed(new_value, "body_morpher_random_seed_1") self.config.body_morpher_random_seed_1 = new_value def set_body_morpher_num_training_examples_per_sample_output(self, new_value: Optional[int]): with self.updating_value(lambda: self.config.body_morpher_num_training_examples_per_sample_output): DistillerConfig.check_num_training_examples_per_sample_output( new_value, "body_morpher_num_training_examples_per_sample_output") self.config.body_morpher_num_training_examples_per_sample_output = new_value def set_body_morpher_batch_size(self, new_value: int): with self.updating_value(lambda: self.config.body_morpher_batch_size): DistillerConfig.check_batch_size(new_value, "body_morpher_batch_size") self.config.body_morpher_batch_size = new_value def get_relative_path_to_cwd(self, file_name: str, message: str): cwd = os.getcwd() assert os.path.commonprefix([cwd, file_name]) == cwd, message cwd_path = Path(cwd).as_posix() new_path = Path(file_name).as_posix() new_relative_path = os.path.relpath(str(new_path), cwd_path) new_relative_path = str(Path(new_relative_path).as_posix()) return new_relative_path def can_show_character_image(self): return os.path.isfile(self.config.character_image_file_name) def can_show_face_mask_image(self): return os.path.isfile(self.config.face_mask_image_file_name) def can_show_mask_on_face_image(self): return self.can_show_character_image() and self.can_show_face_mask_image() def can_save(self): return os.path.isdir(self.config.prefix) \ and os.path.isfile(self.config.character_image_file_name) \ and os.path.isfile(self.config.face_mask_image_file_name) ================================================ FILE: src/tha4/distiller/ui/distiller_ui_main_frame.py ================================================ import multiprocessing import random from contextlib import contextmanager from typing import Callable import PIL.Image import torch import wx import wx.html import wx.lib.intctrl from tha4.distiller.ui.distiller_config_state import DistillerConfigState from tha4.image_util import convert_output_image_from_torch_to_numpy from tha4.shion.base.image_util import extract_pytorch_image_from_PIL_image def wx_bind_event(widget, evt): def f(handler): widget.Bind(evt, handler) return handler return f class DistillerUiMainFrame(wx.Frame): PARAM_NAME_STATIC_TEXT_MIN_WIDTH = 400 NUM_TRAINING_EXAMPLES_PER_SAMPLE_OUTPUT_CHOICES = [ "10_000", "100_000", "1_000_000", "Do not generate sample outputs"] def __init__(self): super().__init__(None, wx.ID_ANY, "Distiller UI") self.init_ui() self.init_menus() self.init_bitmaps() self.Bind(wx.EVT_CLOSE, self.on_close) self.state = DistillerConfigState() self.update_ui() self.config_file_to_run = None def init_ui(self): main_sizer = wx.BoxSizer(wx.HORIZONTAL) self.SetSizer(main_sizer) self.SetAutoLayout(1) left_panel = self.init_left_panel(self) main_sizer.Add(left_panel, 0, wx.FIXED_MINSIZE) middle_panel = self.init_middle_panel(self) main_sizer.Add(middle_panel, 0, wx.EXPAND) right_panel = self.init_right_panel(self) main_sizer.Add(right_panel, 1, wx.EXPAND) main_sizer.Fit(self) def init_menus(self): self.file_menu = wx.Menu() self.new_menu_id = wx.Window.NewControlId() self.file_menu.Append( self.new_menu_id, item="&New\tCTRL+N", helpString="Create a new distiller configuration.") self.Bind(wx.EVT_MENU, self.on_new, id=self.new_menu_id) self.open_menu_id = wx.Window.NewControlId() self.file_menu.Append( self.open_menu_id, item="&Open\tCTRL+O", helpString="Open a distiller confuguration.") self.Bind(wx.EVT_MENU, self.on_open, id=self.open_menu_id) self.save_menu_id = wx.Window.NewControlId() self.save_menu_item = wx.MenuItem( self.file_menu, id=self.save_menu_id, text="&Save\tCTRL+S", helpString="Save the current distiller configuration. Error message will be shown it it is not well formed.") self.Bind(wx.EVT_MENU, self.on_save, id=self.save_menu_id) self.file_menu.Append(self.save_menu_item) self.file_menu.AppendSeparator() self.exit_menu_id = wx.ID_EXIT self.file_menu.Append( self.exit_menu_id, item="E&xit\tCTRL+Q", helpString="Exit the application.") self.Bind(wx.EVT_MENU, self.on_close, id=self.exit_menu_id) self.menu_bar = wx.MenuBar() self.menu_bar.Append(self.file_menu, "&File") self.SetMenuBar(self.menu_bar) def init_bitmaps(self): self.face_image_bitmap = wx.Bitmap(128, 128) self.face_image_pytorch = None self.face_mask_image_bitmap = wx.Bitmap(128, 128) self.face_mask_image_pytorch = None self.mask_on_face_image_bitmap = wx.Bitmap(128, 128) self.draw_nothing_yet_string_to_bitmap(self.face_image_bitmap, 128, 128) self.draw_nothing_yet_string_to_bitmap(self.face_mask_image_bitmap, 128, 128) self.draw_nothing_yet_string_to_bitmap(self.mask_on_face_image_bitmap, 128, 128) @contextmanager def create_panel(self, parent, sizer, *args, **kwargs): panel = wx.Panel(parent, *args, **kwargs) panel.SetSizer(sizer) panel.SetAutoLayout(1) try: yield panel, sizer finally: sizer.Fit(panel) def init_left_panel(self, parent): with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, sizer): self.face_image_panel = wx.Panel(panel, size=(128, 128), style=wx.SIMPLE_BORDER) self.face_image_panel.Bind(wx.EVT_PAINT, self.on_face_image_panel_paint) sizer.Add(self.face_image_panel, 0, wx.EXPAND) static_text = wx.StaticText(panel, label="Face", style=wx.ALIGN_CENTER) sizer.Add(static_text, 0, wx.EXPAND) self.face_mask_image_panel = wx.Panel(panel, size=(128, 128), style=wx.SIMPLE_BORDER) self.face_mask_image_panel.Bind(wx.EVT_PAINT, self.on_face_mask_image_panel_paint) sizer.Add(self.face_mask_image_panel, 0, wx.EXPAND) static_text = wx.StaticText(panel, label="Face mask", style=wx.ALIGN_CENTER) sizer.Add(static_text, 0, wx.EXPAND) self.mask_on_face_image_panel = wx.Panel(panel, size=(128, 128), style=wx.SIMPLE_BORDER) self.mask_on_face_image_panel.Bind(wx.EVT_PAINT, self.on_mask_on_face_image_panel_paint) sizer.Add(self.mask_on_face_image_panel, 0, wx.EXPAND) static_text = wx.StaticText(panel, label="Mask upon face", style=wx.ALIGN_CENTER) sizer.Add(static_text, 0, wx.EXPAND) return panel def on_erase_background(self, event): pass def on_face_image_panel_paint(self, event): wx.BufferedPaintDC(self.face_image_panel, self.face_image_bitmap) def on_face_mask_image_panel_paint(self, event): wx.BufferedPaintDC(self.face_mask_image_panel, self.face_mask_image_bitmap) def on_mask_on_face_image_panel_paint(self, event): wx.BufferedPaintDC(self.mask_on_face_image_panel, self.mask_on_face_image_bitmap) def init_middle_panel(self, parent): with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, sizer): sizer.Add(self.init_prefix_panel(panel), 0, wx.EXPAND) sizer.Add(self.init_character_image_file_name_panel(panel), 0, wx.EXPAND) sizer.Add(self.init_face_mask_image_file_name_panel(panel), 0, wx.EXPAND) sizer.Add(self.init_num_cpu_workers_panel(panel), 0, wx.EXPAND) sizer.Add(self.init_num_gpus_panel(panel), 0, wx.EXPAND) sizer.Add(self.init_face_morpher_random_seed_0_panel(panel), 0, wx.EXPAND) sizer.Add(self.init_face_morpher_random_seed_1_panel(panel), 0, wx.EXPAND) sizer.Add(self.init_face_morpher_batch_size_panel(panel), 0, wx.EXPAND) sizer.Add(self.init_body_morpher_random_seed_0_panel(panel), 0, wx.EXPAND) sizer.Add(self.init_body_morpher_random_seed_1_panel(panel), 0, wx.EXPAND) sizer.Add(self.init_body_morpher_batch_size_panel(panel), 0, wx.EXPAND) sizer.Add(self.init_num_training_examples_per_sample_output_panel(panel), 0, wx.EXPAND) self.run_button = wx.Button(panel, label="RUN") self.run_button.SetMinSize((-1, 64)) self.run_button.Bind(wx.EVT_BUTTON, self.on_run) sizer.Add(self.run_button, 1, wx.EXPAND) return panel def init_prefix_panel(self, parent): with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer): prefix_param_name_panel = self.create_param_name_panel_with_help_button( panel, "prefix (i.e. project directory)", self.create_help_button_func("distiller-ui-doc/params/prefix.html")) panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND) with self.create_panel(panel, wx.BoxSizer(wx.HORIZONTAL), style=wx.BORDER_NONE) \ as (prefix_panel, prefix_sizer): self.prefix_text_ctrl = wx.TextCtrl(prefix_panel, value="") self.prefix_text_ctrl.SetEditable(False) prefix_sizer.Add(self.prefix_text_ctrl, 1, wx.EXPAND) self.prefix_change_button = wx.Button(prefix_panel, label="Change...") self.prefix_change_button.Bind(wx.EVT_BUTTON, self.on_prefix_change_button) prefix_sizer.Add(self.prefix_change_button, 0, wx.EXPAND) panel_sizer.Add(prefix_panel, 1, wx.EXPAND) return panel def on_prefix_change_button(self, event): dir_dialog = wx.DirDialog(self, "Choose a directory.", style=wx.DD_DEFAULT_STYLE | wx.DD_NEW_DIR_BUTTON) if dir_dialog.ShowModal() != wx.ID_OK: return prefix_value = dir_dialog.GetPath() try: self.state.set_prefix(prefix_value) self.update_ui() except Exception as e: message_dialog = wx.MessageDialog(self, str(e), "Error", wx.OK | wx.ICON_ERROR) message_dialog.ShowModal() def init_character_image_file_name_panel(self, parent): with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer): prefix_param_name_panel = self.create_param_name_panel_with_help_button( panel, "character_image_file_name", self.create_help_button_func("distiller-ui-doc/params/character_image_file_name.html")) panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND) with self.create_panel(panel, wx.BoxSizer(wx.HORIZONTAL), style=wx.BORDER_NONE) as (sub_panel, sub_sizer): self.character_image_file_name_text_ctrl = wx.TextCtrl(sub_panel, value="") self.character_image_file_name_text_ctrl.SetEditable(False) sub_sizer.Add(self.character_image_file_name_text_ctrl, 1, wx.EXPAND) self.character_image_change_button = wx.Button(sub_panel, label="Change...") self.character_image_change_button.Bind(wx.EVT_BUTTON, self.on_character_image_change_button) sub_sizer.Add(self.character_image_change_button, 0, wx.EXPAND) panel_sizer.Add(sub_panel, 1, wx.EXPAND) return panel def on_character_image_change_button(self, event): file_dialog = wx.FileDialog(self, "Choose a PNG file", wildcard="*.png", style=wx.FD_OPEN) if file_dialog.ShowModal() != wx.ID_OK: return file_name = file_dialog.GetPath() try: self.state.set_character_image_file_name(file_name) self.update_face_image_bitmap(file_name) self.update_ui() except Exception as e: message_dialog = wx.MessageDialog(self, str(e), "Error", wx.OK | wx.ICON_ERROR) message_dialog.ShowModal() def update_face_image_bitmap(self, new_file_name: str): pil_image = PIL.Image.open(new_file_name) subimage = pil_image.crop((256 - 64, 80, 256 + 64, 208)) self.face_image_bitmap = wx.Bitmap.FromBufferRGBA(128, 128, subimage.convert("RGBA").tobytes()) self.face_image_pytorch = extract_pytorch_image_from_PIL_image(subimage).to(torch.float) self.update_mask_on_face_image_bitmap() def init_face_mask_image_file_name_panel(self, parent): with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer): prefix_param_name_panel = self.create_param_name_panel_with_help_button( panel, "face_mask_image_file_name", self.create_help_button_func("distiller-ui-doc/params/face_mask_image_file_name.html")) panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND) with self.create_panel(panel, wx.BoxSizer(wx.HORIZONTAL), style=wx.BORDER_NONE) as (sub_panel, sub_sizer): self.face_mask_image_file_name_text_ctrl = wx.TextCtrl(sub_panel, value="") self.face_mask_image_file_name_text_ctrl.SetEditable(False) sub_sizer.Add(self.face_mask_image_file_name_text_ctrl, 1, wx.EXPAND) self.face_mask_image_file_name_change_button = wx.Button(sub_panel, label="Change...") self.face_mask_image_file_name_change_button.Bind(wx.EVT_BUTTON, self.on_face_mask_image_change_button) sub_sizer.Add(self.face_mask_image_file_name_change_button, 0, wx.EXPAND) panel_sizer.Add(sub_panel, 1, wx.EXPAND) return panel def on_face_mask_image_change_button(self, event): file_dialog = wx.FileDialog(self, "Choose a PNG file", wildcard="*.png", style=wx.FD_OPEN) if file_dialog.ShowModal() != wx.ID_OK: return file_name = file_dialog.GetPath() try: self.state.set_face_mask_image_file_name(file_name) self.update_face_mask_image_bitmap(file_name) self.update_ui() except Exception as e: message_dialog = wx.MessageDialog(self, str(e), "Error", wx.OK | wx.ICON_ERROR) message_dialog.ShowModal() def update_face_mask_image_bitmap(self, new_file_name): pil_image = PIL.Image.open(new_file_name) subimage = pil_image.crop((256 - 64, 80, 256 + 64, 208)) self.face_mask_image_bitmap = wx.Bitmap.FromBufferRGBA(128, 128, subimage.convert("RGBA").tobytes()) self.face_mask_image_pytorch = extract_pytorch_image_from_PIL_image(subimage).to(torch.float) self.face_mask_image_pytorch = self.face_mask_image_pytorch[0:1, :, :] self.update_mask_on_face_image_bitmap() def update_mask_on_face_image_bitmap(self): if self.face_image_pytorch is None: return if self.face_mask_image_pytorch is None: return mask_on_face_image = (0.5 * self.face_image_pytorch) + (0.5 * self.face_mask_image_pytorch) numpy_image = convert_output_image_from_torch_to_numpy(mask_on_face_image) wx_image = wx.ImageFromBuffer( numpy_image.shape[0], numpy_image.shape[1], numpy_image[:, :, 0:3].tobytes(), numpy_image[:, :, 3].tobytes()) self.mask_on_face_image_bitmap = wx_image.ConvertToBitmap() def init_num_cpu_workers_panel(self, parent): with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer): prefix_param_name_panel = self.create_param_name_panel_with_help_button( panel, "num_cpu_workers", self.create_help_button_func("distiller-ui-doc/params/num_cpu_workers.html")) panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND) num_cpus = multiprocessing.cpu_count() self.num_cpu_workers_spin_ctrl = wx.SpinCtrl(panel, initial=1, min=1, max=num_cpus) @wx_bind_event(self.num_cpu_workers_spin_ctrl, wx.EVT_SPINCTRL) def on_num_cpu_workers_spin_ctrl(event): self.state.set_num_cpu_workers(self.num_cpu_workers_spin_ctrl.GetValue()) self.Refresh() panel_sizer.Add(self.num_cpu_workers_spin_ctrl, 1, wx.EXPAND) return panel def init_num_gpus_panel(self, parent): with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer): prefix_param_name_panel = self.create_param_name_panel_with_help_button( panel, "num_gpus", self.create_help_button_func("distiller-ui-doc/params/num_gpus.html")) panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND) num_gpus = torch.cuda.device_count() self.num_gpus_spin_ctrl = wx.SpinCtrl(panel, initial=1, min=1, max=max(1, num_gpus)) @wx_bind_event(self.num_gpus_spin_ctrl, wx.EVT_SPINCTRL) def on_num_cpu_workers_spin_ctrl(event): self.state.set_num_gpus(self.num_gpus_spin_ctrl.GetValue()) self.Refresh() panel_sizer.Add(self.num_gpus_spin_ctrl, 1, wx.EXPAND) return panel def init_face_morpher_random_seed_0_panel(self, parent): with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer): prefix_param_name_panel = self.create_param_name_panel_with_help_button( panel, "face_morpher_random_seed_0", self.create_help_button_func("distiller-ui-doc/params/face_morpher_random_seed_0.html")) panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND) with self.create_panel(panel, wx.BoxSizer(wx.HORIZONTAL), style=wx.BORDER_NONE) as (sub_panel, sub_sizer): initial_value = random.randint(0, 2 ** 64 - 1) self.face_morpher_random_seed_0_int_ctrl = wx.lib.intctrl.IntCtrl( sub_panel, value=initial_value, min=0, max=0x_ffff_ffff_ffff_ffff) @wx_bind_event(self.face_morpher_random_seed_0_int_ctrl, wx.EVT_TEXT) def on_face_morpher_random_seed_0_int_ctrl_text(event): self.state.set_face_morpher_random_seed_0(self.face_morpher_random_seed_0_int_ctrl.GetValue()) sub_sizer.Add(self.face_morpher_random_seed_0_int_ctrl, 1, wx.EXPAND) self.face_morpher_random_seed_0_randomize_button = wx.Button(sub_panel, label="Randomize") @wx_bind_event(self.face_morpher_random_seed_0_randomize_button, wx.EVT_BUTTON) def on_face_morpher_random_seed_0_randomize_button(event): new_value = random.randint(0, 0x_ffff_ffff_ffff_ffff) self.face_morpher_random_seed_0_int_ctrl.SetValue(new_value) self.state.set_face_morpher_random_seed_0(new_value) sub_sizer.Add(self.face_morpher_random_seed_0_randomize_button, 0, wx.EXPAND) panel_sizer.Add(sub_panel, 1, wx.EXPAND) return panel def init_face_morpher_random_seed_1_panel(self, parent): with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer): prefix_param_name_panel = self.create_param_name_panel_with_help_button( panel, "face_morpher_random_seed_1", self.create_help_button_func("distiller-ui-doc/params/face_morpher_random_seed_1.html")) panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND) with self.create_panel(panel, wx.BoxSizer(wx.HORIZONTAL), style=wx.BORDER_NONE) as (sub_panel, sub_sizer): initial_value = random.randint(0, 2 ** 64 - 1) self.face_morpher_random_seed_1_int_ctrl = wx.lib.intctrl.IntCtrl( sub_panel, value=initial_value, min=0, max=0x_ffff_ffff_ffff_ffff) @wx_bind_event(self.face_morpher_random_seed_1_int_ctrl, wx.EVT_TEXT) def on_face_morpher_random_seed_1_int_ctrl_text(event): self.state.set_face_morpher_random_seed_1(self.face_morpher_random_seed_1_int_ctrl.GetValue()) sub_sizer.Add(self.face_morpher_random_seed_1_int_ctrl, 1, wx.EXPAND) self.face_morpher_random_seed_1_randomize_button = wx.Button(sub_panel, label="Randomize") @wx_bind_event(self.face_morpher_random_seed_1_randomize_button, wx.EVT_BUTTON) def on_face_morpher_random_seed_1_randomize_button(event): new_value = random.randint(0, 0x_ffff_ffff_ffff_ffff) self.face_morpher_random_seed_1_int_ctrl.SetValue(new_value) self.state.set_face_morpher_random_seed_1(new_value) sub_sizer.Add(self.face_morpher_random_seed_1_randomize_button, 0, wx.EXPAND) panel_sizer.Add(sub_panel, 1, wx.EXPAND) return panel def init_face_morpher_batch_size_panel(self, parent): with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer): prefix_param_name_panel = self.create_param_name_panel_with_help_button( panel, "face_morpher_batch_size", self.create_help_button_func("distiller-ui-doc/params/face_morpher_batch_size.html")) panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND) self.face_morpher_batch_size_spin_ctrl = wx.SpinCtrl(panel, initial=8, min=1, max=8) @wx_bind_event(self.face_morpher_batch_size_spin_ctrl, wx.EVT_SPINCTRL) def on_face_morpher_batch_size_spin_ctrl(event): self.state.set_face_morpher_batch_size(self.face_morpher_batch_size_spin_ctrl.GetValue()) panel_sizer.Add(self.face_morpher_batch_size_spin_ctrl, 1, wx.EXPAND) return panel def init_body_morpher_random_seed_0_panel(self, parent): with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer): prefix_param_name_panel = self.create_param_name_panel_with_help_button( panel, "body_morpher_random_seed_0", self.create_help_button_func("distiller-ui-doc/params/body_morpher_random_seed_0.html")) panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND) with self.create_panel(panel, wx.BoxSizer(wx.HORIZONTAL), style=wx.BORDER_NONE) as (sub_panel, sub_sizer): initial_value = random.randint(0, 2 ** 64 - 1) self.body_morpher_random_seed_0_int_ctrl = wx.lib.intctrl.IntCtrl( sub_panel, value=initial_value, min=0, max=0x_ffff_ffff_ffff_ffff) @wx_bind_event(self.body_morpher_random_seed_0_int_ctrl, wx.EVT_TEXT) def on_body_morpher_random_seed_0_int_ctrl_text(event): self.state.set_body_morpher_random_seed_0(self.body_morpher_random_seed_0_int_ctrl.GetValue()) sub_sizer.Add(self.body_morpher_random_seed_0_int_ctrl, 1, wx.EXPAND) self.body_morpher_random_seed_0_randomize_button = wx.Button(sub_panel, label="Randomize") @wx_bind_event(self.body_morpher_random_seed_0_randomize_button, wx.EVT_BUTTON) def on_body_morpher_random_seed_0_randomize_button(event): new_value = random.randint(0, 0x_ffff_ffff_ffff_ffff) self.body_morpher_random_seed_0_int_ctrl.SetValue(new_value) self.state.set_body_morpher_random_seed_0(new_value) sub_sizer.Add(self.body_morpher_random_seed_0_randomize_button, 0, wx.EXPAND) panel_sizer.Add(sub_panel, 1, wx.EXPAND) return panel def init_body_morpher_random_seed_1_panel(self, parent): with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer): prefix_param_name_panel = self.create_param_name_panel_with_help_button( panel, "body_morpher_random_seed_1", self.create_help_button_func("distiller-ui-doc/params/body_morpher_random_seed_1.html")) panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND) with self.create_panel(panel, wx.BoxSizer(wx.HORIZONTAL), style=wx.BORDER_NONE) as (sub_panel, sub_sizer): initial_value = random.randint(0, 2 ** 64 - 1) self.body_morpher_random_seed_1_int_ctrl = wx.lib.intctrl.IntCtrl( sub_panel, value=initial_value, min=0, max=0x_ffff_ffff_ffff_ffff) @wx_bind_event(self.body_morpher_random_seed_1_int_ctrl, wx.EVT_TEXT) def on_body_morpher_random_seed_1_int_ctrl_text(event): self.state.set_body_morpher_random_seed_1(self.body_morpher_random_seed_1_int_ctrl.GetValue()) sub_sizer.Add(self.body_morpher_random_seed_1_int_ctrl, 1, wx.EXPAND) self.body_morpher_random_seed_1_randomize_button = wx.Button(sub_panel, label="Randomize") @wx_bind_event(self.body_morpher_random_seed_1_randomize_button, wx.EVT_BUTTON) def on_body_morpher_random_seed_1_randomize_button(event): new_value = random.randint(0, 0x_ffff_ffff_ffff_ffff) self.body_morpher_random_seed_1_int_ctrl.SetValue(new_value) self.state.set_body_morpher_random_seed_1(new_value) sub_sizer.Add(self.body_morpher_random_seed_1_randomize_button, 0, wx.EXPAND) panel_sizer.Add(sub_panel, 1, wx.EXPAND) return panel def init_body_morpher_batch_size_panel(self, parent): with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer): prefix_param_name_panel = self.create_param_name_panel_with_help_button( panel, "body_morpher_batch_size", self.create_help_button_func("distiller-ui-doc/params/body_morpher_batch_size.html")) panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND) self.body_morpher_batch_size_spin_ctrl = wx.SpinCtrl(panel, initial=8, min=1, max=8) @wx_bind_event(self.body_morpher_batch_size_spin_ctrl, wx.EVT_SPINCTRL) def on_body_morpher_batch_size_spin_ctrl(event): self.state.set_body_morpher_batch_size(self.body_morpher_batch_size_spin_ctrl.GetValue()) panel_sizer.Add(self.body_morpher_batch_size_spin_ctrl, 1, wx.EXPAND) return panel def init_num_training_examples_per_sample_output_panel(self, parent): with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer): prefix_param_name_panel = self.create_param_name_panel_with_help_button( panel, "num_training_examples_per_sample_output", self.create_help_button_func("distiller-ui-doc/params/num_training_examples_per_sample_output.html")) panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND) self.num_training_examples_per_sample_output_combobox = \ wx.ComboBox(panel, value="10_000", choices=DistillerUiMainFrame.NUM_TRAINING_EXAMPLES_PER_SAMPLE_OUTPUT_CHOICES) @wx_bind_event(self.num_training_examples_per_sample_output_combobox, wx.EVT_COMBOBOX) def on_num_training_examples_per_sample_output_combobox(event): index = self.num_training_examples_per_sample_output_combobox.GetSelection() if index == 3: self.state.set_face_morpher_num_training_examples_per_sample_output(None) self.state.set_body_morpher_num_training_examples_per_sample_output(None) else: selected = DistillerUiMainFrame.NUM_TRAINING_EXAMPLES_PER_SAMPLE_OUTPUT_CHOICES[index] new_value = int(selected) self.state.set_face_morpher_num_training_examples_per_sample_output(new_value) self.state.set_body_morpher_num_training_examples_per_sample_output(new_value) panel_sizer.Add(self.num_training_examples_per_sample_output_combobox, 1, wx.EXPAND) return panel def on_close(self, event): if self.state.dirty: confirmation_dialog = wx.MessageDialog( parent=self, message=f"You have not saved your work. Do you want to exit anyway?", caption="Confirmation", style=wx.YES_NO | wx.ICON_QUESTION) result = confirmation_dialog.ShowModal() if result == wx.ID_NO: return self.Destroy() def create_help_button_func(self, html_file_name: str): def init_help_button_func(parent): button = wx.Button(parent, label="Help") @wx_bind_event(button, wx.EVT_BUTTON) def on_prefix_button(event): self.html_window.LoadPage(html_file_name) self.Refresh() return button return init_help_button_func def create_param_name_panel_with_help_button( self, parent, param_name: str, help_button_func: Callable[[wx.Window], wx.Button]): with self.create_panel(parent, wx.BoxSizer(wx.HORIZONTAL), style=wx.NO_BORDER) \ as (panel, sizer): title_text_panel = self.create_vertically_centered_text_panel( panel, param_name, DistillerUiMainFrame.PARAM_NAME_STATIC_TEXT_MIN_WIDTH) sizer.Add(title_text_panel, 1, wx.EXPAND) help_button = help_button_func(panel) sizer.Add(help_button, 0, wx.EXPAND) return panel def create_vertically_centered_text_panel(self, parent, text: str, min_width: int): with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.NO_BORDER) as (panel, sizer): sizer.AddStretchSpacer(1) text = wx.StaticText( panel, label=text, style=wx.ALIGN_CENTER) text.SetMinSize((min_width, -1)) sizer.Add(text, 0, wx.EXPAND) sizer.AddStretchSpacer(1) return panel def init_right_panel(self, parent): with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, sizer): self.html_window = wx.html.HtmlWindow(panel) self.html_window.SetMinSize((600, 600)) self.html_window.SetFonts("Times New Roman", "Courier New", sizes=[10, 12, 14, 16, 18, 20, 24]) self.html_window.LoadPage("distiller-ui-doc/index.html") sizer.Add(self.html_window, 1, wx.EXPAND) go_to_main_documentation_button = wx.Button(panel, label="Go to Main Documentation") sizer.Add(go_to_main_documentation_button, 0, wx.EXPAND) @wx_bind_event(go_to_main_documentation_button, wx.EVT_BUTTON) def on_go_to_main_documentation_button(event): self.html_window.LoadPage("distiller-ui-doc/index.html") self.Refresh() return panel def populate_distiller_config(self): self.state.config.prefix = self.prefix_text_ctrl.GetValue() self.state.config.character_image_file_name = self.character_image_file_name_text_ctrl.GetValue() self.state.config.face_mask_image_file_name = self.face_mask_image_file_name_text_ctrl.GetValue() self.state.config.num_cpu_workers = self.num_cpu_workers_spin_ctrl.GetValue() self.state.config.num_gpus = self.num_gpus_spin_ctrl.GetValue() self.state.config.face_morpher_random_seed_0 = self.face_morpher_random_seed_0_int_ctrl.GetValue() self.state.config.face_morpher_random_seed_1 = self.face_morpher_random_seed_1_int_ctrl.GetValue() self.state.config.face_morpher_batch_size = self.face_morpher_batch_size_spin_ctrl.GetValue() self.state.config.body_morpher_random_seed_0 = self.body_morpher_random_seed_0_int_ctrl.GetValue() self.state.config.body_morpher_random_seed_1 = self.body_morpher_random_seed_1_int_ctrl.GetValue() self.state.config.body_morpher_batch_size = self.body_morpher_batch_size_spin_ctrl.GetValue() if self.num_training_examples_per_sample_output_combobox.GetValue() == \ DistillerUiMainFrame.NUM_TRAINING_EXAMPLES_PER_SAMPLE_OUTPUT_CHOICES[-1]: self.state.config.face_morpher_num_training_examples_per_sample_output = None self.state.config.body_morpher_num_training_examples_per_sample_output = None else: value = int(self.num_training_examples_per_sample_output_combobox.GetValue()) self.state.config.face_morpher_num_training_examples_per_sample_output = value self.state.config.body_morpher_num_training_examples_per_sample_output = value def update_ui(self): self.prefix_text_ctrl.SetValue(self.state.config.prefix) self.character_image_file_name_text_ctrl.SetValue(self.state.config.character_image_file_name) self.face_mask_image_file_name_text_ctrl.SetValue(self.state.config.face_mask_image_file_name) if not self.state.can_show_character_image(): self.draw_nothing_yet_string_to_bitmap(self.face_image_bitmap, 128, 128) if not self.state.can_show_face_mask_image(): self.draw_nothing_yet_string_to_bitmap(self.face_mask_image_bitmap, 128, 128) if not self.state.can_show_mask_on_face_image(): self.draw_nothing_yet_string_to_bitmap(self.mask_on_face_image_bitmap, 128, 128) self.num_cpu_workers_spin_ctrl.SetValue(self.state.config.num_cpu_workers) self.num_gpus_spin_ctrl.SetValue(self.state.config.num_gpus) self.face_morpher_random_seed_0_int_ctrl.SetValue(self.state.config.face_morpher_random_seed_0) self.face_morpher_random_seed_1_int_ctrl.SetValue(self.state.config.face_morpher_random_seed_1) self.face_morpher_batch_size_spin_ctrl.SetValue(self.state.config.face_morpher_batch_size) self.body_morpher_random_seed_0_int_ctrl.SetValue(self.state.config.body_morpher_random_seed_0) self.body_morpher_random_seed_1_int_ctrl.SetValue(self.state.config.body_morpher_random_seed_1) self.body_morpher_batch_size_spin_ctrl.SetValue(self.state.config.body_morpher_batch_size) if self.state.config.body_morpher_num_training_examples_per_sample_output is None: self.num_training_examples_per_sample_output_combobox.SetSelection(3) else: choices = [int(x) for x in DistillerUiMainFrame.NUM_TRAINING_EXAMPLES_PER_SAMPLE_OUTPUT_CHOICES[:-1]] self.num_training_examples_per_sample_output_combobox.SetSelection( choices.index(self.state.config.body_morpher_num_training_examples_per_sample_output)) self.save_menu_item.Enable(self.state.can_save()) self.Refresh() def draw_nothing_yet_string_to_bitmap(self, bitmap, width: int, height: int): dc = wx.MemoryDC() dc.SelectObject(bitmap) dc.Clear() font = wx.Font(wx.FontInfo(14).Family(wx.FONTFAMILY_SWISS)) dc.SetFont(font) w, h = dc.GetTextExtent("Nothing yet!") dc.DrawText("Nothing yet!", (width - w) // 2, (height - h) // 2) del dc def try_saving(self): if not self.state.can_save(): message_dialog = wx.MessageDialog( self, "Cannot save yet! Please make sure you set the prefix, character_image_file_name, " "and face_mask_image_file_name first.", "Error", wx.OK | wx.ICON_ERROR) message_dialog.ShowModal() return False else: if self.state.need_to_check_overwrite(): confirmation_dialog = wx.MessageDialog( parent=self, message=f"Overwriting {self.state.config.config_yaml_file_name()}?", caption="Confirmation", style=wx.YES_NO | wx.CANCEL | wx.ICON_QUESTION) result = confirmation_dialog.ShowModal() if result == wx.ID_YES: self.state.save() return True elif result == wx.ID_NO: return False else: return False else: self.state.save() return True def on_save(self, event): return self.try_saving() def on_new(self, event): if self.state.dirty: confirmation_dialog = wx.MessageDialog( parent=self, message=f"You have not saved the current config. Do you want to proceed?", caption="Confirmation", style=wx.YES_NO | wx.ICON_QUESTION) result = confirmation_dialog.ShowModal() if result == wx.ID_NO: return self.state = DistillerConfigState() self.update_ui() def on_open(self, event): if self.state.dirty: confirmation_dialog = wx.MessageDialog( parent=self, message=f"You have not saved the current config. Do you want to proceed?", caption="Confirmation", style=wx.YES_NO | wx.ICON_QUESTION) result = confirmation_dialog.ShowModal() if result == wx.ID_NO: return file_dialog = wx.FileDialog(self, "Choose a YAML file", wildcard="*.yaml", style=wx.FD_OPEN) if file_dialog.ShowModal() != wx.ID_OK: return file_name = file_dialog.GetPath() try: self.state.load(file_name) self.face_image_pytorch = None self.face_mask_image_pytorch = None self.update_face_image_bitmap(self.state.config.character_image_file_name) self.update_face_mask_image_bitmap(self.state.config.face_mask_image_file_name) self.update_ui() except Exception as e: message_dialog = wx.MessageDialog(self, str(e), "Error", wx.OK | wx.ICON_ERROR) message_dialog.ShowModal() def on_run(self, event): try: self.state.config.check() except Exception as e: message_dialog = wx.MessageDialog(self, str(e), "Error", wx.OK | wx.ICON_ERROR) message_dialog.ShowModal() return if self.state.dirty: message_dialog = wx.MessageDialog( self, "Please save the configuration first.", "Error", wx.OK | wx.ICON_ERROR) message_dialog.ShowModal() return self.config_file_to_run = self.state.config.config_yaml_file_name() self.Destroy() ================================================ FILE: src/tha4/image_util.py ================================================ import math import PIL.Image import numpy import torch from matplotlib import cm from tha4.shion.base.image_util import numpy_linear_to_srgb, pytorch_rgba_to_numpy_image, pytorch_rgb_to_numpy_image, \ torch_linear_to_srgb def grid_change_to_numpy_image(torch_image, num_channels=3): height = torch_image.shape[1] width = torch_image.shape[2] size_image = (torch_image[0, :, :] ** 2 + torch_image[1, :, :] ** 2).sqrt().view(height, width, 1).numpy() hsv = cm.get_cmap('hsv') angle_image = hsv(((torch.atan2( torch_image[0, :, :].view(height * width), torch_image[1, :, :].view(height * width)).view(height, width) + math.pi) / (2 * math.pi)).numpy()) * 3 numpy_image = size_image * angle_image[:, :, 0:3] rgb_image = numpy_linear_to_srgb(numpy_image) if num_channels == 3: return rgb_image elif num_channels == 4: return numpy.concatenate([rgb_image, numpy.ones_like(size_image)], axis=2) else: raise RuntimeError("Unsupported num_channels: " + str(num_channels)) def resize_PIL_image(pil_image, size=(256, 256)): w, h = pil_image.size d = min(w, h) r = ((w - d) // 2, (h - d) // 2, (w + d) // 2, (h + d) // 2) return pil_image.resize(size, resample=PIL.Image.LANCZOS, box=r) def convert_output_image_from_torch_to_numpy(output_image): if output_image.shape[2] == 2: h, w, c = output_image.shape numpy_image = torch.transpose(output_image.reshape(h * w, c), 0, 1).reshape(c, h, w) elif output_image.shape[0] == 4: numpy_image = pytorch_rgba_to_numpy_image(output_image) elif output_image.shape[0] == 3: numpy_image = pytorch_rgb_to_numpy_image(output_image) elif output_image.shape[0] == 1: c, h, w = output_image.shape alpha_image = torch.cat([output_image.repeat(3, 1, 1) * 2.0 - 1.0, torch.ones(1, h, w)], dim=0) numpy_image = pytorch_rgba_to_numpy_image(alpha_image) elif output_image.shape[0] == 2: numpy_image = grid_change_to_numpy_image(output_image, num_channels=4) else: raise RuntimeError("Unsupported # image channels: %d" % output_image.shape[0]) numpy_image = numpy.uint8(numpy.rint(numpy_image * 255.0)) return numpy_image def convert_linear_to_srgb(image: torch.Tensor) -> torch.Tensor: rgb_image = torch_linear_to_srgb(image[0:3, :, :]) return torch.cat([rgb_image, image[3:4, :, :]], dim=0) ================================================ FILE: src/tha4/mocap/__init__.py ================================================ ================================================ FILE: src/tha4/mocap/ifacialmocap_constants.py ================================================ EYE_LOOK_IN_LEFT = "eyeLookInLeft" EYE_LOOK_OUT_LEFT = "eyeLookOutLeft" EYE_LOOK_DOWN_LEFT = "eyeLookDownLeft" EYE_LOOK_UP_LEFT = "eyeLookUpLeft" EYE_BLINK_LEFT = "eyeBlinkLeft" EYE_SQUINT_LEFT = "eyeSquintLeft" EYE_WIDE_LEFT = "eyeWideLeft" EYE_LOOK_IN_RIGHT = "eyeLookInRight" EYE_LOOK_OUT_RIGHT = "eyeLookOutRight" EYE_LOOK_DOWN_RIGHT = "eyeLookDownRight" EYE_LOOK_UP_RIGHT = "eyeLookUpRight" EYE_BLINK_RIGHT = "eyeBlinkRight" EYE_SQUINT_RIGHT = "eyeSquintRight" EYE_WIDE_RIGHT = "eyeWideRight" BROW_DOWN_LEFT = "browDownLeft" BROW_OUTER_UP_LEFT = "browOuterUpLeft" BROW_DOWN_RIGHT = "browDownRight" BROW_OUTER_UP_RIGHT = "browOuterUpRight" BROW_INNER_UP = "browInnerUp" NOSE_SNEER_LEFT = "noseSneerLeft" NOSE_SNEER_RIGHT = "noseSneerRight" CHEEK_SQUINT_LEFT = "cheekSquintLeft" CHEEK_SQUINT_RIGHT = "cheekSquintRight" CHEEK_PUFF = "cheekPuff" MOUTH_LEFT = "mouthLeft" MOUTH_DIMPLE_LEFT = "mouthDimpleLeft" MOUTH_FROWN_LEFT = "mouthFrownLeft" MOUTH_LOWER_DOWN_LEFT = "mouthLowerDownLeft" MOUTH_PRESS_LEFT = "mouthPressLeft" MOUTH_SMILE_LEFT = "mouthSmileLeft" MOUTH_STRETCH_LEFT = "mouthStretchLeft" MOUTH_UPPER_UP_LEFT = "mouthUpperUpLeft" MOUTH_RIGHT = "mouthRight" MOUTH_DIMPLE_RIGHT = "mouthDimpleRight" MOUTH_FROWN_RIGHT = "mouthFrownRight" MOUTH_LOWER_DOWN_RIGHT = "mouthLowerDownRight" MOUTH_PRESS_RIGHT = "mouthPressRight" MOUTH_SMILE_RIGHT = "mouthSmileRight" MOUTH_STRETCH_RIGHT = "mouthStretchRight" MOUTH_UPPER_UP_RIGHT = "mouthUpperUpRight" MOUTH_CLOSE = "mouthClose" MOUTH_FUNNEL = "mouthFunnel" MOUTH_PUCKER = "mouthPucker" MOUTH_ROLL_LOWER = "mouthRollLower" MOUTH_ROLL_UPPER = "mouthRollUpper" MOUTH_SHRUG_LOWER = "mouthShrugLower" MOUTH_SHRUG_UPPER = "mouthShrugUpper" JAW_LEFT = "jawLeft" JAW_RIGHT = "jawRight" JAW_FORWARD = "jawForward" JAW_OPEN = "jawOpen" TONGUE_OUT = "tongueOut" BLENDSHAPE_NAMES = [ EYE_LOOK_IN_LEFT, # 0 EYE_LOOK_OUT_LEFT, # 1 EYE_LOOK_DOWN_LEFT, # 2 EYE_LOOK_UP_LEFT, # 3 EYE_BLINK_LEFT, # 4 EYE_SQUINT_LEFT, # 5 EYE_WIDE_LEFT, # 6 EYE_LOOK_IN_RIGHT, # 7 EYE_LOOK_OUT_RIGHT, # 8 EYE_LOOK_DOWN_RIGHT, # 9 EYE_LOOK_UP_RIGHT, # 10 EYE_BLINK_RIGHT, # 11 EYE_SQUINT_RIGHT, # 12 EYE_WIDE_RIGHT, # 13 BROW_DOWN_LEFT, # 14 BROW_OUTER_UP_LEFT, # 15 BROW_DOWN_RIGHT, # 16 BROW_OUTER_UP_RIGHT, # 17 BROW_INNER_UP, # 18 NOSE_SNEER_LEFT, # 19 NOSE_SNEER_RIGHT, # 20 CHEEK_SQUINT_LEFT, # 21 CHEEK_SQUINT_RIGHT, # 22 CHEEK_PUFF, # 23 MOUTH_LEFT, # 24 MOUTH_DIMPLE_LEFT, # 25 MOUTH_FROWN_LEFT, # 26 MOUTH_LOWER_DOWN_LEFT, # 27 MOUTH_PRESS_LEFT, # 28 MOUTH_SMILE_LEFT, # 29 MOUTH_STRETCH_LEFT, # 30 MOUTH_UPPER_UP_LEFT, # 31 MOUTH_RIGHT, # 32 MOUTH_DIMPLE_RIGHT, # 33 MOUTH_FROWN_RIGHT, # 34 MOUTH_LOWER_DOWN_RIGHT, # 35 MOUTH_PRESS_RIGHT, # 36 MOUTH_SMILE_RIGHT, # 37 MOUTH_STRETCH_RIGHT, # 38 MOUTH_UPPER_UP_RIGHT, # 39 MOUTH_CLOSE, # 40 MOUTH_FUNNEL, # 41 MOUTH_PUCKER, # 42 MOUTH_ROLL_LOWER, # 43 MOUTH_ROLL_UPPER, # 44 MOUTH_SHRUG_LOWER, # 45 MOUTH_SHRUG_UPPER, # 46 JAW_LEFT, # 47 JAW_RIGHT, # 48 JAW_FORWARD, # 49 JAW_OPEN, # 50 TONGUE_OUT, # 51 ] EYE_LEFT_BLENDSHAPES = [ EYE_LOOK_IN_LEFT, # 0 EYE_LOOK_OUT_LEFT, # 1 EYE_LOOK_DOWN_LEFT, # 2 EYE_LOOK_UP_LEFT, # 3 EYE_BLINK_LEFT, # 4 EYE_SQUINT_LEFT, # 5 EYE_WIDE_LEFT, # 6 ] EYE_RIGHT_BLENDSHAPES = [ EYE_LOOK_IN_RIGHT, # 7 EYE_LOOK_OUT_RIGHT, # 8 EYE_LOOK_DOWN_RIGHT, # 9 EYE_LOOK_UP_RIGHT, # 10 EYE_BLINK_RIGHT, # 11 EYE_SQUINT_RIGHT, # 12 EYE_WIDE_RIGHT, # 13 ] BROW_LEFT_BLENDSHAPES = [ BROW_DOWN_LEFT, # 14 BROW_OUTER_UP_LEFT, # 15 ] BROW_RIGHT_BLENDSHAPES = [ BROW_DOWN_RIGHT, # 16 BROW_OUTER_UP_RIGHT, # 17 ] BROW_BOTH_BLENDSHAPES = [ BROW_INNER_UP, # 18 ] NOSE_BLENDSHAPES = [ NOSE_SNEER_LEFT, # 19 NOSE_SNEER_RIGHT, # 20 ] CHECK_BLENDSHAPES = [ CHEEK_SQUINT_LEFT, # 21 CHEEK_SQUINT_RIGHT, # 22 CHEEK_PUFF, # 23 ] MOUTH_LEFT_BLENDSHAPES = [ MOUTH_LEFT, # 24 MOUTH_DIMPLE_LEFT, # 25 MOUTH_FROWN_LEFT, # 26 MOUTH_LOWER_DOWN_LEFT, # 27 MOUTH_PRESS_LEFT, # 28 MOUTH_SMILE_LEFT, # 29 MOUTH_STRETCH_LEFT, # 30 MOUTH_UPPER_UP_LEFT, # 31 ] MOUTH_RIGHT_BLENDSHAPES = [ MOUTH_RIGHT, # 32 MOUTH_DIMPLE_RIGHT, # 33 MOUTH_FROWN_RIGHT, # 34 MOUTH_LOWER_DOWN_RIGHT, # 35 MOUTH_PRESS_RIGHT, # 36 MOUTH_SMILE_RIGHT, # 37 MOUTH_STRETCH_RIGHT, # 38 MOUTH_UPPER_UP_RIGHT, # 39 ] MOUTH_BOTH_BLENDSHAPES = [ MOUTH_CLOSE, # 40 MOUTH_FUNNEL, # 41 MOUTH_PUCKER, # 42 MOUTH_ROLL_LOWER, # 43 MOUTH_ROLL_UPPER, # 44 MOUTH_SHRUG_LOWER, # 45 MOUTH_SHRUG_UPPER, # 46 ] JAW_BLENDSHAPES = [ JAW_LEFT, # 47 JAW_RIGHT, # 48 JAW_FORWARD, # 49 JAW_OPEN, # 50 ] TONGUE_BLENDSHAPES = [ TONGUE_OUT, # 51 ] COLUMN_0_BLENDSHAPES = EYE_RIGHT_BLENDSHAPES + BROW_RIGHT_BLENDSHAPES + [NOSE_SNEER_RIGHT, CHEEK_SQUINT_RIGHT] COLUMN_1_BLENDSHAPES = EYE_LEFT_BLENDSHAPES + BROW_LEFT_BLENDSHAPES + [NOSE_SNEER_LEFT, CHEEK_SQUINT_LEFT] COLUMN_2_BLENDSHAPES = MOUTH_RIGHT_BLENDSHAPES + [JAW_RIGHT] COLUMN_3_BLENDSHAPES = MOUTH_LEFT_BLENDSHAPES + [JAW_LEFT] COLUMN_4_BLENDSHAPES = [BROW_INNER_UP, CHEEK_PUFF] + MOUTH_BOTH_BLENDSHAPES + [JAW_FORWARD, JAW_OPEN, TONGUE_OUT] BLENDSHAPE_COLUMNS = [ COLUMN_0_BLENDSHAPES, COLUMN_1_BLENDSHAPES, COLUMN_2_BLENDSHAPES, COLUMN_3_BLENDSHAPES, COLUMN_4_BLENDSHAPES, ] RIGHT_EYE_BONE_X = "rightEyeBoneX" RIGHT_EYE_BONE_Y = "rightEyeBoneY" RIGHT_EYE_BONE_Z = "rightEyeBoneZ" RIGHT_EYE_BONE_ROTATIONS = [RIGHT_EYE_BONE_X, RIGHT_EYE_BONE_Y, RIGHT_EYE_BONE_Z] LEFT_EYE_BONE_X = "leftEyeBoneX" LEFT_EYE_BONE_Y = "leftEyeBoneY" LEFT_EYE_BONE_Z = "leftEyeBoneZ" LEFT_EYE_BONE_ROTATIONS = [LEFT_EYE_BONE_X, LEFT_EYE_BONE_Y, LEFT_EYE_BONE_Z] HEAD_BONE_X = "headBoneX" HEAD_BONE_Y = "headBoneY" HEAD_BONE_Z = "headBoneZ" HEAD_BONE_ROTATIONS = [HEAD_BONE_X, HEAD_BONE_Y, HEAD_BONE_Z] ROTATION_NAMES = RIGHT_EYE_BONE_ROTATIONS + LEFT_EYE_BONE_ROTATIONS + HEAD_BONE_ROTATIONS RIGHT_EYE_BONE_QUAT = "rightEyeBoneQuat" LEFT_EYE_BONE_QUAT = "leftEyeBoneQuat" HEAD_BONE_QUAT = "headBoneQuat" QUATERNION_NAMES = [ RIGHT_EYE_BONE_QUAT, LEFT_EYE_BONE_QUAT, HEAD_BONE_QUAT ] IFACIALMOCAP_DATETIME_FORMAT = "%Y/%m/%d-%H:%M:%S.%f" ================================================ FILE: src/tha4/mocap/ifacialmocap_pose.py ================================================ from tha4.mocap.ifacialmocap_constants import BLENDSHAPE_NAMES, HEAD_BONE_X, HEAD_BONE_Y, HEAD_BONE_Z, \ HEAD_BONE_QUAT, LEFT_EYE_BONE_X, LEFT_EYE_BONE_Y, LEFT_EYE_BONE_Z, LEFT_EYE_BONE_QUAT, RIGHT_EYE_BONE_X, \ RIGHT_EYE_BONE_Y, RIGHT_EYE_BONE_Z, RIGHT_EYE_BONE_QUAT def create_default_ifacialmocap_pose(): data = {} for blendshape_name in BLENDSHAPE_NAMES: data[blendshape_name] = 0.0 data[HEAD_BONE_X] = 0.0 data[HEAD_BONE_Y] = 0.0 data[HEAD_BONE_Z] = 0.0 data[HEAD_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0] data[LEFT_EYE_BONE_X] = 0.0 data[LEFT_EYE_BONE_Y] = 0.0 data[LEFT_EYE_BONE_Z] = 0.0 data[LEFT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0] data[RIGHT_EYE_BONE_X] = 0.0 data[RIGHT_EYE_BONE_Y] = 0.0 data[RIGHT_EYE_BONE_Z] = 0.0 data[RIGHT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0] return data ================================================ FILE: src/tha4/mocap/ifacialmocap_pose_converter.py ================================================ from abc import ABC, abstractmethod from typing import Dict, List class IFacialMocapPoseConverter(ABC): @abstractmethod def convert(self, ifacialmocap_pose: Dict[str, float]) -> List[float]: pass @abstractmethod def init_pose_converter_panel(self, parent): pass ================================================ FILE: src/tha4/mocap/ifacialmocap_pose_converter_25.py ================================================ import math import time from enum import Enum from typing import Optional, Dict, List, Callable import numpy import scipy.optimize import wx from tha4.mocap.ifacialmocap_constants import MOUTH_SMILE_LEFT, MOUTH_SHRUG_UPPER, MOUTH_SMILE_RIGHT, \ BROW_INNER_UP, BROW_OUTER_UP_RIGHT, BROW_OUTER_UP_LEFT, BROW_DOWN_LEFT, BROW_DOWN_RIGHT, EYE_WIDE_LEFT, \ EYE_WIDE_RIGHT, EYE_BLINK_LEFT, EYE_BLINK_RIGHT, CHEEK_SQUINT_LEFT, CHEEK_SQUINT_RIGHT, EYE_LOOK_IN_LEFT, \ EYE_LOOK_OUT_LEFT, EYE_LOOK_IN_RIGHT, EYE_LOOK_OUT_RIGHT, EYE_LOOK_UP_LEFT, EYE_LOOK_UP_RIGHT, EYE_LOOK_DOWN_RIGHT, \ EYE_LOOK_DOWN_LEFT, HEAD_BONE_X, HEAD_BONE_Y, HEAD_BONE_Z, JAW_OPEN, MOUTH_FROWN_LEFT, MOUTH_FROWN_RIGHT, \ MOUTH_LOWER_DOWN_LEFT, MOUTH_LOWER_DOWN_RIGHT, MOUTH_FUNNEL, MOUTH_PUCKER from tha4.mocap.ifacialmocap_pose_converter import IFacialMocapPoseConverter from tha4.poser.modes.pose_parameters import get_pose_parameters class EyebrowDownMode(Enum): TROUBLED = 1 ANGRY = 2 LOWERED = 3 SERIOUS = 4 class WinkMode(Enum): NORMAL = 1 RELAXED = 2 def rad_to_deg(rad): return rad * 180.0 / math.pi def deg_to_rad(deg): return deg * math.pi / 180.0 def clamp(x, min_value, max_value): return max(min_value, min(max_value, x)) class IFacialMocapPoseConverter25Args: def __init__(self, smile_threshold_min: float = 0.4, smile_threshold_max: float = 0.6, eyebrow_down_mode: EyebrowDownMode = EyebrowDownMode.ANGRY, wink_mode: WinkMode = WinkMode.NORMAL, eye_surprised_max: float = 0.5, eye_blink_max: float = 0.8, eyebrow_down_max: float = 0.4, cheek_squint_min: float = 0.1, cheek_squint_max: float = 0.7, eye_rotation_factor: float = 1.0 / 0.75, jaw_open_min: float = 0.1, jaw_open_max: float = 0.4, mouth_frown_max: float = 0.6, mouth_funnel_min: float = 0.25, mouth_funnel_max: float = 0.5, iris_small_left=0.0, iris_small_right=0.0): self.iris_small_right = iris_small_left self.iris_small_left = iris_small_right self.wink_mode = wink_mode self.mouth_funnel_max = mouth_funnel_max self.mouth_funnel_min = mouth_funnel_min self.mouth_frown_max = mouth_frown_max self.jaw_open_max = jaw_open_max self.jaw_open_min = jaw_open_min self.eye_rotation_factor = eye_rotation_factor self.cheek_squint_max = cheek_squint_max self.cheek_squint_min = cheek_squint_min self.eyebrow_down_max = eyebrow_down_max self.eye_blink_max = eye_blink_max self.eye_surprised_max = eye_surprised_max self.eyebrow_down_mode = eyebrow_down_mode self.smile_threshold_min = smile_threshold_min self.smile_threshold_max = smile_threshold_max def set_smile_threshold_min(self, new_value: float): self.smile_threshold_min = new_value def set_smile_threshold_max(self, new_value: float): self.smile_threshold_max = new_value def set_eye_surprised_max(self, new_value: float): self.eye_surprised_max = new_value def set_eye_blink_max(self, new_value: float): self.eye_blink_max = new_value def set_eyebrow_down_max(self, new_value: float): self.eyebrow_down_max = new_value def set_cheek_squint_min(self, new_value: float): self.cheek_squint_min = new_value def set_cheek_squint_max(self, new_value: float): self.cheek_squint_max = new_value def set_jaw_open_min(self, new_value: float): self.jaw_open_min = new_value def set_jaw_open_max(self, new_value: float): self.jaw_open_max = new_value def set_mouth_frown_max(self, new_value: float): self.mouth_frown_max = new_value def set_mouth_funnel_min(self, new_value: float): self.mouth_funnel_min = new_value def set_mouth_funnel_max(self, new_value: float): self.mouth_funnel_min = new_value class IFacialMocapPoseConverter25(IFacialMocapPoseConverter): def __init__(self, args: Optional[IFacialMocapPoseConverter25Args] = None): super().__init__() if args is None: args = IFacialMocapPoseConverter25Args() self.args = args pose_parameters = get_pose_parameters() self.pose_size = 45 self.eyebrow_troubled_left_index = pose_parameters.get_parameter_index("eyebrow_troubled_left") self.eyebrow_troubled_right_index = pose_parameters.get_parameter_index("eyebrow_troubled_right") self.eyebrow_angry_left_index = pose_parameters.get_parameter_index("eyebrow_angry_left") self.eyebrow_angry_right_index = pose_parameters.get_parameter_index("eyebrow_angry_right") self.eyebrow_happy_left_index = pose_parameters.get_parameter_index("eyebrow_happy_left") self.eyebrow_happy_right_index = pose_parameters.get_parameter_index("eyebrow_happy_right") self.eyebrow_raised_left_index = pose_parameters.get_parameter_index("eyebrow_raised_left") self.eyebrow_raised_right_index = pose_parameters.get_parameter_index("eyebrow_raised_right") self.eyebrow_lowered_left_index = pose_parameters.get_parameter_index("eyebrow_lowered_left") self.eyebrow_lowered_right_index = pose_parameters.get_parameter_index("eyebrow_lowered_right") self.eyebrow_serious_left_index = pose_parameters.get_parameter_index("eyebrow_serious_left") self.eyebrow_serious_right_index = pose_parameters.get_parameter_index("eyebrow_serious_right") self.eye_surprised_left_index = pose_parameters.get_parameter_index("eye_surprised_left") self.eye_surprised_right_index = pose_parameters.get_parameter_index("eye_surprised_right") self.eye_wink_left_index = pose_parameters.get_parameter_index("eye_wink_left") self.eye_wink_right_index = pose_parameters.get_parameter_index("eye_wink_right") self.eye_happy_wink_left_index = pose_parameters.get_parameter_index("eye_happy_wink_left") self.eye_happy_wink_right_index = pose_parameters.get_parameter_index("eye_happy_wink_right") self.eye_relaxed_left_index = pose_parameters.get_parameter_index("eye_relaxed_left") self.eye_relaxed_right_index = pose_parameters.get_parameter_index("eye_relaxed_right") self.eye_raised_lower_eyelid_left_index = pose_parameters.get_parameter_index("eye_raised_lower_eyelid_left") self.eye_raised_lower_eyelid_right_index = pose_parameters.get_parameter_index("eye_raised_lower_eyelid_right") self.iris_small_left_index = pose_parameters.get_parameter_index("iris_small_left") self.iris_small_right_index = pose_parameters.get_parameter_index("iris_small_right") self.iris_rotation_x_index = pose_parameters.get_parameter_index("iris_rotation_x") self.iris_rotation_y_index = pose_parameters.get_parameter_index("iris_rotation_y") self.head_x_index = pose_parameters.get_parameter_index("head_x") self.head_y_index = pose_parameters.get_parameter_index("head_y") self.neck_z_index = pose_parameters.get_parameter_index("neck_z") self.mouth_aaa_index = pose_parameters.get_parameter_index("mouth_aaa") self.mouth_iii_index = pose_parameters.get_parameter_index("mouth_iii") self.mouth_uuu_index = pose_parameters.get_parameter_index("mouth_uuu") self.mouth_eee_index = pose_parameters.get_parameter_index("mouth_eee") self.mouth_ooo_index = pose_parameters.get_parameter_index("mouth_ooo") self.mouth_lowered_corner_left_index = pose_parameters.get_parameter_index("mouth_lowered_corner_left") self.mouth_lowered_corner_right_index = pose_parameters.get_parameter_index("mouth_lowered_corner_right") self.mouth_raised_corner_left_index = pose_parameters.get_parameter_index("mouth_raised_corner_left") self.mouth_raised_corner_right_index = pose_parameters.get_parameter_index("mouth_raised_corner_right") self.body_y_index = pose_parameters.get_parameter_index("body_y") self.body_z_index = pose_parameters.get_parameter_index("body_z") self.breathing_index = pose_parameters.get_parameter_index("breathing") self.breathing_start_time = time.time() self.panel = None def init_pose_converter_panel(self, parent): self.panel = wx.Panel(parent, style=wx.SIMPLE_BORDER) self.panel_sizer = wx.BoxSizer(wx.VERTICAL) self.panel.SetSizer(self.panel_sizer) self.panel.SetAutoLayout(1) parent.GetSizer().Add(self.panel, 0, wx.EXPAND) if True: eyebrow_down_mode_text = wx.StaticText(self.panel, label=" --- Eyebrow Down Mode --- ", style=wx.ALIGN_CENTER) self.panel_sizer.Add(eyebrow_down_mode_text, 0, wx.EXPAND) self.eyebrow_down_mode_choice = wx.Choice( self.panel, choices=[ "ANGRY", "TROUBLED", "SERIOUS", "LOWERED", ]) self.eyebrow_down_mode_choice.SetSelection(0) self.panel_sizer.Add(self.eyebrow_down_mode_choice, 0, wx.EXPAND) self.eyebrow_down_mode_choice.Bind(wx.EVT_CHOICE, self.change_eyebrow_down_mode) separator = wx.StaticLine(self.panel, -1, size=(256, 5)) self.panel_sizer.Add(separator, 0, wx.EXPAND) if True: wink_mode_text = wx.StaticText(self.panel, label=" --- Wink Mode --- ", style=wx.ALIGN_CENTER) self.panel_sizer.Add(wink_mode_text, 0, wx.EXPAND) self.wink_mode_choice = wx.Choice( self.panel, choices=[ "NORMAL", "RELAXED", ]) self.wink_mode_choice.SetSelection(0) self.panel_sizer.Add(self.wink_mode_choice, 0, wx.EXPAND) self.wink_mode_choice.Bind(wx.EVT_CHOICE, self.change_wink_mode) separator = wx.StaticLine(self.panel, -1, size=(256, 5)) self.panel_sizer.Add(separator, 0, wx.EXPAND) if True: iris_size_text = wx.StaticText(self.panel, label=" --- Iris Size --- ", style=wx.ALIGN_CENTER) self.panel_sizer.Add(iris_size_text, 0, wx.EXPAND) self.iris_left_slider = wx.Slider(self.panel, minValue=0, maxValue=1000, value=0, style=wx.HORIZONTAL) self.panel_sizer.Add(self.iris_left_slider, 0, wx.EXPAND) self.iris_left_slider.Bind(wx.EVT_SLIDER, self.change_iris_size) self.iris_right_slider = wx.Slider(self.panel, minValue=0, maxValue=1000, value=0, style=wx.HORIZONTAL) self.panel_sizer.Add(self.iris_right_slider, 0, wx.EXPAND) self.iris_right_slider.Bind(wx.EVT_SLIDER, self.change_iris_size) self.iris_right_slider.Enable(False) self.link_left_right_irises = wx.CheckBox( self.panel, label="Use same value for both sides") self.link_left_right_irises.SetValue(True) self.panel_sizer.Add(self.link_left_right_irises, wx.SizerFlags().CenterHorizontal().Border()) self.link_left_right_irises.Bind(wx.EVT_CHECKBOX, self.link_left_right_irises_clicked) separator = wx.StaticLine(self.panel, -1, size=(256, 5)) self.panel_sizer.Add(separator, 0, wx.EXPAND) if True: breathing_frequency_text = wx.StaticText( self.panel, label=" --- Breathing --- ", style=wx.ALIGN_CENTER) self.panel_sizer.Add(breathing_frequency_text, 0, wx.EXPAND) self.restart_breathing_cycle_button = wx.Button(self.panel, label="Restart Breathing Cycle") self.restart_breathing_cycle_button.Bind(wx.EVT_BUTTON, self.restart_breathing_cycle_clicked) self.panel_sizer.Add(self.restart_breathing_cycle_button, 0, wx.EXPAND) self.breathing_frequency_slider = wx.Slider( self.panel, minValue=0, maxValue=60, value=20, style=wx.HORIZONTAL) self.panel_sizer.Add(self.breathing_frequency_slider, 0, wx.EXPAND) self.breathing_gauge = wx.Gauge(self.panel, style=wx.GA_HORIZONTAL, range=1000) self.panel_sizer.Add(self.breathing_gauge, 0, wx.EXPAND) if True: separator = wx.StaticLine(self.panel, -1, size=(256, 5)) self.panel_sizer.Add(separator, 0, wx.EXPAND) convertion_parameters_text = wx.StaticText( self.panel, label="--- Conversion Parameters ---", style=wx.ALIGN_CENTER) self.panel_sizer.Add(convertion_parameters_text, 0, wx.EXPAND) conversion_param_panel = wx.Panel(self.panel) self.panel_sizer.Add(conversion_param_panel, 0, wx.EXPAND) conversion_panel_sizer = wx.FlexGridSizer(cols=2) conversion_panel_sizer.AddGrowableCol(1) conversion_param_panel.SetSizer(conversion_panel_sizer) conversion_param_panel.SetAutoLayout(1) self.smile_thresold_min_spin = self.create_spin_control( conversion_param_panel, "Smile Threshold Min:", self.args.smile_threshold_min, self.args.set_smile_threshold_min) self.smile_thresold_max_spin = self.create_spin_control( conversion_param_panel, "Smile Threshold Max:", self.args.smile_threshold_max, self.args.set_smile_threshold_max) self.eye_surprised_max_spin = self.create_spin_control( conversion_param_panel, "Eye Surprised Max:", self.args.eye_surprised_max, self.args.set_eye_surprised_max) self.eye_blink_max_spin = self.create_spin_control( conversion_param_panel, "Eye Blink Max:", self.args.eye_blink_max, self.args.set_eye_blink_max) self.eyebrow_down_max_spin = self.create_spin_control( conversion_param_panel, "Eyebrow Down Max:", self.args.eyebrow_down_max, self.args.set_eyebrow_down_max) self.cheek_squint_min_spin = self.create_spin_control( conversion_param_panel, "Cheek Squint Min:", self.args.cheek_squint_min, self.args.set_cheek_squint_min) self.cheek_squint_max_spin = self.create_spin_control( conversion_param_panel, "Cheek Squint Max:", self.args.cheek_squint_max, self.args.set_cheek_squint_max) self.jaw_open_min_spin = self.create_spin_control( conversion_param_panel, "Jaw Open Min:", self.args.jaw_open_min, self.args.set_jaw_open_min) self.jaw_open_max_spin = self.create_spin_control( conversion_param_panel, "Jaw Open Max:", self.args.jaw_open_max, self.args.set_jaw_open_max) self.mouth_frown_max_spin = self.create_spin_control( conversion_param_panel, "Mouth Frown Max:", self.args.mouth_frown_max, self.args.set_mouth_frown_max) self.mouth_funnel_min_spin = self.create_spin_control( conversion_param_panel, "Mouth Funnel Min:", self.args.mouth_funnel_min, self.args.set_mouth_funnel_min) self.mouth_funnel_max_spin = self.create_spin_control( conversion_param_panel, "Mouth Funnel Max:", self.args.mouth_funnel_max, self.args.set_mouth_funnel_max) self.panel_sizer.Fit(self.panel) def create_spin_control(self, parent, label: str, initial_value: float, set_func: Callable[[float], None]): sizer = parent.GetSizer() text = wx.StaticText(parent, label=label) sizer.Add(text, wx.SizerFlags().Right().Border(wx.ALL, 2)) spin_ctrl = wx.SpinCtrlDouble( parent, wx.ID_ANY, min=0.0, max=1.0, initial=initial_value, inc=0.01) sizer.Add(spin_ctrl, wx.SizerFlags().Border(wx.ALL, 2).Expand()) def handler(event: wx.Event): new_value = spin_ctrl.GetValue() set_func(new_value) spin_ctrl.Bind(wx.EVT_SPINCTRLDOUBLE, handler) return spin_ctrl def restart_breathing_cycle_clicked(self, event: wx.Event): self.breathing_start_time = time.time() def change_eyebrow_down_mode(self, event: wx.Event): selected_index = self.eyebrow_down_mode_choice.GetSelection() if selected_index == 0: self.args.eyebrow_down_mode = EyebrowDownMode.ANGRY elif selected_index == 1: self.args.eyebrow_down_mode = EyebrowDownMode.TROUBLED elif selected_index == 2: self.args.eyebrow_down_mode = EyebrowDownMode.SERIOUS else: self.args.eyebrow_down_mode = EyebrowDownMode.LOWERED def change_wink_mode(self, event: wx.Event): selected_index = self.wink_mode_choice.GetSelection() if selected_index == 0: self.args.wink_mode = WinkMode.NORMAL else: self.args.wink_mode = WinkMode.RELAXED def change_iris_size(self, event: wx.Event): if self.link_left_right_irises.GetValue(): left_value = self.iris_left_slider.GetValue() right_value = self.iris_right_slider.GetValue() if left_value != right_value: self.iris_right_slider.SetValue(left_value) self.args.iris_small_left = left_value / 1000.0 self.args.iris_small_right = left_value / 1000.0 else: self.args.iris_small_left = self.iris_left_slider.GetValue() / 1000.0 self.args.iris_small_right = self.iris_right_slider.GetValue() / 1000.0 def link_left_right_irises_clicked(self, event: wx.Event): if self.link_left_right_irises.GetValue(): self.iris_right_slider.Enable(False) else: self.iris_right_slider.Enable(True) self.change_iris_size(event) def decompose_head_body_param(self, param, threshold=2.0 / 3): if abs(param) < threshold: return (param, 0.0) else: if param < 0: sign = -1.0 else: sign = 1.0 return (threshold * sign, (abs(param) - threshold) * sign) def convert(self, ifacialmocap_pose: Dict[str, float]) -> List[float]: pose = [0.0 for i in range(self.pose_size)] smile_value = \ (ifacialmocap_pose[MOUTH_SMILE_LEFT] + ifacialmocap_pose[MOUTH_SMILE_RIGHT]) / 2.0 \ + ifacialmocap_pose[MOUTH_SHRUG_UPPER] if self.args.smile_threshold_min >= self.args.smile_threshold_max: smile_degree = 0.0 else: if smile_value < self.args.smile_threshold_min: smile_degree = 0.0 elif smile_value > self.args.smile_threshold_max: smile_degree = 1.0 else: smile_degree = (smile_value - self.args.smile_threshold_min) / ( self.args.smile_threshold_max - self.args.smile_threshold_min) # Eyebrow if True: brow_inner_up = ifacialmocap_pose[BROW_INNER_UP] brow_outer_up_right = ifacialmocap_pose[BROW_OUTER_UP_RIGHT] brow_outer_up_left = ifacialmocap_pose[BROW_OUTER_UP_LEFT] brow_up_left = clamp(brow_inner_up + brow_outer_up_left, 0.0, 1.0) brow_up_right = clamp(brow_inner_up + brow_outer_up_right, 0.0, 1.0) pose[self.eyebrow_raised_left_index] = brow_up_left pose[self.eyebrow_raised_right_index] = brow_up_right if self.args.eyebrow_down_max <= 0.0: brow_down_left = 0.0 brow_down_right = 0.0 else: brow_down_left = (1.0 - smile_degree) \ * clamp(ifacialmocap_pose[BROW_DOWN_LEFT] / self.args.eyebrow_down_max, 0.0, 1.0) brow_down_right = (1.0 - smile_degree) \ * clamp(ifacialmocap_pose[BROW_DOWN_RIGHT] / self.args.eyebrow_down_max, 0.0, 1.0) if self.args.eyebrow_down_mode == EyebrowDownMode.TROUBLED: pose[self.eyebrow_troubled_left_index] = brow_down_left pose[self.eyebrow_troubled_right_index] = brow_down_right elif self.args.eyebrow_down_mode == EyebrowDownMode.ANGRY: pose[self.eyebrow_angry_left_index] = brow_down_left pose[self.eyebrow_angry_right_index] = brow_down_right elif self.args.eyebrow_down_mode == EyebrowDownMode.LOWERED: pose[self.eyebrow_lowered_left_index] = brow_down_left pose[self.eyebrow_lowered_right_index] = brow_down_right elif self.args.eyebrow_down_mode == EyebrowDownMode.SERIOUS: pose[self.eyebrow_serious_left_index] = brow_down_left pose[self.eyebrow_serious_right_index] = brow_down_right brow_happy_value = clamp(smile_value, 0.0, 1.0) * smile_degree pose[self.eyebrow_happy_left_index] = brow_happy_value pose[self.eyebrow_happy_right_index] = brow_happy_value # Eye if True: # Surprised if self.args.eye_surprised_max <= 0.0: pose[self.eye_surprised_left_index] = 0.0 pose[self.eye_surprised_right_index] = 0.0 else: pose[self.eye_surprised_left_index] = clamp( ifacialmocap_pose[EYE_WIDE_LEFT] / self.args.eye_surprised_max, 0.0, 1.0) pose[self.eye_surprised_right_index] = clamp( ifacialmocap_pose[EYE_WIDE_RIGHT] / self.args.eye_surprised_max, 0.0, 1.0) # Wink if self.args.wink_mode == WinkMode.NORMAL: wink_left_index = self.eye_wink_left_index wink_right_index = self.eye_wink_right_index else: wink_left_index = self.eye_relaxed_left_index wink_right_index = self.eye_relaxed_right_index if self.args.eye_blink_max <= 0: pose[wink_left_index] = 0.0 pose[wink_right_index] = 0.0 pose[self.eye_happy_wink_left_index] = 0.0 pose[self.eye_happy_wink_right_index] = 0.0 else: pose[wink_left_index] = (1.0 - smile_degree) * clamp( ifacialmocap_pose[EYE_BLINK_LEFT] / self.args.eye_blink_max, 0.0, 1.0) pose[wink_right_index] = (1.0 - smile_degree) * clamp( ifacialmocap_pose[EYE_BLINK_RIGHT] / self.args.eye_blink_max, 0.0, 1.0) pose[self.eye_happy_wink_left_index] = smile_degree * clamp( ifacialmocap_pose[EYE_BLINK_LEFT] / self.args.eye_blink_max, 0.0, 1.0) pose[self.eye_happy_wink_right_index] = smile_degree * clamp( ifacialmocap_pose[EYE_BLINK_RIGHT] / self.args.eye_blink_max, 0.0, 1.0) # Lower eyelid cheek_squint_denom = self.args.cheek_squint_max - self.args.cheek_squint_min if cheek_squint_denom <= 0.0: pose[self.eye_raised_lower_eyelid_left_index] = 0.0 pose[self.eye_raised_lower_eyelid_right_index] = 0.0 else: pose[self.eye_raised_lower_eyelid_left_index] = \ clamp( (ifacialmocap_pose[CHEEK_SQUINT_LEFT] - self.args.cheek_squint_min) / cheek_squint_denom, 0.0, 1.0) pose[self.eye_raised_lower_eyelid_right_index] = \ clamp( (ifacialmocap_pose[CHEEK_SQUINT_RIGHT] - self.args.cheek_squint_min) / cheek_squint_denom, 0.0, 1.0) # Iris rotation if True: eye_rotation_y = (ifacialmocap_pose[EYE_LOOK_IN_LEFT] - ifacialmocap_pose[EYE_LOOK_OUT_LEFT] - ifacialmocap_pose[EYE_LOOK_IN_RIGHT] + ifacialmocap_pose[EYE_LOOK_OUT_RIGHT]) / 2.0 * self.args.eye_rotation_factor pose[self.iris_rotation_y_index] = clamp(eye_rotation_y, -1.0, 1.0) eye_rotation_x = (ifacialmocap_pose[EYE_LOOK_UP_LEFT] + ifacialmocap_pose[EYE_LOOK_UP_RIGHT] - ifacialmocap_pose[EYE_LOOK_DOWN_LEFT] - ifacialmocap_pose[EYE_LOOK_DOWN_RIGHT]) / 2.0 * self.args.eye_rotation_factor pose[self.iris_rotation_x_index] = clamp(eye_rotation_x, -1.0, 1.0) # Iris size if True: pose[self.iris_small_left_index] = self.args.iris_small_left pose[self.iris_small_right_index] = self.args.iris_small_right # Head rotation if True: x_param = clamp(-ifacialmocap_pose[HEAD_BONE_X] * 180.0 / math.pi, -15.0, 15.0) / 15.0 pose[self.head_x_index] = x_param y_param = clamp(-ifacialmocap_pose[HEAD_BONE_Y] * 180.0 / math.pi, -10.0, 10.0) / 10.0 pose[self.head_y_index] = y_param pose[self.body_y_index] = y_param z_param = clamp(ifacialmocap_pose[HEAD_BONE_Z] * 180.0 / math.pi, -15.0, 15.0) / 15.0 pose[self.neck_z_index] = z_param pose[self.body_z_index] = z_param # Mouth if True: jaw_open_denom = self.args.jaw_open_max - self.args.jaw_open_min if jaw_open_denom <= 0: mouth_open = 0.0 else: mouth_open = clamp((ifacialmocap_pose[JAW_OPEN] - self.args.jaw_open_min) / jaw_open_denom, 0.0, 1.0) pose[self.mouth_aaa_index] = mouth_open pose[self.mouth_raised_corner_left_index] = clamp(smile_value, 0.0, 1.0) pose[self.mouth_raised_corner_right_index] = clamp(smile_value, 0.0, 1.0) is_mouth_open = mouth_open > 0.0 if not is_mouth_open: if self.args.mouth_frown_max > 0: mouth_frown_value = 0.0 else: mouth_frown_value = clamp( (ifacialmocap_pose[MOUTH_FROWN_LEFT] + ifacialmocap_pose[ MOUTH_FROWN_RIGHT]) / self.args.mouth_frown_max, 0.0, 1.0) pose[self.mouth_lowered_corner_left_index] = mouth_frown_value pose[self.mouth_lowered_corner_right_index] = mouth_frown_value else: mouth_lower_down = clamp( ifacialmocap_pose[MOUTH_LOWER_DOWN_LEFT] + ifacialmocap_pose[MOUTH_LOWER_DOWN_RIGHT], 0.0, 1.0) mouth_funnel = ifacialmocap_pose[MOUTH_FUNNEL] mouth_pucker = ifacialmocap_pose[MOUTH_PUCKER] mouth_point = [mouth_open, mouth_lower_down, mouth_funnel, mouth_pucker] aaa_point = [1.0, 1.0, 0.0, 0.0] iii_point = [0.0, 1.0, 0.0, 0.0] uuu_point = [0.5, 0.3, 0.25, 0.75] ooo_point = [1.0, 0.5, 0.5, 0.4] decomp = numpy.array([0, 0, 0, 0]) M = numpy.array([ aaa_point, iii_point, uuu_point, ooo_point ]) def loss(decomp): return numpy.linalg.norm(numpy.matmul(decomp, M) - mouth_point) \ + 0.01 * numpy.linalg.norm(decomp, ord=1) opt_result = scipy.optimize.minimize( loss, decomp, bounds=[(0.0, 1.0), (0.0, 1.0), (0.0, 1.0), (0.0, 1.0)]) decomp = opt_result["x"] restricted_decomp = [decomp.item(0), decomp.item(1), decomp.item(2), decomp.item(3)] pose[self.mouth_aaa_index] = restricted_decomp[0] pose[self.mouth_iii_index] = restricted_decomp[1] mouth_funnel_denom = self.args.mouth_funnel_max - self.args.mouth_funnel_min if mouth_funnel_denom <= 0: ooo_alpha = 0.0 uo_value = 0.0 else: ooo_alpha = clamp((mouth_funnel - self.args.mouth_funnel_min) / mouth_funnel_denom, 0.0, 1.0) uo_value = clamp(restricted_decomp[2] + restricted_decomp[3], 0.0, 1.0) pose[self.mouth_uuu_index] = uo_value * (1.0 - ooo_alpha) pose[self.mouth_ooo_index] = uo_value * ooo_alpha if self.panel is not None: frequency = self.breathing_frequency_slider.GetValue() if frequency == 0: value = 0.0 pose[self.breathing_index] = value self.breathing_start_time = time.time() else: period = 60.0 / frequency now = time.time() diff = now - self.breathing_start_time frac = (diff % period) / period value = (-math.cos(2 * math.pi * frac) + 1.0) / 2.0 pose[self.breathing_index] = value self.breathing_gauge.SetValue(int(1000 * value)) return pose def create_ifacialmocap_pose_converter( args: Optional[IFacialMocapPoseConverter25Args] = None) -> IFacialMocapPoseConverter: return IFacialMocapPoseConverter25(args) ================================================ FILE: src/tha4/mocap/ifacialmocap_v2.py ================================================ import math from tha4.mocap.ifacialmocap_constants import BLENDSHAPE_NAMES, HEAD_BONE_X, HEAD_BONE_Y, HEAD_BONE_Z, \ RIGHT_EYE_BONE_X, RIGHT_EYE_BONE_Y, RIGHT_EYE_BONE_Z, LEFT_EYE_BONE_X, LEFT_EYE_BONE_Y, LEFT_EYE_BONE_Z, \ HEAD_BONE_QUAT, LEFT_EYE_BONE_QUAT, RIGHT_EYE_BONE_QUAT IFACIALMOCAP_PORT = 49983 IFACIALMOCAP_START_STRING = "iFacialMocap_sahuasouryya9218sauhuiayeta91555dy3719|sendDataVersion=v2".encode('utf-8') def parse_ifacialmocap_v2_pose(ifacialmocap_output): output = {} parts = ifacialmocap_output.split("|") for part in parts: part = part.strip() if len(part) == 0: continue if "&" in part: components = part.split("&") assert len(components) == 2 key = components[0] value = float(components[1]) / 100.0 if key.endswith("_L"): key = key[:-2] + "Left" elif key.endswith("_R"): key = key[:-2] + "Right" if key in BLENDSHAPE_NAMES: output[key] = value elif part.startswith("=head#"): components = part[len("=head#"):].split(",") assert len(components) == 6 output[HEAD_BONE_X] = float(components[0]) * math.pi / 180 output[HEAD_BONE_Y] = float(components[1]) * math.pi / 180 output[HEAD_BONE_Z] = float(components[2]) * math.pi / 180 elif part.startswith("rightEye#"): components = part[len("rightEye#"):].split(",") output[RIGHT_EYE_BONE_X] = float(components[0]) * math.pi / 180 output[RIGHT_EYE_BONE_Y] = float(components[1]) * math.pi / 180 output[RIGHT_EYE_BONE_Z] = float(components[2]) * math.pi / 180 elif part.startswith("leftEye#"): components = part[len("leftEye#"):].split(",") output[LEFT_EYE_BONE_X] = float(components[0]) * math.pi / 180 output[LEFT_EYE_BONE_Y] = float(components[1]) * math.pi / 180 output[LEFT_EYE_BONE_Z] = float(components[2]) * math.pi / 180 output[HEAD_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0] output[LEFT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0] output[RIGHT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0] return output def parse_ifacialmocap_v1_pose(ifacialmocap_output): output = {} parts = ifacialmocap_output.split("|") for part in parts: part = part.strip() if len(part) == 0: continue if part.startswith("=head#"): components = part[len("=head#"):].split(",") assert len(components) == 6 output[HEAD_BONE_X] = float(components[0]) * math.pi / 180 output[HEAD_BONE_Y] = float(components[1]) * math.pi / 180 output[HEAD_BONE_Z] = float(components[2]) * math.pi / 180 elif part.startswith("rightEye#"): components = part[len("rightEye#"):].split(",") output[RIGHT_EYE_BONE_X] = float(components[0]) * math.pi / 180 output[RIGHT_EYE_BONE_Y] = float(components[1]) * math.pi / 180 output[RIGHT_EYE_BONE_Z] = float(components[2]) * math.pi / 180 elif part.startswith("leftEye#"): components = part[len("leftEye#"):].split(",") output[LEFT_EYE_BONE_X] = float(components[0]) * math.pi / 180 output[LEFT_EYE_BONE_Y] = float(components[1]) * math.pi / 180 output[LEFT_EYE_BONE_Z] = float(components[2]) * math.pi / 180 else: components = part.split("-") assert len(components) == 2 key = components[0] value = float(components[1]) / 100.0 if key.endswith("_L"): key = key[:-2] + "Left" elif key.endswith("_R"): key = key[:-2] + "Right" if key in BLENDSHAPE_NAMES: output[key] = value output[HEAD_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0] output[LEFT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0] output[RIGHT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0] return output ================================================ FILE: src/tha4/mocap/mediapipe_constants.py ================================================ EYE_LOOK_IN_LEFT = "eyeLookInLeft" EYE_LOOK_OUT_LEFT = "eyeLookOutLeft" EYE_LOOK_DOWN_LEFT = "eyeLookDownLeft" EYE_LOOK_UP_LEFT = "eyeLookUpLeft" EYE_BLINK_LEFT = "eyeBlinkLeft" EYE_SQUINT_LEFT = "eyeSquintLeft" EYE_WIDE_LEFT = "eyeWideLeft" EYE_LOOK_IN_RIGHT = "eyeLookInRight" EYE_LOOK_OUT_RIGHT = "eyeLookOutRight" EYE_LOOK_DOWN_RIGHT = "eyeLookDownRight" EYE_LOOK_UP_RIGHT = "eyeLookUpRight" EYE_BLINK_RIGHT = "eyeBlinkRight" EYE_SQUINT_RIGHT = "eyeSquintRight" EYE_WIDE_RIGHT = "eyeWideRight" BROW_DOWN_LEFT = "browDownLeft" BROW_OUTER_UP_LEFT = "browOuterUpLeft" BROW_DOWN_RIGHT = "browDownRight" BROW_OUTER_UP_RIGHT = "browOuterUpRight" BROW_INNER_UP = "browInnerUp" NOSE_SNEER_LEFT = "noseSneerLeft" NOSE_SNEER_RIGHT = "noseSneerRight" CHEEK_SQUINT_LEFT = "cheekSquintLeft" CHEEK_SQUINT_RIGHT = "cheekSquintRight" CHEEK_PUFF = "cheekPuff" MOUTH_LEFT = "mouthLeft" MOUTH_DIMPLE_LEFT = "mouthDimpleLeft" MOUTH_FROWN_LEFT = "mouthFrownLeft" MOUTH_LOWER_DOWN_LEFT = "mouthLowerDownLeft" MOUTH_PRESS_LEFT = "mouthPressLeft" MOUTH_SMILE_LEFT = "mouthSmileLeft" MOUTH_STRETCH_LEFT = "mouthStretchLeft" MOUTH_UPPER_UP_LEFT = "mouthUpperUpLeft" MOUTH_RIGHT = "mouthRight" MOUTH_DIMPLE_RIGHT = "mouthDimpleRight" MOUTH_FROWN_RIGHT = "mouthFrownRight" MOUTH_LOWER_DOWN_RIGHT = "mouthLowerDownRight" MOUTH_PRESS_RIGHT = "mouthPressRight" MOUTH_SMILE_RIGHT = "mouthSmileRight" MOUTH_STRETCH_RIGHT = "mouthStretchRight" MOUTH_UPPER_UP_RIGHT = "mouthUpperUpRight" MOUTH_CLOSE = "mouthClose" MOUTH_FUNNEL = "mouthFunnel" MOUTH_PUCKER = "mouthPucker" MOUTH_ROLL_LOWER = "mouthRollLower" MOUTH_ROLL_UPPER = "mouthRollUpper" MOUTH_SHRUG_LOWER = "mouthShrugLower" MOUTH_SHRUG_UPPER = "mouthShrugUpper" JAW_LEFT = "jawLeft" JAW_RIGHT = "jawRight" JAW_FORWARD = "jawForward" JAW_OPEN = "jawOpen" NEUTRAL = "_neutral" BLENDSHAPE_NAMES = [ EYE_LOOK_IN_LEFT, # 0 EYE_LOOK_OUT_LEFT, # 1 EYE_LOOK_DOWN_LEFT, # 2 EYE_LOOK_UP_LEFT, # 3 EYE_BLINK_LEFT, # 4 EYE_SQUINT_LEFT, # 5 EYE_WIDE_LEFT, # 6 EYE_LOOK_IN_RIGHT, # 7 EYE_LOOK_OUT_RIGHT, # 8 EYE_LOOK_DOWN_RIGHT, # 9 EYE_LOOK_UP_RIGHT, # 10 EYE_BLINK_RIGHT, # 11 EYE_SQUINT_RIGHT, # 12 EYE_WIDE_RIGHT, # 13 BROW_DOWN_LEFT, # 14 BROW_OUTER_UP_LEFT, # 15 BROW_DOWN_RIGHT, # 16 BROW_OUTER_UP_RIGHT, # 17 BROW_INNER_UP, # 18 NOSE_SNEER_LEFT, # 19 NOSE_SNEER_RIGHT, # 20 CHEEK_SQUINT_LEFT, # 21 CHEEK_SQUINT_RIGHT, # 22 CHEEK_PUFF, # 23 MOUTH_LEFT, # 24 MOUTH_DIMPLE_LEFT, # 25 MOUTH_FROWN_LEFT, # 26 MOUTH_LOWER_DOWN_LEFT, # 27 MOUTH_PRESS_LEFT, # 28 MOUTH_SMILE_LEFT, # 29 MOUTH_STRETCH_LEFT, # 30 MOUTH_UPPER_UP_LEFT, # 31 MOUTH_RIGHT, # 32 MOUTH_DIMPLE_RIGHT, # 33 MOUTH_FROWN_RIGHT, # 34 MOUTH_LOWER_DOWN_RIGHT, # 35 MOUTH_PRESS_RIGHT, # 36 MOUTH_SMILE_RIGHT, # 37 MOUTH_STRETCH_RIGHT, # 38 MOUTH_UPPER_UP_RIGHT, # 39 MOUTH_CLOSE, # 40 MOUTH_FUNNEL, # 41 MOUTH_PUCKER, # 42 MOUTH_ROLL_LOWER, # 43 MOUTH_ROLL_UPPER, # 44 MOUTH_SHRUG_LOWER, # 45 MOUTH_SHRUG_UPPER, # 46 JAW_LEFT, # 47 JAW_RIGHT, # 48 JAW_FORWARD, # 49 JAW_OPEN, # 50 NEUTRAL, # 51 ] EYE_LEFT_BLENDSHAPES = [ EYE_LOOK_IN_LEFT, # 0 EYE_LOOK_OUT_LEFT, # 1 EYE_LOOK_DOWN_LEFT, # 2 EYE_LOOK_UP_LEFT, # 3 EYE_BLINK_LEFT, # 4 EYE_SQUINT_LEFT, # 5 EYE_WIDE_LEFT, # 6 ] EYE_RIGHT_BLENDSHAPES = [ EYE_LOOK_IN_RIGHT, # 7 EYE_LOOK_OUT_RIGHT, # 8 EYE_LOOK_DOWN_RIGHT, # 9 EYE_LOOK_UP_RIGHT, # 10 EYE_BLINK_RIGHT, # 11 EYE_SQUINT_RIGHT, # 12 EYE_WIDE_RIGHT, # 13 ] BROW_LEFT_BLENDSHAPES = [ BROW_DOWN_LEFT, # 14 BROW_OUTER_UP_LEFT, # 15 ] BROW_RIGHT_BLENDSHAPES = [ BROW_DOWN_RIGHT, # 16 BROW_OUTER_UP_RIGHT, # 17 ] BROW_BOTH_BLENDSHAPES = [ BROW_INNER_UP, # 18 ] NOSE_BLENDSHAPES = [ NOSE_SNEER_LEFT, # 19 NOSE_SNEER_RIGHT, # 20 ] CHECK_BLENDSHAPES = [ CHEEK_SQUINT_LEFT, # 21 CHEEK_SQUINT_RIGHT, # 22 CHEEK_PUFF, # 23 ] MOUTH_LEFT_BLENDSHAPES = [ MOUTH_LEFT, # 24 MOUTH_DIMPLE_LEFT, # 25 MOUTH_FROWN_LEFT, # 26 MOUTH_LOWER_DOWN_LEFT, # 27 MOUTH_PRESS_LEFT, # 28 MOUTH_SMILE_LEFT, # 29 MOUTH_STRETCH_LEFT, # 30 MOUTH_UPPER_UP_LEFT, # 31 ] MOUTH_RIGHT_BLENDSHAPES = [ MOUTH_RIGHT, # 32 MOUTH_DIMPLE_RIGHT, # 33 MOUTH_FROWN_RIGHT, # 34 MOUTH_LOWER_DOWN_RIGHT, # 35 MOUTH_PRESS_RIGHT, # 36 MOUTH_SMILE_RIGHT, # 37 MOUTH_STRETCH_RIGHT, # 38 MOUTH_UPPER_UP_RIGHT, # 39 ] MOUTH_BOTH_BLENDSHAPES = [ MOUTH_CLOSE, # 40 MOUTH_FUNNEL, # 41 MOUTH_PUCKER, # 42 MOUTH_ROLL_LOWER, # 43 MOUTH_ROLL_UPPER, # 44 MOUTH_SHRUG_LOWER, # 45 MOUTH_SHRUG_UPPER, # 46 ] JAW_BLENDSHAPES = [ JAW_LEFT, # 47 JAW_RIGHT, # 48 JAW_FORWARD, # 49 JAW_OPEN, # 50 ] NEUTRAL_BLENDSHAPES = [ NEUTRAL, # 51 ] COLUMN_0_BLENDSHAPES = EYE_RIGHT_BLENDSHAPES + BROW_RIGHT_BLENDSHAPES + [NOSE_SNEER_RIGHT, CHEEK_SQUINT_RIGHT] COLUMN_1_BLENDSHAPES = EYE_LEFT_BLENDSHAPES + BROW_LEFT_BLENDSHAPES + [NOSE_SNEER_LEFT, CHEEK_SQUINT_LEFT] COLUMN_2_BLENDSHAPES = MOUTH_RIGHT_BLENDSHAPES + [JAW_RIGHT] COLUMN_3_BLENDSHAPES = MOUTH_LEFT_BLENDSHAPES + [JAW_LEFT] COLUMN_4_BLENDSHAPES = [BROW_INNER_UP, CHEEK_PUFF] + MOUTH_BOTH_BLENDSHAPES + [JAW_FORWARD, JAW_OPEN, NEUTRAL] BLENDSHAPE_COLUMNS = [ COLUMN_0_BLENDSHAPES, COLUMN_1_BLENDSHAPES, COLUMN_2_BLENDSHAPES, COLUMN_3_BLENDSHAPES, COLUMN_4_BLENDSHAPES, ] HEAD_X = "headX" HEAD_Y = "headY" HEAD_Z = "headZ" HEAD_ROTATIONS = [HEAD_X, HEAD_Y, HEAD_Z] ================================================ FILE: src/tha4/mocap/mediapipe_face_pose.py ================================================ import json import os from typing import Optional, Dict import numpy class MediaPipeFacePose: KEY_BLENDSHAPE_PARAMS = "blendshape_params" KEY_XFORM_MATRIX = "xform_matrix" def __init__(self, blendshape_params: Optional[Dict[str, float]], xform_matrix: Optional[numpy.ndarray]): if blendshape_params is None: blendshape_params = {} if xform_matrix is None: self.xform_matrix = numpy.zeros(4, 4) for i in range(4): self.xform_matrix[i, i] = 1.0 self.blendshape_params = blendshape_params self.xform_matrix = xform_matrix def get_json(self): return { MediaPipeFacePose.KEY_BLENDSHAPE_PARAMS: self.blendshape_params.copy(), MediaPipeFacePose.KEY_XFORM_MATRIX: self.xform_matrix.tolist() } def save(self, file_name: str): os.makedirs(os.path.dirname(file_name), exist_ok=True) with open(file_name, "wt") as fout: fout.write(json.dumps(self.get_json())) @staticmethod def load(file_name: str): with open(file_name, "rt") as fin: s = fin.read() json_data = json.loads(s) return MediaPipeFacePose( json_data[MediaPipeFacePose.KEY_BLENDSHAPE_PARAMS], xform_matrix = numpy.array(json_data[MediaPipeFacePose.KEY_XFORM_MATRIX])) ================================================ FILE: src/tha4/mocap/mediapipe_face_pose_converter.py ================================================ from abc import ABC, abstractmethod from typing import List, Callable, Optional from tha4.mocap.mediapipe_face_pose import MediaPipeFacePose class MediaPipeFacePoseConverter(ABC): @abstractmethod def convert(self, mediapipe_face_pose: MediaPipeFacePose) -> List[float]: pass @abstractmethod def init_pose_converter_panel( self, parent, current_pose_supplier: Callable[[], Optional[MediaPipeFacePose]]): pass ================================================ FILE: src/tha4/mocap/mediapipe_face_pose_converter_00.py ================================================ import math import time from enum import Enum from typing import Optional, List, Callable import numpy import scipy.optimize import wx from scipy.spatial.transform import Rotation from tha4.poser.modes.pose_parameters import get_pose_parameters from tha4.mocap.mediapipe_constants import MOUTH_SMILE_LEFT, MOUTH_SHRUG_UPPER, MOUTH_SMILE_RIGHT, \ BROW_INNER_UP, BROW_OUTER_UP_RIGHT, BROW_OUTER_UP_LEFT, BROW_DOWN_LEFT, BROW_DOWN_RIGHT, EYE_WIDE_LEFT, \ EYE_WIDE_RIGHT, EYE_BLINK_LEFT, EYE_BLINK_RIGHT, CHEEK_SQUINT_LEFT, CHEEK_SQUINT_RIGHT, EYE_LOOK_IN_LEFT, \ EYE_LOOK_OUT_LEFT, EYE_LOOK_IN_RIGHT, EYE_LOOK_OUT_RIGHT, EYE_LOOK_UP_LEFT, EYE_LOOK_UP_RIGHT, EYE_LOOK_DOWN_RIGHT, \ EYE_LOOK_DOWN_LEFT, JAW_OPEN, MOUTH_FROWN_LEFT, MOUTH_FROWN_RIGHT, \ MOUTH_LOWER_DOWN_LEFT, MOUTH_LOWER_DOWN_RIGHT, MOUTH_FUNNEL, MOUTH_PUCKER from tha4.mocap.mediapipe_face_pose import MediaPipeFacePose from tha4.mocap.mediapipe_face_pose_converter import MediaPipeFacePoseConverter class EyebrowDownMode(Enum): TROUBLED = 1 ANGRY = 2 LOWERED = 3 SERIOUS = 4 class WinkMode(Enum): NORMAL = 1 RELAXED = 2 def rad_to_deg(rad): return rad * 180.0 / math.pi def deg_to_rad(deg): return deg * math.pi / 180.0 def clamp(x, min_value, max_value): return max(min_value, min(max_value, x)) class MediaPipeFacePoseConverter00Args: def __init__(self, smile_threshold_min: float = 0.4, smile_threshold_max: float = 0.6, eyebrow_down_mode: EyebrowDownMode = EyebrowDownMode.ANGRY, wink_mode: WinkMode = WinkMode.NORMAL, eye_surprised_max: float = 0.5, eye_blink_max: float = 0.8, eyebrow_down_max: float = 0.4, cheek_squint_min: float = 0.1, cheek_squint_max: float = 0.7, eye_rotation_factor: float = 1.0 / 0.75, jaw_open_min: float = 0.1, jaw_open_max: float = 0.4, mouth_frown_max: float = 0.6, mouth_funnel_min: float = 0.25, mouth_funnel_max: float = 0.5, iris_small_left=0.0, iris_small_right=0.0, head_x_offset=0.0, head_y_offset=0.0, head_z_offset=0.0): self.iris_small_right = iris_small_left self.iris_small_left = iris_small_right self.wink_mode = wink_mode self.mouth_funnel_max = mouth_funnel_max self.mouth_funnel_min = mouth_funnel_min self.mouth_frown_max = mouth_frown_max self.jaw_open_max = jaw_open_max self.jaw_open_min = jaw_open_min self.eye_rotation_factor = eye_rotation_factor self.cheek_squint_max = cheek_squint_max self.cheek_squint_min = cheek_squint_min self.eyebrow_down_max = eyebrow_down_max self.eye_blink_max = eye_blink_max self.eye_surprised_max = eye_surprised_max self.smile_threshold_min = smile_threshold_min self.smile_threshold_max = smile_threshold_max self.head_z_offset = head_z_offset self.head_y_offset = head_y_offset self.head_x_offset = head_x_offset self.eyebrow_down_mode = eyebrow_down_mode def set_smile_threshold_min(self, new_value: float): self.smile_threshold_min = new_value def set_smile_threshold_max(self, new_value: float): self.smile_threshold_max = new_value def set_eye_surprised_max(self, new_value: float): self.eye_surprised_max = new_value def set_eye_blink_max(self, new_value: float): self.eye_blink_max = new_value def set_eyebrow_down_max(self, new_value: float): self.eyebrow_down_max = new_value def set_cheek_squint_min(self, new_value: float): self.cheek_squint_min = new_value def set_cheek_squint_max(self, new_value: float): self.cheek_squint_max = new_value def set_jaw_open_min(self, new_value: float): self.jaw_open_min = new_value def set_jaw_open_max(self, new_value: float): self.jaw_open_max = new_value def set_mouth_frown_max(self, new_value: float): self.mouth_frown_max = new_value def set_mouth_funnel_min(self, new_value: float): self.mouth_funnel_min = new_value def set_mouth_funnel_max(self, new_value: float): self.mouth_funnel_min = new_value class MediaPoseFacePoseConverter00(MediaPipeFacePoseConverter): def __init__(self, args: Optional[MediaPipeFacePoseConverter00Args] = None): super().__init__() if args is None: args = MediaPipeFacePoseConverter00Args() self.args = args pose_parameters = get_pose_parameters() self.pose_size = 45 self.eyebrow_troubled_left_index = pose_parameters.get_parameter_index("eyebrow_troubled_left") self.eyebrow_troubled_right_index = pose_parameters.get_parameter_index("eyebrow_troubled_right") self.eyebrow_angry_left_index = pose_parameters.get_parameter_index("eyebrow_angry_left") self.eyebrow_angry_right_index = pose_parameters.get_parameter_index("eyebrow_angry_right") self.eyebrow_happy_left_index = pose_parameters.get_parameter_index("eyebrow_happy_left") self.eyebrow_happy_right_index = pose_parameters.get_parameter_index("eyebrow_happy_right") self.eyebrow_raised_left_index = pose_parameters.get_parameter_index("eyebrow_raised_left") self.eyebrow_raised_right_index = pose_parameters.get_parameter_index("eyebrow_raised_right") self.eyebrow_lowered_left_index = pose_parameters.get_parameter_index("eyebrow_lowered_left") self.eyebrow_lowered_right_index = pose_parameters.get_parameter_index("eyebrow_lowered_right") self.eyebrow_serious_left_index = pose_parameters.get_parameter_index("eyebrow_serious_left") self.eyebrow_serious_right_index = pose_parameters.get_parameter_index("eyebrow_serious_right") self.eye_surprised_left_index = pose_parameters.get_parameter_index("eye_surprised_left") self.eye_surprised_right_index = pose_parameters.get_parameter_index("eye_surprised_right") self.eye_wink_left_index = pose_parameters.get_parameter_index("eye_wink_left") self.eye_wink_right_index = pose_parameters.get_parameter_index("eye_wink_right") self.eye_happy_wink_left_index = pose_parameters.get_parameter_index("eye_happy_wink_left") self.eye_happy_wink_right_index = pose_parameters.get_parameter_index("eye_happy_wink_right") self.eye_relaxed_left_index = pose_parameters.get_parameter_index("eye_relaxed_left") self.eye_relaxed_right_index = pose_parameters.get_parameter_index("eye_relaxed_right") self.eye_raised_lower_eyelid_left_index = pose_parameters.get_parameter_index("eye_raised_lower_eyelid_left") self.eye_raised_lower_eyelid_right_index = pose_parameters.get_parameter_index("eye_raised_lower_eyelid_right") self.iris_small_left_index = pose_parameters.get_parameter_index("iris_small_left") self.iris_small_right_index = pose_parameters.get_parameter_index("iris_small_right") self.iris_rotation_x_index = pose_parameters.get_parameter_index("iris_rotation_x") self.iris_rotation_y_index = pose_parameters.get_parameter_index("iris_rotation_y") self.head_x_index = pose_parameters.get_parameter_index("head_x") self.head_y_index = pose_parameters.get_parameter_index("head_y") self.neck_z_index = pose_parameters.get_parameter_index("neck_z") self.mouth_aaa_index = pose_parameters.get_parameter_index("mouth_aaa") self.mouth_iii_index = pose_parameters.get_parameter_index("mouth_iii") self.mouth_uuu_index = pose_parameters.get_parameter_index("mouth_uuu") self.mouth_eee_index = pose_parameters.get_parameter_index("mouth_eee") self.mouth_ooo_index = pose_parameters.get_parameter_index("mouth_ooo") self.mouth_lowered_corner_left_index = pose_parameters.get_parameter_index("mouth_lowered_corner_left") self.mouth_lowered_corner_right_index = pose_parameters.get_parameter_index("mouth_lowered_corner_right") self.mouth_raised_corner_left_index = pose_parameters.get_parameter_index("mouth_raised_corner_left") self.mouth_raised_corner_right_index = pose_parameters.get_parameter_index("mouth_raised_corner_right") self.body_y_index = pose_parameters.get_parameter_index("body_y") self.body_z_index = pose_parameters.get_parameter_index("body_z") self.breathing_index = pose_parameters.get_parameter_index("breathing") self.breathing_start_time = time.time() self.panel = None self.current_pose_supplier = None def init_pose_converter_panel( self, parent, current_pose_supplier: Callable[[], Optional[MediaPipeFacePose]]): self.panel = wx.Panel(parent, style=wx.SIMPLE_BORDER) self.panel_sizer = wx.BoxSizer(wx.VERTICAL) self.panel.SetSizer(self.panel_sizer) self.panel.SetAutoLayout(1) parent.GetSizer().Add(self.panel, 0, wx.EXPAND) self.current_pose_supplier = current_pose_supplier if True: eyebrow_down_mode_text = wx.StaticText(self.panel, label=" --- Eyebrow Down Mode --- ", style=wx.ALIGN_CENTER) self.panel_sizer.Add(eyebrow_down_mode_text, 0, wx.EXPAND) self.eyebrow_down_mode_choice = wx.Choice( self.panel, choices=[ "ANGRY", "TROUBLED", "SERIOUS", "LOWERED", ]) self.eyebrow_down_mode_choice.SetSelection(0) self.panel_sizer.Add(self.eyebrow_down_mode_choice, 0, wx.EXPAND) self.eyebrow_down_mode_choice.Bind(wx.EVT_CHOICE, self.change_eyebrow_down_mode) if True: separator = wx.StaticLine(self.panel, -1, size=(256, 5)) self.panel_sizer.Add(separator, 0, wx.EXPAND) wink_mode_text = wx.StaticText(self.panel, label=" --- Wink Mode --- ", style=wx.ALIGN_CENTER) self.panel_sizer.Add(wink_mode_text, 0, wx.EXPAND) self.wink_mode_choice = wx.Choice( self.panel, choices=[ "NORMAL", "RELAXED", ]) self.wink_mode_choice.SetSelection(0) self.panel_sizer.Add(self.wink_mode_choice, 0, wx.EXPAND) self.wink_mode_choice.Bind(wx.EVT_CHOICE, self.change_wink_mode) if True: separator = wx.StaticLine(self.panel, -1, size=(256, 5)) self.panel_sizer.Add(separator, 0, wx.EXPAND) iris_size_text = wx.StaticText(self.panel, label=" --- Iris Size --- ", style=wx.ALIGN_CENTER) self.panel_sizer.Add(iris_size_text, 0, wx.EXPAND) self.iris_left_slider = wx.Slider(self.panel, minValue=0, maxValue=1000, value=0, style=wx.HORIZONTAL) self.panel_sizer.Add(self.iris_left_slider, 0, wx.EXPAND) self.iris_left_slider.Bind(wx.EVT_SLIDER, self.change_iris_size) self.iris_right_slider = wx.Slider(self.panel, minValue=0, maxValue=1000, value=0, style=wx.HORIZONTAL) self.panel_sizer.Add(self.iris_right_slider, 0, wx.EXPAND) self.iris_right_slider.Bind(wx.EVT_SLIDER, self.change_iris_size) self.iris_right_slider.Enable(False) self.link_left_right_irises = wx.CheckBox( self.panel, label="Use same value for both sides") self.link_left_right_irises.SetValue(True) self.panel_sizer.Add(self.link_left_right_irises, wx.SizerFlags().CenterHorizontal().Border()) self.link_left_right_irises.Bind(wx.EVT_CHECKBOX, self.link_left_right_irises_clicked) if True: separator = wx.StaticLine(self.panel, -1, size=(256, 5)) self.panel_sizer.Add(separator, 0, wx.EXPAND) breathing_frequency_text = wx.StaticText( self.panel, label=" --- Breathing --- ", style=wx.ALIGN_CENTER) self.panel_sizer.Add(breathing_frequency_text, 0, wx.EXPAND) self.restart_breathing_cycle_button = wx.Button(self.panel, label="Restart Breathing Cycle") self.restart_breathing_cycle_button.Bind(wx.EVT_BUTTON, self.restart_breathing_cycle_clicked) self.panel_sizer.Add(self.restart_breathing_cycle_button, 0, wx.EXPAND) self.breathing_frequency_slider = wx.Slider( self.panel, minValue=0, maxValue=60, value=20, style=wx.HORIZONTAL) self.panel_sizer.Add(self.breathing_frequency_slider, 0, wx.EXPAND) self.breathing_gauge = wx.Gauge(self.panel, style=wx.GA_HORIZONTAL, range=1000) self.panel_sizer.Add(self.breathing_gauge, 0, wx.EXPAND) if True: separator = wx.StaticLine(self.panel, -1, size=(256, 5)) self.panel_sizer.Add(separator, 0, wx.EXPAND) face_orientation_text = wx.StaticText( self.panel, label="--- Face Orientation ---", style=wx.ALIGN_CENTER) self.panel_sizer.Add(face_orientation_text, 0, wx.EXPAND) self.calibrate_face_orientation_button = wx.Button(self.panel, label="Calibrate (I'm looking forward)") self.calibrate_face_orientation_button.Bind(wx.EVT_BUTTON, self.calibrate_face_orientation_clicked) self.panel_sizer.Add(self.calibrate_face_orientation_button, 0, wx.EXPAND) if True: separator = wx.StaticLine(self.panel, -1, size=(256, 5)) self.panel_sizer.Add(separator, 0, wx.EXPAND) convertion_parameters_text = wx.StaticText( self.panel, label="--- Conversion Parameters ---", style=wx.ALIGN_CENTER) self.panel_sizer.Add(convertion_parameters_text, 0, wx.EXPAND) conversion_param_panel = wx.Panel(self.panel) self.panel_sizer.Add(conversion_param_panel, 0, wx.EXPAND) conversion_panel_sizer = wx.FlexGridSizer(cols=2) conversion_panel_sizer.AddGrowableCol(1) conversion_param_panel.SetSizer(conversion_panel_sizer) conversion_param_panel.SetAutoLayout(1) self.smile_thresold_min_spin = self.create_spin_control( conversion_param_panel, "Smile Threshold Min:", self.args.smile_threshold_min, self.args.set_smile_threshold_min) self.smile_thresold_max_spin = self.create_spin_control( conversion_param_panel, "Smile Threshold Max:", self.args.smile_threshold_max, self.args.set_smile_threshold_max) self.eye_surprised_max_spin = self.create_spin_control( conversion_param_panel, "Eye Surprised Max:", self.args.eye_surprised_max, self.args.set_eye_surprised_max) self.eye_blink_max_spin = self.create_spin_control( conversion_param_panel, "Eye Blink Max:", self.args.eye_blink_max, self.args.set_eye_blink_max) self.eyebrow_down_max_spin = self.create_spin_control( conversion_param_panel, "Eyebrow Down Max:", self.args.eyebrow_down_max, self.args.set_eyebrow_down_max) self.cheek_squint_min_spin = self.create_spin_control( conversion_param_panel, "Cheek Squint Min:", self.args.cheek_squint_min, self.args.set_cheek_squint_min) self.cheek_squint_max_spin = self.create_spin_control( conversion_param_panel, "Cheek Squint Max:", self.args.cheek_squint_max, self.args.set_cheek_squint_max) self.jaw_open_min_spin = self.create_spin_control( conversion_param_panel, "Jaw Open Min:", self.args.jaw_open_min, self.args.set_jaw_open_min) self.jaw_open_max_spin = self.create_spin_control( conversion_param_panel, "Jaw Open Max:", self.args.jaw_open_max, self.args.set_jaw_open_max) self.mouth_frown_max_spin = self.create_spin_control( conversion_param_panel, "Mouth Frown Max:", self.args.mouth_frown_max, self.args.set_mouth_frown_max) self.mouth_funnel_min_spin = self.create_spin_control( conversion_param_panel, "Mouth Funnel Min:", self.args.mouth_funnel_min, self.args.set_mouth_funnel_min) self.mouth_funnel_max_spin = self.create_spin_control( conversion_param_panel, "Mouth Funnel Max:", self.args.mouth_funnel_max, self.args.set_mouth_funnel_max) self.panel_sizer.Fit(self.panel) def create_spin_control(self, parent, label: str, initial_value: float, set_func: Callable[[float], None]): sizer = parent.GetSizer() text = wx.StaticText(parent, label=label) sizer.Add(text, wx.SizerFlags().Right().Border(wx.ALL, 2)) spin_ctrl = wx.SpinCtrlDouble( parent, wx.ID_ANY, min=0.0, max=1.0, initial=initial_value, inc=0.01) sizer.Add(spin_ctrl, wx.SizerFlags().Border(wx.ALL, 2).Expand()) def handler(event: wx.Event): new_value = spin_ctrl.GetValue() set_func(new_value) spin_ctrl.Bind(wx.EVT_SPINCTRLDOUBLE, handler) return spin_ctrl def extract_euler_angles(self, mediapipe_face_pose: MediaPipeFacePose): M = mediapipe_face_pose.xform_matrix[0:3, 0:3] rot = Rotation.from_matrix(M) return rot.as_euler('xyz', degrees=False) def calibrate_face_orientation_clicked(self, event: wx.Event): if self.current_pose_supplier is None: return mediapipe_face_pose = self.current_pose_supplier() if mediapipe_face_pose is None: return euler_angles = self.extract_euler_angles(mediapipe_face_pose) self.args.head_x_offset = euler_angles[0] self.args.head_y_offset = euler_angles[1] self.args.head_z_offset = euler_angles[2] def restart_breathing_cycle_clicked(self, event: wx.Event): self.breathing_start_time = time.time() def change_eyebrow_down_mode(self, event: wx.Event): selected_index = self.eyebrow_down_mode_choice.GetSelection() if selected_index == 0: self.args.eyebrow_down_mode = EyebrowDownMode.ANGRY elif selected_index == 1: self.args.eyebrow_down_mode = EyebrowDownMode.TROUBLED elif selected_index == 2: self.args.eyebrow_down_mode = EyebrowDownMode.SERIOUS else: self.args.eyebrow_down_mode = EyebrowDownMode.LOWERED def change_wink_mode(self, event: wx.Event): selected_index = self.wink_mode_choice.GetSelection() if selected_index == 0: self.args.wink_mode = WinkMode.NORMAL else: self.args.wink_mode = WinkMode.RELAXED def change_iris_size(self, event: wx.Event): if self.link_left_right_irises.GetValue(): left_value = self.iris_left_slider.GetValue() right_value = self.iris_right_slider.GetValue() if left_value != right_value: self.iris_right_slider.SetValue(left_value) self.args.iris_small_left = left_value / 1000.0 self.args.iris_small_right = left_value / 1000.0 else: self.args.iris_small_left = self.iris_left_slider.GetValue() / 1000.0 self.args.iris_small_right = self.iris_right_slider.GetValue() / 1000.0 def link_left_right_irises_clicked(self, event: wx.Event): if self.link_left_right_irises.GetValue(): self.iris_right_slider.Enable(False) else: self.iris_right_slider.Enable(True) self.change_iris_size(event) def decompose_head_body_param(self, param, threshold=2.0 / 3): if abs(param) < threshold: return (param, 0.0) else: if param < 0: sign = -1.0 else: sign = 1.0 return (threshold * sign, (abs(param) - threshold) * sign) def convert(self, mediapipe_face_pose: MediaPipeFacePose) -> List[float]: pose = [0.0 for i in range(self.pose_size)] blendshape_params = mediapipe_face_pose.blendshape_params smile_value = \ (blendshape_params[MOUTH_SMILE_LEFT] + blendshape_params[MOUTH_SMILE_RIGHT]) / 2.0 \ + blendshape_params[MOUTH_SHRUG_UPPER] if self.args.smile_threshold_min >= self.args.smile_threshold_max: smile_degree = 0.0 else: if smile_value < self.args.smile_threshold_min: smile_degree = 0.0 elif smile_value > self.args.smile_threshold_max: smile_degree = 1.0 else: smile_degree = (smile_value - self.args.smile_threshold_min) / ( self.args.smile_threshold_max - self.args.smile_threshold_min) # Eyebrow if True: brow_inner_up = blendshape_params[BROW_INNER_UP] brow_outer_up_right = blendshape_params[BROW_OUTER_UP_RIGHT] brow_outer_up_left = blendshape_params[BROW_OUTER_UP_LEFT] brow_up_left = clamp(brow_inner_up + brow_outer_up_left, 0.0, 1.0) brow_up_right = clamp(brow_inner_up + brow_outer_up_right, 0.0, 1.0) pose[self.eyebrow_raised_left_index] = brow_up_left pose[self.eyebrow_raised_right_index] = brow_up_right if self.args.eyebrow_down_max <= 0.0: brow_down_left = 0.0 brow_down_right = 0.0 else: brow_down_left = (1.0 - smile_degree) \ * clamp(blendshape_params[BROW_DOWN_LEFT] / self.args.eyebrow_down_max, 0.0, 1.0) brow_down_right = (1.0 - smile_degree) \ * clamp(blendshape_params[BROW_DOWN_RIGHT] / self.args.eyebrow_down_max, 0.0, 1.0) if self.args.eyebrow_down_mode == EyebrowDownMode.TROUBLED: pose[self.eyebrow_troubled_left_index] = brow_down_left pose[self.eyebrow_troubled_right_index] = brow_down_right elif self.args.eyebrow_down_mode == EyebrowDownMode.ANGRY: pose[self.eyebrow_angry_left_index] = brow_down_left pose[self.eyebrow_angry_right_index] = brow_down_right elif self.args.eyebrow_down_mode == EyebrowDownMode.LOWERED: pose[self.eyebrow_lowered_left_index] = brow_down_left pose[self.eyebrow_lowered_right_index] = brow_down_right elif self.args.eyebrow_down_mode == EyebrowDownMode.SERIOUS: pose[self.eyebrow_serious_left_index] = brow_down_left pose[self.eyebrow_serious_right_index] = brow_down_right brow_happy_value = clamp(smile_value, 0.0, 1.0) * smile_degree pose[self.eyebrow_happy_left_index] = brow_happy_value pose[self.eyebrow_happy_right_index] = brow_happy_value # Eye if True: # Surprised if self.args.eye_surprised_max <= 0.0: pose[self.eye_surprised_left_index] = 0.0 pose[self.eye_surprised_right_index] = 0.0 else: pose[self.eye_surprised_left_index] = clamp( blendshape_params[EYE_WIDE_LEFT] / self.args.eye_surprised_max, 0.0, 1.0) pose[self.eye_surprised_right_index] = clamp( blendshape_params[EYE_WIDE_RIGHT] / self.args.eye_surprised_max, 0.0, 1.0) # Wink if self.args.wink_mode == WinkMode.NORMAL: wink_left_index = self.eye_wink_left_index wink_right_index = self.eye_wink_right_index else: wink_left_index = self.eye_relaxed_left_index wink_right_index = self.eye_relaxed_right_index if self.args.eye_blink_max <= 0: pose[wink_left_index] = 0.0 pose[wink_right_index] = 0.0 pose[self.eye_happy_wink_left_index] = 0.0 pose[self.eye_happy_wink_right_index] = 0.0 else: pose[wink_left_index] = (1.0 - smile_degree) * clamp( blendshape_params[EYE_BLINK_LEFT] / self.args.eye_blink_max, 0.0, 1.0) pose[wink_right_index] = (1.0 - smile_degree) * clamp( blendshape_params[EYE_BLINK_RIGHT] / self.args.eye_blink_max, 0.0, 1.0) pose[self.eye_happy_wink_left_index] = smile_degree * clamp( blendshape_params[EYE_BLINK_LEFT] / self.args.eye_blink_max, 0.0, 1.0) pose[self.eye_happy_wink_right_index] = smile_degree * clamp( blendshape_params[EYE_BLINK_RIGHT] / self.args.eye_blink_max, 0.0, 1.0) # Lower eyelid cheek_squint_denom = self.args.cheek_squint_max - self.args.cheek_squint_min if cheek_squint_denom <= 0.0: pose[self.eye_raised_lower_eyelid_left_index] = 0.0 pose[self.eye_raised_lower_eyelid_right_index] = 0.0 else: pose[self.eye_raised_lower_eyelid_left_index] = \ clamp( (blendshape_params[CHEEK_SQUINT_LEFT] - self.args.cheek_squint_min) / cheek_squint_denom, 0.0, 1.0) pose[self.eye_raised_lower_eyelid_right_index] = \ clamp( (blendshape_params[CHEEK_SQUINT_RIGHT] - self.args.cheek_squint_min) / cheek_squint_denom, 0.0, 1.0) # Iris rotation if True: eye_rotation_y = (blendshape_params[EYE_LOOK_IN_LEFT] - blendshape_params[EYE_LOOK_OUT_LEFT] - blendshape_params[EYE_LOOK_IN_RIGHT] + blendshape_params[EYE_LOOK_OUT_RIGHT]) / 2.0 * self.args.eye_rotation_factor pose[self.iris_rotation_y_index] = clamp(eye_rotation_y, -1.0, 1.0) eye_rotation_x = (blendshape_params[EYE_LOOK_UP_LEFT] + blendshape_params[EYE_LOOK_UP_RIGHT] - blendshape_params[EYE_LOOK_DOWN_LEFT] - blendshape_params[EYE_LOOK_DOWN_RIGHT]) / 2.0 * self.args.eye_rotation_factor pose[self.iris_rotation_x_index] = clamp(eye_rotation_x, -1.0, 1.0) # Iris size if True: pose[self.iris_small_left_index] = self.args.iris_small_left pose[self.iris_small_right_index] = self.args.iris_small_right # Head rotation if True: euler_angles = self.extract_euler_angles(mediapipe_face_pose) euler_angles[0] -= self.args.head_x_offset euler_angles[1] -= self.args.head_y_offset euler_angles[2] -= self.args.head_z_offset x_param = clamp(-euler_angles[0] * 180.0 / math.pi, -15.0, 15.0) / 15.0 pose[self.head_x_index] = x_param y_param = clamp(-euler_angles[1] * 180.0 / math.pi, -10.0, 10.0) / 10.0 pose[self.head_y_index] = y_param pose[self.body_y_index] = y_param z_param = clamp(euler_angles[2] * 180.0 / math.pi, -15.0, 15.0) / 15.0 pose[self.neck_z_index] = z_param pose[self.body_z_index] = z_param # Mouth if True: jaw_open_denom = self.args.jaw_open_max - self.args.jaw_open_min if jaw_open_denom <= 0: mouth_open = 0.0 else: mouth_open = clamp((blendshape_params[JAW_OPEN] - self.args.jaw_open_min) / jaw_open_denom, 0.0, 1.0) pose[self.mouth_aaa_index] = mouth_open pose[self.mouth_raised_corner_left_index] = clamp(smile_value, 0.0, 1.0) pose[self.mouth_raised_corner_right_index] = clamp(smile_value, 0.0, 1.0) is_mouth_open = mouth_open > 0.0 if not is_mouth_open: if self.args.mouth_frown_max <= 0: mouth_frown_value = 0.0 else: mouth_frown_value = clamp( (blendshape_params[MOUTH_FROWN_LEFT] + blendshape_params[ MOUTH_FROWN_RIGHT]) / self.args.mouth_frown_max, 0.0, 1.0) pose[self.mouth_lowered_corner_left_index] = mouth_frown_value pose[self.mouth_lowered_corner_right_index] = mouth_frown_value else: mouth_lower_down = clamp( blendshape_params[MOUTH_LOWER_DOWN_LEFT] + blendshape_params[MOUTH_LOWER_DOWN_RIGHT], 0.0, 1.0) mouth_funnel = blendshape_params[MOUTH_FUNNEL] mouth_pucker = blendshape_params[MOUTH_PUCKER] mouth_point = [mouth_open, mouth_lower_down, mouth_funnel, mouth_pucker] aaa_point = [1.0, 1.0, 0.0, 0.0] iii_point = [0.0, 1.0, 0.0, 0.0] uuu_point = [0.5, 0.3, 0.25, 0.75] ooo_point = [1.0, 0.5, 0.5, 0.4] decomp = numpy.array([0, 0, 0, 0]) M = numpy.array([ aaa_point, iii_point, uuu_point, ooo_point ]) def loss(decomp): return numpy.linalg.norm(numpy.matmul(decomp, M) - mouth_point) \ + 0.01 * numpy.linalg.norm(decomp, ord=1) opt_result = scipy.optimize.minimize( loss, decomp, bounds=[(0.0, 1.0), (0.0, 1.0), (0.0, 1.0), (0.0, 1.0)]) decomp = opt_result["x"] restricted_decomp = [decomp.item(0), decomp.item(1), decomp.item(2), decomp.item(3)] pose[self.mouth_aaa_index] = restricted_decomp[0] pose[self.mouth_iii_index] = restricted_decomp[1] mouth_funnel_denom = self.args.mouth_funnel_max - self.args.mouth_funnel_min if mouth_funnel_denom <= 0: ooo_alpha = 0.0 uo_value = 0.0 else: ooo_alpha = clamp((mouth_funnel - self.args.mouth_funnel_min) / mouth_funnel_denom, 0.0, 1.0) uo_value = clamp(restricted_decomp[2] + restricted_decomp[3], 0.0, 1.0) pose[self.mouth_uuu_index] = uo_value * (1.0 - ooo_alpha) pose[self.mouth_ooo_index] = uo_value * ooo_alpha if self.panel is not None: frequency = self.breathing_frequency_slider.GetValue() if frequency == 0: value = 0.0 pose[self.breathing_index] = value self.breathing_start_time = time.time() else: period = 60.0 / frequency now = time.time() diff = now - self.breathing_start_time frac = (diff % period) / period value = (-math.cos(2 * math.pi * frac) + 1.0) / 2.0 pose[self.breathing_index] = value self.breathing_gauge.SetValue(int(1000 * value)) return pose ================================================ FILE: src/tha4/nn/__init__.py ================================================ ================================================ FILE: src/tha4/nn/common/__init__.py ================================================ ================================================ FILE: src/tha4/nn/common/conv_block_factory.py ================================================ from typing import Optional from tha4.nn.conv import create_conv7_block_from_block_args, create_conv3_block_from_block_args, \ create_downsample_block_from_block_args, create_conv3 from tha4.nn.resnet_block import ResnetBlock from tha4.nn.resnet_block_seperable import ResnetBlockSeparable from tha4.nn.separable_conv import create_separable_conv7_block, create_separable_conv3_block, \ create_separable_downsample_block, create_separable_conv3 from tha4.nn.util import BlockArgs class ConvBlockFactory: def __init__(self, block_args: BlockArgs, use_separable_convolution: bool = False): self.use_separable_convolution = use_separable_convolution self.block_args = block_args def create_conv3(self, in_channels: int, out_channels: int, bias: bool, initialization_method: Optional[str] = None): if initialization_method is None: initialization_method = self.block_args.initialization_method if self.use_separable_convolution: return create_separable_conv3( in_channels, out_channels, bias, initialization_method, self.block_args.use_spectral_norm) else: return create_conv3( in_channels, out_channels, bias, initialization_method, self.block_args.use_spectral_norm) def create_conv7_block(self, in_channels: int, out_channels: int): if self.use_separable_convolution: return create_separable_conv7_block(in_channels, out_channels, self.block_args) else: return create_conv7_block_from_block_args(in_channels, out_channels, self.block_args) def create_conv3_block(self, in_channels: int, out_channels: int): if self.use_separable_convolution: return create_separable_conv3_block(in_channels, out_channels, self.block_args) else: return create_conv3_block_from_block_args(in_channels, out_channels, self.block_args) def create_downsample_block(self, in_channels: int, out_channels: int, is_output_1x1: bool): if self.use_separable_convolution: return create_separable_downsample_block(in_channels, out_channels, is_output_1x1, self.block_args) else: return create_downsample_block_from_block_args(in_channels, out_channels, is_output_1x1) def create_resnet_block(self, num_channels: int, is_1x1: bool): if self.use_separable_convolution: return ResnetBlockSeparable.create(num_channels, is_1x1, block_args=self.block_args) else: return ResnetBlock.create(num_channels, is_1x1, block_args=self.block_args) ================================================ FILE: src/tha4/nn/common/poser_args.py ================================================ from typing import Optional from torch.nn import Sigmoid, Sequential, Tanh from tha4.nn.conv import create_conv3, create_conv3_from_block_args from tha4.nn.nonlinearity_factory import ReLUFactory from tha4.nn.normalization import InstanceNorm2dFactory from tha4.nn.util import BlockArgs class PoserArgs00: def __init__(self, image_size: int, input_image_channels: int, output_image_channels: int, start_channels: int, num_pose_params: int, block_args: Optional[BlockArgs] = None): self.num_pose_params = num_pose_params self.start_channels = start_channels self.output_image_channels = output_image_channels self.input_image_channels = input_image_channels self.image_size = image_size if block_args is None: self.block_args = BlockArgs( normalization_layer_factory=InstanceNorm2dFactory(), nonlinearity_factory=ReLUFactory(inplace=True)) else: self.block_args = block_args def create_alpha_block(self): from torch.nn import Sequential return Sequential( create_conv3( in_channels=self.start_channels, out_channels=1, bias=True, initialization_method=self.block_args.initialization_method, use_spectral_norm=False), Sigmoid()) def create_all_channel_alpha_block(self): from torch.nn import Sequential return Sequential( create_conv3( in_channels=self.start_channels, out_channels=self.output_image_channels, bias=True, initialization_method=self.block_args.initialization_method, use_spectral_norm=False), Sigmoid()) def create_color_change_block(self): return Sequential( create_conv3_from_block_args( in_channels=self.start_channels, out_channels=self.output_image_channels, bias=True, block_args=self.block_args), Tanh()) def create_grid_change_block(self): return create_conv3( in_channels=self.start_channels, out_channels=2, bias=False, initialization_method='zero', use_spectral_norm=False) ================================================ FILE: src/tha4/nn/common/poser_encoder_decoder_00.py ================================================ import math from typing import Optional, List import torch from torch import Tensor from torch.nn import ModuleList, Module from tha4.nn.common.poser_args import PoserArgs00 from tha4.nn.conv import create_conv3_block_from_block_args, create_downsample_block_from_block_args, \ create_upsample_block_from_block_args from tha4.nn.nonlinearity_factory import ReLUFactory from tha4.nn.normalization import InstanceNorm2dFactory from tha4.nn.resnet_block import ResnetBlock from tha4.nn.util import BlockArgs class PoserEncoderDecoder00Args(PoserArgs00): def __init__(self, image_size: int, input_image_channels: int, output_image_channels: int, num_pose_params: int , start_channels: int, bottleneck_image_size, num_bottleneck_blocks, max_channels: int, block_args: Optional[BlockArgs] = None): super().__init__( image_size, input_image_channels, output_image_channels, start_channels, num_pose_params, block_args) self.max_channels = max_channels self.num_bottleneck_blocks = num_bottleneck_blocks self.bottleneck_image_size = bottleneck_image_size assert bottleneck_image_size > 1 if block_args is None: self.block_args = BlockArgs( normalization_layer_factory=InstanceNorm2dFactory(), nonlinearity_factory=ReLUFactory(inplace=True)) else: self.block_args = block_args class PoserEncoderDecoder00(Module): def __init__(self, args: PoserEncoderDecoder00Args): super().__init__() self.args = args self.num_levels = int(math.log2(args.image_size // args.bottleneck_image_size)) + 1 self.downsample_blocks = ModuleList() self.downsample_blocks.append( create_conv3_block_from_block_args( args.input_image_channels, args.start_channels, args.block_args)) current_image_size = args.image_size current_num_channels = args.start_channels while current_image_size > args.bottleneck_image_size: next_image_size = current_image_size // 2 next_num_channels = self.get_num_output_channels_from_image_size(next_image_size) self.downsample_blocks.append(create_downsample_block_from_block_args( in_channels=current_num_channels, out_channels=next_num_channels, is_output_1x1=False, block_args=args.block_args)) current_image_size = next_image_size current_num_channels = next_num_channels assert len(self.downsample_blocks) == self.num_levels self.bottleneck_blocks = ModuleList() self.bottleneck_blocks.append(create_conv3_block_from_block_args( in_channels=current_num_channels + args.num_pose_params, out_channels=current_num_channels, block_args=args.block_args)) for i in range(1, args.num_bottleneck_blocks): self.bottleneck_blocks.append( ResnetBlock.create( num_channels=current_num_channels, is1x1=False, block_args=args.block_args)) self.upsample_blocks = ModuleList() while current_image_size < args.image_size: next_image_size = current_image_size * 2 next_num_channels = self.get_num_output_channels_from_image_size(next_image_size) self.upsample_blocks.append(create_upsample_block_from_block_args( in_channels=current_num_channels, out_channels=next_num_channels, block_args=args.block_args)) current_image_size = next_image_size current_num_channels = next_num_channels def get_num_output_channels_from_level(self, level: int): return self.get_num_output_channels_from_image_size(self.args.image_size // (2 ** level)) def get_num_output_channels_from_image_size(self, image_size: int): return min(self.args.start_channels * (self.args.image_size // image_size), self.args.max_channels) def forward(self, image: Tensor, pose: Optional[Tensor] = None) -> List[Tensor]: if self.args.num_pose_params != 0: assert pose is not None else: assert pose is None outputs = [] feature = image outputs.append(feature) for block in self.downsample_blocks: feature = block(feature) outputs.append(feature) if pose is not None: n, c = pose.shape pose = pose.view(n, c, 1, 1).repeat(1, 1, self.args.bottleneck_image_size, self.args.bottleneck_image_size) feature = torch.cat([feature, pose], dim=1) for block in self.bottleneck_blocks: feature = block(feature) outputs.append(feature) for block in self.upsample_blocks: feature = block(feature) outputs.append(feature) outputs.reverse() return outputs ================================================ FILE: src/tha4/nn/common/poser_encoder_decoder_00_separable.py ================================================ import math from typing import Optional, List import torch from torch import Tensor from torch.nn import ModuleList, Module from tha4.nn.common.poser_encoder_decoder_00 import PoserEncoderDecoder00Args from tha4.nn.resnet_block_seperable import ResnetBlockSeparable from tha4.nn.separable_conv import create_separable_conv3_block, create_separable_downsample_block, \ create_separable_upsample_block class PoserEncoderDecoder00Separable(Module): def __init__(self, args: PoserEncoderDecoder00Args): super().__init__() self.args = args self.num_levels = int(math.log2(args.image_size // args.bottleneck_image_size)) + 1 self.downsample_blocks = ModuleList() self.downsample_blocks.append( create_separable_conv3_block( args.input_image_channels, args.start_channels, args.block_args)) current_image_size = args.image_size current_num_channels = args.start_channels while current_image_size > args.bottleneck_image_size: next_image_size = current_image_size // 2 next_num_channels = self.get_num_output_channels_from_image_size(next_image_size) self.downsample_blocks.append(create_separable_downsample_block( in_channels=current_num_channels, out_channels=next_num_channels, is_output_1x1=False, block_args=args.block_args)) current_image_size = next_image_size current_num_channels = next_num_channels assert len(self.downsample_blocks) == self.num_levels self.bottleneck_blocks = ModuleList() self.bottleneck_blocks.append(create_separable_conv3_block( in_channels=current_num_channels + args.num_pose_params, out_channels=current_num_channels, block_args=args.block_args)) for i in range(1, args.num_bottleneck_blocks): self.bottleneck_blocks.append( ResnetBlockSeparable.create( num_channels=current_num_channels, is1x1=False, block_args=args.block_args)) self.upsample_blocks = ModuleList() while current_image_size < args.image_size: next_image_size = current_image_size * 2 next_num_channels = self.get_num_output_channels_from_image_size(next_image_size) self.upsample_blocks.append(create_separable_upsample_block( in_channels=current_num_channels, out_channels=next_num_channels, block_args=args.block_args)) current_image_size = next_image_size current_num_channels = next_num_channels def get_num_output_channels_from_level(self, level: int): return self.get_num_output_channels_from_image_size(self.args.image_size // (2 ** level)) def get_num_output_channels_from_image_size(self, image_size: int): return min(self.args.start_channels * (self.args.image_size // image_size), self.args.max_channels) def forward(self, image: Tensor, pose: Optional[Tensor] = None) -> List[Tensor]: if self.args.num_pose_params != 0: assert pose is not None else: assert pose is None outputs = [] feature = image outputs.append(feature) for block in self.downsample_blocks: feature = block(feature) outputs.append(feature) if pose is not None: n, c = pose.shape pose = pose.view(n, c, 1, 1).repeat(1, 1, self.args.bottleneck_image_size, self.args.bottleneck_image_size) feature = torch.cat([feature, pose], dim=1) for block in self.bottleneck_blocks: feature = block(feature) outputs.append(feature) for block in self.upsample_blocks: feature = block(feature) outputs.append(feature) outputs.reverse() return outputs ================================================ FILE: src/tha4/nn/common/resize_conv_encoder_decoder.py ================================================ import math from typing import Optional, List import torch from torch import Tensor from torch.nn import Module, ModuleList, Sequential, Upsample from tha4.nn.common.conv_block_factory import ConvBlockFactory from tha4.nn.nonlinearity_factory import LeakyReLUFactory from tha4.nn.normalization import InstanceNorm2dFactory from tha4.nn.util import BlockArgs class ResizeConvEncoderDecoderArgs: def __init__(self, image_size: int, input_channels: int, start_channels: int, bottleneck_image_size, num_bottleneck_blocks, max_channels: int, block_args: Optional[BlockArgs] = None, upsample_mode: str = 'bilinear', use_separable_convolution=False): self.use_separable_convolution = use_separable_convolution self.upsample_mode = upsample_mode self.block_args = block_args self.max_channels = max_channels self.num_bottleneck_blocks = num_bottleneck_blocks self.bottleneck_image_size = bottleneck_image_size self.start_channels = start_channels self.image_size = image_size self.input_channels = input_channels class ResizeConvEncoderDecoder(Module): def __init__(self, args: ResizeConvEncoderDecoderArgs): super().__init__() self.args = args self.num_levels = int(math.log2(args.image_size // args.bottleneck_image_size)) + 1 conv_block_factory = ConvBlockFactory(args.block_args, args.use_separable_convolution) self.downsample_blocks = ModuleList() self.downsample_blocks.append(conv_block_factory.create_conv7_block(args.input_channels, args.start_channels)) current_image_size = args.image_size current_num_channels = args.start_channels while current_image_size > args.bottleneck_image_size: next_image_size = current_image_size // 2 next_num_channels = self.get_num_output_channels_from_image_size(next_image_size) self.downsample_blocks.append(conv_block_factory.create_downsample_block( in_channels=current_num_channels, out_channels=next_num_channels, is_output_1x1=False)) current_image_size = next_image_size current_num_channels = next_num_channels assert len(self.downsample_blocks) == self.num_levels self.bottleneck_blocks = ModuleList() for i in range(args.num_bottleneck_blocks): self.bottleneck_blocks.append(conv_block_factory.create_resnet_block(current_num_channels, is_1x1=False)) self.output_image_sizes = [current_image_size] self.output_num_channels = [current_num_channels] self.upsample_blocks = ModuleList() if args.upsample_mode == 'nearest': align_corners = None else: align_corners = False while current_image_size < args.image_size: next_image_size = current_image_size * 2 next_num_channels = self.get_num_output_channels_from_image_size(next_image_size) self.upsample_blocks.append( Sequential( Upsample(scale_factor=2, mode=args.upsample_mode, align_corners=align_corners), conv_block_factory.create_conv3_block( in_channels=current_num_channels, out_channels=next_num_channels))) current_image_size = next_image_size current_num_channels = next_num_channels self.output_image_sizes.append(current_image_size) self.output_num_channels.append(current_num_channels) def get_num_output_channels_from_level(self, level: int): return self.get_num_output_channels_from_image_size(self.args.image_size // (2 ** level)) def get_num_output_channels_from_image_size(self, image_size: int): return min(self.args.start_channels * (self.args.image_size // image_size), self.args.max_channels) def forward(self, feature: Tensor) -> List[Tensor]: outputs = [] for block in self.downsample_blocks: feature = block(feature) for block in self.bottleneck_blocks: feature = block(feature) outputs.append(feature) for block in self.upsample_blocks: feature = block(feature) outputs.append(feature) return outputs ================================================ FILE: src/tha4/nn/common/resize_conv_unet.py ================================================ from typing import Optional, List import torch from torch import Tensor from torch.nn import ModuleList, Module, Upsample from tha4.nn.common.conv_block_factory import ConvBlockFactory from tha4.nn.nonlinearity_factory import ReLUFactory from tha4.nn.normalization import InstanceNorm2dFactory from tha4.nn.util import BlockArgs class ResizeConvUNetArgs: def __init__(self, image_size: int, input_channels: int, start_channels: int, bottleneck_image_size: int, num_bottleneck_blocks: int, max_channels: int, upsample_mode: str = 'bilinear', block_args: Optional[BlockArgs] = None, use_separable_convolution: bool = False): if block_args is None: block_args = BlockArgs( normalization_layer_factory=InstanceNorm2dFactory(), nonlinearity_factory=ReLUFactory(inplace=False)) self.use_separable_convolution = use_separable_convolution self.block_args = block_args self.upsample_mode = upsample_mode self.max_channels = max_channels self.num_bottleneck_blocks = num_bottleneck_blocks self.bottleneck_image_size = bottleneck_image_size self.input_channels = input_channels self.start_channels = start_channels self.image_size = image_size class ResizeConvUNet(Module): def __init__(self, args: ResizeConvUNetArgs): super().__init__() self.args = args conv_block_factory = ConvBlockFactory(args.block_args, args.use_separable_convolution) self.downsample_blocks = ModuleList() self.downsample_blocks.append(conv_block_factory.create_conv3_block( self.args.input_channels, self.args.start_channels)) current_channels = self.args.start_channels current_size = self.args.image_size size_to_channel = { current_size: current_channels } while current_size > self.args.bottleneck_image_size: next_size = current_size // 2 next_channels = min(self.args.max_channels, current_channels * 2) self.downsample_blocks.append(conv_block_factory.create_downsample_block( current_channels, next_channels, is_output_1x1=False)) current_size = next_size current_channels = next_channels size_to_channel[current_size] = current_channels self.bottleneck_blocks = ModuleList() for i in range(self.args.num_bottleneck_blocks): self.bottleneck_blocks.append(conv_block_factory.create_resnet_block(current_channels, is_1x1=False)) self.output_image_sizes = [current_size] self.output_num_channels = [current_channels] self.upsample_blocks = ModuleList() while current_size < self.args.image_size: next_size = current_size * 2 next_channels = size_to_channel[next_size] self.upsample_blocks.append(conv_block_factory.create_conv3_block( current_channels + next_channels, next_channels)) current_size = next_size current_channels = next_channels self.output_image_sizes.append(current_size) self.output_num_channels.append(current_channels) if args.upsample_mode == 'nearest': align_corners = None else: align_corners = False self.double_resolution = Upsample(scale_factor=2, mode=args.upsample_mode, align_corners=align_corners) def forward(self, feature: Tensor) -> List[Tensor]: downsampled_features = [] for block in self.downsample_blocks: feature = block(feature) downsampled_features.append(feature) for block in self.bottleneck_blocks: feature = block(feature) outputs = [feature] for i in range(0, len(self.upsample_blocks)): feature = self.double_resolution(feature) feature = torch.cat([feature, downsampled_features[-i - 2]], dim=1) feature = self.upsample_blocks[i](feature) outputs.append(feature) return outputs ================================================ FILE: src/tha4/nn/common/unet.py ================================================ import math from enum import Enum from typing import Optional, List import torch from torch import zero_, Tensor from torch.nn import Module, GroupNorm, Sequential, SiLU, Conv2d, AvgPool2d, Linear, Dropout, ModuleList from torch.nn.functional import interpolate from tha4.shion.core.module_factory import ModuleFactory class Identity(Module): def __init__(self): super().__init__() def forward(self, x): return x class IdentityFactory(ModuleFactory): def create(self) -> Module: return Identity() def init_to_zero(module: Module): with torch.no_grad(): zero_(module.weight) zero_(module.bias) return module class Upsample(Module): def __init__(self, in_channels: int, out_channels: Optional[int] = None, use_conv: bool = False): super().__init__() if out_channels is None: out_channels = in_channels self.in_channels = in_channels if use_conv or in_channels != out_channels: self.postprocess = Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) else: self.postprocess = Identity() def forward(self, x): assert x.shape[1] == self.in_channels return self.postprocess(interpolate(x, scale_factor=2, mode="nearest")) class Downsample(Module): def __init__(self, in_channels: int, out_channels: Optional[int] = None, use_conv: bool = False): super().__init__() if out_channels is None: out_channels = in_channels self.in_channels = in_channels if use_conv or in_channels != out_channels: self.op = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1) else: self.op = AvgPool2d(kernel_size=2, stride=2) def forward(self, x): assert x.shape[1] == self.in_channels return self.op(x) def GroupNorm32(channels): return GroupNorm(min(32, channels), channels) class SamplingMode(Enum): SAME_RESOLUTION = 0 UPSAMPLING = 1 DOWNSAMPING = 2 class ResBlockArgs: def __init__(self, dropout_prob: float, use_cond0: bool = True, use_cond1: bool = False, init_conditioned_residual_to_zero: bool = False, use_conv_on_skip_connection: bool = False): assert not use_cond1 or use_cond0 self.use_conv_on_skip_connection = use_conv_on_skip_connection self.use_cond1 = use_cond1 self.use_cond0 = use_cond0 self.init_conditioned_residual_to_zero = init_conditioned_residual_to_zero self.dropout_prob = dropout_prob def apply_scaleshift(x: Tensor, scaleshift: Tensor, condition_bias: float = 1.0) -> Tensor: assert len(scaleshift.shape) == 2 assert len(x.shape) == 4 assert x.shape[0] == scaleshift.shape[0] assert 2 * x.shape[1] == scaleshift.shape[1] scaleshift = scaleshift.reshape(scaleshift.shape[0], scaleshift.shape[1], 1, 1) scale, shift = torch.chunk(scaleshift, 2, dim=1) return x * (condition_bias + scale) + shift class ResBlock(Module): def __init__(self, in_channels: int, out_channels: int, cond0_channels: Optional[int] = None, cond1_channels: Optional[int] = None, sampling_mode: SamplingMode = SamplingMode.SAME_RESOLUTION, dropout_prob: float = 0.1, condition_bias: float = 1.0): super().__init__() assert cond0_channels is not None or cond1_channels is None self.in_channels = in_channels self.out_channels = out_channels self.sampling_mode = sampling_mode self.cond0_channels = cond0_channels self.cond1_channels = cond1_channels self.condition_bias = condition_bias if sampling_mode == SamplingMode.UPSAMPLING: self.x_resample = Upsample(in_channels) self.h_resample = Upsample(in_channels) elif sampling_mode == SamplingMode.DOWNSAMPING: self.x_resample = Downsample(in_channels) self.h_resample = Downsample(in_channels) else: self.x_resample = Identity() self.h_resample = Identity() self.nonlinear = SiLU() # Layers before conditioning self.norm0 = GroupNorm32(in_channels) self.conv0 = Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) # Conditioning layers if cond0_channels is not None: self.cond0_layers = Sequential( SiLU(), Linear(cond0_channels, 2 * out_channels)) self.norm1 = GroupNorm32(out_channels) self.dropout = Dropout(dropout_prob) self.conv1 = init_to_zero(Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)) if cond1_channels is not None: self.cond1_layers = Sequential( SiLU(), Linear(cond0_channels, 2 * out_channels)) # Skip layer if in_channels == out_channels: self.skip = Identity() else: self.skip = Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x: Tensor, cond0: Optional[Tensor] = None, cond1: Optional[Tensor] = None) -> Tensor: assert self.cond0_channels is None or cond0 is not None assert self.cond1_channels is None or cond1 is not None h = self.conv0(self.h_resample(self.nonlinear(self.norm0(x)))) if self.cond0_channels is not None: h = self.norm1(h) h = apply_scaleshift(h, self.cond0_layers(cond0), self.condition_bias) if self.cond1_channels is not None: h = apply_scaleshift(h, self.cond1_layers(cond1), self.condition_bias) h = self.conv1(self.dropout(self.nonlinear(h))) return self.skip(self.x_resample(x)) + h class AttentionBlockArgs: def __init__(self, num_heads: Optional[int] = 1, num_head_channels: Optional[int] = None, use_new_attention_order: bool = False): self.use_new_attention_order = use_new_attention_order self.num_head_channels = num_head_channels self.num_heads = num_heads def qkv_attention_legacy(qkv: torch.Tensor, num_heads: int): assert len(qkv.shape) == 3 B, W, L = qkv.shape H = num_heads assert W % (3 * H) == 0 C = W // (3 * H) q, k, v = qkv.reshape(B * H, C * 3, L).split(C, dim=1) scale = 1.0 / math.sqrt(math.sqrt(C)) weight = torch.einsum('bct,bcs->bts', q * scale, k * scale) weight = torch.softmax(weight, dim=-1) output = torch.einsum("bts,bcs->bct", weight, v) return output.reshape(B, H * C, L) def qkv_attention(qkv: torch.Tensor, num_heads: int): B, W, L = qkv.shape H = num_heads assert W % (3 * H) == 0 C = W // (3 * H) q, k, v = qkv.chunk(3, dim=1) scale = 1.0 / math.sqrt(math.sqrt(C)) weight = torch.einsum("bct,bcs->bts", (q * scale).view(B * H, C, L), (k * scale).view(B * H, C, L)) weight = torch.softmax(weight, dim=-1) output = torch.einsum("bts,bcs->bct", weight, v.reshape(B * H, C, L)) return output.reshape(B, H * C, L) class AttentionBlock(Module): def __init__(self, num_channels: int, args: AttentionBlockArgs): super().__init__() self.use_new_attention_order = args.use_new_attention_order if args.num_head_channels is None: assert args.num_heads is not None assert num_channels % args.num_heads == 0 self.num_heads = args.num_heads self.num_head_channels = num_channels // self.num_heads elif args.num_heads is None: assert args.num_head_channels is not None assert num_channels % args.num_head_channels == 0 self.num_heads = num_channels // args.num_head_channels self.num_head_channels = args.num_head_channels self.norm = GroupNorm32(num_channels) self.qkv = Conv2d(num_channels, 3 * num_channels, kernel_size=1, stride=1, padding=0) self.conv = Conv2d(num_channels, num_channels, kernel_size=1, stride=1, padding=0) with torch.no_grad(): zero_(self.conv.weight) zero_(self.conv.bias) def forward(self, x: torch.Tensor): assert len(x.shape) == 4 B, C, H, W = x.shape qkv = self.qkv(self.norm(x)).reshape(B, 3 * C, H * W) if self.use_new_attention_order: h = qkv_attention(qkv, self.num_heads) else: h = qkv_attention_legacy(qkv, self.num_heads) h = self.conv(h.reshape(B, C, H, W)) return x + h class Arity3To1(Module): def __init__(self, module: Module): super().__init__() self.module = module def forward(self, x: Tensor, y: Optional[Tensor] = None, z: Optional[Tensor] = None): return self.module(x) class DownsamplingBlock(Module): def __init__(self, in_channels: int, out_channels: int, cond0_channels: Optional[int], cond1_channels: Optional[int], num_res_blocks: int, dropout_prob: float, use_attention: bool, perform_downsampling: bool, resample_with_res_block: bool, use_conv_to_resample: bool, attention_block_args: AttentionBlockArgs, condition_bias: float = 1.0): super().__init__() self.use_attention = use_attention self.res_blocks = ModuleList() self.attention_blocks = ModuleList() self.perform_downsampling = perform_downsampling self.output_channels = [] for j in range(num_res_blocks): self.res_blocks.append(ResBlock( in_channels=in_channels if j == 0 else out_channels, out_channels=out_channels, cond0_channels=cond0_channels, cond1_channels=cond1_channels, dropout_prob=dropout_prob, condition_bias=condition_bias)) if use_attention: self.attention_blocks.append(AttentionBlock(out_channels, attention_block_args)) self.output_channels.append(out_channels) if perform_downsampling: if resample_with_res_block: self.downsample = ResBlock( in_channels=out_channels, out_channels=out_channels, cond0_channels=cond0_channels, cond1_channels=cond1_channels, dropout_prob=dropout_prob, sampling_mode=SamplingMode.DOWNSAMPING, condition_bias=condition_bias) else: self.downsample = Arity3To1(Downsample(out_channels, use_conv_to_resample)) self.output_channels.append(out_channels) def forward(self, h: Tensor, cond0: Optional[Tensor] = None, cond1: Optional[Tensor] = None) -> List[Tensor]: hs = [] for i in range(len(self.res_blocks)): h = self.res_blocks[i].forward(h, cond0, cond1) if self.use_attention: h = self.attention_blocks[i].forward(h) hs.append(h) if self.perform_downsampling: hs.append(self.downsample(h, cond0, cond1)) return hs class UpsamplingBlock(Module): def __init__(self, in_channels: int, out_channels: int, cond0_channels: Optional[int], cond1_channels: Optional[int], num_resnet_blocks: int, skip_channels: List[int], dropout_prob: float, use_attention: bool, perform_upsampling: bool, resample_with_res_block: bool, use_conv_to_resample: bool, attention_block_args: AttentionBlockArgs, condition_bias: float = 1.0): super().__init__() self.use_attention = use_attention self.resnet_blocks = ModuleList() self.attention_blocks = ModuleList() self.perform_upsampling = perform_upsampling for i in range(num_resnet_blocks): self.resnet_blocks.append(ResBlock( in_channels=(in_channels if i == 0 else out_channels) + skip_channels[i], out_channels=out_channels, cond0_channels=cond0_channels, cond1_channels=cond1_channels, dropout_prob=dropout_prob, condition_bias=condition_bias)) if use_attention: self.attention_blocks.append(AttentionBlock(out_channels, attention_block_args)) if perform_upsampling: if resample_with_res_block: self.upsample = ResBlock( in_channels=out_channels, out_channels=out_channels, cond0_channels=cond0_channels, cond1_channels=cond1_channels, sampling_mode=SamplingMode.UPSAMPLING, dropout_prob=dropout_prob, condition_bias=condition_bias) else: self.upsample = Arity3To1(Upsample(out_channels, use_conv_to_resample)) def forward(self, h: Tensor, skips: List[Tensor], cond0: Optional[Tensor] = None, cond1: Optional[Tensor] = None) -> Tensor: for i in range(len(self.resnet_blocks)): h = self.resnet_blocks[i].forward(torch.concat([h, skips[i]], dim=1), cond0, cond1) if self.use_attention: h = self.attention_blocks[i].forward(h) if self.perform_upsampling: h = self.upsample.forward(h, cond0, cond1) return h def compute_timestep_embedding(t: Tensor, out_channels: int): assert len(t.shape) == 2 b, c = t.shape assert c == 1 half_channels = out_channels // 2 scale = -math.log(10000.0) / (half_channels - 1) log_times = scale * torch.arange(0, half_channels, device=t.device) times = torch.exp(log_times).reshape(1, half_channels) * t t_emb = torch.cat([torch.cos(times), torch.sin(times)], dim=1) if out_channels % 2 == 1: t_emb = torch.nn.functional.pad(t_emb, (1, 1), mode='constant') return t_emb class TimeEmbedding(Module): def __init__(self, out_channels: int): super().__init__() self.out_channels = out_channels def forward(self, t: Tensor): return compute_timestep_embedding(t, self.out_channels) class UnetArgs: def __init__(self, in_channels: int = 3, out_channels: int = 3, model_channels: int = 64, level_channel_multipliers: Optional[List[int]] = None, level_use_attention: Optional[List[bool]] = None, num_res_blocks_per_level: int = 2, num_middle_res_blocks: int = 2, time_embedding_channels: Optional[int] = None, cond_input_channels: int = 4, cond_internal_channels: int = 512, attention_block_args: Optional[AttentionBlockArgs] = None, dropout_prob: float = 0.1, resample_with_res_block: bool = True, use_conv_to_resample=False, condition_bias: float = 1.0): assert len(level_channel_multipliers) == len(level_use_attention) assert not use_conv_to_resample or not resample_with_res_block if time_embedding_channels is None: time_embedding_channels = model_channels if level_channel_multipliers is None: level_channel_multipliers = [1, 2, 4, 8] if level_use_attention is None: level_use_attention = [False for _ in level_channel_multipliers] if attention_block_args is None: attention_block_args = AttentionBlockArgs( num_heads=1, num_head_channels=None, use_new_attention_order=False) self.condition_bias = condition_bias self.use_conv_to_resample = use_conv_to_resample self.resample_with_res_block = resample_with_res_block self.cond_internal_channels = cond_internal_channels self.dropout_prob = dropout_prob self.attention_block_args = attention_block_args self.time_embedding_channels = time_embedding_channels self.num_res_blocks_per_level = num_res_blocks_per_level self.level_use_attention = level_use_attention self.level_channel_multipliers = level_channel_multipliers self.model_channels = model_channels self.out_channels = out_channels self.in_channels = in_channels self.num_levels = len(level_channel_multipliers) self.num_middle_res_blocks = num_middle_res_blocks self.cond_input_channels = cond_input_channels class Unet(Module): def __init__(self, args: UnetArgs): super().__init__() self.args = args self.time_embed = Sequential( TimeEmbedding(self.args.time_embedding_channels), Linear(self.args.time_embedding_channels, self.args.cond_internal_channels), SiLU(), Linear(self.args.cond_internal_channels, self.args.cond_internal_channels)) self.cond_embed = Sequential( Linear(self.args.cond_input_channels, self.args.cond_internal_channels), SiLU(), Linear(self.args.cond_internal_channels, self.args.cond_internal_channels)) self.first_conv = Conv2d(args.in_channels, args.model_channels, kernel_size=3, stride=1, padding=1) current_channels = args.model_channels channels = [current_channels] # Downsampling blocks self.down_blocks = ModuleList() for i in range(args.num_levels): out_channels = args.model_channels * args.level_channel_multipliers[i] perform_downsampling = i < args.num_levels - 1 down_block = DownsamplingBlock( in_channels=current_channels, out_channels=out_channels, cond0_channels=args.cond_internal_channels, cond1_channels=args.cond_internal_channels, num_res_blocks=args.num_res_blocks_per_level, dropout_prob=args.dropout_prob, use_attention=args.level_use_attention[i], perform_downsampling=perform_downsampling, attention_block_args=args.attention_block_args, resample_with_res_block=args.resample_with_res_block, use_conv_to_resample=args.use_conv_to_resample, condition_bias=args.condition_bias) self.down_blocks.append(down_block) current_channels = out_channels channels += down_block.output_channels # Middle blocks self.middle_blocks = ModuleList() for i in range(self.args.num_middle_res_blocks - 1): self.middle_blocks.append(ResBlock( in_channels=current_channels, out_channels=current_channels, cond0_channels=args.cond_internal_channels, cond1_channels=args.cond_internal_channels, dropout_prob=args.dropout_prob, condition_bias=args.condition_bias)) self.middle_blocks.append( Arity3To1(AttentionBlock(num_channels=current_channels, args=args.attention_block_args))) self.middle_blocks.append(ResBlock( in_channels=current_channels, out_channels=current_channels, cond0_channels=args.cond_internal_channels, cond1_channels=args.cond_internal_channels, dropout_prob=args.dropout_prob, condition_bias=args.condition_bias)) # Upsampling blocks self.up_blocks = ModuleList() for i in reversed(range(args.num_levels)): skip_channels = [] for j in range(args.num_res_blocks_per_level + 1): skip_channels.append(channels.pop()) perform_upsampling = i > 0 out_channels = args.model_channels * args.level_channel_multipliers[i] up_block = UpsamplingBlock( in_channels=current_channels, out_channels=out_channels, cond0_channels=args.cond_internal_channels, cond1_channels=args.cond_internal_channels, num_resnet_blocks=args.num_res_blocks_per_level + 1, skip_channels=skip_channels, dropout_prob=args.dropout_prob, use_attention=args.level_use_attention[i], perform_upsampling=perform_upsampling, attention_block_args=args.attention_block_args, resample_with_res_block=args.resample_with_res_block, use_conv_to_resample=args.use_conv_to_resample, condition_bias=args.condition_bias) self.up_blocks.append(up_block) current_channels = out_channels assert len(channels) == 0 self.last = Sequential( GroupNorm32(current_channels), SiLU(), init_to_zero(Conv2d(current_channels, args.out_channels, kernel_size=3, stride=1, padding=1))) def forward(self, x: Tensor, t: Tensor, cond: Tensor): t_emb = self.time_embed(t) cond_emb = self.cond_embed(cond) hs = [self.first_conv(x)] for block in self.down_blocks: hs += block.forward(hs[-1], t_emb, cond_emb) h = hs[-1] for block in self.middle_blocks: h = block(h, t_emb, cond_emb) for block in self.up_blocks: skips = [] for i in range(self.args.num_res_blocks_per_level + 1): skips.append(hs.pop()) h = block.forward(h, skips, t_emb, cond_emb) assert len(hs) == 0 return self.last(h) class UnetWithFirstConvAddition(Module): def __init__(self, args: UnetArgs): super().__init__() self.args = args self.time_embed = Sequential( TimeEmbedding(self.args.time_embedding_channels), Linear(self.args.time_embedding_channels, self.args.cond_internal_channels), SiLU(), Linear(self.args.cond_internal_channels, self.args.cond_internal_channels)) self.cond_embed = Sequential( Linear(self.args.cond_input_channels, self.args.cond_internal_channels), SiLU(), Linear(self.args.cond_internal_channels, self.args.cond_internal_channels)) self.first_conv = Conv2d(args.in_channels, args.model_channels, kernel_size=3, stride=1, padding=1) current_channels = args.model_channels channels = [current_channels] # Downsampling blocks self.down_blocks = ModuleList() for i in range(args.num_levels): out_channels = args.model_channels * args.level_channel_multipliers[i] perform_downsampling = i < args.num_levels - 1 down_block = DownsamplingBlock( in_channels=current_channels, out_channels=out_channels, cond0_channels=args.cond_internal_channels, cond1_channels=args.cond_internal_channels, num_res_blocks=args.num_res_blocks_per_level, dropout_prob=args.dropout_prob, use_attention=args.level_use_attention[i], perform_downsampling=perform_downsampling, attention_block_args=args.attention_block_args, resample_with_res_block=args.resample_with_res_block, use_conv_to_resample=args.use_conv_to_resample, condition_bias=args.condition_bias) self.down_blocks.append(down_block) current_channels = out_channels channels += down_block.output_channels # Middle blocks self.middle_blocks = ModuleList() for i in range(self.args.num_middle_res_blocks - 1): self.middle_blocks.append(ResBlock( in_channels=current_channels, out_channels=current_channels, cond0_channels=args.cond_internal_channels, cond1_channels=args.cond_internal_channels, dropout_prob=args.dropout_prob, condition_bias=args.condition_bias)) self.middle_blocks.append( Arity3To1(AttentionBlock(num_channels=current_channels, args=args.attention_block_args))) self.middle_blocks.append(ResBlock( in_channels=current_channels, out_channels=current_channels, cond0_channels=args.cond_internal_channels, cond1_channels=args.cond_internal_channels, dropout_prob=args.dropout_prob, condition_bias=args.condition_bias)) # Upsampling blocks self.up_blocks = ModuleList() for i in reversed(range(args.num_levels)): skip_channels = [] for j in range(args.num_res_blocks_per_level + 1): skip_channels.append(channels.pop()) perform_upsampling = i > 0 out_channels = args.model_channels * args.level_channel_multipliers[i] up_block = UpsamplingBlock( in_channels=current_channels, out_channels=out_channels, cond0_channels=args.cond_internal_channels, cond1_channels=args.cond_internal_channels, num_resnet_blocks=args.num_res_blocks_per_level + 1, skip_channels=skip_channels, dropout_prob=args.dropout_prob, use_attention=args.level_use_attention[i], perform_upsampling=perform_upsampling, attention_block_args=args.attention_block_args, resample_with_res_block=args.resample_with_res_block, use_conv_to_resample=args.use_conv_to_resample, condition_bias=args.condition_bias) self.up_blocks.append(up_block) current_channels = out_channels assert len(channels) == 0 self.last = Sequential( GroupNorm32(current_channels), SiLU(), init_to_zero(Conv2d(current_channels, args.out_channels, kernel_size=3, stride=1, padding=1))) def forward(self, x: Tensor, t: Tensor, cond: Tensor, first_conv_addition: Tensor): t_emb = self.time_embed(t) cond_emb = self.cond_embed(cond) first_conv = self.first_conv(x) hs = [first_conv + first_conv_addition] for block in self.down_blocks: hs += block.forward(hs[-1], t_emb, cond_emb) h = hs[-1] for block in self.middle_blocks: h = block(h, t_emb, cond_emb) for block in self.up_blocks: skips = [] for i in range(self.args.num_res_blocks_per_level + 1): skips.append(hs.pop()) h = block.forward(h, skips, t_emb, cond_emb) assert len(hs) == 0 return self.last(h) ================================================ FILE: src/tha4/nn/conv.py ================================================ from typing import Optional, Union, Callable from torch.nn import Conv2d, Module, Sequential, ConvTranspose2d from tha4.shion.core.module_factory import ModuleFactory from tha4.nn.nonlinearity_factory import resolve_nonlinearity_factory from tha4.nn.normalization import NormalizationLayerFactory from tha4.nn.util import wrap_conv_or_linear_module, BlockArgs def create_conv7(in_channels: int, out_channels: int, bias: bool = False, initialization_method: Union[str, Callable[[Module], Module]] = 'he', use_spectral_norm: bool = False) -> Module: return wrap_conv_or_linear_module( Conv2d(in_channels, out_channels, kernel_size=7, stride=1, padding=3, bias=bias), initialization_method, use_spectral_norm) def create_conv7_from_block_args(in_channels: int, out_channels: int, bias: bool = False, block_args: Optional[BlockArgs] = None) -> Module: if block_args is None: block_args = BlockArgs() return create_conv7( in_channels, out_channels, bias, block_args.initialization_method, block_args.use_spectral_norm) def create_conv3(in_channels: int, out_channels: int, bias: bool = False, initialization_method: Union[str, Callable[[Module], Module]] = 'he', use_spectral_norm: bool = False) -> Module: return wrap_conv_or_linear_module( Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=bias), initialization_method, use_spectral_norm) def create_conv3_from_block_args(in_channels: int, out_channels: int, bias: bool = False, block_args: Optional[BlockArgs] = None): if block_args is None: block_args = BlockArgs() return create_conv3(in_channels, out_channels, bias, block_args.initialization_method, block_args.use_spectral_norm) def create_conv1(in_channels: int, out_channels: int, initialization_method: Union[str, Callable[[Module], Module]] = 'he', bias: bool = False, use_spectral_norm: bool = False) -> Module: return wrap_conv_or_linear_module( Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias), initialization_method, use_spectral_norm) def create_conv1_from_block_args(in_channels: int, out_channels: int, bias: bool = False, block_args: Optional[BlockArgs] = None) -> Module: if block_args is None: block_args = BlockArgs() return create_conv1( in_channels=in_channels, out_channels=out_channels, initialization_method=block_args.initialization_method, bias=bias, use_spectral_norm=block_args.use_spectral_norm) def create_conv7_block(in_channels: int, out_channels: int, initialization_method: Union[str, Callable[[Module], Module]] = 'he', nonlinearity_factory: Optional[ModuleFactory] = None, normalization_layer_factory: Optional[NormalizationLayerFactory] = None, use_spectral_norm: bool = False) -> Module: nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory) return Sequential( create_conv7(in_channels, out_channels, bias=False, initialization_method=initialization_method, use_spectral_norm=use_spectral_norm), NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(out_channels, affine=True), resolve_nonlinearity_factory(nonlinearity_factory).create()) def create_conv7_block_from_block_args( in_channels: int, out_channels: int, block_args: Optional[BlockArgs] = None) -> Module: if block_args is None: block_args = BlockArgs() return create_conv7_block(in_channels, out_channels, block_args.initialization_method, block_args.nonlinearity_factory, block_args.normalization_layer_factory, block_args.use_spectral_norm) def create_conv3_block(in_channels: int, out_channels: int, initialization_method: Union[str, Callable[[Module], Module]] = 'he', nonlinearity_factory: Optional[ModuleFactory] = None, normalization_layer_factory: Optional[NormalizationLayerFactory] = None, use_spectral_norm: bool = False) -> Module: nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory) return Sequential( create_conv3(in_channels, out_channels, bias=False, initialization_method=initialization_method, use_spectral_norm=use_spectral_norm), NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(out_channels, affine=True), resolve_nonlinearity_factory(nonlinearity_factory).create()) def create_conv3_block_from_block_args( in_channels: int, out_channels: int, block_args: Optional[BlockArgs] = None): if block_args is None: block_args = BlockArgs() return create_conv3_block(in_channels, out_channels, block_args.initialization_method, block_args.nonlinearity_factory, block_args.normalization_layer_factory, block_args.use_spectral_norm) def create_downsample_block(in_channels: int, out_channels: int, is_output_1x1: bool = False, initialization_method: Union[str, Callable[[Module], Module]] = 'he', nonlinearity_factory: Optional[ModuleFactory] = None, normalization_layer_factory: Optional[NormalizationLayerFactory] = None, use_spectral_norm: bool = False) -> Module: if is_output_1x1: return Sequential( wrap_conv_or_linear_module( Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False), initialization_method, use_spectral_norm), resolve_nonlinearity_factory(nonlinearity_factory).create()) else: return Sequential( wrap_conv_or_linear_module( Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False), initialization_method, use_spectral_norm), NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(out_channels, affine=True), resolve_nonlinearity_factory(nonlinearity_factory).create()) def create_downsample_block_from_block_args(in_channels: int, out_channels: int, is_output_1x1: bool = False, block_args: Optional[BlockArgs] = None): if block_args is None: block_args = BlockArgs() return create_downsample_block( in_channels, out_channels, is_output_1x1, block_args.initialization_method, block_args.nonlinearity_factory, block_args.normalization_layer_factory, block_args.use_spectral_norm) def create_upsample_block(in_channels: int, out_channels: int, initialization_method: Union[str, Callable[[Module], Module]] = 'he', nonlinearity_factory: Optional[ModuleFactory] = None, normalization_layer_factory: Optional[NormalizationLayerFactory] = None, use_spectral_norm: bool = False) -> Module: nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory) return Sequential( wrap_conv_or_linear_module( ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False), initialization_method, use_spectral_norm), NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(out_channels, affine=True), resolve_nonlinearity_factory(nonlinearity_factory).create()) def create_upsample_block_from_block_args(in_channels: int, out_channels: int, block_args: Optional[BlockArgs] = None) -> Module: if block_args is None: block_args = BlockArgs() return create_upsample_block(in_channels, out_channels, block_args.initialization_method, block_args.nonlinearity_factory, block_args.normalization_layer_factory, block_args.use_spectral_norm) ================================================ FILE: src/tha4/nn/eyebrow_decomposer/__init__.py ================================================ ================================================ FILE: src/tha4/nn/eyebrow_decomposer/eyebrow_decomposer_00.py ================================================ from typing import List, Optional import torch from torch import Tensor from torch.nn import Module from tha4.nn.common.poser_encoder_decoder_00 import PoserEncoderDecoder00Args, PoserEncoderDecoder00 from tha4.nn.image_processing_util import apply_color_change from tha4.shion.core.module_factory import ModuleFactory from tha4.nn.nonlinearity_factory import ReLUFactory from tha4.nn.normalization import InstanceNorm2dFactory from tha4.nn.util import BlockArgs class EyebrowDecomposer00Args(PoserEncoderDecoder00Args): def __init__(self, image_size: int = 128, image_channels: int = 4, start_channels: int = 64, bottleneck_image_size=16, num_bottleneck_blocks=6, max_channels: int = 512, block_args: Optional[BlockArgs] = None): super().__init__( image_size, image_channels, image_channels, 0, start_channels, bottleneck_image_size, num_bottleneck_blocks, max_channels, block_args) class EyebrowDecomposer00(Module): def __init__(self, args: EyebrowDecomposer00Args): super().__init__() self.args = args self.body = PoserEncoderDecoder00(args) self.background_layer_alpha = self.args.create_alpha_block() self.background_layer_color_change = self.args.create_color_change_block() self.eyebrow_layer_alpha = self.args.create_alpha_block() self.eyebrow_layer_color_change = self.args.create_color_change_block() def forward(self, image: Tensor, *args) -> List[Tensor]: feature = self.body(image)[0] background_layer_alpha = self.background_layer_alpha(feature) background_layer_color_change = self.background_layer_color_change(feature) background_layer_1 = apply_color_change(background_layer_alpha, background_layer_color_change, image) eyebrow_layer_alpha = self.eyebrow_layer_alpha(feature) eyebrow_layer_color_change = self.eyebrow_layer_color_change(feature) eyebrow_layer = apply_color_change(eyebrow_layer_alpha, image, eyebrow_layer_color_change) return [ eyebrow_layer, # 0 eyebrow_layer_alpha, # 1 eyebrow_layer_color_change, # 2 background_layer_1, # 3 background_layer_alpha, # 4 background_layer_color_change, # 5 ] EYEBROW_LAYER_INDEX = 0 EYEBROW_LAYER_ALPHA_INDEX = 1 EYEBROW_LAYER_COLOR_CHANGE_INDEX = 2 BACKGROUND_LAYER_INDEX = 3 BACKGROUND_LAYER_ALPHA_INDEX = 4 BACKGROUND_LAYER_COLOR_CHANGE_INDEX = 5 OUTPUT_LENGTH = 6 class EyebrowDecomposer00Factory(ModuleFactory): def __init__(self, args: EyebrowDecomposer00Args): super().__init__() self.args = args def create(self) -> Module: return EyebrowDecomposer00(self.args) ================================================ FILE: src/tha4/nn/eyebrow_morphing_combiner/__init__.py ================================================ ================================================ FILE: src/tha4/nn/eyebrow_morphing_combiner/eyebrow_morphing_combiner_00.py ================================================ from typing import List, Optional import torch from torch import Tensor from torch.nn import Module from tha4.nn.common.poser_encoder_decoder_00 import PoserEncoderDecoder00Args, PoserEncoderDecoder00 from tha4.nn.image_processing_util import apply_color_change, apply_grid_change, apply_rgb_change from tha4.shion.core.module_factory import ModuleFactory from tha4.nn.nonlinearity_factory import ReLUFactory from tha4.nn.normalization import InstanceNorm2dFactory from tha4.nn.util import BlockArgs class EyebrowMorphingCombiner00Args(PoserEncoderDecoder00Args): def __init__(self, image_size: int = 128, image_channels: int = 4, num_pose_params: int = 12, start_channels: int = 64, bottleneck_image_size=16, num_bottleneck_blocks=6, max_channels: int = 512, block_args: Optional[BlockArgs] = None): super().__init__( image_size, 2 * image_channels, image_channels, num_pose_params, start_channels, bottleneck_image_size, num_bottleneck_blocks, max_channels, block_args) class EyebrowMorphingCombiner00(Module): def __init__(self, args: EyebrowMorphingCombiner00Args): super().__init__() self.args = args self.body = PoserEncoderDecoder00(args) self.morphed_eyebrow_layer_grid_change = self.args.create_grid_change_block() self.morphed_eyebrow_layer_alpha = self.args.create_alpha_block() self.morphed_eyebrow_layer_color_change = self.args.create_color_change_block() self.combine_alpha = self.args.create_alpha_block() def forward(self, background_layer: Tensor, eyebrow_layer: Tensor, pose: Tensor, *args) -> List[Tensor]: combined_image = torch.cat([background_layer, eyebrow_layer], dim=1) feature = self.body(combined_image, pose)[0] morphed_eyebrow_layer_grid_change = self.morphed_eyebrow_layer_grid_change(feature) morphed_eyebrow_layer_alpha = self.morphed_eyebrow_layer_alpha(feature) morphed_eyebrow_layer_color_change = self.morphed_eyebrow_layer_color_change(feature) warped_eyebrow_layer = apply_grid_change(morphed_eyebrow_layer_grid_change, eyebrow_layer) morphed_eyebrow_layer = apply_color_change( morphed_eyebrow_layer_alpha, morphed_eyebrow_layer_color_change, warped_eyebrow_layer) combine_alpha = self.combine_alpha(feature) eyebrow_image = apply_rgb_change(combine_alpha, morphed_eyebrow_layer, background_layer) eyebrow_image_no_combine_alpha = apply_rgb_change( (morphed_eyebrow_layer[:, 3:4, :, :] + 1.0) / 2.0, morphed_eyebrow_layer, background_layer) return [ eyebrow_image, # 0 combine_alpha, # 1 eyebrow_image_no_combine_alpha, # 2 morphed_eyebrow_layer, # 3 morphed_eyebrow_layer_alpha, # 4 morphed_eyebrow_layer_color_change, # 5 warped_eyebrow_layer, # 6 morphed_eyebrow_layer_grid_change, # 7 ] EYEBROW_IMAGE_INDEX = 0 COMBINE_ALPHA_INDEX = 1 EYEBROW_IMAGE_NO_COMBINE_ALPHA_INDEX = 2 MORPHED_EYEBROW_LAYER_INDEX = 3 MORPHED_EYEBROW_LAYER_ALPHA_INDEX = 4 MORPHED_EYEBROW_LAYER_COLOR_CHANGE_INDEX = 5 WARPED_EYEBROW_LAYER_INDEX = 6 MORPHED_EYEBROW_LAYER_GRID_CHANGE_INDEX = 7 OUTPUT_LENGTH = 8 class EyebrowMorphingCombiner00Factory(ModuleFactory): def __init__(self, args: EyebrowMorphingCombiner00Args): super().__init__() self.args = args def create(self) -> Module: return EyebrowMorphingCombiner00(self.args) ================================================ FILE: src/tha4/nn/face_morpher/__init__.py ================================================ ================================================ FILE: src/tha4/nn/face_morpher/face_morpher_08.py ================================================ import math from typing import List, Optional import torch from torch import Tensor from torch.nn import ModuleList, Sequential, Sigmoid, Tanh, Module from torch.nn.functional import affine_grid, grid_sample from tha4.shion.core.module_factory import ModuleFactory from tha4.nn.conv import create_conv3_block_from_block_args, \ create_downsample_block_from_block_args, create_upsample_block_from_block_args, create_conv3_from_block_args, \ create_conv3 from tha4.nn.nonlinearity_factory import LeakyReLUFactory from tha4.nn.normalization import InstanceNorm2dFactory from tha4.nn.resnet_block import ResnetBlock from tha4.nn.util import BlockArgs class FaceMorpher08Args: def __init__(self, image_size: int = 256, image_channels: int = 4, num_expression_params: int = 67, start_channels: int = 16, bottleneck_image_size=4, num_bottleneck_blocks=3, max_channels: int = 512, block_args: Optional[BlockArgs] = None, output_iris_mouth_grid_change: bool = False): self.max_channels = max_channels self.num_bottleneck_blocks = num_bottleneck_blocks assert bottleneck_image_size > 1 self.bottleneck_image_size = bottleneck_image_size self.start_channels = start_channels self.image_channels = image_channels self.num_expression_params = num_expression_params self.image_size = image_size self.output_iris_mouth_grid_change = output_iris_mouth_grid_change if block_args is None: self.block_args = BlockArgs( normalization_layer_factory=InstanceNorm2dFactory(), nonlinearity_factory=LeakyReLUFactory(negative_slope=0.2, inplace=True)) else: self.block_args = block_args class FaceMorpher08(Module): def __init__(self, args: FaceMorpher08Args): super().__init__() self.args = args self.num_levels = int(math.log2(args.image_size // args.bottleneck_image_size)) + 1 self.downsample_blocks = ModuleList() self.downsample_blocks.append( create_conv3_block_from_block_args( args.image_channels, args.start_channels, args.block_args)) current_image_size = args.image_size current_num_channels = args.start_channels while current_image_size > args.bottleneck_image_size: next_image_size = current_image_size // 2 next_num_channels = self.get_num_output_channels_from_image_size(next_image_size) self.downsample_blocks.append(create_downsample_block_from_block_args( in_channels=current_num_channels, out_channels=next_num_channels, is_output_1x1=False, block_args=args.block_args)) current_image_size = next_image_size current_num_channels = next_num_channels assert len(self.downsample_blocks) == self.num_levels self.bottleneck_blocks = ModuleList() self.bottleneck_blocks.append(create_conv3_block_from_block_args( in_channels=current_num_channels + args.num_expression_params, out_channels=current_num_channels, block_args=args.block_args)) for i in range(1, args.num_bottleneck_blocks): self.bottleneck_blocks.append( ResnetBlock.create( num_channels=current_num_channels, is1x1=False, block_args=args.block_args)) self.upsample_blocks = ModuleList() while current_image_size < args.image_size: next_image_size = current_image_size * 2 next_num_channels = self.get_num_output_channels_from_image_size(next_image_size) self.upsample_blocks.append(create_upsample_block_from_block_args( in_channels=current_num_channels, out_channels=next_num_channels, block_args=args.block_args)) current_image_size = next_image_size current_num_channels = next_num_channels self.iris_mouth_grid_change = self.create_grid_change_block() self.iris_mouth_color_change = self.create_color_change_block() self.iris_mouth_alpha = self.create_alpha_block() self.eye_color_change = self.create_color_change_block() self.eye_alpha = self.create_alpha_block() def create_alpha_block(self): return Sequential( create_conv3( in_channels=self.args.start_channels, out_channels=1, bias=True, initialization_method=self.args.block_args.initialization_method, use_spectral_norm=False), Sigmoid()) def create_color_change_block(self): return Sequential( create_conv3_from_block_args( in_channels=self.args.start_channels, out_channels=self.args.image_channels, bias=True, block_args=self.args.block_args), Tanh()) def create_grid_change_block(self): return create_conv3( in_channels=self.args.start_channels, out_channels=2, bias=False, initialization_method='zero', use_spectral_norm=False) def get_num_output_channels_from_level(self, level: int): return self.get_num_output_channels_from_image_size(self.args.image_size // (2 ** level)) def get_num_output_channels_from_image_size(self, image_size: int): return min(self.args.start_channels * (self.args.image_size // image_size), self.args.max_channels) def merge_down(self, top_layer: Tensor, bottom_layer: Tensor): top_layer_rgb = top_layer[:, 0:3, :, :] top_layer_a = top_layer[:, 3:4, :, :] return bottom_layer * (1 - top_layer_a) + torch.cat([top_layer_rgb * top_layer_a, top_layer_a], dim=1) def apply_grid_change(self, grid_change, image: Tensor) -> Tensor: n, c, h, w = image.shape device = grid_change.device grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2) identity = torch.tensor( [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device, dtype=grid_change.dtype).unsqueeze(0).repeat(n, 1, 1) base_grid = affine_grid(identity, [n, c, h, w], align_corners=False) grid = base_grid + grid_change resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=False) return resampled_image def apply_color_change(self, alpha, color_change, image: Tensor) -> Tensor: return color_change * alpha + image * (1 - alpha) def forward(self, image: Tensor, pose: Tensor, *args) -> List[Tensor]: feature = image for block in self.downsample_blocks: feature = block(feature) n, c = pose.shape pose = pose.view(n, c, 1, 1).repeat(1, 1, self.args.bottleneck_image_size, self.args.bottleneck_image_size) feature = torch.cat([feature, pose], dim=1) for block in self.bottleneck_blocks: feature = block(feature) for block in self.upsample_blocks: feature = block(feature) iris_mouth_grid_change = self.iris_mouth_grid_change(feature) iris_mouth_image_0 = self.apply_grid_change(iris_mouth_grid_change, image) iris_mouth_color_change = self.iris_mouth_color_change(feature) iris_mouth_alpha = self.iris_mouth_alpha(feature) iris_mouth_image_1 = self.apply_color_change(iris_mouth_alpha, iris_mouth_color_change, iris_mouth_image_0) eye_color_change = self.eye_color_change(feature) eye_alpha = self.eye_alpha(feature) output_image = self.apply_color_change(eye_alpha, eye_color_change, iris_mouth_image_1.detach()) outputs = [ output_image, # 0 eye_alpha, # 1 eye_color_change, # 2 iris_mouth_image_1, # 3 iris_mouth_alpha, # 4 iris_mouth_color_change, # 5 iris_mouth_image_0, # 6 ] if self.args.output_iris_mouth_grid_change: outputs.append(iris_mouth_grid_change) return outputs OUTPUT_IMAGE_INDEX = 0 EYE_ALPHA_INDEX = 1 EYE_COLOR_CHANGE_INDEX = 2 IRIS_MOUTH_IMAGE_1_INDEX = 3 IRIS_MOUTH_ALPHA_INDEX = 4 IRIS_MOUTH_COLOR_CHANGE_INDEX = 5 IRIS_MOUTH_IMAGE_0_INDEX = 6 IRIS_MOUTH_GRID_CHANGE_INDEX = 7 class FaceMorpher08Factory(ModuleFactory): def __init__(self, args: FaceMorpher08Args): super().__init__() self.args = args def create(self) -> Module: return FaceMorpher08(self.args) ================================================ FILE: src/tha4/nn/image_processing_util.py ================================================ import torch from torch import Tensor from torch.nn.functional import affine_grid, grid_sample def apply_rgb_change(alpha: Tensor, color_change: Tensor, image: Tensor): image_rgb = image[:, 0:3, :, :] color_change_rgb = color_change[:, 0:3, :, :] output_rgb = color_change_rgb * alpha + image_rgb * (1 - alpha) return torch.cat([output_rgb, image[:, 3:4, :, :]], dim=1) def apply_grid_change(grid_change, image: Tensor) -> Tensor: n, c, h, w = image.shape device = grid_change.device grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2) identity = torch.tensor( [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype=grid_change.dtype, device=device).unsqueeze(0).repeat(n, 1, 1) base_grid = affine_grid(identity, [n, c, h, w], align_corners=False) grid = base_grid + grid_change resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=False) return resampled_image class GridChangeApplier: def __init__(self): self.last_n = None self.last_device = None self.last_identity = None def apply(self, grid_change: Tensor, image: Tensor, align_corners: bool = False) -> Tensor: n, c, h, w = image.shape device = grid_change.device grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2) if n == self.last_n and device == self.last_device: identity = self.last_identity else: identity = torch.tensor( [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype=grid_change.dtype, device=device, requires_grad=False) \ .unsqueeze(0).repeat(n, 1, 1) self.last_identity = identity self.last_n = n self.last_device = device base_grid = affine_grid(identity, [n, c, h, w], align_corners=align_corners) grid = base_grid + grid_change resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=align_corners) return resampled_image def apply_color_change(alpha, color_change, image: Tensor) -> Tensor: return color_change * alpha + image * (1 - alpha) ================================================ FILE: src/tha4/nn/init_function.py ================================================ from typing import Callable import torch from torch import zero_ from torch.nn import Module from torch.nn.init import kaiming_normal_, xavier_normal_, normal_ def create_init_function(method: str = 'none') -> Callable[[Module], Module]: def init(module: Module): if method == 'none': return module elif method == 'he': kaiming_normal_(module.weight) return module elif method == 'xavier': xavier_normal_(module.weight) return module elif method == 'dcgan': normal_(module.weight, 0.0, 0.02) return module elif method == 'dcgan_001': normal_(module.weight, 0.0, 0.01) return module elif method == "zero": with torch.no_grad(): zero_(module.weight) return module else: raise ("Invalid initialization method %s" % method) return init class HeInitialization: def __init__(self, a: int = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'): self.nonlinearity = nonlinearity self.mode = mode self.a = a def __call__(self, module: Module) -> Module: with torch.no_grad(): kaiming_normal_(module.weight, a=self.a, mode=self.mode, nonlinearity=self.nonlinearity) return module class NormalInitialization: def __init__(self, mean: float = 0.0, std: float = 1.0): self.std = std self.mean = mean def __call__(self, module: Module) -> Module: with torch.no_grad(): normal_(module.weight, self.mean, self.std) return module class XavierInitialization: def __init__(self, gain: float = 1.0): self.gain = gain def __call__(self, module: Module) -> Module: with torch.no_grad(): xavier_normal_(module.weight, self.gain) return module class ZeroInitialization: def __call__(self, module: Module) -> Module: with torch.no_grad: zero_(module.weight) return module class NoInitialization: def __call__(self, module: Module) -> Module: return module ================================================ FILE: src/tha4/nn/morpher/__init__.py ================================================ ================================================ FILE: src/tha4/nn/morpher/morpher_00.py ================================================ from typing import List import torch from torch import Tensor from torch.nn import Module from tha4.shion.core.module_factory import ModuleFactory from tha4.nn.image_processing_util import GridChangeApplier from tha4.nn.common.unet import UnetArgs, Unet, AttentionBlockArgs def apply_color_change(alpha, color_change, image: Tensor) -> Tensor: return color_change * alpha + image * (1 - alpha) class Morpher00Args: def __init__(self, image_size: int, image_channels: int, num_pose_parameters: int, unet_args: UnetArgs): assert unet_args.in_channels == image_channels assert unet_args.out_channels == ( image_channels + # direct 2 + # warp 1 # alpha ) assert unet_args.cond_input_channels == num_pose_parameters self.image_channels = image_channels self.image_size = image_size self.num_pose_parameters = num_pose_parameters self.unet_args = unet_args class Morpher00(Module): def __init__(self, args: Morpher00Args): super().__init__() self.args = args self.body = Unet(args.unet_args) self.grid_change_applier = GridChangeApplier() def forward(self, image: torch.Tensor, pose: torch.Tensor) -> List[Tensor]: assert len(image.shape) == 4 assert image.shape[1] == self.args.image_channels assert image.shape[2] == self.args.image_size assert image.shape[3] == self.args.image_size assert len(pose.shape) == 2 assert image.shape[0] == pose.shape[0] assert pose.shape[1] == self.args.num_pose_parameters t = torch.zeros(image.shape[0], 1, device=image.device) body_output = self.body(image, t, pose) direct = body_output[:, 0:self.args.image_channels, :, :] grid_change = body_output[:, self.args.image_channels:self.args.image_channels + 2, :, :] alpha = torch.sigmoid(body_output[:, self.args.image_channels + 2:self.args.image_channels + 3, :, :]) warped = self.grid_change_applier.apply(grid_change, image) merged = apply_color_change(alpha, direct, warped) return [ merged, alpha, warped, grid_change, direct ] INDEX_MERGED = 0 INDEX_ALPHA = 1 INDEX_WARPED = 2 INDEX_GRID_CHANGE = 3 INDEX_DIRECT = 4 class Morpher00Factory(ModuleFactory): def __init__(self, args: Morpher00Args): self.args = args def create(self) -> Module: return Morpher00(self.args) ================================================ FILE: src/tha4/nn/nonlinearity_factory.py ================================================ from typing import Optional from torch.nn import Module, ReLU, LeakyReLU, ELU, ReLU6, Hardswish, SiLU, Tanh, Sigmoid from tha4.shion.core.module_factory import ModuleFactory class ReLUFactory(ModuleFactory): def __init__(self, inplace: bool = False): self.inplace = inplace def create(self) -> Module: return ReLU(self.inplace) class LeakyReLUFactory(ModuleFactory): def __init__(self, inplace: bool = False, negative_slope: float = 1e-2): self.negative_slope = negative_slope self.inplace = inplace def create(self) -> Module: return LeakyReLU(inplace=self.inplace, negative_slope=self.negative_slope) class ELUFactory(ModuleFactory): def __init__(self, inplace: bool = False, alpha: float = 1.0): self.alpha = alpha self.inplace = inplace def create(self) -> Module: return ELU(inplace=self.inplace, alpha=self.alpha) class ReLU6Factory(ModuleFactory): def __init__(self, inplace: bool = False): self.inplace = inplace def create(self) -> Module: return ReLU6(inplace=self.inplace) class SiLUFactory(ModuleFactory): def __init__(self, inplace: bool = False): self.inplace = inplace def create(self) -> Module: return SiLU(inplace=self.inplace) class HardswishFactory(ModuleFactory): def __init__(self, inplace: bool = False): self.inplace = inplace def create(self) -> Module: return Hardswish(inplace=self.inplace) class TanhFactory(ModuleFactory): def create(self) -> Module: return Tanh() class SigmoidFactory(ModuleFactory): def create(self) -> Module: return Sigmoid() def resolve_nonlinearity_factory(nonlinearity_fatory: Optional[ModuleFactory]) -> ModuleFactory: if nonlinearity_fatory is None: return ReLUFactory(inplace=False) else: return nonlinearity_fatory ================================================ FILE: src/tha4/nn/normalization.py ================================================ from abc import ABC, abstractmethod from typing import Optional import torch from torch import layer_norm from torch.nn import Module, BatchNorm2d, InstanceNorm2d, Parameter from torch.nn.init import normal_, constant_ from tha4.nn.pass_through import PassThrough class PixelNormalization(Module): def __init__(self, epsilon=1e-8): super().__init__() self.epsilon = epsilon def forward(self, x): return x / torch.sqrt((x ** 2).mean(dim=1, keepdim=True) + self.epsilon) class NormalizationLayerFactory(ABC): def __init__(self): super().__init__() @abstractmethod def create(self, num_features: int, affine: bool = True) -> Module: pass @staticmethod def resolve_2d(factory: Optional['NormalizationLayerFactory']) -> 'NormalizationLayerFactory': if factory is None: return InstanceNorm2dFactory() else: return factory class Bias2d(Module): def __init__(self, num_features: int): super().__init__() self.num_features = num_features self.bias = Parameter(torch.zeros(1, num_features, 1, 1)) def forward(self, x): return x + self.bias class NoNorm2dFactory(NormalizationLayerFactory): def __init__(self): super().__init__() def create(self, num_features: int, affine: bool = True) -> Module: if affine: return Bias2d(num_features) else: return PassThrough() class BatchNorm2dFactory(NormalizationLayerFactory): def __init__(self, weight_mean: Optional[float] = None, weight_std: Optional[float] = None, bias: Optional[float] = None): super().__init__() self.bias = bias self.weight_std = weight_std self.weight_mean = weight_mean def get_weight_mean(self): if self.weight_mean is None: return 1.0 else: return self.weight_mean def get_weight_std(self): if self.weight_std is None: return 0.02 else: return self.weight_std def create(self, num_features: int, affine: bool = True) -> Module: module = BatchNorm2d(num_features=num_features, affine=affine) if affine: if self.weight_mean is not None or self.weight_std is not None: normal_(module.weight, self.get_weight_mean(), self.get_weight_std()) if self.bias is not None: constant_(module.bias, self.bias) return module class InstanceNorm2dFactory(NormalizationLayerFactory): def __init__(self): super().__init__() def create(self, num_features: int, affine: bool = True) -> Module: return InstanceNorm2d(num_features=num_features, affine=affine) class PixelNormFactory(NormalizationLayerFactory): def __init__(self): super().__init__() def create(self, num_features: int, affine: bool = True) -> Module: return PixelNormalization() class LayerNorm2d(Module): def __init__(self, channels: int, affine: bool = True): super(LayerNorm2d, self).__init__() self.channels = channels self.affine = affine if self.affine: self.weight = Parameter(torch.ones(1, channels, 1, 1)) self.bias = Parameter(torch.zeros(1, channels, 1, 1)) def forward(self, x): shape = x.size()[1:] y = layer_norm(x, shape) * self.weight + self.bias return y class LayerNorm2dFactory(NormalizationLayerFactory): def __init__(self): super().__init__() def create(self, num_features: int, affine: bool = True) -> Module: return LayerNorm2d(channels=num_features, affine=affine) ================================================ FILE: src/tha4/nn/pass_through.py ================================================ from torch.nn import Module class PassThrough(Module): def __init__(self): super().__init__() def forward(self, x): return x ================================================ FILE: src/tha4/nn/resnet_block.py ================================================ from typing import Optional import torch from torch.nn import Module, Sequential, Parameter from tha4.shion.core.module_factory import ModuleFactory from tha4.nn.conv import create_conv1, create_conv3 from tha4.nn.nonlinearity_factory import resolve_nonlinearity_factory from tha4.nn.normalization import NormalizationLayerFactory from tha4.nn.util import BlockArgs class ResnetBlock(Module): @staticmethod def create(num_channels: int, is1x1: bool = False, use_scale_parameters: bool = False, block_args: Optional[BlockArgs] = None): if block_args is None: block_args = BlockArgs() return ResnetBlock(num_channels, is1x1, block_args.initialization_method, block_args.nonlinearity_factory, block_args.normalization_layer_factory, block_args.use_spectral_norm, use_scale_parameters) def __init__(self, num_channels: int, is1x1: bool = False, initialization_method: str = 'he', nonlinearity_factory: ModuleFactory = None, normalization_layer_factory: Optional[NormalizationLayerFactory] = None, use_spectral_norm: bool = False, use_scale_parameter: bool = False): super().__init__() self.use_scale_parameter = use_scale_parameter if self.use_scale_parameter: self.scale = Parameter(torch.zeros(1)) nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory) if is1x1: self.resnet_path = Sequential( create_conv1(num_channels, num_channels, initialization_method, bias=True, use_spectral_norm=use_spectral_norm), nonlinearity_factory.create(), create_conv1(num_channels, num_channels, initialization_method, bias=True, use_spectral_norm=use_spectral_norm)) else: self.resnet_path = Sequential( create_conv3(num_channels, num_channels, bias=False, initialization_method=initialization_method, use_spectral_norm=use_spectral_norm), NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(num_channels, affine=True), nonlinearity_factory.create(), create_conv3(num_channels, num_channels, bias=False, initialization_method=initialization_method, use_spectral_norm=use_spectral_norm), NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(num_channels, affine=True)) def forward(self, x): if self.use_scale_parameter: return x + self.scale * self.resnet_path(x) else: return x + self.resnet_path(x) ================================================ FILE: src/tha4/nn/resnet_block_seperable.py ================================================ from typing import Optional import torch from torch.nn import Module, Sequential, Parameter from tha4.shion.core.module_factory import ModuleFactory from tha4.nn.conv import create_conv1 from tha4.nn.nonlinearity_factory import resolve_nonlinearity_factory from tha4.nn.normalization import NormalizationLayerFactory from tha4.nn.separable_conv import create_separable_conv3 from tha4.nn.util import BlockArgs class ResnetBlockSeparable(Module): @staticmethod def create(num_channels: int, is1x1: bool = False, use_scale_parameters: bool = False, block_args: Optional[BlockArgs] = None): if block_args is None: block_args = BlockArgs() return ResnetBlockSeparable( num_channels, is1x1, block_args.initialization_method, block_args.nonlinearity_factory, block_args.normalization_layer_factory, block_args.use_spectral_norm, use_scale_parameters) def __init__(self, num_channels: int, is1x1: bool = False, initialization_method: str = 'he', nonlinearity_factory: ModuleFactory = None, normalization_layer_factory: Optional[NormalizationLayerFactory] = None, use_spectral_norm: bool = False, use_scale_parameter: bool = False): super().__init__() self.use_scale_parameter = use_scale_parameter if self.use_scale_parameter: self.scale = Parameter(torch.zeros(1)) nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory) if is1x1: self.resnet_path = Sequential( create_conv1(num_channels, num_channels, initialization_method, bias=True, use_spectral_norm=use_spectral_norm), nonlinearity_factory.create(), create_conv1(num_channels, num_channels, initialization_method, bias=True, use_spectral_norm=use_spectral_norm)) else: self.resnet_path = Sequential( create_separable_conv3( num_channels, num_channels, bias=False, initialization_method=initialization_method, use_spectral_norm=use_spectral_norm), NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(num_channels, affine=True), nonlinearity_factory.create(), create_separable_conv3( num_channels, num_channels, bias=False, initialization_method=initialization_method, use_spectral_norm=use_spectral_norm), NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(num_channels, affine=True)) def forward(self, x): if self.use_scale_parameter: return x + self.scale * self.resnet_path(x) else: return x + self.resnet_path(x) ================================================ FILE: src/tha4/nn/separable_conv.py ================================================ from typing import Optional from torch.nn import Sequential, Conv2d, ConvTranspose2d, Module from tha4.nn.normalization import NormalizationLayerFactory from tha4.nn.util import BlockArgs, wrap_conv_or_linear_module def create_separable_conv3(in_channels: int, out_channels: int, bias: bool = False, initialization_method='he', use_spectral_norm: bool = False) -> Module: return Sequential( wrap_conv_or_linear_module( Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=False, groups=in_channels), initialization_method, use_spectral_norm), wrap_conv_or_linear_module( Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias), initialization_method, use_spectral_norm)) def create_separable_conv7(in_channels: int, out_channels: int, bias: bool = False, initialization_method='he', use_spectral_norm: bool = False) -> Module: return Sequential( wrap_conv_or_linear_module( Conv2d(in_channels, in_channels, kernel_size=7, stride=1, padding=3, bias=False, groups=in_channels), initialization_method, use_spectral_norm), wrap_conv_or_linear_module( Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias), initialization_method, use_spectral_norm)) def create_separable_conv3_block( in_channels: int, out_channels: int, block_args: Optional[BlockArgs] = None): if block_args is None: block_args = BlockArgs() return Sequential( wrap_conv_or_linear_module( Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=False, groups=in_channels), block_args.initialization_method, block_args.use_spectral_norm), wrap_conv_or_linear_module( Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), block_args.initialization_method, block_args.use_spectral_norm), NormalizationLayerFactory.resolve_2d(block_args.normalization_layer_factory).create(out_channels, affine=True), block_args.nonlinearity_factory.create()) def create_separable_conv7_block( in_channels: int, out_channels: int, block_args: Optional[BlockArgs] = None): if block_args is None: block_args = BlockArgs() return Sequential( wrap_conv_or_linear_module( Conv2d(in_channels, in_channels, kernel_size=7, stride=1, padding=3, bias=False, groups=in_channels), block_args.initialization_method, block_args.use_spectral_norm), wrap_conv_or_linear_module( Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), block_args.initialization_method, block_args.use_spectral_norm), NormalizationLayerFactory.resolve_2d(block_args.normalization_layer_factory).create(out_channels, affine=True), block_args.nonlinearity_factory.create()) def create_separable_downsample_block( in_channels: int, out_channels: int, is_output_1x1: bool, block_args: Optional[BlockArgs] = None): if block_args is None: block_args = BlockArgs() if is_output_1x1: return Sequential( wrap_conv_or_linear_module( Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1, bias=False, groups=in_channels), block_args.initialization_method, block_args.use_spectral_norm), wrap_conv_or_linear_module( Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), block_args.initialization_method, block_args.use_spectral_norm), block_args.nonlinearity_factory.create()) else: return Sequential( wrap_conv_or_linear_module( Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1, bias=False, groups=in_channels), block_args.initialization_method, block_args.use_spectral_norm), wrap_conv_or_linear_module( Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), block_args.initialization_method, block_args.use_spectral_norm), NormalizationLayerFactory.resolve_2d(block_args.normalization_layer_factory) .create(out_channels, affine=True), block_args.nonlinearity_factory.create()) def create_separable_upsample_block( in_channels: int, out_channels: int, block_args: Optional[BlockArgs] = None): if block_args is None: block_args = BlockArgs() return Sequential( wrap_conv_or_linear_module( ConvTranspose2d( in_channels, in_channels, kernel_size=4, stride=2, padding=1, bias=False, groups=in_channels), block_args.initialization_method, block_args.use_spectral_norm), wrap_conv_or_linear_module( Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), block_args.initialization_method, block_args.use_spectral_norm), NormalizationLayerFactory.resolve_2d(block_args.normalization_layer_factory) .create(out_channels, affine=True), block_args.nonlinearity_factory.create()) ================================================ FILE: src/tha4/nn/siren/__init__.py ================================================ ================================================ FILE: src/tha4/nn/siren/face_morpher/__init__.py ================================================ ================================================ FILE: src/tha4/nn/siren/face_morpher/siren_face_morpher_00.py ================================================ from typing import Optional, List import torch from torch import Tensor from torch.nn import Module from torch.nn.functional import affine_grid from tha4.shion.core.module_factory import ModuleFactory from tha4.nn.siren.vanilla.siren import SirenArgs, Siren class SirenFaceMorpher00Args: def __init__(self, image_size: int, image_channels: int, pose_size: int, siren_args: SirenArgs): assert siren_args.in_channels == pose_size + 2 assert siren_args.out_channels == image_channels assert not siren_args.use_tanh self.siren_args = siren_args self.pose_size = pose_size self.image_size = image_size self.image_channels = image_channels class SirenFaceMorpher00(Module): def __init__(self, args: SirenFaceMorpher00Args): super().__init__() self.args = args self.siren = Siren(self.args.siren_args) def forward(self, pose: Tensor, position: Optional[Tensor] = None) -> Tensor: n, p = pose.shape[0], pose.shape[1] device = pose.device if position is None: h, w = self.args.image_size, self.args.image_size identity = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device).unsqueeze(0) position = affine_grid(identity, [1, 1, h, w], align_corners=False) \ .view(1, h * w, 2) position = torch.transpose(position, dim0=1, dim1=2).view(1, 2, h, w) \ .repeat(n, 1, 1, 1) h, w = position.shape[2], position.shape[3] pose_image = pose.view(n, p, 1, 1).repeat(1, 1, h, w) siren_input = torch.cat([position, pose_image], dim=1) return self.siren.forward(siren_input) class SirenFaceMorpher00Factory(ModuleFactory): def __init__(self, args: SirenFaceMorpher00Args): self.args = args def create(self) -> Module: return SirenFaceMorpher00(self.args) ================================================ FILE: src/tha4/nn/siren/face_morpher/siren_face_morpher_00_trainer.py ================================================ from typing import Dict, List, Optional, Callable import torch from tha4.shion.base.dataset.lazy_tensor_dataset import LazyTensorDataset from tha4.shion.base.image_util import extract_pytorch_image_from_filelike from tha4.shion.base.loss.l1_loss import L1Loss, MaskedL1Loss from tha4.shion.base.loss.sum_loss import SumLoss from tha4.shion.base.optimizer_factories import AdamOptimizerFactory from tha4.shion.core.training.distrib.distributed_trainer import DistributedTrainer from tha4.dataset.image_poses_and_aother_images_dataset import ImagePosesAndOtherImagesDataset from tha4.nn.siren.face_morpher.siren_face_morpher_00 import SirenFaceMorpher00Factory, SirenFaceMorpher00Args from tha4.nn.siren.face_morpher.siren_face_morpher_protocols_00 import SirenFaceMorpherComputationProtocol00, \ SirenFaceMorpherSampleOutputProtocol00 from tha4.nn.siren.morpher.siren_morpher_protocols_03 import SirenMorpherTrainingProtocol03 from tha4.nn.siren.vanilla.siren import SirenArgs from tha4.poser.poser import Poser from torch import Tensor KEY_MODULE = "module" KEY_POSER = "poser" def get_poser(): import tha4.poser.modes.mode_12 poser = tha4.poser.modes.mode_12.create_poser(torch.device('cpu')) return poser class SirenFaceMorpher00TrainerArgs: def __init__(self, character_file_name: str, face_mask_file_name: str, pose_dataset_file_name: str, num_training_total_examples: int = 1_000_000, num_training_examples_per_checkpoint: int = 100_000, num_training_examples_lr_boundaries: Optional[List[int]] = None, num_training_examples_per_sample_output: Optional[int] = 5_000, num_training_examples_per_snapshot: int = 10_000, total_batch_size: int = 8, training_random_seed: int = 2965603729, sample_output_random_seed: int = 3522651501, total_worker: int = 16, poser_func: Optional[Callable[[], Poser]] = None, base_learning_rate: float = 1e-4): assert num_training_total_examples % num_training_examples_per_checkpoint == 0 if num_training_examples_lr_boundaries is None: num_training_examples_lr_boundaries = [ int(num_training_examples_per_checkpoint * 2), int(num_training_examples_per_checkpoint * 5), int(num_training_examples_per_checkpoint * 8), ] for x in num_training_examples_lr_boundaries: assert x % num_training_examples_per_snapshot == 0 if poser_func is None: poser_func = get_poser self.face_mask_file_name = face_mask_file_name self.base_learning_rate = base_learning_rate self.poser_func = poser_func self.total_worker = total_worker self.num_training_examples_per_snapshot = num_training_examples_per_snapshot self.num_training_examples_per_sample_output = num_training_examples_per_sample_output self.sample_output_random_seed = sample_output_random_seed self.training_random_seed = training_random_seed self.total_batch_size = total_batch_size self.num_training_total_examples = num_training_total_examples self.num_training_examples_per_checkpoint = num_training_examples_per_checkpoint self.num_training_examples_lr_boundaries = num_training_examples_lr_boundaries self.pose_dataset_file_name = pose_dataset_file_name self.character_file_name = character_file_name def get_character_image(self): return extract_pytorch_image_from_filelike( self.character_file_name, scale=2.0, offset=-1.0, premultiply_alpha=True, perform_srgb_to_linear=True) def get_face_mask_image(self): loaded_image = extract_pytorch_image_from_filelike( self.face_mask_file_name, scale=1.0, offset=0.0, premultiply_alpha=True, perform_srgb_to_linear=True) output_image = torch.zeros(4, 128, 128) center_x = 256 center_y = 128 + 16 for i in range(4): output_image[i, :, :] = loaded_image[0, center_y - 64:center_y + 64, center_x - 64:center_x + 64] return output_image def get_training_dataset(self): return ImagePosesAndOtherImagesDataset( main_image_func=self.get_character_image, other_image_funcs=[self.get_face_mask_image], pose_dataset=LazyTensorDataset(self.pose_dataset_file_name)) def get_module_factory(self): return SirenFaceMorpher00Factory( SirenFaceMorpher00Args( image_size=128, image_channels=4, pose_size=39, siren_args=SirenArgs( in_channels=39 + 2, out_channels=4, intermediate_channels=128, num_sine_layers=8))) def transform_pose_to_module_input(self, pose: Tensor): return pose[:, 0:39] def transform_original_image_to_module_input(self, image: Tensor): center_x = 256 center_y = 128 + 16 return image[:, :, center_y - 64:center_y + 64, center_x - 64:center_x + 64] def transform_poser_posed_image_to_groundtruth(self, image: Tensor): center_x = 96 center_y = 96 + 16 return image[:, :, center_y - 64:center_y + 64, center_x - 64:center_x + 64] def get_training_computation_protocol(self): return SirenFaceMorpherComputationProtocol00( transform_pose_to_module_input_func=self.transform_pose_to_module_input, transform_original_image_to_module_input_func=self.transform_original_image_to_module_input, transform_poser_posed_image_to_groundtruth_func=self.transform_poser_posed_image_to_groundtruth) def get_learning_rate(self, examples_seen_so_far) -> Dict[str, float]: if examples_seen_so_far < self.num_training_examples_lr_boundaries[0]: return { KEY_MODULE: self.base_learning_rate, } elif examples_seen_so_far < self.num_training_examples_lr_boundaries[1]: return { KEY_MODULE: self.base_learning_rate / 3.0, } elif examples_seen_so_far < self.num_training_examples_lr_boundaries[2]: return { KEY_MODULE: self.base_learning_rate / 10.0, } else: return { KEY_MODULE: self.base_learning_rate / 30.0, } def get_optimizer_factories(self): return { KEY_MODULE: AdamOptimizerFactory(betas=(0.9, 0.999)), } def get_poser(self): return self.poser_func() def get_training_protocol(self, world_size: int): total_examples = self.num_training_total_examples per_checkpoint_examples = self.num_training_examples_per_checkpoint num_checkpoints = total_examples // per_checkpoint_examples batch_size = self.total_batch_size // world_size return SirenMorpherTrainingProtocol03( check_point_examples=[per_checkpoint_examples * (i + 1) for i in range(num_checkpoints)], batch_size=batch_size, learning_rate=self.get_learning_rate, optimizer_factories=self.get_optimizer_factories(), random_seed=self.training_random_seed, poser_func=self.get_poser, key_module=KEY_MODULE, key_poser=KEY_POSER) def get_sample_output_protocol(self): return SirenFaceMorpherSampleOutputProtocol00( num_images=8, image_size=128, images_per_row=2, examples_per_sample_output=self.num_training_examples_per_sample_output, computation_protocol=self.get_training_computation_protocol(), poser_func=self.get_poser, random_seed=self.sample_output_random_seed) def get_loss(self): protocol = self.get_training_computation_protocol() return SumLoss([ ( 'full', L1Loss( expected_func=protocol.get_output_func(protocol.keys.groundtruth_posed_image), actual_func=protocol.get_output_func(protocol.keys.predicted_posed_image), weight=1.0) ), ( 'eye_mouth', MaskedL1Loss( expected_func=protocol.get_output_func(protocol.keys.groundtruth_posed_image), actual_func=protocol.get_output_func(protocol.keys.predicted_posed_image), mask_func=protocol.get_output_func(protocol.keys.eye_mouth_mask), weight=20.0) ), ]) def create_trainer(self, prefix: str, world_size: int, distrib_backend: str = 'gloo'): if self.num_training_examples_per_sample_output is not None: sample_output_protocol = self.get_sample_output_protocol() else: sample_output_protocol = None return DistributedTrainer( prefix=prefix, module_factories={ KEY_MODULE: self.get_module_factory(), }, accumulators={}, losses={ KEY_MODULE: self.get_loss(), }, training_dataset=self.get_training_dataset(), validation_dataset=self.get_training_dataset(), training_protocol=self.get_training_protocol(world_size), validation_protocol=None, sample_output_protocol=sample_output_protocol, pretrained_module_file_names={}, example_per_snapshot=self.num_training_examples_per_snapshot, num_data_loader_workers=max(1, self.total_worker // world_size), distrib_backend=distrib_backend) ================================================ FILE: src/tha4/nn/siren/face_morpher/siren_face_morpher_protocols_00.py ================================================ import os from dataclasses import dataclass from typing import Dict, Any, Optional, Callable import PIL.Image import numpy import torch from tha4.shion.base.dataset.util import get_indexed_batch from tha4.shion.base.image_util import pytorch_rgba_to_numpy_image from tha4.shion.core.cached_computation import CachedComputationProtocol, ComputationState, \ ComposableCachedComputationProtocol, batch_indexing_func, add_step from tha4.shion.core.training.sample_output_protocol import SampleOutputProtocol from tha4.poser.general_poser_02 import GeneralPoser02 from torch import Tensor from torch.nn import Module from torch.utils.data import Dataset KEY_MODULE = "module" KEY_POSER = "poser" @dataclass class SirenMorpherProtocol00Keys: module: str = KEY_MODULE module_output: str = "module_output" poser: str = KEY_POSER poser_output: str = "poser_output" original_image: str = "original_image" original_pose: str = "original_pose" module_input_image: str = "module_input_image" module_input_pose: str = "module_input_pose" groundtruth_posed_image: str = 'groundtruth_posed_image' predicted_posed_image: str = 'predicted_posed_image' eye_mouth_mask: str = 'eye_mouth_mask' @dataclass class SirenMorpherProtocol00Indices: batch_original_image: int = 0 batch_pose: int = 1 batch_eye_mouth_mask: int = 2 poser_posed_image: int = 0 class SirenFaceMorpherComputationProtocol00(ComposableCachedComputationProtocol): def __init__(self, transform_pose_to_module_input_func: Callable[[Tensor], Tensor], transform_original_image_to_module_input_func: Callable[[Tensor], Tensor], transform_poser_posed_image_to_groundtruth_func: Callable[[Tensor], Tensor], keys: Optional[SirenMorpherProtocol00Keys] = None, indices: Optional[SirenMorpherProtocol00Indices] = None): super().__init__() if keys is None: keys = SirenMorpherProtocol00Keys() if indices is None: indices = SirenMorpherProtocol00Indices() self.keys = keys self.indices = indices self.transform_image_to_module_input_func = transform_original_image_to_module_input_func self.transform_pose_to_module_input_func = transform_pose_to_module_input_func self.transform_poser_posed_image_to_groundtruth_func = transform_poser_posed_image_to_groundtruth_func self.computation_steps[keys.original_image] = batch_indexing_func(indices.batch_original_image) self.computation_steps[keys.original_pose] = batch_indexing_func(indices.batch_pose) @add_step(self.computation_steps, keys.module_input_pose) def get_module_input_pose(protocol: CachedComputationProtocol, state: ComputationState): original_pose = protocol.get_output(keys.original_pose, state) return transform_pose_to_module_input_func(original_pose) @add_step(self.computation_steps, keys.module_input_image) def get_module_input_image(protocol: CachedComputationProtocol, state: ComputationState): original_image = protocol.get_output(keys.original_image, state) return transform_original_image_to_module_input_func(original_image) @add_step(self.computation_steps, keys.poser_output) def get_poser_output(protocol: CachedComputationProtocol, state: ComputationState): with torch.no_grad(): poser = state.modules[keys.poser] pose = protocol.get_output(keys.original_pose, state) image = protocol.get_output(keys.original_image, state) return poser.get_posing_outputs(image, pose) @add_step(self.computation_steps, keys.groundtruth_posed_image) def get_groundtruth_posed_image(protocol: CachedComputationProtocol, state: ComputationState): poser_output = protocol.get_output(keys.poser_output, state) poser_posed_image = poser_output[indices.poser_posed_image] return transform_poser_posed_image_to_groundtruth_func(poser_posed_image) @add_step(self.computation_steps, keys.module_output) def get_module_output(protocol: CachedComputationProtocol, state: ComputationState): module_input_pose = protocol.get_output(keys.module_input_pose, state) module = state.modules[keys.module] return module.forward(module_input_pose) @add_step(self.computation_steps, keys.predicted_posed_image) def get_predicted_image(protocol: CachedComputationProtocol, state: ComputationState): return protocol.get_output(keys.module_output, state) self.computation_steps[keys.eye_mouth_mask] = batch_indexing_func(indices.batch_eye_mouth_mask) class SirenFaceMorpherSampleOutputProtocol00(SampleOutputProtocol): def __init__(self, num_images: int, image_size: int, images_per_row: int, examples_per_sample_output: int, computation_protocol: SirenFaceMorpherComputationProtocol00, poser_func: Callable[[], GeneralPoser02], random_seed: int = 54859395058, batch_size: Optional[int] = None): if batch_size is None: batch_size = num_images self.batch_size = batch_size self.poser_func = poser_func self.random_seed = random_seed self.examples_per_sample_output = examples_per_sample_output self.images_per_row = images_per_row self.image_size = image_size self.num_images = num_images self.computation_protocol = computation_protocol def get_examples_per_sample_output(self) -> int: return self.examples_per_sample_output def get_random_seed(self) -> int: return self.random_seed def get_sample_output_data(self, validation_dataset: Dataset, device: torch.device) -> dict: example_indices = torch.randint(0, len(validation_dataset), (self.num_images,)) example_indices = [example_indices[i].item() for i in range(self.num_images)] batch = get_indexed_batch(validation_dataset, example_indices, device) poser = self.poser_func() poser.to(device) with torch.no_grad(): ground_truth = poser.pose( batch[self.computation_protocol.indices.batch_original_image], batch[self.computation_protocol.indices.batch_pose]) return { 'batch': batch, 'ground_truth': ground_truth } def save_sample_output_data(self, modules: Dict[str, Module], accumulated_modules: Dict[str, Module], sample_output_data: Any, prefix: str, examples_seen_so_far: int, device: torch.device): batch = sample_output_data['batch'] ground_truth = sample_output_data['ground_truth'] ground_truth = self.computation_protocol.transform_poser_posed_image_to_groundtruth_func(ground_truth) module = modules[self.computation_protocol.keys.module] module.train(False) if self.batch_size == self.num_images: with torch.no_grad(): state = ComputationState( modules=modules, accumulated_modules=accumulated_modules, batch=batch, outputs={}) poser_output_images = self.computation_protocol.get_output( self.computation_protocol.keys.predicted_posed_image, state) else: poser_output_images_list = [] start = 0 while start < self.num_images: end = start + self.batch_size end = min(self.num_images, end) minibatch = [batch[i][start:end] for i in range(len(batch))] state = ComputationState( modules=modules, accumulated_modules=accumulated_modules, batch=minibatch, outputs={}) with torch.no_grad(): poser_output_images = self.computation_protocol.get_output( self.computation_protocol.keys.predicted_posed_image, state) poser_output_images_list.append(poser_output_images) start = end poser_output_images = torch.cat(poser_output_images_list, dim=0) num_rows = self.num_images // self.images_per_row if self.num_images % self.images_per_row > 0: num_rows += 1 num_cols = 2 * self.images_per_row image_channels = 4 output_image = numpy.zeros([self.image_size * num_rows, self.image_size * num_cols, image_channels]) for image_index in range(self.num_images): row = image_index // self.images_per_row start_row = row * self.image_size col = 2 * (image_index % self.images_per_row) start_col = col * self.image_size output_image[start_row:start_row + self.image_size, start_col:start_col + self.image_size, :] \ = pytorch_rgba_to_numpy_image(ground_truth[image_index].detach().cpu()) start_col += self.image_size output_image[start_row:start_row + self.image_size, start_col:start_col + self.image_size, :] \ = pytorch_rgba_to_numpy_image(poser_output_images[image_index].detach().cpu()) file_name = "%s/sample_output_%010d.png" % (prefix, examples_seen_so_far) os.makedirs(os.path.dirname(file_name), exist_ok=True) pil_image = PIL.Image.fromarray(numpy.uint8(numpy.rint(output_image * 255.0)), mode='RGBA') pil_image.save(file_name) print("Saved %s" % file_name) ================================================ FILE: src/tha4/nn/siren/morpher/__init__.py ================================================ ================================================ FILE: src/tha4/nn/siren/morpher/siren_morpher_03.py ================================================ from typing import List, Optional, Callable import torch from torch import Tensor from torch.nn import Module, ModuleList, Sequential, Conv2d from torch.nn.functional import affine_grid, interpolate from tha4.shion.core.module_factory import ModuleFactory from tha4.shion.nn00.initialization_funcs import HeInitialization from tha4.nn.image_processing_util import GridChangeApplier from tha4.nn.siren.vanilla.siren import SineLinearLayer class SirenMorpherLevelArgs: def __init__(self, image_size: int, intermediate_channels: int, num_sine_layers: int): assert num_sine_layers >= 2 self.image_size = image_size self.num_sine_layers = num_sine_layers self.intermediate_channels = intermediate_channels class SirenMorpher03Args: def __init__(self, image_size: int, image_channels: int, pose_size: int, level_args: List[SirenMorpherLevelArgs], init_func: Optional[Callable[[Module], Module]] = None): assert len(level_args) >= 2 if init_func is None: init_func = HeInitialization() self.image_size = image_size self.init_func = init_func self.level_args = level_args self.pose_size = pose_size self.image_channels = image_channels class SirenMorpher03(Module): def __init__(self, args: SirenMorpher03Args): super().__init__() self.args = args self.siren_layers = ModuleList() for i in range(len(args.level_args)): level_args = args.level_args[i] layers = [] if i == 0: layers.append(SineLinearLayer( in_channels=args.pose_size + 2, out_channels=level_args.intermediate_channels, is_first=True)) else: layers.append(SineLinearLayer( in_channels=level_args.intermediate_channels + args.pose_size + 2, out_channels=level_args.intermediate_channels, is_first=False)) for j in range(1, level_args.num_sine_layers - 1): layers.append(SineLinearLayer( in_channels=level_args.intermediate_channels, out_channels=level_args.intermediate_channels, is_first=False)) if i == len(args.level_args) - 1: out_channels = level_args.intermediate_channels else: out_channels = args.level_args[i + 1].intermediate_channels layers.append(SineLinearLayer( in_channels=level_args.intermediate_channels, out_channels=out_channels, is_first=False)) self.siren_layers.append(Sequential(*layers)) self.last_linear = args.init_func(Conv2d( args.level_args[-1].intermediate_channels, args.image_channels + 2 + 1, kernel_size=1, stride=1, padding=0, bias=True)) self.grid_change_applier = GridChangeApplier() def get_position_grid(self, n: int, image_size: int, device: torch.device): h, w = image_size, image_size identity = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device).unsqueeze(0) position = affine_grid(identity, [1, 1, h, w], align_corners=False) \ .view(1, h * w, 2) position = torch.transpose(position, dim0=1, dim1=2).view(1, 2, h, w) \ .repeat(n, 1, 1, 1) return position def get_pose_image(self, pose: Tensor, image_size: int): n, p = pose.shape[0], pose.shape[1] h, w = image_size, image_size pose_image = pose.view(n, p, 1, 1).repeat(1, 1, h, w) return pose_image def forward(self, image: Tensor, pose: Tensor) -> List[Tensor]: n = pose.shape[0] device = pose.device x = None for i in range(len(self.args.level_args)): args = self.args.level_args[i] position_and_pose = torch.cat([ self.get_position_grid(n, args.image_size, device), self.get_pose_image(pose, args.image_size) ], dim=1) if i == 0: x = self.siren_layers[i].forward(position_and_pose) else: x = interpolate(x, size=(args.image_size, args.image_size), mode='bilinear') x = torch.cat([x, position_and_pose], dim=1) x = self.siren_layers[i].forward(x) siren_output = self.last_linear(x) grid_change = siren_output[:, 0:2, :, :] alpha = siren_output[:, 2:3, :, :] color_change = siren_output[:, 3:, :, :] warped_image = self.grid_change_applier.apply(grid_change, image, align_corners=False) blended_image = (1 - alpha) * warped_image + alpha * color_change return [ blended_image, alpha, color_change, warped_image, grid_change ] INDEX_BLENDED_IMAGE = 0 INDEX_ALPHA = 1 INDEX_COLOR_CHANGE = 2 INDEX_WARPED_IMAGE = 3 INDEX_GRID_CHANGE = 4 class SirenMorpher03Factory(ModuleFactory): def __init__(self, args: SirenMorpher03Args): self.args = args def create(self): return SirenMorpher03(self.args) ================================================ FILE: src/tha4/nn/siren/morpher/siren_morpher_03_trainer.py ================================================ from enum import Enum from typing import Dict, List, Optional, Callable import torch from tha4.shion.base.dataset.lazy_tensor_dataset import LazyTensorDataset from tha4.shion.base.image_util import extract_pytorch_image_from_filelike from tha4.shion.base.loss.l1_loss import L1Loss from tha4.shion.base.loss.sum_loss import SumLoss from tha4.shion.base.loss.time_dependently_weighted_loss import TimeDependentlyWeightedLoss from tha4.shion.base.optimizer_factories import AdamOptimizerFactory from tha4.shion.core.training.distrib.distributed_trainer import DistributedTrainer from tha4.dataset.image_poses_and_aother_images_dataset import ImagePosesAndOtherImagesDataset from tha4.nn.siren.morpher.siren_morpher_03 import SirenMorpherLevelArgs, SirenMorpher03Factory, SirenMorpher03Args from tha4.nn.siren.morpher.siren_morpher_protocols_03 import SirenMorpherComputationProtocol03, \ SirenMorpherProtocol03Indices, KEY_MODULE, KEY_POSER, KEY_EXAMPLES_SEEN_SO_FAR, SirenMorpherTrainingProtocol03, \ SirenMorpherSampleOutputProtocol from tha4.poser.poser import Poser def get_poser(): import tha4.poser.modes.mode_07 poser = tha4.poser.modes.mode_07.create_poser(torch.device('cpu')) return poser class LossTerm(Enum): full_blended = 1 full_warped = 2 full_grid_change = 3 full_color_change = 4 def get_loss(self, protocol: SirenMorpherComputationProtocol03): if self == LossTerm.full_blended: return L1Loss( expected_func=protocol.get_output_func(protocol.keys.groundtruth_posed_image), actual_func=protocol.get_output_func(protocol.keys.predicted_posed_image)) elif self == LossTerm.full_warped: return L1Loss( expected_func=protocol.get_output_func(protocol.keys.groundtruth_warped_image), actual_func=protocol.get_output_func(protocol.keys.predicted_warped_image)) elif self == LossTerm.full_grid_change: return L1Loss( expected_func=protocol.get_output_func(protocol.keys.groundtruth_grid_change), actual_func=protocol.get_output_func(protocol.keys.predicted_grid_change)) elif self == LossTerm.full_color_change: return L1Loss( expected_func=protocol.get_output_func(protocol.keys.groundtruth_posed_image), actual_func=protocol.get_output_func(protocol.keys.predicted_color_change)) else: raise RuntimeError(f"Unsupported loss term {self}") class LossWeights: def __init__(self, weights: Optional[Dict[LossTerm, float]] = None): self.weights = {} for term in LossTerm: self.weights[term] = 0.0 if weights is not None: for term in LossTerm: if term in weights: self.weights[term] = weights[term] class TrainingPhase: def __init__(self, num_examples_upper_bound: int, learning_rate: float, loss_weights: LossWeights): self.loss_weights = loss_weights self.learning_rate = learning_rate self.num_examples_upper_bound = num_examples_upper_bound class LearningRateFunc: def __init__(self, phases: List[TrainingPhase], keys: List[str]): self.phases = phases self.keys = keys def make_learning_rate_dict(self, keys: List[str], value: float): output = {} for key in keys: output[key] = value return output def __call__(self, examples_seen_so_far: int) -> Dict[str, float]: for i in range(len(self.phases) - 1): if examples_seen_so_far < self.phases[i].num_examples_upper_bound: return self.make_learning_rate_dict(self.keys, self.phases[i].learning_rate) return self.make_learning_rate_dict(self.keys, self.phases[-1].learning_rate) class LossWeightFunc: def __init__(self, phases: List[TrainingPhase], term: LossTerm): self.term = term self.phases = phases def __call__(self, examples_seen_so_far: int) -> float: for i in range(len(self.phases) - 1): if examples_seen_so_far < self.phases[i].num_examples_upper_bound: return self.phases[i].loss_weights.weights[self.term] return self.phases[-1].loss_weights.weights[self.term] class TrainingPhases: def __init__(self, phases: List[TrainingPhase]): assert len(phases) > 0 for i in range(1, len(phases)): assert phases[i - 1].num_examples_upper_bound < phases[i].num_examples_upper_bound self.phases = phases def make_learning_rate_dict(self, keys: List[str], value: float): output = {} for key in keys: output[key] = value return output def get_learning_rate_func(self, keys: List[str]): return LearningRateFunc(self.phases, keys) def get_loss_weight_func(self, term: LossTerm) -> Callable[[int], float]: return LossWeightFunc(self.phases, term) class SirenMorpher03TrainerArgs: def __init__(self, character_file_name: str, pose_dataset_file_name: str, training_phases: TrainingPhases, num_training_examples_per_checkpoint: int = 100_000, num_training_examples_per_sample_output: Optional[int] = 10_000, num_training_examples_per_snapshot: int = 10_000, total_batch_size: int = 8, training_random_seed: int = 2965603729, sample_output_random_seed: int = 3522651501, total_worker: int = 8, poser_func: Optional[Callable[[], Poser]] = None, sample_output_batch_size: Optional[int] = None, pretrained_module_file_name: Optional[str] = None): for phase in training_phases.phases: assert phase.num_examples_upper_bound % num_training_examples_per_checkpoint == 0 if poser_func is None: poser_func = get_poser self.training_phases = training_phases self.pretrained_module_file_name = pretrained_module_file_name self.sample_output_batch_size = sample_output_batch_size self.poser_func = poser_func self.total_worker = total_worker self.num_training_examples_per_snapshot = num_training_examples_per_snapshot self.num_training_examples_per_sample_output = num_training_examples_per_sample_output self.sample_output_random_seed = sample_output_random_seed self.training_random_seed = training_random_seed self.total_batch_size = total_batch_size self.num_training_examples_per_checkpoint = num_training_examples_per_checkpoint self.pose_dataset_file_name = pose_dataset_file_name self.character_file_name = character_file_name def get_character_image(self): return extract_pytorch_image_from_filelike( self.character_file_name, scale=2.0, offset=-1.0, premultiply_alpha=True, perform_srgb_to_linear=True) def get_training_dataset(self): return ImagePosesAndOtherImagesDataset( main_image_func=self.get_character_image, pose_dataset=LazyTensorDataset(self.pose_dataset_file_name), other_image_funcs=[]) def get_module_factory(self): return SirenMorpher03Factory( SirenMorpher03Args( image_size=512, image_channels=4, pose_size=45, level_args=[ SirenMorpherLevelArgs( image_size=128, intermediate_channels=360, num_sine_layers=3), SirenMorpherLevelArgs( image_size=256, intermediate_channels=180, num_sine_layers=3), SirenMorpherLevelArgs( image_size=512, intermediate_channels=90, num_sine_layers=3), ])) def get_training_computation_protocol(self): return SirenMorpherComputationProtocol03( indices=SirenMorpherProtocol03Indices( batch_image=0, batch_pose=1, batch_face_mask=2)) def get_optimizer_factories(self): return { KEY_MODULE: AdamOptimizerFactory(betas=(0.9, 0.999)), } def get_poser(self): return self.poser_func() def get_training_protocol(self, world_size: int): total_examples = self.training_phases.phases[-1].num_examples_upper_bound per_checkpoint_examples = self.num_training_examples_per_checkpoint num_checkpoints = total_examples // per_checkpoint_examples batch_size = self.total_batch_size // world_size return SirenMorpherTrainingProtocol03( check_point_examples=[per_checkpoint_examples * (i + 1) for i in range(num_checkpoints)], batch_size=batch_size, learning_rate=self.training_phases.get_learning_rate_func([KEY_MODULE]), optimizer_factories=self.get_optimizer_factories(), random_seed=self.training_random_seed, poser_func=self.get_poser, key_module=KEY_MODULE, key_poser=KEY_POSER) def get_sample_output_protocol(self): return SirenMorpherSampleOutputProtocol( num_images=4, image_size=512, examples_per_sample_output=self.num_training_examples_per_sample_output, computation_protocol=self.get_training_computation_protocol(), poser_func=self.get_poser, random_seed=self.sample_output_random_seed, batch_size=self.sample_output_batch_size, batch_pose_index=1, batch_image_index=0) def get_loss(self): protocol = self.get_training_computation_protocol() losses = [] for term in LossTerm: base_loss = term.get_loss(protocol) loss = TimeDependentlyWeightedLoss( base_loss, examples_seen_so_far_func=lambda state: state.outputs[KEY_EXAMPLES_SEEN_SO_FAR], weight_func=self.training_phases.get_loss_weight_func(term)) losses.append((term.name, loss)) return SumLoss(losses) def create_trainer(self, prefix: str, world_size: int, distrib_backend: str = 'gloo'): if self.num_training_examples_per_sample_output is not None: sample_output_protocol = self.get_sample_output_protocol() else: sample_output_protocol = None pretrained_module_file_names = {} if self.pretrained_module_file_name is not None: pretrained_module_file_names[KEY_MODULE] = self.pretrained_module_file_name return DistributedTrainer( prefix=prefix, module_factories={ KEY_MODULE: self.get_module_factory(), }, accumulators={}, losses={ KEY_MODULE: self.get_loss(), }, training_dataset=self.get_training_dataset(), validation_dataset=self.get_training_dataset(), training_protocol=self.get_training_protocol(world_size), validation_protocol=None, sample_output_protocol=sample_output_protocol, pretrained_module_file_names=pretrained_module_file_names, example_per_snapshot=self.num_training_examples_per_snapshot, num_data_loader_workers=max(1, self.total_worker // world_size), distrib_backend=distrib_backend) ================================================ FILE: src/tha4/nn/siren/morpher/siren_morpher_protocols_03.py ================================================ from dataclasses import dataclass from typing import Optional, List, Callable, Dict, Any import torch from tha4.shion.base.dataset.util import get_indexed_batch from tha4.shion.core.cached_computation import output_array_indexing_func, add_step, ComputationState, \ CachedComputationProtocol, ComposableCachedComputationProtocol, batch_indexing_func, proxy_func from tha4.shion.core.loss import Loss from tha4.shion.core.optimizer_factory import OptimizerFactory from tha4.shion.core.training.sample_output_protocol import SampleOutputProtocol from tha4.shion.core.training.training_protocol import AbstractTrainingProtocol from tha4.nn.image_processing_util import GridChangeApplier from tha4.nn.siren.morpher.siren_morpher_03 import SirenMorpher03 from tha4.poser.general_poser_02 import GeneralPoser02 from tha4.sampleoutput.sample_image_creator import SampleImageSpec, ImageSource, ImageType, SampleImageSaver from torch.nn import Module from torch.optim import Optimizer from torch.utils.data import Dataset KEY_MODULE = "module" KEY_POSER = "poser" KEY_EXAMPLES_SEEN_SO_FAR = "examples_seen_so_far" @dataclass class SirenMorpherProtocol03Keys: module: str = KEY_MODULE module_output: str = "module_output" poser: str = KEY_POSER poser_output: str = "poser_output" image: str = "image" pose: str = "pose" face_mask: str = 'face_mask' groundtruth_posed_image: str = 'groundtruth_posed_image' groundtruth_grid_change: str = 'groundtruth_grid_change' groundtruth_posed_face_mask: str = 'groundtruth_posed_face_mask' predicted_posed_image: str = 'predicted_posed_image' module_input_image: str = "module_input_image" predicted_grid_change: str = "predicted_grid_change" predicted_color_change: str = "predicted_color_change" predicted_warped_image: str = "predicted_warped_image" predicted_alpha: str = "predicted_alpha" groundtruth_alpha: str = "groundtruth_alpha" groundtruth_warped_image: str = "groundtruth_warped_image" zero: str = "zero" @dataclass class SirenMorpherProtocol03Indices: batch_image: int = 0 batch_face_mask: int = 1 batch_pose: int = 2 poser_posed_image: int = 0 poser_grid_change: int = 3 poser_output_module_input_image_index: int = 5 poser_alpha: int = 1 poser_warped_image: int = 2 module_blended_image: int = SirenMorpher03.INDEX_BLENDED_IMAGE module_grid_change: int = SirenMorpher03.INDEX_GRID_CHANGE module_color_change: int = SirenMorpher03.INDEX_COLOR_CHANGE module_warped_image: int = SirenMorpher03.INDEX_WARPED_IMAGE module_alpha: int = SirenMorpher03.INDEX_ALPHA class SirenMorpherComputationProtocol03(ComposableCachedComputationProtocol): def __init__(self, keys: Optional[SirenMorpherProtocol03Keys] = None, indices: Optional[SirenMorpherProtocol03Indices] = None): super().__init__() if keys is None: keys = SirenMorpherProtocol03Keys() if indices is None: indices = SirenMorpherProtocol03Indices() self.keys = keys self.indices = indices self.computation_steps[keys.image] = batch_indexing_func(indices.batch_image) self.computation_steps[keys.pose] = batch_indexing_func(indices.batch_pose) self.computation_steps[keys.face_mask] = batch_indexing_func(indices.batch_face_mask) self.grid_change_applier = GridChangeApplier() @add_step(self.computation_steps, keys.module_output) def get_module_output(protocol: CachedComputationProtocol, state: ComputationState): pose = protocol.get_output(keys.pose, state) module = state.modules[self.keys.module] return module.forward(pose) self.computation_steps[keys.predicted_posed_image] = proxy_func(keys.module_output) @add_step(self.computation_steps, keys.poser_output) def get_poser_output(protocol: CachedComputationProtocol, state: ComputationState): with torch.no_grad(): poser = state.modules[keys.poser] pose = protocol.get_output(keys.pose, state) image = protocol.get_output(keys.image, state) return poser.get_posing_outputs(image, pose) @add_step(self.computation_steps, keys.groundtruth_posed_image) def get_groundtruth_posed_image(protocol: CachedComputationProtocol, state: ComputationState): return protocol.get_output(keys.poser_output, state)[indices.poser_posed_image] @add_step(self.computation_steps, keys.groundtruth_grid_change) def get_groundtruth_posed_image(protocol: CachedComputationProtocol, state: ComputationState): return protocol.get_output(keys.poser_output, state)[indices.poser_grid_change] @add_step(self.computation_steps, keys.groundtruth_posed_face_mask) def get_groundtruth_posed_face_mask(protocol: CachedComputationProtocol, state: ComputationState): face_mask = protocol.get_output(keys.face_mask, state) groundtruth_grid_change = protocol.get_output(keys.groundtruth_grid_change, state) with torch.no_grad(): return self.grid_change_applier.apply(groundtruth_grid_change, face_mask) @add_step(self.computation_steps, keys.module_input_image) def get_module_input_image(protocol: CachedComputationProtocol, state: ComputationState): poser_output = protocol.get_output(keys.poser_output, state) return poser_output[indices.poser_output_module_input_image_index] @add_step(self.computation_steps, keys.module_output) def get_module_output(protocol: CachedComputationProtocol, state: ComputationState): image = protocol.get_output(keys.module_input_image, state) pose = protocol.get_output(keys.pose, state) module = state.modules[self.keys.module] return module.forward(image, pose) self.computation_steps[keys.predicted_posed_image] = output_array_indexing_func( keys.module_output, indices.module_blended_image) self.computation_steps[keys.predicted_grid_change] = output_array_indexing_func( keys.module_output, indices.module_grid_change) self.computation_steps[keys.predicted_color_change] = output_array_indexing_func( keys.module_output, indices.module_color_change) self.computation_steps[keys.predicted_warped_image] = output_array_indexing_func( keys.module_output, indices.module_warped_image) self.computation_steps[keys.predicted_alpha] = output_array_indexing_func( keys.module_output, indices.module_alpha) self.computation_steps[keys.groundtruth_alpha] = output_array_indexing_func( keys.poser_output, indices.poser_alpha) self.computation_steps[keys.groundtruth_warped_image] = output_array_indexing_func( keys.poser_output, indices.poser_warped_image) @add_step(self.computation_steps, keys.zero) def get_zero(protocol: CachedComputationProtocol, state: ComputationState): pose = protocol.get_output(keys.pose, state) device = pose.device return torch.zeros(1, device=device) class SirenMorpherTrainingProtocol03(AbstractTrainingProtocol): def __init__(self, check_point_examples: List[int], batch_size: int, learning_rate: Callable[[int], Dict[str, float]], optimizer_factories: Dict[str, OptimizerFactory], random_seed: int, poser_func: Callable[[], GeneralPoser02], key_module: str, key_poser: str = KEY_POSER, key_examples_seen_so_far: str = KEY_EXAMPLES_SEEN_SO_FAR): super().__init__(check_point_examples, batch_size, learning_rate, optimizer_factories, random_seed) self.key_examples_seen_so_far = key_examples_seen_so_far self.key_poser = key_poser self.key_module = key_module self.poser_func = poser_func self.poser = None def run_training_iteration( self, batch: Any, examples_seen_so_far: int, modules: Dict[str, Module], accumulated_modules: Dict[str, Module], optimizers: Dict[str, Optimizer], losses: Dict[str, Loss], create_log_func: Optional[Callable[[str, int], Callable[[str, float], None]]], device: torch.device): if self.poser is None: self.poser = self.poser_func() self.poser.to(device) module = modules[self.key_module] module.train(True) module_optimizer = optimizers[self.key_module] module_optimizer.zero_grad(set_to_none=True) loss = losses[self.key_module] if create_log_func is not None: log_func = create_log_func(f"training_{self.key_module}", examples_seen_so_far) else: log_func = None state = ComputationState( modules={ **modules, self.key_poser: self.poser, }, accumulated_modules=accumulated_modules, batch=batch, outputs={ self.key_examples_seen_so_far: examples_seen_so_far, }) loss_value = loss.compute(state, log_func) loss_value.backward() module_optimizer.step() class SirenMorpherSampleOutputProtocol(SampleOutputProtocol): def __init__(self, num_images: int, image_size: int, examples_per_sample_output: int, computation_protocol, poser_func: Callable[[], GeneralPoser02], random_seed: int = 54859395058, batch_image_index: int = 0, batch_pose_index: int = 2, batch_size: Optional[int] = None, sample_image_specs: Optional[List[SampleImageSpec]] = None, cell_size: Optional[int] = None): if batch_size is None: batch_size = num_images if sample_image_specs is None: sample_image_specs = [ SampleImageSpec( ImageSource.BATCH, computation_protocol.indices.poser_posed_image, ImageType.COLOR), SampleImageSpec( ImageSource.OUTPUT, computation_protocol.indices.module_blended_image, ImageType.COLOR), SampleImageSpec( ImageSource.OUTPUT, computation_protocol.indices.module_alpha, ImageType.ALPHA), SampleImageSpec( ImageSource.OUTPUT, computation_protocol.indices.module_color_change, ImageType.COLOR), SampleImageSpec( ImageSource.OUTPUT, computation_protocol.indices.module_warped_image, ImageType.COLOR), SampleImageSpec( ImageSource.BATCH, computation_protocol.indices.poser_grid_change, ImageType.GRID_CHANGE), SampleImageSpec( ImageSource.OUTPUT, computation_protocol.indices.module_grid_change, ImageType.GRID_CHANGE), ] if cell_size is None: cell_size = image_size self.batch_size = batch_size self.batch_pose_index = batch_pose_index self.batch_image_index = batch_image_index self.poser_func = poser_func self.random_seed = random_seed self.examples_per_sample_output = examples_per_sample_output self.image_size = image_size self.num_images = num_images self.computation_protocol = computation_protocol self.cell_size = cell_size self.sample_image_saver = SampleImageSaver(image_size, cell_size, 4, sample_image_specs) def get_examples_per_sample_output(self) -> int: return self.examples_per_sample_output def get_random_seed(self) -> int: return self.random_seed def get_sample_output_data(self, validation_dataset: Dataset, device: torch.device) -> dict: example_indices = torch.randint(0, len(validation_dataset), (self.num_images,)) example_indices = [example_indices[i].item() for i in range(self.num_images)] batch = get_indexed_batch(validation_dataset, example_indices, device) poser = self.poser_func() poser.to(device) with torch.no_grad(): ground_truth = poser.get_posing_outputs(batch[self.batch_image_index], batch[self.batch_pose_index]) return { 'batch': batch, 'ground_truth': ground_truth } def save_sample_output_data(self, modules: Dict[str, Module], accumulated_modules: Dict[str, Module], sample_output_data: Any, prefix: str, examples_seen_so_far: int, device: torch.device): batch = sample_output_data['batch'] ground_truth = sample_output_data['ground_truth'] module = modules[self.computation_protocol.keys.module] module.train(False) if self.batch_size == self.num_images: with torch.no_grad(): state = ComputationState( modules=modules, accumulated_modules=accumulated_modules, batch=batch, outputs={ self.computation_protocol.keys.poser_output: ground_truth, }) module_outputs = self.computation_protocol.get_output( self.computation_protocol.keys.module_output, state) else: module_outputs_list = [] start = 0 while start < self.num_images: end = start + self.batch_size end = min(self.num_images, end) minibatch = [batch[i][start:end] for i in range(len(batch))] ground_truth_batch = [ground_truth[i][start:end] for i in range(len(ground_truth))] state = ComputationState( modules=modules, accumulated_modules=accumulated_modules, batch=minibatch, outputs={ self.computation_protocol.keys.poser_output: ground_truth_batch }) with torch.no_grad(): module_outputs = self.computation_protocol.get_output( self.computation_protocol.keys.module_output, state) module_outputs_list.append(module_outputs) start = end module_outputs = [] for i in range(len(module_outputs_list[0])): tensor_list = [] for j in range(len(module_outputs_list)): tensor_list.append(module_outputs_list[j][i]) module_output = torch.cat(tensor_list, dim=0) module_outputs.append(module_output) self.sample_image_saver.save_sample_output_data(ground_truth, module_outputs, prefix, examples_seen_so_far) ================================================ FILE: src/tha4/nn/siren/vanilla/__init__.py ================================================ ================================================ FILE: src/tha4/nn/siren/vanilla/siren.py ================================================ import math from typing import Callable, Optional, List import torch from torch import Tensor from torch.nn import Module, Conv2d, ModuleList from tha4.shion.core.module_factory import ModuleFactory from tha4.shion.nn00.initialization_funcs import HeInitialization class SineLinearLayer(Module): def __init__(self, in_channels: int, out_channels: int, is_first=False, omega_0=30.0): super().__init__() self.out_channels = out_channels self.in_channels = in_channels self.omega_0 = omega_0 self.is_first = is_first self.linear = Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, bias=True) with torch.no_grad(): if self.is_first: self.linear.weight.uniform_(-1 / in_channels, 1.0 / in_channels) else: self.linear.weight.uniform_( -math.sqrt(6.0 / in_channels) / self.omega_0, math.sqrt(6.0 / in_channels) / self.omega_0) def forward(self, x: Tensor): return torch.sin(self.omega_0 * self.linear(x)) class SirenArgs: def __init__( self, in_channels: int, out_channels: int, intermediate_channels: int, num_sine_layers: int, use_tanh: bool = False, init_func: Optional[Callable[[Module], Module]] = None): if init_func is None: init_func = HeInitialization() self.init_func = init_func self.use_tanh = use_tanh assert num_sine_layers >= 1 self.intermediate_channels = intermediate_channels self.num_sine_layers = num_sine_layers self.out_channels = out_channels self.in_channels = in_channels class Siren(Module): def __init__(self, args: SirenArgs): super().__init__() self.args = args self.sine_layers = ModuleList() self.sine_layers.append( SineLinearLayer( in_channels=args.in_channels, out_channels=args.intermediate_channels, is_first=True)) for i in range(args.num_sine_layers - 1): self.sine_layers.append( SineLinearLayer( in_channels=args.intermediate_channels, out_channels=args.intermediate_channels, is_first=False)) self.last_linear = args.init_func(Conv2d( args.intermediate_channels, args.out_channels, kernel_size=1, stride=1, padding=0, bias=True)) def forward(self, x: Tensor) -> Tensor: for i in range(self.args.num_sine_layers): x = self.sine_layers[i].forward(x) x = self.last_linear(x) if self.args.use_tanh: return torch.tanh(x) else: return x class SirenFactory(ModuleFactory): def __init__(self, args: SirenArgs): super().__init__() self.args = args def create(self) -> Module: return Siren(self.args) ================================================ FILE: src/tha4/nn/spectral_norm.py ================================================ from torch.nn import Module from torch.nn.utils import spectral_norm def apply_spectral_norm(module: Module, use_spectrial_norm: bool = False) -> Module: if use_spectrial_norm: return spectral_norm(module) else: return module ================================================ FILE: src/tha4/nn/upscaler/__init__.py ================================================ ================================================ FILE: src/tha4/nn/upscaler/upscaler_02.py ================================================ from typing import List import torch from torch import Tensor, zero_ from torch.nn import Module, Conv2d from tha4.shion.core.module_factory import ModuleFactory from tha4.nn.image_processing_util import GridChangeApplier from tha4.nn.common.unet import UnetArgs, AttentionBlockArgs, UnetWithFirstConvAddition class Upscaler02Args: def __init__(self, image_size: int, image_channels: int, num_pose_parameters: int, unet_args: UnetArgs): assert unet_args.in_channels == ( image_channels ) assert unet_args.out_channels == ( image_channels + # direct 2 + # warp 1 # alpha ) assert unet_args.cond_input_channels == num_pose_parameters self.image_channels = image_channels self.image_size = image_size self.num_pose_parameters = num_pose_parameters self.unet_args = unet_args def apply_color_change(alpha, color_change, image: Tensor) -> Tensor: return color_change * alpha + image * (1 - alpha) class Upscaler02(Module): def __init__(self, args: Upscaler02Args): super().__init__() self.args = args self.body = UnetWithFirstConvAddition(args.unet_args) self.grid_change_applier = GridChangeApplier() self.coarse_image_conv = Conv2d( args.image_channels + args.image_channels + 2, args.unet_args.model_channels, kernel_size=3, stride=1, padding=1) with torch.no_grad(): zero_(self.coarse_image_conv.weight) zero_(self.coarse_image_conv.bias) def check_image(self, image: torch.Tensor): assert len(image.shape) == 4 assert image.shape[1] == self.args.image_channels assert image.shape[2] == self.args.image_size assert image.shape[3] == self.args.image_size def forward(self, rest_image: torch.Tensor, coarse_posed_image: torch.Tensor, coarse_grid_change: torch.Tensor, pose: torch.Tensor) -> List[Tensor]: self.check_image(rest_image) self.check_image(coarse_posed_image) assert len(pose.shape) == 2 assert rest_image.shape[0] == pose.shape[0] assert coarse_posed_image.shape[0] == pose.shape[0] assert coarse_grid_change.shape[0] == pose.shape[0] assert coarse_grid_change.shape[1] == 2 assert coarse_grid_change.shape[2] == self.args.image_size assert coarse_grid_change.shape[3] == self.args.image_size assert pose.shape[1] == self.args.num_pose_parameters warped_image = self.grid_change_applier.apply(coarse_grid_change, rest_image) t = torch.zeros(rest_image.shape[0], 1, device=rest_image.device) feature = torch.cat([coarse_posed_image, warped_image, coarse_grid_change], dim=1) first_conv_addition = self.coarse_image_conv(feature) body_output = self.body(rest_image, t, pose, first_conv_addition) direct = body_output[:, 0:self.args.image_channels, :, :] grid_change = body_output[:, self.args.image_channels:self.args.image_channels + 2, :, :] alpha = torch.sigmoid(body_output[:, self.args.image_channels + 2:self.args.image_channels + 3, :, :]) warped = self.grid_change_applier.apply(grid_change, rest_image) merged = apply_color_change(alpha, direct, warped) return [ merged, alpha, warped, grid_change, direct ] INDEX_MERGED = 0 INDEX_ALPHA = 1 INDEX_WARPED = 2 INDEX_GRID_CHANGE = 3 INDEX_DIRECT = 4 class Upscaler02Factory(ModuleFactory): def __init__(self, args: Upscaler02Args): self.args = args def create(self) -> Module: return Upscaler02(self.args) ================================================ FILE: src/tha4/nn/util.py ================================================ from typing import Optional, Callable, Union from torch.nn import Module from tha4.shion.core.module_factory import ModuleFactory from tha4.nn.init_function import create_init_function from tha4.nn.nonlinearity_factory import resolve_nonlinearity_factory from tha4.nn.normalization import NormalizationLayerFactory from tha4.nn.spectral_norm import apply_spectral_norm def wrap_conv_or_linear_module(module: Module, initialization_method: Union[str, Callable[[Module], Module]], use_spectral_norm: bool): if isinstance(initialization_method, str): init = create_init_function(initialization_method) else: init = initialization_method return apply_spectral_norm(init(module), use_spectral_norm) class BlockArgs: def __init__(self, initialization_method: Union[str, Callable[[Module], Module]] = 'he', use_spectral_norm: bool = False, normalization_layer_factory: Optional[NormalizationLayerFactory] = None, nonlinearity_factory: Optional[ModuleFactory] = None): self.nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory) self.normalization_layer_factory = normalization_layer_factory self.use_spectral_norm = use_spectral_norm self.initialization_method = initialization_method def wrap_module(self, module: Module) -> Module: return wrap_conv_or_linear_module(module, self.get_init_func(), self.use_spectral_norm) def get_init_func(self) -> Callable[[Module], Module]: if isinstance(self.initialization_method, str): return create_init_function(self.initialization_method) else: return self.initialization_method ================================================ FILE: src/tha4/poser/__init__.py ================================================ ================================================ FILE: src/tha4/poser/general_poser_02.py ================================================ from typing import List, Optional, Tuple, Dict, Callable import torch from tha4.shion.core.cached_computation import ComputationState from tha4.poser.poser import PoseParameterGroup, Poser from torch import Tensor from torch.nn import Module class GeneralPoser02(Poser): def __init__(self, module_loaders: Dict[str, Callable[[], Module]], device: torch.device, output_length: int, pose_parameters: List[PoseParameterGroup], output_list_func: Callable[[ComputationState], List[Tensor]], subrect: Optional[Tuple[Tuple[int, int], Tuple[int, int]]] = None, default_output_index: int = 0, image_size: int = 256, dtype: torch.dtype = torch.float): self.dtype = dtype self.image_size = image_size self.default_output_index = default_output_index self.output_list_func = output_list_func self.subrect = subrect self.pose_parameters = pose_parameters self.device = device self.module_loaders = module_loaders self.modules = None self.num_parameters = 0 for pose_parameter in self.pose_parameters: self.num_parameters += pose_parameter.get_arity() self.output_length = output_length def get_image_size(self) -> int: return self.image_size def get_modules(self): if self.modules is None: self.modules = {} for key in self.module_loaders: module = self.module_loaders[key]() self.modules[key] = module module.to(self.device) module.train(False) return self.modules def get_pose_parameter_groups(self) -> List[PoseParameterGroup]: return self.pose_parameters def get_num_parameters(self) -> int: return self.num_parameters def pose(self, image: Tensor, pose: Tensor, output_index: Optional[int] = None) -> Tensor: if output_index is None: output_index = self.default_output_index output_list = self.get_posing_outputs(image, pose) return output_list[output_index] def get_posing_outputs(self, image: Tensor, pose: Tensor) -> List[Tensor]: modules = self.get_modules() if len(image.shape) == 3: image = image.unsqueeze(0) if len(pose.shape) == 1: pose = pose.unsqueeze(0) if self.subrect is not None: image = image[:, :, self.subrect[0][0]:self.subrect[0][1], self.subrect[1][0]:self.subrect[1][1]] batch = [image, pose] state = ComputationState( modules=modules, accumulated_modules={}, batch=batch, outputs={}) return self.output_list_func(state) def get_output_length(self) -> int: return self.output_length def free(self): self.modules = None def get_dtype(self) -> torch.dtype: return self.dtype def to(self, device: torch.device) -> 'GeneralPoser02': if device == self.device: return self modules = self.get_modules() self.device = device for key in modules: module = modules[key] module.to(self.device) return self ================================================ FILE: src/tha4/poser/modes/__init__.py ================================================ ================================================ FILE: src/tha4/poser/modes/mode_07.py ================================================ from enum import Enum from typing import List, Dict, Optional import torch from tha4.shion.core.cached_computation import CachedComputationProtocol, ComputationState from tha4.shion.core.load_save import torch_load from tha4.nn.eyebrow_decomposer.eyebrow_decomposer_00 import EyebrowDecomposer00, \ EyebrowDecomposer00Factory, EyebrowDecomposer00Args from tha4.nn.eyebrow_morphing_combiner.eyebrow_morphing_combiner_00 import \ EyebrowMorphingCombiner00Factory, EyebrowMorphingCombiner00Args, EyebrowMorphingCombiner00 from tha4.nn.face_morpher.face_morpher_08 import FaceMorpher08Args, FaceMorpher08Factory from tha4.nn.nonlinearity_factory import ReLUFactory from tha4.nn.normalization import InstanceNorm2dFactory from tha4.nn.util import BlockArgs from tha4.nn.common.unet import UnetArgs, AttentionBlockArgs from tha4.nn.morpher.morpher_00 import Morpher00Args, Morpher00 from tha4.nn.upscaler.upscaler_02 import Upscaler02Args, Upscaler02 from tha4.poser.general_poser_02 import GeneralPoser02 from tha4.poser.modes.pose_parameters import get_pose_parameters from torch import Tensor from torch.nn.functional import interpolate class Network(Enum): eyebrow_decomposer = 1 eyebrow_morphing_combiner = 2 face_morpher = 3 body_morpher = 4 upscaler = 5 @property def outputs_key(self): return f"{self.name}_outputs" class Branch(Enum): face_morphed_half = 1 face_morphed_full = 2 all_outputs = 3 NUM_EYEBROW_PARAMS = 12 NUM_FACE_PARAMS = 27 NUM_ROTATION_PARAMS = 6 class FiveStepPoserComputationProtocol(CachedComputationProtocol): def __init__(self, eyebrow_morphed_image_index: int): super().__init__() self.eyebrow_morphed_image_index = eyebrow_morphed_image_index self.cached_batch_0 = None self.cached_eyebrow_decomposer_output = None def compute_func(self): def func(state: ComputationState) -> List[Tensor]: if self.cached_batch_0 is None: new_batch_0 = True elif state.batch[0].shape[0] != self.cached_batch_0.shape[0]: new_batch_0 = True else: new_batch_0 = torch.max((state.batch[0] - self.cached_batch_0).abs()).item() > 0 if not new_batch_0: state.outputs[Network.eyebrow_decomposer.outputs_key] = self.cached_eyebrow_decomposer_output output = self.get_output(Branch.all_outputs.name, state) if new_batch_0: self.cached_batch_0 = state.batch[0] self.cached_eyebrow_decomposer_output = state.outputs[Network.eyebrow_decomposer.outputs_key] return output return func def compute_output(self, key: str, state: ComputationState) -> List[Tensor]: if key == Network.eyebrow_decomposer.outputs_key: input_image = state.batch[0][:, :, 64:192, 64 + 128:192 + 128] return state.modules[Network.eyebrow_decomposer.name].forward(input_image) elif key == Network.eyebrow_morphing_combiner.outputs_key: eyebrow_decomposer_output = self.get_output(Network.eyebrow_decomposer.outputs_key, state) background_layer = eyebrow_decomposer_output[EyebrowDecomposer00.BACKGROUND_LAYER_INDEX] eyebrow_layer = eyebrow_decomposer_output[EyebrowDecomposer00.EYEBROW_LAYER_INDEX] eyebrow_pose = state.batch[1][:, :NUM_EYEBROW_PARAMS] return state.modules[Network.eyebrow_morphing_combiner.name].forward( background_layer, eyebrow_layer, eyebrow_pose) elif key == Network.face_morpher.outputs_key: eyebrow_morphing_combiner_output = self.get_output( Network.eyebrow_morphing_combiner.outputs_key, state) eyebrow_morphed_image = eyebrow_morphing_combiner_output[self.eyebrow_morphed_image_index] input_image = state.batch[0][:, :, 32:32 + 192, (32 + 128):(32 + 192 + 128)].clone() input_image[:, :, 32:32 + 128, 32:32 + 128] = eyebrow_morphed_image face_pose = state.batch[1][:, NUM_EYEBROW_PARAMS:NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS] return state.modules[Network.face_morpher.name].forward(input_image, face_pose) elif key == Branch.face_morphed_full.name: face_morpher_output = self.get_output(Network.face_morpher.outputs_key, state) face_morphed_image = face_morpher_output[0] input_image = state.batch[0].clone() input_image[:, :, 32:32 + 192, 32 + 128:32 + 192 + 128] = face_morphed_image return [input_image] elif key == Branch.face_morphed_half.name: face_morphed_full = self.get_output(Branch.face_morphed_full.name, state)[0] return [ interpolate(face_morphed_full, size=(256, 256), mode='bilinear', align_corners=False) ] elif key == Network.body_morpher.outputs_key: face_morphed_half = self.get_output(Branch.face_morphed_half.name, state)[0] rotation_pose = state.batch[1][:, NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS:] return state.modules[Network.body_morpher.name].forward(face_morphed_half, rotation_pose) elif key == Network.upscaler.outputs_key: rest_image = self.get_output(Branch.face_morphed_full.name, state)[0] body_morpher_outputs = self.get_output( Network.body_morpher.outputs_key, state) half_res_posed_image = body_morpher_outputs[Morpher00.INDEX_MERGED] half_res_grid_change = body_morpher_outputs[Morpher00.INDEX_GRID_CHANGE] coarse_posed_image = interpolate(half_res_posed_image, size=(512, 512), mode='bilinear') coarse_grid_change = interpolate(half_res_grid_change, size=(512, 512), mode='bilinear') rotation_pose = state.batch[1][:, NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS:] return state.modules[Network.upscaler.name].forward( rest_image, coarse_posed_image, coarse_grid_change, rotation_pose) elif key == Branch.all_outputs.name: upscaler_output = self.get_output(Network.upscaler.outputs_key, state) face_morphed_full = self.get_output(Branch.face_morphed_full.name, state) body_morpher_output = self.get_output(Network.body_morpher.outputs_key, state) face_morpher_output = self.get_output(Network.face_morpher.outputs_key, state) eyebrow_morphing_combiner_output = self.get_output(Network.eyebrow_morphing_combiner.outputs_key, state) eyebrow_decomposer_output = self.get_output(Network.eyebrow_decomposer.outputs_key, state) output = upscaler_output \ + face_morphed_full \ + body_morpher_output \ + face_morpher_output \ + eyebrow_morphing_combiner_output \ + eyebrow_decomposer_output return output else: raise RuntimeError("Unsupported key: " + key) def load_eyebrow_decomposer(file_name: str): factory = EyebrowDecomposer00Factory( EyebrowDecomposer00Args( image_size=128, image_channels=4, start_channels=64, bottleneck_image_size=16, num_bottleneck_blocks=6, max_channels=512, block_args=BlockArgs( initialization_method='he', use_spectral_norm=False, normalization_layer_factory=InstanceNorm2dFactory(), nonlinearity_factory=ReLUFactory(inplace=True)))) print("Loading the eyebrow decomposer ... ", end="") module = factory.create() module.load_state_dict(torch_load(file_name)) print("DONE!!!") return module def load_eyebrow_morphing_combiner(file_name: str): factory = EyebrowMorphingCombiner00Factory( EyebrowMorphingCombiner00Args( image_size=128, image_channels=4, start_channels=64, num_pose_params=12, bottleneck_image_size=16, num_bottleneck_blocks=6, max_channels=512, block_args=BlockArgs( initialization_method='he', use_spectral_norm=False, normalization_layer_factory=InstanceNorm2dFactory(), nonlinearity_factory=ReLUFactory(inplace=True)))) print("Loading the eyebrow morphing conbiner ... ", end="") module = factory.create() module.load_state_dict(torch_load(file_name)) print("DONE!!!") return module def load_face_morpher(file_name: str): factory = FaceMorpher08Factory( FaceMorpher08Args( image_size=192, image_channels=4, num_expression_params=27, start_channels=64, bottleneck_image_size=24, num_bottleneck_blocks=6, max_channels=512, block_args=BlockArgs( initialization_method='he', use_spectral_norm=False, normalization_layer_factory=InstanceNorm2dFactory(), nonlinearity_factory=ReLUFactory(inplace=False) ), output_iris_mouth_grid_change=True, ) ) print("Loading the face morpher ... ", end="") module = factory.create() module.load_state_dict(torch_load(file_name)) print("DONE!!!") return module def apply_color_change(alpha, color_change, image: Tensor) -> Tensor: return color_change * alpha + image * (1 - alpha) def load_morpher_00(file_name: str): unet_args = UnetArgs( in_channels=4, out_channels=7, model_channels=64, level_channel_multipliers=[1, 2, 4, 4, 4], level_use_attention=[False, False, False, False, True], num_res_blocks_per_level=1, num_middle_res_blocks=4, time_embedding_channels=None, cond_input_channels=6, cond_internal_channels=256, attention_block_args=AttentionBlockArgs( num_heads=8, use_new_attention_order=True), dropout_prob=0.0) morpher_00_args = Morpher00Args( image_size=256, image_channels=4, num_pose_parameters=6, unet_args=unet_args) morpher_00 = Morpher00(morpher_00_args) print("Loading the body morpher ... ", end="") morpher_00.load_state_dict(torch_load(file_name)) print("DONE") morpher_00.train(False) return morpher_00 def load_upscaler_02(file_name: str): unet_args = UnetArgs( in_channels=4, out_channels=7, model_channels=32, level_channel_multipliers=[1, 2, 4, 8, 8, 8], level_use_attention=[False, False, False, False, False, True], num_res_blocks_per_level=1, num_middle_res_blocks=4, time_embedding_channels=None, cond_input_channels=6, cond_internal_channels=256, attention_block_args=AttentionBlockArgs( num_heads=8, use_new_attention_order=True), dropout_prob=0.0) upscaler_02_args = Upscaler02Args( image_size=512, image_channels=4, num_pose_parameters=6, unet_args=unet_args) upscaler_02 = Upscaler02(upscaler_02_args) print("Loading the upscaler ... ", end="") upscaler_02.load_state_dict(torch_load(file_name)) print("DONE") upscaler_02.train(False) return upscaler_02 def create_poser( device: torch.device, module_file_names: Optional[Dict[str, str]] = None, eyebrow_morphed_image_index: int = EyebrowMorphingCombiner00.EYEBROW_IMAGE_NO_COMBINE_ALPHA_INDEX, default_output_index: int = 0) -> GeneralPoser02: if module_file_names is None: module_file_names = {} if Network.eyebrow_decomposer.name not in module_file_names: file_name = "data/tha4/eyebrow_decomposer.pt" module_file_names[Network.eyebrow_decomposer.name] = file_name if Network.eyebrow_morphing_combiner.name not in module_file_names: file_name = "data/tha4/eyebrow_morphing_combiner.pt" module_file_names[Network.eyebrow_morphing_combiner.name] = file_name if Network.face_morpher.name not in module_file_names: file_name = "data/tha4/face_morpher.pt" module_file_names[Network.face_morpher.name] = file_name if Network.body_morpher.name not in module_file_names: file_name = "data/tha4/body_morpher.pt" module_file_names[Network.body_morpher.name] = file_name if Network.upscaler.name not in module_file_names: file_name = "data/tha4/upscaler.pt" module_file_names[Network.upscaler.name] = file_name loaders = { Network.eyebrow_decomposer.name: lambda: load_eyebrow_decomposer(module_file_names[Network.eyebrow_decomposer.name]), Network.eyebrow_morphing_combiner.name: lambda: load_eyebrow_morphing_combiner(module_file_names[Network.eyebrow_morphing_combiner.name]), Network.face_morpher.name: lambda: load_face_morpher(module_file_names[Network.face_morpher.name]), Network.body_morpher.name: lambda: load_morpher_00(module_file_names[Network.body_morpher.name]), Network.upscaler.name: lambda: load_upscaler_02(module_file_names[Network.upscaler.name]), } return GeneralPoser02( image_size=512, module_loaders=loaders, pose_parameters=get_pose_parameters().get_pose_parameter_groups(), output_list_func=FiveStepPoserComputationProtocol(eyebrow_morphed_image_index).compute_func(), subrect=None, device=device, output_length=5 + 1 + 5 + 8 + 8 + 6, default_output_index=default_output_index) ================================================ FILE: src/tha4/poser/modes/mode_12.py ================================================ from enum import Enum from typing import List, Dict, Optional, Any import torch from tha4.shion.core.cached_computation import CachedComputationProtocol, ComputationState from tha4.shion.core.load_save import torch_load from tha4.nn.eyebrow_decomposer.eyebrow_decomposer_00 import EyebrowDecomposer00, \ EyebrowDecomposer00Factory, EyebrowDecomposer00Args from tha4.nn.eyebrow_morphing_combiner.eyebrow_morphing_combiner_00 import \ EyebrowMorphingCombiner00Factory, EyebrowMorphingCombiner00Args, EyebrowMorphingCombiner00 from tha4.nn.face_morpher.face_morpher_08 import FaceMorpher08Args, FaceMorpher08Factory from tha4.nn.nonlinearity_factory import ReLUFactory from tha4.nn.normalization import InstanceNorm2dFactory from tha4.nn.util import BlockArgs from tha4.poser.general_poser_02 import GeneralPoser02 from tha4.poser.modes.pose_parameters import get_pose_parameters from torch import Tensor class Network(Enum): eyebrow_decomposer = 1 eyebrow_morphing_combiner = 2 face_morpher = 3 @property def outputs_key(self): return f"{self.name}_outputs" class Branch(Enum): face_morphed_half = 1 face_morphed_full = 2 all_outputs = 3 NUM_EYEBROW_PARAMS = 12 NUM_FACE_PARAMS = 27 NUM_ROTATION_PARAMS = 6 class FiveStepPoserComputationProtocol(CachedComputationProtocol): def __init__(self, eyebrow_morphed_image_index: int): super().__init__() self.eyebrow_morphed_image_index = eyebrow_morphed_image_index self.cached_batch_0 = None self.cached_eyebrow_decomposer_output = None def compute_func(self): def func(state: ComputationState) -> List[Tensor]: if self.cached_batch_0 is None: new_batch_0 = True elif state.batch[0].shape[0] != self.cached_batch_0.shape[0]: new_batch_0 = True else: new_batch_0 = torch.max((state.batch[0] - self.cached_batch_0).abs()).item() > 0 if not new_batch_0: state.outputs[Network.eyebrow_decomposer.outputs_key] = self.cached_eyebrow_decomposer_output output = self.get_output(Branch.all_outputs.name, state) if new_batch_0: self.cached_batch_0 = state.batch[0] self.cached_eyebrow_decomposer_output = state.outputs[Network.eyebrow_decomposer.outputs_key] return output return func def compute_output(self, key: str, state: ComputationState) -> Any: if key == Network.eyebrow_decomposer.outputs_key: input_image = state.batch[0][:, :, 64:192, 64 + 128:192 + 128] return state.modules[Network.eyebrow_decomposer.name].forward(input_image) elif key == Network.eyebrow_morphing_combiner.outputs_key: eyebrow_decomposer_output = self.get_output(Network.eyebrow_decomposer.outputs_key, state) background_layer = eyebrow_decomposer_output[EyebrowDecomposer00.BACKGROUND_LAYER_INDEX] eyebrow_layer = eyebrow_decomposer_output[EyebrowDecomposer00.EYEBROW_LAYER_INDEX] eyebrow_pose = state.batch[1][:, :NUM_EYEBROW_PARAMS] return state.modules[Network.eyebrow_morphing_combiner.name].forward( background_layer, eyebrow_layer, eyebrow_pose) elif key == Network.face_morpher.outputs_key: eyebrow_morphing_combiner_output = self.get_output( Network.eyebrow_morphing_combiner.outputs_key, state) eyebrow_morphed_image = eyebrow_morphing_combiner_output[self.eyebrow_morphed_image_index] input_image = state.batch[0][:, :, 32:32 + 192, (32 + 128):(32 + 192 + 128)].clone() input_image[:, :, 32:32 + 128, 32:32 + 128] = eyebrow_morphed_image face_pose = state.batch[1][:, NUM_EYEBROW_PARAMS:NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS] return state.modules[Network.face_morpher.name].forward(input_image, face_pose) elif key == Branch.all_outputs.name: face_morpher_output = self.get_output(Network.face_morpher.outputs_key, state) eyebrow_morphing_combiner_output = self.get_output(Network.eyebrow_morphing_combiner.outputs_key, state) eyebrow_decomposer_output = self.get_output(Network.eyebrow_decomposer.outputs_key, state) output = face_morpher_output \ + eyebrow_morphing_combiner_output \ + eyebrow_decomposer_output return output else: raise RuntimeError("Unsupported key: " + key) def load_eyebrow_decomposer(file_name: str): factory = EyebrowDecomposer00Factory( EyebrowDecomposer00Args( image_size=128, image_channels=4, start_channels=64, bottleneck_image_size=16, num_bottleneck_blocks=6, max_channels=512, block_args=BlockArgs( initialization_method='he', use_spectral_norm=False, normalization_layer_factory=InstanceNorm2dFactory(), nonlinearity_factory=ReLUFactory(inplace=True)))) print("Loading the eyebrow decomposer ... ", end="") module = factory.create() module.load_state_dict(torch_load(file_name)) print("DONE!!!") return module def load_eyebrow_morphing_combiner(file_name: str): factory = EyebrowMorphingCombiner00Factory( EyebrowMorphingCombiner00Args( image_size=128, image_channels=4, start_channels=64, num_pose_params=12, bottleneck_image_size=16, num_bottleneck_blocks=6, max_channels=512, block_args=BlockArgs( initialization_method='he', use_spectral_norm=False, normalization_layer_factory=InstanceNorm2dFactory(), nonlinearity_factory=ReLUFactory(inplace=True)))) print("Loading the eyebrow morphing conbiner ... ", end="") module = factory.create() module.load_state_dict(torch_load(file_name)) print("DONE!!!") return module def load_face_morpher(file_name: str): factory = FaceMorpher08Factory( FaceMorpher08Args( image_size=192, image_channels=4, num_expression_params=27, start_channels=64, bottleneck_image_size=24, num_bottleneck_blocks=6, max_channels=512, block_args=BlockArgs( initialization_method='he', use_spectral_norm=False, normalization_layer_factory=InstanceNorm2dFactory(), nonlinearity_factory=ReLUFactory(inplace=False)), output_iris_mouth_grid_change=True)) print("Loading the face morpher ... ", end="") module = factory.create() module.load_state_dict(torch_load(file_name)) print("DONE!!!") return module def apply_color_change(alpha, color_change, image: Tensor) -> Tensor: return color_change * alpha + image * (1 - alpha) def create_poser( device: torch.device, module_file_names: Optional[Dict[str, str]] = None, eyebrow_morphed_image_index: int = EyebrowMorphingCombiner00.EYEBROW_IMAGE_NO_COMBINE_ALPHA_INDEX, default_output_index: int = 0) -> GeneralPoser02: if module_file_names is None: module_file_names = {} if Network.eyebrow_decomposer.name not in module_file_names: file_name = "data/tha4/eyebrow_decomposer.pt" module_file_names[Network.eyebrow_decomposer.name] = file_name if Network.eyebrow_morphing_combiner.name not in module_file_names: file_name = "data/tha4/eyebrow_morphing_combiner.pt" module_file_names[Network.eyebrow_morphing_combiner.name] = file_name if Network.face_morpher.name not in module_file_names: file_name = "data/tha4/face_morpher.pt" module_file_names[Network.face_morpher.name] = file_name loaders = { Network.eyebrow_decomposer.name: lambda: load_eyebrow_decomposer(module_file_names[Network.eyebrow_decomposer.name]), Network.eyebrow_morphing_combiner.name: lambda: load_eyebrow_morphing_combiner(module_file_names[Network.eyebrow_morphing_combiner.name]), Network.face_morpher.name: lambda: load_face_morpher(module_file_names[Network.face_morpher.name]), } return GeneralPoser02( image_size=512, module_loaders=loaders, pose_parameters=get_pose_parameters().get_pose_parameter_groups(), output_list_func=FiveStepPoserComputationProtocol(eyebrow_morphed_image_index).compute_func(), subrect=None, device=device, output_length=5 + 5 + 8, default_output_index=default_output_index) ================================================ FILE: src/tha4/poser/modes/mode_14.py ================================================ from dataclasses import dataclass from typing import List, Optional, Dict, Any import torch from tha4.shion.core.cached_computation import CachedComputationProtocol, ComputationState from tha4.shion.core.load_save import torch_load from tha4.nn.siren.face_morpher.siren_face_morpher_00 import SirenFaceMorpher00Args, SirenFaceMorpher00 from tha4.nn.siren.morpher.siren_morpher_03 import SirenMorpher03, SirenMorpher03Args, SirenMorpherLevelArgs from tha4.nn.siren.vanilla.siren import SirenArgs from tha4.poser.general_poser_02 import GeneralPoser02 from tha4.poser.modes.pose_parameters import get_pose_parameters from torch import Tensor KEY_FACE_MORPHER = "face_morpher" KEY_BODY_MORPHER = "body_morpher" @dataclass class Keys: face_morpher: str = KEY_FACE_MORPHER face_morpher_output: str = "face_morpher_output" face_morpher_input_image: str = "face_morpher_input_image" face_morpher_input_pose: str = "face_morpher_input_pose" body_morpher_input_image: str = "body_morpher_input_image" body_morpher: str = KEY_BODY_MORPHER body_morpher_output: str = "body_morpher_output" all_outputs: str = "all_outputs" @dataclass class Indices: original_image: int = 0 original_pose: int = 1 class TwoStepPoserComputationProtocol(CachedComputationProtocol): def __init__(self, keys: Optional[Keys] = None, indices: Optional[Indices] = None): super().__init__() if keys is None: keys = Keys() if indices is None: indices = Indices() self.keys = keys self.indices = indices def compute_func(self): def func(state: ComputationState) -> List[Tensor]: return self.get_output(self.keys.all_outputs, state) return func def compute_output(self, key: str, state: ComputationState) -> Any: if key == self.keys.face_morpher_input_image: image = state.batch[self.indices.original_image] center_x = 256 center_y = 128 + 16 return image[:, :, center_y - 64:center_y + 64, center_x - 64:center_x + 64] elif key == self.keys.face_morpher_input_pose: pose = state.batch[self.indices.original_pose] return pose[:, 0:39] elif key == self.keys.face_morpher_output: module = state.modules[self.keys.face_morpher] pose = self.get_output(self.keys.face_morpher_input_pose, state) with torch.no_grad(): return module.forward(pose) elif key == self.keys.body_morpher_input_image: image = state.batch[self.indices.original_image].clone() center_x = 256 center_y = 128 + 16 face_morphed_image = self.get_output(self.keys.face_morpher_output, state) image[:, :, center_y - 64:center_y + 64, center_x - 64:center_x + 64] = face_morphed_image return image elif key == self.keys.body_morpher_output: image = self.get_output(self.keys.body_morpher_input_image, state) pose = state.batch[self.indices.original_pose] body_morpher = state.modules[self.keys.body_morpher] with torch.no_grad(): return body_morpher.forward(image, pose) elif key == self.keys.all_outputs: body_morpher_output = self.get_output(self.keys.body_morpher_output, state) face_morpher_output = self.get_output(self.keys.face_morpher_output, state) return body_morpher_output + [face_morpher_output] else: raise RuntimeError("Unsupported key: " + key) def load_face_morpher(file_name: Optional[str] = None): module = SirenFaceMorpher00( SirenFaceMorpher00Args( image_size=128, image_channels=4, pose_size=39, siren_args=SirenArgs( in_channels=39 + 2, out_channels=4, intermediate_channels=128, num_sine_layers=8))) if file_name is not None: module.load_state_dict(torch_load(file_name)) return module def load_body_morpher(file_name: Optional[str] = None): module = SirenMorpher03( SirenMorpher03Args( image_size=512, image_channels=4, pose_size=45, level_args=[ SirenMorpherLevelArgs( image_size=128, intermediate_channels=360, num_sine_layers=3), SirenMorpherLevelArgs( image_size=256, intermediate_channels=180, num_sine_layers=3), SirenMorpherLevelArgs( image_size=512, intermediate_channels=90, num_sine_layers=3), ])) if file_name is not None: module.load_state_dict(torch_load(file_name)) return module def create_poser( device: torch.device, module_file_names: Optional[Dict[str, str]] = None, default_output_index: int = 0) -> GeneralPoser02: if module_file_names is None: module_file_names = {} if KEY_FACE_MORPHER not in module_file_names: file_name = "data/character_models/lambda_00/face_morpher.pt" module_file_names[KEY_FACE_MORPHER] = file_name if KEY_BODY_MORPHER not in module_file_names: file_name = "data/character_models/lambda_00/body_morpher.pt" module_file_names[KEY_BODY_MORPHER] = file_name loaders = { KEY_FACE_MORPHER: lambda: load_face_morpher(module_file_names[KEY_FACE_MORPHER]), KEY_BODY_MORPHER: lambda: load_body_morpher(module_file_names[KEY_BODY_MORPHER]), } return GeneralPoser02( image_size=512, module_loaders=loaders, pose_parameters=get_pose_parameters().get_pose_parameter_groups(), output_list_func=TwoStepPoserComputationProtocol().compute_func(), subrect=None, device=device, output_length=5 + 1, default_output_index=default_output_index) ================================================ FILE: src/tha4/poser/modes/pose_parameters.py ================================================ from tha4.poser.poser import PoseParameters, PoseParameterCategory def get_pose_parameters(): return PoseParameters.Builder() \ .add_parameter_group("eyebrow_troubled", PoseParameterCategory.EYEBROW, arity=2) \ .add_parameter_group("eyebrow_angry", PoseParameterCategory.EYEBROW, arity=2) \ .add_parameter_group("eyebrow_lowered", PoseParameterCategory.EYEBROW, arity=2) \ .add_parameter_group("eyebrow_raised", PoseParameterCategory.EYEBROW, arity=2) \ .add_parameter_group("eyebrow_happy", PoseParameterCategory.EYEBROW, arity=2) \ .add_parameter_group("eyebrow_serious", PoseParameterCategory.EYEBROW, arity=2) \ .add_parameter_group("eye_wink", PoseParameterCategory.EYE, arity=2) \ .add_parameter_group("eye_happy_wink", PoseParameterCategory.EYE, arity=2) \ .add_parameter_group("eye_surprised", PoseParameterCategory.EYE, arity=2) \ .add_parameter_group("eye_relaxed", PoseParameterCategory.EYE, arity=2) \ .add_parameter_group("eye_unimpressed", PoseParameterCategory.EYE, arity=2) \ .add_parameter_group("eye_raised_lower_eyelid", PoseParameterCategory.EYE, arity=2) \ .add_parameter_group("iris_small", PoseParameterCategory.IRIS_MORPH, arity=2) \ .add_parameter_group("mouth_aaa", PoseParameterCategory.MOUTH, arity=1, default_value=1.0) \ .add_parameter_group("mouth_iii", PoseParameterCategory.MOUTH, arity=1) \ .add_parameter_group("mouth_uuu", PoseParameterCategory.MOUTH, arity=1) \ .add_parameter_group("mouth_eee", PoseParameterCategory.MOUTH, arity=1) \ .add_parameter_group("mouth_ooo", PoseParameterCategory.MOUTH, arity=1) \ .add_parameter_group("mouth_delta", PoseParameterCategory.MOUTH, arity=1) \ .add_parameter_group("mouth_lowered_corner", PoseParameterCategory.MOUTH, arity=2) \ .add_parameter_group("mouth_raised_corner", PoseParameterCategory.MOUTH, arity=2) \ .add_parameter_group("mouth_smirk", PoseParameterCategory.MOUTH, arity=1) \ .add_parameter_group("iris_rotation_x", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \ .add_parameter_group("iris_rotation_y", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \ .add_parameter_group("head_x", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ .add_parameter_group("head_y", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ .add_parameter_group("neck_z", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ .add_parameter_group("body_y", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \ .add_parameter_group("body_z", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \ .add_parameter_group("breathing", PoseParameterCategory.BREATHING, arity=1, range=(0.0, 1.0)) \ .build() ================================================ FILE: src/tha4/poser/poser.py ================================================ from abc import ABC, abstractmethod from enum import Enum from typing import Tuple, List, Optional import torch from torch import Tensor class PoseParameterCategory(Enum): EYEBROW = 1 EYE = 2 IRIS_MORPH = 3 IRIS_ROTATION = 4 MOUTH = 5 FACE_ROTATION = 6 BODY_ROTATION = 7 BREATHING = 8 class PoseParameterGroup: def __init__(self, group_name: str, parameter_index: int, category: PoseParameterCategory, arity: int = 1, discrete: bool = False, default_value: float = 0.0, range: Optional[Tuple[float, float]] = None): assert arity == 1 or arity == 2 if range is None: range = (0.0, 1.0) if arity == 1: parameter_names = [group_name] else: parameter_names = [group_name + "_left", group_name + "_right"] assert len(parameter_names) == arity self.parameter_names = parameter_names self.range = range self.default_value = default_value self.discrete = discrete self.arity = arity self.category = category self.parameter_index = parameter_index self.group_name = group_name def get_arity(self) -> int: return self.arity def get_group_name(self) -> str: return self.group_name def get_parameter_names(self) -> List[str]: return self.parameter_names def is_discrete(self) -> bool: return self.discrete def get_range(self) -> Tuple[float, float]: return self.range def get_default_value(self): return self.default_value def get_parameter_index(self): return self.parameter_index def get_category(self) -> PoseParameterCategory: return self.category class PoseParameters: def __init__(self, pose_parameter_groups: List[PoseParameterGroup]): self.pose_parameter_groups = pose_parameter_groups def get_parameter_index(self, name: str) -> int: index = 0 for parameter_group in self.pose_parameter_groups: for param_name in parameter_group.parameter_names: if name == param_name: return index index += 1 raise RuntimeError("Cannot find parameter with name %s" % name) def get_parameter_name(self, index: int) -> str: assert index >= 0 and index < self.get_parameter_count() for group in self.pose_parameter_groups: if index < group.get_arity(): return group.get_parameter_names()[index] index -= group.arity raise RuntimeError("Something is wrong here!!!") def get_pose_parameter_groups(self): return self.pose_parameter_groups def get_parameter_count(self): count = 0 for group in self.pose_parameter_groups: count += group.arity return count class Builder: def __init__(self): self.index = 0 self.pose_parameter_groups = [] def add_parameter_group(self, group_name: str, category: PoseParameterCategory, arity: int = 1, discrete: bool = False, default_value: float = 0.0, range: Optional[Tuple[float, float]] = None): self.pose_parameter_groups.append( PoseParameterGroup( group_name, self.index, category, arity, discrete, default_value, range)) self.index += arity return self def build(self) -> 'PoseParameters': return PoseParameters(self.pose_parameter_groups) class Poser(ABC): @abstractmethod def get_image_size(self) -> int: pass @abstractmethod def get_output_length(self) -> int: pass @abstractmethod def get_pose_parameter_groups(self) -> List[PoseParameterGroup]: pass @abstractmethod def get_num_parameters(self) -> int: pass @abstractmethod def pose(self, image: Tensor, pose: Tensor, output_index: int = 0) -> Tensor: pass @abstractmethod def get_posing_outputs(self, image: Tensor, pose: Tensor) -> List[Tensor]: pass def get_dtype(self) -> torch.dtype: return torch.float @abstractmethod def to(self, device: torch.device): pass ================================================ FILE: src/tha4/pytasuku/__init__.py ================================================ ================================================ FILE: src/tha4/pytasuku/indexed/__init__.py ================================================ ================================================ FILE: src/tha4/pytasuku/indexed/all_tasks.py ================================================ from typing import Iterable from tha4.pytasuku.workspace import Workspace from tha4.pytasuku.indexed.indexed_tasks import IndexedTasks from tha4.pytasuku.indexed.no_index_command_tasks import NoIndexCommandTasks class AllTasks(NoIndexCommandTasks): def __init__( self, workspace: Workspace, prefix: str, tasks: Iterable[IndexedTasks], command_name: str = "all", define_tasks_immediately: bool = True): super().__init__(workspace, prefix, command_name, define_tasks_immediately) self.tasks = [t for t in tasks] if define_tasks_immediately: self.define_tasks() def execute_run_command(self): for task in self.tasks: self.workspace.run(task.run_command) def execute_clean_command(self): for task in self.tasks: self.workspace.run(task.clean_command) ================================================ FILE: src/tha4/pytasuku/indexed/bundled_indexed_file_tasks.py ================================================ import abc from typing import Iterable, List from tha4.pytasuku.workspace import Workspace from tha4.pytasuku.indexed.indexed_tasks import IndexedTasks from tha4.pytasuku.workspace import do_nothing class BundledIndexedTasks: __metaclass__ = abc.ABCMeta @property @abc.abstractmethod def indexed_tasks_command_names(self) -> Iterable[str]: pass @abc.abstractmethod def get_indexed_tasks(self, command_name) -> IndexedTasks: pass def define_all_tasks_from_list(workspace: Workspace, prefix: str, tasks: List[BundledIndexedTasks]): for command_name in tasks[0].indexed_tasks_command_names: workspace.create_command_task( prefix + "/" + command_name, [x.get_indexed_tasks(command_name).run_command for x in tasks], do_nothing) workspace.create_command_task( prefix + "/" + command_name + "_clean", [x.get_indexed_tasks(command_name).clean_command for x in tasks], do_nothing) ================================================ FILE: src/tha4/pytasuku/indexed/indexed_file_tasks.py ================================================ import abc from typing import List from tha4.pytasuku.workspace import Workspace from tha4.pytasuku.indexed.indexed_tasks import IndexedTasks class IndexedFileTasks(IndexedTasks, abc.ABC): def __init__(self, workspace: Workspace, prefix: str): super().__init__(workspace, prefix) @property @abc.abstractmethod def file_list(self) -> List[str]: pass @abc.abstractmethod def get_file_name(self, *indices: int) -> str: pass ================================================ FILE: src/tha4/pytasuku/indexed/indexed_tasks.py ================================================ import abc from typing import List from tha4.pytasuku.workspace import Workspace class IndexedTasks(abc.ABC): def __init__(self, workspace: Workspace, prefix: str): self.prefix = prefix self.workspace = workspace @property @abc.abstractmethod def run_command(self) -> str: pass @property @abc.abstractmethod def clean_command(self) -> str: pass @property @abc.abstractmethod def shape(self) -> List[int]: pass @property @abc.abstractmethod def arity(self) -> int: pass @abc.abstractmethod def define_tasks(self): pass ================================================ FILE: src/tha4/pytasuku/indexed/no_index_command_tasks.py ================================================ import abc from typing import List from tha4.pytasuku.workspace import Workspace from tha4.pytasuku.indexed.indexed_tasks import IndexedTasks class NoIndexCommandTasks(IndexedTasks, abc.ABC): def __init__(self, workspace: Workspace, prefix: str, command_name: str, define_tasks_immediately: bool = True): super().__init__(workspace, prefix) self.command_name = command_name if define_tasks_immediately: self.define_tasks() @property def run_command(self): return self.prefix + "/" + self.command_name @property def clean_command(self): return self.prefix + "/" + self.command_name + "_clean" @property def arity(self) -> int: return 0 @property def shape(self) -> List[int]: return [] @abc.abstractmethod def execute_run_command(self): pass @abc.abstractmethod def execute_clean_command(self): pass def define_tasks(self): self.workspace.create_command_task(self.run_command, [], self.execute_run_command) self.workspace.create_command_task(self.clean_command, [], self.execute_clean_command) ================================================ FILE: src/tha4/pytasuku/indexed/no_index_file_tasks.py ================================================ import abc from typing import List from tha4.pytasuku.workspace import Workspace from tha4.pytasuku.indexed.indexed_file_tasks import IndexedFileTasks from tha4.pytasuku.indexed.util import delete_file class NoIndexFileTasks(IndexedFileTasks, abc.ABC): def __init__(self, workspace: Workspace, prefix: str, command_name: str, define_tasks_immediately: bool = True): super().__init__(workspace, prefix) self.command_name = command_name if define_tasks_immediately: self.define_tasks() @property @abc.abstractmethod def file_name(self): pass @abc.abstractmethod def create_file_task(self): pass def get_file_name(self, *indices: int) -> str: if len(indices) > 0: raise IndexError("NoIndexFileTasks has arity 0, but get_file_name is called with an index.") return self.file_name @property def run_command(self): return self.prefix + "/" + self.command_name @property def clean_command(self): return self.prefix + "/" + self.command_name + "_clean" @property def arity(self) -> int: return 0 @property def shape(self) -> List[int]: return [] @property def file_list(self) -> List[str]: return [self.file_name] def clean(self): delete_file(self.file_name) def define_tasks(self): self.create_file_task() self.workspace.create_command_task(self.run_command, [self.file_name]) self.workspace.create_command_task(self.clean_command, [], lambda: self.clean()) ================================================ FILE: src/tha4/pytasuku/indexed/one_index_file_tasks.py ================================================ import abc from typing import List from tha4.pytasuku.workspace import Workspace from tha4.pytasuku.indexed.indexed_file_tasks import IndexedFileTasks from tha4.pytasuku.indexed.util import delete_file class OneIndexFileTasks(IndexedFileTasks, abc.ABC): def __init__(self, workspace: Workspace, prefix: str, command_name: str, count: int, define_tasks_immediately: bool = True): super().__init__(workspace, prefix) self.command_name = command_name self.count = count self.file_list_ = [] if define_tasks_immediately: self.define_tasks() @property def run_command(self) -> str: return self.prefix + "/" + self.command_name @property def clean_command(self) -> str: return self.prefix + "/" + self.command_name + "_clean" @property def shape(self) -> List[int]: return [self.count] @property def arity(self) -> int: return 1 @abc.abstractmethod def file_name(self, index): pass @abc.abstractmethod def create_file_tasks(self, index): pass def get_file_name(self, *indices: int) -> str: if len(indices) != 1: raise IndexError("OneIndexFileTasks has arity 1, but " "get_file_name does not get the appropriate number of arguments.") return self.file_name(indices[0]) @property def file_list(self): if len(self.file_list_) == 0: for i in range(self.count): self.file_list_.append(self.file_name(i)) return self.file_list_ def clean(self): for file in self.file_list: delete_file(file) def define_tasks(self): for index in range(self.count): self.create_file_tasks(index) dependencies = self.file_list self.workspace.create_command_task(self.run_command, dependencies) self.workspace.create_command_task(self.clean_command, [], lambda: self.clean()) ================================================ FILE: src/tha4/pytasuku/indexed/simple_no_index_file_tasks.py ================================================ from typing import Callable, List, Optional from tha4.pytasuku.workspace import Workspace from tha4.pytasuku.indexed.no_index_file_tasks import NoIndexFileTasks class SimpleNoIndexFileTasks(NoIndexFileTasks): def __init__(self, workspace: Workspace, prefix: str, command_name: str, file_name: str, run_func: Callable[[], None], dependencies: Optional[List[str]] = None): super().__init__(workspace, prefix, command_name, define_tasks_immediately=False) if dependencies is None: dependencies = [] self.run_func = run_func self._file_name = file_name self.dependencies = dependencies self.define_tasks() @property def file_name(self): return self._file_name def create_file_task(self): self.workspace.create_file_task(self.file_name, self.dependencies, self.run_func) ================================================ FILE: src/tha4/pytasuku/indexed/two_indices_file_tasks.py ================================================ import abc from typing import List from tha4.pytasuku.workspace import Workspace from tha4.pytasuku.indexed.indexed_file_tasks import IndexedFileTasks from tha4.pytasuku.indexed.util import delete_file class TwoIndicesFileTasks(IndexedFileTasks, abc.ABC): def __init__(self, workspace: Workspace, prefix: str, command_name: str, count0: int, count1: int, define_tasks_immediately: bool = True): super().__init__(workspace, prefix) self.count1 = count1 self.count0 = count0 self.command_name = command_name self.file_list_ = [] if define_tasks_immediately: self.define_tasks() @property def run_command(self) -> str: return self.prefix + "/" + self.command_name @property def clean_command(self) -> str: return self.prefix + "/" + self.command_name + "_clean" @property def shape(self) -> List[int]: return [self.count0, self.count1] @property def arity(self) -> int: return 2 @abc.abstractmethod def file_name(self, index0: int, index1: int) -> str: pass @property def file_list(self) -> List[str]: if len(self.file_list_) == 0: for i in range(self.count0): for j in range(self.count1): self.file_list_.append(self.file_name(i, j)) return self.file_list_ @abc.abstractmethod def create_file_tasks(self, index0: int, index1: int): pass def get_file_name(self, *indices: int) -> str: if len(indices) != 2: raise IndexError("TwoIndicesFileTasks.get_file_name require two indices, " + "but not exactly 2 indices were provide") return self.file_name(indices[0], indices[1]) def clean(self): for file in self.file_list: delete_file(file) def define_tasks(self): for index0 in range(self.count0): for index1 in range(self.count1): self.create_file_tasks(index0, index1) self.workspace.create_command_task(self.run_command, self.file_list) self.workspace.create_command_task(self.clean_command, [], lambda: self.clean()) ================================================ FILE: src/tha4/pytasuku/indexed/util.py ================================================ import os from typing import Iterable, Dict, Callable, List from tha4.pytasuku.workspace import Workspace from tha4.pytasuku.indexed.all_tasks import AllTasks from tha4.pytasuku.indexed.indexed_tasks import IndexedTasks def delete_file(file_name): if os.path.exists(file_name): os.remove(file_name) print("[delete] " + file_name) else: print("[not exist] " + file_name) def all_tasks_from_named_tasks_map( workspace: Workspace, prefix: str, tasks: Iterable[Dict[str, IndexedTasks]], define_all_tasks: bool = True) \ -> Dict[str, IndexedTasks]: subtasks = [x for x in tasks] name_to_subtask_list = {} for a_subtasks in subtasks: for name in a_subtasks: if not define_all_tasks and name == "all": continue if name not in name_to_subtask_list: name_to_subtask_list[name] = [] name_to_subtask_list[name].append(a_subtasks[name]) output = {} for name in name_to_subtask_list: output[name] = AllTasks(workspace, prefix, name_to_subtask_list[name], name) return output def create_tasks_hierarchy_helper( workspace: Workspace, prefix: str, tasks_func: Callable[[Workspace, str, List[str]], Dict[str, IndexedTasks]], branches: List[List[str]], path: List[str]): if len(branches) == 0: return tasks_func(workspace, prefix, path) else: tasks = {} for branch in branches[0]: output_tasks = create_tasks_hierarchy_helper( workspace, f"{prefix}/{branch}", tasks_func, branches[1:], path + [branch]) if output_tasks is not None: tasks[branch] = output_tasks return all_tasks_from_named_tasks_map(workspace, prefix, tasks.values()) def create_task_hierarchy( workspace: Workspace, prefix: str, tasks_func: Callable[[Workspace, str, List[str]], Dict[str, IndexedTasks]], branches: List[List[str]]) -> Dict[str, IndexedTasks]: return create_tasks_hierarchy_helper(workspace, prefix, tasks_func, branches, []) def write_done_file(file_name: str): os.makedirs(os.path.dirname(file_name), exist_ok=True) with open(file_name, "wt") as fout: fout.write("DONE!!!") ================================================ FILE: src/tha4/pytasuku/task.py ================================================ import os import logging from typing import List class Task: def __init__(self, workspace: 'Workspace', name: str, dependencies: List[str]): self._workspace = workspace self._name = name self._dependencies = dependencies self._workspace.add_task(self) def run(self): pass @property def can_run(self) -> bool: return True @property def needs_to_be_run(self) -> bool: return False @property def name(self) -> str: return self._name @property def dependencies(self) -> List[str]: return self._dependencies @property def workspace(self) -> 'Workspace': return self._workspace @property def timestamp(self) -> float: return float("inf") class CommandTask(Task): def __init__(self, workspace, name, dependencies): super().__init__(workspace, name, dependencies) @property def needs_to_be_run(self): return True class PlaceholderTask(Task): def __init__(self, workspace, name): super().__init__(workspace, name, []) @property def can_run(self): return False def run(self): raise Exception("A placeholder task cannot be run! (%s)" % self.name) @property def needs_to_be_run(self): return not os.path.isfile(self.name) @property def timestamp(self) -> float: if not os.path.isfile(self.name): return float("inf") else: return os.path.getmtime(self.name) class FileTask(Task): def __init__(self, workspace, name, dependencies): super().__init__(workspace, name, dependencies) @property def timestamp(self): return os.path.getmtime(self.name) @property def needs_to_be_run(self): if not os.path.isfile(self.name): logging.info("Task %s will be run because the corresponding file does not exist." % self.name) return True for dep in self.dependencies: if self.workspace.needs_to_run(dep): logging.info("Task %s will be run because dependency %s also needs to be run." % (self.name, dep)) return True else: self_timestamp = self.timestamp dep_task = self.workspace.get_task(dep) if dep_task.timestamp > self_timestamp: if isinstance(dep_task, FileTask) or isinstance(dep_task, PlaceholderTask): logging.info("Task %s needs to be run because task %s has later timestamp." % (self.name, dep)) elif isinstance(dep_task, CommandTask): logging.info("Task %s needs to be run because task %s is a command." % (self.name, dep)) return True return False ================================================ FILE: src/tha4/pytasuku/task_selector_ui.py ================================================ from tkinter import Tk, BOTH, Button, RIGHT, Scrollbar from tkinter.ttk import Frame, Treeview from tha4.pytasuku.workspace import Workspace, PlaceholderTask class TaskSelectorUi(Frame): def __init__(self, root, workspace: Workspace): super().__init__() self.root = root self.workspace = workspace self.master.title("Tasks") self.master.geometry("256x512") treeview_frame = Frame(self) treeview_frame.pack(fill=BOTH, expand=True) self.treeview = Treeview(treeview_frame) self.treeview["columns"] = ("task_name") self.treeview.column("#0", width=256, minwidth=256) self.treeview.heading("#0", text="Tree") self.treeview.heading("task_name", text="Task Name") treeview_vertical_scroll = Scrollbar(treeview_frame, orient='vertical', command=self.treeview.yview) self.treeview.configure(yscrollcommand=treeview_vertical_scroll.set) treeview_vertical_scroll.pack(side=RIGHT, fill='y') self.treeview.pack(fill=BOTH, expand=True) treeview_horizontal_scroll = Scrollbar(treeview_frame, orient='horizontal', command=self.treeview.xview) self.treeview.configure(xscrollcommand=treeview_horizontal_scroll.set) treeview_horizontal_scroll.pack(fill='x') self.add_tree_nodes() self.execute_button = Button(self, text="Execute!", command=self.run_selected_task) self.execute_button.pack(side=RIGHT, padx=5, pady=5) self.pack(fill=BOTH, expand=True) self.selected_task_name = None def add_tree_nodes(self): nodes = {} for task in self.workspace._tasks.values(): if isinstance(task, PlaceholderTask): continue comps = task.name.split('/') for i in range(1, len(comps)): assert len(comps) > 0 prefix = "" index = 0 for comp in comps: index = index + 1 parent = prefix if prefix == "" and comp == "": prefix = "/" elif prefix == "": prefix = prefix + comp elif prefix == "/": prefix = prefix + comp else: prefix = prefix + "/" + comp if prefix in nodes: continue if index == len(comps): data = prefix else: data = "" if prefix == "/": comp = "/" nodes[prefix] = { "name": str(prefix), "display_name": comp, "parent": parent, "data": data } sorted_node_names = sorted(nodes.keys()) node_index = {} for name in sorted_node_names: node = nodes[name] if node["parent"] == "": id = self.treeview.insert("", "end", text=node["display_name"], values=node["data"], ) else: parent = node_index[node["parent"]] id = self.treeview.insert(parent, "end", text=node["display_name"], values=node["data"], ) node_index[node["name"]] = id def run_selected_task(self): selection = self.treeview.selection() item = self.treeview.item(selection) if item['values'] == "": return task_name = item["values"][0] self.selected_task_name = task_name self.root.destroy() def run_task_selector_ui(workspace: Workspace): root = Tk() task_selector_ui = TaskSelectorUi(root, workspace=workspace) root.mainloop() task_name = task_selector_ui.selected_task_name if task_name is not None: print("Running", task_name, "...") with workspace.session(): workspace.run(task_name) ================================================ FILE: src/tha4/pytasuku/util.py ================================================ import os.path from typing import List import logging from tha4.pytasuku.workspace import Workspace def create_delete_all_task(workspace: Workspace, name: str, files: List[str]): def delete_all(): for file in files: if os.path.exists(file): logging.info("Removing %s ..." % file) os.remove(file) workspace.create_command_task(name, [], delete_all) ================================================ FILE: src/tha4/pytasuku/workspace.py ================================================ from contextlib import contextmanager from enum import Enum from typing import List from tha4.pytasuku.task import Task, CommandTask, FileTask, PlaceholderTask class WorkspaceState(Enum): OUT_OF_SESSION = 1 IN_SESSION = 2 class NodeState(Enum): IN_STACK = 1 VISITED = 2 class FuncCommandTask(CommandTask): def __init__(self, workspace, name, dependencies, func): super().__init__(workspace, name, dependencies) self._func = func def run(self): self._func() class FuncFileTask(FileTask): def __init__(self, workspace, name, dependencies, func): super().__init__(workspace, name, dependencies) self._func = func def run(self): self._func() def do_nothing(): pass class Workspace: def __init__(self): self._tasks = dict() self._name_to_done = None self._state = WorkspaceState.OUT_OF_SESSION self._modified = False @property def modified(self) -> bool: return self._modified @property def state(self) -> WorkspaceState: return self._state @property def in_session(self) -> bool: return self._state == WorkspaceState.IN_SESSION def task_exists(self, name: str) -> bool: return name in self._tasks def task_exists_and_not_placeholder(self, name: str) -> bool: return self.task_exists(name) and not isinstance(self.get_task(name), PlaceholderTask) def get_task(self, name: str) -> Task: return self._tasks[name] def add_task(self, task): if self.in_session: raise RuntimeError("New tasks can only be created when the workspace is out of session.") if isinstance(task, PlaceholderTask): if not self.task_exists(task.name): self._tasks[task.name] = task self._modified = True else: self._tasks[task.name] = task for dep in task.dependencies: PlaceholderTask(self, dep) self._modified = True def start_session(self): if self.in_session: raise RuntimeError("A session can only be started when the workspace is out of session.") if self.modified: self.check_cycle() self._state = WorkspaceState.IN_SESSION self._name_to_done = dict() self._modified = False def end_session(self): if not self.in_session: raise RuntimeError("A session can only be ended when the workspace is in session.") self._state = WorkspaceState.OUT_OF_SESSION self._name_to_done = None @contextmanager def session(self): try: self.start_session() yield finally: self.end_session() def check_cycle(self): node_states = dict() for name in self._tasks: if name not in node_states: self.dfs(name, node_states) def dfs(self, name, node_states): node_states[name] = NodeState.IN_STACK task = self.get_task(name) for dep in task.dependencies: if dep not in node_states: self.dfs(dep, node_states) else: state = node_states[dep] if state == NodeState.IN_STACK: raise RuntimeError("Dicovered cyclic dependency!") node_states[name] = NodeState.VISITED def run(self, name): if not self.in_session: raise RuntimeError("A task can only be run when the workspace is in session.") if not self.task_exists(name): raise RuntimeError("Task %s does not exists" % name) self.run_helper(name) def run_helper(self, name): task = self.get_task(name) for dep in task.dependencies: if self.needs_to_run(dep): self.run_helper(dep) if self.needs_to_run(name): task.run() self._name_to_done[name] = True def needs_to_run(self, name): if not self.in_session: raise RuntimeError("You can only check whether a task needs to run when the workspace is in session.") if name in self._name_to_done: return not self._name_to_done[name] task = self.get_task(name) need_to_run_value = task.needs_to_be_run self._name_to_done[name] = not need_to_run_value return need_to_run_value def create_command_task(self, name, dependencies, func=do_nothing): return FuncCommandTask(self, name, dependencies, func) def create_file_task(self, name, dependencies, func): return FuncFileTask(self, name, dependencies, func) def command_task(workspace: Workspace, name: str, dependencies: List[str]): def func(f): workspace.create_command_task(name, dependencies, f) return f return func def file_task(workspace: Workspace, name: str, dependencies: List[str]): def func(f): workspace.create_file_task(name, dependencies, f) return f return func ================================================ FILE: src/tha4/sampleoutput/__init__.py ================================================ ================================================ FILE: src/tha4/sampleoutput/general_sample_output_protocol.py ================================================ import os from enum import Enum from typing import List, Dict import PIL.Image import numpy import torch from tha4.shion.base.dataset.util import get_indexed_batch from tha4.shion.base.image_util import pytorch_rgb_to_numpy_image, pytorch_rgba_to_numpy_image from tha4.shion.core.cached_computation import TensorCachedComputationFunc, ComputationState from tha4.shion.core.training.sample_output_protocol import SampleOutputProtocol from tha4.image_util import grid_change_to_numpy_image from torch.nn import Module from torch.nn.functional import interpolate from torch.utils.data import Dataset class ImageType(Enum): COLOR = 0 ALPHA = 1 GRID_CHANGE = 2 SIGMOID_LOGIT = 3 class SampleImageSpec: def __init__(self, value_func: TensorCachedComputationFunc, image_type: ImageType): self.value_func = value_func self.image_type = image_type class SampleImageSaver: def __init__(self, cell_size: int, image_channels: int, sample_image_specs: List[SampleImageSpec]): super().__init__() self.sample_image_specs = sample_image_specs self.cell_size = cell_size self.image_channels = image_channels def save_sample_output_data(self, state: ComputationState, prefix: str, examples_seen_so_far: int): num_cols = len(self.sample_image_specs) num_rows = state.batch[0].shape[0] output_image = numpy.zeros([self.cell_size * num_rows, self.cell_size * num_cols, self.image_channels]) for col in range(num_cols): spec = self.sample_image_specs[col] images = spec.value_func(state) start_col = col * self.cell_size for image_index in range(num_rows): image = images[image_index].clone().detach() row = image_index start_row = row * self.cell_size if spec.image_type == ImageType.COLOR: c, h, w = image.shape green_screen = torch.ones(3, h, w, device=image.device) * -1.0 green_screen[1, :, :] = 1.0 alpha = (image[3:4, :, :] + 1.0) * 0.5 image[0:3, :, :] = image[0:3, :, :] * alpha + green_screen * (1 - alpha) image[3:4, :, :] = 1.0 image = image.cpu() elif spec.image_type == ImageType.GRID_CHANGE: image = image.cpu() elif spec.image_type == ImageType.SIGMOID_LOGIT: image = torch.sigmoid(image) image = image.repeat(self.image_channels, 1, 1) image = image * 2.0 - 1.0 image = image.cpu() elif spec.image_type == ImageType.ALPHA: if image.shape[0] == 1: image = image.repeat(self.image_channels, 1, 1) image = image * 2.0 - 1.0 image = image.cpu() else: raise RuntimeError(f"Unsupported image type: {spec.image_type}") output_image[start_row:start_row + self.cell_size, start_col:start_col + self.cell_size, :] \ = self.convert_to_numpy_image(image) file_name = "%s/sample_output_%010d.png" % (prefix, examples_seen_so_far) os.makedirs(os.path.dirname(file_name), exist_ok=True) if self.image_channels == 3: mode = 'RGB' else: mode = 'RGBA' pil_image = PIL.Image.fromarray(numpy.uint8(numpy.rint(output_image * 255.0)), mode=mode) pil_image.save(file_name) print("Saved %s" % file_name) def convert_to_numpy_image(self, image: torch.Tensor): image_size = image.shape[-1] if self.cell_size != image_size: image = interpolate(image.unsqueeze(0), size=self.cell_size).squeeze(0) if image.shape[0] == 2: return grid_change_to_numpy_image(image, num_channels=self.image_channels) elif self.image_channels == 3: return pytorch_rgb_to_numpy_image(image) else: return pytorch_rgba_to_numpy_image(image) class GeneralSampleOutputProtocol(SampleOutputProtocol): def __init__(self, sample_image_specs: List[SampleImageSpec], num_images: int = 8, cell_size: int = 256, image_channels: int = 4, examples_per_sample_output: int = 5000, random_seed: int = 1203040687): super().__init__() self.num_images = num_images self.random_seed = random_seed self.examples_per_sample_output = examples_per_sample_output self.sample_image_saver = SampleImageSaver(cell_size, image_channels, sample_image_specs) def get_examples_per_sample_output(self) -> int: return self.examples_per_sample_output def get_random_seed(self) -> int: return self.random_seed def get_sample_output_data(self, validation_dataset: Dataset, device: torch.device) -> dict: example_indices = torch.randint(0, len(validation_dataset), (self.num_images,)) example_indices = [example_indices[i].item() for i in range(self.num_images)] batch = get_indexed_batch(validation_dataset, example_indices, device) return {'batch': batch} def save_sample_output_data(self, modules: Dict[str, Module], accumulated_modules: Dict[str, Module], sample_output_data: dict, prefix: str, examples_seen_so_far: int, device: torch.device): for key in modules: modules[key].train(False) batch = sample_output_data['batch'] state = ComputationState(modules, accumulated_modules, batch, {}) self.sample_image_saver.save_sample_output_data(state, prefix, examples_seen_so_far) ================================================ FILE: src/tha4/sampleoutput/poser_sampler_output_protocol.py ================================================ from typing import Optional, List, Dict import torch from torch.nn import Module from torch.utils.data import Dataset from tha4.shion.base.dataset.util import get_indexed_batch from tha4.shion.core.cached_computation import CachedComputationFunc, ComputationState from tha4.shion.core.training.sample_output_protocol import SampleOutputProtocol from tha4.sampleoutput.sample_image_creator import SampleImageSpec, ImageSource, ImageType, SampleImageSaver class PoserSampleOutputProtocol(SampleOutputProtocol): def __init__(self, output_list_func: Optional[CachedComputationFunc] = None, num_images: int = 8, image_size: int = 256, cell_size: int = 256, image_channels: int = 4, examples_per_sample_output: int = 5000, sample_image_specs: Optional[List[SampleImageSpec]] = None, random_seed: int = 1203040687): super().__init__() self.num_images = num_images self.random_seed = random_seed self.examples_per_sample_output = examples_per_sample_output self.output_list_func = output_list_func if sample_image_specs is None: sample_image_specs = [ SampleImageSpec(ImageSource.BATCH, 0, ImageType.COLOR), SampleImageSpec(ImageSource.BATCH, 2, ImageType.COLOR), SampleImageSpec(ImageSource.OUTPUT, 0, ImageType.COLOR), SampleImageSpec(ImageSource.OUTPUT, 1, ImageType.COLOR), SampleImageSpec(ImageSource.OUTPUT, 2, ImageType.ALPHA), SampleImageSpec(ImageSource.BATCH, 3, ImageType.COLOR), SampleImageSpec(ImageSource.OUTPUT, 3, ImageType.COLOR), SampleImageSpec(ImageSource.OUTPUT, 4, ImageType.COLOR), SampleImageSpec(ImageSource.OUTPUT, 5, ImageType.ALPHA), SampleImageSpec(ImageSource.OUTPUT, 6, ImageType.COLOR), SampleImageSpec(ImageSource.BATCH, 4, ImageType.COLOR), SampleImageSpec(ImageSource.OUTPUT, 7, ImageType.COLOR), SampleImageSpec(ImageSource.OUTPUT, 8, ImageType.COLOR), SampleImageSpec(ImageSource.OUTPUT, 9, ImageType.ALPHA), SampleImageSpec(ImageSource.OUTPUT, 10, ImageType.COLOR), ] self.sample_image_saver = SampleImageSaver(image_size, cell_size, image_channels, sample_image_specs) def get_examples_per_sample_output(self) -> int: return self.examples_per_sample_output def get_random_seed(self) -> int: return self.random_seed def get_sample_output_data(self, validation_dataset: Dataset, device: torch.device) -> dict: example_indices = torch.randint(0, len(validation_dataset), (self.num_images,)) example_indices = [example_indices[i].item() for i in range(self.num_images)] batch = get_indexed_batch(validation_dataset, example_indices, device) return {'batch': batch} def save_sample_output_data(self, modules: Dict[str, Module], accumulated_modules: Dict[str, Module], sample_output_data: dict, prefix: str, examples_seen_so_far: int, device: torch.device): for key in modules: modules[key].train(False) batch = sample_output_data['batch'] with torch.no_grad(): outputs = self.output_list_func(ComputationState(modules, accumulated_modules, batch)) self.sample_image_saver.save_sample_output_data(batch, outputs, prefix, examples_seen_so_far) ================================================ FILE: src/tha4/sampleoutput/sample_image_creator.py ================================================ import math import os from enum import Enum from typing import List import numpy import torch from matplotlib import cm from torch import Tensor from torch.nn.functional import interpolate from tha4.shion.base.image_util import save_numpy_image class ImageSource(Enum): BATCH = 0 OUTPUT = 1 class ImageType(Enum): COLOR = 0 ALPHA = 1 GRID_CHANGE = 2 SIGMOID_LOGIT = 3 class SampleImageSpec: def __init__(self, image_source: ImageSource, index: int, image_type: ImageType): self.image_type = image_type self.index = index self.image_source = image_source def torch_rgb_to_numpy_image(torch_image: Tensor, min_pixel_value=-1.0, max_pixel_value=1.0): assert torch_image.dim() == 3 assert torch_image.shape[0] == 3 height = torch_image.shape[1] width = torch_image.shape[2] reshaped_image = torch_image.numpy().reshape(3, height * width).transpose().reshape(height, width, 3) numpy_image = (reshaped_image - min_pixel_value) / (max_pixel_value - min_pixel_value) return numpy_image def torch_rgba_to_numpy_image(torch_image: Tensor, min_pixel_value=-1.0, max_pixel_value=1.0): assert torch_image.dim() == 3 assert torch_image.shape[0] == 4 height = torch_image.shape[1] width = torch_image.shape[2] reshaped_image = torch_image.numpy().reshape(4, height * width).transpose().reshape(height, width, 4) numpy_image = (reshaped_image - min_pixel_value) / (max_pixel_value - min_pixel_value) numpy_image = numpy.clip(numpy_image, 0.0, 1.0) return numpy_image def torch_grid_change_to_numpy_image(torch_image, num_channels=3): height = torch_image.shape[1] width = torch_image.shape[2] size_image = (torch_image[0, :, :] ** 2 + torch_image[1, :, :] ** 2).sqrt().view(height, width, 1).numpy() hsv = cm.get_cmap('hsv') angle_image = hsv(((torch.atan2( torch_image[0, :, :].view(height * width), torch_image[1, :, :].view(height * width)).view(height, width) + math.pi) / (2 * math.pi)).numpy()) * 3 numpy_image = size_image * angle_image[:, :, 0:3] if num_channels == 3: return numpy_image elif num_channels == 4: return numpy.concatenate([numpy_image, numpy.ones_like(size_image)], axis=2) else: raise RuntimeError("Unsupported num_channels: " + str(num_channels)) class SampleImageSaver: def __init__(self, image_size: int, cell_size: int, image_channels: int, sample_image_specs: List[SampleImageSpec]): super().__init__() self.sample_image_specs = sample_image_specs self.cell_size = cell_size self.image_channels = image_channels self.image_size = image_size def save_sample_output_image(self, batch: List[Tensor], outputs: List[Tensor], file_name: str): num_cols = len(self.sample_image_specs) num_rows = batch[0].shape[0] output_image = numpy.zeros([self.cell_size * num_rows, self.cell_size * num_cols, self.image_channels]) for image_index in range(num_rows): row = image_index start_row = row * self.cell_size for col in range(num_cols): spec = self.sample_image_specs[col] start_col = col * self.cell_size if spec.image_source == ImageSource.BATCH: image = batch[spec.index][image_index].clone() else: image = outputs[spec.index][image_index].clone() if spec.image_type == ImageType.COLOR: c, h, w = image.shape green_screen = torch.ones(3, h, w, device=image.device) * -1.0 green_screen[1, :, :] = 1.0 alpha = (image[3:4, :, :] + 1.0) * 0.5 image[0:3, :, :] = image[0:3, :, :] * alpha + green_screen * (1 - alpha) image[3:4, :, :] = 1.0 image = image.detach().cpu() elif spec.image_type == ImageType.GRID_CHANGE: image = image.detach().cpu() elif spec.image_type == ImageType.SIGMOID_LOGIT: image = torch.sigmoid(image) image = image.repeat(self.image_channels, 1, 1) image = image * 2.0 - 1.0 image = image.detach().cpu() else: if image.shape[0] == 1: image = image.repeat(self.image_channels, 1, 1) image = image * 2.0 - 1.0 image = image.detach().cpu() output_image[start_row:start_row + self.cell_size, start_col:start_col + self.cell_size, :] \ = self.convert_to_numpy_image(image) os.makedirs(os.path.dirname(file_name), exist_ok=True) save_numpy_image(output_image, file_name, save_straight_alpha=True) def save_sample_output_data(self, batch: List[Tensor], outputs: List[Tensor], prefix: str, examples_seen_so_far: int): file_name = "%s/sample_output_%010d.png" % (prefix, examples_seen_so_far) self.save_sample_output_image(batch, outputs, file_name) def convert_to_numpy_image(self, image: torch.Tensor): if self.cell_size != self.image_size: image = interpolate(image.unsqueeze(0), size=self.cell_size).squeeze(0) if image.shape[0] == 2: return torch_grid_change_to_numpy_image(image, num_channels=self.image_channels) elif self.image_channels == 3: return torch_rgb_to_numpy_image(image) else: return torch_rgba_to_numpy_image(image) ================================================ FILE: src/tha4/shion/__init__.py ================================================ ================================================ FILE: src/tha4/shion/base/__init__.py ================================================ ================================================ FILE: src/tha4/shion/base/dataset/__init__.py ================================================ ================================================ FILE: src/tha4/shion/base/dataset/lazy_dataset.py ================================================ from typing import Callable from torch.utils.data import Dataset class LazyDataset(Dataset): def __init__(self, source_func: Callable[[], Dataset]): self.source_func = source_func self.source = None def get_source(self): if self.source is None: self.source = self.source_func() return self.source def __len__(self): return len(self.get_source()) def __getitem__(self, item): return self.get_source()[item] ================================================ FILE: src/tha4/shion/base/dataset/lazy_tensor_dataset.py ================================================ import torch from torch.utils.data import Dataset, TensorDataset from tha4.shion.core.load_save import torch_load class LazyTensorDataset(Dataset): def __init__(self, file_name: str): self.file_name = file_name self.dataset = None def get_dataset(self): if self.dataset is None: data = torch_load(self.file_name) if isinstance(data, torch.Tensor): self.dataset = TensorDataset(data) elif isinstance(data, tuple): self.dataset = TensorDataset(*data) elif isinstance(data, list): self.dataset = TensorDataset(*data) else: raise RuntimeError("Unsupported data type: " + type(data)) return self.dataset def __len__(self): dataset = self.get_dataset() return len(dataset) def __getitem__(self, item): dataset = self.get_dataset() return dataset.__getitem__(item) ================================================ FILE: src/tha4/shion/base/dataset/png_in_dir_dataset.py ================================================ import os from torch.nn import functional from torch.utils.data import Dataset from os import listdir from os.path import isfile from tha4.shion.base.image_util import extract_pytorch_image_from_filelike class PngInDirDataset(Dataset): def __init__(self, dir: str, downscale_kernel_size: int = 1, has_alpha=False, scale=2.0, offset=-1.0, premultiply_alpha=True, perfrom_srb_to_linear=True): super().__init__() self.perfrom_srb_to_linear = perfrom_srb_to_linear self.premultiply_alpha = premultiply_alpha self.offset = offset self.scale = scale self.has_alpha = has_alpha self.downscale_kernel_size = downscale_kernel_size self.dir = dir self.file_names = None def get_file_names(self): if self.file_names is None: self.file_names = [os.path.join(self.dir, x) for x in listdir(self.dir)] self.file_names = [x for x in self.file_names if isfile(x) and x[-4:].lower() == ".png"] self.file_names = sorted(self.file_names) return self.file_names def __len__(self): file_names = self.get_file_names() return len(file_names) def __getitem__(self, item): file_names = self.get_file_names() file_name = file_names[item] image = extract_pytorch_image_from_filelike( file_name, scale=self.scale, offset=self.offset, premultiply_alpha=self.has_alpha and self.premultiply_alpha, perform_srgb_to_linear=self.perfrom_srb_to_linear) if self.downscale_kernel_size == 1: return [image] else: image = functional.avg_pool2d(image.unsqueeze(0), kernel_size=self.downscale_kernel_size).squeeze(0) return [image] ================================================ FILE: src/tha4/shion/base/dataset/util.py ================================================ from typing import List import torch from torch.utils.data import Dataset def get_indexed_batch(dataset: Dataset, example_indices: List[int], device: torch.device): if len(example_indices) == 0: return [] examples = [] for index in range(len(example_indices)): example_index = example_indices[index] raw_example = dataset[example_index] example = [] for x in raw_example: if isinstance(x, torch.Tensor): y = x.to(device).unsqueeze(0) elif isinstance(x, float) or isinstance(x, int): y = torch.tensor([[x]], device=device) else: raise RuntimeError(f"get_indexed_batch: Data of type {type(x)} is not supported.") example.append(y) examples.append(example) k = len(examples[0]) transposed = [[] for i in range(k)] for example in examples: for i in range(k): transposed[i].append(example[i]) return [torch.cat(x, dim=0) for x in transposed] ================================================ FILE: src/tha4/shion/base/dataset/xformed_dataset.py ================================================ from typing import Any, Callable from torch.utils.data import Dataset class XformedDataset(Dataset): def __init__(self, source: Dataset, xform_func: Callable[[Any], Any]): self.xform_func = xform_func self.source = source def __len__(self): return len(self.source) def __getitem__(self, item): return self.xform_func(self.source[item]) ================================================ FILE: src/tha4/shion/base/image_util.py ================================================ import os import PIL.Image import numpy import torch from matplotlib import pyplot from torch import Tensor def numpy_srgb_to_linear(x): x = numpy.clip(x, 0.0, 1.0) return numpy.where(x <= 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4) def numpy_linear_to_srgb(x): x = numpy.clip(x, 0.0, 1.0) return numpy.where(x <= 0.003130804953560372, x * 12.92, 1.055 * (x ** (1.0 / 2.4)) - 0.055) def numpy_alpha_devide(rgb, a, epsilon=1e-5): aaa = numpy.repeat(a, 3, axis=2) aaa_prime = aaa + numpy.where(numpy.abs(aaa) < epsilon, epsilon, 0.0) return numpy.where(numpy.abs(aaa) < epsilon, 0.0, rgb / aaa_prime) def torch_srgb_to_linear(x: torch.Tensor): x = torch.clip(x, 0.0, 1.0) return torch.where(torch.le(x, 0.04045), x / 12.92, ((x + 0.055) / 1.055) ** 2.4) def torch_linear_to_srgb(x): x = torch.clip(x, 0.0, 1.0) return torch.where(torch.le(x, 0.003130804953560372), x * 12.92, 1.055 * (x ** (1.0 / 2.4)) - 0.055) def numpy_image_linear_to_srgb(image): assert image.shape[2] == 3 or image.shape[2] == 4 if image.shape[2] == 3: return numpy_linear_to_srgb(image) else: height, width, _ = image.shape rgb_image = numpy_linear_to_srgb(image[:, :, 0:3]) a_image = image[:, :, 3:4] return numpy.concatenate((rgb_image, a_image), axis=2) def numpy_image_srgb_to_linear(image): assert image.shape[2] == 3 or image.shape[2] == 4 if image.shape[2] == 3: return numpy_srgb_to_linear(image) else: height, width, _ = image.shape rgb_image = numpy_srgb_to_linear(image[:, :, 0:3]) a_image = image[:, :, 3:4] return numpy.concatenate((rgb_image, a_image), axis=2) def pytorch_rgb_to_numpy_image(torch_image: Tensor, min_pixel_value=-1.0, max_pixel_value=1.0): assert torch_image.dim() == 3 assert torch_image.shape[0] == 3 height = torch_image.shape[1] width = torch_image.shape[2] reshaped_image = torch_image.numpy().reshape(3, height * width).transpose().reshape(height, width, 3) numpy_image = (reshaped_image - min_pixel_value) / (max_pixel_value - min_pixel_value) return numpy_linear_to_srgb(numpy_image) def pytorch_rgba_to_numpy_image_greenscreen(torch_image: Tensor, min_pixel_value=-1.0, max_pixel_value=1.0, include_alpha=False): height = torch_image.shape[1] width = torch_image.shape[2] numpy_image = (torch_image.numpy().reshape(4, height * width).transpose().reshape(height, width, 4) - min_pixel_value) \ / (max_pixel_value - min_pixel_value) rgb_image = numpy_linear_to_srgb(numpy_image[:, :, 0:3]) a_image = numpy_image[:, :, 3] rgb_image[:, :, 0:3] = rgb_image[:, :, 0:3] * a_image.reshape(a_image.shape[0], a_image.shape[1], 1) rgb_image[:, :, 1] = rgb_image[:, :, 1] + (1 - a_image) if not include_alpha: return rgb_image else: return numpy.concatenate((rgb_image, numpy.ones_like(numpy_image[:, :, 3:4])), axis=2) def pytorch_rgba_to_numpy_image( torch_image: Tensor, min_pixel_value=-1.0, max_pixel_value=1.0, perform_linear_to_srb: bool = True): assert torch_image.dim() == 3 assert torch_image.shape[0] == 4 height = torch_image.shape[1] width = torch_image.shape[2] reshaped_image = torch_image.numpy().reshape(4, height * width).transpose().reshape(height, width, 4) numpy_image = (reshaped_image - min_pixel_value) / (max_pixel_value - min_pixel_value) if perform_linear_to_srb: rgb_image = numpy_linear_to_srgb(numpy_image[:, :, 0:3]) else: rgb_image = numpy.clip(numpy_image[:, :, 0:3], 0.0, 1.0) a_image = numpy.clip(numpy_image[:, :, 3], 0.0, 1.0) rgba_image = numpy.concatenate((rgb_image, a_image.reshape(height, width, 1)), axis=2) return rgba_image def pil_image_has_transparency(pil_image): if pil_image.info.get("transparency", None) is not None: return True if pil_image.mode == "P": transparent = pil_image.info.get("transparency", -1) for _, index in pil_image.getcolors(): if index == transparent: return True elif pil_image.mode == "RGBA": extrema = pil_image.getextrema() if extrema[3][0] < 255: return True return False def extract_numpy_image_from_PIL_image(pil_image, scale=2.0, offset=-1.0, premultiply_alpha=True, perform_srgb_to_linear=True): has_alpha = pil_image_has_transparency(pil_image) if has_alpha and pil_image.mode != 'RGBA': pil_image = pil_image.convert("RGBA") if not has_alpha and pil_image.mode != 'RGB': pil_image = pil_image.convert("RGB") if has_alpha: num_channel = 4 else: num_channel = 3 image_width = pil_image.width image_height = pil_image.height raw_image = numpy.asarray(pil_image, dtype=numpy.float32) image = (raw_image / 255.0).reshape(image_height, image_width, num_channel) if perform_srgb_to_linear: image[:, :, 0:3] = numpy_srgb_to_linear(image[:, :, 0:3]) # Premultiply alpha if has_alpha and premultiply_alpha: image[:, :, 0:3] = image[:, :, 0:3] * image[:, :, 3:4] return image * scale + offset def extract_numpy_image_from_PIL_image_with_pytorch_layout(pil_image, scale=2.0, offset=-1.0, premultiply_alpha=True, perform_srgb_to_linear=True): numpy_image = extract_numpy_image_from_PIL_image( pil_image, scale, offset, premultiply_alpha, perform_srgb_to_linear) image_height, image_width, num_channel = numpy_image.shape image = numpy_image \ .reshape(image_height * image_width, num_channel) \ .transpose() \ .reshape(num_channel, image_height, image_width) return image def extract_numpy_image_from_filelike_with_pytorch_layout(file, scale=2.0, offset=-1.0, premultiply_alpha=True): try: pil_image = PIL.Image.open(file) except Exception as e: raise RuntimeError(file) return extract_numpy_image_from_PIL_image_with_pytorch_layout(pil_image, scale, offset, premultiply_alpha) def extract_numpy_image_from_filelike(file, scale=1.0, offset=0.0, premultiply_alpha=True, perform_srgb_to_linear: bool = True): try: pil_image = PIL.Image.open(file) except Exception as e: raise RuntimeError(file) return extract_numpy_image_from_PIL_image(pil_image, scale, offset, premultiply_alpha, perform_srgb_to_linear) def extract_pytorch_image_from_filelike(file, scale=2.0, offset=-1.0, premultiply_alpha=True, perform_srgb_to_linear=True): try: pil_image = PIL.Image.open(file) except Exception as e: raise RuntimeError(file) image = extract_numpy_image_from_PIL_image_with_pytorch_layout(pil_image, scale, offset, premultiply_alpha, perform_srgb_to_linear) return torch.from_numpy(image).float() def extract_pytorch_image_from_PIL_image(pil_image, scale=2.0, offset=-1.0, premultiply_alpha=True, perform_srgb_to_linear=True): image = extract_numpy_image_from_PIL_image_with_pytorch_layout( pil_image, scale, offset, premultiply_alpha, perform_srgb_to_linear) return torch.from_numpy(image).float() def convert_pytorch_image_to_zero_to_one_numpy_image( torch_image: torch.Tensor, scale: float = 2.0, offset: float = -1.0): torch_image = (torch_image - offset) / scale torch_image = torch.permute(torch_image, (1, 2, 0)) numpy_image = torch_image.cpu().numpy() return numpy_image def convert_zero_to_one_numpy_image_to_PIL_image( numpy_image, use_straight_alpha=True, perform_linear_to_srgb=True): if numpy_image.shape[2] == 4: rgb_image = numpy_image[:, :, 0:3] a_image = numpy.clip(numpy_image[:, :, 3:4], 0.0, 1.0) if use_straight_alpha: rgb_image = numpy_alpha_devide(rgb_image, a_image) if perform_linear_to_srgb: rgb_image = numpy_linear_to_srgb(rgb_image) else: rgb_image = numpy.clip(rgb_image, 0.0, 1.0) new_numpy_image = numpy.concatenate((rgb_image, a_image), axis=2) pil_image = PIL.Image.fromarray(numpy.uint8(numpy.rint(new_numpy_image * 255.0)), mode='RGBA') else: if perform_linear_to_srgb: numpy_image = numpy_linear_to_srgb(numpy_image) else: numpy_image = numpy.clip(numpy_image, 0.0, 1.0) pil_image = PIL.Image.fromarray(numpy.uint8(numpy.rint(numpy_image * 255.0)), mode='RGB') return pil_image def save_numpy_image(numpy_image, file_name: str, save_straight_alpha=True, perform_linear_to_srgb=True): pil_image = convert_zero_to_one_numpy_image_to_PIL_image(numpy_image, save_straight_alpha, perform_linear_to_srgb) os.makedirs(os.path.dirname(file_name), exist_ok=True) pil_image.save(file_name) def resize_PIL_image(pil_image, size=(256, 256)): w, h = pil_image.size d = min(w, h) r = ((w - d) // 2, (h - d) // 2, (w + d) // 2, (h + d) // 2) return pil_image.resize(size, resample=PIL.Image.LANCZOS, box=r) ================================================ FILE: src/tha4/shion/base/loss/__init__.py ================================================ ================================================ FILE: src/tha4/shion/base/loss/computed_scale_loss.py ================================================ from typing import Optional, Callable from tha4.shion.core.cached_computation import TensorCachedComputationFunc, ComputationState from tha4.shion.core.loss import Loss class ComputedScaleLoss(Loss): def __init__(self, scale_func: TensorCachedComputationFunc, loss: Loss, weight: float = 1.0): self.weight = weight self.loss = loss self.scale_func = scale_func def compute(self, state: ComputationState, log_func: Optional[Callable[[str, float], None]] = None): loss = self.loss.compute(state) scale = self.scale_func(state) loss = self.weight * scale * loss if log_func is not None: log_func("loss", loss.item()) return loss ================================================ FILE: src/tha4/shion/base/loss/computed_scaled_l2_loss.py ================================================ from typing import Callable, Optional from tha4.shion.core.cached_computation import TensorCachedComputationFunc, ComputationState from tha4.shion.core.loss import Loss class ComputedScaledL2Loss(Loss): def __init__(self, expected_func: TensorCachedComputationFunc, actual_func: TensorCachedComputationFunc, element_scale_func: TensorCachedComputationFunc, weight: float = 1.0): self.element_scale_func = element_scale_func self.actual_func = actual_func self.expected_func = expected_func self.weight = weight def compute( self, state: ComputationState, log_func: Optional[Callable[[str, float], None]] = None): element_scale = self.element_scale_func(state) expected = self.expected_func(state) actual = self.actual_func(state) diff = (expected - actual) * element_scale loss = self.weight * (diff ** 2).mean() if log_func is not None: log_func("loss", loss.item()) return loss ================================================ FILE: src/tha4/shion/base/loss/l1_loss.py ================================================ from typing import Callable, Optional import torch from tha4.shion.core.cached_computation import TensorCachedComputationFunc, ComputationState from tha4.shion.core.loss import Loss class L1Loss(Loss): def __init__(self, expected_func: TensorCachedComputationFunc, actual_func: TensorCachedComputationFunc, weight: float = 1.0): self.actual_func = actual_func self.expected_func = expected_func self.weight = weight def compute(self, state: ComputationState, log_func: Optional[Callable[[str, float], None]] = None): expected = self.expected_func(state) actual = self.actual_func(state) loss = self.weight * (expected - actual).abs().mean() if log_func is not None: log_func("loss", loss.item()) return loss class ListL1Loss(Loss): def __init__(self, expected_func: TensorCachedComputationFunc, actual_func: TensorCachedComputationFunc, weight: float = 1.0): self.actual_func = actual_func self.expected_func = expected_func self.weight = weight def compute(self, state: ComputationState, log_func: Optional[Callable[[str, float], None]] = None): expected = self.expected_func(state) actual = self.actual_func(state) assert len(expected) == len(actual) loss = torch.zeros(1, device=expected[0].device) for i in range(len(expected)): loss += (expected[i] - actual[i]).abs().mean() loss = self.weight * loss if log_func is not None: log_func("loss", loss.item()) return loss class MaskedL1Loss(Loss): def __init__(self, expected_func: TensorCachedComputationFunc, actual_func: TensorCachedComputationFunc, mask_func: TensorCachedComputationFunc, weight: float = 1.0): self.mask_func = mask_func self.actual_func = actual_func self.expected_func = expected_func self.weight = weight def compute(self, state: ComputationState, log_func: Optional[Callable[[str, float], None]] = None): mask = self.mask_func(state) expected = self.expected_func(state) actual = self.actual_func(state) loss = self.weight * ((expected - actual) * mask).abs().mean() if log_func is not None: log_func("loss", loss.item()) return loss ================================================ FILE: src/tha4/shion/base/loss/l2_loss.py ================================================ from typing import Callable, Optional from tha4.shion.core.cached_computation import TensorCachedComputationFunc, ComputationState from tha4.shion.core.loss import Loss class L2Loss(Loss): def __init__(self, expected_func: TensorCachedComputationFunc, actual_func: TensorCachedComputationFunc, weight: float = 1.0): self.actual_func = actual_func self.expected_func = expected_func self.weight = weight def compute( self, state: ComputationState, log_func: Optional[Callable[[str, float], None]] = None): expected = self.expected_func(state) actual = self.actual_func(state) loss = self.weight * ((expected - actual) ** 2).mean() if log_func is not None: log_func("loss", loss.item()) return loss ================================================ FILE: src/tha4/shion/base/loss/sum_loss.py ================================================ from typing import List, Tuple, Callable, Optional import torch from torch import Tensor from tha4.shion.core.cached_computation import ComputationState from tha4.shion.core.loss import Loss class SumLoss(Loss): def __init__(self, losses: List[Tuple[str, Loss]]): self.losses = losses def compute(self, state: ComputationState, log_func: Optional[Callable[[str, float], None]] = None) -> Tensor: device = state.batch[0].device loss_value = torch.zeros(1, device=device) for loss_spec in self.losses: loss_name = loss_spec[0] loss = loss_spec[1] if log_func is not None: def loss_log_func(name, value): log_func(loss_name + "_" + name, value) else: loss_log_func = None loss_value = loss_value + loss.compute(state, loss_log_func) if log_func is not None: log_func("loss", loss_value.item()) return loss_value ================================================ FILE: src/tha4/shion/base/loss/time_dependently_weighted_loss.py ================================================ from typing import Callable, Optional from torch import Tensor from tha4.shion.core.cached_computation import ComputationState, CachedComputationFunc from tha4.shion.core.loss import Loss class TimeDependentlyWeightedLoss(Loss): def __init__(self, base_loss: Loss, examples_seen_so_far_func: CachedComputationFunc, weight_func: Callable[[int], float]): self.weight_func = weight_func self.examples_seen_so_far_func = examples_seen_so_far_func self.base_loss = base_loss def compute(self, state: ComputationState, log_func: Optional[Callable[[str, float], None]] = None) -> Tensor: base_value = self.base_loss.compute(state) examples_seen_so_far = self.examples_seen_so_far_func(state) weight = self.weight_func(examples_seen_so_far) loss_value = base_value * weight if log_func is not None: log_func("loss", loss_value.item()) return loss_value ================================================ FILE: src/tha4/shion/base/module_accumulators.py ================================================ from typing import Optional import torch from torch.nn import Module from tha4.shion.core.module_accumulator import ModuleAccumulator # Code from https://github.com/rosinality/style-based-gan-pytorch/blob/8437a8bbd106ad4a4691b798ce35d30b5111990b/train.py def accumulate_modules(new_module: Module, accumulated_module: Module, beta=0.99): with torch.no_grad(): new_module_params = dict(new_module.named_parameters()) accumulated_module_params = dict(accumulated_module.named_parameters()) for key in new_module_params.keys(): accumulated_module_params[key].mul_(beta).add_(new_module_params[key] * (1 - beta)) new_module_buffers = dict(new_module.named_buffers()) accumulated_module_buffers = dict(accumulated_module.named_buffers()) for key in new_module_buffers.keys(): accumulated_module_buffers[key].copy_(new_module_buffers[key]) class DecayAccumulator(ModuleAccumulator): def __init__(self, decay: float = 0.999): self.decay = decay def accumulate(self, module: Module, output: Module, examples_seen_so_far: Optional[int] = None) -> Module: accumulate_modules(module, output, self.decay) return output ================================================ FILE: src/tha4/shion/base/optimizer_factories.py ================================================ from typing import Tuple, Iterable from torch.nn import Parameter from torch.optim import Optimizer, Adam, AdamW, SparseAdam, RMSprop from tha4.shion.core.optimizer_factory import OptimizerFactory class AdamOptimizerFactory(OptimizerFactory): def __init__(self, betas: Tuple[float, float] = (0.9, 0.999), epsilon: float = 1e-8, weight_decay: float = 0.0): super().__init__() self.weight_decay = weight_decay self.betas = betas self.epsilon = epsilon def create(self, parameters: Iterable[Parameter]) -> Optimizer: return Adam(parameters, betas=self.betas, eps=self.epsilon, weight_decay=self.weight_decay) class AdamWOptimizerFactory(OptimizerFactory): def __init__(self, betas: Tuple[float, float] = (0.9, 0.999), epsilon: float = 1e-8, weight_decay: float = 0.01): super().__init__() self.weight_decay = weight_decay self.betas = betas self.epsilon = epsilon def create(self, parameters: Iterable[Parameter]) -> Optimizer: return AdamW(parameters, betas=self.betas, eps=self.epsilon, weight_decay=self.weight_decay) class SparseAdamOptimizerFactory(OptimizerFactory): def __init__(self, betas: Tuple[float, float] = (0.9, 0.999), epsilon: float = 1e-8): super().__init__() self.betas = betas self.epsilon = epsilon def create(self, parameters: Iterable[Parameter]) -> Optimizer: return SparseAdam(list(parameters), betas=self.betas, eps=self.epsilon) class RMSpropOptimizerFactory(OptimizerFactory): def __init__(self): super().__init__() def create(self, parameters: Iterable[Parameter]) -> Optimizer: return RMSprop(parameters) ================================================ FILE: src/tha4/shion/base/protocol/single_network_from_batch_input_computation_protocol.py ================================================ from typing import Optional, Any, List from tha4.shion.core.cached_computation import CachedComputationProtocol, ComputationState KEY_NETWORK = "network" KEY_NETWORK_OUTPUT = "network_output" class SingleNetworkBatchInputComputationProtocol(CachedComputationProtocol): def __init__(self, key_network: str = KEY_NETWORK, key_network_output: str = KEY_NETWORK_OUTPUT, input_index_to_batch_index: Optional[List[int]] = None): if input_index_to_batch_index is None: input_index_to_batch_index = [0] self.input_index_to_batch_index = input_index_to_batch_index self.key_network_output = key_network_output self.key_network = key_network def compute_output(self, key: str, state: ComputationState) -> Any: if key == self.key_network_output: inputs = [] for batch_index in self.input_index_to_batch_index: inputs.append(state.batch[batch_index]) network = state.modules[self.key_network] return network.forward(*inputs) else: raise RuntimeError("Computing output for key " + key + " is not supported!") ================================================ FILE: src/tha4/shion/base/training/__init__.py ================================================ ================================================ FILE: src/tha4/shion/base/training/single_network.py ================================================ import time from typing import List, Dict, Callable, Any, Optional import torch from torch.nn import Module from torch.nn.utils import clip_grad_norm_ from torch.optim.optimizer import Optimizer from tha4.shion.core.cached_computation import ComputationState from tha4.shion.core.loss import Loss from tha4.shion.core.optimizer_factory import OptimizerFactory from tha4.shion.core.training.training_protocol import TrainingProtocol from tha4.shion.core.training.validation_protocol import ValidationProtocol KEY_NETWORK = "network" class SingleNetworkTrainingProtocol(TrainingProtocol): def __init__(self, check_point_examples: List[int], batch_size: int, learning_rate: Callable[[int], Dict[str, float]], optimizer_factories: Dict[str, OptimizerFactory], module_key: str = KEY_NETWORK, random_seed: int = 39549059840, max_grad_norm: Optional[float] = None): super().__init__() self.max_grad_norm = max_grad_norm self.module_key = module_key self.optimizer_factories = optimizer_factories self.learning_rate = learning_rate self.batch_size = batch_size self.random_seed = random_seed self.check_point_examples = check_point_examples def get_optimizer_factories(self) -> Dict[str, OptimizerFactory]: return self.optimizer_factories def get_checkpoint_examples(self) -> List[int]: return self.check_point_examples def get_random_seed(self) -> int: return self.random_seed def get_batch_size(self) -> int: return self.batch_size def get_learning_rate(self, examples_seen_so_far: int) -> Dict[str, float]: return self.learning_rate(examples_seen_so_far) def run_training_iteration( self, batch: Any, examples_seen_so_far: int, modules: Dict[str, Module], accumulated_modules: Dict[str, Module], optimizers: Dict[str, Optimizer], losses: Dict[str, Loss], create_log_func: Optional[Callable[[str, int], Callable[[str, float], None]]], device: torch.device): module = modules[self.module_key] module.train(True) optimizers[self.module_key].zero_grad(set_to_none=True) if create_log_func is not None: log_func = create_log_func("training_" + self.module_key, examples_seen_so_far) else: log_func = None losses[self.module_key].compute( ComputationState(modules, accumulated_modules, batch), log_func).backward() if self.max_grad_norm is not None: clip_grad_norm_(module.parameters(), self.max_grad_norm) optimizers[self.module_key].step() class SingleNetworkValidationProtocol(ValidationProtocol): def __init__( self, example_per_validation_iteration: int, batch_size: int, module_key: str = KEY_NETWORK): super().__init__() self.module_key = module_key self.batch_size = batch_size self.example_per_validation_iteration = example_per_validation_iteration def get_batch_size(self, ) -> int: return self.batch_size def get_examples_per_validation_iteration(self) -> int: return self.example_per_validation_iteration def run_validation_iteration( self, batch: Any, examples_seen_so_far: int, modules: Dict[str, Module], accumulated_modules: Dict[str, Module], losses: Dict[str, Loss], create_log_func: Callable[[str, int], Callable[[str, float], None]], device: torch.device): module = modules[self.module_key] module.train(False) with torch.no_grad(): log_func = create_log_func("validation_" + self.module_key, examples_seen_so_far) losses[self.module_key].compute( ComputationState(modules, accumulated_modules, batch), log_func) ================================================ FILE: src/tha4/shion/base/training/single_network_with_minibatch.py ================================================ import time from typing import List, Dict, Callable, Any, Optional import torch from torch.nn import Module from torch.nn.utils import clip_grad_norm_ from torch.optim.optimizer import Optimizer from tha4.shion.core.cached_computation import ComputationState from tha4.shion.core.loss import Loss from tha4.shion.core.optimizer_factory import OptimizerFactory from tha4.shion.core.training.training_protocol import TrainingProtocol from tha4.shion.core.training.validation_protocol import ValidationProtocol KEY_NETWORK = "network" class SingleNetworkWithMinibatchTrainingProtocol(TrainingProtocol): def __init__(self, check_point_examples: List[int], batch_size: int, minibatch_size: int, learning_rate: Callable[[int], Dict[str, float]], optimizer_factories: Dict[str, OptimizerFactory], module_key: str = KEY_NETWORK, random_seed: int = 39549059840, max_grad_norm: Optional[float] = None): super().__init__() assert batch_size % minibatch_size == 0 self.minibatch_size = minibatch_size self.max_grad_norm = max_grad_norm self.module_key = module_key self.optimizer_factories = optimizer_factories self.learning_rate = learning_rate self.batch_size = batch_size self.random_seed = random_seed self.check_point_examples = check_point_examples def get_optimizer_factories(self) -> Dict[str, OptimizerFactory]: return self.optimizer_factories def get_checkpoint_examples(self) -> List[int]: return self.check_point_examples def get_random_seed(self) -> int: return self.random_seed def get_batch_size(self) -> int: return self.batch_size def get_learning_rate(self, examples_seen_so_far: int) -> Dict[str, float]: return self.learning_rate(examples_seen_so_far) def run_training_iteration( self, batch: Any, examples_seen_so_far: int, modules: Dict[str, Module], accumulated_modules: Dict[str, Module], optimizers: Dict[str, Optimizer], losses: Dict[str, Loss], create_log_func: Optional[Callable[[str, int], Callable[[str, float], None]]], device: torch.device): module = modules[self.module_key] module.train(True) optimizers[self.module_key].zero_grad(set_to_none=True) if create_log_func is not None: log_func = create_log_func("training_" + self.module_key, examples_seen_so_far) else: log_func = None num_minibatch = self.batch_size // self.minibatch_size for minibatch_index in range(num_minibatch): minibatch = [] for item in batch: minibatch.append( item[minibatch_index * self.minibatch_size:(minibatch_index + 1) * self.minibatch_size]) loss = losses[self.module_key].compute( ComputationState(modules, accumulated_modules, minibatch), log_func if minibatch_index == 0 else None) loss = loss / num_minibatch loss.backward() if self.max_grad_norm is not None: clip_grad_norm_(module.parameters(), self.max_grad_norm) optimizers[self.module_key].step() ================================================ FILE: src/tha4/shion/base/training/two_networks_training_protocol.py ================================================ from typing import List, Dict, Callable, Any, Optional import torch from torch.nn import Module from torch.nn.utils import clip_grad_norm_ from torch.optim.optimizer import Optimizer from tha4.shion.core.cached_computation import ComputationState from tha4.shion.core.loss import Loss from tha4.shion.core.optimizer_factory import OptimizerFactory from tha4.shion.core.training.training_protocol import TrainingProtocol class TwoNetworksWithMinibatchTrainingProtocol(TrainingProtocol): def __init__(self, check_point_examples: List[int], batch_size: int, learning_rate: Callable[[int], Dict[str, float]], optimizer_factories: Dict[str, OptimizerFactory], key_network_0: str, key_network_1: str, train_network_0: bool = False, random_seed: int = 39549059840, max_grad_norm: Optional[float] = None, minibatch_size: Optional[int] = None): super().__init__() if minibatch_size is None: minibatch_size = batch_size assert batch_size % minibatch_size == 0 self.train_network_0 = train_network_0 self.key_network_1 = key_network_1 self.key_network_0 = key_network_0 self.minibatch_size = minibatch_size self.max_grad_norm = max_grad_norm self.optimizer_factories = optimizer_factories self.learning_rate = learning_rate self.batch_size = batch_size self.random_seed = random_seed self.check_point_examples = check_point_examples def get_optimizer_factories(self) -> Dict[str, OptimizerFactory]: return self.optimizer_factories def get_checkpoint_examples(self) -> List[int]: return self.check_point_examples def get_random_seed(self) -> int: return self.random_seed def get_batch_size(self) -> int: return self.batch_size def get_learning_rate(self, examples_seen_so_far: int) -> Dict[str, float]: return self.learning_rate(examples_seen_so_far) def run_training_iteration( self, batch: Any, examples_seen_so_far: int, modules: Dict[str, Module], accumulated_modules: Dict[str, Module], optimizers: Dict[str, Optimizer], losses: Dict[str, Loss], create_log_func: Optional[Callable[[str, int], Callable[[str, float], None]]], device: torch.device): network_0 = modules[self.key_network_0] network_0.train(self.train_network_0) network_1 = modules[self.key_network_1] network_1.train(True) if self.train_network_0: optimizers[self.key_network_0].zero_grad(set_to_none=True) optimizers[self.key_network_1].zero_grad(set_to_none=True) if create_log_func is not None: network_0_log_func = create_log_func("training_" + self.key_network_0, examples_seen_so_far) network_1_log_func = create_log_func("training_" + self.key_network_1, examples_seen_so_far) else: network_0_log_func = None network_1_log_func = None num_minibatch = self.batch_size // self.minibatch_size for minibatch_index in range(num_minibatch): minibatch = [] for item in batch: minibatch.append( item[minibatch_index * self.minibatch_size:(minibatch_index + 1) * self.minibatch_size]) loss = losses[self.key_network_1].compute( ComputationState(modules, accumulated_modules, minibatch), network_1_log_func if minibatch_index == 0 else None) if self.train_network_0 and self.key_network_0 in losses: loss = loss + losses[self.key_network_0].compute( ComputationState(modules, accumulated_modules, minibatch), network_0_log_func if minibatch_index == 0 else None) loss = loss / num_minibatch loss.backward() if self.max_grad_norm is not None: clip_grad_norm_(network_1.parameters(), self.max_grad_norm) if self.train_network_0: clip_grad_norm_(network_0.parameters(), self.max_grad_norm) optimizers[self.key_network_1].step() if self.train_network_0: optimizers[self.key_network_0].step() ================================================ FILE: src/tha4/shion/core/__init__.py ================================================ ================================================ FILE: src/tha4/shion/core/cached_computation.py ================================================ from abc import ABC, abstractmethod from typing import Callable, Dict, Any, Optional import torch from torch import Tensor from torch.nn import Module class ComputationState: def __init__(self, modules: Dict[str, Module], accumulated_modules: Dict[str, Module], batch: Any, outputs: Optional[Dict[str, Any]] = None): if outputs is None: outputs = {} self.outputs = outputs self.batch = batch self.accumulated_modules = accumulated_modules self.modules = modules CachedComputationFunc = Callable[[ComputationState], Any] TensorCachedComputationFunc = Callable[[ComputationState], Tensor] def create_get_item_func(func: CachedComputationFunc, index): def _f(state: ComputationState): output = func(state) return output[index] return _f def create_batch_element_func(index: int) -> TensorCachedComputationFunc: def _f(state: ComputationState) -> Tensor: return state.batch[index] return _f class CachedComputationProtocol(ABC): def get_output(self, key: str, state: ComputationState) -> Any: if key in state.outputs: return state.outputs[key] else: output = self.compute_output(key, state) state.outputs[key] = output return state.outputs[key] @abstractmethod def compute_output(self, key: str, state: ComputationState) -> Any: pass def get_output_func(self, key: str) -> CachedComputationFunc: def func(state: ComputationState): return self.get_output(key, state) return func ComposableCachedComputationStep = Callable[[CachedComputationProtocol, ComputationState], Any] class ComposableCachedComputationProtocol(CachedComputationProtocol): def __init__(self, computation_steps: Optional[Dict[str, ComposableCachedComputationStep]] = None): if computation_steps is None: computation_steps = {} self.computation_steps = computation_steps def compute_output(self, key: str, state: ComputationState) -> Any: if key in self.computation_steps: return self.computation_steps[key](self, state) else: raise RuntimeError("Computing output for key " + key + " is not supported!") def batch_indexing_func(index: int): def _f(protocol: CachedComputationProtocol, state: ComputationState): return state.batch[index] return _f def proxy_func(key: str): def _f(protocol: CachedComputationProtocol, state: ComputationState): return protocol.get_output(key, state) return _f def output_array_indexing_func(key: str, index: int): def _f(protocol: CachedComputationProtocol, state: ComputationState): return protocol.get_output(key, state)[index] return _f def add_step(step_dict: Dict[str, ComposableCachedComputationStep], name: str): def _f(func): step_dict[name] = func return func return _f def zeros_like_func(key: str): def _f(protocol: CachedComputationProtocol, state: ComputationState): prototype = protocol.get_output(key, state) return torch.zeros_like(prototype) return _f ================================================ FILE: src/tha4/shion/core/load_save.py ================================================ import os import torch def torch_save(content, file_name): os.makedirs(os.path.dirname(file_name), exist_ok=True) with open(file_name, 'wb') as f: torch.save(content, f) def torch_load(file_name): with open(file_name, 'rb') as f: return torch.load(f, map_location=lambda storage, loc: storage) ================================================ FILE: src/tha4/shion/core/loss.py ================================================ from abc import ABC, abstractmethod from typing import Callable, Optional from torch import Tensor from tha4.shion.core.cached_computation import ComputationState class Loss(ABC): @abstractmethod def compute( self, state: ComputationState, log_func: Optional[Callable[[str, float], None]] = None) -> Tensor: pass ================================================ FILE: src/tha4/shion/core/module_accumulator.py ================================================ from abc import ABC, abstractmethod from typing import Optional from torch.nn import Module class ModuleAccumulator(ABC): @abstractmethod def accumulate(self, module: Module, output: Module, examples_seen_so_far: Optional[int] = None) -> Module: pass ================================================ FILE: src/tha4/shion/core/module_factory.py ================================================ from abc import ABC, abstractmethod from torch.nn import Module class ModuleFactory(ABC): @abstractmethod def create(self) -> Module: pass ================================================ FILE: src/tha4/shion/core/optimizer_factory.py ================================================ from abc import ABC, abstractmethod from typing import Iterable from torch.nn import Parameter class OptimizerFactory(ABC): @abstractmethod def create(self, parameters: Iterable[Parameter]): pass ================================================ FILE: src/tha4/shion/core/training/__init__.py ================================================ ================================================ FILE: src/tha4/shion/core/training/distrib/__init__.py ================================================ ================================================ FILE: src/tha4/shion/core/training/distrib/device_mapper.py ================================================ from typing import Dict import torch class SimpleCudaDeviceMapper: def __call__(self, rank, local_rank): return torch.device("cuda", local_rank) class UserSpecifiedLocalRankToDeviceMapper: def __init__(self, device_map: Dict[int, torch.device]): self.device_map = device_map def __call__(self, rank, local_rank): assert local_rank in self.device_map return self.device_map[local_rank] ================================================ FILE: src/tha4/shion/core/training/distrib/distributed_trainer.py ================================================ import argparse import logging import os.path import time from datetime import datetime from typing import Dict, Optional, Callable, Any import torch import torch.distributed from torch.nn.parallel import DistributedDataParallel from torch.utils.data import Dataset, DataLoader, DistributedSampler from torch.utils.tensorboard import SummaryWriter from tha4.shion.core.load_save import torch_save, torch_load from tha4.shion.core.loss import Loss from tha4.shion.core.module_accumulator import ModuleAccumulator from tha4.shion.core.module_factory import ModuleFactory from tha4.shion.core.training.distrib.device_mapper import SimpleCudaDeviceMapper from tha4.shion.core.training.distrib.distributed_training_states import DistributedTrainingState from tha4.shion.core.training.sample_output_protocol import SampleOutputProtocol from tha4.shion.core.training.training_protocol import TrainingProtocol from tha4.shion.core.training.util import set_learning_rate, create_log_func, get_least_greater_multiple from tha4.shion.core.training.validation_protocol import ValidationProtocol KEY_CHECKPOINT = 'checkpoint' KEY_SNAPSHOT = 'snapshot' KEY_VALIDATION = 'validation' KEY_SAMPLE_OUTPUT = 'sample_output' class DistributedTrainer: def __init__(self, prefix: str, module_factories: Dict[str, ModuleFactory], accumulators: Dict[str, ModuleAccumulator], losses: Dict[str, Loss], training_dataset: Dataset, validation_dataset: Optional[Dataset], training_protocol: TrainingProtocol, validation_protocol: Optional[ValidationProtocol], sample_output_protocol: Optional[SampleOutputProtocol], pretrained_module_file_names: Dict[str, str], example_per_snapshot: int, num_data_loader_workers: int = 8, distrib_backend: str = 'gloo'): self.distrib_backend = distrib_backend self.num_data_loader_workers = num_data_loader_workers self.accumulators = accumulators self.sample_output_protocol = sample_output_protocol self.example_per_snapshot = example_per_snapshot self.pretrained_module_file_names = pretrained_module_file_names self.losses = losses self.validation_protocol = validation_protocol self.training_protocol = training_protocol self.module_factories = module_factories self.prefix = prefix self.training_dataset = training_dataset self.validation_dataset = validation_dataset self.checkpoint_examples = self.training_protocol.get_checkpoint_examples() assert len(self.checkpoint_examples) >= 1 assert self.checkpoint_examples[0] > 0 self.checkpoint_examples = [0] + self.checkpoint_examples self.module_names = self.module_factories.keys() assert len(self.module_names) > 0 self.training_data_loader = None self.training_data_loader_iter = None self.training_data_loader_batch_size = None self.training_data_sampler = None self.validation_data_loader = None self.validation_data_loader_iter = None self.validation_data_loader_batch_size = None self.sample_output_data = None self.summary_writer = None self.log_dir = None self.training_state = None def get_sample_output_data_file_name(self): return self.prefix + "/sample_output_data.pt" def save_sample_output_data(self, rank: int, device: torch.device): if rank != 0: return if os.path.exists(self.get_sample_output_data_file_name()): return if self.sample_output_protocol is not None: torch.manual_seed(self.sample_output_protocol.get_random_seed()) sample_output_data = self.sample_output_protocol.get_sample_output_data(self.validation_dataset, device) torch_save(sample_output_data, self.get_sample_output_data_file_name()) else: torch_save({}, self.get_sample_output_data_file_name()) def load_sample_output_data(self, rank: int, device: torch.device): if rank != 0: return None else: self.save_sample_output_data(rank, device) return torch_load(self.get_sample_output_data_file_name()) def get_snapshot_prefix(self) -> str: return self.prefix + "/snapshot" def can_load_training_state(self, prefix: str, world_size: int) -> bool: return DistributedTrainingState.can_load( prefix, self.module_factories, self.accumulators, self.training_protocol.get_optimizer_factories(), world_size) def load_training_state(self, prefix, rank: int, local_rank: int, device: torch.device) -> DistributedTrainingState: return DistributedTrainingState.load( prefix, self.module_factories, self.accumulators, self.training_protocol.get_optimizer_factories(), rank, local_rank, device) @staticmethod def checkpoint_prefix(prefix: str, checkpoint_index: int) -> str: return "%s/checkpoint/%04d" % (prefix, checkpoint_index) def get_checkpoint_prefix(self, checkpoint_index) -> str: return DistributedTrainer.checkpoint_prefix(self.prefix, checkpoint_index) def get_initial_training_state(self, rank: int, local_rank: int, device: torch.device) -> DistributedTrainingState: training_state = DistributedTrainingState.new( self.module_factories, self.accumulators, self.training_protocol.get_optimizer_factories(), self.training_protocol.get_random_seed(), rank, local_rank, device, self.pretrained_module_file_names) logging.info("Created a new initial training state.") return training_state def load_previous_training_state(self, target_checkpoint_examples: int, world_size: int, rank: int, local_rank: int, device: torch.device) -> DistributedTrainingState: if self.can_load_training_state(self.get_snapshot_prefix(), world_size): examples_seen_so_far = DistributedTrainingState.get_examples_seen_so_far(self.get_snapshot_prefix()) diff = examples_seen_so_far - target_checkpoint_examples if diff < self.training_protocol.get_batch_size(): return self.load_training_state(self.get_snapshot_prefix(), rank, local_rank, device) num_checkpoints = len(self.checkpoint_examples) for checkpoint_index in range(num_checkpoints - 1, -1, -1): if self.can_load_training_state(self.get_checkpoint_prefix(checkpoint_index), world_size): examples_seen_so_far = DistributedTrainingState.get_examples_seen_so_far( self.get_checkpoint_prefix(checkpoint_index)) diff = examples_seen_so_far - target_checkpoint_examples if diff < self.training_protocol.get_batch_size(): return self.load_training_state( self.get_checkpoint_prefix(checkpoint_index), rank, local_rank, device) training_state = self.get_initial_training_state(rank, local_rank, device) training_state.save(self.get_checkpoint_prefix(0), rank, lambda: self.barrier(local_rank)) training_state = self.load_training_state(self.get_checkpoint_prefix(0), rank, local_rank, device) return training_state def get_log_dir(self): if self.log_dir is None: now = datetime.now() self.log_dir = self.prefix + "/log/" + now.strftime("%Y_%m_%d__%H_%M_%S") return self.log_dir def get_summary_writer(self, rank: int) -> Optional[SummaryWriter]: if rank != 0: return None if self.summary_writer is None: self.summary_writer = SummaryWriter(log_dir=self.get_log_dir()) return self.summary_writer def get_effective_training_epoch_size(self, world_size: int): batch_size = self.training_protocol.get_batch_size() N = len(self.training_dataset) N = (N // world_size) * world_size N = (N // batch_size) * batch_size return N def get_training_epoch_index(self, examples_seen_so_far: int, world_size: int): epoch_size = self.get_effective_training_epoch_size(world_size) batch_size = self.training_protocol.get_batch_size() return (examples_seen_so_far + batch_size * world_size) // epoch_size def get_next_training_batch(self, examples_seen_so_far: int, world_size: int, device: torch.device): batch_size = self.training_protocol.get_batch_size() dataset = self.training_dataset if self.training_data_loader is None: self.training_data_sampler = DistributedSampler( dataset, shuffle=True, drop_last=True) self.training_data_loader = DataLoader( dataset, batch_size=batch_size, sampler=self.training_data_sampler, shuffle=False, num_workers=self.num_data_loader_workers, drop_last=True) if self.training_data_loader_iter is None: epoch_index = self.get_training_epoch_index(examples_seen_so_far, world_size) logging.info(f"Started a new epoch: index = {epoch_index}, examples_seen_so_far = {examples_seen_so_far}") self.training_data_sampler.set_epoch(epoch_index) self.training_data_loader_iter = iter(self.training_data_loader) try: batch = next(self.training_data_loader_iter) except StopIteration: epoch_index = self.get_training_epoch_index(examples_seen_so_far, world_size) logging.info(f"Started a new epoch: index = {epoch_index}, examples_seen_so_far = {examples_seen_so_far}") self.training_data_sampler.set_epoch(epoch_index) self.training_data_loader_iter = iter(self.training_data_loader) batch = next(self.training_data_loader_iter) return [x.to(device) for x in batch] def get_next_checkpoint_num_examples(self, examples_seen_so_far) -> int: next_index = next( (i for i in range(len(self.checkpoint_examples)) if self.checkpoint_examples[i] > examples_seen_so_far), -1) return self.checkpoint_examples[next_index] def get_next_snapshot_num_examples(self, examples_seen_so_far) -> int: return get_least_greater_multiple(examples_seen_so_far, self.example_per_snapshot) def get_next_validation_num_examples(self, examples_seen_so_far) -> int: if self.validation_protocol is None: return -1 return get_least_greater_multiple(examples_seen_so_far, self.validation_protocol.get_examples_per_validation_iteration()) def get_next_sample_output_num_examples(self, examples_seen_so_far) -> int: if self.sample_output_protocol is None: return -1 return get_least_greater_multiple(examples_seen_so_far, self.sample_output_protocol.get_examples_per_sample_output()) def get_next_num_examples(self, examples_seen_so_far) -> Dict[str, int]: return { KEY_CHECKPOINT: self.get_next_checkpoint_num_examples(examples_seen_so_far), KEY_SNAPSHOT: self.get_next_snapshot_num_examples(examples_seen_so_far), KEY_VALIDATION: self.get_next_validation_num_examples(examples_seen_so_far), KEY_SAMPLE_OUTPUT: self.get_next_sample_output_num_examples(examples_seen_so_far) } def get_next_validation_batch(self, device: torch.device): if self.validation_dataset is None: return None if self.validation_data_loader is None: self.validation_data_loader = DataLoader( self.validation_dataset, batch_size=self.validation_protocol.get_batch_size(), shuffle=True, num_workers=1, drop_last=True) if self.validation_data_loader_iter is None: self.validation_data_loader_iter = iter(self.validation_data_loader) try: batch = next(self.validation_data_loader_iter) except StopIteration: self.validation_data_loader_iter = iter(self.validation_data_loader) batch = next(self.validation_data_loader_iter) return [x.to(device) for x in batch] def get_checkpoint_index_to_save(self, examples_seen_so_far: int) -> int: checkpoint_index = 0 for i in range(len(self.checkpoint_examples)): if self.checkpoint_examples[i] <= examples_seen_so_far: checkpoint_index = i return checkpoint_index def barrier(self, local_rank: int): if self.distrib_backend == 'nccl': torch.distributed.barrier(device_ids=[local_rank]) else: torch.distributed.barrier() def train(self, world_size: int, rank: int, local_rank: int, target_checkpoint_examples: Optional[int] = None, device_mapper: Optional[Callable[[int, int], torch.device]] = None): if target_checkpoint_examples is None: target_checkpoint_examples = self.checkpoint_examples[-1] if device_mapper is None: device_mapper = SimpleCudaDeviceMapper() device = device_mapper(rank, local_rank) sample_output_data = self.load_sample_output_data(rank, device) training_state = self.load_previous_training_state( target_checkpoint_examples, world_size, rank, local_rank, device) summary_writer = self.get_summary_writer(rank) if summary_writer is not None: log_func_factory = lambda name, num: create_log_func(summary_writer, name, num) else: log_func_factory = None last_time = time.time() while training_state.examples_seen_so_far < target_checkpoint_examples: # Set the learning rate learning_rate_by_module_name = self.training_protocol.get_learning_rate(training_state.examples_seen_so_far) for module_name in self.module_factories.keys(): if module_name not in learning_rate_by_module_name or module_name not in training_state.optimizers: continue lr = learning_rate_by_module_name[module_name] set_learning_rate(training_state.optimizers[module_name], lr) if summary_writer is not None: summary_writer.add_scalar( module_name + "_learning_rate", lr, training_state.examples_seen_so_far) # One training iteration training_batch = self.get_next_training_batch(training_state.examples_seen_so_far, world_size, device) self.training_protocol.run_training_iteration( training_batch, training_state.examples_seen_so_far, training_state.modules, training_state.accumulated_modules, training_state.optimizers, self.losses, log_func_factory, device) # Accumulate model data for module_name in self.accumulators: new_module = training_state.modules[module_name] if isinstance(new_module, DistributedDataParallel): new_module = new_module.module buffer_module = training_state.accumulated_modules[module_name] self.accumulators[module_name].accumulate( new_module, buffer_module, examples_seen_so_far=training_state.examples_seen_so_far) # Advance the number of examples seen so far next_num_examples = self.get_next_num_examples(training_state.examples_seen_so_far) training_state.examples_seen_so_far += self.training_protocol.get_batch_size() * world_size # Validation iteration if self.validation_protocol is not None \ and training_state.examples_seen_so_far >= next_num_examples[KEY_VALIDATION] \ and rank == 0: validation_batch = self.get_next_validation_batch(device) self.validation_protocol.run_validation_iteration( validation_batch, training_state.examples_seen_so_far, training_state.modules, training_state.accumulated_modules, self.losses, log_func_factory, device) # Save sample output if self.sample_output_protocol is not None \ and training_state.examples_seen_so_far >= next_num_examples[KEY_SAMPLE_OUTPUT]: if rank == 0: self.sample_output_protocol.save_sample_output_data( training_state.modules, training_state.accumulated_modules, sample_output_data, self.prefix + "/sample_outputs", training_state.examples_seen_so_far, device) self.barrier(local_rank) # Save checkpoint if training_state.examples_seen_so_far >= next_num_examples[KEY_CHECKPOINT]: checkpoint_index = self.get_checkpoint_index_to_save(training_state.examples_seen_so_far) training_state.save( self.get_checkpoint_prefix(checkpoint_index), rank, lambda: self.barrier(local_rank)) if next_num_examples[KEY_CHECKPOINT] != next_num_examples[KEY_SNAPSHOT]: training_state.save(self.get_snapshot_prefix(), rank, lambda: self.barrier(local_rank)) # Save snapshot if training_state.examples_seen_so_far >= next_num_examples[KEY_SNAPSHOT]: training_state.save(self.get_snapshot_prefix(), rank, lambda: self.barrier(local_rank)) now = time.time() if now - last_time > 10: logging.info("Showed %d training examples." % training_state.examples_seen_so_far) last_time = now @staticmethod def get_default_arg_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description='Training script.') parser.add_argument("--target_checkpoint_examples", type=int) return parser @staticmethod def run_with_args(trainer_factory: Callable[[int, str], 'DistributedTrainer'], args, backend: str = 'gloo', device_mapper: Optional[Callable[[int, int], torch.device]] = None): world_size = int(os.environ['WORLD_SIZE']) rank = int(os.environ['RANK']) local_rank = int(os.environ['LOCAL_RANK']) torch.distributed.init_process_group(backend) trainer = trainer_factory(world_size, backend) trainer.train(world_size, rank, local_rank, args.target_checkpoint_examples, device_mapper) @staticmethod def run(trainer_factory: Callable[[int, str], 'DistributedTrainer'], backend: str = 'gloo', device_mapper: Optional[Callable[[int, int], torch.device]] = None, args: Optional[Any] = None): if args is None: parser = DistributedTrainer.get_default_arg_parser() args = parser.parse_args() DistributedTrainer.run_with_args(trainer_factory, args, backend, device_mapper) ================================================ FILE: src/tha4/shion/core/training/distrib/distributed_training_states.py ================================================ import copy import logging import os from typing import Dict, Optional, Callable import torch from torch.nn import Module from torch.nn.parallel import DistributedDataParallel from torch.optim.optimizer import Optimizer from tha4.shion.core.load_save import torch_save, torch_load from tha4.shion.core.module_accumulator import ModuleAccumulator from tha4.shion.core.module_factory import ModuleFactory from tha4.shion.core.optimizer_factory import OptimizerFactory from tha4.shion.core.training.util import optimizer_to_device class DistributedTrainingState: def __init__(self, examples_seen_so_far: int, modules: Dict[str, Module], accumulated_modules: Dict[str, Module], optimizers: Dict[str, Optimizer]): self.accumulated_modules = accumulated_modules self.optimizers = optimizers self.modules = modules self.examples_seen_so_far = examples_seen_so_far @staticmethod def get_examples_seen_so_far_file_name(prefix) -> str: return prefix + "/examples_seen_so_far.txt" @staticmethod def get_module_file_name(prefix, module_name) -> str: return "%s/module_%s.pt" % (prefix, module_name) @staticmethod def get_accumulated_module_file_name(prefix, module_name) -> str: return "%s/accumulated_%s.pt" % (prefix, module_name) @staticmethod def get_optimizer_file_name(prefix, module_name) -> str: return "%s/optimizer_%s.pt" % (prefix, module_name) @staticmethod def get_rng_state_file_name(prefix, rank: int): return "%s/rng_state_%08d.pt" % (prefix, rank) def mkdir(self, prefix: str): os.makedirs(prefix, exist_ok=True) def save_data(self, prefix: str, rank: int): assert os.path.exists(prefix) torch_save(torch.get_rng_state(), DistributedTrainingState.get_rng_state_file_name(prefix, rank)) logging.info("Saved %s" % DistributedTrainingState.get_rng_state_file_name(prefix, rank)) if rank == 0: logging.info("Saving training state to %s" % prefix) with open(DistributedTrainingState.get_examples_seen_so_far_file_name(prefix), "wt") as fout: fout.write("%d\n" % self.examples_seen_so_far) logging.info("Saved %s" % DistributedTrainingState.get_examples_seen_so_far_file_name(prefix)) for module_name in self.modules: file_name = DistributedTrainingState.get_module_file_name(prefix, module_name) module = self.modules[module_name] if isinstance(module, DistributedDataParallel): state_dict = module.module.state_dict() else: state_dict = module.state_dict() torch_save(state_dict, file_name) logging.info("Saved %s" % file_name) for module_name in self.accumulated_modules: file_name = DistributedTrainingState.get_accumulated_module_file_name(prefix, module_name) torch_save(self.accumulated_modules[module_name].state_dict(), file_name) logging.info("Saved %s" % file_name) for module_name in self.optimizers: file_name = DistributedTrainingState.get_optimizer_file_name(prefix, module_name) torch_save(self.optimizers[module_name].state_dict(), file_name) logging.info("Saved %s" % file_name) logging.info("Done saving training state to %s" % prefix) def save(self, prefix: str, rank: int, barrier_func: Callable[[], None]): if rank == 0: self.mkdir(prefix) barrier_func() self.save_data(prefix, rank) barrier_func() @staticmethod def get_examples_seen_so_far(prefix: str) -> int: with open(DistributedTrainingState.get_examples_seen_so_far_file_name(prefix)) as fin: lines = fin.readlines() return int(lines[0]) @staticmethod def load( prefix: str, module_factories: Dict[str, ModuleFactory], accumulators: Dict[str, ModuleAccumulator], optimizer_factories: Dict[str, OptimizerFactory], rank: int, local_rank: int, device: torch.device) -> 'DistributedTrainingState': logging.info("Loading training state from %s" % prefix) with open(DistributedTrainingState.get_examples_seen_so_far_file_name(prefix)) as fin: lines = fin.readlines() examples_seen_so_far = int(lines[0]) logging.info("Loaded %s" % DistributedTrainingState.get_examples_seen_so_far_file_name(prefix)) modules = { module_name: factory.create() for (module_name, factory) in module_factories.items() } for module_name in modules: file_name = DistributedTrainingState.get_module_file_name(prefix, module_name) module = modules[module_name] state_dict = torch_load(file_name) module.load_state_dict(state_dict) module.to(device) modules[module_name] = DistributedDataParallel( module, device_ids=[device.index], output_device=device.index) logging.info("Loaded %s" % file_name) accumulated_modules = {} for module_name in accumulators: module_factory = module_factories[module_name] module = module_factory.create() file_name = DistributedTrainingState.get_accumulated_module_file_name(prefix, module_name) module.load_state_dict(torch_load(file_name)) module.to(device) accumulated_modules[module_name] = module logging.info("Loaded %s" % file_name) optimizers = {} for module_name in optimizer_factories: optimizer = optimizer_factories[module_name].create(modules[module_name].parameters()) file_name = DistributedTrainingState.get_optimizer_file_name(prefix, module_name) optimizer.load_state_dict(torch_load(file_name)) optimizer_to_device(optimizer, device) optimizers[module_name] = optimizer logging.info("Loaded %s" % file_name) torch.set_rng_state(torch_load(DistributedTrainingState.get_rng_state_file_name(prefix, rank))) logging.info("Loaded %s" % DistributedTrainingState.get_rng_state_file_name(prefix, rank)) logging.info("Done loading training state from %s" % prefix) return DistributedTrainingState(examples_seen_so_far, modules, accumulated_modules, optimizers) @staticmethod def new(module_factories: Dict[str, ModuleFactory], accumulators: Dict[str, ModuleAccumulator], optimizer_factories: Dict[str, OptimizerFactory], random_seed: int, rank: int, local_rank: int, device: torch.device, pretrained_module_file_names: Optional[Dict[str, str]] = None) -> 'DistributedTrainingState': examples_seen_so_far = 0 modules = { module_name: factory.create() for (module_name, factory) in module_factories.items() } for module_name in modules: modules[module_name].to(device) if pretrained_module_file_names is not None: for module_name in modules: if module_name in pretrained_module_file_names: file_name = pretrained_module_file_names[module_name] modules[module_name].load_state_dict(torch_load(file_name)) logging.info("Loaded initial state from %s ..." % file_name) accumulated_modules = {} for module_name in accumulators: accumulated_modules[module_name] = copy.deepcopy(modules[module_name]) for module_name in modules: module = modules[module_name] modules[module_name] = DistributedDataParallel( module, device_ids=[device.index], output_device=device.index) optimizers = {} for module_name in optimizer_factories: module = modules[module_name] optimizer = optimizer_factories[module_name].create(module.parameters()) optimizer_to_device(optimizer, device) optimizers[module_name] = optimizer torch.manual_seed(random_seed + rank) return DistributedTrainingState(examples_seen_so_far, modules, accumulated_modules, optimizers) @staticmethod def can_load(prefix: str, module_factories: Dict[str, ModuleFactory], accumulators: Dict[str, ModuleAccumulator], optimizer_factories: Dict[str, OptimizerFactory], world_size: int) -> bool: logging.info(f"Checking directory {prefix}") if not os.path.isdir(prefix): logging.info(f"Cannot load files in {prefix} because it is not a directory") return False examples_seen_so_far_file_name = DistributedTrainingState.get_examples_seen_so_far_file_name(prefix) if not os.path.isfile(examples_seen_so_far_file_name): logging.info(f"Cannot load files in {prefix} because {examples_seen_so_far_file_name} is not a file.") return False for module_name in module_factories.keys(): file_name = DistributedTrainingState.get_module_file_name(prefix, module_name) if not os.path.isfile(file_name): logging.info(f"Cannot load files in {prefix} because {file_name} is not a file.") return False for module_name in accumulators: file_name = DistributedTrainingState.get_accumulated_module_file_name(prefix, module_name) if not os.path.isfile(file_name): logging.info(f"Cannot load files in {prefix} because {file_name} is not a file.") return False for module_name in optimizer_factories: file_name = DistributedTrainingState.get_optimizer_file_name(prefix, module_name) if not os.path.isfile(file_name): logging.info(f"Cannot load files in {prefix} because {file_name} is not a file.") return False for rank in range(world_size): file_name = DistributedTrainingState.get_rng_state_file_name(prefix, rank) if not os.path.isfile(file_name): logging.info(f"Cannot load files in {prefix} because {file_name} is not a file.") return False return True ================================================ FILE: src/tha4/shion/core/training/distrib/distributed_training_tasks.py ================================================ import logging import os import sys from typing import Callable, List, Optional from tha4.pytasuku.workspace import Workspace from tha4.shion.core.training.distrib.distributed_trainer import DistributedTrainer from tha4.shion.core.training.distrib.distributed_training_states import DistributedTrainingState def get_torchrun_executable(): return os.path.dirname(sys.executable) + os.path.sep + "torchrun" def run_distributed_training_script( training_script_file_name: str, num_nodes: int, node_rank: int, num_proc_per_node: int, master_addr: int = "127.0.0.1", master_port: int = 8888): command = f"{get_torchrun_executable()} " \ f"--nproc_per_node={num_proc_per_node} " \ f"--nnodes={num_nodes} " \ f"--node_rank={node_rank} " \ f"--master_addr={master_addr} " \ f"--master_port={master_port} " \ f"{training_script_file_name}" logging.info(f"Executing -- {command}") os.system(command) class RdzvConfig: def __init__(self, id: int, port: int): self.port = port self.id = id def run_standalone_distributed_training_script( training_script_file_name: str, num_proc_per_node: int, target_checkpoint_examples: Optional[int] = None, rdzv_config: Optional[RdzvConfig] = None): command = f"{get_torchrun_executable()} " \ f"--nnodes=1 " \ f"--nproc_per_node={num_proc_per_node} " if rdzv_config is not None: command += f"--rdzv_endpoint=localhost:{rdzv_config.port} " command += "--rdzv_backend=c10d " command += f"--rdzv_id={rdzv_config.id} " else: command += "--standalone " command += f"{training_script_file_name} " if target_checkpoint_examples is not None: command += f"--target_checkpoint_examples {target_checkpoint_examples} " logging.info(f"Executing -- {command}") os.system(command) def define_distributed_training_tasks( workspace: Workspace, prefix: str, training_script_file_name: str, num_nodes: int, num_proc_per_node: int, master_addr: int = "127.0.0.1", master_port: int = 8888): def run_training_script_func(rank: int): def _f(): run_distributed_training_script( training_script_file_name, num_nodes, rank, num_proc_per_node, master_addr, master_port) return _f for i in range(num_nodes): workspace.create_command_task(f"{prefix}/train_node_%06d" % i, [], run_training_script_func(i)) def define_standalone_distributed_training_tasks( workspace: Workspace, distributed_trainer_func: Callable[[int], DistributedTrainer], training_script_file_name: str, num_proc_per_node: int, dependencies: Optional[List[str]] = None, rdzv_config: Optional[RdzvConfig] = None): trainer = distributed_trainer_func(1) checkpoint_examples = trainer.training_protocol.get_checkpoint_examples() assert len(checkpoint_examples) >= 1 assert checkpoint_examples[0] > 0 checkpoint_examples = [0] + checkpoint_examples if dependencies is None: dependencies = [] module_file_dependencies = dependencies[:] for module_name in trainer.pretrained_module_file_names: module_file_dependencies.append(trainer.pretrained_module_file_names[module_name]) def create_train_func(target_checkpoint_examples: int): return lambda: run_standalone_distributed_training_script( training_script_file_name, num_proc_per_node, target_checkpoint_examples, rdzv_config=rdzv_config) train_tasks = [] for checkpoint_index in range(0, len(checkpoint_examples)): for module_name in trainer.module_names: module_file_name = DistributedTrainingState.get_module_file_name( trainer.get_checkpoint_prefix(checkpoint_index), module_name) workspace.create_file_task( module_file_name, module_file_dependencies, create_train_func(trainer.checkpoint_examples[checkpoint_index])) for module_name in trainer.accumulators: accumulated_module_file_name = DistributedTrainingState.get_accumulated_module_file_name( trainer.get_checkpoint_prefix(checkpoint_index), module_name) workspace.create_file_task( accumulated_module_file_name, module_file_dependencies, create_train_func(checkpoint_examples[checkpoint_index])) workspace.create_command_task( trainer.get_checkpoint_prefix(checkpoint_index) + "/train_standalone", module_file_dependencies, create_train_func(checkpoint_examples[checkpoint_index])) train_tasks.append(trainer.get_checkpoint_prefix(checkpoint_index) + "/train_standlone") workspace.create_file_task( trainer.prefix + "/train_standalone", module_file_dependencies, create_train_func(checkpoint_examples[-1])) if __name__ == "__main__": print(os.path.dirname(sys.executable) + os.path.sep + "torchrun") ================================================ FILE: src/tha4/shion/core/training/sample_output_protocol.py ================================================ from abc import ABC, abstractmethod from typing import Dict, Any import torch from torch.nn import Module from torch.utils.data import Dataset class SampleOutputProtocol(ABC): @abstractmethod def get_examples_per_sample_output(self) -> int: pass @abstractmethod def get_random_seed(self) -> int: pass @abstractmethod def get_sample_output_data(self, validation_dataset: Dataset, device: torch.device) -> Any: pass @abstractmethod def save_sample_output_data( self, modules: Dict[str, Module], accumulated_modules: Dict[str, Module], sample_output_data: Any, prefix: str, examples_seen_so_far: int, device: torch.device): pass class AbstractSampleOutputProtocol(SampleOutputProtocol, ABC): def __init__(self, examples_per_sample_output: int, random_seed: int): self.random_seed = random_seed self.examples_per_sample_output = examples_per_sample_output def get_examples_per_sample_output(self) -> int: return self.examples_per_sample_output def get_random_seed(self) -> int: return self.random_seed ================================================ FILE: src/tha4/shion/core/training/single/__init__.py ================================================ ================================================ FILE: src/tha4/shion/core/training/single/training_states.py ================================================ import copy import logging import os from typing import Dict, Optional import torch from torch.nn import Module from torch.optim import Optimizer from tha4.shion.core.load_save import torch_save, torch_load from tha4.shion.core.module_accumulator import ModuleAccumulator from tha4.shion.core.module_factory import ModuleFactory from tha4.shion.core.optimizer_factory import OptimizerFactory from tha4.shion.core.training.util import optimizer_to_device class TrainingState: def __init__(self, examples_seen_so_far: int, modules: Dict[str, Module], accumulated_modules: Dict[str, Module], optimizers: Dict[str, Optimizer]): self.accumulated_modules = accumulated_modules self.optimizers = optimizers self.modules = modules self.examples_seen_so_far = examples_seen_so_far @staticmethod def get_examples_seen_so_far_file_name(prefix) -> str: return prefix + "/examples_seen_so_far.txt" @staticmethod def get_module_file_name(prefix, module_name) -> str: return "%s/module_%s.pt" % (prefix, module_name) @staticmethod def get_accumulated_module_file_name(prefix, module_name) -> str: return "%s/accumulated_%s.pt" % (prefix, module_name) @staticmethod def get_optimizer_file_name(prefix, module_name) -> str: return "%s/optimizer_%s.pt" % (prefix, module_name) @staticmethod def get_rng_state_file_name(prefix): return "%s/rng_state.pt" % prefix def save(self, prefix): logging.info("Saving training state to %s" % prefix) os.makedirs(prefix, exist_ok=True) with open(TrainingState.get_examples_seen_so_far_file_name(prefix), "wt") as fout: fout.write("%d\n" % self.examples_seen_so_far) logging.info("Saved %s" % TrainingState.get_examples_seen_so_far_file_name(prefix)) for module_name in self.modules: file_name = TrainingState.get_module_file_name(prefix, module_name) torch_save(self.modules[module_name].state_dict(), file_name) logging.info("Saved %s" % file_name) for module_name in self.accumulated_modules: file_name = TrainingState.get_accumulated_module_file_name(prefix, module_name) torch_save(self.accumulated_modules[module_name].state_dict(), file_name) logging.info("Saved %s" % file_name) for module_name in self.optimizers: file_name = TrainingState.get_optimizer_file_name(prefix, module_name) torch_save(self.optimizers[module_name].state_dict(), file_name) logging.info("Saved %s" % file_name) torch_save(torch.get_rng_state(), TrainingState.get_rng_state_file_name(prefix)) logging.info("Saved %s" % TrainingState.get_rng_state_file_name(prefix)) logging.info("Done saving training state to %s" % prefix) @staticmethod def get_examples_seen_so_far(prefix: str) -> int: with open(TrainingState.get_examples_seen_so_far_file_name(prefix)) as fin: lines = fin.readlines() return int(lines[0]) @staticmethod def load(prefix: str, module_factories: Dict[str, ModuleFactory], accumulators: Dict[str, ModuleAccumulator], optimizer_factories: Dict[str, OptimizerFactory], device: torch.device) -> 'TrainingState': logging.info("Loading training state from %s" % prefix) with open(TrainingState.get_examples_seen_so_far_file_name(prefix)) as fin: lines = fin.readlines() examples_seen_so_far = int(lines[0]) logging.info("Loaded %s" % TrainingState.get_examples_seen_so_far_file_name(prefix)) modules = { module_name: factory.create() for (module_name, factory) in module_factories.items() } for module_name in modules: file_name = TrainingState.get_module_file_name(prefix, module_name) modules[module_name].load_state_dict(torch_load(file_name)) modules[module_name].to(device) logging.info("Loaded %s" % file_name) accumulated_modules = {} for module_name in accumulators: module_factory = module_factories[module_name] module = module_factory.create() file_name = TrainingState.get_accumulated_module_file_name(prefix, module_name) module.load_state_dict(torch_load(file_name)) module.to(device) accumulated_modules[module_name] = module logging.info("Loaded %s" % file_name) optimizers = {} for module_name in optimizer_factories: optimizer = optimizer_factories[module_name].create(modules[module_name].parameters()) file_name = TrainingState.get_optimizer_file_name(prefix, module_name) optimizer.load_state_dict(torch_load(file_name)) optimizer_to_device(optimizer, device) optimizers[module_name] = optimizer logging.info("Loaded %s" % file_name) torch.set_rng_state(torch_load(TrainingState.get_rng_state_file_name(prefix))) logging.info("Loaded %s" % TrainingState.get_rng_state_file_name(prefix)) logging.info("Done loading training state from %s" % prefix) return TrainingState(examples_seen_so_far, modules, accumulated_modules, optimizers) @staticmethod def new(module_factories: Dict[str, ModuleFactory], accumulators: Dict[str, ModuleAccumulator], optimizer_factories: Dict[str, OptimizerFactory], random_seed: int, device: torch.device, pretrained_module_file_names: Optional[Dict[str, str]] = None) -> 'TrainingState': examples_seen_so_far = 0 modules = { module_name: factory.create() for (module_name, factory) in module_factories.items() } for module_name in modules: modules[module_name].to(device) if pretrained_module_file_names is not None: for module_name in modules: if module_name in pretrained_module_file_names: file_name = pretrained_module_file_names[module_name] modules[module_name].load_state_dict(torch_load(file_name)) logging.info("Loaded initial state from %s ..." % file_name) accumulated_modules = {} for module_name in accumulators: accumulated_modules[module_name] = copy.deepcopy(modules[module_name]) optimizers = {} for module_name in optimizer_factories: module = modules[module_name] optimizer = optimizer_factories[module_name].create(module.parameters()) optimizer_to_device(optimizer, device) optimizers[module_name] = optimizer torch.manual_seed(random_seed) return TrainingState(examples_seen_so_far, modules, accumulated_modules, optimizers) @staticmethod def can_load(prefix: str, module_factories: Dict[str, ModuleFactory], accumulators: Dict[str, ModuleAccumulator], optimizer_factories: Dict[str, OptimizerFactory]) -> bool: if not os.path.isdir(prefix): return False if not os.path.isfile(TrainingState.get_examples_seen_so_far_file_name(prefix)): return False for module_name in module_factories.keys(): if not os.path.isfile(TrainingState.get_module_file_name(prefix, module_name)): return False for module_name in accumulators: if not os.path.isfile(TrainingState.get_accumulated_module_file_name(prefix, module_name)): return False for module_name in optimizer_factories: if not os.path.isfile(TrainingState.get_optimizer_file_name(prefix, module_name)): return False if not os.path.isfile(TrainingState.get_rng_state_file_name(prefix)): return False return True ================================================ FILE: src/tha4/shion/core/training/single/training_tasks.py ================================================ import logging import time from datetime import datetime from typing import Optional, Dict, List import torch from torch.utils.data import Dataset, DataLoader from torch.utils.tensorboard import SummaryWriter from tha4.pytasuku.workspace import Workspace from tha4.shion.core.load_save import torch_save, torch_load from tha4.shion.core.loss import Loss from tha4.shion.core.module_accumulator import ModuleAccumulator from tha4.shion.core.module_factory import ModuleFactory from tha4.shion.core.training.sample_output_protocol import SampleOutputProtocol from tha4.shion.core.training.single.training_states import TrainingState from tha4.shion.core.training.training_protocol import TrainingProtocol from tha4.shion.core.training.util import get_least_greater_multiple, create_log_func, set_learning_rate from tha4.shion.core.training.validation_protocol import ValidationProtocol KEY_CHECKPOINT = 'checkpoint' KEY_SNAPSHOT = 'snapshot' KEY_VALIDATION = 'validation' KEY_SAMPLE_OUTPUT = 'sample_output' class TrainingTasks: def __init__( self, workspace: Workspace, prefix: str, module_factories: Dict[str, ModuleFactory], accumulators: Dict[str, ModuleAccumulator], losses: Dict[str, Loss], training_dataset: Dataset, validation_dataset: Optional[Dataset], training_protocol: TrainingProtocol, validation_protocol: Optional[ValidationProtocol], sample_output_protocol: Optional[SampleOutputProtocol], pretrained_module_file_names: Dict[str, str], example_per_snapshot: int, device: torch.device, num_data_loader_workers: int = 8, dependencies: Optional[List[str]] = None): super().__init__() self.num_data_loader_workers = num_data_loader_workers self.accumulators = accumulators self.device = device self.sample_output_protocol = sample_output_protocol self.example_per_snapshot = example_per_snapshot self.pretrained_module_file_names = pretrained_module_file_names self.losses = losses self.validation_protocol = validation_protocol self.training_protocol = training_protocol self.module_factories = module_factories self.prefix = prefix self.training_dataset = training_dataset self.validation_dataset = validation_dataset self.checkpoint_examples = self.training_protocol.get_checkpoint_examples() assert len(self.checkpoint_examples) >= 1 assert self.checkpoint_examples[0] > 0 self.checkpoint_examples = [0] + self.checkpoint_examples self.module_names = self.module_factories.keys() assert len(self.module_names) > 0 self.training_data_loader = None self.training_data_loader_iter = None self.training_data_loader_batch_size = None self.validation_data_loader = None self.validation_data_loader_iter = None self.validation_data_loader_batch_size = None self.sample_output_data = None self.summary_writer = None self.log_dir = None self.training_state = None if dependencies is None: dependencies = [] self.sample_output_data_task = workspace.create_file_task( self.get_sample_output_data_file_name(), dependencies, self.save_sample_output_data) module_file_dependencies = [self.sample_output_data_task.name] for module_name in pretrained_module_file_names: module_file_dependencies.append(self.pretrained_module_file_names[module_name]) def create_train_func(target_examples: int): return lambda: self.train(target_examples) train_tasks = [] for checkpoint_index in range(1, len(self.checkpoint_examples)): for module_name in self.module_names: module_file_name = TrainingState.get_module_file_name( self.get_checkpoint_prefix(checkpoint_index), module_name) workspace.create_file_task( module_file_name, module_file_dependencies, create_train_func(self.checkpoint_examples[checkpoint_index])) for module_name in self.accumulators: accumulated_module_file_name = TrainingState.get_accumulated_module_file_name( self.get_checkpoint_prefix(checkpoint_index), module_name) workspace.create_file_task( accumulated_module_file_name, module_file_dependencies, create_train_func(self.checkpoint_examples[checkpoint_index])) workspace.create_command_task( self.get_checkpoint_prefix(checkpoint_index) + "/train", module_file_dependencies, create_train_func(self.checkpoint_examples[checkpoint_index])) train_tasks.append(self.get_checkpoint_prefix(checkpoint_index) + "/train") self.train_task = workspace.create_file_task( self.get_train_command_name(), module_file_dependencies, create_train_func(self.checkpoint_examples[-1])) def get_sample_output_data_file_name(self): return self.prefix + "/sample_output_data.pt" def save_sample_output_data(self): if self.sample_output_protocol is not None: torch.manual_seed(self.sample_output_protocol.get_random_seed()) sample_output_data = self.sample_output_protocol.get_sample_output_data(self.validation_dataset, self.device) torch_save(sample_output_data, self.get_sample_output_data_file_name()) else: torch_save({}, self.get_sample_output_data_file_name()) def get_module_file_name(self, checkpoint_index, module_name): return TrainingState.get_module_file_name(self.get_checkpoint_prefix(checkpoint_index), module_name) def get_last_module_file_name(self, module_name): return self.get_module_file_name(len(self.checkpoint_examples) - 1, module_name) def get_log_dir(self): if self.log_dir is None: now = datetime.now() self.log_dir = self.prefix + "/log/" + now.strftime("%Y_%m_%d__%H_%M_%S") return self.log_dir def get_summary_writer(self) -> SummaryWriter: if self.summary_writer is None: self.summary_writer = SummaryWriter(log_dir=self.get_log_dir()) return self.summary_writer def get_train_command_name(self) -> str: return self.prefix + "/train" def get_snapshot_prefix(self) -> str: return self.prefix + "/snapshot" def get_checkpoint_prefix(self, checkpoint_index) -> str: return "%s/checkpoint/%04d" % (self.prefix, checkpoint_index) def can_load_training_state(self, prefix) -> bool: return TrainingState.can_load( prefix, self.module_factories, self.accumulators, self.training_protocol.get_optimizer_factories()) def load_training_state(self, prefix) -> TrainingState: return TrainingState.load( prefix, self.module_factories, self.accumulators, self.training_protocol.get_optimizer_factories(), self.device) def get_initial_training_state(self) -> TrainingState: training_state = TrainingState.new( self.module_factories, self.accumulators, self.training_protocol.get_optimizer_factories(), self.training_protocol.get_random_seed(), self.device, self.pretrained_module_file_names) logging.info("Created a new initial training state.") return training_state def load_previous_training_state(self, target_checkpoint_examples: int) -> TrainingState: if self.can_load_training_state(self.get_snapshot_prefix()): examples_seen_so_far = TrainingState.get_examples_seen_so_far(self.get_snapshot_prefix()) diff = examples_seen_so_far - target_checkpoint_examples if diff < self.training_protocol.get_batch_size(): return self.load_training_state(self.get_snapshot_prefix()) num_checkpoints = len(self.checkpoint_examples) for checkpoint_index in range(num_checkpoints - 1, -1, -1): if self.can_load_training_state(self.get_checkpoint_prefix(checkpoint_index)): examples_seen_so_far = TrainingState.get_examples_seen_so_far( self.get_checkpoint_prefix(checkpoint_index)) diff = examples_seen_so_far - target_checkpoint_examples if diff < self.training_protocol.get_batch_size(): return self.load_training_state(self.get_checkpoint_prefix(checkpoint_index)) return self.get_initial_training_state() def get_next_checkpoint_num_examples(self, examples_seen_so_far) -> int: next_index = next( (i for i in range(len(self.checkpoint_examples)) if self.checkpoint_examples[i] > examples_seen_so_far), -1) return self.checkpoint_examples[next_index] def get_next_snapshot_num_examples(self, examples_seen_so_far) -> int: return get_least_greater_multiple(examples_seen_so_far, self.example_per_snapshot) def get_next_validation_num_examples(self, examples_seen_so_far) -> int: if self.validation_protocol is None: return -1 return get_least_greater_multiple(examples_seen_so_far, self.validation_protocol.get_examples_per_validation_iteration()) def get_next_sample_output_num_examples(self, examples_seen_so_far) -> int: if self.sample_output_protocol is None: return -1 return get_least_greater_multiple(examples_seen_so_far, self.sample_output_protocol.get_examples_per_sample_output()) def get_next_num_examples(self, examples_seen_so_far) -> Dict[str, int]: return { KEY_CHECKPOINT: self.get_next_checkpoint_num_examples(examples_seen_so_far), KEY_SNAPSHOT: self.get_next_snapshot_num_examples(examples_seen_so_far), KEY_VALIDATION: self.get_next_validation_num_examples(examples_seen_so_far), KEY_SAMPLE_OUTPUT: self.get_next_sample_output_num_examples(examples_seen_so_far) } def get_checkpoint_index_to_save(self, examples_seen_so_far: int) -> int: checkpoint_index = 0 for i in range(len(self.checkpoint_examples)): if self.checkpoint_examples[i] <= examples_seen_so_far: checkpoint_index = i return checkpoint_index def get_next_training_batch(self): if self.training_data_loader is None: self.training_data_loader = DataLoader( self.training_dataset, batch_size=self.training_protocol.get_batch_size(), shuffle=True, num_workers=self.num_data_loader_workers, drop_last=True) if self.training_data_loader_iter is None: self.training_data_loader_iter = iter(self.training_data_loader) try: batch = next(self.training_data_loader_iter) except StopIteration: self.training_data_loader_iter = iter(self.training_data_loader) batch = next(self.training_data_loader_iter) return [x.to(self.device) for x in batch] def get_next_validation_batch(self): if self.validation_dataset is None: return None if self.validation_data_loader is None: self.validation_data_loader = DataLoader( self.validation_dataset, batch_size=self.validation_protocol.get_batch_size(), shuffle=True, num_workers=self.num_data_loader_workers, drop_last=True) if self.validation_data_loader_iter is None: self.validation_data_loader_iter = iter(self.validation_data_loader) try: batch = next(self.validation_data_loader_iter) except StopIteration: self.validation_data_loader_iter = iter(self.validation_data_loader) batch = next(self.validation_data_loader_iter) return [x.to(self.device) for x in batch] def get_checkpoint_index(self, target_checkpoint_examples: int): return self.checkpoint_examples.index(target_checkpoint_examples) def train(self, target_checkpoint_examples: Optional[int] = None): if target_checkpoint_examples is None: target_checkpoint_examples = self.checkpoint_examples[-1] sample_output_data = torch_load(self.get_sample_output_data_file_name()) logging.info("Loaded sampled output data from %s", self.get_sample_output_data_file_name()) training_state = self.load_previous_training_state(target_checkpoint_examples) summary_writer = self.get_summary_writer() last_time = time.time() while training_state.examples_seen_so_far < target_checkpoint_examples: # One training iteration learning_rate = self.training_protocol.get_learning_rate(training_state.examples_seen_so_far) for module_name in self.module_factories.keys(): if module_name not in learning_rate or module_name not in training_state.optimizers: continue lr = learning_rate[module_name] set_learning_rate(training_state.optimizers[module_name], lr) self.get_summary_writer().add_scalar( module_name + "_learning_rate", lr, training_state.examples_seen_so_far) training_batch = self.get_next_training_batch() self.training_protocol.run_training_iteration( training_batch, training_state.examples_seen_so_far, training_state.modules, training_state.accumulated_modules, training_state.optimizers, self.losses, lambda name, num: create_log_func(summary_writer, name, num), self.device) # Accumulate model data for module_name in self.accumulators: new_module = training_state.modules[module_name] buffer_module = training_state.accumulated_modules[module_name] self.accumulators[module_name].accumulate( new_module, buffer_module, training_state.examples_seen_so_far) # Advance the number of examples seen so far next_num_examples = self.get_next_num_examples(training_state.examples_seen_so_far) training_state.examples_seen_so_far += self.training_protocol.get_batch_size() # Validation iteration if self.validation_protocol is not None \ and training_state.examples_seen_so_far >= next_num_examples[KEY_VALIDATION]: validation_batch = self.get_next_validation_batch() self.validation_protocol.run_validation_iteration( validation_batch, training_state.examples_seen_so_far, training_state.modules, training_state.accumulated_modules, self.losses, lambda name, num: create_log_func(summary_writer, name, num), self.device) # Save sample output if self.sample_output_protocol is not None \ and training_state.examples_seen_so_far >= next_num_examples[KEY_SAMPLE_OUTPUT]: self.sample_output_protocol.save_sample_output_data( training_state.modules, training_state.accumulated_modules, sample_output_data, self.prefix + "/sample_outputs", training_state.examples_seen_so_far, self.device) # Save checkpoint if training_state.examples_seen_so_far >= next_num_examples[KEY_CHECKPOINT]: checkpoint_index = self.get_checkpoint_index_to_save(training_state.examples_seen_so_far) training_state.save(self.get_checkpoint_prefix(checkpoint_index)) if next_num_examples[KEY_CHECKPOINT] != next_num_examples[KEY_SNAPSHOT]: training_state.save(self.get_snapshot_prefix()) # Save snapshot if training_state.examples_seen_so_far >= next_num_examples[KEY_SNAPSHOT]: training_state.save(self.get_snapshot_prefix()) now = time.time() if now - last_time > 10: logging.info("Showed %d training examples." % training_state.examples_seen_so_far) last_time = now ================================================ FILE: src/tha4/shion/core/training/swarm/__init__.py ================================================ ================================================ FILE: src/tha4/shion/core/training/swarm/swarm_training_tasks.py ================================================ from typing import Callable, Optional, List from tha4.pytasuku.workspace import Workspace from tha4.shion.core.training.distrib.distributed_training_tasks import RdzvConfig, \ run_standalone_distributed_training_script from tha4.shion.core.training.single.training_states import TrainingState from tha4.shion.core.training.swarm.swarm_unit_trainer import SwarmUnitTrainer def define_standalone_swarm_training_tasks( workspace: Workspace, swarm_unit_trainer_func: Callable[[], SwarmUnitTrainer], training_script_file_name: str, num_proc_per_node: int, dependencies: Optional[List[str]] = None, rdzv_config: Optional[RdzvConfig] = None): trainer = swarm_unit_trainer_func() checkpoint_examples = trainer.training_protocol.get_checkpoint_examples() assert len(checkpoint_examples) >= 1 assert checkpoint_examples[0] > 0 checkpoint_examples = [0] + checkpoint_examples if dependencies is None: dependencies = [] module_file_dependencies = dependencies[:] for module_name in trainer.pretrained_module_file_names: module_file_dependencies.append(trainer.pretrained_module_file_names[module_name]) def create_train_func(target_checkpoint_examples: int): return lambda: run_standalone_distributed_training_script( training_script_file_name, num_proc_per_node, target_checkpoint_examples, rdzv_config=rdzv_config) train_tasks = [] for checkpoint_index in range(0, len(checkpoint_examples)): for module_name in trainer.module_names: module_file_name = TrainingState.get_module_file_name( trainer.get_checkpoint_prefix(checkpoint_index), module_name) workspace.create_file_task( module_file_name, module_file_dependencies, create_train_func(trainer.checkpoint_examples[checkpoint_index])) for module_name in trainer.accumulators: accumulated_module_file_name = TrainingState.get_accumulated_module_file_name( trainer.get_checkpoint_prefix(checkpoint_index), module_name) workspace.create_file_task( accumulated_module_file_name, module_file_dependencies, create_train_func(checkpoint_examples[checkpoint_index])) workspace.create_command_task( trainer.get_checkpoint_prefix(checkpoint_index) + "/train_standalone", module_file_dependencies, create_train_func(checkpoint_examples[checkpoint_index])) train_tasks.append(trainer.get_checkpoint_prefix(checkpoint_index) + "/train_standlone") workspace.create_file_task( trainer.prefix + "/train_standalone", module_file_dependencies, create_train_func(checkpoint_examples[-1])) ================================================ FILE: src/tha4/shion/core/training/swarm/swarm_unit_trainer.py ================================================ import argparse import logging import os import time from datetime import datetime from typing import Dict, Optional, Callable import torch.distributed import torch from torch.utils.data import Dataset, DataLoader from torch.utils.tensorboard import SummaryWriter from tha4.shion.core.load_save import torch_save, torch_load from tha4.shion.core.loss import Loss from tha4.shion.core.module_accumulator import ModuleAccumulator from tha4.shion.core.module_factory import ModuleFactory from tha4.shion.core.training.distrib.device_mapper import SimpleCudaDeviceMapper from tha4.shion.core.training.sample_output_protocol import SampleOutputProtocol from tha4.shion.core.training.single.training_states import TrainingState from tha4.shion.core.training.single.training_tasks import KEY_CHECKPOINT, KEY_SNAPSHOT, KEY_VALIDATION, KEY_SAMPLE_OUTPUT from tha4.shion.core.training.training_protocol import TrainingProtocol from tha4.shion.core.training.util import get_least_greater_multiple, create_log_func, set_learning_rate from tha4.shion.core.training.validation_protocol import ValidationProtocol class SwarmUnitTrainer: def __init__(self, prefix: str, module_factories: Dict[str, ModuleFactory], accumulators: Dict[str, ModuleAccumulator], losses: Dict[str, Loss], training_dataset: Dataset, validation_dataset: Optional[Dataset], training_protocol: TrainingProtocol, validation_protocol: Optional[ValidationProtocol], sample_output_protocol: Optional[SampleOutputProtocol], pretrained_module_file_names: Dict[str, str], example_per_snapshot: int, num_data_loader_workers: int = 8): self.num_data_loader_workers = num_data_loader_workers self.accumulators = accumulators self.sample_output_protocol = sample_output_protocol self.example_per_snapshot = example_per_snapshot self.pretrained_module_file_names = pretrained_module_file_names self.losses = losses self.validation_protocol = validation_protocol self.training_protocol = training_protocol self.module_factories = module_factories self.prefix = prefix self.training_dataset = training_dataset self.validation_dataset = validation_dataset self.checkpoint_examples = self.training_protocol.get_checkpoint_examples() assert len(self.checkpoint_examples) >= 1 assert self.checkpoint_examples[0] > 0 self.checkpoint_examples = [0] + self.checkpoint_examples self.module_names = self.module_factories.keys() assert len(self.module_names) > 0 self.training_data_loader = None self.training_data_loader_iter = None self.training_data_loader_batch_size = None self.training_data_sampler = None self.validation_data_loader = None self.validation_data_loader_iter = None self.validation_data_loader_batch_size = None self.sample_output_data = None self.summary_writer = None self.log_dir = None self.training_state = None def get_sample_output_data_file_name(self): return self.prefix + "/sample_output_data.pt" def save_sample_output_data(self, device: torch.device): if os.path.exists(self.get_sample_output_data_file_name()): return if self.sample_output_protocol is not None: torch.manual_seed(self.sample_output_protocol.get_random_seed()) sample_output_data = self.sample_output_protocol.get_sample_output_data(self.validation_dataset, device) torch_save(sample_output_data, self.get_sample_output_data_file_name()) else: torch_save({}, self.get_sample_output_data_file_name()) def load_sample_output_data(self, device: torch.device): self.save_sample_output_data(device) return torch_load(self.get_sample_output_data_file_name()) def get_snapshot_prefix(self) -> str: return self.prefix + "/snapshot" def can_load_training_state(self, prefix: str) -> bool: return TrainingState.can_load( prefix, self.module_factories, self.accumulators, self.training_protocol.get_optimizer_factories()) def load_training_state(self, prefix, device: torch.device) -> TrainingState: return TrainingState.load( prefix, self.module_factories, self.accumulators, self.training_protocol.get_optimizer_factories(), device) @staticmethod def checkpoint_prefix(prefix: str, checkpoint_index: int) -> str: return "%s/checkpoint/%04d" % (prefix, checkpoint_index) def get_checkpoint_prefix(self, checkpoint_index) -> str: return SwarmUnitTrainer.checkpoint_prefix(self.prefix, checkpoint_index) def get_initial_training_state(self, device: torch.device) -> TrainingState: training_state = TrainingState.new( self.module_factories, self.accumulators, self.training_protocol.get_optimizer_factories(), self.training_protocol.get_random_seed(), device, self.pretrained_module_file_names) logging.info("Created a new initial training state.") return training_state def load_previous_training_state(self, target_checkpoint_examples: int, device: torch.device) -> TrainingState: if self.can_load_training_state(self.get_snapshot_prefix()): examples_seen_so_far = TrainingState.get_examples_seen_so_far(self.get_snapshot_prefix()) diff = examples_seen_so_far - target_checkpoint_examples if diff < self.training_protocol.get_batch_size(): return self.load_training_state(self.get_snapshot_prefix(), device) num_checkpoints = len(self.checkpoint_examples) for checkpoint_index in range(num_checkpoints - 1, -1, -1): if self.can_load_training_state(self.get_checkpoint_prefix(checkpoint_index)): examples_seen_so_far = TrainingState.get_examples_seen_so_far( self.get_checkpoint_prefix(checkpoint_index)) diff = examples_seen_so_far - target_checkpoint_examples if diff < self.training_protocol.get_batch_size(): return self.load_training_state( self.get_checkpoint_prefix(checkpoint_index), device) training_state = self.get_initial_training_state(device) training_state.save(self.get_checkpoint_prefix(0)) training_state = self.load_training_state(self.get_checkpoint_prefix(0), device) return training_state def get_log_dir(self): if self.log_dir is None: now = datetime.now() self.log_dir = self.prefix + "/log/" + now.strftime("%Y_%m_%d__%H_%M_%S") return self.log_dir def get_summary_writer(self) -> Optional[SummaryWriter]: if self.summary_writer is None: self.summary_writer = SummaryWriter(log_dir=self.get_log_dir()) return self.summary_writer def get_next_training_batch(self, device: torch.device): if self.training_data_loader is None: self.training_data_loader = DataLoader( self.training_dataset, batch_size=self.training_protocol.get_batch_size(), shuffle=True, num_workers=self.num_data_loader_workers, drop_last=True) if self.training_data_loader_iter is None: self.training_data_loader_iter = iter(self.training_data_loader) try: batch = next(self.training_data_loader_iter) except StopIteration: self.training_data_loader_iter = iter(self.training_data_loader) batch = next(self.training_data_loader_iter) return [x.to(device) for x in batch] def get_next_checkpoint_num_examples(self, examples_seen_so_far) -> int: next_index = next( (i for i in range(len(self.checkpoint_examples)) if self.checkpoint_examples[i] > examples_seen_so_far), -1) return self.checkpoint_examples[next_index] def get_next_snapshot_num_examples(self, examples_seen_so_far) -> int: return get_least_greater_multiple(examples_seen_so_far, self.example_per_snapshot) def get_next_validation_num_examples(self, examples_seen_so_far) -> int: if self.validation_protocol is None: return -1 return get_least_greater_multiple(examples_seen_so_far, self.validation_protocol.get_examples_per_validation_iteration()) def get_next_sample_output_num_examples(self, examples_seen_so_far) -> int: if self.sample_output_protocol is None: return -1 return get_least_greater_multiple(examples_seen_so_far, self.sample_output_protocol.get_examples_per_sample_output()) def get_next_num_examples(self, examples_seen_so_far) -> Dict[str, int]: return { KEY_CHECKPOINT: self.get_next_checkpoint_num_examples(examples_seen_so_far), KEY_SNAPSHOT: self.get_next_snapshot_num_examples(examples_seen_so_far), KEY_VALIDATION: self.get_next_validation_num_examples(examples_seen_so_far), KEY_SAMPLE_OUTPUT: self.get_next_sample_output_num_examples(examples_seen_so_far) } def get_next_validation_batch(self, device: torch.device): if self.validation_dataset is None: return None if self.validation_data_loader is None: self.validation_data_loader = DataLoader( self.validation_dataset, batch_size=self.validation_protocol.get_batch_size(), shuffle=True, num_workers=1, drop_last=True) if self.validation_data_loader_iter is None: self.validation_data_loader_iter = iter(self.validation_data_loader) try: batch = next(self.validation_data_loader_iter) except StopIteration: self.validation_data_loader_iter = iter(self.validation_data_loader) batch = next(self.validation_data_loader_iter) return [x.to(device) for x in batch] def get_checkpoint_index_to_save(self, examples_seen_so_far: int) -> int: checkpoint_index = 0 for i in range(len(self.checkpoint_examples)): if self.checkpoint_examples[i] <= examples_seen_so_far: checkpoint_index = i return checkpoint_index def train(self, rank: int, local_rank: int, target_checkpoint_examples: Optional[int] = None, device_mapper: Optional[Callable[[int, int], torch.device]] = None): if target_checkpoint_examples is None: target_checkpoint_examples = self.checkpoint_examples[-1] if device_mapper is None: device_mapper = SimpleCudaDeviceMapper() device = device_mapper(rank, local_rank) sample_output_data = self.load_sample_output_data(device) training_state = self.load_previous_training_state( target_checkpoint_examples, device) summary_writer = self.get_summary_writer() if summary_writer is not None: log_func_factory = lambda name, num: create_log_func(summary_writer, name, num) else: log_func_factory = None last_time = time.time() while training_state.examples_seen_so_far < target_checkpoint_examples: # Set the learning rate learning_rate_by_module_name = self.training_protocol.get_learning_rate(training_state.examples_seen_so_far) for module_name in self.module_factories.keys(): if module_name not in learning_rate_by_module_name or module_name not in training_state.optimizers: continue lr = learning_rate_by_module_name[module_name] set_learning_rate(training_state.optimizers[module_name], lr) if summary_writer is not None: summary_writer.add_scalar( module_name + "_learning_rate", lr, training_state.examples_seen_so_far) # One training iteration training_batch = self.get_next_training_batch(device) self.training_protocol.run_training_iteration( training_batch, training_state.examples_seen_so_far, training_state.modules, training_state.accumulated_modules, training_state.optimizers, self.losses, log_func_factory, device) # Accumulate model data for module_name in self.accumulators: new_module = training_state.modules[module_name] buffer_module = training_state.accumulated_modules[module_name] self.accumulators[module_name].accumulate( new_module, buffer_module, examples_seen_so_far=training_state.examples_seen_so_far) # Advance the number of examples seen so far next_num_examples = self.get_next_num_examples(training_state.examples_seen_so_far) training_state.examples_seen_so_far += self.training_protocol.get_batch_size() # Validation iteration if self.validation_protocol is not None \ and training_state.examples_seen_so_far >= next_num_examples[KEY_VALIDATION]: validation_batch = self.get_next_validation_batch(device) self.validation_protocol.run_validation_iteration( validation_batch, training_state.examples_seen_so_far, training_state.modules, training_state.accumulated_modules, self.losses, log_func_factory, device) # Save sample output if self.sample_output_protocol is not None \ and training_state.examples_seen_so_far >= next_num_examples[KEY_SAMPLE_OUTPUT]: self.sample_output_protocol.save_sample_output_data( training_state.modules, training_state.accumulated_modules, sample_output_data, self.prefix + "/sample_outputs", training_state.examples_seen_so_far, device) # Save checkpoint if training_state.examples_seen_so_far >= next_num_examples[KEY_CHECKPOINT]: checkpoint_index = self.get_checkpoint_index_to_save(training_state.examples_seen_so_far) training_state.save(self.get_checkpoint_prefix(checkpoint_index)) if next_num_examples[KEY_CHECKPOINT] != next_num_examples[KEY_SNAPSHOT]: training_state.save(self.get_snapshot_prefix()) # Save snapshot if training_state.examples_seen_so_far >= next_num_examples[KEY_SNAPSHOT]: training_state.save(self.get_snapshot_prefix()) now = time.time() if now - last_time > 10: logging.info("[Rank %d] Showed %d training examples." % (rank, training_state.examples_seen_so_far)) last_time = now @staticmethod def run(trainer_factory: Dict[int, Callable[[], 'SwarmUnitTrainer']], backend: str = 'gloo', device_mapper: Optional[Callable[[int, int], torch.device]] = None): parser = argparse.ArgumentParser(description='Training script.') parser.add_argument("--target_checkpoint_examples", type=int) args = parser.parse_args() rank = int(os.environ['RANK']) local_rank = int(os.environ['LOCAL_RANK']) torch.distributed.init_process_group(backend) if rank in trainer_factory: trainer = trainer_factory[rank]() trainer.train(rank, local_rank, args.target_checkpoint_examples, device_mapper) ================================================ FILE: src/tha4/shion/core/training/training_protocol.py ================================================ from abc import ABC, abstractmethod from typing import Dict, List, Callable, Any, Optional import torch from torch.nn import Module from torch.optim.optimizer import Optimizer from tha4.shion.core.loss import Loss from tha4.shion.core.optimizer_factory import OptimizerFactory class TrainingProtocol(ABC): @abstractmethod def get_optimizer_factories(self) -> Dict[str, OptimizerFactory]: pass @abstractmethod def get_checkpoint_examples(self) -> List[int]: pass @abstractmethod def get_random_seed(self) -> int: pass @abstractmethod def get_batch_size(self) -> int: pass @abstractmethod def get_learning_rate(self, examples_seen_so_far: int) -> Dict[str, float]: pass @abstractmethod def run_training_iteration( self, batch: Any, examples_seen_so_far: int, modules: Dict[str, Module], accumulated_modules: Dict[str, Module], optimizers: Dict[str, Optimizer], losses: Dict[str, Loss], create_log_func: Optional[Callable[[str, int], Callable[[str, float], None]]], device: torch.device): pass class AbstractTrainingProtocol(TrainingProtocol, ABC): def __init__(self, check_point_examples: List[int], batch_size: int, learning_rate: Callable[[int], Dict[str, float]], optimizer_factories: Dict[str, OptimizerFactory], random_seed: int): self.random_seed = random_seed self.optimizer_factories = optimizer_factories self.learning_rate = learning_rate self.batch_size = batch_size self.check_point_examples = check_point_examples def get_optimizer_factories(self) -> Dict[str, OptimizerFactory]: return self.optimizer_factories def get_checkpoint_examples(self) -> List[int]: return self.check_point_examples def get_random_seed(self) -> int: return self.random_seed def get_batch_size(self) -> int: return self.batch_size def get_learning_rate(self, examples_seen_so_far: int) -> Dict[str, float]: return self.learning_rate(examples_seen_so_far) ================================================ FILE: src/tha4/shion/core/training/util.py ================================================ from typing import Callable import torch from torch.nn import Module from torch.optim import Optimizer def optimizer_to_device(optim: Optimizer, device: torch.device): for state in optim.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device) def zero_module(module: Module): parameters = dict(module.named_parameters()) for k in parameters.keys(): parameters[k].data.zero_() def get_least_greater_multiple(x: int, m: int) -> int: """ :param x: a non-negative integer :param m: a positive integer :return: the next multiple of m that is greater than x """ assert x >= 0 assert m > 0 return (x // m + 1) * m def create_log_func(summary_writer, prefix: str, examples_seen_so_far: int) -> Callable[[str, float], None]: def log_func(tag: str, value: float): summary_writer.add_scalar(prefix + "_" + tag, value, examples_seen_so_far) return log_func def set_learning_rate(module, lr): for param_group in module.param_groups: param_group['lr'] = lr ================================================ FILE: src/tha4/shion/core/training/validation_protocol.py ================================================ from abc import ABC, abstractmethod from typing import Dict, Callable, Any import torch from torch.nn import Module from tha4.shion.core.loss import Loss class ValidationProtocol(ABC): @abstractmethod def get_batch_size(self) -> int: pass @abstractmethod def get_examples_per_validation_iteration(self) -> int: pass @abstractmethod def run_validation_iteration( self, batch: Any, examples_seen_so_far: int, modules: Dict[str, Module], accumulated_modules: Dict[str, Module], losses: Dict[str, Loss], create_log_func: Callable[[str, int], Callable[[str, float], None]], device: torch.device): pass class AbstractValidationProtocol(ValidationProtocol, ABC): def __init__(self, example_per_validation_iteration: int, batch_size: int): self.batch_size = batch_size self.example_per_validation_iteration = example_per_validation_iteration def get_batch_size(self) -> int: return self.batch_size def get_examples_per_validation_iteration(self) -> int: return self.example_per_validation_iteration ================================================ FILE: src/tha4/shion/nn00/__init__.py ================================================ ================================================ FILE: src/tha4/shion/nn00/block_args.py ================================================ from typing import Optional from torch.nn import Module, Sequential from tha4.shion.core.module_factory import ModuleFactory from tha4.shion.nn00.linear_module_args import LinearModuleArgs from tha4.shion.nn00.nonlinearity_factories import resolve_nonlinearity_factory from tha4.shion.nn00.normalization_layer_factories import resolve_normalization_layer_factory from tha4.shion.nn00.normalization_layer_factory import NormalizationLayerFactory class BlockArgs: def __init__( self, linear_module_args: Optional[LinearModuleArgs] = None, normalization_layer_factory: Optional[NormalizationLayerFactory] = None, nonlinearity_factory: Optional[ModuleFactory] = None): if linear_module_args is None: linear_module_args = LinearModuleArgs() self.nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory) self.normalization_layer_factory = resolve_normalization_layer_factory(normalization_layer_factory) self.linear_module_args = linear_module_args ================================================ FILE: src/tha4/shion/nn00/conv.py ================================================ from typing import Optional, Union, Callable from torch.nn import Conv2d, Module, Sequential, ConvTranspose2d from tha4.shion.nn00.block_args import BlockArgs from tha4.shion.nn00.linear_module_args import LinearModuleArgs, wrap_linear_module def create_conv7( in_channels: int, out_channels: int, bias: bool = False, linear_module_args: Optional[LinearModuleArgs] = None) -> Module: return wrap_linear_module( Conv2d(in_channels, out_channels, kernel_size=7, stride=1, padding=3, bias=bias), linear_module_args) def create_conv3(in_channels: int, out_channels: int, bias: bool = False, linear_module_args: Optional[LinearModuleArgs] = None) -> Module: return wrap_linear_module( Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=bias), linear_module_args) def create_conv1( in_channels: int, out_channels: int, bias: bool = False, linear_module_args: Optional[LinearModuleArgs] = None) -> Module: return wrap_linear_module( Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias), linear_module_args) def create_conv7_block( in_channels: int, out_channels: int, block_args: Optional[BlockArgs] = None) -> Module: if block_args is None: block_args = BlockArgs() return Sequential( create_conv7( in_channels, out_channels, bias=False, linear_module_args=block_args.linear_module_args), block_args.normalization_layer_factory.create(out_channels, affine=True), block_args.nonlinearity_factory.create()) def create_conv3_block( in_channels: int, out_channels: int, block_args: Optional[BlockArgs] = None) -> Module: if block_args is None: block_args = BlockArgs() return Sequential( create_conv7( in_channels, out_channels, bias=False, linear_module_args=block_args.linear_module_args), block_args.normalization_layer_factory.create(out_channels, affine=True), block_args.nonlinearity_factory.create()) def create_downsample_block( in_channels: int, out_channels: int, is_output_1x1: bool = False, block_args: Optional[BlockArgs] = None) -> Module: if block_args is None: block_args = BlockArgs() if is_output_1x1: return Sequential( wrap_linear_module( Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False), block_args.linear_module_args), block_args.nonlinearity_factory.create()) else: return Sequential( wrap_linear_module( Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False), block_args.linear_module_args), block_args.normalization_layer_factory.create(out_channels, affine=True), block_args.nonlinearity_factory.create()) def create_upsample_block( in_channels: int, out_channels: int, block_args: Optional[BlockArgs] = None) -> Module: if block_args is None: block_args = BlockArgs() return Sequential( wrap_linear_module( ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False), linear_module_args=block_args.linear_module_args), block_args.normalization_layer_factory.create(out_channels, affine=True), block_args.nonlinearity_factory.create()) ================================================ FILE: src/tha4/shion/nn00/initialization_funcs.py ================================================ from typing import Callable, Optional import torch from torch import zero_ from torch.nn import Module from torch.nn.init import kaiming_normal_, xavier_normal_, normal_ class HeInitialization: def __init__(self, a: int = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'): self.nonlinearity = nonlinearity self.mode = mode self.a = a def __call__(self, module: Module) -> Module: with torch.no_grad(): kaiming_normal_(module.weight, a=self.a, mode=self.mode, nonlinearity=self.nonlinearity) return module class NormalInitialization: def __init__(self, mean: float = 0.0, std: float = 1.0): self.std = std self.mean = mean def __call__(self, module: Module) -> Module: with torch.no_grad(): normal_(module.weight, self.mean, self.std) return module class XavierInitialization: def __init__(self, gain: float = 1.0): self.gain = gain def __call__(self, module: Module) -> Module: with torch.no_grad(): xavier_normal_(module.weight, self.gain) return module class ZeroInitialization: def __call__(self, module: Module) -> Module: with torch.no_grad: zero_(module.weight) return module class NoInitialization: def __call__(self, module: Module) -> Module: return module def resolve_initialization_func(initialization: Optional[Callable[[Module], Module]]): if initialization is None: return NoInitialization() else: return initialization ================================================ FILE: src/tha4/shion/nn00/linear_module_args.py ================================================ from typing import Optional, Callable from torch.nn import Module from torch.nn.utils import spectral_norm from tha4.shion.nn00.initialization_funcs import resolve_initialization_func class LinearModuleArgs: def __init__( self, initialization_func: Optional[Callable[[Module], Module]] = None, use_spectral_norm: bool = False): self.use_spectral_norm = use_spectral_norm self.initialization_func = resolve_initialization_func(initialization_func) def wrap_linear_module(self, module: Module) -> Module: module = self.initialization_func(module) if self.use_spectral_norm: module = spectral_norm(module) return module def wrap_linear_module(module: Module, linear_module_args: Optional[LinearModuleArgs] = None): if linear_module_args is None: linear_module_args = LinearModuleArgs() module = linear_module_args.initialization_func(module) if linear_module_args.use_spectral_norm: module = spectral_norm(module) return module ================================================ FILE: src/tha4/shion/nn00/nonlinearity_factories.py ================================================ from typing import Optional import torch from torch import Tensor from torch.nn import Module, ReLU, LeakyReLU, ELU, ReLU6, Hardswish, SiLU, Tanh, Sigmoid from tha4.shion.core.module_factory import ModuleFactory class ReLUFactory(ModuleFactory): def __init__(self, inplace: bool = False): self.inplace = inplace def create(self) -> Module: return ReLU(self.inplace) class LeakyReLUFactory(ModuleFactory): def __init__(self, inplace: bool = False, negative_slope: float = 1e-2): self.negative_slope = negative_slope self.inplace = inplace def create(self) -> Module: return LeakyReLU(inplace=self.inplace, negative_slope=self.negative_slope) class ELUFactory(ModuleFactory): def __init__(self, inplace: bool = False, alpha: float = 1.0): self.alpha = alpha self.inplace = inplace def create(self) -> Module: return ELU(inplace=self.inplace, alpha=self.alpha) class ReLU6Factory(ModuleFactory): def __init__(self, inplace: bool = False): self.inplace = inplace def create(self) -> Module: return ReLU6(inplace=self.inplace) class SiLUFactory(ModuleFactory): def __init__(self, inplace: bool = False): self.inplace = inplace def create(self) -> Module: return SiLU(inplace=self.inplace) class HardswishFactory(ModuleFactory): def __init__(self, inplace: bool = False): self.inplace = inplace def create(self) -> Module: return Hardswish(inplace=self.inplace) class TanhFactory(ModuleFactory): def create(self) -> Module: return Tanh() class SigmoidFactory(ModuleFactory): def create(self) -> Module: return Sigmoid() class Swish(Module): def __init__(self): super().__init__() def forward(self, x: Tensor): return x * torch.sigmoid(x) class SwishFactory(ModuleFactory): def create(self) -> Module: return Swish() def resolve_nonlinearity_factory(nonlinearity_factory: Optional[ModuleFactory]) -> ModuleFactory: if nonlinearity_factory is None: return ReLUFactory(inplace=False) else: return nonlinearity_factory ================================================ FILE: src/tha4/shion/nn00/normalization_layer_factories.py ================================================ from typing import Optional import torch from torch.nn import Module, Parameter, BatchNorm2d, InstanceNorm2d, GroupNorm from torch.nn.functional import layer_norm from torch.nn.init import normal_, constant_ from tha4.shion.nn00.normalization_layer_factory import NormalizationLayerFactory from tha4.shion.nn00.pass_through import PassThrough class Bias2d(Module): def __init__(self, num_features: int): super().__init__() self.num_features = num_features self.bias = Parameter(torch.zeros(1, num_features, 1, 1)) def forward(self, x): return x + self.bias class NoNorm2dFactory(NormalizationLayerFactory): def __init__(self): super().__init__() def create(self, num_features: int, affine: bool = True) -> Module: if affine: return Bias2d(num_features) else: return PassThrough() class BatchNorm2dFactory(NormalizationLayerFactory): def __init__(self, weight_mean: Optional[float] = None, weight_std: Optional[float] = None, bias: Optional[float] = None): super().__init__() self.bias = bias self.weight_std = weight_std self.weight_mean = weight_mean def get_weight_mean(self): if self.weight_mean is None: return 1.0 else: return self.weight_mean def get_weight_std(self): if self.weight_std is None: return 0.02 else: return self.weight_std def create(self, num_features: int, affine: bool = True) -> Module: module = BatchNorm2d(num_features=num_features, affine=affine) if affine: if self.weight_mean is not None or self.weight_std is not None: normal_(module.weight, self.get_weight_mean(), self.get_weight_std()) if self.bias is not None: constant_(module.bias, self.bias) return module class InstanceNorm2dFactory(NormalizationLayerFactory): def __init__(self): super().__init__() def create(self, num_features: int, affine: bool = True) -> Module: return InstanceNorm2d(num_features=num_features, affine=affine) class LayerNorm2d(Module): def __init__(self, channels: int, affine: bool = True): super(LayerNorm2d, self).__init__() self.channels = channels self.affine = affine if self.affine: self.weight = Parameter(torch.ones(1, channels, 1, 1)) self.bias = Parameter(torch.zeros(1, channels, 1, 1)) def forward(self, x): shape = x.size()[1:] y = layer_norm(x, shape) * self.weight + self.bias return y class LayerNorm2dFactory(NormalizationLayerFactory): def __init__(self): super().__init__() def create(self, num_features: int, affine: bool = True) -> Module: return LayerNorm2d(channels=num_features, affine=affine) class GroupNormFactory(NormalizationLayerFactory): def __init__(self, num_groups: int, eps=1e-6): super().__init__() self.eps = eps self.num_groups = num_groups def create(self, num_features: int, affine: bool = True) -> Module: return GroupNorm(num_channels=num_features, num_groups=self.num_groups, eps=self.eps, affine=affine) def resolve_normalization_layer_factory(factory: Optional['NormalizationLayerFactory']) -> 'NormalizationLayerFactory': if factory is None: return InstanceNorm2dFactory() else: return factory ================================================ FILE: src/tha4/shion/nn00/normalization_layer_factory.py ================================================ from abc import ABC, abstractmethod from torch.nn import Module class NormalizationLayerFactory(ABC): def __init__(self): super().__init__() @abstractmethod def create(self, num_features: int, affine: bool = True) -> Module: pass ================================================ FILE: src/tha4/shion/nn00/pass_through.py ================================================ from torch.nn import Module class PassThrough(Module): def __init__(self): super().__init__() def forward(self, x): return x ================================================ FILE: src/tha4/shion/nn00/resnet_block.py ================================================ from typing import Optional import torch from torch.nn import Module, Sequential, Parameter from tha4.shion.nn00.block_args import BlockArgs from tha4.shion.nn00.conv import create_conv1, create_conv3 class ResnetBlock(Module): def __init__(self, num_channels: int, is1x1: bool = False, use_scale_parameter: bool = False, block_args: Optional[BlockArgs] = None): super().__init__() if block_args is None: block_args = BlockArgs() self.use_scale_parameter = use_scale_parameter if self.use_scale_parameter: self.scale = Parameter(torch.zeros(1)) if is1x1: self.resnet_path = Sequential( create_conv1( num_channels, num_channels, bias=True, linear_module_args=block_args.linear_module_args), block_args.nonlinearity_factory.create(), create_conv1( num_channels, num_channels, bias=True, linear_module_args=block_args.linear_module_args)) else: self.resnet_path = Sequential( create_conv3( num_channels, num_channels, bias=False, linear_module_args=block_args.linear_module_args), block_args.normalization_layer_factory.create(num_channels, affine=True), block_args.nonlinearity_factory.create(), create_conv3( num_channels, num_channels, bias=False, linear_module_args=block_args.linear_module_args), block_args.normalization_layer_factory.create(num_channels, affine=True)) def forward(self, x): if self.use_scale_parameter: return x + self.scale * self.resnet_path(x) else: return x + self.resnet_path(x)