Repository: pkhungurn/talking-head-anime-4-demo
Branch: main
Commit: 320640116abd
Files: 185
Total size: 655.5 KB
Directory structure:
gitextract_k8uli292/
├── .gitignore
├── .python-version
├── LICENSE
├── README.md
├── bin/
│ ├── activate-venv.bat
│ ├── activate-venv.sh
│ ├── run
│ └── run.bat
├── distiller-ui-doc/
│ ├── index.html
│ └── params/
│ ├── body_morpher_batch_size.html
│ ├── body_morpher_random_seed_0.html
│ ├── body_morpher_random_seed_1.html
│ ├── character_image_file_name.html
│ ├── face_mask_image_file_name.html
│ ├── face_morpher_batch_size.html
│ ├── face_morpher_random_seed_0.html
│ ├── face_morpher_random_seed_1.html
│ ├── num_cpu_workers.html
│ ├── num_gpus.html
│ ├── num_training_examples_per_sample_output.html
│ └── prefix.html
├── docs/
│ ├── character_model_ifacialmocap_puppeteer.md
│ ├── character_model_manual_poser.md
│ ├── character_model_mediapipe_puppeteer.md
│ ├── distill.md
│ ├── distiller_ui.md
│ └── full_manual_poser.md
├── poetry/
│ ├── README.md
│ └── pyproject.toml
└── src/
└── tha4/
├── __init__.py
├── app/
│ ├── __init__.py
│ ├── character_model_ifacialmocap_puppeteer.py
│ ├── character_model_manual_poser.py
│ ├── character_model_mediapipe_puppeteer.py
│ ├── distill.py
│ ├── distiller_ui.py
│ └── full_manual_poser.py
├── charmodel/
│ ├── __init__.py
│ └── character_model.py
├── dataset/
│ ├── __init__.py
│ └── image_poses_and_aother_images_dataset.py
├── distiller/
│ ├── __init__.py
│ ├── config_based_training_tasks.py
│ ├── distill_body_morpher.py
│ ├── distill_face_morpher.py
│ ├── distiller_config.py
│ └── ui/
│ ├── __init__.py
│ ├── distiller_config_state.py
│ └── distiller_ui_main_frame.py
├── image_util.py
├── mocap/
│ ├── __init__.py
│ ├── ifacialmocap_constants.py
│ ├── ifacialmocap_pose.py
│ ├── ifacialmocap_pose_converter.py
│ ├── ifacialmocap_pose_converter_25.py
│ ├── ifacialmocap_v2.py
│ ├── mediapipe_constants.py
│ ├── mediapipe_face_pose.py
│ ├── mediapipe_face_pose_converter.py
│ └── mediapipe_face_pose_converter_00.py
├── nn/
│ ├── __init__.py
│ ├── common/
│ │ ├── __init__.py
│ │ ├── conv_block_factory.py
│ │ ├── poser_args.py
│ │ ├── poser_encoder_decoder_00.py
│ │ ├── poser_encoder_decoder_00_separable.py
│ │ ├── resize_conv_encoder_decoder.py
│ │ ├── resize_conv_unet.py
│ │ └── unet.py
│ ├── conv.py
│ ├── eyebrow_decomposer/
│ │ ├── __init__.py
│ │ └── eyebrow_decomposer_00.py
│ ├── eyebrow_morphing_combiner/
│ │ ├── __init__.py
│ │ └── eyebrow_morphing_combiner_00.py
│ ├── face_morpher/
│ │ ├── __init__.py
│ │ └── face_morpher_08.py
│ ├── image_processing_util.py
│ ├── init_function.py
│ ├── morpher/
│ │ ├── __init__.py
│ │ └── morpher_00.py
│ ├── nonlinearity_factory.py
│ ├── normalization.py
│ ├── pass_through.py
│ ├── resnet_block.py
│ ├── resnet_block_seperable.py
│ ├── separable_conv.py
│ ├── siren/
│ │ ├── __init__.py
│ │ ├── face_morpher/
│ │ │ ├── __init__.py
│ │ │ ├── siren_face_morpher_00.py
│ │ │ ├── siren_face_morpher_00_trainer.py
│ │ │ └── siren_face_morpher_protocols_00.py
│ │ ├── morpher/
│ │ │ ├── __init__.py
│ │ │ ├── siren_morpher_03.py
│ │ │ ├── siren_morpher_03_trainer.py
│ │ │ └── siren_morpher_protocols_03.py
│ │ └── vanilla/
│ │ ├── __init__.py
│ │ └── siren.py
│ ├── spectral_norm.py
│ ├── upscaler/
│ │ ├── __init__.py
│ │ └── upscaler_02.py
│ └── util.py
├── poser/
│ ├── __init__.py
│ ├── general_poser_02.py
│ ├── modes/
│ │ ├── __init__.py
│ │ ├── mode_07.py
│ │ ├── mode_12.py
│ │ ├── mode_14.py
│ │ └── pose_parameters.py
│ └── poser.py
├── pytasuku/
│ ├── __init__.py
│ ├── indexed/
│ │ ├── __init__.py
│ │ ├── all_tasks.py
│ │ ├── bundled_indexed_file_tasks.py
│ │ ├── indexed_file_tasks.py
│ │ ├── indexed_tasks.py
│ │ ├── no_index_command_tasks.py
│ │ ├── no_index_file_tasks.py
│ │ ├── one_index_file_tasks.py
│ │ ├── simple_no_index_file_tasks.py
│ │ ├── two_indices_file_tasks.py
│ │ └── util.py
│ ├── task.py
│ ├── task_selector_ui.py
│ ├── util.py
│ └── workspace.py
├── sampleoutput/
│ ├── __init__.py
│ ├── general_sample_output_protocol.py
│ ├── poser_sampler_output_protocol.py
│ └── sample_image_creator.py
└── shion/
├── __init__.py
├── base/
│ ├── __init__.py
│ ├── dataset/
│ │ ├── __init__.py
│ │ ├── lazy_dataset.py
│ │ ├── lazy_tensor_dataset.py
│ │ ├── png_in_dir_dataset.py
│ │ ├── util.py
│ │ └── xformed_dataset.py
│ ├── image_util.py
│ ├── loss/
│ │ ├── __init__.py
│ │ ├── computed_scale_loss.py
│ │ ├── computed_scaled_l2_loss.py
│ │ ├── l1_loss.py
│ │ ├── l2_loss.py
│ │ ├── sum_loss.py
│ │ └── time_dependently_weighted_loss.py
│ ├── module_accumulators.py
│ ├── optimizer_factories.py
│ ├── protocol/
│ │ └── single_network_from_batch_input_computation_protocol.py
│ └── training/
│ ├── __init__.py
│ ├── single_network.py
│ ├── single_network_with_minibatch.py
│ └── two_networks_training_protocol.py
├── core/
│ ├── __init__.py
│ ├── cached_computation.py
│ ├── load_save.py
│ ├── loss.py
│ ├── module_accumulator.py
│ ├── module_factory.py
│ ├── optimizer_factory.py
│ └── training/
│ ├── __init__.py
│ ├── distrib/
│ │ ├── __init__.py
│ │ ├── device_mapper.py
│ │ ├── distributed_trainer.py
│ │ ├── distributed_training_states.py
│ │ └── distributed_training_tasks.py
│ ├── sample_output_protocol.py
│ ├── single/
│ │ ├── __init__.py
│ │ ├── training_states.py
│ │ └── training_tasks.py
│ ├── swarm/
│ │ ├── __init__.py
│ │ ├── swarm_training_tasks.py
│ │ └── swarm_unit_trainer.py
│ ├── training_protocol.py
│ ├── util.py
│ └── validation_protocol.py
└── nn00/
├── __init__.py
├── block_args.py
├── conv.py
├── initialization_funcs.py
├── linear_module_args.py
├── nonlinearity_factories.py
├── normalization_layer_factories.py
├── normalization_layer_factory.py
├── pass_through.py
└── resnet_block.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Compiled class file
*.class
# Log file
*.log
# BlueJ files
*.ctxt
# Mobile Tools for Java (J2ME)
.mtj.tmp/
# Package Files #
*.jar
*.war
*.nar
*.ear
*.zip
*.tar.gz
*.rar
# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
hs_err_pid*
.gradle
.vscode
out/
.idea/
.gradle/
build/
/data/
.idea/*
*.iml
*/.idea/*
*/build
*/.gradle/*
*/out/*
*.pyc
*.pyd
**/.cache/*
*/bin
__pycache__/
./tools/
temp/
*/*/bin/
venv/*
# io dump to ABCI tasks
*.o*
================================================
FILE: .python-version
================================================
3.10.11
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2024 pixiv Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# Demo Code for "Talking Head(?) Anime from a Single Image 4: Improved Model and Its Distillation"
This repository contains demo programs for the "Talking Head(?) Anime from a Single Image 4: Improved Model and Its Distillation" project. Roughly, the project is about a machine learning model that can animate an anime character given only one image. However, the model is too slow to run in real-time. So, it also proposes an algorithm to use the model to train a small machine learning model that is specialized to a character image that can anime the character in real time.
This demo code has two parts.
* **Improved model.** This part gives a model similar to [Version 3](https://github.com/pkhungurn/talking-head-anime-3-demo) of the porject. It has one demo program:
* The `full_manual_poser` allows the user to manipulate a character's facial expression and body rotation through a graphical user interface.
There are no real-time demos because the new model is too slow for that.
* **Distillation.** This part allows the user to train small models (which we will refer to as **student models**) to mimic that behavior of the full system with regards to a specific character image. It also allows the user to run these models under various interfaces. The demo programs are:
* `distill` trains a student model given a configuration file, a $512 \times 512$ RGBA character image, and a mask of facial organs.
* `distiller_ui` provides a user-friendly interface to `distill`, allowing you to create training configurations and providing useful documentation.
* `character_model_manual_poser` allows the user to control trained student models with a graphical user interface.
* `character_model_ifacialmocap_puppeteer` allows the user to control trained student models with their facial movement, which is captured by the [iFacialMocap](https://www.ifacialmocap.com/) software. To run this software, you must have an iOS device and, of course, iFacialMocap.
* `character_model_mediapipe_puppeteer` allows the user to control trained student models with their facial movement, which is captured a web camera and processed by the [Mediapipe FaceLandmarker](https://developers.google.com/mediapipe/solutions/vision/face_landmarker) model. To run this software, you need a web camera.
## Preemptive FAQs
### What is the program to control character images with my facial movement?
There is no such program in this release. If you want one, try the `ifacialmocap_puppeteer` of [Version 3](https://github.com/pkhungurn/talking-head-anime-3-demo).
### OK. I'm confused. Isn't your work about easy VTubing? Are you saying this release cannot do it?
NO. This release does it in a more complicated way. In order to control an image, you need to create a "student model." It is a small (< 2MB) and fast machine learning model that knows how to animate that particular image. Then, the student model can be controlled with facial movement. You can find two student models in the `data/character_models` directory. The [two](https://pkhungurn.github.io/talking-head-anime-4/supplementary/webcam-demo/index.html) [demos](https://pkhungurn.github.io/talking-head-anime-4/supplementary/manual-poser-demo/index.html) on the project website feature 13 students models.
### So, for this release, you can control only these few characters in real time?
No. You can create your own student models.
### How do I create this student model then?
1. You prepare your characater image according to the "Constraint on Input Images" section below.
2. You prepare a black-and-white mask image that covers the eyes and the mouth of the character, like [this image](data/images/lambda_00_face_mask.png). You can see how I made it with [GIMP](https://www.gimp.org/) by inspecting this [GIMP file](data/images/lambda_00_face_mask.xcf).
3. You use `distiller_ui` to create a configuration file that specifies how the student model should be trained.
4. You use `distiller_ui` or `distill` to start the training process.
5. You wait several ten hours for the student model to finish training. Last time I tried, it was about 30 hours on a computer with an Nvidia RTX A6000 GPU.
6. After that, you can control the student model with `character_model_ifacialmocap_puppeteer` and `character_model_mediapipe_puppeteer`.
### Why is this release so hard to use?
[Version 3](https://github.com/pkhungurn/talking-head-anime-3-demo) is arguably easier to use because you can give it an animate and you can control it with your facial movment immediately. However, I was not satisfied with its image quality and speed.
In this release, I explore a new way of doing things. I added a new preprocessing stage (i.e., training the student models) that has to be done one time per character image. It allows the image to be animated much faster at a higher image quality level.
In other words, it makes the user's life difficult but the engineer/researcher happy. Patient users who are willing to go through the steps, though, would be rewarded with faster animation.
### Can I use a student model from a web browser?
No. A student model created by `distill` is a [PyTorch](https://pytorch.org/) model, which cannot run directly in the browser. It needs to be converted to the appropriate format ([TensorFlow.js](https://www.tensorflow.org/js)) first, and the [web](https://pkhungurn.github.io/talking-head-anime-4/supplementary/webcam-demo/index.html) [demos](https://pkhungurn.github.io/talking-head-anime-4/supplementary/manual-poser-demo/index.html) use the converted models. However, The conversion code is not included in this repository. I will not release it unless I change my mind.
## Hardware Requirements
All programs require a recent and powerful Nvidia GPU to run. I developed the programs on a machine with an Nvidia RTX A6000. However, anything after the GeForce RTX 2080 should be fine.
The `character_model_ifacialmocap_puppeteer` program requires an iOS device that is capable of computing [blend shape parameters](https://developer.apple.com/documentation/arkit/arfaceanchor/2928251-blendshapes) from a video feed. This means that the device must be able to run iOS 11.0 or higher and must have a TrueDepth front-facing camera. (See [this page](https://developer.apple.com/documentation/arkit/content_anchors/tracking_and_visualizing_faces) for more info.) In other words, if you have the iPhone X or something better, you should be all set. Personally, I have used an iPhone 12 mini.
The `character_model_mediapipe_puppeteer` program requires a web camera.
## Software Requirements
### GPU Driver and CUDA Toolkit
Please update your GPU's device driver and install the [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit) that is compatible with your GPU and is newer than the version you will be installing in the next subsection.
### Python and Python Libraries
All programs are written in the [Python](https://www.python.org/) programming languages. The following libraries are required:
* `python` 3.10.11
* `torch` 1.13.1 with CUDA support
* `torchvision` 0.14.1
* `tensorboard` 2.15.1
* `opencv-python` 4.8.1.78
* `wxpython` 4.2.1
* `numpy-quaternion` 2022.4.2
* `pillow` 9.4.0
* `matplotlib` 3.6.3
* `einops` 0.6.0
* `mediapipe` 0.10.3
* `numpy` 1.26.3
* `scipy` 1.12.0
* `omegaconf` 2.3.0
Instead of installing these libraries yourself, you should follow the recommended method to set up a Python environment in the next section.
### iFacialMocap
If you want to use ``ifacialmocap_puppeteer``, you will also need to an iOS software called [iFacialMocap](https://www.ifacialmocap.com/) (a 980 yen purchase in the App Store). Your iOS and your computer must use the same network. For example, you may connect them to the same wireless router.
## Creating Python Environment
### Installing Python
Please install [Python 3.10.11](https://www.python.org/downloads/release/python-31011/).
I recommend using [`pyenv`](https://github.com/pyenv/pyenv) (or [`pyenv-win`](https://github.com/pyenv-win/pyenv-win) for Windows users) to manage multiple Python versions on your system. If you use `pyenv`, this repository has a `.python-version` file that indicates it would use Python 3.10.11. So, you will be using Python 3.10.11 automatically once you `cd` into the repository's directory.
Make sure that you can run Python from the command line.
### Installing Poetry
Please install [Poetry](https://python-poetry.org/) 1.7 or later. We will use it to automatically install the required libraries. Again, make sure that you can run it from the command line.
### Cloning the Repository
Please clone the repository to an arbitrary directory in your machine.
### Instruction for Linux/OSX Users
1. Open a shell.
2. `cd` to the directory you just cloned the repository too
```
cd SOMEWHERE/talking-head-anime-4-demo
```
3. Use Python to create a virtual environment under the `venv` directory.
```
python -m venv venv --prompt talking-head-anime-4-demo
```
4. Activate the newly created virtual environment. You can either use the script I provide:
```
source bin/activate-venv.sh
```
or do it yourself:
```
source venv/bin/activate
```
5. Use Poetry to install libraries.
```
cd poetry
poetry install
```
### Instruction for Windows Users
1. Open a shell.
2. `cd` to the directory you just cloned the repository too
```
cd SOMEWHERE\talking-head-anime-4-demo
```
3. Use Python to create a virtual environment under the `venv` directory.
```
python -m venv venv --prompt talking-head-anime-4-demo
```
4. Activate the newly created virtual environment. You can either use the script I provide:
```
bin\activate-venv.bat
```
or do it yourself:
```
venv\Scripts\activate
```
5. Use Poetry to install libraries.
```
cd poetry
poetry install
```
## Download the Models/Dataset Files
### THA4 Models
Please download [this ZIP file](https://www.dropbox.com/scl/fi/7wec0sur7449iqgtlpi3n/tha4-models.zip?rlkey=0f9d1djmbvjjjn09469s1adx8&dl=0) hosted on Dropbox, and unzip it to the `data/tha4` directory the under the repository's directory. In the end, the directory tree should look like the following diagram:
```
+ talking-head-anime-4-demo
+ data
- character_models
- distill_examples
+ tha4
- body_morpher.pt
- eyebrow_decomposer.pt
- eyebrow_morphing_combiner.pt
- face_morpher.pt
- upscaler.pt
- images
- third_party
```
### Pose Dataset
If you want to create your own student models, you also need to download a dataset of poses that are needed for the training process. Download [this `pose_dataset.pt` file](https://www.dropbox.com/scl/fi/du10e6buzr5bslbe025qu/pose_dataset.pt?rlkey=y052g4n3xb14nu2elctzouc5x&dl=0) and save it to the `data` folder. The directory tree should then look like the following diagram:
```
+ talking-head-anime-4-demo
+ data
- character_models
- distill_examples
- tha4
- images
- third_party
- pose_dataset.pt
```
## Running the Programs
The programs are located in the `src/tha4/app` directory. You need to run them from a shell with the provided scripts.
### Instruction for Linux/OSX Users
1. Open a shell.
2. `cd` to the repository's directory.
```
cd SOMEWHERE/talking-head-anime-4-demo
```
3. Run a program.
```
bin/run src/tha4/app/
```
where `` can be replaced with:
* `character_model_ifacialmocap_puppeteer.py`
* `character_model_manual_poser.py`
* `character_model_mediapipe_puppeteer.py`
* `distill.py`
* `disllerer_ui.py`
* `full_manual_poser.py`
### Instruction for Windows Users
1. Open a shell.
2. `cd` to the repository's directory.
```
cd SOMEWHERE\talking-head-anime-4-demo
```
3. Run a program.
```
bin\run.bat src\tha4\app\
```
where `` can be replaced with:
* `character_model_ifacialmocap_puppeteer.py`
* `character_model_manual_poser.py`
* `character_model_mediapipe_puppeteer.py`
* `distill.py`
* `disllerer_ui.py`
* `full_manual_poser.py`
## Contraints on Input Images
In order for the system to work well, the input image must obey the following constraints:
* It should be of resolution 512 x 512. (If the demo programs receives an input image of any other size, they will resize the image to this resolution and also output at this resolution.)
* It must have an alpha channel.
* It must contain only one humanoid character.
* The character should be standing upright and facing forward.
* The character's hands should be below and far from the head.
* The head of the character should roughly be contained in the 128 x 128 box in the middle of the top half of the image.
* The alpha channels of all pixels that do not belong to the character (i.e., background pixels) must be 0.

## Documentation for the Tools
* [`character_model_ifacial_model_puppeteer`](docs/character_model_ifacialmocap_puppeteer.md)
* [`character_model_manual_poser`](docs/character_model_manual_poser.md)
* [`character_model_mediapipe_puppeteer`](docs/character_model_mediapipe_puppeteer.md)
* [`distill`](docs/distill.md)
* [`distiller_ui`](docs/distiller_ui.md)
* [`full_manual_poser`](docs/full_manual_poser.md)
## Disclaimer
The author is an employee of [pixiv Inc.](https://www.pixiv.co.jp/) This project is a part of his work as a researcher.
However, this project is NOT a pixiv product. The company will NOT provide any support for this project. The author will try to support the project, but there are no Service Level Agreements (SLAs) that he will maintain.
The code is released under the [MIT license](https://github.com/pkhungurn/talking-head-anime-2-demo/blob/master/LICENSE).
The THA4 models and the images under the `data/images` directory are released under the [Creative Commons Attribution-NonCommercial 4.0 International](https://creativecommons.org/licenses/by-nc/4.0/deed.en).
This repository redistributes a version of the [Face landmark detection model](https://developers.google.com/mediapipe/solutions/vision/face_landmarker) from the [MediaPipe](https://developers.google.com/mediapipe) project. The model has been released under the [Apache License, Version 2.0](https://www.apache.org/licenses/LICENSE-2.0.html).
================================================
FILE: bin/activate-venv.bat
================================================
venv\Scripts\activate
================================================
FILE: bin/activate-venv.sh
================================================
#! /bin/bash
source venv/bin/activate
================================================
FILE: bin/run
================================================
#! /bin/bash
export PYTHONPATH=$(pwd)/src
venv/bin/python $@
================================================
FILE: bin/run.bat
================================================
set PYTHONPATH=%cd%\src
venv\Scripts\python.exe %*
================================================
FILE: distiller-ui-doc/index.html
================================================
Distiller UI Documentation
How to use Distiller UI
This program is called distiller_ui. It allows you to create and modify configurations for the process of distilling the full, but slow THA4 system to a student model that can be run in real time on computers with moderately powerful GPUs.
Basic Usage
This program manipulates YAML files that are used as configurations for the distillation process. The menus
- File → New
- File → Open
- File → Save
do what they are supposed to do in typical application programs.
You can use the UI in the middle panel to change various parameters of the configuration. If you do not understand what the meaning of a parameter, click the "Help" button for that parameter to learn more.
Once you have modified the parameters to your liking, click the "RUN" button at the bottom of the middle panel to carry out the distillation. This will take several ten hours, so sit back and relax.
The distillation process can be interrupted and resumed at any time. As a result, you do not have to worry that you may lose data if there's a blackout or if you need to free your GPU(s) to do something else. Resuming can be done through this program or through the distill script.
Explanation of Configuration Parameters
================================================
FILE: distiller-ui-doc/params/body_morpher_batch_size.html
================================================
Distiller UI Documentation: body_morpher_batch_size
body_morpher_batch_size
The "batch size" is the number of training examples shown to a machine learning model in one round of parameter update. This parameter is the batch size for training the student body morpher. We recommend you set it to 8. However, if your computer does not have enough GPU RAM, you can reduce the number to any smaller positive integer.
Back to main documentation
================================================
FILE: distiller-ui-doc/params/body_morpher_random_seed_0.html
================================================
Distiller UI Documentation: body_morpher_random_seed_0
body_morpher_random_seed_0
This parameter will be used as a random seed in the process of training the student body morpher. It can be any non-negative integer from 0 to 264-1. You can specify the number directly, or use the "Randomize" button to specify a random one.
Back to main documentation
================================================
FILE: distiller-ui-doc/params/body_morpher_random_seed_1.html
================================================
Distiller UI Documentation: body_morpher_random_seed_1
body_morpher_random_seed_1
This parameter will be used as a random seed in the process of training the student body morpher. It can be any non-negative integer from 0 to 264-1. You can specify the number directly, or use the "Randomize" button to specify a random one.
Back to main documentation
================================================
FILE: distiller-ui-doc/params/character_image_file_name.html
================================================
Distiller UI Documentation: character_image_file_name
character_image_file_name
This is the name of the file of an image of a humanoid character. The image must conform to the following specifications.
- It MUST in the PNG format.
- It MUST have an alpha channel.
- It MUST be 512 x 512.
- It MUST contain only one humanoid character.
- The character should be standing upright and facing forward.
- The character's hands should be below and far from the head.
- The head of the character should roughly be contained in the 128 x 128 box in the middle of the top half of the image.
- The alpha channels of all pixels that do not belong to the character (i.e., background pixels) must be 0.
Once you have chosen the image, a crop of the character face will be shown on the right side of the window. In order for the distillation process works correctly, make sure that all the movable parts of the face— eyes, eyebrows, mouth, jaw line — can all be seen in this crop.
 ☑ |
This image is GOOD because we can see all of the eyes, eyebrows, mouth, and jaw line in the image. |
 ☒ |
This image is NOT GOOD because we cannot see the whole of the jaw line in the image |
 ☒ |
This image is NOT GOOD because we cannot see the whole of the right eye and eyebrow in the image. |
 ☒ |
This image is NOT GOOD because we cannot see the whole of the eyebrows in the image. |
The data/images directory contains two example images that conform to all the above specifications: data/images/lambda_00.png and data/images/lambda_01.png. Please use them as references.
Back to main documentation
================================================
FILE: distiller-ui-doc/params/face_mask_image_file_name.html
================================================
Distiller UI Documentation: face_mask_image_file_name
face_mask_image_file_name
This is the name of the file containing binary masks of movable facial organs of the character. It is probably the best to see an example.
A "face mask image" conforms to the following specification.
- It must be in the PNG format.
- It must be 512 x 512.
- It must be an RGB image (i.e., no alpha channel).
- All pixels must be either block (0,0,0) or white (255,255,255).
- The white pixels should cover movable parts of the face.
We recommend creating three rectangles.
- One covers the right eye and eyebrow.
- One covers the left eye and eyebrow.
- One covers the mouth and the jaw line.
The rectangles for the eyes and the eyebrows should extend above the eyes to some extent because the eyebrows can move upward.
Once you have specified the face mask image with the "Change..." button, a crop of the face area will show up on the left side of the window. If the character image has also been specified, an image of the face mask laid over the character's face will also show up. Use this image to check whether the masks are covering everything.
Back to main documentation
================================================
FILE: distiller-ui-doc/params/face_morpher_batch_size.html
================================================
Distiller UI Documentation: face_morpher_batch_size
face_morpher_batch_size
The "batch size" is the number of training examples shown to a machine learning model in one round of parameter update. This parameter is the batch size for training the student face morpher. We recommend you set it to 8. However, if your computer does not have enough GPU RAM, you can reduce the number to any smaller positive integer.
Back to main documentation
================================================
FILE: distiller-ui-doc/params/face_morpher_random_seed_0.html
================================================
Distiller UI Documentation: face_morpher_random_seed_0
face_morpher_random_seed_0
This parameter will be used as a random seed in the process of training the student face morpher. It can be any non-negative integer from 0 to 264-1. You can specify the number directly, or use the "Randomize" button to specify a random one.
Back to main documentation
================================================
FILE: distiller-ui-doc/params/face_morpher_random_seed_1.html
================================================
Distiller UI Documentation: face_morpher_random_seed_1
face_morpher_random_seed_1
This parameter will be used as a random seed in the process of training the student face morpher. It can be any non-negative integer from 0 to 264-1. You can specify the number directly, or use the "Randomize" button to specify a random one.
Back to main documentation
================================================
FILE: distiller-ui-doc/params/num_cpu_workers.html
================================================
Distiller UI Documentation: face_mask_image_file_name
num_cpu_workers
This is the number of worker threads that are used to process pose data during training of the student models. Typically, 1 would be enough, but you can specify up to the number of CPUs your computer has.
Back to main documentation
================================================
FILE: distiller-ui-doc/params/num_gpus.html
================================================
Distiller UI Documentation: num_gpus
num_gpus
This is the number of GPUs that are used to to train the student models. Typically, 1 would be enough. However, you can specify up to the number of Nvidia GPUs that your PC has.
Back to main documentation
================================================
FILE: distiller-ui-doc/params/num_training_examples_per_sample_output.html
================================================
Distiller UI Documentation: num_training_example_per_sample_output
num_training_example_per_sample_output
During training of a student model, the training process would periodically create "sample output" produced by the model being trained in order to allow the user to see training progress and observe whether there is any anomalies.
This parameter specifies how frequent the sample outputs are generated. You can indicate whether you want a sample output to be generated every time the trained model has beeen shown 10,000, 100,000 or 1,000,000 training examples. If you do not care about sample outputs, you can also make the process not generate any sample outputs at all.
Back to main documentation
================================================
FILE: distiller-ui-doc/params/prefix.html
================================================
Distiller UI Documentation: prefix
prefix
prefix is the name of the directory under which the distillation process will store the trained models and other intermediate data. Please choose a directory that is a subdirectory of the directory that stores the talking-head-anime-4-demo's repository.
Back to main documentation
================================================
FILE: docs/character_model_ifacialmocap_puppeteer.md
================================================
# `character_model_ifacialmocap_puppeteer`
This program allows the user to control trained student models with their facial movement, which is captured by the [iFacialMocap](https://www.ifacialmocap.com/) software. You can purchase the software from the App Store for 980 Japanese Yen.
## Invoking the Program
Make sure you have (1) created a Python environment and (2) downloaded model files as instruction in the [main README file](../README.md).
### Instruction for Linux/OSX Users
1. Open a shell.
2. `cd` to the repository's directory.
```
cd SOMEWHERE/talking-head-anime-4-demo
```
3. Run the program.
```
bin/run src/tha4/app/character_model_ifacialmocap_puppeteer.py
```
### Instruction for Windows Users
1. Open a shell.
2. `cd` to the repository's directory.
```
cd SOMEWHERE\talking-head-anime-4-demo
```
3. Run the program.
```
bin\run.bat src\tha4\app\character_model_ifacialmocap_puppeteer.py
```
## Usage
1. Run iFacialMocap on your iOS device. It should show you the device's IP address. Jot it down. Keep the app open.

2. Invoke the `character_model_ifacialmocap_puppeteer` application.
3. You will see a text box with label "Capture Device IP." Write the iOS device's IP address that you jotted down there.

4. Click the "START CAPTURE!" button to the right.

If the programs are connected properly, you should see the numbers in the bottom part of the window change when you move your head.

5. Now, you can load a student model, and the character should follow your facial movement.
================================================
FILE: docs/character_model_manual_poser.md
================================================
# `character_model_manual_poser`
This program allows the user to control trained student models with a graphical user interface, mostly sliders.
## Invoking the Program
Make sure you have (1) created a Python environment and (2) downloaded model files as instruction in the [main README file](../README.md).
### Instruction for Linux/OSX Users
1. Open a shell.
2. `cd` to the repository's directory.
```
cd SOMEWHERE/talking-head-anime-4-demo
```
3. Run the program.
```
bin/run src/tha4/app/character_model_manual_poser.py
```
### Instruction for Windows Users
1. Open a shell.
2. `cd` to the repository's directory.
```
cd SOMEWHERE\talking-head-anime-4-demo
```
3. Run the program.
```
bin\run.bat src\tha4\app\character_model_manual_poser.py
```
================================================
FILE: docs/character_model_mediapipe_puppeteer.md
================================================
# `character_model_mediapipe_puppeteer`
allows the user to control trained student models with their facial movement, which is captured by a web camera and processed by the [Mediapipe FaceLandmarker](https://developers.google.com/mediapipe/solutions/vision/face_landmarker) model.
## Web Camera
Please make sure that, before you invoke the program, your computer has a web camera plugged in. The program will use a web camera, but it does not allow you to specify which. In case your machine has more than one web camera, you can turn off all camera except the one that you want to use.
You can also inspect the [source code](../src/tha4/app/character_model_mediapipe_puppeteer.py) and change the
```
video_capture = cv2.VideoCapture(0)
```
line to choose a particular camera that you want to use.
## Invoking the Program
Make sure you have (1) created a Python environment and (2) downloaded model files as instruction in the [main README file](../README.md).
### Instruction for Linux/OSX Users
1. Open a shell.
2. `cd` to the repository's directory.
```
cd SOMEWHERE/talking-head-anime-4-demo
```
3. Run the program.
```
bin/run src/tha4/app/character_model_mediapipe_puppeteer.py
```
### Instruction for Windows Users
1. Open a shell.
2. `cd` to the repository's directory.
```
cd SOMEWHERE\talking-head-anime-4-demo
```
3. Run the program.
```
bin\run.bat src\tha4\app\character_model_mediapipe_puppeteer.py
```
================================================
FILE: docs/distill.md
================================================
# `distill`
This program trains a student model given a configuration file, a $512 \times 512$ RGBA character image, and a mask of facial organs.
## Invoking the Program
Make sure you have (1) created a Python environment and (2) downloaded model files as instruction in the [main README file](../README.md).
### Instruction for Linux/OSX Users
1. Open a shell.
2. `cd` to the repository's directory.
```
cd SOMEWHERE/talking-head-anime-4-demo
```
3. Run the program.
```
bin/run src/tha4/app/distill.py
```
where `` is a configuration file for creating a student model. More on this later.
### Instruction for Windows Users
1. Open a shell.
2. `cd` to the repository's directory.
```
cd SOMEWHERE\talking-head-anime-4-demo
```
3. Run the program.
```
bin\run.bat src\tha4\app\full_manual_poser.py
```
where `` is a configuration file for creating a student model. More on this later.
## Configuration File
A configuration file is a [YAML](https://yaml.org/) file that specify how to create a student model. This repository comes with two valid configuration files that you can peruse:
* [data/distill_examples/lambda_00/config.yaml](../data/distill_examples/lambda_00/config.yaml)
* [data/distill_examples/lambda_01/config.yaml](../data/distill_examples/lambda_01/config.yaml)
I recommend that you use the `distiller_ui` program to create configuration files rather than writing them yourself. Inside the program, you can see what the fields are and what they mean.
## What `distill` Outputs
Inside the configuration file, you specify a directory where the student models should be saved to in the `prefix` field. After `distill` is done with its job, the output directory will look like this:
```
+
+ body_morpher
+ face_morpher
+ character_model
- config.yaml
```
Here:
* `config.yaml` is a copy of the configuration file that you wrote.
* The `character_model` directory contains a trained student model that can be used with `character_model_manual_poser.md`, `character_model_ifacialmocap_puppeteer.md`, and `character_model_mediapipe_puppeteer.md`.
* `body_morpher` is a scratch directory that was used to save intermediate results during the training of a part of the student model.
* `face_morpher` is a scratch directory that was used to save intermediate results during the training of another part of the student model.
You only need what is inside the `character_model` directory. As a resulit, you can delete other files after the `character_model` directory has been filled. You can move the directory out to somewhere and rename it as long as the contents inside are not modified.
## The Training Process Is Interruptible
Invoking `distill` on a configuration will start a rather long process of training a student model. On a machine with an A6000 GPU, it takes about 30 hours to complete. As a result, it might take several days on machines with less powerful GPUs.
The training process is robust and interruptible. You can stop it any time by closing the shell window or by typing `Ctrl+C`. Intermediate results are periodically saved in the scratch directories, ready to be picked up at a later time when you are ready to train the student model again. To resume the process, just invoke `distill` again with the same configuration file that you started with, and the process will take care of itself.
================================================
FILE: docs/distiller_ui.md
================================================
# `distiller_ui`
This program provides a user-friendly interface to the [`distill`](distill.md) program, allowing you to create training configurations and providing useful documentation.
## Invoking the Program
Make sure you have (1) created a Python environment and (2) downloaded model files as instruction in the [main README file](../README.md).
### Instruction for Linux/OSX Users
1. Open a shell.
2. `cd` to the repository's directory.
```
cd SOMEWHERE/talking-head-anime-4-demo
```
3. Run the program.
```
bin/run src/tha4/app/distill_ui.py
```
### Instruction for Windows Users
1. Open a shell.
2. `cd` to the repository's directory.
```
cd SOMEWHERE\talking-head-anime-4-demo
```
3. Run the program.
```
bin\run.bat src\tha4\app\distill_ui.py
```
## Usage
Please consult the documentation inside the program itself. It is available on the rightmost panel.
================================================
FILE: docs/full_manual_poser.md
================================================
# `full_manual_poser`
This program uses the full version of the Talking Head(?) Anime 4 system to animate character images.
## Invoking the Program
Make sure you have (1) created a Python environment and (2) downloaded model files as instruction in the [main README file](../README.md).
### Instruction for Linux/OSX Users
1. Open a shell.
2. `cd` to the repository's directory.
```
cd SOMEWHERE/talking-head-anime-4-demo
```
3. Run the program.
```
bin/run src/tha4/app/full_manual_poser.py
```
### Instruction for Windows Users
1. Open a shell.
2. `cd` to the repository's directory.
```
cd SOMEWHERE\talking-head-anime-4-demo
```
3. Run the program.
```
bin\run.bat src\tha4\app\full_manual_poser.py
```
================================================
FILE: poetry/README.md
================================================
================================================
FILE: poetry/pyproject.toml
================================================
[tool.poetry]
name = "talking-head-anime-4-demo"
version = "0.1.0"
description = "Demo code for Talking Head(?) Anime 4"
authors = ["Pramook Khungurn "]
readme = "README.md"
packages = [
{include = "tha4", from = "../src"},
]
[tool.poetry.dependencies]
python = ">=3.10, <3.11"
torch = {version = "1.13.1", source = "torch_cu117"}
torchvision = {version = "0.14.1", source = "torch_cu117"}
tensorboard = "^2.15.1"
opencv-python = "^4.8.1.78"
wxpython = "^4.2.1"
numpy-quaternion = "^2022.4.2"
pillow = "^9.4.0"
matplotlib = "^3.6.3"
einops = "^0.6.0"
mediapipe = "^0.10.3"
numpy = "^1.26.3"
scipy = "^1.12.0"
omegaconf = "^2.3.0"
[[tool.poetry.source]]
name = "torch_cu117"
url = "https://download.pytorch.org/whl/cu117"
priority = "explicit"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
================================================
FILE: src/tha4/__init__.py
================================================
================================================
FILE: src/tha4/app/__init__.py
================================================
================================================
FILE: src/tha4/app/character_model_ifacialmocap_puppeteer.py
================================================
import os
import socket
import sys
import threading
import time
from typing import Optional
import PIL.Image
from tha4.shion.base.image_util import torch_linear_to_srgb
from tha4.image_util import convert_linear_to_srgb
from tha4.mocap.ifacialmocap_pose_converter_25 import create_ifacialmocap_pose_converter
from tha4.app.full_manual_poser import resize_PIL_image
from tha4.charmodel.character_model import CharacterModel
sys.path.append(os.getcwd())
from tha4.mocap.ifacialmocap_pose import create_default_ifacialmocap_pose
from tha4.mocap.ifacialmocap_v2 import IFACIALMOCAP_PORT, IFACIALMOCAP_START_STRING, parse_ifacialmocap_v2_pose
import torch
import wx
from tha4.mocap.ifacialmocap_constants import *
from tha4.mocap.ifacialmocap_pose_converter import IFacialMocapPoseConverter
class FpsStatistics:
def __init__(self):
self.count = 100
self.fps = []
def add_fps(self, fps):
self.fps.append(fps)
while len(self.fps) > self.count:
del self.fps[0]
def get_average_fps(self):
if len(self.fps) == 0:
return 0.0
else:
return sum(self.fps) / len(self.fps)
class MainFrame(wx.Frame):
IMAGE_SIZE = 512
def __init__(self, pose_converter: IFacialMocapPoseConverter, device: torch.device):
super().__init__(None, wx.ID_ANY, "iFacialMocap Puppeteer (Fuji)")
self.poser = None
self.pose_converter = pose_converter
self.device = device
self.ifacialmocap_pose = create_default_ifacialmocap_pose()
self.source_image_bitmap = wx.Bitmap(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE)
self.result_image_bitmap = wx.Bitmap(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE)
self.wx_source_image = None
self.torch_source_image = None
self.last_pose = None
self.fps_statistics = FpsStatistics()
self.last_update_time = None
self.create_receiving_socket()
self.create_ui()
self.create_timers()
self.Bind(wx.EVT_CLOSE, self.on_close)
self.update_source_image_bitmap()
self.update_result_image_bitmap()
def create_receiving_socket(self):
self.receiving_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.receiving_socket.bind(("", IFACIALMOCAP_PORT))
self.receiving_socket.setblocking(False)
def create_timers(self):
self.capture_timer = wx.Timer(self, wx.ID_ANY)
self.Bind(wx.EVT_TIMER, self.update_capture_panel, id=self.capture_timer.GetId())
self.animation_timer = wx.Timer(self, wx.ID_ANY)
self.Bind(wx.EVT_TIMER, self.update_result_image_bitmap, id=self.animation_timer.GetId())
def on_close(self, event: wx.Event):
# Stop the timers
self.animation_timer.Stop()
self.capture_timer.Stop()
# Close receiving socket
self.receiving_socket.close()
# Destroy the windows
self.Destroy()
event.Skip()
def on_start_capture(self, event: wx.Event):
capture_device_ip_address = self.capture_device_ip_text_ctrl.GetValue()
out_socket = None
try:
address = (capture_device_ip_address, IFACIALMOCAP_PORT)
out_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
out_socket.sendto(IFACIALMOCAP_START_STRING, address)
except Exception as e:
message_dialog = wx.MessageDialog(self, str(e), "Error!", wx.OK)
message_dialog.ShowModal()
message_dialog.Destroy()
finally:
if out_socket is not None:
out_socket.close()
def read_ifacialmocap_pose(self):
if not self.animation_timer.IsRunning():
return self.ifacialmocap_pose
socket_bytes = None
while True:
try:
socket_bytes = self.receiving_socket.recv(8192)
except socket.error as e:
break
if socket_bytes is not None:
socket_string = socket_bytes.decode("utf-8")
self.ifacialmocap_pose = parse_ifacialmocap_v2_pose(socket_string)
return self.ifacialmocap_pose
def on_erase_background(self, event: wx.Event):
pass
def create_animation_panel(self, parent):
self.animation_panel = wx.Panel(parent, style=wx.RAISED_BORDER)
self.animation_panel_sizer = wx.BoxSizer(wx.HORIZONTAL)
self.animation_panel.SetSizer(self.animation_panel_sizer)
self.animation_panel.SetAutoLayout(1)
image_size = MainFrame.IMAGE_SIZE
if True:
self.input_panel = wx.Panel(self.animation_panel, size=(image_size, image_size + 128),
style=wx.SIMPLE_BORDER)
self.input_panel_sizer = wx.BoxSizer(wx.VERTICAL)
self.input_panel.SetSizer(self.input_panel_sizer)
self.input_panel.SetAutoLayout(1)
self.animation_panel_sizer.Add(self.input_panel, 0, wx.FIXED_MINSIZE)
self.source_image_panel = wx.Panel(self.input_panel, size=(image_size, image_size), style=wx.SIMPLE_BORDER)
self.source_image_panel.Bind(wx.EVT_PAINT, self.paint_source_image_panel)
self.source_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)
self.input_panel_sizer.Add(self.source_image_panel, 0, wx.FIXED_MINSIZE)
self.load_model_button = wx.Button(self.input_panel, wx.ID_ANY, "Load Model")
self.input_panel_sizer.Add(self.load_model_button, 1, wx.EXPAND)
self.load_model_button.Bind(wx.EVT_BUTTON, self.load_model)
self.input_panel_sizer.Fit(self.input_panel)
if True:
self.pose_converter.init_pose_converter_panel(self.animation_panel)
if True:
self.animation_left_panel = wx.Panel(self.animation_panel, style=wx.SIMPLE_BORDER)
self.animation_left_panel_sizer = wx.BoxSizer(wx.VERTICAL)
self.animation_left_panel.SetSizer(self.animation_left_panel_sizer)
self.animation_left_panel.SetAutoLayout(1)
self.animation_panel_sizer.Add(self.animation_left_panel, 0, wx.EXPAND)
self.result_image_panel = wx.Panel(self.animation_left_panel, size=(image_size, image_size),
style=wx.SIMPLE_BORDER)
self.result_image_panel.Bind(wx.EVT_PAINT, self.paint_result_image_panel)
self.result_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)
self.animation_left_panel_sizer.Add(self.result_image_panel, 0, wx.FIXED_MINSIZE)
separator = wx.StaticLine(self.animation_left_panel, -1, size=(256, 5))
self.animation_left_panel_sizer.Add(separator, 0, wx.EXPAND)
background_text = wx.StaticText(self.animation_left_panel, label="--- Background ---",
style=wx.ALIGN_CENTER)
self.animation_left_panel_sizer.Add(background_text, 0, wx.EXPAND)
self.output_background_choice = wx.Choice(
self.animation_left_panel,
choices=[
"TRANSPARENT",
"GREEN",
"BLUE",
"BLACK",
"WHITE"
])
self.output_background_choice.SetSelection(0)
self.animation_left_panel_sizer.Add(self.output_background_choice, 0, wx.EXPAND)
separator = wx.StaticLine(self.animation_left_panel, -1, size=(256, 5))
self.animation_left_panel_sizer.Add(separator, 0, wx.EXPAND)
self.fps_text = wx.StaticText(self.animation_left_panel, label="")
self.animation_left_panel_sizer.Add(self.fps_text, wx.SizerFlags().Border())
self.animation_left_panel_sizer.Fit(self.animation_left_panel)
self.animation_panel_sizer.Fit(self.animation_panel)
def create_ui(self):
self.main_sizer = wx.BoxSizer(wx.VERTICAL)
self.SetSizer(self.main_sizer)
self.SetAutoLayout(1)
self.capture_pose_lock = threading.Lock()
self.create_connection_panel(self)
self.main_sizer.Add(self.connection_panel, wx.SizerFlags(0).Expand().Border(wx.ALL, 5))
self.create_animation_panel(self)
self.main_sizer.Add(self.animation_panel, wx.SizerFlags(0).Expand().Border(wx.ALL, 5))
self.create_capture_panel(self)
self.main_sizer.Add(self.capture_panel, wx.SizerFlags(0).Expand().Border(wx.ALL, 5))
self.main_sizer.Fit(self)
def create_connection_panel(self, parent):
self.connection_panel = wx.Panel(parent, style=wx.RAISED_BORDER)
self.connection_panel_sizer = wx.BoxSizer(wx.HORIZONTAL)
self.connection_panel.SetSizer(self.connection_panel_sizer)
self.connection_panel.SetAutoLayout(1)
capture_device_ip_text = wx.StaticText(self.connection_panel, label="Capture Device IP:", style=wx.ALIGN_RIGHT)
self.connection_panel_sizer.Add(capture_device_ip_text, wx.SizerFlags(0).FixedMinSize().Border(wx.ALL, 3))
self.capture_device_ip_text_ctrl = wx.TextCtrl(self.connection_panel, value="192.168.0.1")
self.connection_panel_sizer.Add(self.capture_device_ip_text_ctrl, wx.SizerFlags(1).Expand().Border(wx.ALL, 3))
self.start_capture_button = wx.Button(self.connection_panel, label="START CAPTURE!")
self.connection_panel_sizer.Add(self.start_capture_button, wx.SizerFlags(0).FixedMinSize().Border(wx.ALL, 3))
self.start_capture_button.Bind(wx.EVT_BUTTON, self.on_start_capture)
def create_capture_panel(self, parent):
self.capture_panel = wx.Panel(parent, style=wx.RAISED_BORDER)
self.capture_panel_sizer = wx.FlexGridSizer(cols=5)
for i in range(5):
self.capture_panel_sizer.AddGrowableCol(i)
self.capture_panel.SetSizer(self.capture_panel_sizer)
self.capture_panel.SetAutoLayout(1)
self.rotation_labels = {}
self.rotation_value_labels = {}
rotation_column_0 = self.create_rotation_column(self.capture_panel, RIGHT_EYE_BONE_ROTATIONS)
self.capture_panel_sizer.Add(rotation_column_0, wx.SizerFlags(0).Expand().Border(wx.ALL, 3))
rotation_column_1 = self.create_rotation_column(self.capture_panel, LEFT_EYE_BONE_ROTATIONS)
self.capture_panel_sizer.Add(rotation_column_1, wx.SizerFlags(0).Expand().Border(wx.ALL, 3))
rotation_column_2 = self.create_rotation_column(self.capture_panel, HEAD_BONE_ROTATIONS)
self.capture_panel_sizer.Add(rotation_column_2, wx.SizerFlags(0).Expand().Border(wx.ALL, 3))
def create_rotation_column(self, parent, rotation_names):
column_panel = wx.Panel(parent, style=wx.SIMPLE_BORDER)
column_panel_sizer = wx.FlexGridSizer(cols=2)
column_panel_sizer.AddGrowableCol(1)
column_panel.SetSizer(column_panel_sizer)
column_panel.SetAutoLayout(1)
for rotation_name in rotation_names:
self.rotation_labels[rotation_name] = wx.StaticText(
column_panel, label=rotation_name, style=wx.ALIGN_RIGHT)
column_panel_sizer.Add(self.rotation_labels[rotation_name],
wx.SizerFlags(1).Expand().Border(wx.ALL, 3))
self.rotation_value_labels[rotation_name] = wx.TextCtrl(
column_panel, style=wx.TE_RIGHT)
self.rotation_value_labels[rotation_name].SetValue("0.00")
self.rotation_value_labels[rotation_name].Disable()
column_panel_sizer.Add(self.rotation_value_labels[rotation_name],
wx.SizerFlags(1).Expand().Border(wx.ALL, 3))
column_panel.GetSizer().Fit(column_panel)
return column_panel
def paint_capture_panel(self, event: wx.Event):
self.update_capture_panel(event)
def update_capture_panel(self, event: wx.Event):
data = self.ifacialmocap_pose
for rotation_name in ROTATION_NAMES:
value = data[rotation_name]
self.rotation_value_labels[rotation_name].SetValue("%0.2f" % value)
@staticmethod
def convert_to_100(x):
return int(max(0.0, min(1.0, x)) * 100)
def paint_source_image_panel(self, event: wx.Event):
wx.BufferedPaintDC(self.source_image_panel, self.source_image_bitmap)
def update_source_image_bitmap(self):
dc = wx.MemoryDC()
dc.SelectObject(self.source_image_bitmap)
if self.wx_source_image is None:
self.draw_nothing_yet_string(dc)
else:
dc.Clear()
dc.DrawBitmap(self.wx_source_image, 0, 0, True)
del dc
def draw_nothing_yet_string(self, dc):
dc.Clear()
font = wx.Font(wx.FontInfo(14).Family(wx.FONTFAMILY_SWISS))
dc.SetFont(font)
w, h = dc.GetTextExtent("Nothing yet!")
dc.DrawText("Nothing yet!", (MainFrame.IMAGE_SIZE - w) // 2, (MainFrame.IMAGE_SIZE - h) // 2)
def paint_result_image_panel(self, event: wx.Event):
wx.BufferedPaintDC(self.result_image_panel, self.result_image_bitmap)
def update_result_image_bitmap(self, event: Optional[wx.Event] = None):
ifacialmocap_pose = self.read_ifacialmocap_pose()
current_pose = self.pose_converter.convert(ifacialmocap_pose)
if self.last_pose is not None and self.last_pose == current_pose:
return
self.last_pose = current_pose
if self.torch_source_image is None or self.poser is None:
dc = wx.MemoryDC()
dc.SelectObject(self.result_image_bitmap)
self.draw_nothing_yet_string(dc)
del dc
return
pose = torch.tensor(current_pose, device=self.device, dtype=self.poser.get_dtype())
with torch.no_grad():
output_image = self.poser.pose(self.torch_source_image, pose)[0].float()
output_image = torch.clip((output_image + 1.0) / 2.0, 0.0, 1.0)
output_image = convert_linear_to_srgb(output_image)
background_choice = self.output_background_choice.GetSelection()
if background_choice == 0:
pass
else:
background = torch.zeros(4, output_image.shape[1], output_image.shape[2], device=self.device)
background[3, :, :] = 1.0
if background_choice == 1:
background[1, :, :] = 1.0
output_image = self.blend_with_background(output_image, background)
elif background_choice == 2:
background[2, :, :] = 1.0
output_image = self.blend_with_background(output_image, background)
elif background_choice == 3:
output_image = self.blend_with_background(output_image, background)
else:
background[0:3, :, :] = 1.0
output_image = self.blend_with_background(output_image, background)
c, h, w = output_image.shape
output_image = 255.0 * torch.transpose(output_image.reshape(c, h * w), 0, 1).reshape(h, w, c)
output_image = output_image.byte()
numpy_image = output_image.detach().cpu().numpy()
wx_image = wx.ImageFromBuffer(numpy_image.shape[0],
numpy_image.shape[1],
numpy_image[:, :, 0:3].tobytes(),
numpy_image[:, :, 3].tobytes())
wx_bitmap = wx_image.ConvertToBitmap()
dc = wx.MemoryDC()
dc.SelectObject(self.result_image_bitmap)
dc.Clear()
dc.DrawBitmap(wx_bitmap,
(MainFrame.IMAGE_SIZE - numpy_image.shape[0]) // 2,
(MainFrame.IMAGE_SIZE - numpy_image.shape[1]) // 2, True)
del dc
time_now = time.time_ns()
if self.last_update_time is not None:
elapsed_time = time_now - self.last_update_time
fps = 1.0 / (elapsed_time / 10 ** 9)
if self.torch_source_image is not None:
self.fps_statistics.add_fps(fps)
self.fps_text.SetLabelText("FPS = %0.2f" % self.fps_statistics.get_average_fps())
self.last_update_time = time_now
self.Refresh()
def blend_with_background(self, numpy_image, background):
alpha = numpy_image[3:4, :, :]
color = numpy_image[0:3, :, :]
new_color = color * alpha + (1.0 - alpha) * background[0:3, :, :]
return torch.cat([new_color, background[3:4, :, :]], dim=0)
def load_model(self, event: wx.Event):
dir_name = "data/character_models"
file_dialog = wx.FileDialog(self, "Choose a model", dir_name, "", "*.yaml", wx.FD_OPEN)
if file_dialog.ShowModal() == wx.ID_OK:
character_model_json_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename())
try:
self.character_model = CharacterModel.load(character_model_json_file_name)
self.torch_source_image = self.character_model.get_character_image(self.device)
pil_image = resize_PIL_image(
PIL.Image.open(self.character_model.character_image_file_name),
(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE))
w, h = pil_image.size
self.wx_source_image = wx.Bitmap.FromBufferRGBA(w, h, pil_image.convert("RGBA").tobytes())
self.update_source_image_bitmap()
self.poser = self.character_model.get_poser(self.device)
except Exception:
message_dialog = wx.MessageDialog(
self, "Could not load character model " + character_model_json_file_name, "Poser", wx.OK)
message_dialog.ShowModal()
message_dialog.Destroy()
file_dialog.Destroy()
self.Refresh()
if __name__ == "__main__":
device = torch.device('cuda:0')
pose_converter = create_ifacialmocap_pose_converter()
app = wx.App()
main_frame = MainFrame(pose_converter, device)
main_frame.Show(True)
main_frame.capture_timer.Start(10)
main_frame.animation_timer.Start(10)
app.MainLoop()
================================================
FILE: src/tha4/app/character_model_manual_poser.py
================================================
import logging
import os
import sys
import time
from typing import List
from tha4.charmodel.character_model import CharacterModel
from tha4.image_util import resize_PIL_image, convert_output_image_from_torch_to_numpy
from tha4.poser.modes.mode_14 import get_pose_parameters
sys.path.append(os.getcwd())
import PIL.Image
import torch
import wx
from tha4.poser.poser import PoseParameterCategory, PoseParameterGroup
class MorphCategoryControlPanel(wx.Panel):
def __init__(self,
parent,
title: str,
pose_param_category: PoseParameterCategory,
param_groups: List[PoseParameterGroup]):
super().__init__(parent, style=wx.SIMPLE_BORDER)
self.pose_param_category = pose_param_category
self.sizer = wx.BoxSizer(wx.VERTICAL)
self.SetSizer(self.sizer)
self.SetAutoLayout(1)
title_text = wx.StaticText(self, label=title, style=wx.ALIGN_CENTER)
self.sizer.Add(title_text, 0, wx.EXPAND)
self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category]
self.choice = wx.Choice(self, choices=[group.get_group_name() for group in self.param_groups])
if len(self.param_groups) > 0:
self.choice.SetSelection(0)
self.choice.Bind(wx.EVT_CHOICE, self.on_choice_updated)
self.sizer.Add(self.choice, 0, wx.EXPAND)
self.left_slider = wx.Slider(self, minValue=-1000, maxValue=1000, value=-1000, style=wx.HORIZONTAL)
self.sizer.Add(self.left_slider, 0, wx.EXPAND)
self.right_slider = wx.Slider(self, minValue=-1000, maxValue=1000, value=-1000, style=wx.HORIZONTAL)
self.sizer.Add(self.right_slider, 0, wx.EXPAND)
self.checkbox = wx.CheckBox(self, label="Show")
self.checkbox.SetValue(True)
self.sizer.Add(self.checkbox, 0, wx.SHAPED | wx.ALIGN_CENTER)
self.update_ui()
self.sizer.Fit(self)
def update_ui(self):
param_group = self.param_groups[self.choice.GetSelection()]
if param_group.is_discrete():
self.left_slider.Enable(False)
self.right_slider.Enable(False)
self.checkbox.Enable(True)
elif param_group.get_arity() == 1:
self.left_slider.Enable(True)
self.right_slider.Enable(False)
self.checkbox.Enable(False)
else:
self.left_slider.Enable(True)
self.right_slider.Enable(True)
self.checkbox.Enable(False)
def on_choice_updated(self, event: wx.Event):
param_group = self.param_groups[self.choice.GetSelection()]
if param_group.is_discrete():
self.checkbox.SetValue(True)
self.update_ui()
def set_param_value(self, pose: List[float]):
if len(self.param_groups) == 0:
return
selected_morph_index = self.choice.GetSelection()
param_group = self.param_groups[selected_morph_index]
param_index = param_group.get_parameter_index()
if param_group.is_discrete():
if self.checkbox.GetValue():
for i in range(param_group.get_arity()):
pose[param_index + i] = 1.0
else:
param_range = param_group.get_range()
alpha = (self.left_slider.GetValue() + 1000) / 2000.0
pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha
if param_group.get_arity() == 2:
alpha = (self.right_slider.GetValue() + 1000) / 2000.0
pose[param_index + 1] = param_range[0] + (param_range[1] - param_range[0]) * alpha
class SimpleParamGroupsControlPanel(wx.Panel):
def __init__(self, parent,
pose_param_category: PoseParameterCategory,
param_groups: List[PoseParameterGroup]):
super().__init__(parent, style=wx.SIMPLE_BORDER)
self.sizer = wx.BoxSizer(wx.VERTICAL)
self.SetSizer(self.sizer)
self.SetAutoLayout(1)
self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category]
for param_group in self.param_groups:
assert not param_group.is_discrete()
assert param_group.get_arity() == 1
self.sliders = []
for param_group in self.param_groups:
static_text = wx.StaticText(
self,
label=" ------------ %s ------------ " % param_group.get_group_name(), style=wx.ALIGN_CENTER)
self.sizer.Add(static_text, 0, wx.EXPAND)
range = param_group.get_range()
min_value = int(range[0] * 1000)
max_value = int(range[1] * 1000)
slider = wx.Slider(self, minValue=min_value, maxValue=max_value, value=0, style=wx.HORIZONTAL)
self.sizer.Add(slider, 0, wx.EXPAND)
self.sliders.append(slider)
self.sizer.Fit(self)
def set_param_value(self, pose: List[float]):
if len(self.param_groups) == 0:
return
for param_group_index in range(len(self.param_groups)):
param_group = self.param_groups[param_group_index]
slider = self.sliders[param_group_index]
param_range = param_group.get_range()
param_index = param_group.get_parameter_index()
alpha = (slider.GetValue() - slider.GetMin()) * 1.0 / (slider.GetMax() - slider.GetMin())
pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha
class MainFrame(wx.Frame):
IMAGE_SIZE = 512
OUTPUT_LENGTH = 6
NUM_PARAMETERS = 45
def __init__(self, device: torch.device):
super().__init__(None, wx.ID_ANY, "Poser")
self.poser = None
self.device = device
self.wx_source_image = None
self.torch_source_image = None
self.main_sizer = wx.BoxSizer(wx.HORIZONTAL)
self.SetSizer(self.main_sizer)
self.SetAutoLayout(1)
self.init_left_panel()
self.init_control_panel()
self.init_right_panel()
self.main_sizer.Fit(self)
self.timer = wx.Timer(self, wx.ID_ANY)
self.Bind(wx.EVT_TIMER, self.update_images, self.timer)
save_image_id = wx.NewIdRef()
self.Bind(wx.EVT_MENU, self.on_save_image, id=save_image_id)
accelerator_table = wx.AcceleratorTable([
(wx.ACCEL_CTRL, ord('S'), save_image_id)
])
self.SetAcceleratorTable(accelerator_table)
self.last_pose = None
self.last_output_index = self.output_index_choice.GetSelection()
self.last_output_numpy_image = None
self.wx_source_image = None
self.torch_source_image = None
self.source_image_bitmap = wx.Bitmap(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE)
self.result_image_bitmap = wx.Bitmap(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE)
self.source_image_dirty = True
def init_left_panel(self):
self.control_panel = wx.Panel(self, style=wx.SIMPLE_BORDER, size=(MainFrame.IMAGE_SIZE, -1))
self.left_panel = wx.Panel(self, style=wx.SIMPLE_BORDER)
left_panel_sizer = wx.BoxSizer(wx.VERTICAL)
self.left_panel.SetSizer(left_panel_sizer)
self.left_panel.SetAutoLayout(1)
self.source_image_panel = wx.Panel(self.left_panel, size=(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE),
style=wx.SIMPLE_BORDER)
self.source_image_panel.Bind(wx.EVT_PAINT, self.paint_source_image_panel)
self.source_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)
left_panel_sizer.Add(self.source_image_panel, 0, wx.FIXED_MINSIZE)
self.load_model_button = wx.Button(self.left_panel, wx.ID_ANY, "\nLoad Model\n\n")
left_panel_sizer.Add(self.load_model_button, 1, wx.EXPAND)
self.load_model_button.Bind(wx.EVT_BUTTON, self.load_model)
left_panel_sizer.Fit(self.left_panel)
self.main_sizer.Add(self.left_panel, 0, wx.FIXED_MINSIZE)
def on_erase_background(self, event: wx.Event):
pass
def init_control_panel(self):
self.control_panel_sizer = wx.BoxSizer(wx.VERTICAL)
self.control_panel.SetSizer(self.control_panel_sizer)
self.control_panel.SetMinSize(wx.Size(256, 1))
morph_categories = [
PoseParameterCategory.EYEBROW,
PoseParameterCategory.EYE,
PoseParameterCategory.MOUTH,
PoseParameterCategory.IRIS_MORPH
]
morph_category_titles = {
PoseParameterCategory.EYEBROW: " ------------ Eyebrow ------------ ",
PoseParameterCategory.EYE: " ------------ Eye ------------ ",
PoseParameterCategory.MOUTH: " ------------ Mouth ------------ ",
PoseParameterCategory.IRIS_MORPH: " ------------ Iris morphs ------------ ",
}
self.morph_control_panels = {}
param_groups = get_pose_parameters().get_pose_parameter_groups()
for category in morph_categories:
filtered_param_groups = [group for group in param_groups if group.get_category() == category]
if len(filtered_param_groups) == 0:
continue
control_panel = MorphCategoryControlPanel(
self.control_panel,
morph_category_titles[category],
category,
param_groups)
self.morph_control_panels[category] = control_panel
self.control_panel_sizer.Add(control_panel, 0, wx.EXPAND)
self.non_morph_control_panels = {}
non_morph_categories = [
PoseParameterCategory.IRIS_ROTATION,
PoseParameterCategory.FACE_ROTATION,
PoseParameterCategory.BODY_ROTATION,
PoseParameterCategory.BREATHING
]
for category in non_morph_categories:
filtered_param_groups = [group for group in param_groups if group.get_category() == category]
if len(filtered_param_groups) == 0:
continue
control_panel = SimpleParamGroupsControlPanel(
self.control_panel,
category,
param_groups)
self.non_morph_control_panels[category] = control_panel
self.control_panel_sizer.Add(control_panel, 0, wx.EXPAND)
self.control_panel_sizer.Fit(self.control_panel)
self.main_sizer.Add(self.control_panel, 1, wx.FIXED_MINSIZE)
def init_right_panel(self):
self.right_panel = wx.Panel(self, style=wx.SIMPLE_BORDER)
right_panel_sizer = wx.BoxSizer(wx.VERTICAL)
self.right_panel.SetSizer(right_panel_sizer)
self.right_panel.SetAutoLayout(1)
self.result_image_panel = wx.Panel(self.right_panel,
size=(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE),
style=wx.SIMPLE_BORDER)
self.result_image_panel.Bind(wx.EVT_PAINT, self.paint_result_image_panel)
self.result_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)
self.output_index_choice = wx.Choice(
self.right_panel,
choices=[str(i) for i in range(MainFrame.OUTPUT_LENGTH)])
self.output_index_choice.SetSelection(0)
right_panel_sizer.Add(self.result_image_panel, 0, wx.FIXED_MINSIZE)
right_panel_sizer.Add(self.output_index_choice, 0, wx.EXPAND)
self.save_image_button = wx.Button(self.right_panel, wx.ID_ANY, "\nSave Image\n\n")
right_panel_sizer.Add(self.save_image_button, 1, wx.EXPAND)
self.save_image_button.Bind(wx.EVT_BUTTON, self.on_save_image)
right_panel_sizer.Fit(self.right_panel)
self.main_sizer.Add(self.right_panel, 0, wx.FIXED_MINSIZE)
def create_param_category_choice(self, param_category: PoseParameterCategory):
params = []
for param_group in self.poser.get_pose_parameter_groups():
if param_group.get_category() == param_category:
params.append(param_group.get_group_name())
choice = wx.Choice(self.control_panel, choices=params)
if len(params) > 0:
choice.SetSelection(0)
return choice
def load_model(self, event: wx.Event):
dir_name = "data/character_models"
file_dialog = wx.FileDialog(self, "Choose a model", dir_name, "", "*.yaml", wx.FD_OPEN)
if file_dialog.ShowModal() == wx.ID_OK:
character_model_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename())
try:
self.character_model = CharacterModel.load(character_model_file_name)
self.torch_source_image = self.character_model.get_character_image(self.device)
pil_image = resize_PIL_image(
PIL.Image.open(self.character_model.character_image_file_name),
(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE))
w, h = pil_image.size
self.wx_source_image = wx.Bitmap.FromBufferRGBA(w, h, pil_image.convert("RGBA").tobytes())
self.poser = self.character_model.get_poser(self.device)
self.source_image_dirty = True
self.Refresh()
self.Update()
except RuntimeError as e:
message_dialog = wx.MessageDialog(
self, "Could not load character model " + character_model_file_name, "Poser", wx.OK)
message_dialog.ShowModal()
message_dialog.Destroy()
file_dialog.Destroy()
def paint_source_image_panel(self, event: wx.Event):
wx.BufferedPaintDC(self.source_image_panel, self.source_image_bitmap)
def paint_result_image_panel(self, event: wx.Event):
wx.BufferedPaintDC(self.result_image_panel, self.result_image_bitmap)
def draw_nothing_yet_string_to_bitmap(self, bitmap):
dc = wx.MemoryDC()
dc.SelectObject(bitmap)
dc.Clear()
font = wx.Font(wx.FontInfo(14).Family(wx.FONTFAMILY_SWISS))
dc.SetFont(font)
w, h = dc.GetTextExtent("Nothing yet!")
dc.DrawText("Nothing yet!", (MainFrame.IMAGE_SIZE - w) // 2, (MainFrame.IMAGE_SIZE - - h) // 2)
del dc
def get_current_pose(self):
current_pose = [0.0 for i in range(MainFrame.NUM_PARAMETERS)]
for morph_control_panel in self.morph_control_panels.values():
morph_control_panel.set_param_value(current_pose)
for rotation_control_panel in self.non_morph_control_panels.values():
rotation_control_panel.set_param_value(current_pose)
return current_pose
def update_images(self, event: wx.Event):
current_pose = self.get_current_pose()
if not self.source_image_dirty \
and self.last_pose is not None \
and self.last_pose == current_pose \
and self.last_output_index == self.output_index_choice.GetSelection():
return
self.last_pose = current_pose
self.last_output_index = self.output_index_choice.GetSelection()
if self.torch_source_image is None or self.poser is None:
self.draw_nothing_yet_string_to_bitmap(self.source_image_bitmap)
self.draw_nothing_yet_string_to_bitmap(self.result_image_bitmap)
self.source_image_dirty = False
self.Refresh()
self.Update()
return
if self.source_image_dirty:
dc = wx.MemoryDC()
dc.SelectObject(self.source_image_bitmap)
dc.Clear()
dc.DrawBitmap(self.wx_source_image, 0, 0)
self.source_image_dirty = False
pose = torch.tensor(current_pose, device=self.device)
output_index = self.output_index_choice.GetSelection()
with torch.no_grad():
start_cuda_event = torch.cuda.Event(enable_timing=True)
end_cuda_event = torch.cuda.Event(enable_timing=True)
start_cuda_event.record()
start_time = time.time()
output_image = self.poser.pose(self.torch_source_image, pose, output_index)[0].detach().cpu()
end_time = time.time()
end_cuda_event.record()
torch.cuda.synchronize()
print("cuda time (ms):", start_cuda_event.elapsed_time(end_cuda_event))
print("elapsed time (ms):", (end_time - start_time) * 1000.0)
numpy_image = convert_output_image_from_torch_to_numpy(output_image)
self.last_output_numpy_image = numpy_image
wx_image = wx.ImageFromBuffer(
numpy_image.shape[0],
numpy_image.shape[1],
numpy_image[:, :, 0:3].tobytes(),
numpy_image[:, :, 3].tobytes())
wx_bitmap = wx_image.ConvertToBitmap()
dc = wx.MemoryDC()
dc.SelectObject(self.result_image_bitmap)
dc.Clear()
dc.DrawBitmap(wx_bitmap,
(MainFrame.IMAGE_SIZE - numpy_image.shape[0]) // 2,
(MainFrame.IMAGE_SIZE - numpy_image.shape[1]) // 2,
True)
del dc
self.Refresh()
self.Update()
def on_save_image(self, event: wx.Event):
if self.last_output_numpy_image is None:
logging.info("There is no output image to save!!!")
return
dir_name = "data/images"
file_dialog = wx.FileDialog(self, "Choose an image", dir_name, "", "*.png", wx.FD_SAVE)
if file_dialog.ShowModal() == wx.ID_OK:
image_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename())
try:
if os.path.exists(image_file_name):
message_dialog = wx.MessageDialog(self, f"Override {image_file_name}", "Manual Poser",
wx.YES_NO | wx.ICON_QUESTION)
result = message_dialog.ShowModal()
if result == wx.ID_YES:
self.save_last_numpy_image(image_file_name)
message_dialog.Destroy()
else:
self.save_last_numpy_image(image_file_name)
except:
message_dialog = wx.MessageDialog(self, f"Could not save {image_file_name}", "Manual Poser", wx.OK)
message_dialog.ShowModal()
message_dialog.Destroy()
file_dialog.Destroy()
def save_last_numpy_image(self, image_file_name):
numpy_image = self.last_output_numpy_image
pil_image = PIL.Image.fromarray(numpy_image, mode='RGBA')
os.makedirs(os.path.dirname(image_file_name), exist_ok=True)
pil_image.save(image_file_name)
if __name__ == "__main__":
device = torch.device('cuda:0')
app = wx.App()
main_frame = MainFrame(device)
main_frame.Show(True)
main_frame.timer.Start(16)
app.MainLoop()
================================================
FILE: src/tha4/app/character_model_mediapipe_puppeteer.py
================================================
import os
import sys
import threading
import time
from typing import Optional
import PIL.Image
import cv2
import mediapipe
from scipy.spatial.transform import Rotation
from tha4.shion.base.image_util import resize_PIL_image
from tha4.charmodel.character_model import CharacterModel
from tha4.image_util import convert_linear_to_srgb
from tha4.mocap.mediapipe_constants import HEAD_ROTATIONS, HEAD_X, HEAD_Y, HEAD_Z
from tha4.mocap.mediapipe_face_pose import MediaPipeFacePose
from tha4.mocap.mediapipe_face_pose_converter_00 import MediaPoseFacePoseConverter00
sys.path.append(os.getcwd())
import torch
import wx
class FpsStatistics:
def __init__(self):
self.count = 100
self.fps = []
def add_fps(self, fps):
self.fps.append(fps)
while len(self.fps) > self.count:
del self.fps[0]
def get_average_fps(self):
if len(self.fps) == 0:
return 0.0
else:
return sum(self.fps) / len(self.fps)
class MainFrame(wx.Frame):
IMAGE_SIZE = 512
def __init__(self,
pose_converter: MediaPoseFacePoseConverter00,
video_capture,
face_landmarker,
device: torch.device):
super().__init__(None, wx.ID_ANY, "THA4 Character Model MediaPipe Puppeteer")
self.face_landmarker = face_landmarker
self.video_capture = video_capture
self.pose_converter = pose_converter
self.device = device
self.source_image_bitmap = wx.Bitmap(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE)
self.result_image_bitmap = wx.Bitmap(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE)
self.webcam_capture_bitmap = wx.Bitmap(256, 192)
self.wx_source_image = None
self.torch_source_image = None
self.last_pose = None
self.mediapipe_face_pose = None
self.fps_statistics = FpsStatistics()
self.last_update_time = None
self.character_model = None
self.poser = None
self.create_ui()
self.create_timers()
self.Bind(wx.EVT_CLOSE, self.on_close)
self.update_source_image_bitmap()
self.update_result_image_bitmap()
def create_timers(self):
self.capture_timer = wx.Timer(self, wx.ID_ANY)
self.Bind(wx.EVT_TIMER, self.update_capture_panel, id=self.capture_timer.GetId())
self.animation_timer = wx.Timer(self, wx.ID_ANY)
self.Bind(wx.EVT_TIMER, self.update_result_image_bitmap, id=self.animation_timer.GetId())
def on_close(self, event: wx.Event):
# Stop the timers
self.animation_timer.Stop()
self.capture_timer.Stop()
# Destroy the windows
self.Destroy()
event.Skip()
def on_erase_background(self, event: wx.Event):
pass
def create_animation_panel(self, parent):
self.animation_panel = wx.Panel(parent, style=wx.RAISED_BORDER)
self.animation_panel_sizer = wx.BoxSizer(wx.HORIZONTAL)
self.animation_panel.SetSizer(self.animation_panel_sizer)
self.animation_panel.SetAutoLayout(1)
image_size = MainFrame.IMAGE_SIZE
if True:
self.input_panel = wx.Panel(self.animation_panel, size=(image_size, image_size + 128),
style=wx.SIMPLE_BORDER)
self.input_panel_sizer = wx.BoxSizer(wx.VERTICAL)
self.input_panel.SetSizer(self.input_panel_sizer)
self.input_panel.SetAutoLayout(1)
self.animation_panel_sizer.Add(self.input_panel, 0, wx.FIXED_MINSIZE)
self.source_image_panel = wx.Panel(self.input_panel, size=(image_size, image_size), style=wx.SIMPLE_BORDER)
self.source_image_panel.Bind(wx.EVT_PAINT, self.paint_source_image_panel)
self.source_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)
self.input_panel_sizer.Add(self.source_image_panel, 0, wx.FIXED_MINSIZE)
self.load_model_button = wx.Button(self.input_panel, wx.ID_ANY, "Load Model")
self.input_panel_sizer.Add(self.load_model_button, 1, wx.EXPAND)
self.load_model_button.Bind(wx.EVT_BUTTON, self.load_model)
self.input_panel_sizer.Fit(self.input_panel)
if True:
def current_pose_supplier() -> Optional[MediaPipeFacePose]:
return self.mediapipe_face_pose
self.pose_converter.init_pose_converter_panel(self.animation_panel, current_pose_supplier)
if True:
self.animation_left_panel = wx.Panel(self.animation_panel, style=wx.SIMPLE_BORDER)
self.animation_left_panel_sizer = wx.BoxSizer(wx.VERTICAL)
self.animation_left_panel.SetSizer(self.animation_left_panel_sizer)
self.animation_left_panel.SetAutoLayout(1)
self.animation_panel_sizer.Add(self.animation_left_panel, 0, wx.EXPAND)
self.result_image_panel = wx.Panel(self.animation_left_panel, size=(image_size, image_size),
style=wx.SIMPLE_BORDER)
self.result_image_panel.Bind(wx.EVT_PAINT, self.paint_result_image_panel)
self.result_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)
self.animation_left_panel_sizer.Add(self.result_image_panel, 0, wx.FIXED_MINSIZE)
separator = wx.StaticLine(self.animation_left_panel, -1, size=(256, 5))
self.animation_left_panel_sizer.Add(separator, 0, wx.EXPAND)
background_text = wx.StaticText(self.animation_left_panel, label="--- Background ---",
style=wx.ALIGN_CENTER)
self.animation_left_panel_sizer.Add(background_text, 0, wx.EXPAND)
self.output_background_choice = wx.Choice(
self.animation_left_panel,
choices=[
"TRANSPARENT",
"GREEN",
"BLUE",
"BLACK",
"WHITE"
])
self.output_background_choice.SetSelection(0)
self.animation_left_panel_sizer.Add(self.output_background_choice, 0, wx.EXPAND)
separator = wx.StaticLine(self.animation_left_panel, -1, size=(256, 5))
self.animation_left_panel_sizer.Add(separator, 0, wx.EXPAND)
self.fps_text = wx.StaticText(self.animation_left_panel, label="")
self.animation_left_panel_sizer.Add(self.fps_text, wx.SizerFlags().Border())
self.animation_left_panel_sizer.Fit(self.animation_left_panel)
self.animation_panel_sizer.Fit(self.animation_panel)
def create_ui(self):
self.main_sizer = wx.BoxSizer(wx.VERTICAL)
self.SetSizer(self.main_sizer)
self.SetAutoLayout(1)
self.capture_pose_lock = threading.Lock()
self.create_animation_panel(self)
self.main_sizer.Add(self.animation_panel, wx.SizerFlags(0).Expand().Border(wx.ALL, 5))
self.create_capture_panel(self)
self.main_sizer.Add(self.capture_panel, wx.SizerFlags(0).Expand().Border(wx.ALL, 5))
self.main_sizer.Fit(self)
def create_capture_panel(self, parent):
self.capture_panel = wx.Panel(parent, style=wx.RAISED_BORDER)
self.capture_panel_sizer = wx.BoxSizer(wx.HORIZONTAL)
self.capture_panel.SetSizer(self.capture_panel_sizer)
self.capture_panel.SetAutoLayout(1)
self.webcam_capture_panel = wx.Panel(self.capture_panel, size=(256, 192), style=wx.SIMPLE_BORDER)
self.webcam_capture_panel.Bind(wx.EVT_PAINT, self.paint_webcam_capture_panel)
self.webcam_capture_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)
self.capture_panel_sizer.Add(self.webcam_capture_panel, wx.SizerFlags(0).FixedMinSize().Border(wx.ALL, 5))
self.rotation_labels = {}
self.rotation_value_labels = {}
rotation_column = self.create_rotation_column(self.capture_panel, HEAD_ROTATIONS)
self.capture_panel_sizer.Add(rotation_column, wx.SizerFlags(0).Expand().Border(wx.ALL, 3))
def paint_webcam_capture_panel(self, event: wx.Event):
wx.BufferedPaintDC(self.webcam_capture_panel, self.webcam_capture_bitmap)
def create_rotation_column(self, parent, rotation_names):
column_panel = wx.Panel(parent, style=wx.SIMPLE_BORDER)
column_panel_sizer = wx.FlexGridSizer(cols=2)
column_panel_sizer.AddGrowableCol(1)
column_panel.SetSizer(column_panel_sizer)
column_panel.SetAutoLayout(1)
for rotation_name in rotation_names:
self.rotation_labels[rotation_name] = wx.StaticText(
column_panel, label=rotation_name, style=wx.ALIGN_RIGHT)
column_panel_sizer.Add(self.rotation_labels[rotation_name],
wx.SizerFlags(1).Expand().Border(wx.ALL, 3))
self.rotation_value_labels[rotation_name] = wx.TextCtrl(
column_panel, style=wx.TE_RIGHT)
self.rotation_value_labels[rotation_name].SetValue("0.00")
self.rotation_value_labels[rotation_name].Disable()
column_panel_sizer.Add(self.rotation_value_labels[rotation_name],
wx.SizerFlags(1).Expand().Border(wx.ALL, 3))
column_panel.GetSizer().Fit(column_panel)
return column_panel
def update_capture_panel(self, event: wx.Event):
there_is_frame, frame = self.video_capture.read()
if not there_is_frame:
dc = wx.MemoryDC()
dc.SelectObject(self.webcam_capture_bitmap)
self.draw_nothing_yet_string(dc)
del dc
return
rgb_frame = cv2.flip(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), 1)
resized_frame = cv2.resize(rgb_frame, (256, 192))
wx_image = wx.ImageFromBuffer(256, 192, resized_frame.tobytes())
wx_bitmap = wx_image.ConvertToBitmap()
dc = wx.MemoryDC()
dc.SelectObject(self.webcam_capture_bitmap)
dc.Clear()
dc.DrawBitmap(wx_bitmap, 0, 0, True)
del dc
self.webcam_capture_panel.Refresh()
time_ms = int(time.time() * 1000)
mediapipe_image = mediapipe.Image(image_format=mediapipe.ImageFormat.SRGB, data=rgb_frame)
detection_result = self.face_landmarker.detect_for_video(mediapipe_image, time_ms)
self.update_mediapipe_face_pose(detection_result)
def update_mediapipe_face_pose(self, detection_result):
if len(detection_result.facial_transformation_matrixes) == 0:
return
xform_matrix = detection_result.facial_transformation_matrixes[0]
blendshape_params = {}
for item in detection_result.face_blendshapes[0]:
blendshape_params[item.category_name] = item.score
M = xform_matrix[0:3, 0:3]
rot = Rotation.from_matrix(M)
euler_angles = rot.as_euler('xyz', degrees=True)
self.rotation_value_labels[HEAD_X].SetValue("%0.2f" % euler_angles[0])
self.rotation_value_labels[HEAD_X].Refresh()
self.rotation_value_labels[HEAD_Y].SetValue("%0.2f" % euler_angles[1])
self.rotation_value_labels[HEAD_Y].Refresh()
self.rotation_value_labels[HEAD_Z].SetValue("%0.2f" % euler_angles[2])
self.rotation_value_labels[HEAD_Z].Refresh()
self.mediapipe_face_pose = MediaPipeFacePose(blendshape_params, xform_matrix)
@staticmethod
def convert_to_100(x):
return int(max(0.0, min(1.0, x)) * 100)
def paint_source_image_panel(self, event: wx.Event):
wx.BufferedPaintDC(self.source_image_panel, self.source_image_bitmap)
def update_source_image_bitmap(self):
dc = wx.MemoryDC()
dc.SelectObject(self.source_image_bitmap)
if self.wx_source_image is None:
self.draw_nothing_yet_string(dc)
else:
dc.Clear()
dc.DrawBitmap(self.wx_source_image, 0, 0, True)
del dc
def draw_nothing_yet_string(self, dc):
dc.Clear()
font = wx.Font(wx.FontInfo(14).Family(wx.FONTFAMILY_SWISS))
dc.SetFont(font)
w, h = dc.GetTextExtent("Nothing yet!")
dc.DrawText("Nothing yet!", (MainFrame.IMAGE_SIZE - w) // 2, (MainFrame.IMAGE_SIZE - h) // 2)
def paint_result_image_panel(self, event: wx.Event):
wx.BufferedPaintDC(self.result_image_panel, self.result_image_bitmap)
def update_result_image_bitmap(self, event: Optional[wx.Event] = None):
if self.mediapipe_face_pose is None or self.poser is None:
dc = wx.MemoryDC()
dc.SelectObject(self.result_image_bitmap)
self.draw_nothing_yet_string(dc)
del dc
return
current_pose = self.pose_converter.convert(self.mediapipe_face_pose)
if self.last_pose is not None and self.last_pose == current_pose:
return
self.last_pose = current_pose
if self.torch_source_image is None:
dc = wx.MemoryDC()
dc.SelectObject(self.result_image_bitmap)
self.draw_nothing_yet_string(dc)
del dc
return
pose = torch.tensor(current_pose, device=self.device, dtype=self.poser.get_dtype())
with torch.no_grad():
output_image = self.poser.pose(self.torch_source_image, pose)[0].float()
output_image = torch.clip((output_image + 1.0) / 2.0, 0.0, 1.0)
output_image = convert_linear_to_srgb(output_image)
background_choice = self.output_background_choice.GetSelection()
if background_choice == 0:
pass
else:
background = torch.zeros(4, output_image.shape[1], output_image.shape[2], device=self.device)
background[3, :, :] = 1.0
if background_choice == 1:
background[1, :, :] = 1.0
output_image = self.blend_with_background(output_image, background)
elif background_choice == 2:
background[2, :, :] = 1.0
output_image = self.blend_with_background(output_image, background)
elif background_choice == 3:
output_image = self.blend_with_background(output_image, background)
else:
background[0:3, :, :] = 1.0
output_image = self.blend_with_background(output_image, background)
c, h, w = output_image.shape
output_image = 255.0 * torch.transpose(output_image.reshape(c, h * w), 0, 1).reshape(h, w, c)
output_image = output_image.byte()
numpy_image = output_image.detach().cpu().numpy()
wx_image = wx.ImageFromBuffer(numpy_image.shape[0],
numpy_image.shape[1],
numpy_image[:, :, 0:3].tobytes(),
numpy_image[:, :, 3].tobytes())
wx_bitmap = wx_image.ConvertToBitmap()
dc = wx.MemoryDC()
dc.SelectObject(self.result_image_bitmap)
dc.Clear()
dc.DrawBitmap(wx_bitmap,
(MainFrame.IMAGE_SIZE - numpy_image.shape[0]) // 2,
(MainFrame.IMAGE_SIZE - numpy_image.shape[1]) // 2, True)
del dc
time_now = time.time_ns()
if self.last_update_time is not None:
elapsed_time = time_now - self.last_update_time
fps = 1.0 / (elapsed_time / 10 ** 9)
if self.torch_source_image is not None:
self.fps_statistics.add_fps(fps)
self.fps_text.SetLabelText("FPS = %0.2f" % self.fps_statistics.get_average_fps())
self.last_update_time = time_now
self.Refresh()
def blend_with_background(self, numpy_image, background):
alpha = numpy_image[3:4, :, :]
color = numpy_image[0:3, :, :]
new_color = color * alpha + (1.0 - alpha) * background[0:3, :, :]
return torch.cat([new_color, background[3:4, :, :]], dim=0)
def load_model(self, event: wx.Event):
dir_name = "data/character_models"
file_dialog = wx.FileDialog(self, "Choose a model", dir_name, "", "*.yaml", wx.FD_OPEN)
if file_dialog.ShowModal() == wx.ID_OK:
character_model_json_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename())
try:
self.character_model = CharacterModel.load(character_model_json_file_name)
self.torch_source_image = self.character_model.get_character_image(self.device)
pil_image = resize_PIL_image(
PIL.Image.open(self.character_model.character_image_file_name),
(MainFrame.IMAGE_SIZE, MainFrame.IMAGE_SIZE))
w, h = pil_image.size
self.wx_source_image = wx.Bitmap.FromBufferRGBA(w, h, pil_image.convert("RGBA").tobytes())
self.update_source_image_bitmap()
self.poser = self.character_model.get_poser(self.device)
except Exception:
message_dialog = wx.MessageDialog(
self, "Could not load character model " + character_model_json_file_name, "Poser", wx.OK)
message_dialog.ShowModal()
message_dialog.Destroy()
file_dialog.Destroy()
self.Refresh()
if __name__ == "__main__":
device = torch.device("cuda:0")
pose_converter = MediaPoseFacePoseConverter00()
face_landmarker_base_options = mediapipe.tasks.BaseOptions(
model_asset_path='data/thirdparty/mediapipe/face_landmarker_v2_with_blendshapes.task')
options = mediapipe.tasks.vision.FaceLandmarkerOptions(
base_options=face_landmarker_base_options,
running_mode=mediapipe.tasks.vision.RunningMode.VIDEO,
output_face_blendshapes=True,
output_facial_transformation_matrixes=True,
num_faces=1)
face_landmarker = mediapipe.tasks.vision.FaceLandmarker.create_from_options(options)
video_capture = cv2.VideoCapture(0)
app = wx.App()
main_frame = MainFrame(pose_converter, video_capture, face_landmarker, device)
main_frame.Show(True)
main_frame.capture_timer.Start(30)
main_frame.animation_timer.Start(30)
app.MainLoop()
================================================
FILE: src/tha4/app/distill.py
================================================
import argparse
import logging
from tha4.distiller.distiller_config import DistillerConfig
from tha4.pytasuku.workspace import Workspace
def run_config(config_file_name: str):
config = DistillerConfig.load(config_file_name)
logging.basicConfig(level=logging.INFO, force=True)
workspace = Workspace()
config.define_tasks(workspace)
workspace.start_session()
workspace.run(f"{config.prefix}/all")
workspace.end_session()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Training script.')
parser.add_argument("--config_file", type=str, required=True,
help="The name of the config file for the distillation process.")
args = parser.parse_args()
run_config(args.config_file)
================================================
FILE: src/tha4/app/distiller_ui.py
================================================
import wx
from tha4.app.distill import run_config
from tha4.distiller.ui.distiller_ui_main_frame import DistillerUiMainFrame
if __name__ == "__main__":
app = wx.App()
main_frame = DistillerUiMainFrame()
main_frame.Show(True)
app.MainLoop()
if main_frame.config_file_to_run is not None:
run_config(main_frame.config_file_to_run)
================================================
FILE: src/tha4/app/full_manual_poser.py
================================================
import logging
import os
import sys
import time
from typing import List
from tha4.shion.base.image_util import extract_pytorch_image_from_PIL_image, pytorch_rgba_to_numpy_image, \
pytorch_rgb_to_numpy_image
from tha4.image_util import grid_change_to_numpy_image, resize_PIL_image
sys.path.append(os.getcwd())
import PIL.Image
import numpy
import torch
import wx
from tha4.poser.poser import Poser, PoseParameterCategory, PoseParameterGroup
class MorphCategoryControlPanel(wx.Panel):
def __init__(self,
parent,
title: str,
pose_param_category: PoseParameterCategory,
param_groups: List[PoseParameterGroup]):
super().__init__(parent, style=wx.SIMPLE_BORDER)
self.pose_param_category = pose_param_category
self.sizer = wx.BoxSizer(wx.VERTICAL)
self.SetSizer(self.sizer)
self.SetAutoLayout(1)
title_text = wx.StaticText(self, label=title, style=wx.ALIGN_CENTER)
self.sizer.Add(title_text, 0, wx.EXPAND)
self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category]
self.choice = wx.Choice(self, choices=[group.get_group_name() for group in self.param_groups])
if len(self.param_groups) > 0:
self.choice.SetSelection(0)
self.choice.Bind(wx.EVT_CHOICE, self.on_choice_updated)
self.sizer.Add(self.choice, 0, wx.EXPAND)
self.left_slider = wx.Slider(self, minValue=-1000, maxValue=1000, value=-1000, style=wx.HORIZONTAL)
self.sizer.Add(self.left_slider, 0, wx.EXPAND)
self.right_slider = wx.Slider(self, minValue=-1000, maxValue=1000, value=-1000, style=wx.HORIZONTAL)
self.sizer.Add(self.right_slider, 0, wx.EXPAND)
self.checkbox = wx.CheckBox(self, label="Show")
self.checkbox.SetValue(True)
self.sizer.Add(self.checkbox, 0, wx.SHAPED | wx.ALIGN_CENTER)
self.update_ui()
self.sizer.Fit(self)
def update_ui(self):
param_group = self.param_groups[self.choice.GetSelection()]
if param_group.is_discrete():
self.left_slider.Enable(False)
self.right_slider.Enable(False)
self.checkbox.Enable(True)
elif param_group.get_arity() == 1:
self.left_slider.Enable(True)
self.right_slider.Enable(False)
self.checkbox.Enable(False)
else:
self.left_slider.Enable(True)
self.right_slider.Enable(True)
self.checkbox.Enable(False)
def on_choice_updated(self, event: wx.Event):
param_group = self.param_groups[self.choice.GetSelection()]
if param_group.is_discrete():
self.checkbox.SetValue(True)
self.update_ui()
def set_param_value(self, pose: List[float]):
if len(self.param_groups) == 0:
return
selected_morph_index = self.choice.GetSelection()
param_group = self.param_groups[selected_morph_index]
param_index = param_group.get_parameter_index()
if param_group.is_discrete():
if self.checkbox.GetValue():
for i in range(param_group.get_arity()):
pose[param_index + i] = 1.0
else:
param_range = param_group.get_range()
alpha = (self.left_slider.GetValue() + 1000) / 2000.0
pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha
if param_group.get_arity() == 2:
alpha = (self.right_slider.GetValue() + 1000) / 2000.0
pose[param_index + 1] = param_range[0] + (param_range[1] - param_range[0]) * alpha
class SimpleParamGroupsControlPanel(wx.Panel):
def __init__(self, parent,
pose_param_category: PoseParameterCategory,
param_groups: List[PoseParameterGroup]):
super().__init__(parent, style=wx.SIMPLE_BORDER)
self.sizer = wx.BoxSizer(wx.VERTICAL)
self.SetSizer(self.sizer)
self.SetAutoLayout(1)
self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category]
for param_group in self.param_groups:
assert not param_group.is_discrete()
assert param_group.get_arity() == 1
self.sliders = []
for param_group in self.param_groups:
static_text = wx.StaticText(
self,
label=" ------------ %s ------------ " % param_group.get_group_name(), style=wx.ALIGN_CENTER)
self.sizer.Add(static_text, 0, wx.EXPAND)
range = param_group.get_range()
min_value = int(range[0] * 1000)
max_value = int(range[1] * 1000)
slider = wx.Slider(self, minValue=min_value, maxValue=max_value, value=0, style=wx.HORIZONTAL)
self.sizer.Add(slider, 0, wx.EXPAND)
self.sliders.append(slider)
self.sizer.Fit(self)
def set_param_value(self, pose: List[float]):
if len(self.param_groups) == 0:
return
for param_group_index in range(len(self.param_groups)):
param_group = self.param_groups[param_group_index]
slider = self.sliders[param_group_index]
param_range = param_group.get_range()
param_index = param_group.get_parameter_index()
alpha = (slider.GetValue() - slider.GetMin()) * 1.0 / (slider.GetMax() - slider.GetMin())
pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha
def convert_output_image_from_torch_to_numpy(output_image):
if output_image.shape[2] == 2:
h, w, c = output_image.shape
numpy_image = torch.transpose(output_image.reshape(h * w, c), 0, 1).reshape(c, h, w)
elif output_image.shape[0] == 4:
numpy_image = pytorch_rgba_to_numpy_image(output_image)
elif output_image.shape[0] == 3:
numpy_image = pytorch_rgb_to_numpy_image(output_image)
elif output_image.shape[0] == 1:
c, h, w = output_image.shape
alpha_image = torch.cat([output_image.repeat(3, 1, 1) * 2.0 - 1.0, torch.ones(1, h, w)], dim=0)
numpy_image = pytorch_rgba_to_numpy_image(alpha_image)
elif output_image.shape[0] == 2:
numpy_image = grid_change_to_numpy_image(output_image, num_channels=4)
else:
raise RuntimeError("Unsupported # image channels: %d" % output_image.shape[0])
numpy_image = numpy.uint8(numpy.rint(numpy_image * 255.0))
return numpy_image
class MainFrame(wx.Frame):
def __init__(self, poser: Poser, device: torch.device):
super().__init__(None, wx.ID_ANY, "Poser")
self.poser = poser
self.dtype = self.poser.get_dtype()
self.device = device
self.image_size = self.poser.get_image_size()
self.wx_source_image = None
self.torch_source_image = None
self.main_sizer = wx.BoxSizer(wx.HORIZONTAL)
self.SetSizer(self.main_sizer)
self.SetAutoLayout(1)
self.init_left_panel()
self.init_control_panel()
self.init_right_panel()
self.main_sizer.Fit(self)
self.timer = wx.Timer(self, wx.ID_ANY)
self.Bind(wx.EVT_TIMER, self.update_images, self.timer)
save_image_id = wx.NewIdRef()
self.Bind(wx.EVT_MENU, self.on_save_image, id=save_image_id)
accelerator_table = wx.AcceleratorTable([
(wx.ACCEL_CTRL, ord('S'), save_image_id)
])
self.SetAcceleratorTable(accelerator_table)
self.last_pose = None
self.last_output_index = self.output_index_choice.GetSelection()
self.last_output_numpy_image = None
self.wx_source_image = None
self.torch_source_image = None
self.source_image_bitmap = wx.Bitmap(self.image_size, self.image_size)
self.result_image_bitmap = wx.Bitmap(self.image_size, self.image_size)
self.source_image_dirty = True
def init_left_panel(self):
self.control_panel = wx.Panel(self, style=wx.SIMPLE_BORDER, size=(self.image_size, -1))
self.left_panel = wx.Panel(self, style=wx.SIMPLE_BORDER)
left_panel_sizer = wx.BoxSizer(wx.VERTICAL)
self.left_panel.SetSizer(left_panel_sizer)
self.left_panel.SetAutoLayout(1)
self.source_image_panel = wx.Panel(self.left_panel, size=(self.image_size, self.image_size),
style=wx.SIMPLE_BORDER)
self.source_image_panel.Bind(wx.EVT_PAINT, self.paint_source_image_panel)
self.source_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)
left_panel_sizer.Add(self.source_image_panel, 0, wx.FIXED_MINSIZE)
self.load_image_button = wx.Button(self.left_panel, wx.ID_ANY, "\nLoad Image\n\n")
left_panel_sizer.Add(self.load_image_button, 1, wx.EXPAND)
self.load_image_button.Bind(wx.EVT_BUTTON, self.load_image)
left_panel_sizer.Fit(self.left_panel)
self.main_sizer.Add(self.left_panel, 0, wx.FIXED_MINSIZE)
def on_erase_background(self, event: wx.Event):
pass
def init_control_panel(self):
self.control_panel_sizer = wx.BoxSizer(wx.VERTICAL)
self.control_panel.SetSizer(self.control_panel_sizer)
self.control_panel.SetMinSize(wx.Size(256, 1))
morph_categories = [
PoseParameterCategory.EYEBROW,
PoseParameterCategory.EYE,
PoseParameterCategory.MOUTH,
PoseParameterCategory.IRIS_MORPH
]
morph_category_titles = {
PoseParameterCategory.EYEBROW: " ------------ Eyebrow ------------ ",
PoseParameterCategory.EYE: " ------------ Eye ------------ ",
PoseParameterCategory.MOUTH: " ------------ Mouth ------------ ",
PoseParameterCategory.IRIS_MORPH: " ------------ Iris morphs ------------ ",
}
self.morph_control_panels = {}
for category in morph_categories:
param_groups = self.poser.get_pose_parameter_groups()
filtered_param_groups = [group for group in param_groups if group.get_category() == category]
if len(filtered_param_groups) == 0:
continue
control_panel = MorphCategoryControlPanel(
self.control_panel,
morph_category_titles[category],
category,
self.poser.get_pose_parameter_groups())
self.morph_control_panels[category] = control_panel
self.control_panel_sizer.Add(control_panel, 0, wx.EXPAND)
self.non_morph_control_panels = {}
non_morph_categories = [
PoseParameterCategory.IRIS_ROTATION,
PoseParameterCategory.FACE_ROTATION,
PoseParameterCategory.BODY_ROTATION,
PoseParameterCategory.BREATHING
]
for category in non_morph_categories:
param_groups = self.poser.get_pose_parameter_groups()
filtered_param_groups = [group for group in param_groups if group.get_category() == category]
if len(filtered_param_groups) == 0:
continue
control_panel = SimpleParamGroupsControlPanel(
self.control_panel,
category,
self.poser.get_pose_parameter_groups())
self.non_morph_control_panels[category] = control_panel
self.control_panel_sizer.Add(control_panel, 0, wx.EXPAND)
self.control_panel_sizer.Fit(self.control_panel)
self.main_sizer.Add(self.control_panel, 1, wx.FIXED_MINSIZE)
def init_right_panel(self):
self.right_panel = wx.Panel(self, style=wx.SIMPLE_BORDER)
right_panel_sizer = wx.BoxSizer(wx.VERTICAL)
self.right_panel.SetSizer(right_panel_sizer)
self.right_panel.SetAutoLayout(1)
self.result_image_panel = wx.Panel(self.right_panel,
size=(self.image_size, self.image_size),
style=wx.SIMPLE_BORDER)
self.result_image_panel.Bind(wx.EVT_PAINT, self.paint_result_image_panel)
self.result_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)
self.output_index_choice = wx.Choice(
self.right_panel,
choices=[str(i) for i in range(self.poser.get_output_length())])
self.output_index_choice.SetSelection(0)
right_panel_sizer.Add(self.result_image_panel, 0, wx.FIXED_MINSIZE)
right_panel_sizer.Add(self.output_index_choice, 0, wx.EXPAND)
self.save_image_button = wx.Button(self.right_panel, wx.ID_ANY, "\nSave Image\n\n")
right_panel_sizer.Add(self.save_image_button, 1, wx.EXPAND)
self.save_image_button.Bind(wx.EVT_BUTTON, self.on_save_image)
right_panel_sizer.Fit(self.right_panel)
self.main_sizer.Add(self.right_panel, 0, wx.FIXED_MINSIZE)
def create_param_category_choice(self, param_category: PoseParameterCategory):
params = []
for param_group in self.poser.get_pose_parameter_groups():
if param_group.get_category() == param_category:
params.append(param_group.get_group_name())
choice = wx.Choice(self.control_panel, choices=params)
if len(params) > 0:
choice.SetSelection(0)
return choice
def load_image(self, event: wx.Event):
dir_name = "data/images"
file_dialog = wx.FileDialog(self, "Choose an image", dir_name, "", "*.png", wx.FD_OPEN)
if file_dialog.ShowModal() == wx.ID_OK:
image_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename())
try:
pil_image = resize_PIL_image(PIL.Image.open(image_file_name),
(self.poser.get_image_size(), self.poser.get_image_size()))
w, h = pil_image.size
if pil_image.mode != 'RGBA':
self.source_image_string = "Image must have alpha channel!"
self.wx_source_image = None
self.torch_source_image = None
else:
self.wx_source_image = wx.Bitmap.FromBufferRGBA(w, h, pil_image.convert("RGBA").tobytes())
self.torch_source_image = extract_pytorch_image_from_PIL_image(pil_image) \
.to(self.device).to(self.dtype)
self.source_image_dirty = True
self.Refresh()
self.Update()
except:
message_dialog = wx.MessageDialog(self, "Could not load image " + image_file_name, "Poser", wx.OK)
message_dialog.ShowModal()
message_dialog.Destroy()
file_dialog.Destroy()
def paint_source_image_panel(self, event: wx.Event):
wx.BufferedPaintDC(self.source_image_panel, self.source_image_bitmap)
def paint_result_image_panel(self, event: wx.Event):
wx.BufferedPaintDC(self.result_image_panel, self.result_image_bitmap)
def draw_nothing_yet_string_to_bitmap(self, bitmap):
dc = wx.MemoryDC()
dc.SelectObject(bitmap)
dc.Clear()
font = wx.Font(wx.FontInfo(14).Family(wx.FONTFAMILY_SWISS))
dc.SetFont(font)
w, h = dc.GetTextExtent("Nothing yet!")
dc.DrawText("Nothing yet!", (self.image_size - w) // 2, (self.image_size - - h) // 2)
del dc
def get_current_pose(self):
current_pose = [0.0 for i in range(self.poser.get_num_parameters())]
for morph_control_panel in self.morph_control_panels.values():
morph_control_panel.set_param_value(current_pose)
for rotation_control_panel in self.non_morph_control_panels.values():
rotation_control_panel.set_param_value(current_pose)
return current_pose
def update_images(self, event: wx.Event):
current_pose = self.get_current_pose()
if not self.source_image_dirty \
and self.last_pose is not None \
and self.last_pose == current_pose \
and self.last_output_index == self.output_index_choice.GetSelection():
return
self.last_pose = current_pose
self.last_output_index = self.output_index_choice.GetSelection()
if self.torch_source_image is None:
self.draw_nothing_yet_string_to_bitmap(self.source_image_bitmap)
self.draw_nothing_yet_string_to_bitmap(self.result_image_bitmap)
self.source_image_dirty = False
self.Refresh()
self.Update()
return
if self.source_image_dirty:
dc = wx.MemoryDC()
dc.SelectObject(self.source_image_bitmap)
dc.Clear()
dc.DrawBitmap(self.wx_source_image, 0, 0)
self.source_image_dirty = False
pose = torch.tensor(current_pose, device=self.device, dtype=self.dtype)
output_index = self.output_index_choice.GetSelection()
with torch.no_grad():
start_cuda_event = torch.cuda.Event(enable_timing=True)
end_cuda_event = torch.cuda.Event(enable_timing=True)
start_cuda_event.record()
start_time = time.time()
output_image = self.poser.pose(self.torch_source_image, pose, output_index)[0].detach().cpu()
end_time = time.time()
end_cuda_event.record()
torch.cuda.synchronize()
print("cuda time (ms):", start_cuda_event.elapsed_time(end_cuda_event))
print("elapsed time (ms):", (end_time - start_time) * 1000.0)
numpy_image = convert_output_image_from_torch_to_numpy(output_image)
self.last_output_numpy_image = numpy_image
wx_image = wx.ImageFromBuffer(
numpy_image.shape[0],
numpy_image.shape[1],
numpy_image[:, :, 0:3].tobytes(),
numpy_image[:, :, 3].tobytes())
wx_bitmap = wx_image.ConvertToBitmap()
dc = wx.MemoryDC()
dc.SelectObject(self.result_image_bitmap)
dc.Clear()
dc.DrawBitmap(wx_bitmap,
(self.image_size - numpy_image.shape[0]) // 2,
(self.image_size - numpy_image.shape[1]) // 2,
True)
del dc
self.Refresh()
self.Update()
def on_save_image(self, event: wx.Event):
if self.last_output_numpy_image is None:
logging.info("There is no output image to save!!!")
return
dir_name = "data/images"
file_dialog = wx.FileDialog(self, "Choose an image", dir_name, "", "*.png", wx.FD_SAVE)
if file_dialog.ShowModal() == wx.ID_OK:
image_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename())
try:
if os.path.exists(image_file_name):
message_dialog = wx.MessageDialog(self, f"Override {image_file_name}", "Manual Poser",
wx.YES_NO | wx.ICON_QUESTION)
result = message_dialog.ShowModal()
if result == wx.ID_YES:
self.save_last_numpy_image(image_file_name)
message_dialog.Destroy()
else:
self.save_last_numpy_image(image_file_name)
except:
message_dialog = wx.MessageDialog(self, f"Could not save {image_file_name}", "Manual Poser", wx.OK)
message_dialog.ShowModal()
message_dialog.Destroy()
file_dialog.Destroy()
def save_last_numpy_image(self, image_file_name):
numpy_image = self.last_output_numpy_image
pil_image = PIL.Image.fromarray(numpy_image, mode='RGBA')
os.makedirs(os.path.dirname(image_file_name), exist_ok=True)
pil_image.save(image_file_name)
if __name__ == "__main__":
device = torch.device('cuda:0')
try:
import tha4.poser.modes.mode_07
poser = tha4.poser.modes.mode_07.create_poser(device)
except RuntimeError as e:
print(e)
sys.exit()
app = wx.App()
main_frame = MainFrame(poser, device)
main_frame.Show(True)
main_frame.timer.Start(16)
app.MainLoop()
================================================
FILE: src/tha4/charmodel/__init__.py
================================================
================================================
FILE: src/tha4/charmodel/character_model.py
================================================
import json
import os.path
import PIL.Image
import torch
from omegaconf import OmegaConf
from tha4.shion.base.image_util import extract_pytorch_image_from_PIL_image
from tha4.poser.modes.mode_14 import create_poser, KEY_FACE_MORPHER, KEY_BODY_MORPHER
class CharacterModel:
def __init__(self,
character_image_file_name: str,
face_morpher_file_name: str,
body_morpher_file_name: str):
self.body_morpher_file_name = body_morpher_file_name
self.face_morpher_file_name = face_morpher_file_name
self.character_image_file_name = character_image_file_name
self.poser = None
self.character_image = None
def get_poser(self, device: torch.device):
if self.poser is not None:
self.poser.to(device)
else:
self.poser = create_poser(
device,
module_file_names={
KEY_FACE_MORPHER: self.face_morpher_file_name,
KEY_BODY_MORPHER: self.body_morpher_file_name
})
return self.poser
def get_character_image(self, device: torch.device):
if self.character_image is None:
pil_image = PIL.Image.open(self.character_image_file_name)
if pil_image.mode != 'RGBA':
raise RuntimeError("Character image is not an RGBA image!")
self.character_image = extract_pytorch_image_from_PIL_image(pil_image)
self.character_image = self.character_image.to(device)
return self.character_image
def save(self, file_name: str):
dir = os.path.dirname(file_name)
rel_char_image_file_name = os.path.relpath(self.character_image_file_name, dir)
rel_face_morpher_file_name = os.path.relpath(self.face_morpher_file_name, dir)
rel_body_morpher_file_name = os.path.relpath(self.body_morpher_file_name, dir)
data = {
"character_image_file_name": rel_char_image_file_name,
"face_morpher_file_name": rel_face_morpher_file_name,
"body_morpher_file_name": rel_body_morpher_file_name,
}
conf = OmegaConf.create(data)
os.makedirs(dir, exist_ok=True)
with open(file_name, "wt") as fout:
fout.write(OmegaConf.to_yaml(conf))
@staticmethod
def load(file_name: str):
conf = OmegaConf.to_container(OmegaConf.load(file_name))
dir = os.path.dirname(file_name)
character_image_file_name = os.path.join(dir, conf["character_image_file_name"])
face_morpher_file_name = os.path.join(dir, conf["face_morpher_file_name"])
body_morpher_file_name = os.path.join(dir, conf["body_morpher_file_name"])
return CharacterModel(
character_image_file_name,
face_morpher_file_name,
body_morpher_file_name)
================================================
FILE: src/tha4/dataset/__init__.py
================================================
================================================
FILE: src/tha4/dataset/image_poses_and_aother_images_dataset.py
================================================
from typing import List, Callable
from torch import Tensor
from torch.utils.data import Dataset
class ImagePosesAndOtherImagesDataset(Dataset):
def __init__(self,
main_image_func: Callable[[], Tensor],
pose_dataset: Dataset,
other_image_funcs: List[Callable[[], Tensor]]):
self.main_image_func = main_image_func
self.other_image_funcs = other_image_funcs
self.pose_dataset = pose_dataset
self.main_image = None
self.other_images = [None for i in range(len(self.other_image_funcs))]
def get_main_image(self):
if self.main_image is None:
self.main_image = self.main_image_func()
return self.main_image
def get_other_image(self, image_index: int):
if self.other_images[image_index] is None:
self.other_images[image_index] = self.other_image_funcs[image_index]()
return self.other_images[image_index]
def __len__(self):
return len(self.pose_dataset)
def __getitem__(self, index):
main_image = self.get_main_image()
pose = self.pose_dataset[index][0]
other_images = [self.get_other_image(i) for i in range(len(self.other_image_funcs))]
return [main_image, pose] + other_images
================================================
FILE: src/tha4/distiller/__init__.py
================================================
================================================
FILE: src/tha4/distiller/config_based_training_tasks.py
================================================
import logging
import os
import sys
from typing import Callable, List, Optional
from tha4.pytasuku.workspace import Workspace
from tha4.shion.core.training.distrib.distributed_trainer import DistributedTrainer
from tha4.shion.core.training.distrib.distributed_training_states import DistributedTrainingState
def get_torchrun_executable():
return os.path.dirname(sys.executable) + os.path.sep + "torchrun"
class RdzvConfig:
def __init__(self, id: int, port: int):
self.port = port
self.id = id
def run_standalone_config_based_training_script(
training_script_file_name: str,
config_file_name: str,
num_proc_per_node: int,
target_checkpoint_examples: Optional[int] = None,
rdzv_config: Optional[RdzvConfig] = None):
command = f"{get_torchrun_executable()} " \
f"--nnodes=1 " \
f"--nproc_per_node={num_proc_per_node} "
if rdzv_config is not None:
command += f"--rdzv_endpoint=localhost:{rdzv_config.port} "
command += "--rdzv_backend=c10d "
command += f"--rdzv_id={rdzv_config.id} "
else:
command += "--standalone "
command += f"{training_script_file_name} "
if target_checkpoint_examples is not None:
command += f"--target_checkpoint_examples {target_checkpoint_examples} "
command += f"--config_file={config_file_name} "
logging.info(f"Executing -- {command}")
os.system(command)
def define_standalone_config_based_training_tasks(
workspace: Workspace,
distributed_trainer_func: Callable[[], DistributedTrainer],
training_script_file_name: str,
config_file_name: str,
num_proc_per_node: int,
dependencies: Optional[List[str]] = None,
rdzv_config: Optional[RdzvConfig] = None):
trainer = distributed_trainer_func()
checkpoint_examples = trainer.training_protocol.get_checkpoint_examples()
assert len(checkpoint_examples) >= 1
assert checkpoint_examples[0] > 0
checkpoint_examples = [0] + checkpoint_examples
if dependencies is None:
dependencies = []
module_file_dependencies = dependencies[:]
for module_name in trainer.pretrained_module_file_names:
module_file_dependencies.append(trainer.pretrained_module_file_names[module_name])
def create_train_func(target_checkpoint_examples: int):
return lambda: run_standalone_config_based_training_script(
training_script_file_name,
config_file_name,
num_proc_per_node,
target_checkpoint_examples,
rdzv_config=rdzv_config)
train_tasks = []
for checkpoint_index in range(0, len(checkpoint_examples)):
for module_name in trainer.module_names:
module_file_name = DistributedTrainingState.get_module_file_name(
trainer.get_checkpoint_prefix(checkpoint_index),
module_name)
workspace.create_file_task(
module_file_name,
module_file_dependencies,
create_train_func(trainer.checkpoint_examples[checkpoint_index]))
for module_name in trainer.accumulators:
accumulated_module_file_name = DistributedTrainingState.get_accumulated_module_file_name(
trainer.get_checkpoint_prefix(checkpoint_index),
module_name)
workspace.create_file_task(
accumulated_module_file_name,
module_file_dependencies,
create_train_func(checkpoint_examples[checkpoint_index]))
workspace.create_command_task(
trainer.get_checkpoint_prefix(checkpoint_index) + "/train_standalone",
module_file_dependencies,
create_train_func(checkpoint_examples[checkpoint_index]))
train_tasks.append(trainer.get_checkpoint_prefix(checkpoint_index) + "/train_standlone")
workspace.create_file_task(
trainer.prefix + "/train_standalone",
module_file_dependencies,
create_train_func(checkpoint_examples[-1]))
================================================
FILE: src/tha4/distiller/distill_body_morpher.py
================================================
import logging
from tha4.shion.core.training.distrib.distributed_trainer import DistributedTrainer
from tha4.distiller.distiller_config import DistillerConfig
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
parser = DistributedTrainer.get_default_arg_parser()
parser.add_argument('--config_file', type=str)
args = parser.parse_args()
config_file_name = args.config_file
config = DistillerConfig.load(config_file_name)
DistributedTrainer.run_with_args(config.get_body_morpher_trainer, args)
================================================
FILE: src/tha4/distiller/distill_face_morpher.py
================================================
import logging
from tha4.shion.core.training.distrib.distributed_trainer import DistributedTrainer
from tha4.distiller.distiller_config import DistillerConfig
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
parser = DistributedTrainer.get_default_arg_parser()
parser.add_argument('--config_file', type=str)
args = parser.parse_args()
config_file_name = args.config_file
config = DistillerConfig.load(config_file_name)
DistributedTrainer.run_with_args(config.get_face_morpher_trainer, args)
================================================
FILE: src/tha4/distiller/distiller_config.py
================================================
import os.path
import shutil
import PIL.Image
from dataclasses import dataclass
from typing import Optional
from omegaconf import OmegaConf
from tha4.charmodel.character_model import CharacterModel
from tha4.pytasuku.workspace import Workspace, file_task
from tha4.distiller.config_based_training_tasks import define_standalone_config_based_training_tasks
from tha4.nn.siren.face_morpher.siren_face_morpher_00_trainer import SirenFaceMorpher00TrainerArgs
from tha4.nn.siren.morpher.siren_morpher_03_trainer import SirenMorpher03TrainerArgs, TrainingPhases, TrainingPhase, \
LossWeights, LossTerm
from tha4.shion.base.image_util import pil_image_has_transparency
POSE_DATASET_FILE_NAME = 'data/pose_dataset.pt'
def copy_file(source_file_name: str, dest_file_name):
os.makedirs(os.path.dirname(dest_file_name), exist_ok=True)
shutil.copyfile(source_file_name, dest_file_name)
@dataclass
class DistillerConfig:
prefix: str
character_image_file_name: str
face_mask_image_file_name: str
face_morpher_random_seed_0: int = 12771885812175595441
face_morpher_random_seed_1: int = 14367217090963479175
face_morpher_num_training_examples_per_sample_output: Optional[int] = 10_000
face_morpher_batch_size: int = 8
body_morpher_random_seed_0: int = 2892221210020292507
body_morpher_random_seed_1: int = 9998918537095922080
body_morpher_num_training_examples_per_sample_output: Optional[int] = 10_000
body_morpher_batch_size: int = 8
num_cpu_workers: int = 1
num_gpus: int = 1
def check(self):
DistillerConfig.check_prefix(self.prefix)
DistillerConfig.check_character_image_file_name(self.character_image_file_name)
DistillerConfig.check_face_mask_image_file_name(self.face_mask_image_file_name)
DistillerConfig.check_num_cpu_workers(self.num_cpu_workers)
DistillerConfig.check_num_gpus(self.num_gpus)
DistillerConfig.check_random_seed(self.face_morpher_random_seed_0, "face_morpher_random_seed_0")
DistillerConfig.check_random_seed(self.face_morpher_random_seed_1, "face_morpher_random_seed_1")
DistillerConfig.check_batch_size(self.face_morpher_batch_size, "face_morpher_batch_size")
DistillerConfig.check_num_training_examples_per_sample_output(
self.face_morpher_num_training_examples_per_sample_output,
"face_morpher_num_training_examples_per_sample_output")
DistillerConfig.check_random_seed(self.body_morpher_random_seed_0, "body_morpher_random_seed_0")
DistillerConfig.check_random_seed(self.body_morpher_random_seed_1, "body_morpher_random_seed_1")
DistillerConfig.check_batch_size(self.body_morpher_batch_size, "body_morpher_batch_size")
DistillerConfig.check_num_training_examples_per_sample_output(
self.body_morpher_num_training_examples_per_sample_output,
"body_morpher_num_training_examples_per_sample_output")
@staticmethod
def check_prefix(prefix):
assert os.path.isdir(prefix), "The 'prefix' must be a directory."
assert os.path.exists(prefix), f"The {prefix} directory does not exist."
@staticmethod
def check_character_image_file_name(file_name):
_, ext = os.path.splitext(file_name)
assert os.path.isfile(file_name), \
f"The specified character image file name, {file_name}, does not point to a file."
assert ext.lower() == ".png", "The character image file name must have extension '.png'."
image = PIL.Image.open(file_name)
assert pil_image_has_transparency(image), "The character image must have an alpha channel."
assert image.width == 512 and image.height == 512, "The character image must be 512x512."
image.close()
@staticmethod
def check_face_mask_image_file_name(file_name):
_, ext = os.path.splitext(file_name)
assert os.path.isfile(file_name), \
f"The specified face mask image file name, {file_name}, does not point to a file."
assert ext.lower() == ".png", "The face mask image file name must have extension '.png'."
image = PIL.Image.open(file_name)
assert image.width == 512 and image.height == 512, "The face mask image must be 512x512."
assert image.mode == "RGB", "The face mask image must be an RGB image."
for x in range(512):
for y in range(512):
r, g, b = image.getpixel((x, y))
assert (r == 0) or (r == 255), "The R channel of the face mask image must be 0 or 255"
assert (g == 0) or (g == 255), "The G channel of the face mask image must be 0 or 255"
assert (b == 0) or (b == 255), "The B channel of the face mask image must be 0 or 255"
image.close()
@staticmethod
def check_batch_size(value, field_name: str):
assert isinstance(value, int), f"The {field_name} must be an integer."
assert value >= 1, f"The {field_name} must be at least 1."
assert value <= 8, f"The {field_name} must be at most 8."
@staticmethod
def check_num_cpu_workers(value):
assert value >= 1, "The value of 'num_cpu_workers must be at least 1."
@staticmethod
def check_num_gpus(value):
assert value >= 1, "The value of 'num_gpus' must be at least 1."
@staticmethod
def check_random_seed(value, field_name: str):
assert isinstance(value, int), f"The {field_name} must be an integer."
assert value >= 0 and value <= 0x_ffff_ffff_ffff_ffff, "A random seed must be between 0 and 2**64-1."
@staticmethod
def check_num_training_examples_per_sample_output(value, field_name):
assert value in [10_000, 100_000, 1_000_000,
None], f"The {field_name} must be 10_000, 100_00, 1_000_000_000, or None."
def save(self, file_name: str):
conf = OmegaConf.structured(self)
os.makedirs(self.prefix, exist_ok=True)
with open(file_name, "wt") as fout:
fout.write(OmegaConf.to_yaml(conf))
def config_yaml_file_name(self):
return f"{self.prefix}/config.yaml"
def create_config_yaml_file(self):
if os.path.exists(self.config_yaml_file_name()):
return
self.save(self.config_yaml_file_name())
@staticmethod
def load(file_name: str) -> 'DistillerConfig':
conf = OmegaConf.to_container(OmegaConf.load(file_name))
args = DistillerConfig(**conf)
args.check()
return args
def face_morpher_prefix(self):
return f"{self.prefix}/face_morpher"
def get_face_morpher_trainer(self, world_size: Optional[int] = None, backend: str = 'gloo'):
if world_size is None:
world_size = self.num_gpus
args = SirenFaceMorpher00TrainerArgs(
character_file_name=self.character_image_file_name,
face_mask_file_name=self.face_mask_image_file_name,
pose_dataset_file_name=POSE_DATASET_FILE_NAME,
total_worker=self.num_cpu_workers,
num_training_examples_per_sample_output=self.face_morpher_num_training_examples_per_sample_output,
total_batch_size=self.face_morpher_batch_size,
training_random_seed=self.face_morpher_random_seed_0,
sample_output_random_seed=self.face_morpher_random_seed_1)
return args.create_trainer(self.face_morpher_prefix(), world_size, backend)
def body_morpher_prefix(self):
return f"{self.prefix}/body_morpher"
def get_body_morpher_trainer(self, world_size: Optional[int] = None, backend: str = 'gloo'):
if world_size is None:
world_size = self.num_gpus
args = SirenMorpher03TrainerArgs(
character_file_name=self.character_image_file_name,
pose_dataset_file_name=POSE_DATASET_FILE_NAME,
total_worker=self.num_cpu_workers,
num_training_examples_per_sample_output=self.body_morpher_num_training_examples_per_sample_output,
training_random_seed=self.body_morpher_random_seed_0,
sample_output_random_seed=self.body_morpher_random_seed_1,
total_batch_size=self.body_morpher_batch_size,
sample_output_batch_size=1,
training_phases=TrainingPhases([
TrainingPhase(
num_examples_upper_bound=200_000,
learning_rate=1e-4,
loss_weights=LossWeights(weights={
LossTerm.full_blended: 0.25,
LossTerm.full_warped: 0.25,
LossTerm.full_grid_change: 0.5,
LossTerm.full_color_change: 2.0,
})),
TrainingPhase(
num_examples_upper_bound=400_000,
learning_rate=3e-5,
loss_weights=LossWeights(weights={
LossTerm.full_blended: 0.25,
LossTerm.full_warped: 0.25,
LossTerm.full_grid_change: 0.5,
LossTerm.full_color_change: 2.0,
})),
TrainingPhase(
num_examples_upper_bound=600_000,
learning_rate=3e-5,
loss_weights=LossWeights(weights={
LossTerm.full_blended: 1.0,
LossTerm.full_warped: 2.5,
LossTerm.full_grid_change: 5.0,
LossTerm.full_color_change: 1.0,
})),
TrainingPhase(
num_examples_upper_bound=800_000,
learning_rate=1e-5,
loss_weights=LossWeights(weights={
LossTerm.full_blended: 1.0,
LossTerm.full_warped: 2.5,
LossTerm.full_grid_change: 5.0,
LossTerm.full_color_change: 1.0,
})),
TrainingPhase(
num_examples_upper_bound=1_300_000,
learning_rate=1e-5,
loss_weights=LossWeights(weights={
LossTerm.full_blended: 10.0,
LossTerm.full_warped: 1.0,
LossTerm.full_grid_change: 1.0,
LossTerm.full_color_change: 1.0,
})),
TrainingPhase(
num_examples_upper_bound=1_500_000,
learning_rate=3e-6,
loss_weights=LossWeights(weights={
LossTerm.full_blended: 10.0,
LossTerm.full_warped: 1.0,
LossTerm.full_grid_change: 1.0,
LossTerm.full_color_change: 1.0,
})),
]))
return args.create_trainer(self.body_morpher_prefix(), world_size, backend)
def character_model_prefix(self):
return f"{self.prefix}/character_model"
def character_model_face_morpher_file_name(self):
return f"{self.character_model_prefix()}/face_morpher.pt"
def character_model_body_morpher_file_name(self):
return f"{self.character_model_prefix()}/body_morpher.pt"
def character_model_character_png_file_name(self):
return f"{self.character_model_prefix()}/character.png"
def character_model_yaml_file_name(self):
return f"{self.character_model_prefix()}/character_model.yaml"
def define_tasks(self, workspace: Workspace):
workspace.create_file_task(self.config_yaml_file_name(), [], self.create_config_yaml_file)
define_standalone_config_based_training_tasks(
workspace,
self.get_face_morpher_trainer,
"src/tha4/distiller/distill_face_morpher.py",
self.config_yaml_file_name(),
num_proc_per_node=self.num_gpus,
dependencies=[
self.config_yaml_file_name(),
])
define_standalone_config_based_training_tasks(
workspace,
self.get_body_morpher_trainer,
"src/tha4/distiller/distill_body_morpher.py",
self.config_yaml_file_name(),
num_proc_per_node=self.num_gpus,
dependencies=[
self.config_yaml_file_name(),
])
@file_task(workspace, self.character_model_character_png_file_name(), [self.character_image_file_name])
def copy_character_image_file_name():
copy_file(self.character_image_file_name, self.character_model_character_png_file_name())
@file_task(workspace, self.character_model_face_morpher_file_name(), [
f"{self.face_morpher_prefix()}/checkpoint/0010/module_module.pt",
])
def copy_face_morpher():
copy_file(
f"{self.face_morpher_prefix()}/checkpoint/0010/module_module.pt",
self.character_model_face_morpher_file_name())
@file_task(workspace, self.character_model_body_morpher_file_name(), [
f"{self.body_morpher_prefix()}/checkpoint/0015/module_module.pt",
])
def copy_face_morpher():
copy_file(
f"{self.body_morpher_prefix()}/checkpoint/0015/module_module.pt",
self.character_model_body_morpher_file_name())
@file_task(workspace, self.character_model_yaml_file_name(), [])
def create_character_model_yaml_file():
character_model = CharacterModel(
self.character_model_character_png_file_name(),
self.character_model_face_morpher_file_name(),
self.character_model_body_morpher_file_name())
character_model.save(self.character_model_yaml_file_name())
workspace.create_command_task(
f"{self.prefix}/all",
[
f"{self.face_morpher_prefix()}/train_standalone",
f"{self.body_morpher_prefix()}/train_standalone",
self.character_model_character_png_file_name(),
self.character_model_face_morpher_file_name(),
self.character_model_body_morpher_file_name(),
self.character_model_yaml_file_name(),
])
================================================
FILE: src/tha4/distiller/ui/__init__.py
================================================
================================================
FILE: src/tha4/distiller/ui/distiller_config_state.py
================================================
import os.path
from contextlib import contextmanager
from pathlib import PurePath, Path
from typing import Callable, Any, Optional
from tha4.distiller.distiller_config import DistillerConfig
class DistillerConfigState:
def __init__(self):
self.config = DistillerConfig(prefix="", character_image_file_name="", face_mask_image_file_name="")
self.last_saved_timestamp = None
self.dirty = False
def load(self, file_name):
self.config = DistillerConfig.load(file_name)
if os.path.exists(self.config.config_yaml_file_name()):
self.last_saved_timestamp = os.path.getmtime(self.config.config_yaml_file_name())
else:
self.last_saved_timestamp = None
self.dirty = False
def need_to_check_overwrite(self):
if self.last_saved_timestamp is None:
return True
if not os.path.exists(self.config.config_yaml_file_name()):
return False
if self.last_saved_timestamp < os.path.getmtime(self.config.config_yaml_file_name()):
return True
return False
def save(self):
self.config.save(self.config.config_yaml_file_name())
self.dirty = False
self.last_saved_timestamp = os.path.getmtime(self.config.config_yaml_file_name())
@contextmanager
def updating_value(self, value_func: Callable[[], Any]):
old_value = value_func()
yield
new_value = value_func()
if new_value != old_value:
self.dirty = True
def set_prefix(self, new_value):
with self.updating_value(lambda: self.config.prefix):
new_relative_path = self.get_relative_path_to_cwd(
new_value,
"The prefix directory must be a subdirectory of the talking-head-anime-4-demo's source code directory.")
DistillerConfig.check_prefix(new_relative_path)
self.config.prefix = new_relative_path
def set_character_image_file_name(self, new_value):
with self.updating_value(lambda: self.config.character_image_file_name):
new_relative_path = self.get_relative_path_to_cwd(
new_value,
"The character image file must be under talking-head-anime-4-demo's source code directory.")
DistillerConfig.check_character_image_file_name(new_relative_path)
self.config.character_image_file_name = new_relative_path
def set_face_mask_image_file_name(self, new_value):
with self.updating_value(lambda: self.config.face_mask_image_file_name):
new_relative_path = self.get_relative_path_to_cwd(
new_value,
"The face mask image file must be under talking-head-anime-4-demo's source code directory.")
DistillerConfig.check_face_mask_image_file_name(new_relative_path)
self.config.face_mask_image_file_name = new_relative_path
def set_num_cpu_workers(self, new_value: int):
with self.updating_value(lambda: self.config.num_cpu_workers):
DistillerConfig.check_num_cpu_workers(new_value)
self.config.num_cpu_workers = new_value
def set_num_gpus(self, new_value: int):
with self.updating_value(lambda: self.config.num_gpus):
DistillerConfig.check_num_cpu_workers(new_value)
self.config.num_gpus = new_value
def set_face_morpher_random_seed_0(self, new_value: int):
with self.updating_value(lambda: self.config.face_morpher_random_seed_0):
DistillerConfig.check_random_seed(new_value, "face_morpher_random_seed_0")
self.config.face_morpher_random_seed_0 = new_value
def set_face_morpher_random_seed_1(self, new_value: int):
with self.updating_value(lambda: self.config.face_morpher_random_seed_1):
DistillerConfig.check_random_seed(new_value, "face_morpher_random_seed_1")
self.config.face_morpher_random_seed_1 = new_value
def set_face_morpher_num_training_examples_per_sample_output(self, new_value: Optional[int]):
with self.updating_value(lambda: self.config.face_morpher_num_training_examples_per_sample_output):
DistillerConfig.check_num_training_examples_per_sample_output(
new_value, "face_morpher_num_training_examples_per_sample_output")
self.config.face_morpher_num_training_examples_per_sample_output = new_value
def set_face_morpher_batch_size(self, new_value: int):
with self.updating_value(lambda: self.config.face_morpher_batch_size):
DistillerConfig.check_batch_size(new_value, "face_morpher_batch_size")
self.config.face_morpher_batch_size = new_value
def set_body_morpher_random_seed_0(self, new_value: int):
with self.updating_value(lambda: self.config.body_morpher_random_seed_0):
DistillerConfig.check_random_seed(new_value, "body_morpher_random_seed_0")
self.config.body_morpher_random_seed_0 = new_value
def set_body_morpher_random_seed_1(self, new_value: int):
with self.updating_value(lambda: self.config.body_morpher_random_seed_1):
DistillerConfig.check_random_seed(new_value, "body_morpher_random_seed_1")
self.config.body_morpher_random_seed_1 = new_value
def set_body_morpher_num_training_examples_per_sample_output(self, new_value: Optional[int]):
with self.updating_value(lambda: self.config.body_morpher_num_training_examples_per_sample_output):
DistillerConfig.check_num_training_examples_per_sample_output(
new_value, "body_morpher_num_training_examples_per_sample_output")
self.config.body_morpher_num_training_examples_per_sample_output = new_value
def set_body_morpher_batch_size(self, new_value: int):
with self.updating_value(lambda: self.config.body_morpher_batch_size):
DistillerConfig.check_batch_size(new_value, "body_morpher_batch_size")
self.config.body_morpher_batch_size = new_value
def get_relative_path_to_cwd(self, file_name: str, message: str):
cwd = os.getcwd()
assert os.path.commonprefix([cwd, file_name]) == cwd, message
cwd_path = Path(cwd).as_posix()
new_path = Path(file_name).as_posix()
new_relative_path = os.path.relpath(str(new_path), cwd_path)
new_relative_path = str(Path(new_relative_path).as_posix())
return new_relative_path
def can_show_character_image(self):
return os.path.isfile(self.config.character_image_file_name)
def can_show_face_mask_image(self):
return os.path.isfile(self.config.face_mask_image_file_name)
def can_show_mask_on_face_image(self):
return self.can_show_character_image() and self.can_show_face_mask_image()
def can_save(self):
return os.path.isdir(self.config.prefix) \
and os.path.isfile(self.config.character_image_file_name) \
and os.path.isfile(self.config.face_mask_image_file_name)
================================================
FILE: src/tha4/distiller/ui/distiller_ui_main_frame.py
================================================
import multiprocessing
import random
from contextlib import contextmanager
from typing import Callable
import PIL.Image
import torch
import wx
import wx.html
import wx.lib.intctrl
from tha4.distiller.ui.distiller_config_state import DistillerConfigState
from tha4.image_util import convert_output_image_from_torch_to_numpy
from tha4.shion.base.image_util import extract_pytorch_image_from_PIL_image
def wx_bind_event(widget, evt):
def f(handler):
widget.Bind(evt, handler)
return handler
return f
class DistillerUiMainFrame(wx.Frame):
PARAM_NAME_STATIC_TEXT_MIN_WIDTH = 400
NUM_TRAINING_EXAMPLES_PER_SAMPLE_OUTPUT_CHOICES = [
"10_000", "100_000", "1_000_000", "Do not generate sample outputs"]
def __init__(self):
super().__init__(None, wx.ID_ANY, "Distiller UI")
self.init_ui()
self.init_menus()
self.init_bitmaps()
self.Bind(wx.EVT_CLOSE, self.on_close)
self.state = DistillerConfigState()
self.update_ui()
self.config_file_to_run = None
def init_ui(self):
main_sizer = wx.BoxSizer(wx.HORIZONTAL)
self.SetSizer(main_sizer)
self.SetAutoLayout(1)
left_panel = self.init_left_panel(self)
main_sizer.Add(left_panel, 0, wx.FIXED_MINSIZE)
middle_panel = self.init_middle_panel(self)
main_sizer.Add(middle_panel, 0, wx.EXPAND)
right_panel = self.init_right_panel(self)
main_sizer.Add(right_panel, 1, wx.EXPAND)
main_sizer.Fit(self)
def init_menus(self):
self.file_menu = wx.Menu()
self.new_menu_id = wx.Window.NewControlId()
self.file_menu.Append(
self.new_menu_id, item="&New\tCTRL+N", helpString="Create a new distiller configuration.")
self.Bind(wx.EVT_MENU, self.on_new, id=self.new_menu_id)
self.open_menu_id = wx.Window.NewControlId()
self.file_menu.Append(
self.open_menu_id, item="&Open\tCTRL+O", helpString="Open a distiller confuguration.")
self.Bind(wx.EVT_MENU, self.on_open, id=self.open_menu_id)
self.save_menu_id = wx.Window.NewControlId()
self.save_menu_item = wx.MenuItem(
self.file_menu, id=self.save_menu_id, text="&Save\tCTRL+S",
helpString="Save the current distiller configuration. Error message will be shown it it is not well formed.")
self.Bind(wx.EVT_MENU, self.on_save, id=self.save_menu_id)
self.file_menu.Append(self.save_menu_item)
self.file_menu.AppendSeparator()
self.exit_menu_id = wx.ID_EXIT
self.file_menu.Append(
self.exit_menu_id, item="E&xit\tCTRL+Q", helpString="Exit the application.")
self.Bind(wx.EVT_MENU, self.on_close, id=self.exit_menu_id)
self.menu_bar = wx.MenuBar()
self.menu_bar.Append(self.file_menu, "&File")
self.SetMenuBar(self.menu_bar)
def init_bitmaps(self):
self.face_image_bitmap = wx.Bitmap(128, 128)
self.face_image_pytorch = None
self.face_mask_image_bitmap = wx.Bitmap(128, 128)
self.face_mask_image_pytorch = None
self.mask_on_face_image_bitmap = wx.Bitmap(128, 128)
self.draw_nothing_yet_string_to_bitmap(self.face_image_bitmap, 128, 128)
self.draw_nothing_yet_string_to_bitmap(self.face_mask_image_bitmap, 128, 128)
self.draw_nothing_yet_string_to_bitmap(self.mask_on_face_image_bitmap, 128, 128)
@contextmanager
def create_panel(self, parent, sizer, *args, **kwargs):
panel = wx.Panel(parent, *args, **kwargs)
panel.SetSizer(sizer)
panel.SetAutoLayout(1)
try:
yield panel, sizer
finally:
sizer.Fit(panel)
def init_left_panel(self, parent):
with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, sizer):
self.face_image_panel = wx.Panel(panel, size=(128, 128), style=wx.SIMPLE_BORDER)
self.face_image_panel.Bind(wx.EVT_PAINT, self.on_face_image_panel_paint)
sizer.Add(self.face_image_panel, 0, wx.EXPAND)
static_text = wx.StaticText(panel, label="Face", style=wx.ALIGN_CENTER)
sizer.Add(static_text, 0, wx.EXPAND)
self.face_mask_image_panel = wx.Panel(panel, size=(128, 128), style=wx.SIMPLE_BORDER)
self.face_mask_image_panel.Bind(wx.EVT_PAINT, self.on_face_mask_image_panel_paint)
sizer.Add(self.face_mask_image_panel, 0, wx.EXPAND)
static_text = wx.StaticText(panel, label="Face mask", style=wx.ALIGN_CENTER)
sizer.Add(static_text, 0, wx.EXPAND)
self.mask_on_face_image_panel = wx.Panel(panel, size=(128, 128), style=wx.SIMPLE_BORDER)
self.mask_on_face_image_panel.Bind(wx.EVT_PAINT, self.on_mask_on_face_image_panel_paint)
sizer.Add(self.mask_on_face_image_panel, 0, wx.EXPAND)
static_text = wx.StaticText(panel, label="Mask upon face", style=wx.ALIGN_CENTER)
sizer.Add(static_text, 0, wx.EXPAND)
return panel
def on_erase_background(self, event):
pass
def on_face_image_panel_paint(self, event):
wx.BufferedPaintDC(self.face_image_panel, self.face_image_bitmap)
def on_face_mask_image_panel_paint(self, event):
wx.BufferedPaintDC(self.face_mask_image_panel, self.face_mask_image_bitmap)
def on_mask_on_face_image_panel_paint(self, event):
wx.BufferedPaintDC(self.mask_on_face_image_panel, self.mask_on_face_image_bitmap)
def init_middle_panel(self, parent):
with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, sizer):
sizer.Add(self.init_prefix_panel(panel), 0, wx.EXPAND)
sizer.Add(self.init_character_image_file_name_panel(panel), 0, wx.EXPAND)
sizer.Add(self.init_face_mask_image_file_name_panel(panel), 0, wx.EXPAND)
sizer.Add(self.init_num_cpu_workers_panel(panel), 0, wx.EXPAND)
sizer.Add(self.init_num_gpus_panel(panel), 0, wx.EXPAND)
sizer.Add(self.init_face_morpher_random_seed_0_panel(panel), 0, wx.EXPAND)
sizer.Add(self.init_face_morpher_random_seed_1_panel(panel), 0, wx.EXPAND)
sizer.Add(self.init_face_morpher_batch_size_panel(panel), 0, wx.EXPAND)
sizer.Add(self.init_body_morpher_random_seed_0_panel(panel), 0, wx.EXPAND)
sizer.Add(self.init_body_morpher_random_seed_1_panel(panel), 0, wx.EXPAND)
sizer.Add(self.init_body_morpher_batch_size_panel(panel), 0, wx.EXPAND)
sizer.Add(self.init_num_training_examples_per_sample_output_panel(panel), 0, wx.EXPAND)
self.run_button = wx.Button(panel, label="RUN")
self.run_button.SetMinSize((-1, 64))
self.run_button.Bind(wx.EVT_BUTTON, self.on_run)
sizer.Add(self.run_button, 1, wx.EXPAND)
return panel
def init_prefix_panel(self, parent):
with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer):
prefix_param_name_panel = self.create_param_name_panel_with_help_button(
panel,
"prefix (i.e. project directory)",
self.create_help_button_func("distiller-ui-doc/params/prefix.html"))
panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND)
with self.create_panel(panel, wx.BoxSizer(wx.HORIZONTAL), style=wx.BORDER_NONE) \
as (prefix_panel, prefix_sizer):
self.prefix_text_ctrl = wx.TextCtrl(prefix_panel, value="")
self.prefix_text_ctrl.SetEditable(False)
prefix_sizer.Add(self.prefix_text_ctrl, 1, wx.EXPAND)
self.prefix_change_button = wx.Button(prefix_panel, label="Change...")
self.prefix_change_button.Bind(wx.EVT_BUTTON, self.on_prefix_change_button)
prefix_sizer.Add(self.prefix_change_button, 0, wx.EXPAND)
panel_sizer.Add(prefix_panel, 1, wx.EXPAND)
return panel
def on_prefix_change_button(self, event):
dir_dialog = wx.DirDialog(self, "Choose a directory.", style=wx.DD_DEFAULT_STYLE | wx.DD_NEW_DIR_BUTTON)
if dir_dialog.ShowModal() != wx.ID_OK:
return
prefix_value = dir_dialog.GetPath()
try:
self.state.set_prefix(prefix_value)
self.update_ui()
except Exception as e:
message_dialog = wx.MessageDialog(self, str(e), "Error", wx.OK | wx.ICON_ERROR)
message_dialog.ShowModal()
def init_character_image_file_name_panel(self, parent):
with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer):
prefix_param_name_panel = self.create_param_name_panel_with_help_button(
panel,
"character_image_file_name",
self.create_help_button_func("distiller-ui-doc/params/character_image_file_name.html"))
panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND)
with self.create_panel(panel, wx.BoxSizer(wx.HORIZONTAL), style=wx.BORDER_NONE) as (sub_panel, sub_sizer):
self.character_image_file_name_text_ctrl = wx.TextCtrl(sub_panel, value="")
self.character_image_file_name_text_ctrl.SetEditable(False)
sub_sizer.Add(self.character_image_file_name_text_ctrl, 1, wx.EXPAND)
self.character_image_change_button = wx.Button(sub_panel, label="Change...")
self.character_image_change_button.Bind(wx.EVT_BUTTON, self.on_character_image_change_button)
sub_sizer.Add(self.character_image_change_button, 0, wx.EXPAND)
panel_sizer.Add(sub_panel, 1, wx.EXPAND)
return panel
def on_character_image_change_button(self, event):
file_dialog = wx.FileDialog(self, "Choose a PNG file", wildcard="*.png", style=wx.FD_OPEN)
if file_dialog.ShowModal() != wx.ID_OK:
return
file_name = file_dialog.GetPath()
try:
self.state.set_character_image_file_name(file_name)
self.update_face_image_bitmap(file_name)
self.update_ui()
except Exception as e:
message_dialog = wx.MessageDialog(self, str(e), "Error", wx.OK | wx.ICON_ERROR)
message_dialog.ShowModal()
def update_face_image_bitmap(self, new_file_name: str):
pil_image = PIL.Image.open(new_file_name)
subimage = pil_image.crop((256 - 64, 80, 256 + 64, 208))
self.face_image_bitmap = wx.Bitmap.FromBufferRGBA(128, 128, subimage.convert("RGBA").tobytes())
self.face_image_pytorch = extract_pytorch_image_from_PIL_image(subimage).to(torch.float)
self.update_mask_on_face_image_bitmap()
def init_face_mask_image_file_name_panel(self, parent):
with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer):
prefix_param_name_panel = self.create_param_name_panel_with_help_button(
panel,
"face_mask_image_file_name",
self.create_help_button_func("distiller-ui-doc/params/face_mask_image_file_name.html"))
panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND)
with self.create_panel(panel, wx.BoxSizer(wx.HORIZONTAL), style=wx.BORDER_NONE) as (sub_panel, sub_sizer):
self.face_mask_image_file_name_text_ctrl = wx.TextCtrl(sub_panel, value="")
self.face_mask_image_file_name_text_ctrl.SetEditable(False)
sub_sizer.Add(self.face_mask_image_file_name_text_ctrl, 1, wx.EXPAND)
self.face_mask_image_file_name_change_button = wx.Button(sub_panel, label="Change...")
self.face_mask_image_file_name_change_button.Bind(wx.EVT_BUTTON, self.on_face_mask_image_change_button)
sub_sizer.Add(self.face_mask_image_file_name_change_button, 0, wx.EXPAND)
panel_sizer.Add(sub_panel, 1, wx.EXPAND)
return panel
def on_face_mask_image_change_button(self, event):
file_dialog = wx.FileDialog(self, "Choose a PNG file", wildcard="*.png", style=wx.FD_OPEN)
if file_dialog.ShowModal() != wx.ID_OK:
return
file_name = file_dialog.GetPath()
try:
self.state.set_face_mask_image_file_name(file_name)
self.update_face_mask_image_bitmap(file_name)
self.update_ui()
except Exception as e:
message_dialog = wx.MessageDialog(self, str(e), "Error", wx.OK | wx.ICON_ERROR)
message_dialog.ShowModal()
def update_face_mask_image_bitmap(self, new_file_name):
pil_image = PIL.Image.open(new_file_name)
subimage = pil_image.crop((256 - 64, 80, 256 + 64, 208))
self.face_mask_image_bitmap = wx.Bitmap.FromBufferRGBA(128, 128, subimage.convert("RGBA").tobytes())
self.face_mask_image_pytorch = extract_pytorch_image_from_PIL_image(subimage).to(torch.float)
self.face_mask_image_pytorch = self.face_mask_image_pytorch[0:1, :, :]
self.update_mask_on_face_image_bitmap()
def update_mask_on_face_image_bitmap(self):
if self.face_image_pytorch is None:
return
if self.face_mask_image_pytorch is None:
return
mask_on_face_image = (0.5 * self.face_image_pytorch) + (0.5 * self.face_mask_image_pytorch)
numpy_image = convert_output_image_from_torch_to_numpy(mask_on_face_image)
wx_image = wx.ImageFromBuffer(
numpy_image.shape[0],
numpy_image.shape[1],
numpy_image[:, :, 0:3].tobytes(),
numpy_image[:, :, 3].tobytes())
self.mask_on_face_image_bitmap = wx_image.ConvertToBitmap()
def init_num_cpu_workers_panel(self, parent):
with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer):
prefix_param_name_panel = self.create_param_name_panel_with_help_button(
panel,
"num_cpu_workers",
self.create_help_button_func("distiller-ui-doc/params/num_cpu_workers.html"))
panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND)
num_cpus = multiprocessing.cpu_count()
self.num_cpu_workers_spin_ctrl = wx.SpinCtrl(panel, initial=1, min=1, max=num_cpus)
@wx_bind_event(self.num_cpu_workers_spin_ctrl, wx.EVT_SPINCTRL)
def on_num_cpu_workers_spin_ctrl(event):
self.state.set_num_cpu_workers(self.num_cpu_workers_spin_ctrl.GetValue())
self.Refresh()
panel_sizer.Add(self.num_cpu_workers_spin_ctrl, 1, wx.EXPAND)
return panel
def init_num_gpus_panel(self, parent):
with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer):
prefix_param_name_panel = self.create_param_name_panel_with_help_button(
panel,
"num_gpus",
self.create_help_button_func("distiller-ui-doc/params/num_gpus.html"))
panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND)
num_gpus = torch.cuda.device_count()
self.num_gpus_spin_ctrl = wx.SpinCtrl(panel, initial=1, min=1, max=max(1, num_gpus))
@wx_bind_event(self.num_gpus_spin_ctrl, wx.EVT_SPINCTRL)
def on_num_cpu_workers_spin_ctrl(event):
self.state.set_num_gpus(self.num_gpus_spin_ctrl.GetValue())
self.Refresh()
panel_sizer.Add(self.num_gpus_spin_ctrl, 1, wx.EXPAND)
return panel
def init_face_morpher_random_seed_0_panel(self, parent):
with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer):
prefix_param_name_panel = self.create_param_name_panel_with_help_button(
panel,
"face_morpher_random_seed_0",
self.create_help_button_func("distiller-ui-doc/params/face_morpher_random_seed_0.html"))
panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND)
with self.create_panel(panel, wx.BoxSizer(wx.HORIZONTAL), style=wx.BORDER_NONE) as (sub_panel, sub_sizer):
initial_value = random.randint(0, 2 ** 64 - 1)
self.face_morpher_random_seed_0_int_ctrl = wx.lib.intctrl.IntCtrl(
sub_panel, value=initial_value, min=0, max=0x_ffff_ffff_ffff_ffff)
@wx_bind_event(self.face_morpher_random_seed_0_int_ctrl, wx.EVT_TEXT)
def on_face_morpher_random_seed_0_int_ctrl_text(event):
self.state.set_face_morpher_random_seed_0(self.face_morpher_random_seed_0_int_ctrl.GetValue())
sub_sizer.Add(self.face_morpher_random_seed_0_int_ctrl, 1, wx.EXPAND)
self.face_morpher_random_seed_0_randomize_button = wx.Button(sub_panel, label="Randomize")
@wx_bind_event(self.face_morpher_random_seed_0_randomize_button, wx.EVT_BUTTON)
def on_face_morpher_random_seed_0_randomize_button(event):
new_value = random.randint(0, 0x_ffff_ffff_ffff_ffff)
self.face_morpher_random_seed_0_int_ctrl.SetValue(new_value)
self.state.set_face_morpher_random_seed_0(new_value)
sub_sizer.Add(self.face_morpher_random_seed_0_randomize_button, 0, wx.EXPAND)
panel_sizer.Add(sub_panel, 1, wx.EXPAND)
return panel
def init_face_morpher_random_seed_1_panel(self, parent):
with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer):
prefix_param_name_panel = self.create_param_name_panel_with_help_button(
panel,
"face_morpher_random_seed_1",
self.create_help_button_func("distiller-ui-doc/params/face_morpher_random_seed_1.html"))
panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND)
with self.create_panel(panel, wx.BoxSizer(wx.HORIZONTAL), style=wx.BORDER_NONE) as (sub_panel, sub_sizer):
initial_value = random.randint(0, 2 ** 64 - 1)
self.face_morpher_random_seed_1_int_ctrl = wx.lib.intctrl.IntCtrl(
sub_panel, value=initial_value, min=0, max=0x_ffff_ffff_ffff_ffff)
@wx_bind_event(self.face_morpher_random_seed_1_int_ctrl, wx.EVT_TEXT)
def on_face_morpher_random_seed_1_int_ctrl_text(event):
self.state.set_face_morpher_random_seed_1(self.face_morpher_random_seed_1_int_ctrl.GetValue())
sub_sizer.Add(self.face_morpher_random_seed_1_int_ctrl, 1, wx.EXPAND)
self.face_morpher_random_seed_1_randomize_button = wx.Button(sub_panel, label="Randomize")
@wx_bind_event(self.face_morpher_random_seed_1_randomize_button, wx.EVT_BUTTON)
def on_face_morpher_random_seed_1_randomize_button(event):
new_value = random.randint(0, 0x_ffff_ffff_ffff_ffff)
self.face_morpher_random_seed_1_int_ctrl.SetValue(new_value)
self.state.set_face_morpher_random_seed_1(new_value)
sub_sizer.Add(self.face_morpher_random_seed_1_randomize_button, 0, wx.EXPAND)
panel_sizer.Add(sub_panel, 1, wx.EXPAND)
return panel
def init_face_morpher_batch_size_panel(self, parent):
with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer):
prefix_param_name_panel = self.create_param_name_panel_with_help_button(
panel,
"face_morpher_batch_size",
self.create_help_button_func("distiller-ui-doc/params/face_morpher_batch_size.html"))
panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND)
self.face_morpher_batch_size_spin_ctrl = wx.SpinCtrl(panel, initial=8, min=1, max=8)
@wx_bind_event(self.face_morpher_batch_size_spin_ctrl, wx.EVT_SPINCTRL)
def on_face_morpher_batch_size_spin_ctrl(event):
self.state.set_face_morpher_batch_size(self.face_morpher_batch_size_spin_ctrl.GetValue())
panel_sizer.Add(self.face_morpher_batch_size_spin_ctrl, 1, wx.EXPAND)
return panel
def init_body_morpher_random_seed_0_panel(self, parent):
with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer):
prefix_param_name_panel = self.create_param_name_panel_with_help_button(
panel,
"body_morpher_random_seed_0",
self.create_help_button_func("distiller-ui-doc/params/body_morpher_random_seed_0.html"))
panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND)
with self.create_panel(panel, wx.BoxSizer(wx.HORIZONTAL), style=wx.BORDER_NONE) as (sub_panel, sub_sizer):
initial_value = random.randint(0, 2 ** 64 - 1)
self.body_morpher_random_seed_0_int_ctrl = wx.lib.intctrl.IntCtrl(
sub_panel, value=initial_value, min=0, max=0x_ffff_ffff_ffff_ffff)
@wx_bind_event(self.body_morpher_random_seed_0_int_ctrl, wx.EVT_TEXT)
def on_body_morpher_random_seed_0_int_ctrl_text(event):
self.state.set_body_morpher_random_seed_0(self.body_morpher_random_seed_0_int_ctrl.GetValue())
sub_sizer.Add(self.body_morpher_random_seed_0_int_ctrl, 1, wx.EXPAND)
self.body_morpher_random_seed_0_randomize_button = wx.Button(sub_panel, label="Randomize")
@wx_bind_event(self.body_morpher_random_seed_0_randomize_button, wx.EVT_BUTTON)
def on_body_morpher_random_seed_0_randomize_button(event):
new_value = random.randint(0, 0x_ffff_ffff_ffff_ffff)
self.body_morpher_random_seed_0_int_ctrl.SetValue(new_value)
self.state.set_body_morpher_random_seed_0(new_value)
sub_sizer.Add(self.body_morpher_random_seed_0_randomize_button, 0, wx.EXPAND)
panel_sizer.Add(sub_panel, 1, wx.EXPAND)
return panel
def init_body_morpher_random_seed_1_panel(self, parent):
with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer):
prefix_param_name_panel = self.create_param_name_panel_with_help_button(
panel,
"body_morpher_random_seed_1",
self.create_help_button_func("distiller-ui-doc/params/body_morpher_random_seed_1.html"))
panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND)
with self.create_panel(panel, wx.BoxSizer(wx.HORIZONTAL), style=wx.BORDER_NONE) as (sub_panel, sub_sizer):
initial_value = random.randint(0, 2 ** 64 - 1)
self.body_morpher_random_seed_1_int_ctrl = wx.lib.intctrl.IntCtrl(
sub_panel, value=initial_value, min=0, max=0x_ffff_ffff_ffff_ffff)
@wx_bind_event(self.body_morpher_random_seed_1_int_ctrl, wx.EVT_TEXT)
def on_body_morpher_random_seed_1_int_ctrl_text(event):
self.state.set_body_morpher_random_seed_1(self.body_morpher_random_seed_1_int_ctrl.GetValue())
sub_sizer.Add(self.body_morpher_random_seed_1_int_ctrl, 1, wx.EXPAND)
self.body_morpher_random_seed_1_randomize_button = wx.Button(sub_panel, label="Randomize")
@wx_bind_event(self.body_morpher_random_seed_1_randomize_button, wx.EVT_BUTTON)
def on_body_morpher_random_seed_1_randomize_button(event):
new_value = random.randint(0, 0x_ffff_ffff_ffff_ffff)
self.body_morpher_random_seed_1_int_ctrl.SetValue(new_value)
self.state.set_body_morpher_random_seed_1(new_value)
sub_sizer.Add(self.body_morpher_random_seed_1_randomize_button, 0, wx.EXPAND)
panel_sizer.Add(sub_panel, 1, wx.EXPAND)
return panel
def init_body_morpher_batch_size_panel(self, parent):
with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer):
prefix_param_name_panel = self.create_param_name_panel_with_help_button(
panel,
"body_morpher_batch_size",
self.create_help_button_func("distiller-ui-doc/params/body_morpher_batch_size.html"))
panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND)
self.body_morpher_batch_size_spin_ctrl = wx.SpinCtrl(panel, initial=8, min=1, max=8)
@wx_bind_event(self.body_morpher_batch_size_spin_ctrl, wx.EVT_SPINCTRL)
def on_body_morpher_batch_size_spin_ctrl(event):
self.state.set_body_morpher_batch_size(self.body_morpher_batch_size_spin_ctrl.GetValue())
panel_sizer.Add(self.body_morpher_batch_size_spin_ctrl, 1, wx.EXPAND)
return panel
def init_num_training_examples_per_sample_output_panel(self, parent):
with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, panel_sizer):
prefix_param_name_panel = self.create_param_name_panel_with_help_button(
panel,
"num_training_examples_per_sample_output",
self.create_help_button_func("distiller-ui-doc/params/num_training_examples_per_sample_output.html"))
panel_sizer.Add(prefix_param_name_panel, 1, wx.EXPAND)
self.num_training_examples_per_sample_output_combobox = \
wx.ComboBox(panel,
value="10_000",
choices=DistillerUiMainFrame.NUM_TRAINING_EXAMPLES_PER_SAMPLE_OUTPUT_CHOICES)
@wx_bind_event(self.num_training_examples_per_sample_output_combobox, wx.EVT_COMBOBOX)
def on_num_training_examples_per_sample_output_combobox(event):
index = self.num_training_examples_per_sample_output_combobox.GetSelection()
if index == 3:
self.state.set_face_morpher_num_training_examples_per_sample_output(None)
self.state.set_body_morpher_num_training_examples_per_sample_output(None)
else:
selected = DistillerUiMainFrame.NUM_TRAINING_EXAMPLES_PER_SAMPLE_OUTPUT_CHOICES[index]
new_value = int(selected)
self.state.set_face_morpher_num_training_examples_per_sample_output(new_value)
self.state.set_body_morpher_num_training_examples_per_sample_output(new_value)
panel_sizer.Add(self.num_training_examples_per_sample_output_combobox, 1, wx.EXPAND)
return panel
def on_close(self, event):
if self.state.dirty:
confirmation_dialog = wx.MessageDialog(
parent=self,
message=f"You have not saved your work. Do you want to exit anyway?",
caption="Confirmation",
style=wx.YES_NO | wx.ICON_QUESTION)
result = confirmation_dialog.ShowModal()
if result == wx.ID_NO:
return
self.Destroy()
def create_help_button_func(self, html_file_name: str):
def init_help_button_func(parent):
button = wx.Button(parent, label="Help")
@wx_bind_event(button, wx.EVT_BUTTON)
def on_prefix_button(event):
self.html_window.LoadPage(html_file_name)
self.Refresh()
return button
return init_help_button_func
def create_param_name_panel_with_help_button(
self, parent, param_name: str, help_button_func: Callable[[wx.Window], wx.Button]):
with self.create_panel(parent, wx.BoxSizer(wx.HORIZONTAL), style=wx.NO_BORDER) \
as (panel, sizer):
title_text_panel = self.create_vertically_centered_text_panel(
panel, param_name, DistillerUiMainFrame.PARAM_NAME_STATIC_TEXT_MIN_WIDTH)
sizer.Add(title_text_panel, 1, wx.EXPAND)
help_button = help_button_func(panel)
sizer.Add(help_button, 0, wx.EXPAND)
return panel
def create_vertically_centered_text_panel(self, parent, text: str, min_width: int):
with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.NO_BORDER) as (panel, sizer):
sizer.AddStretchSpacer(1)
text = wx.StaticText(
panel,
label=text,
style=wx.ALIGN_CENTER)
text.SetMinSize((min_width, -1))
sizer.Add(text, 0, wx.EXPAND)
sizer.AddStretchSpacer(1)
return panel
def init_right_panel(self, parent):
with self.create_panel(parent, wx.BoxSizer(wx.VERTICAL), style=wx.BORDER_SIMPLE) as (panel, sizer):
self.html_window = wx.html.HtmlWindow(panel)
self.html_window.SetMinSize((600, 600))
self.html_window.SetFonts("Times New Roman", "Courier New", sizes=[10, 12, 14, 16, 18, 20, 24])
self.html_window.LoadPage("distiller-ui-doc/index.html")
sizer.Add(self.html_window, 1, wx.EXPAND)
go_to_main_documentation_button = wx.Button(panel, label="Go to Main Documentation")
sizer.Add(go_to_main_documentation_button, 0, wx.EXPAND)
@wx_bind_event(go_to_main_documentation_button, wx.EVT_BUTTON)
def on_go_to_main_documentation_button(event):
self.html_window.LoadPage("distiller-ui-doc/index.html")
self.Refresh()
return panel
def populate_distiller_config(self):
self.state.config.prefix = self.prefix_text_ctrl.GetValue()
self.state.config.character_image_file_name = self.character_image_file_name_text_ctrl.GetValue()
self.state.config.face_mask_image_file_name = self.face_mask_image_file_name_text_ctrl.GetValue()
self.state.config.num_cpu_workers = self.num_cpu_workers_spin_ctrl.GetValue()
self.state.config.num_gpus = self.num_gpus_spin_ctrl.GetValue()
self.state.config.face_morpher_random_seed_0 = self.face_morpher_random_seed_0_int_ctrl.GetValue()
self.state.config.face_morpher_random_seed_1 = self.face_morpher_random_seed_1_int_ctrl.GetValue()
self.state.config.face_morpher_batch_size = self.face_morpher_batch_size_spin_ctrl.GetValue()
self.state.config.body_morpher_random_seed_0 = self.body_morpher_random_seed_0_int_ctrl.GetValue()
self.state.config.body_morpher_random_seed_1 = self.body_morpher_random_seed_1_int_ctrl.GetValue()
self.state.config.body_morpher_batch_size = self.body_morpher_batch_size_spin_ctrl.GetValue()
if self.num_training_examples_per_sample_output_combobox.GetValue() == \
DistillerUiMainFrame.NUM_TRAINING_EXAMPLES_PER_SAMPLE_OUTPUT_CHOICES[-1]:
self.state.config.face_morpher_num_training_examples_per_sample_output = None
self.state.config.body_morpher_num_training_examples_per_sample_output = None
else:
value = int(self.num_training_examples_per_sample_output_combobox.GetValue())
self.state.config.face_morpher_num_training_examples_per_sample_output = value
self.state.config.body_morpher_num_training_examples_per_sample_output = value
def update_ui(self):
self.prefix_text_ctrl.SetValue(self.state.config.prefix)
self.character_image_file_name_text_ctrl.SetValue(self.state.config.character_image_file_name)
self.face_mask_image_file_name_text_ctrl.SetValue(self.state.config.face_mask_image_file_name)
if not self.state.can_show_character_image():
self.draw_nothing_yet_string_to_bitmap(self.face_image_bitmap, 128, 128)
if not self.state.can_show_face_mask_image():
self.draw_nothing_yet_string_to_bitmap(self.face_mask_image_bitmap, 128, 128)
if not self.state.can_show_mask_on_face_image():
self.draw_nothing_yet_string_to_bitmap(self.mask_on_face_image_bitmap, 128, 128)
self.num_cpu_workers_spin_ctrl.SetValue(self.state.config.num_cpu_workers)
self.num_gpus_spin_ctrl.SetValue(self.state.config.num_gpus)
self.face_morpher_random_seed_0_int_ctrl.SetValue(self.state.config.face_morpher_random_seed_0)
self.face_morpher_random_seed_1_int_ctrl.SetValue(self.state.config.face_morpher_random_seed_1)
self.face_morpher_batch_size_spin_ctrl.SetValue(self.state.config.face_morpher_batch_size)
self.body_morpher_random_seed_0_int_ctrl.SetValue(self.state.config.body_morpher_random_seed_0)
self.body_morpher_random_seed_1_int_ctrl.SetValue(self.state.config.body_morpher_random_seed_1)
self.body_morpher_batch_size_spin_ctrl.SetValue(self.state.config.body_morpher_batch_size)
if self.state.config.body_morpher_num_training_examples_per_sample_output is None:
self.num_training_examples_per_sample_output_combobox.SetSelection(3)
else:
choices = [int(x) for x in DistillerUiMainFrame.NUM_TRAINING_EXAMPLES_PER_SAMPLE_OUTPUT_CHOICES[:-1]]
self.num_training_examples_per_sample_output_combobox.SetSelection(
choices.index(self.state.config.body_morpher_num_training_examples_per_sample_output))
self.save_menu_item.Enable(self.state.can_save())
self.Refresh()
def draw_nothing_yet_string_to_bitmap(self, bitmap, width: int, height: int):
dc = wx.MemoryDC()
dc.SelectObject(bitmap)
dc.Clear()
font = wx.Font(wx.FontInfo(14).Family(wx.FONTFAMILY_SWISS))
dc.SetFont(font)
w, h = dc.GetTextExtent("Nothing yet!")
dc.DrawText("Nothing yet!", (width - w) // 2, (height - h) // 2)
del dc
def try_saving(self):
if not self.state.can_save():
message_dialog = wx.MessageDialog(
self,
"Cannot save yet! Please make sure you set the prefix, character_image_file_name, "
"and face_mask_image_file_name first.",
"Error",
wx.OK | wx.ICON_ERROR)
message_dialog.ShowModal()
return False
else:
if self.state.need_to_check_overwrite():
confirmation_dialog = wx.MessageDialog(
parent=self,
message=f"Overwriting {self.state.config.config_yaml_file_name()}?",
caption="Confirmation",
style=wx.YES_NO | wx.CANCEL | wx.ICON_QUESTION)
result = confirmation_dialog.ShowModal()
if result == wx.ID_YES:
self.state.save()
return True
elif result == wx.ID_NO:
return False
else:
return False
else:
self.state.save()
return True
def on_save(self, event):
return self.try_saving()
def on_new(self, event):
if self.state.dirty:
confirmation_dialog = wx.MessageDialog(
parent=self,
message=f"You have not saved the current config. Do you want to proceed?",
caption="Confirmation",
style=wx.YES_NO | wx.ICON_QUESTION)
result = confirmation_dialog.ShowModal()
if result == wx.ID_NO:
return
self.state = DistillerConfigState()
self.update_ui()
def on_open(self, event):
if self.state.dirty:
confirmation_dialog = wx.MessageDialog(
parent=self,
message=f"You have not saved the current config. Do you want to proceed?",
caption="Confirmation",
style=wx.YES_NO | wx.ICON_QUESTION)
result = confirmation_dialog.ShowModal()
if result == wx.ID_NO:
return
file_dialog = wx.FileDialog(self, "Choose a YAML file", wildcard="*.yaml", style=wx.FD_OPEN)
if file_dialog.ShowModal() != wx.ID_OK:
return
file_name = file_dialog.GetPath()
try:
self.state.load(file_name)
self.face_image_pytorch = None
self.face_mask_image_pytorch = None
self.update_face_image_bitmap(self.state.config.character_image_file_name)
self.update_face_mask_image_bitmap(self.state.config.face_mask_image_file_name)
self.update_ui()
except Exception as e:
message_dialog = wx.MessageDialog(self, str(e), "Error", wx.OK | wx.ICON_ERROR)
message_dialog.ShowModal()
def on_run(self, event):
try:
self.state.config.check()
except Exception as e:
message_dialog = wx.MessageDialog(self, str(e), "Error", wx.OK | wx.ICON_ERROR)
message_dialog.ShowModal()
return
if self.state.dirty:
message_dialog = wx.MessageDialog(
self,
"Please save the configuration first.",
"Error",
wx.OK | wx.ICON_ERROR)
message_dialog.ShowModal()
return
self.config_file_to_run = self.state.config.config_yaml_file_name()
self.Destroy()
================================================
FILE: src/tha4/image_util.py
================================================
import math
import PIL.Image
import numpy
import torch
from matplotlib import cm
from tha4.shion.base.image_util import numpy_linear_to_srgb, pytorch_rgba_to_numpy_image, pytorch_rgb_to_numpy_image, \
torch_linear_to_srgb
def grid_change_to_numpy_image(torch_image, num_channels=3):
height = torch_image.shape[1]
width = torch_image.shape[2]
size_image = (torch_image[0, :, :] ** 2 + torch_image[1, :, :] ** 2).sqrt().view(height, width, 1).numpy()
hsv = cm.get_cmap('hsv')
angle_image = hsv(((torch.atan2(
torch_image[0, :, :].view(height * width),
torch_image[1, :, :].view(height * width)).view(height, width) + math.pi) / (2 * math.pi)).numpy()) * 3
numpy_image = size_image * angle_image[:, :, 0:3]
rgb_image = numpy_linear_to_srgb(numpy_image)
if num_channels == 3:
return rgb_image
elif num_channels == 4:
return numpy.concatenate([rgb_image, numpy.ones_like(size_image)], axis=2)
else:
raise RuntimeError("Unsupported num_channels: " + str(num_channels))
def resize_PIL_image(pil_image, size=(256, 256)):
w, h = pil_image.size
d = min(w, h)
r = ((w - d) // 2, (h - d) // 2, (w + d) // 2, (h + d) // 2)
return pil_image.resize(size, resample=PIL.Image.LANCZOS, box=r)
def convert_output_image_from_torch_to_numpy(output_image):
if output_image.shape[2] == 2:
h, w, c = output_image.shape
numpy_image = torch.transpose(output_image.reshape(h * w, c), 0, 1).reshape(c, h, w)
elif output_image.shape[0] == 4:
numpy_image = pytorch_rgba_to_numpy_image(output_image)
elif output_image.shape[0] == 3:
numpy_image = pytorch_rgb_to_numpy_image(output_image)
elif output_image.shape[0] == 1:
c, h, w = output_image.shape
alpha_image = torch.cat([output_image.repeat(3, 1, 1) * 2.0 - 1.0, torch.ones(1, h, w)], dim=0)
numpy_image = pytorch_rgba_to_numpy_image(alpha_image)
elif output_image.shape[0] == 2:
numpy_image = grid_change_to_numpy_image(output_image, num_channels=4)
else:
raise RuntimeError("Unsupported # image channels: %d" % output_image.shape[0])
numpy_image = numpy.uint8(numpy.rint(numpy_image * 255.0))
return numpy_image
def convert_linear_to_srgb(image: torch.Tensor) -> torch.Tensor:
rgb_image = torch_linear_to_srgb(image[0:3, :, :])
return torch.cat([rgb_image, image[3:4, :, :]], dim=0)
================================================
FILE: src/tha4/mocap/__init__.py
================================================
================================================
FILE: src/tha4/mocap/ifacialmocap_constants.py
================================================
EYE_LOOK_IN_LEFT = "eyeLookInLeft"
EYE_LOOK_OUT_LEFT = "eyeLookOutLeft"
EYE_LOOK_DOWN_LEFT = "eyeLookDownLeft"
EYE_LOOK_UP_LEFT = "eyeLookUpLeft"
EYE_BLINK_LEFT = "eyeBlinkLeft"
EYE_SQUINT_LEFT = "eyeSquintLeft"
EYE_WIDE_LEFT = "eyeWideLeft"
EYE_LOOK_IN_RIGHT = "eyeLookInRight"
EYE_LOOK_OUT_RIGHT = "eyeLookOutRight"
EYE_LOOK_DOWN_RIGHT = "eyeLookDownRight"
EYE_LOOK_UP_RIGHT = "eyeLookUpRight"
EYE_BLINK_RIGHT = "eyeBlinkRight"
EYE_SQUINT_RIGHT = "eyeSquintRight"
EYE_WIDE_RIGHT = "eyeWideRight"
BROW_DOWN_LEFT = "browDownLeft"
BROW_OUTER_UP_LEFT = "browOuterUpLeft"
BROW_DOWN_RIGHT = "browDownRight"
BROW_OUTER_UP_RIGHT = "browOuterUpRight"
BROW_INNER_UP = "browInnerUp"
NOSE_SNEER_LEFT = "noseSneerLeft"
NOSE_SNEER_RIGHT = "noseSneerRight"
CHEEK_SQUINT_LEFT = "cheekSquintLeft"
CHEEK_SQUINT_RIGHT = "cheekSquintRight"
CHEEK_PUFF = "cheekPuff"
MOUTH_LEFT = "mouthLeft"
MOUTH_DIMPLE_LEFT = "mouthDimpleLeft"
MOUTH_FROWN_LEFT = "mouthFrownLeft"
MOUTH_LOWER_DOWN_LEFT = "mouthLowerDownLeft"
MOUTH_PRESS_LEFT = "mouthPressLeft"
MOUTH_SMILE_LEFT = "mouthSmileLeft"
MOUTH_STRETCH_LEFT = "mouthStretchLeft"
MOUTH_UPPER_UP_LEFT = "mouthUpperUpLeft"
MOUTH_RIGHT = "mouthRight"
MOUTH_DIMPLE_RIGHT = "mouthDimpleRight"
MOUTH_FROWN_RIGHT = "mouthFrownRight"
MOUTH_LOWER_DOWN_RIGHT = "mouthLowerDownRight"
MOUTH_PRESS_RIGHT = "mouthPressRight"
MOUTH_SMILE_RIGHT = "mouthSmileRight"
MOUTH_STRETCH_RIGHT = "mouthStretchRight"
MOUTH_UPPER_UP_RIGHT = "mouthUpperUpRight"
MOUTH_CLOSE = "mouthClose"
MOUTH_FUNNEL = "mouthFunnel"
MOUTH_PUCKER = "mouthPucker"
MOUTH_ROLL_LOWER = "mouthRollLower"
MOUTH_ROLL_UPPER = "mouthRollUpper"
MOUTH_SHRUG_LOWER = "mouthShrugLower"
MOUTH_SHRUG_UPPER = "mouthShrugUpper"
JAW_LEFT = "jawLeft"
JAW_RIGHT = "jawRight"
JAW_FORWARD = "jawForward"
JAW_OPEN = "jawOpen"
TONGUE_OUT = "tongueOut"
BLENDSHAPE_NAMES = [
EYE_LOOK_IN_LEFT, # 0
EYE_LOOK_OUT_LEFT, # 1
EYE_LOOK_DOWN_LEFT, # 2
EYE_LOOK_UP_LEFT, # 3
EYE_BLINK_LEFT, # 4
EYE_SQUINT_LEFT, # 5
EYE_WIDE_LEFT, # 6
EYE_LOOK_IN_RIGHT, # 7
EYE_LOOK_OUT_RIGHT, # 8
EYE_LOOK_DOWN_RIGHT, # 9
EYE_LOOK_UP_RIGHT, # 10
EYE_BLINK_RIGHT, # 11
EYE_SQUINT_RIGHT, # 12
EYE_WIDE_RIGHT, # 13
BROW_DOWN_LEFT, # 14
BROW_OUTER_UP_LEFT, # 15
BROW_DOWN_RIGHT, # 16
BROW_OUTER_UP_RIGHT, # 17
BROW_INNER_UP, # 18
NOSE_SNEER_LEFT, # 19
NOSE_SNEER_RIGHT, # 20
CHEEK_SQUINT_LEFT, # 21
CHEEK_SQUINT_RIGHT, # 22
CHEEK_PUFF, # 23
MOUTH_LEFT, # 24
MOUTH_DIMPLE_LEFT, # 25
MOUTH_FROWN_LEFT, # 26
MOUTH_LOWER_DOWN_LEFT, # 27
MOUTH_PRESS_LEFT, # 28
MOUTH_SMILE_LEFT, # 29
MOUTH_STRETCH_LEFT, # 30
MOUTH_UPPER_UP_LEFT, # 31
MOUTH_RIGHT, # 32
MOUTH_DIMPLE_RIGHT, # 33
MOUTH_FROWN_RIGHT, # 34
MOUTH_LOWER_DOWN_RIGHT, # 35
MOUTH_PRESS_RIGHT, # 36
MOUTH_SMILE_RIGHT, # 37
MOUTH_STRETCH_RIGHT, # 38
MOUTH_UPPER_UP_RIGHT, # 39
MOUTH_CLOSE, # 40
MOUTH_FUNNEL, # 41
MOUTH_PUCKER, # 42
MOUTH_ROLL_LOWER, # 43
MOUTH_ROLL_UPPER, # 44
MOUTH_SHRUG_LOWER, # 45
MOUTH_SHRUG_UPPER, # 46
JAW_LEFT, # 47
JAW_RIGHT, # 48
JAW_FORWARD, # 49
JAW_OPEN, # 50
TONGUE_OUT, # 51
]
EYE_LEFT_BLENDSHAPES = [
EYE_LOOK_IN_LEFT, # 0
EYE_LOOK_OUT_LEFT, # 1
EYE_LOOK_DOWN_LEFT, # 2
EYE_LOOK_UP_LEFT, # 3
EYE_BLINK_LEFT, # 4
EYE_SQUINT_LEFT, # 5
EYE_WIDE_LEFT, # 6
]
EYE_RIGHT_BLENDSHAPES = [
EYE_LOOK_IN_RIGHT, # 7
EYE_LOOK_OUT_RIGHT, # 8
EYE_LOOK_DOWN_RIGHT, # 9
EYE_LOOK_UP_RIGHT, # 10
EYE_BLINK_RIGHT, # 11
EYE_SQUINT_RIGHT, # 12
EYE_WIDE_RIGHT, # 13
]
BROW_LEFT_BLENDSHAPES = [
BROW_DOWN_LEFT, # 14
BROW_OUTER_UP_LEFT, # 15
]
BROW_RIGHT_BLENDSHAPES = [
BROW_DOWN_RIGHT, # 16
BROW_OUTER_UP_RIGHT, # 17
]
BROW_BOTH_BLENDSHAPES = [
BROW_INNER_UP, # 18
]
NOSE_BLENDSHAPES = [
NOSE_SNEER_LEFT, # 19
NOSE_SNEER_RIGHT, # 20
]
CHECK_BLENDSHAPES = [
CHEEK_SQUINT_LEFT, # 21
CHEEK_SQUINT_RIGHT, # 22
CHEEK_PUFF, # 23
]
MOUTH_LEFT_BLENDSHAPES = [
MOUTH_LEFT, # 24
MOUTH_DIMPLE_LEFT, # 25
MOUTH_FROWN_LEFT, # 26
MOUTH_LOWER_DOWN_LEFT, # 27
MOUTH_PRESS_LEFT, # 28
MOUTH_SMILE_LEFT, # 29
MOUTH_STRETCH_LEFT, # 30
MOUTH_UPPER_UP_LEFT, # 31
]
MOUTH_RIGHT_BLENDSHAPES = [
MOUTH_RIGHT, # 32
MOUTH_DIMPLE_RIGHT, # 33
MOUTH_FROWN_RIGHT, # 34
MOUTH_LOWER_DOWN_RIGHT, # 35
MOUTH_PRESS_RIGHT, # 36
MOUTH_SMILE_RIGHT, # 37
MOUTH_STRETCH_RIGHT, # 38
MOUTH_UPPER_UP_RIGHT, # 39
]
MOUTH_BOTH_BLENDSHAPES = [
MOUTH_CLOSE, # 40
MOUTH_FUNNEL, # 41
MOUTH_PUCKER, # 42
MOUTH_ROLL_LOWER, # 43
MOUTH_ROLL_UPPER, # 44
MOUTH_SHRUG_LOWER, # 45
MOUTH_SHRUG_UPPER, # 46
]
JAW_BLENDSHAPES = [
JAW_LEFT, # 47
JAW_RIGHT, # 48
JAW_FORWARD, # 49
JAW_OPEN, # 50
]
TONGUE_BLENDSHAPES = [
TONGUE_OUT, # 51
]
COLUMN_0_BLENDSHAPES = EYE_RIGHT_BLENDSHAPES + BROW_RIGHT_BLENDSHAPES + [NOSE_SNEER_RIGHT, CHEEK_SQUINT_RIGHT]
COLUMN_1_BLENDSHAPES = EYE_LEFT_BLENDSHAPES + BROW_LEFT_BLENDSHAPES + [NOSE_SNEER_LEFT, CHEEK_SQUINT_LEFT]
COLUMN_2_BLENDSHAPES = MOUTH_RIGHT_BLENDSHAPES + [JAW_RIGHT]
COLUMN_3_BLENDSHAPES = MOUTH_LEFT_BLENDSHAPES + [JAW_LEFT]
COLUMN_4_BLENDSHAPES = [BROW_INNER_UP, CHEEK_PUFF] + MOUTH_BOTH_BLENDSHAPES + [JAW_FORWARD, JAW_OPEN, TONGUE_OUT]
BLENDSHAPE_COLUMNS = [
COLUMN_0_BLENDSHAPES,
COLUMN_1_BLENDSHAPES,
COLUMN_2_BLENDSHAPES,
COLUMN_3_BLENDSHAPES,
COLUMN_4_BLENDSHAPES,
]
RIGHT_EYE_BONE_X = "rightEyeBoneX"
RIGHT_EYE_BONE_Y = "rightEyeBoneY"
RIGHT_EYE_BONE_Z = "rightEyeBoneZ"
RIGHT_EYE_BONE_ROTATIONS = [RIGHT_EYE_BONE_X, RIGHT_EYE_BONE_Y, RIGHT_EYE_BONE_Z]
LEFT_EYE_BONE_X = "leftEyeBoneX"
LEFT_EYE_BONE_Y = "leftEyeBoneY"
LEFT_EYE_BONE_Z = "leftEyeBoneZ"
LEFT_EYE_BONE_ROTATIONS = [LEFT_EYE_BONE_X, LEFT_EYE_BONE_Y, LEFT_EYE_BONE_Z]
HEAD_BONE_X = "headBoneX"
HEAD_BONE_Y = "headBoneY"
HEAD_BONE_Z = "headBoneZ"
HEAD_BONE_ROTATIONS = [HEAD_BONE_X, HEAD_BONE_Y, HEAD_BONE_Z]
ROTATION_NAMES = RIGHT_EYE_BONE_ROTATIONS + LEFT_EYE_BONE_ROTATIONS + HEAD_BONE_ROTATIONS
RIGHT_EYE_BONE_QUAT = "rightEyeBoneQuat"
LEFT_EYE_BONE_QUAT = "leftEyeBoneQuat"
HEAD_BONE_QUAT = "headBoneQuat"
QUATERNION_NAMES = [
RIGHT_EYE_BONE_QUAT,
LEFT_EYE_BONE_QUAT,
HEAD_BONE_QUAT
]
IFACIALMOCAP_DATETIME_FORMAT = "%Y/%m/%d-%H:%M:%S.%f"
================================================
FILE: src/tha4/mocap/ifacialmocap_pose.py
================================================
from tha4.mocap.ifacialmocap_constants import BLENDSHAPE_NAMES, HEAD_BONE_X, HEAD_BONE_Y, HEAD_BONE_Z, \
HEAD_BONE_QUAT, LEFT_EYE_BONE_X, LEFT_EYE_BONE_Y, LEFT_EYE_BONE_Z, LEFT_EYE_BONE_QUAT, RIGHT_EYE_BONE_X, \
RIGHT_EYE_BONE_Y, RIGHT_EYE_BONE_Z, RIGHT_EYE_BONE_QUAT
def create_default_ifacialmocap_pose():
data = {}
for blendshape_name in BLENDSHAPE_NAMES:
data[blendshape_name] = 0.0
data[HEAD_BONE_X] = 0.0
data[HEAD_BONE_Y] = 0.0
data[HEAD_BONE_Z] = 0.0
data[HEAD_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]
data[LEFT_EYE_BONE_X] = 0.0
data[LEFT_EYE_BONE_Y] = 0.0
data[LEFT_EYE_BONE_Z] = 0.0
data[LEFT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]
data[RIGHT_EYE_BONE_X] = 0.0
data[RIGHT_EYE_BONE_Y] = 0.0
data[RIGHT_EYE_BONE_Z] = 0.0
data[RIGHT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]
return data
================================================
FILE: src/tha4/mocap/ifacialmocap_pose_converter.py
================================================
from abc import ABC, abstractmethod
from typing import Dict, List
class IFacialMocapPoseConverter(ABC):
@abstractmethod
def convert(self, ifacialmocap_pose: Dict[str, float]) -> List[float]:
pass
@abstractmethod
def init_pose_converter_panel(self, parent):
pass
================================================
FILE: src/tha4/mocap/ifacialmocap_pose_converter_25.py
================================================
import math
import time
from enum import Enum
from typing import Optional, Dict, List, Callable
import numpy
import scipy.optimize
import wx
from tha4.mocap.ifacialmocap_constants import MOUTH_SMILE_LEFT, MOUTH_SHRUG_UPPER, MOUTH_SMILE_RIGHT, \
BROW_INNER_UP, BROW_OUTER_UP_RIGHT, BROW_OUTER_UP_LEFT, BROW_DOWN_LEFT, BROW_DOWN_RIGHT, EYE_WIDE_LEFT, \
EYE_WIDE_RIGHT, EYE_BLINK_LEFT, EYE_BLINK_RIGHT, CHEEK_SQUINT_LEFT, CHEEK_SQUINT_RIGHT, EYE_LOOK_IN_LEFT, \
EYE_LOOK_OUT_LEFT, EYE_LOOK_IN_RIGHT, EYE_LOOK_OUT_RIGHT, EYE_LOOK_UP_LEFT, EYE_LOOK_UP_RIGHT, EYE_LOOK_DOWN_RIGHT, \
EYE_LOOK_DOWN_LEFT, HEAD_BONE_X, HEAD_BONE_Y, HEAD_BONE_Z, JAW_OPEN, MOUTH_FROWN_LEFT, MOUTH_FROWN_RIGHT, \
MOUTH_LOWER_DOWN_LEFT, MOUTH_LOWER_DOWN_RIGHT, MOUTH_FUNNEL, MOUTH_PUCKER
from tha4.mocap.ifacialmocap_pose_converter import IFacialMocapPoseConverter
from tha4.poser.modes.pose_parameters import get_pose_parameters
class EyebrowDownMode(Enum):
TROUBLED = 1
ANGRY = 2
LOWERED = 3
SERIOUS = 4
class WinkMode(Enum):
NORMAL = 1
RELAXED = 2
def rad_to_deg(rad):
return rad * 180.0 / math.pi
def deg_to_rad(deg):
return deg * math.pi / 180.0
def clamp(x, min_value, max_value):
return max(min_value, min(max_value, x))
class IFacialMocapPoseConverter25Args:
def __init__(self,
smile_threshold_min: float = 0.4,
smile_threshold_max: float = 0.6,
eyebrow_down_mode: EyebrowDownMode = EyebrowDownMode.ANGRY,
wink_mode: WinkMode = WinkMode.NORMAL,
eye_surprised_max: float = 0.5,
eye_blink_max: float = 0.8,
eyebrow_down_max: float = 0.4,
cheek_squint_min: float = 0.1,
cheek_squint_max: float = 0.7,
eye_rotation_factor: float = 1.0 / 0.75,
jaw_open_min: float = 0.1,
jaw_open_max: float = 0.4,
mouth_frown_max: float = 0.6,
mouth_funnel_min: float = 0.25,
mouth_funnel_max: float = 0.5,
iris_small_left=0.0,
iris_small_right=0.0):
self.iris_small_right = iris_small_left
self.iris_small_left = iris_small_right
self.wink_mode = wink_mode
self.mouth_funnel_max = mouth_funnel_max
self.mouth_funnel_min = mouth_funnel_min
self.mouth_frown_max = mouth_frown_max
self.jaw_open_max = jaw_open_max
self.jaw_open_min = jaw_open_min
self.eye_rotation_factor = eye_rotation_factor
self.cheek_squint_max = cheek_squint_max
self.cheek_squint_min = cheek_squint_min
self.eyebrow_down_max = eyebrow_down_max
self.eye_blink_max = eye_blink_max
self.eye_surprised_max = eye_surprised_max
self.eyebrow_down_mode = eyebrow_down_mode
self.smile_threshold_min = smile_threshold_min
self.smile_threshold_max = smile_threshold_max
def set_smile_threshold_min(self, new_value: float):
self.smile_threshold_min = new_value
def set_smile_threshold_max(self, new_value: float):
self.smile_threshold_max = new_value
def set_eye_surprised_max(self, new_value: float):
self.eye_surprised_max = new_value
def set_eye_blink_max(self, new_value: float):
self.eye_blink_max = new_value
def set_eyebrow_down_max(self, new_value: float):
self.eyebrow_down_max = new_value
def set_cheek_squint_min(self, new_value: float):
self.cheek_squint_min = new_value
def set_cheek_squint_max(self, new_value: float):
self.cheek_squint_max = new_value
def set_jaw_open_min(self, new_value: float):
self.jaw_open_min = new_value
def set_jaw_open_max(self, new_value: float):
self.jaw_open_max = new_value
def set_mouth_frown_max(self, new_value: float):
self.mouth_frown_max = new_value
def set_mouth_funnel_min(self, new_value: float):
self.mouth_funnel_min = new_value
def set_mouth_funnel_max(self, new_value: float):
self.mouth_funnel_min = new_value
class IFacialMocapPoseConverter25(IFacialMocapPoseConverter):
def __init__(self, args: Optional[IFacialMocapPoseConverter25Args] = None):
super().__init__()
if args is None:
args = IFacialMocapPoseConverter25Args()
self.args = args
pose_parameters = get_pose_parameters()
self.pose_size = 45
self.eyebrow_troubled_left_index = pose_parameters.get_parameter_index("eyebrow_troubled_left")
self.eyebrow_troubled_right_index = pose_parameters.get_parameter_index("eyebrow_troubled_right")
self.eyebrow_angry_left_index = pose_parameters.get_parameter_index("eyebrow_angry_left")
self.eyebrow_angry_right_index = pose_parameters.get_parameter_index("eyebrow_angry_right")
self.eyebrow_happy_left_index = pose_parameters.get_parameter_index("eyebrow_happy_left")
self.eyebrow_happy_right_index = pose_parameters.get_parameter_index("eyebrow_happy_right")
self.eyebrow_raised_left_index = pose_parameters.get_parameter_index("eyebrow_raised_left")
self.eyebrow_raised_right_index = pose_parameters.get_parameter_index("eyebrow_raised_right")
self.eyebrow_lowered_left_index = pose_parameters.get_parameter_index("eyebrow_lowered_left")
self.eyebrow_lowered_right_index = pose_parameters.get_parameter_index("eyebrow_lowered_right")
self.eyebrow_serious_left_index = pose_parameters.get_parameter_index("eyebrow_serious_left")
self.eyebrow_serious_right_index = pose_parameters.get_parameter_index("eyebrow_serious_right")
self.eye_surprised_left_index = pose_parameters.get_parameter_index("eye_surprised_left")
self.eye_surprised_right_index = pose_parameters.get_parameter_index("eye_surprised_right")
self.eye_wink_left_index = pose_parameters.get_parameter_index("eye_wink_left")
self.eye_wink_right_index = pose_parameters.get_parameter_index("eye_wink_right")
self.eye_happy_wink_left_index = pose_parameters.get_parameter_index("eye_happy_wink_left")
self.eye_happy_wink_right_index = pose_parameters.get_parameter_index("eye_happy_wink_right")
self.eye_relaxed_left_index = pose_parameters.get_parameter_index("eye_relaxed_left")
self.eye_relaxed_right_index = pose_parameters.get_parameter_index("eye_relaxed_right")
self.eye_raised_lower_eyelid_left_index = pose_parameters.get_parameter_index("eye_raised_lower_eyelid_left")
self.eye_raised_lower_eyelid_right_index = pose_parameters.get_parameter_index("eye_raised_lower_eyelid_right")
self.iris_small_left_index = pose_parameters.get_parameter_index("iris_small_left")
self.iris_small_right_index = pose_parameters.get_parameter_index("iris_small_right")
self.iris_rotation_x_index = pose_parameters.get_parameter_index("iris_rotation_x")
self.iris_rotation_y_index = pose_parameters.get_parameter_index("iris_rotation_y")
self.head_x_index = pose_parameters.get_parameter_index("head_x")
self.head_y_index = pose_parameters.get_parameter_index("head_y")
self.neck_z_index = pose_parameters.get_parameter_index("neck_z")
self.mouth_aaa_index = pose_parameters.get_parameter_index("mouth_aaa")
self.mouth_iii_index = pose_parameters.get_parameter_index("mouth_iii")
self.mouth_uuu_index = pose_parameters.get_parameter_index("mouth_uuu")
self.mouth_eee_index = pose_parameters.get_parameter_index("mouth_eee")
self.mouth_ooo_index = pose_parameters.get_parameter_index("mouth_ooo")
self.mouth_lowered_corner_left_index = pose_parameters.get_parameter_index("mouth_lowered_corner_left")
self.mouth_lowered_corner_right_index = pose_parameters.get_parameter_index("mouth_lowered_corner_right")
self.mouth_raised_corner_left_index = pose_parameters.get_parameter_index("mouth_raised_corner_left")
self.mouth_raised_corner_right_index = pose_parameters.get_parameter_index("mouth_raised_corner_right")
self.body_y_index = pose_parameters.get_parameter_index("body_y")
self.body_z_index = pose_parameters.get_parameter_index("body_z")
self.breathing_index = pose_parameters.get_parameter_index("breathing")
self.breathing_start_time = time.time()
self.panel = None
def init_pose_converter_panel(self, parent):
self.panel = wx.Panel(parent, style=wx.SIMPLE_BORDER)
self.panel_sizer = wx.BoxSizer(wx.VERTICAL)
self.panel.SetSizer(self.panel_sizer)
self.panel.SetAutoLayout(1)
parent.GetSizer().Add(self.panel, 0, wx.EXPAND)
if True:
eyebrow_down_mode_text = wx.StaticText(self.panel, label=" --- Eyebrow Down Mode --- ",
style=wx.ALIGN_CENTER)
self.panel_sizer.Add(eyebrow_down_mode_text, 0, wx.EXPAND)
self.eyebrow_down_mode_choice = wx.Choice(
self.panel,
choices=[
"ANGRY",
"TROUBLED",
"SERIOUS",
"LOWERED",
])
self.eyebrow_down_mode_choice.SetSelection(0)
self.panel_sizer.Add(self.eyebrow_down_mode_choice, 0, wx.EXPAND)
self.eyebrow_down_mode_choice.Bind(wx.EVT_CHOICE, self.change_eyebrow_down_mode)
separator = wx.StaticLine(self.panel, -1, size=(256, 5))
self.panel_sizer.Add(separator, 0, wx.EXPAND)
if True:
wink_mode_text = wx.StaticText(self.panel, label=" --- Wink Mode --- ", style=wx.ALIGN_CENTER)
self.panel_sizer.Add(wink_mode_text, 0, wx.EXPAND)
self.wink_mode_choice = wx.Choice(
self.panel,
choices=[
"NORMAL",
"RELAXED",
])
self.wink_mode_choice.SetSelection(0)
self.panel_sizer.Add(self.wink_mode_choice, 0, wx.EXPAND)
self.wink_mode_choice.Bind(wx.EVT_CHOICE, self.change_wink_mode)
separator = wx.StaticLine(self.panel, -1, size=(256, 5))
self.panel_sizer.Add(separator, 0, wx.EXPAND)
if True:
iris_size_text = wx.StaticText(self.panel, label=" --- Iris Size --- ", style=wx.ALIGN_CENTER)
self.panel_sizer.Add(iris_size_text, 0, wx.EXPAND)
self.iris_left_slider = wx.Slider(self.panel, minValue=0, maxValue=1000, value=0, style=wx.HORIZONTAL)
self.panel_sizer.Add(self.iris_left_slider, 0, wx.EXPAND)
self.iris_left_slider.Bind(wx.EVT_SLIDER, self.change_iris_size)
self.iris_right_slider = wx.Slider(self.panel, minValue=0, maxValue=1000, value=0, style=wx.HORIZONTAL)
self.panel_sizer.Add(self.iris_right_slider, 0, wx.EXPAND)
self.iris_right_slider.Bind(wx.EVT_SLIDER, self.change_iris_size)
self.iris_right_slider.Enable(False)
self.link_left_right_irises = wx.CheckBox(
self.panel, label="Use same value for both sides")
self.link_left_right_irises.SetValue(True)
self.panel_sizer.Add(self.link_left_right_irises, wx.SizerFlags().CenterHorizontal().Border())
self.link_left_right_irises.Bind(wx.EVT_CHECKBOX, self.link_left_right_irises_clicked)
separator = wx.StaticLine(self.panel, -1, size=(256, 5))
self.panel_sizer.Add(separator, 0, wx.EXPAND)
if True:
breathing_frequency_text = wx.StaticText(
self.panel, label=" --- Breathing --- ", style=wx.ALIGN_CENTER)
self.panel_sizer.Add(breathing_frequency_text, 0, wx.EXPAND)
self.restart_breathing_cycle_button = wx.Button(self.panel, label="Restart Breathing Cycle")
self.restart_breathing_cycle_button.Bind(wx.EVT_BUTTON, self.restart_breathing_cycle_clicked)
self.panel_sizer.Add(self.restart_breathing_cycle_button, 0, wx.EXPAND)
self.breathing_frequency_slider = wx.Slider(
self.panel, minValue=0, maxValue=60, value=20, style=wx.HORIZONTAL)
self.panel_sizer.Add(self.breathing_frequency_slider, 0, wx.EXPAND)
self.breathing_gauge = wx.Gauge(self.panel, style=wx.GA_HORIZONTAL, range=1000)
self.panel_sizer.Add(self.breathing_gauge, 0, wx.EXPAND)
if True:
separator = wx.StaticLine(self.panel, -1, size=(256, 5))
self.panel_sizer.Add(separator, 0, wx.EXPAND)
convertion_parameters_text = wx.StaticText(
self.panel, label="--- Conversion Parameters ---", style=wx.ALIGN_CENTER)
self.panel_sizer.Add(convertion_parameters_text, 0, wx.EXPAND)
conversion_param_panel = wx.Panel(self.panel)
self.panel_sizer.Add(conversion_param_panel, 0, wx.EXPAND)
conversion_panel_sizer = wx.FlexGridSizer(cols=2)
conversion_panel_sizer.AddGrowableCol(1)
conversion_param_panel.SetSizer(conversion_panel_sizer)
conversion_param_panel.SetAutoLayout(1)
self.smile_thresold_min_spin = self.create_spin_control(
conversion_param_panel,
"Smile Threshold Min:", self.args.smile_threshold_min, self.args.set_smile_threshold_min)
self.smile_thresold_max_spin = self.create_spin_control(
conversion_param_panel,
"Smile Threshold Max:", self.args.smile_threshold_max, self.args.set_smile_threshold_max)
self.eye_surprised_max_spin = self.create_spin_control(
conversion_param_panel,
"Eye Surprised Max:", self.args.eye_surprised_max, self.args.set_eye_surprised_max)
self.eye_blink_max_spin = self.create_spin_control(
conversion_param_panel,
"Eye Blink Max:", self.args.eye_blink_max, self.args.set_eye_blink_max)
self.eyebrow_down_max_spin = self.create_spin_control(
conversion_param_panel,
"Eyebrow Down Max:", self.args.eyebrow_down_max, self.args.set_eyebrow_down_max)
self.cheek_squint_min_spin = self.create_spin_control(
conversion_param_panel,
"Cheek Squint Min:", self.args.cheek_squint_min, self.args.set_cheek_squint_min)
self.cheek_squint_max_spin = self.create_spin_control(
conversion_param_panel,
"Cheek Squint Max:", self.args.cheek_squint_max, self.args.set_cheek_squint_max)
self.jaw_open_min_spin = self.create_spin_control(
conversion_param_panel,
"Jaw Open Min:", self.args.jaw_open_min, self.args.set_jaw_open_min)
self.jaw_open_max_spin = self.create_spin_control(
conversion_param_panel,
"Jaw Open Max:", self.args.jaw_open_max, self.args.set_jaw_open_max)
self.mouth_frown_max_spin = self.create_spin_control(
conversion_param_panel,
"Mouth Frown Max:", self.args.mouth_frown_max, self.args.set_mouth_frown_max)
self.mouth_funnel_min_spin = self.create_spin_control(
conversion_param_panel,
"Mouth Funnel Min:", self.args.mouth_funnel_min, self.args.set_mouth_funnel_min)
self.mouth_funnel_max_spin = self.create_spin_control(
conversion_param_panel,
"Mouth Funnel Max:", self.args.mouth_funnel_max, self.args.set_mouth_funnel_max)
self.panel_sizer.Fit(self.panel)
def create_spin_control(self, parent, label: str, initial_value: float, set_func: Callable[[float], None]):
sizer = parent.GetSizer()
text = wx.StaticText(parent, label=label)
sizer.Add(text, wx.SizerFlags().Right().Border(wx.ALL, 2))
spin_ctrl = wx.SpinCtrlDouble(
parent,
wx.ID_ANY,
min=0.0,
max=1.0,
initial=initial_value,
inc=0.01)
sizer.Add(spin_ctrl, wx.SizerFlags().Border(wx.ALL, 2).Expand())
def handler(event: wx.Event):
new_value = spin_ctrl.GetValue()
set_func(new_value)
spin_ctrl.Bind(wx.EVT_SPINCTRLDOUBLE, handler)
return spin_ctrl
def restart_breathing_cycle_clicked(self, event: wx.Event):
self.breathing_start_time = time.time()
def change_eyebrow_down_mode(self, event: wx.Event):
selected_index = self.eyebrow_down_mode_choice.GetSelection()
if selected_index == 0:
self.args.eyebrow_down_mode = EyebrowDownMode.ANGRY
elif selected_index == 1:
self.args.eyebrow_down_mode = EyebrowDownMode.TROUBLED
elif selected_index == 2:
self.args.eyebrow_down_mode = EyebrowDownMode.SERIOUS
else:
self.args.eyebrow_down_mode = EyebrowDownMode.LOWERED
def change_wink_mode(self, event: wx.Event):
selected_index = self.wink_mode_choice.GetSelection()
if selected_index == 0:
self.args.wink_mode = WinkMode.NORMAL
else:
self.args.wink_mode = WinkMode.RELAXED
def change_iris_size(self, event: wx.Event):
if self.link_left_right_irises.GetValue():
left_value = self.iris_left_slider.GetValue()
right_value = self.iris_right_slider.GetValue()
if left_value != right_value:
self.iris_right_slider.SetValue(left_value)
self.args.iris_small_left = left_value / 1000.0
self.args.iris_small_right = left_value / 1000.0
else:
self.args.iris_small_left = self.iris_left_slider.GetValue() / 1000.0
self.args.iris_small_right = self.iris_right_slider.GetValue() / 1000.0
def link_left_right_irises_clicked(self, event: wx.Event):
if self.link_left_right_irises.GetValue():
self.iris_right_slider.Enable(False)
else:
self.iris_right_slider.Enable(True)
self.change_iris_size(event)
def decompose_head_body_param(self, param, threshold=2.0 / 3):
if abs(param) < threshold:
return (param, 0.0)
else:
if param < 0:
sign = -1.0
else:
sign = 1.0
return (threshold * sign, (abs(param) - threshold) * sign)
def convert(self, ifacialmocap_pose: Dict[str, float]) -> List[float]:
pose = [0.0 for i in range(self.pose_size)]
smile_value = \
(ifacialmocap_pose[MOUTH_SMILE_LEFT] + ifacialmocap_pose[MOUTH_SMILE_RIGHT]) / 2.0 \
+ ifacialmocap_pose[MOUTH_SHRUG_UPPER]
if self.args.smile_threshold_min >= self.args.smile_threshold_max:
smile_degree = 0.0
else:
if smile_value < self.args.smile_threshold_min:
smile_degree = 0.0
elif smile_value > self.args.smile_threshold_max:
smile_degree = 1.0
else:
smile_degree = (smile_value - self.args.smile_threshold_min) / (
self.args.smile_threshold_max - self.args.smile_threshold_min)
# Eyebrow
if True:
brow_inner_up = ifacialmocap_pose[BROW_INNER_UP]
brow_outer_up_right = ifacialmocap_pose[BROW_OUTER_UP_RIGHT]
brow_outer_up_left = ifacialmocap_pose[BROW_OUTER_UP_LEFT]
brow_up_left = clamp(brow_inner_up + brow_outer_up_left, 0.0, 1.0)
brow_up_right = clamp(brow_inner_up + brow_outer_up_right, 0.0, 1.0)
pose[self.eyebrow_raised_left_index] = brow_up_left
pose[self.eyebrow_raised_right_index] = brow_up_right
if self.args.eyebrow_down_max <= 0.0:
brow_down_left = 0.0
brow_down_right = 0.0
else:
brow_down_left = (1.0 - smile_degree) \
* clamp(ifacialmocap_pose[BROW_DOWN_LEFT] / self.args.eyebrow_down_max, 0.0, 1.0)
brow_down_right = (1.0 - smile_degree) \
* clamp(ifacialmocap_pose[BROW_DOWN_RIGHT] / self.args.eyebrow_down_max, 0.0, 1.0)
if self.args.eyebrow_down_mode == EyebrowDownMode.TROUBLED:
pose[self.eyebrow_troubled_left_index] = brow_down_left
pose[self.eyebrow_troubled_right_index] = brow_down_right
elif self.args.eyebrow_down_mode == EyebrowDownMode.ANGRY:
pose[self.eyebrow_angry_left_index] = brow_down_left
pose[self.eyebrow_angry_right_index] = brow_down_right
elif self.args.eyebrow_down_mode == EyebrowDownMode.LOWERED:
pose[self.eyebrow_lowered_left_index] = brow_down_left
pose[self.eyebrow_lowered_right_index] = brow_down_right
elif self.args.eyebrow_down_mode == EyebrowDownMode.SERIOUS:
pose[self.eyebrow_serious_left_index] = brow_down_left
pose[self.eyebrow_serious_right_index] = brow_down_right
brow_happy_value = clamp(smile_value, 0.0, 1.0) * smile_degree
pose[self.eyebrow_happy_left_index] = brow_happy_value
pose[self.eyebrow_happy_right_index] = brow_happy_value
# Eye
if True:
# Surprised
if self.args.eye_surprised_max <= 0.0:
pose[self.eye_surprised_left_index] = 0.0
pose[self.eye_surprised_right_index] = 0.0
else:
pose[self.eye_surprised_left_index] = clamp(
ifacialmocap_pose[EYE_WIDE_LEFT] / self.args.eye_surprised_max, 0.0, 1.0)
pose[self.eye_surprised_right_index] = clamp(
ifacialmocap_pose[EYE_WIDE_RIGHT] / self.args.eye_surprised_max, 0.0, 1.0)
# Wink
if self.args.wink_mode == WinkMode.NORMAL:
wink_left_index = self.eye_wink_left_index
wink_right_index = self.eye_wink_right_index
else:
wink_left_index = self.eye_relaxed_left_index
wink_right_index = self.eye_relaxed_right_index
if self.args.eye_blink_max <= 0:
pose[wink_left_index] = 0.0
pose[wink_right_index] = 0.0
pose[self.eye_happy_wink_left_index] = 0.0
pose[self.eye_happy_wink_right_index] = 0.0
else:
pose[wink_left_index] = (1.0 - smile_degree) * clamp(
ifacialmocap_pose[EYE_BLINK_LEFT] / self.args.eye_blink_max, 0.0, 1.0)
pose[wink_right_index] = (1.0 - smile_degree) * clamp(
ifacialmocap_pose[EYE_BLINK_RIGHT] / self.args.eye_blink_max, 0.0, 1.0)
pose[self.eye_happy_wink_left_index] = smile_degree * clamp(
ifacialmocap_pose[EYE_BLINK_LEFT] / self.args.eye_blink_max, 0.0, 1.0)
pose[self.eye_happy_wink_right_index] = smile_degree * clamp(
ifacialmocap_pose[EYE_BLINK_RIGHT] / self.args.eye_blink_max, 0.0, 1.0)
# Lower eyelid
cheek_squint_denom = self.args.cheek_squint_max - self.args.cheek_squint_min
if cheek_squint_denom <= 0.0:
pose[self.eye_raised_lower_eyelid_left_index] = 0.0
pose[self.eye_raised_lower_eyelid_right_index] = 0.0
else:
pose[self.eye_raised_lower_eyelid_left_index] = \
clamp(
(ifacialmocap_pose[CHEEK_SQUINT_LEFT] - self.args.cheek_squint_min) / cheek_squint_denom,
0.0, 1.0)
pose[self.eye_raised_lower_eyelid_right_index] = \
clamp(
(ifacialmocap_pose[CHEEK_SQUINT_RIGHT] - self.args.cheek_squint_min) / cheek_squint_denom,
0.0, 1.0)
# Iris rotation
if True:
eye_rotation_y = (ifacialmocap_pose[EYE_LOOK_IN_LEFT]
- ifacialmocap_pose[EYE_LOOK_OUT_LEFT]
- ifacialmocap_pose[EYE_LOOK_IN_RIGHT]
+ ifacialmocap_pose[EYE_LOOK_OUT_RIGHT]) / 2.0 * self.args.eye_rotation_factor
pose[self.iris_rotation_y_index] = clamp(eye_rotation_y, -1.0, 1.0)
eye_rotation_x = (ifacialmocap_pose[EYE_LOOK_UP_LEFT]
+ ifacialmocap_pose[EYE_LOOK_UP_RIGHT]
- ifacialmocap_pose[EYE_LOOK_DOWN_LEFT]
- ifacialmocap_pose[EYE_LOOK_DOWN_RIGHT]) / 2.0 * self.args.eye_rotation_factor
pose[self.iris_rotation_x_index] = clamp(eye_rotation_x, -1.0, 1.0)
# Iris size
if True:
pose[self.iris_small_left_index] = self.args.iris_small_left
pose[self.iris_small_right_index] = self.args.iris_small_right
# Head rotation
if True:
x_param = clamp(-ifacialmocap_pose[HEAD_BONE_X] * 180.0 / math.pi, -15.0, 15.0) / 15.0
pose[self.head_x_index] = x_param
y_param = clamp(-ifacialmocap_pose[HEAD_BONE_Y] * 180.0 / math.pi, -10.0, 10.0) / 10.0
pose[self.head_y_index] = y_param
pose[self.body_y_index] = y_param
z_param = clamp(ifacialmocap_pose[HEAD_BONE_Z] * 180.0 / math.pi, -15.0, 15.0) / 15.0
pose[self.neck_z_index] = z_param
pose[self.body_z_index] = z_param
# Mouth
if True:
jaw_open_denom = self.args.jaw_open_max - self.args.jaw_open_min
if jaw_open_denom <= 0:
mouth_open = 0.0
else:
mouth_open = clamp((ifacialmocap_pose[JAW_OPEN] - self.args.jaw_open_min) / jaw_open_denom, 0.0, 1.0)
pose[self.mouth_aaa_index] = mouth_open
pose[self.mouth_raised_corner_left_index] = clamp(smile_value, 0.0, 1.0)
pose[self.mouth_raised_corner_right_index] = clamp(smile_value, 0.0, 1.0)
is_mouth_open = mouth_open > 0.0
if not is_mouth_open:
if self.args.mouth_frown_max > 0:
mouth_frown_value = 0.0
else:
mouth_frown_value = clamp(
(ifacialmocap_pose[MOUTH_FROWN_LEFT] + ifacialmocap_pose[
MOUTH_FROWN_RIGHT]) / self.args.mouth_frown_max, 0.0, 1.0)
pose[self.mouth_lowered_corner_left_index] = mouth_frown_value
pose[self.mouth_lowered_corner_right_index] = mouth_frown_value
else:
mouth_lower_down = clamp(
ifacialmocap_pose[MOUTH_LOWER_DOWN_LEFT] + ifacialmocap_pose[MOUTH_LOWER_DOWN_RIGHT], 0.0, 1.0)
mouth_funnel = ifacialmocap_pose[MOUTH_FUNNEL]
mouth_pucker = ifacialmocap_pose[MOUTH_PUCKER]
mouth_point = [mouth_open, mouth_lower_down, mouth_funnel, mouth_pucker]
aaa_point = [1.0, 1.0, 0.0, 0.0]
iii_point = [0.0, 1.0, 0.0, 0.0]
uuu_point = [0.5, 0.3, 0.25, 0.75]
ooo_point = [1.0, 0.5, 0.5, 0.4]
decomp = numpy.array([0, 0, 0, 0])
M = numpy.array([
aaa_point,
iii_point,
uuu_point,
ooo_point
])
def loss(decomp):
return numpy.linalg.norm(numpy.matmul(decomp, M) - mouth_point) \
+ 0.01 * numpy.linalg.norm(decomp, ord=1)
opt_result = scipy.optimize.minimize(
loss, decomp, bounds=[(0.0, 1.0), (0.0, 1.0), (0.0, 1.0), (0.0, 1.0)])
decomp = opt_result["x"]
restricted_decomp = [decomp.item(0), decomp.item(1), decomp.item(2), decomp.item(3)]
pose[self.mouth_aaa_index] = restricted_decomp[0]
pose[self.mouth_iii_index] = restricted_decomp[1]
mouth_funnel_denom = self.args.mouth_funnel_max - self.args.mouth_funnel_min
if mouth_funnel_denom <= 0:
ooo_alpha = 0.0
uo_value = 0.0
else:
ooo_alpha = clamp((mouth_funnel - self.args.mouth_funnel_min) / mouth_funnel_denom, 0.0, 1.0)
uo_value = clamp(restricted_decomp[2] + restricted_decomp[3], 0.0, 1.0)
pose[self.mouth_uuu_index] = uo_value * (1.0 - ooo_alpha)
pose[self.mouth_ooo_index] = uo_value * ooo_alpha
if self.panel is not None:
frequency = self.breathing_frequency_slider.GetValue()
if frequency == 0:
value = 0.0
pose[self.breathing_index] = value
self.breathing_start_time = time.time()
else:
period = 60.0 / frequency
now = time.time()
diff = now - self.breathing_start_time
frac = (diff % period) / period
value = (-math.cos(2 * math.pi * frac) + 1.0) / 2.0
pose[self.breathing_index] = value
self.breathing_gauge.SetValue(int(1000 * value))
return pose
def create_ifacialmocap_pose_converter(
args: Optional[IFacialMocapPoseConverter25Args] = None) -> IFacialMocapPoseConverter:
return IFacialMocapPoseConverter25(args)
================================================
FILE: src/tha4/mocap/ifacialmocap_v2.py
================================================
import math
from tha4.mocap.ifacialmocap_constants import BLENDSHAPE_NAMES, HEAD_BONE_X, HEAD_BONE_Y, HEAD_BONE_Z, \
RIGHT_EYE_BONE_X, RIGHT_EYE_BONE_Y, RIGHT_EYE_BONE_Z, LEFT_EYE_BONE_X, LEFT_EYE_BONE_Y, LEFT_EYE_BONE_Z, \
HEAD_BONE_QUAT, LEFT_EYE_BONE_QUAT, RIGHT_EYE_BONE_QUAT
IFACIALMOCAP_PORT = 49983
IFACIALMOCAP_START_STRING = "iFacialMocap_sahuasouryya9218sauhuiayeta91555dy3719|sendDataVersion=v2".encode('utf-8')
def parse_ifacialmocap_v2_pose(ifacialmocap_output):
output = {}
parts = ifacialmocap_output.split("|")
for part in parts:
part = part.strip()
if len(part) == 0:
continue
if "&" in part:
components = part.split("&")
assert len(components) == 2
key = components[0]
value = float(components[1]) / 100.0
if key.endswith("_L"):
key = key[:-2] + "Left"
elif key.endswith("_R"):
key = key[:-2] + "Right"
if key in BLENDSHAPE_NAMES:
output[key] = value
elif part.startswith("=head#"):
components = part[len("=head#"):].split(",")
assert len(components) == 6
output[HEAD_BONE_X] = float(components[0]) * math.pi / 180
output[HEAD_BONE_Y] = float(components[1]) * math.pi / 180
output[HEAD_BONE_Z] = float(components[2]) * math.pi / 180
elif part.startswith("rightEye#"):
components = part[len("rightEye#"):].split(",")
output[RIGHT_EYE_BONE_X] = float(components[0]) * math.pi / 180
output[RIGHT_EYE_BONE_Y] = float(components[1]) * math.pi / 180
output[RIGHT_EYE_BONE_Z] = float(components[2]) * math.pi / 180
elif part.startswith("leftEye#"):
components = part[len("leftEye#"):].split(",")
output[LEFT_EYE_BONE_X] = float(components[0]) * math.pi / 180
output[LEFT_EYE_BONE_Y] = float(components[1]) * math.pi / 180
output[LEFT_EYE_BONE_Z] = float(components[2]) * math.pi / 180
output[HEAD_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]
output[LEFT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]
output[RIGHT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]
return output
def parse_ifacialmocap_v1_pose(ifacialmocap_output):
output = {}
parts = ifacialmocap_output.split("|")
for part in parts:
part = part.strip()
if len(part) == 0:
continue
if part.startswith("=head#"):
components = part[len("=head#"):].split(",")
assert len(components) == 6
output[HEAD_BONE_X] = float(components[0]) * math.pi / 180
output[HEAD_BONE_Y] = float(components[1]) * math.pi / 180
output[HEAD_BONE_Z] = float(components[2]) * math.pi / 180
elif part.startswith("rightEye#"):
components = part[len("rightEye#"):].split(",")
output[RIGHT_EYE_BONE_X] = float(components[0]) * math.pi / 180
output[RIGHT_EYE_BONE_Y] = float(components[1]) * math.pi / 180
output[RIGHT_EYE_BONE_Z] = float(components[2]) * math.pi / 180
elif part.startswith("leftEye#"):
components = part[len("leftEye#"):].split(",")
output[LEFT_EYE_BONE_X] = float(components[0]) * math.pi / 180
output[LEFT_EYE_BONE_Y] = float(components[1]) * math.pi / 180
output[LEFT_EYE_BONE_Z] = float(components[2]) * math.pi / 180
else:
components = part.split("-")
assert len(components) == 2
key = components[0]
value = float(components[1]) / 100.0
if key.endswith("_L"):
key = key[:-2] + "Left"
elif key.endswith("_R"):
key = key[:-2] + "Right"
if key in BLENDSHAPE_NAMES:
output[key] = value
output[HEAD_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]
output[LEFT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]
output[RIGHT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]
return output
================================================
FILE: src/tha4/mocap/mediapipe_constants.py
================================================
EYE_LOOK_IN_LEFT = "eyeLookInLeft"
EYE_LOOK_OUT_LEFT = "eyeLookOutLeft"
EYE_LOOK_DOWN_LEFT = "eyeLookDownLeft"
EYE_LOOK_UP_LEFT = "eyeLookUpLeft"
EYE_BLINK_LEFT = "eyeBlinkLeft"
EYE_SQUINT_LEFT = "eyeSquintLeft"
EYE_WIDE_LEFT = "eyeWideLeft"
EYE_LOOK_IN_RIGHT = "eyeLookInRight"
EYE_LOOK_OUT_RIGHT = "eyeLookOutRight"
EYE_LOOK_DOWN_RIGHT = "eyeLookDownRight"
EYE_LOOK_UP_RIGHT = "eyeLookUpRight"
EYE_BLINK_RIGHT = "eyeBlinkRight"
EYE_SQUINT_RIGHT = "eyeSquintRight"
EYE_WIDE_RIGHT = "eyeWideRight"
BROW_DOWN_LEFT = "browDownLeft"
BROW_OUTER_UP_LEFT = "browOuterUpLeft"
BROW_DOWN_RIGHT = "browDownRight"
BROW_OUTER_UP_RIGHT = "browOuterUpRight"
BROW_INNER_UP = "browInnerUp"
NOSE_SNEER_LEFT = "noseSneerLeft"
NOSE_SNEER_RIGHT = "noseSneerRight"
CHEEK_SQUINT_LEFT = "cheekSquintLeft"
CHEEK_SQUINT_RIGHT = "cheekSquintRight"
CHEEK_PUFF = "cheekPuff"
MOUTH_LEFT = "mouthLeft"
MOUTH_DIMPLE_LEFT = "mouthDimpleLeft"
MOUTH_FROWN_LEFT = "mouthFrownLeft"
MOUTH_LOWER_DOWN_LEFT = "mouthLowerDownLeft"
MOUTH_PRESS_LEFT = "mouthPressLeft"
MOUTH_SMILE_LEFT = "mouthSmileLeft"
MOUTH_STRETCH_LEFT = "mouthStretchLeft"
MOUTH_UPPER_UP_LEFT = "mouthUpperUpLeft"
MOUTH_RIGHT = "mouthRight"
MOUTH_DIMPLE_RIGHT = "mouthDimpleRight"
MOUTH_FROWN_RIGHT = "mouthFrownRight"
MOUTH_LOWER_DOWN_RIGHT = "mouthLowerDownRight"
MOUTH_PRESS_RIGHT = "mouthPressRight"
MOUTH_SMILE_RIGHT = "mouthSmileRight"
MOUTH_STRETCH_RIGHT = "mouthStretchRight"
MOUTH_UPPER_UP_RIGHT = "mouthUpperUpRight"
MOUTH_CLOSE = "mouthClose"
MOUTH_FUNNEL = "mouthFunnel"
MOUTH_PUCKER = "mouthPucker"
MOUTH_ROLL_LOWER = "mouthRollLower"
MOUTH_ROLL_UPPER = "mouthRollUpper"
MOUTH_SHRUG_LOWER = "mouthShrugLower"
MOUTH_SHRUG_UPPER = "mouthShrugUpper"
JAW_LEFT = "jawLeft"
JAW_RIGHT = "jawRight"
JAW_FORWARD = "jawForward"
JAW_OPEN = "jawOpen"
NEUTRAL = "_neutral"
BLENDSHAPE_NAMES = [
EYE_LOOK_IN_LEFT, # 0
EYE_LOOK_OUT_LEFT, # 1
EYE_LOOK_DOWN_LEFT, # 2
EYE_LOOK_UP_LEFT, # 3
EYE_BLINK_LEFT, # 4
EYE_SQUINT_LEFT, # 5
EYE_WIDE_LEFT, # 6
EYE_LOOK_IN_RIGHT, # 7
EYE_LOOK_OUT_RIGHT, # 8
EYE_LOOK_DOWN_RIGHT, # 9
EYE_LOOK_UP_RIGHT, # 10
EYE_BLINK_RIGHT, # 11
EYE_SQUINT_RIGHT, # 12
EYE_WIDE_RIGHT, # 13
BROW_DOWN_LEFT, # 14
BROW_OUTER_UP_LEFT, # 15
BROW_DOWN_RIGHT, # 16
BROW_OUTER_UP_RIGHT, # 17
BROW_INNER_UP, # 18
NOSE_SNEER_LEFT, # 19
NOSE_SNEER_RIGHT, # 20
CHEEK_SQUINT_LEFT, # 21
CHEEK_SQUINT_RIGHT, # 22
CHEEK_PUFF, # 23
MOUTH_LEFT, # 24
MOUTH_DIMPLE_LEFT, # 25
MOUTH_FROWN_LEFT, # 26
MOUTH_LOWER_DOWN_LEFT, # 27
MOUTH_PRESS_LEFT, # 28
MOUTH_SMILE_LEFT, # 29
MOUTH_STRETCH_LEFT, # 30
MOUTH_UPPER_UP_LEFT, # 31
MOUTH_RIGHT, # 32
MOUTH_DIMPLE_RIGHT, # 33
MOUTH_FROWN_RIGHT, # 34
MOUTH_LOWER_DOWN_RIGHT, # 35
MOUTH_PRESS_RIGHT, # 36
MOUTH_SMILE_RIGHT, # 37
MOUTH_STRETCH_RIGHT, # 38
MOUTH_UPPER_UP_RIGHT, # 39
MOUTH_CLOSE, # 40
MOUTH_FUNNEL, # 41
MOUTH_PUCKER, # 42
MOUTH_ROLL_LOWER, # 43
MOUTH_ROLL_UPPER, # 44
MOUTH_SHRUG_LOWER, # 45
MOUTH_SHRUG_UPPER, # 46
JAW_LEFT, # 47
JAW_RIGHT, # 48
JAW_FORWARD, # 49
JAW_OPEN, # 50
NEUTRAL, # 51
]
EYE_LEFT_BLENDSHAPES = [
EYE_LOOK_IN_LEFT, # 0
EYE_LOOK_OUT_LEFT, # 1
EYE_LOOK_DOWN_LEFT, # 2
EYE_LOOK_UP_LEFT, # 3
EYE_BLINK_LEFT, # 4
EYE_SQUINT_LEFT, # 5
EYE_WIDE_LEFT, # 6
]
EYE_RIGHT_BLENDSHAPES = [
EYE_LOOK_IN_RIGHT, # 7
EYE_LOOK_OUT_RIGHT, # 8
EYE_LOOK_DOWN_RIGHT, # 9
EYE_LOOK_UP_RIGHT, # 10
EYE_BLINK_RIGHT, # 11
EYE_SQUINT_RIGHT, # 12
EYE_WIDE_RIGHT, # 13
]
BROW_LEFT_BLENDSHAPES = [
BROW_DOWN_LEFT, # 14
BROW_OUTER_UP_LEFT, # 15
]
BROW_RIGHT_BLENDSHAPES = [
BROW_DOWN_RIGHT, # 16
BROW_OUTER_UP_RIGHT, # 17
]
BROW_BOTH_BLENDSHAPES = [
BROW_INNER_UP, # 18
]
NOSE_BLENDSHAPES = [
NOSE_SNEER_LEFT, # 19
NOSE_SNEER_RIGHT, # 20
]
CHECK_BLENDSHAPES = [
CHEEK_SQUINT_LEFT, # 21
CHEEK_SQUINT_RIGHT, # 22
CHEEK_PUFF, # 23
]
MOUTH_LEFT_BLENDSHAPES = [
MOUTH_LEFT, # 24
MOUTH_DIMPLE_LEFT, # 25
MOUTH_FROWN_LEFT, # 26
MOUTH_LOWER_DOWN_LEFT, # 27
MOUTH_PRESS_LEFT, # 28
MOUTH_SMILE_LEFT, # 29
MOUTH_STRETCH_LEFT, # 30
MOUTH_UPPER_UP_LEFT, # 31
]
MOUTH_RIGHT_BLENDSHAPES = [
MOUTH_RIGHT, # 32
MOUTH_DIMPLE_RIGHT, # 33
MOUTH_FROWN_RIGHT, # 34
MOUTH_LOWER_DOWN_RIGHT, # 35
MOUTH_PRESS_RIGHT, # 36
MOUTH_SMILE_RIGHT, # 37
MOUTH_STRETCH_RIGHT, # 38
MOUTH_UPPER_UP_RIGHT, # 39
]
MOUTH_BOTH_BLENDSHAPES = [
MOUTH_CLOSE, # 40
MOUTH_FUNNEL, # 41
MOUTH_PUCKER, # 42
MOUTH_ROLL_LOWER, # 43
MOUTH_ROLL_UPPER, # 44
MOUTH_SHRUG_LOWER, # 45
MOUTH_SHRUG_UPPER, # 46
]
JAW_BLENDSHAPES = [
JAW_LEFT, # 47
JAW_RIGHT, # 48
JAW_FORWARD, # 49
JAW_OPEN, # 50
]
NEUTRAL_BLENDSHAPES = [
NEUTRAL, # 51
]
COLUMN_0_BLENDSHAPES = EYE_RIGHT_BLENDSHAPES + BROW_RIGHT_BLENDSHAPES + [NOSE_SNEER_RIGHT, CHEEK_SQUINT_RIGHT]
COLUMN_1_BLENDSHAPES = EYE_LEFT_BLENDSHAPES + BROW_LEFT_BLENDSHAPES + [NOSE_SNEER_LEFT, CHEEK_SQUINT_LEFT]
COLUMN_2_BLENDSHAPES = MOUTH_RIGHT_BLENDSHAPES + [JAW_RIGHT]
COLUMN_3_BLENDSHAPES = MOUTH_LEFT_BLENDSHAPES + [JAW_LEFT]
COLUMN_4_BLENDSHAPES = [BROW_INNER_UP, CHEEK_PUFF] + MOUTH_BOTH_BLENDSHAPES + [JAW_FORWARD, JAW_OPEN, NEUTRAL]
BLENDSHAPE_COLUMNS = [
COLUMN_0_BLENDSHAPES,
COLUMN_1_BLENDSHAPES,
COLUMN_2_BLENDSHAPES,
COLUMN_3_BLENDSHAPES,
COLUMN_4_BLENDSHAPES,
]
HEAD_X = "headX"
HEAD_Y = "headY"
HEAD_Z = "headZ"
HEAD_ROTATIONS = [HEAD_X, HEAD_Y, HEAD_Z]
================================================
FILE: src/tha4/mocap/mediapipe_face_pose.py
================================================
import json
import os
from typing import Optional, Dict
import numpy
class MediaPipeFacePose:
KEY_BLENDSHAPE_PARAMS = "blendshape_params"
KEY_XFORM_MATRIX = "xform_matrix"
def __init__(self, blendshape_params: Optional[Dict[str, float]], xform_matrix: Optional[numpy.ndarray]):
if blendshape_params is None:
blendshape_params = {}
if xform_matrix is None:
self.xform_matrix = numpy.zeros(4, 4)
for i in range(4):
self.xform_matrix[i, i] = 1.0
self.blendshape_params = blendshape_params
self.xform_matrix = xform_matrix
def get_json(self):
return {
MediaPipeFacePose.KEY_BLENDSHAPE_PARAMS: self.blendshape_params.copy(),
MediaPipeFacePose.KEY_XFORM_MATRIX: self.xform_matrix.tolist()
}
def save(self, file_name: str):
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, "wt") as fout:
fout.write(json.dumps(self.get_json()))
@staticmethod
def load(file_name: str):
with open(file_name, "rt") as fin:
s = fin.read()
json_data = json.loads(s)
return MediaPipeFacePose(
json_data[MediaPipeFacePose.KEY_BLENDSHAPE_PARAMS],
xform_matrix = numpy.array(json_data[MediaPipeFacePose.KEY_XFORM_MATRIX]))
================================================
FILE: src/tha4/mocap/mediapipe_face_pose_converter.py
================================================
from abc import ABC, abstractmethod
from typing import List, Callable, Optional
from tha4.mocap.mediapipe_face_pose import MediaPipeFacePose
class MediaPipeFacePoseConverter(ABC):
@abstractmethod
def convert(self, mediapipe_face_pose: MediaPipeFacePose) -> List[float]:
pass
@abstractmethod
def init_pose_converter_panel(
self,
parent,
current_pose_supplier: Callable[[], Optional[MediaPipeFacePose]]):
pass
================================================
FILE: src/tha4/mocap/mediapipe_face_pose_converter_00.py
================================================
import math
import time
from enum import Enum
from typing import Optional, List, Callable
import numpy
import scipy.optimize
import wx
from scipy.spatial.transform import Rotation
from tha4.poser.modes.pose_parameters import get_pose_parameters
from tha4.mocap.mediapipe_constants import MOUTH_SMILE_LEFT, MOUTH_SHRUG_UPPER, MOUTH_SMILE_RIGHT, \
BROW_INNER_UP, BROW_OUTER_UP_RIGHT, BROW_OUTER_UP_LEFT, BROW_DOWN_LEFT, BROW_DOWN_RIGHT, EYE_WIDE_LEFT, \
EYE_WIDE_RIGHT, EYE_BLINK_LEFT, EYE_BLINK_RIGHT, CHEEK_SQUINT_LEFT, CHEEK_SQUINT_RIGHT, EYE_LOOK_IN_LEFT, \
EYE_LOOK_OUT_LEFT, EYE_LOOK_IN_RIGHT, EYE_LOOK_OUT_RIGHT, EYE_LOOK_UP_LEFT, EYE_LOOK_UP_RIGHT, EYE_LOOK_DOWN_RIGHT, \
EYE_LOOK_DOWN_LEFT, JAW_OPEN, MOUTH_FROWN_LEFT, MOUTH_FROWN_RIGHT, \
MOUTH_LOWER_DOWN_LEFT, MOUTH_LOWER_DOWN_RIGHT, MOUTH_FUNNEL, MOUTH_PUCKER
from tha4.mocap.mediapipe_face_pose import MediaPipeFacePose
from tha4.mocap.mediapipe_face_pose_converter import MediaPipeFacePoseConverter
class EyebrowDownMode(Enum):
TROUBLED = 1
ANGRY = 2
LOWERED = 3
SERIOUS = 4
class WinkMode(Enum):
NORMAL = 1
RELAXED = 2
def rad_to_deg(rad):
return rad * 180.0 / math.pi
def deg_to_rad(deg):
return deg * math.pi / 180.0
def clamp(x, min_value, max_value):
return max(min_value, min(max_value, x))
class MediaPipeFacePoseConverter00Args:
def __init__(self,
smile_threshold_min: float = 0.4,
smile_threshold_max: float = 0.6,
eyebrow_down_mode: EyebrowDownMode = EyebrowDownMode.ANGRY,
wink_mode: WinkMode = WinkMode.NORMAL,
eye_surprised_max: float = 0.5,
eye_blink_max: float = 0.8,
eyebrow_down_max: float = 0.4,
cheek_squint_min: float = 0.1,
cheek_squint_max: float = 0.7,
eye_rotation_factor: float = 1.0 / 0.75,
jaw_open_min: float = 0.1,
jaw_open_max: float = 0.4,
mouth_frown_max: float = 0.6,
mouth_funnel_min: float = 0.25,
mouth_funnel_max: float = 0.5,
iris_small_left=0.0,
iris_small_right=0.0,
head_x_offset=0.0,
head_y_offset=0.0,
head_z_offset=0.0):
self.iris_small_right = iris_small_left
self.iris_small_left = iris_small_right
self.wink_mode = wink_mode
self.mouth_funnel_max = mouth_funnel_max
self.mouth_funnel_min = mouth_funnel_min
self.mouth_frown_max = mouth_frown_max
self.jaw_open_max = jaw_open_max
self.jaw_open_min = jaw_open_min
self.eye_rotation_factor = eye_rotation_factor
self.cheek_squint_max = cheek_squint_max
self.cheek_squint_min = cheek_squint_min
self.eyebrow_down_max = eyebrow_down_max
self.eye_blink_max = eye_blink_max
self.eye_surprised_max = eye_surprised_max
self.smile_threshold_min = smile_threshold_min
self.smile_threshold_max = smile_threshold_max
self.head_z_offset = head_z_offset
self.head_y_offset = head_y_offset
self.head_x_offset = head_x_offset
self.eyebrow_down_mode = eyebrow_down_mode
def set_smile_threshold_min(self, new_value: float):
self.smile_threshold_min = new_value
def set_smile_threshold_max(self, new_value: float):
self.smile_threshold_max = new_value
def set_eye_surprised_max(self, new_value: float):
self.eye_surprised_max = new_value
def set_eye_blink_max(self, new_value: float):
self.eye_blink_max = new_value
def set_eyebrow_down_max(self, new_value: float):
self.eyebrow_down_max = new_value
def set_cheek_squint_min(self, new_value: float):
self.cheek_squint_min = new_value
def set_cheek_squint_max(self, new_value: float):
self.cheek_squint_max = new_value
def set_jaw_open_min(self, new_value: float):
self.jaw_open_min = new_value
def set_jaw_open_max(self, new_value: float):
self.jaw_open_max = new_value
def set_mouth_frown_max(self, new_value: float):
self.mouth_frown_max = new_value
def set_mouth_funnel_min(self, new_value: float):
self.mouth_funnel_min = new_value
def set_mouth_funnel_max(self, new_value: float):
self.mouth_funnel_min = new_value
class MediaPoseFacePoseConverter00(MediaPipeFacePoseConverter):
def __init__(self, args: Optional[MediaPipeFacePoseConverter00Args] = None):
super().__init__()
if args is None:
args = MediaPipeFacePoseConverter00Args()
self.args = args
pose_parameters = get_pose_parameters()
self.pose_size = 45
self.eyebrow_troubled_left_index = pose_parameters.get_parameter_index("eyebrow_troubled_left")
self.eyebrow_troubled_right_index = pose_parameters.get_parameter_index("eyebrow_troubled_right")
self.eyebrow_angry_left_index = pose_parameters.get_parameter_index("eyebrow_angry_left")
self.eyebrow_angry_right_index = pose_parameters.get_parameter_index("eyebrow_angry_right")
self.eyebrow_happy_left_index = pose_parameters.get_parameter_index("eyebrow_happy_left")
self.eyebrow_happy_right_index = pose_parameters.get_parameter_index("eyebrow_happy_right")
self.eyebrow_raised_left_index = pose_parameters.get_parameter_index("eyebrow_raised_left")
self.eyebrow_raised_right_index = pose_parameters.get_parameter_index("eyebrow_raised_right")
self.eyebrow_lowered_left_index = pose_parameters.get_parameter_index("eyebrow_lowered_left")
self.eyebrow_lowered_right_index = pose_parameters.get_parameter_index("eyebrow_lowered_right")
self.eyebrow_serious_left_index = pose_parameters.get_parameter_index("eyebrow_serious_left")
self.eyebrow_serious_right_index = pose_parameters.get_parameter_index("eyebrow_serious_right")
self.eye_surprised_left_index = pose_parameters.get_parameter_index("eye_surprised_left")
self.eye_surprised_right_index = pose_parameters.get_parameter_index("eye_surprised_right")
self.eye_wink_left_index = pose_parameters.get_parameter_index("eye_wink_left")
self.eye_wink_right_index = pose_parameters.get_parameter_index("eye_wink_right")
self.eye_happy_wink_left_index = pose_parameters.get_parameter_index("eye_happy_wink_left")
self.eye_happy_wink_right_index = pose_parameters.get_parameter_index("eye_happy_wink_right")
self.eye_relaxed_left_index = pose_parameters.get_parameter_index("eye_relaxed_left")
self.eye_relaxed_right_index = pose_parameters.get_parameter_index("eye_relaxed_right")
self.eye_raised_lower_eyelid_left_index = pose_parameters.get_parameter_index("eye_raised_lower_eyelid_left")
self.eye_raised_lower_eyelid_right_index = pose_parameters.get_parameter_index("eye_raised_lower_eyelid_right")
self.iris_small_left_index = pose_parameters.get_parameter_index("iris_small_left")
self.iris_small_right_index = pose_parameters.get_parameter_index("iris_small_right")
self.iris_rotation_x_index = pose_parameters.get_parameter_index("iris_rotation_x")
self.iris_rotation_y_index = pose_parameters.get_parameter_index("iris_rotation_y")
self.head_x_index = pose_parameters.get_parameter_index("head_x")
self.head_y_index = pose_parameters.get_parameter_index("head_y")
self.neck_z_index = pose_parameters.get_parameter_index("neck_z")
self.mouth_aaa_index = pose_parameters.get_parameter_index("mouth_aaa")
self.mouth_iii_index = pose_parameters.get_parameter_index("mouth_iii")
self.mouth_uuu_index = pose_parameters.get_parameter_index("mouth_uuu")
self.mouth_eee_index = pose_parameters.get_parameter_index("mouth_eee")
self.mouth_ooo_index = pose_parameters.get_parameter_index("mouth_ooo")
self.mouth_lowered_corner_left_index = pose_parameters.get_parameter_index("mouth_lowered_corner_left")
self.mouth_lowered_corner_right_index = pose_parameters.get_parameter_index("mouth_lowered_corner_right")
self.mouth_raised_corner_left_index = pose_parameters.get_parameter_index("mouth_raised_corner_left")
self.mouth_raised_corner_right_index = pose_parameters.get_parameter_index("mouth_raised_corner_right")
self.body_y_index = pose_parameters.get_parameter_index("body_y")
self.body_z_index = pose_parameters.get_parameter_index("body_z")
self.breathing_index = pose_parameters.get_parameter_index("breathing")
self.breathing_start_time = time.time()
self.panel = None
self.current_pose_supplier = None
def init_pose_converter_panel(
self,
parent,
current_pose_supplier: Callable[[], Optional[MediaPipeFacePose]]):
self.panel = wx.Panel(parent, style=wx.SIMPLE_BORDER)
self.panel_sizer = wx.BoxSizer(wx.VERTICAL)
self.panel.SetSizer(self.panel_sizer)
self.panel.SetAutoLayout(1)
parent.GetSizer().Add(self.panel, 0, wx.EXPAND)
self.current_pose_supplier = current_pose_supplier
if True:
eyebrow_down_mode_text = wx.StaticText(self.panel, label=" --- Eyebrow Down Mode --- ",
style=wx.ALIGN_CENTER)
self.panel_sizer.Add(eyebrow_down_mode_text, 0, wx.EXPAND)
self.eyebrow_down_mode_choice = wx.Choice(
self.panel,
choices=[
"ANGRY",
"TROUBLED",
"SERIOUS",
"LOWERED",
])
self.eyebrow_down_mode_choice.SetSelection(0)
self.panel_sizer.Add(self.eyebrow_down_mode_choice, 0, wx.EXPAND)
self.eyebrow_down_mode_choice.Bind(wx.EVT_CHOICE, self.change_eyebrow_down_mode)
if True:
separator = wx.StaticLine(self.panel, -1, size=(256, 5))
self.panel_sizer.Add(separator, 0, wx.EXPAND)
wink_mode_text = wx.StaticText(self.panel, label=" --- Wink Mode --- ", style=wx.ALIGN_CENTER)
self.panel_sizer.Add(wink_mode_text, 0, wx.EXPAND)
self.wink_mode_choice = wx.Choice(
self.panel,
choices=[
"NORMAL",
"RELAXED",
])
self.wink_mode_choice.SetSelection(0)
self.panel_sizer.Add(self.wink_mode_choice, 0, wx.EXPAND)
self.wink_mode_choice.Bind(wx.EVT_CHOICE, self.change_wink_mode)
if True:
separator = wx.StaticLine(self.panel, -1, size=(256, 5))
self.panel_sizer.Add(separator, 0, wx.EXPAND)
iris_size_text = wx.StaticText(self.panel, label=" --- Iris Size --- ", style=wx.ALIGN_CENTER)
self.panel_sizer.Add(iris_size_text, 0, wx.EXPAND)
self.iris_left_slider = wx.Slider(self.panel, minValue=0, maxValue=1000, value=0, style=wx.HORIZONTAL)
self.panel_sizer.Add(self.iris_left_slider, 0, wx.EXPAND)
self.iris_left_slider.Bind(wx.EVT_SLIDER, self.change_iris_size)
self.iris_right_slider = wx.Slider(self.panel, minValue=0, maxValue=1000, value=0, style=wx.HORIZONTAL)
self.panel_sizer.Add(self.iris_right_slider, 0, wx.EXPAND)
self.iris_right_slider.Bind(wx.EVT_SLIDER, self.change_iris_size)
self.iris_right_slider.Enable(False)
self.link_left_right_irises = wx.CheckBox(
self.panel, label="Use same value for both sides")
self.link_left_right_irises.SetValue(True)
self.panel_sizer.Add(self.link_left_right_irises, wx.SizerFlags().CenterHorizontal().Border())
self.link_left_right_irises.Bind(wx.EVT_CHECKBOX, self.link_left_right_irises_clicked)
if True:
separator = wx.StaticLine(self.panel, -1, size=(256, 5))
self.panel_sizer.Add(separator, 0, wx.EXPAND)
breathing_frequency_text = wx.StaticText(
self.panel, label=" --- Breathing --- ", style=wx.ALIGN_CENTER)
self.panel_sizer.Add(breathing_frequency_text, 0, wx.EXPAND)
self.restart_breathing_cycle_button = wx.Button(self.panel, label="Restart Breathing Cycle")
self.restart_breathing_cycle_button.Bind(wx.EVT_BUTTON, self.restart_breathing_cycle_clicked)
self.panel_sizer.Add(self.restart_breathing_cycle_button, 0, wx.EXPAND)
self.breathing_frequency_slider = wx.Slider(
self.panel, minValue=0, maxValue=60, value=20, style=wx.HORIZONTAL)
self.panel_sizer.Add(self.breathing_frequency_slider, 0, wx.EXPAND)
self.breathing_gauge = wx.Gauge(self.panel, style=wx.GA_HORIZONTAL, range=1000)
self.panel_sizer.Add(self.breathing_gauge, 0, wx.EXPAND)
if True:
separator = wx.StaticLine(self.panel, -1, size=(256, 5))
self.panel_sizer.Add(separator, 0, wx.EXPAND)
face_orientation_text = wx.StaticText(
self.panel, label="--- Face Orientation ---", style=wx.ALIGN_CENTER)
self.panel_sizer.Add(face_orientation_text, 0, wx.EXPAND)
self.calibrate_face_orientation_button = wx.Button(self.panel, label="Calibrate (I'm looking forward)")
self.calibrate_face_orientation_button.Bind(wx.EVT_BUTTON, self.calibrate_face_orientation_clicked)
self.panel_sizer.Add(self.calibrate_face_orientation_button, 0, wx.EXPAND)
if True:
separator = wx.StaticLine(self.panel, -1, size=(256, 5))
self.panel_sizer.Add(separator, 0, wx.EXPAND)
convertion_parameters_text = wx.StaticText(
self.panel, label="--- Conversion Parameters ---", style=wx.ALIGN_CENTER)
self.panel_sizer.Add(convertion_parameters_text, 0, wx.EXPAND)
conversion_param_panel = wx.Panel(self.panel)
self.panel_sizer.Add(conversion_param_panel, 0, wx.EXPAND)
conversion_panel_sizer = wx.FlexGridSizer(cols=2)
conversion_panel_sizer.AddGrowableCol(1)
conversion_param_panel.SetSizer(conversion_panel_sizer)
conversion_param_panel.SetAutoLayout(1)
self.smile_thresold_min_spin = self.create_spin_control(
conversion_param_panel,
"Smile Threshold Min:", self.args.smile_threshold_min, self.args.set_smile_threshold_min)
self.smile_thresold_max_spin = self.create_spin_control(
conversion_param_panel,
"Smile Threshold Max:", self.args.smile_threshold_max, self.args.set_smile_threshold_max)
self.eye_surprised_max_spin = self.create_spin_control(
conversion_param_panel,
"Eye Surprised Max:", self.args.eye_surprised_max, self.args.set_eye_surprised_max)
self.eye_blink_max_spin = self.create_spin_control(
conversion_param_panel,
"Eye Blink Max:", self.args.eye_blink_max, self.args.set_eye_blink_max)
self.eyebrow_down_max_spin = self.create_spin_control(
conversion_param_panel,
"Eyebrow Down Max:", self.args.eyebrow_down_max, self.args.set_eyebrow_down_max)
self.cheek_squint_min_spin = self.create_spin_control(
conversion_param_panel,
"Cheek Squint Min:", self.args.cheek_squint_min, self.args.set_cheek_squint_min)
self.cheek_squint_max_spin = self.create_spin_control(
conversion_param_panel,
"Cheek Squint Max:", self.args.cheek_squint_max, self.args.set_cheek_squint_max)
self.jaw_open_min_spin = self.create_spin_control(
conversion_param_panel,
"Jaw Open Min:", self.args.jaw_open_min, self.args.set_jaw_open_min)
self.jaw_open_max_spin = self.create_spin_control(
conversion_param_panel,
"Jaw Open Max:", self.args.jaw_open_max, self.args.set_jaw_open_max)
self.mouth_frown_max_spin = self.create_spin_control(
conversion_param_panel,
"Mouth Frown Max:", self.args.mouth_frown_max, self.args.set_mouth_frown_max)
self.mouth_funnel_min_spin = self.create_spin_control(
conversion_param_panel,
"Mouth Funnel Min:", self.args.mouth_funnel_min, self.args.set_mouth_funnel_min)
self.mouth_funnel_max_spin = self.create_spin_control(
conversion_param_panel,
"Mouth Funnel Max:", self.args.mouth_funnel_max, self.args.set_mouth_funnel_max)
self.panel_sizer.Fit(self.panel)
def create_spin_control(self, parent, label: str, initial_value: float, set_func: Callable[[float], None]):
sizer = parent.GetSizer()
text = wx.StaticText(parent, label=label)
sizer.Add(text, wx.SizerFlags().Right().Border(wx.ALL, 2))
spin_ctrl = wx.SpinCtrlDouble(
parent,
wx.ID_ANY,
min=0.0,
max=1.0,
initial=initial_value,
inc=0.01)
sizer.Add(spin_ctrl, wx.SizerFlags().Border(wx.ALL, 2).Expand())
def handler(event: wx.Event):
new_value = spin_ctrl.GetValue()
set_func(new_value)
spin_ctrl.Bind(wx.EVT_SPINCTRLDOUBLE, handler)
return spin_ctrl
def extract_euler_angles(self, mediapipe_face_pose: MediaPipeFacePose):
M = mediapipe_face_pose.xform_matrix[0:3, 0:3]
rot = Rotation.from_matrix(M)
return rot.as_euler('xyz', degrees=False)
def calibrate_face_orientation_clicked(self, event: wx.Event):
if self.current_pose_supplier is None:
return
mediapipe_face_pose = self.current_pose_supplier()
if mediapipe_face_pose is None:
return
euler_angles = self.extract_euler_angles(mediapipe_face_pose)
self.args.head_x_offset = euler_angles[0]
self.args.head_y_offset = euler_angles[1]
self.args.head_z_offset = euler_angles[2]
def restart_breathing_cycle_clicked(self, event: wx.Event):
self.breathing_start_time = time.time()
def change_eyebrow_down_mode(self, event: wx.Event):
selected_index = self.eyebrow_down_mode_choice.GetSelection()
if selected_index == 0:
self.args.eyebrow_down_mode = EyebrowDownMode.ANGRY
elif selected_index == 1:
self.args.eyebrow_down_mode = EyebrowDownMode.TROUBLED
elif selected_index == 2:
self.args.eyebrow_down_mode = EyebrowDownMode.SERIOUS
else:
self.args.eyebrow_down_mode = EyebrowDownMode.LOWERED
def change_wink_mode(self, event: wx.Event):
selected_index = self.wink_mode_choice.GetSelection()
if selected_index == 0:
self.args.wink_mode = WinkMode.NORMAL
else:
self.args.wink_mode = WinkMode.RELAXED
def change_iris_size(self, event: wx.Event):
if self.link_left_right_irises.GetValue():
left_value = self.iris_left_slider.GetValue()
right_value = self.iris_right_slider.GetValue()
if left_value != right_value:
self.iris_right_slider.SetValue(left_value)
self.args.iris_small_left = left_value / 1000.0
self.args.iris_small_right = left_value / 1000.0
else:
self.args.iris_small_left = self.iris_left_slider.GetValue() / 1000.0
self.args.iris_small_right = self.iris_right_slider.GetValue() / 1000.0
def link_left_right_irises_clicked(self, event: wx.Event):
if self.link_left_right_irises.GetValue():
self.iris_right_slider.Enable(False)
else:
self.iris_right_slider.Enable(True)
self.change_iris_size(event)
def decompose_head_body_param(self, param, threshold=2.0 / 3):
if abs(param) < threshold:
return (param, 0.0)
else:
if param < 0:
sign = -1.0
else:
sign = 1.0
return (threshold * sign, (abs(param) - threshold) * sign)
def convert(self, mediapipe_face_pose: MediaPipeFacePose) -> List[float]:
pose = [0.0 for i in range(self.pose_size)]
blendshape_params = mediapipe_face_pose.blendshape_params
smile_value = \
(blendshape_params[MOUTH_SMILE_LEFT] + blendshape_params[MOUTH_SMILE_RIGHT]) / 2.0 \
+ blendshape_params[MOUTH_SHRUG_UPPER]
if self.args.smile_threshold_min >= self.args.smile_threshold_max:
smile_degree = 0.0
else:
if smile_value < self.args.smile_threshold_min:
smile_degree = 0.0
elif smile_value > self.args.smile_threshold_max:
smile_degree = 1.0
else:
smile_degree = (smile_value - self.args.smile_threshold_min) / (
self.args.smile_threshold_max - self.args.smile_threshold_min)
# Eyebrow
if True:
brow_inner_up = blendshape_params[BROW_INNER_UP]
brow_outer_up_right = blendshape_params[BROW_OUTER_UP_RIGHT]
brow_outer_up_left = blendshape_params[BROW_OUTER_UP_LEFT]
brow_up_left = clamp(brow_inner_up + brow_outer_up_left, 0.0, 1.0)
brow_up_right = clamp(brow_inner_up + brow_outer_up_right, 0.0, 1.0)
pose[self.eyebrow_raised_left_index] = brow_up_left
pose[self.eyebrow_raised_right_index] = brow_up_right
if self.args.eyebrow_down_max <= 0.0:
brow_down_left = 0.0
brow_down_right = 0.0
else:
brow_down_left = (1.0 - smile_degree) \
* clamp(blendshape_params[BROW_DOWN_LEFT] / self.args.eyebrow_down_max, 0.0, 1.0)
brow_down_right = (1.0 - smile_degree) \
* clamp(blendshape_params[BROW_DOWN_RIGHT] / self.args.eyebrow_down_max, 0.0, 1.0)
if self.args.eyebrow_down_mode == EyebrowDownMode.TROUBLED:
pose[self.eyebrow_troubled_left_index] = brow_down_left
pose[self.eyebrow_troubled_right_index] = brow_down_right
elif self.args.eyebrow_down_mode == EyebrowDownMode.ANGRY:
pose[self.eyebrow_angry_left_index] = brow_down_left
pose[self.eyebrow_angry_right_index] = brow_down_right
elif self.args.eyebrow_down_mode == EyebrowDownMode.LOWERED:
pose[self.eyebrow_lowered_left_index] = brow_down_left
pose[self.eyebrow_lowered_right_index] = brow_down_right
elif self.args.eyebrow_down_mode == EyebrowDownMode.SERIOUS:
pose[self.eyebrow_serious_left_index] = brow_down_left
pose[self.eyebrow_serious_right_index] = brow_down_right
brow_happy_value = clamp(smile_value, 0.0, 1.0) * smile_degree
pose[self.eyebrow_happy_left_index] = brow_happy_value
pose[self.eyebrow_happy_right_index] = brow_happy_value
# Eye
if True:
# Surprised
if self.args.eye_surprised_max <= 0.0:
pose[self.eye_surprised_left_index] = 0.0
pose[self.eye_surprised_right_index] = 0.0
else:
pose[self.eye_surprised_left_index] = clamp(
blendshape_params[EYE_WIDE_LEFT] / self.args.eye_surprised_max, 0.0, 1.0)
pose[self.eye_surprised_right_index] = clamp(
blendshape_params[EYE_WIDE_RIGHT] / self.args.eye_surprised_max, 0.0, 1.0)
# Wink
if self.args.wink_mode == WinkMode.NORMAL:
wink_left_index = self.eye_wink_left_index
wink_right_index = self.eye_wink_right_index
else:
wink_left_index = self.eye_relaxed_left_index
wink_right_index = self.eye_relaxed_right_index
if self.args.eye_blink_max <= 0:
pose[wink_left_index] = 0.0
pose[wink_right_index] = 0.0
pose[self.eye_happy_wink_left_index] = 0.0
pose[self.eye_happy_wink_right_index] = 0.0
else:
pose[wink_left_index] = (1.0 - smile_degree) * clamp(
blendshape_params[EYE_BLINK_LEFT] / self.args.eye_blink_max, 0.0, 1.0)
pose[wink_right_index] = (1.0 - smile_degree) * clamp(
blendshape_params[EYE_BLINK_RIGHT] / self.args.eye_blink_max, 0.0, 1.0)
pose[self.eye_happy_wink_left_index] = smile_degree * clamp(
blendshape_params[EYE_BLINK_LEFT] / self.args.eye_blink_max, 0.0, 1.0)
pose[self.eye_happy_wink_right_index] = smile_degree * clamp(
blendshape_params[EYE_BLINK_RIGHT] / self.args.eye_blink_max, 0.0, 1.0)
# Lower eyelid
cheek_squint_denom = self.args.cheek_squint_max - self.args.cheek_squint_min
if cheek_squint_denom <= 0.0:
pose[self.eye_raised_lower_eyelid_left_index] = 0.0
pose[self.eye_raised_lower_eyelid_right_index] = 0.0
else:
pose[self.eye_raised_lower_eyelid_left_index] = \
clamp(
(blendshape_params[CHEEK_SQUINT_LEFT] - self.args.cheek_squint_min) / cheek_squint_denom,
0.0, 1.0)
pose[self.eye_raised_lower_eyelid_right_index] = \
clamp(
(blendshape_params[CHEEK_SQUINT_RIGHT] - self.args.cheek_squint_min) / cheek_squint_denom,
0.0, 1.0)
# Iris rotation
if True:
eye_rotation_y = (blendshape_params[EYE_LOOK_IN_LEFT]
- blendshape_params[EYE_LOOK_OUT_LEFT]
- blendshape_params[EYE_LOOK_IN_RIGHT]
+ blendshape_params[EYE_LOOK_OUT_RIGHT]) / 2.0 * self.args.eye_rotation_factor
pose[self.iris_rotation_y_index] = clamp(eye_rotation_y, -1.0, 1.0)
eye_rotation_x = (blendshape_params[EYE_LOOK_UP_LEFT]
+ blendshape_params[EYE_LOOK_UP_RIGHT]
- blendshape_params[EYE_LOOK_DOWN_LEFT]
- blendshape_params[EYE_LOOK_DOWN_RIGHT]) / 2.0 * self.args.eye_rotation_factor
pose[self.iris_rotation_x_index] = clamp(eye_rotation_x, -1.0, 1.0)
# Iris size
if True:
pose[self.iris_small_left_index] = self.args.iris_small_left
pose[self.iris_small_right_index] = self.args.iris_small_right
# Head rotation
if True:
euler_angles = self.extract_euler_angles(mediapipe_face_pose)
euler_angles[0] -= self.args.head_x_offset
euler_angles[1] -= self.args.head_y_offset
euler_angles[2] -= self.args.head_z_offset
x_param = clamp(-euler_angles[0] * 180.0 / math.pi, -15.0, 15.0) / 15.0
pose[self.head_x_index] = x_param
y_param = clamp(-euler_angles[1] * 180.0 / math.pi, -10.0, 10.0) / 10.0
pose[self.head_y_index] = y_param
pose[self.body_y_index] = y_param
z_param = clamp(euler_angles[2] * 180.0 / math.pi, -15.0, 15.0) / 15.0
pose[self.neck_z_index] = z_param
pose[self.body_z_index] = z_param
# Mouth
if True:
jaw_open_denom = self.args.jaw_open_max - self.args.jaw_open_min
if jaw_open_denom <= 0:
mouth_open = 0.0
else:
mouth_open = clamp((blendshape_params[JAW_OPEN] - self.args.jaw_open_min) / jaw_open_denom, 0.0, 1.0)
pose[self.mouth_aaa_index] = mouth_open
pose[self.mouth_raised_corner_left_index] = clamp(smile_value, 0.0, 1.0)
pose[self.mouth_raised_corner_right_index] = clamp(smile_value, 0.0, 1.0)
is_mouth_open = mouth_open > 0.0
if not is_mouth_open:
if self.args.mouth_frown_max <= 0:
mouth_frown_value = 0.0
else:
mouth_frown_value = clamp(
(blendshape_params[MOUTH_FROWN_LEFT] + blendshape_params[
MOUTH_FROWN_RIGHT]) / self.args.mouth_frown_max, 0.0, 1.0)
pose[self.mouth_lowered_corner_left_index] = mouth_frown_value
pose[self.mouth_lowered_corner_right_index] = mouth_frown_value
else:
mouth_lower_down = clamp(
blendshape_params[MOUTH_LOWER_DOWN_LEFT] + blendshape_params[MOUTH_LOWER_DOWN_RIGHT], 0.0, 1.0)
mouth_funnel = blendshape_params[MOUTH_FUNNEL]
mouth_pucker = blendshape_params[MOUTH_PUCKER]
mouth_point = [mouth_open, mouth_lower_down, mouth_funnel, mouth_pucker]
aaa_point = [1.0, 1.0, 0.0, 0.0]
iii_point = [0.0, 1.0, 0.0, 0.0]
uuu_point = [0.5, 0.3, 0.25, 0.75]
ooo_point = [1.0, 0.5, 0.5, 0.4]
decomp = numpy.array([0, 0, 0, 0])
M = numpy.array([
aaa_point,
iii_point,
uuu_point,
ooo_point
])
def loss(decomp):
return numpy.linalg.norm(numpy.matmul(decomp, M) - mouth_point) \
+ 0.01 * numpy.linalg.norm(decomp, ord=1)
opt_result = scipy.optimize.minimize(
loss, decomp, bounds=[(0.0, 1.0), (0.0, 1.0), (0.0, 1.0), (0.0, 1.0)])
decomp = opt_result["x"]
restricted_decomp = [decomp.item(0), decomp.item(1), decomp.item(2), decomp.item(3)]
pose[self.mouth_aaa_index] = restricted_decomp[0]
pose[self.mouth_iii_index] = restricted_decomp[1]
mouth_funnel_denom = self.args.mouth_funnel_max - self.args.mouth_funnel_min
if mouth_funnel_denom <= 0:
ooo_alpha = 0.0
uo_value = 0.0
else:
ooo_alpha = clamp((mouth_funnel - self.args.mouth_funnel_min) / mouth_funnel_denom, 0.0, 1.0)
uo_value = clamp(restricted_decomp[2] + restricted_decomp[3], 0.0, 1.0)
pose[self.mouth_uuu_index] = uo_value * (1.0 - ooo_alpha)
pose[self.mouth_ooo_index] = uo_value * ooo_alpha
if self.panel is not None:
frequency = self.breathing_frequency_slider.GetValue()
if frequency == 0:
value = 0.0
pose[self.breathing_index] = value
self.breathing_start_time = time.time()
else:
period = 60.0 / frequency
now = time.time()
diff = now - self.breathing_start_time
frac = (diff % period) / period
value = (-math.cos(2 * math.pi * frac) + 1.0) / 2.0
pose[self.breathing_index] = value
self.breathing_gauge.SetValue(int(1000 * value))
return pose
================================================
FILE: src/tha4/nn/__init__.py
================================================
================================================
FILE: src/tha4/nn/common/__init__.py
================================================
================================================
FILE: src/tha4/nn/common/conv_block_factory.py
================================================
from typing import Optional
from tha4.nn.conv import create_conv7_block_from_block_args, create_conv3_block_from_block_args, \
create_downsample_block_from_block_args, create_conv3
from tha4.nn.resnet_block import ResnetBlock
from tha4.nn.resnet_block_seperable import ResnetBlockSeparable
from tha4.nn.separable_conv import create_separable_conv7_block, create_separable_conv3_block, \
create_separable_downsample_block, create_separable_conv3
from tha4.nn.util import BlockArgs
class ConvBlockFactory:
def __init__(self,
block_args: BlockArgs,
use_separable_convolution: bool = False):
self.use_separable_convolution = use_separable_convolution
self.block_args = block_args
def create_conv3(self,
in_channels: int,
out_channels: int,
bias: bool,
initialization_method: Optional[str] = None):
if initialization_method is None:
initialization_method = self.block_args.initialization_method
if self.use_separable_convolution:
return create_separable_conv3(
in_channels, out_channels, bias, initialization_method, self.block_args.use_spectral_norm)
else:
return create_conv3(
in_channels, out_channels, bias, initialization_method, self.block_args.use_spectral_norm)
def create_conv7_block(self, in_channels: int, out_channels: int):
if self.use_separable_convolution:
return create_separable_conv7_block(in_channels, out_channels, self.block_args)
else:
return create_conv7_block_from_block_args(in_channels, out_channels, self.block_args)
def create_conv3_block(self, in_channels: int, out_channels: int):
if self.use_separable_convolution:
return create_separable_conv3_block(in_channels, out_channels, self.block_args)
else:
return create_conv3_block_from_block_args(in_channels, out_channels, self.block_args)
def create_downsample_block(self, in_channels: int, out_channels: int, is_output_1x1: bool):
if self.use_separable_convolution:
return create_separable_downsample_block(in_channels, out_channels, is_output_1x1, self.block_args)
else:
return create_downsample_block_from_block_args(in_channels, out_channels, is_output_1x1)
def create_resnet_block(self, num_channels: int, is_1x1: bool):
if self.use_separable_convolution:
return ResnetBlockSeparable.create(num_channels, is_1x1, block_args=self.block_args)
else:
return ResnetBlock.create(num_channels, is_1x1, block_args=self.block_args)
================================================
FILE: src/tha4/nn/common/poser_args.py
================================================
from typing import Optional
from torch.nn import Sigmoid, Sequential, Tanh
from tha4.nn.conv import create_conv3, create_conv3_from_block_args
from tha4.nn.nonlinearity_factory import ReLUFactory
from tha4.nn.normalization import InstanceNorm2dFactory
from tha4.nn.util import BlockArgs
class PoserArgs00:
def __init__(self,
image_size: int,
input_image_channels: int,
output_image_channels: int,
start_channels: int,
num_pose_params: int,
block_args: Optional[BlockArgs] = None):
self.num_pose_params = num_pose_params
self.start_channels = start_channels
self.output_image_channels = output_image_channels
self.input_image_channels = input_image_channels
self.image_size = image_size
if block_args is None:
self.block_args = BlockArgs(
normalization_layer_factory=InstanceNorm2dFactory(),
nonlinearity_factory=ReLUFactory(inplace=True))
else:
self.block_args = block_args
def create_alpha_block(self):
from torch.nn import Sequential
return Sequential(
create_conv3(
in_channels=self.start_channels,
out_channels=1,
bias=True,
initialization_method=self.block_args.initialization_method,
use_spectral_norm=False),
Sigmoid())
def create_all_channel_alpha_block(self):
from torch.nn import Sequential
return Sequential(
create_conv3(
in_channels=self.start_channels,
out_channels=self.output_image_channels,
bias=True,
initialization_method=self.block_args.initialization_method,
use_spectral_norm=False),
Sigmoid())
def create_color_change_block(self):
return Sequential(
create_conv3_from_block_args(
in_channels=self.start_channels,
out_channels=self.output_image_channels,
bias=True,
block_args=self.block_args),
Tanh())
def create_grid_change_block(self):
return create_conv3(
in_channels=self.start_channels,
out_channels=2,
bias=False,
initialization_method='zero',
use_spectral_norm=False)
================================================
FILE: src/tha4/nn/common/poser_encoder_decoder_00.py
================================================
import math
from typing import Optional, List
import torch
from torch import Tensor
from torch.nn import ModuleList, Module
from tha4.nn.common.poser_args import PoserArgs00
from tha4.nn.conv import create_conv3_block_from_block_args, create_downsample_block_from_block_args, \
create_upsample_block_from_block_args
from tha4.nn.nonlinearity_factory import ReLUFactory
from tha4.nn.normalization import InstanceNorm2dFactory
from tha4.nn.resnet_block import ResnetBlock
from tha4.nn.util import BlockArgs
class PoserEncoderDecoder00Args(PoserArgs00):
def __init__(self,
image_size: int,
input_image_channels: int,
output_image_channels: int,
num_pose_params: int ,
start_channels: int,
bottleneck_image_size,
num_bottleneck_blocks,
max_channels: int,
block_args: Optional[BlockArgs] = None):
super().__init__(
image_size, input_image_channels, output_image_channels, start_channels, num_pose_params, block_args)
self.max_channels = max_channels
self.num_bottleneck_blocks = num_bottleneck_blocks
self.bottleneck_image_size = bottleneck_image_size
assert bottleneck_image_size > 1
if block_args is None:
self.block_args = BlockArgs(
normalization_layer_factory=InstanceNorm2dFactory(),
nonlinearity_factory=ReLUFactory(inplace=True))
else:
self.block_args = block_args
class PoserEncoderDecoder00(Module):
def __init__(self, args: PoserEncoderDecoder00Args):
super().__init__()
self.args = args
self.num_levels = int(math.log2(args.image_size // args.bottleneck_image_size)) + 1
self.downsample_blocks = ModuleList()
self.downsample_blocks.append(
create_conv3_block_from_block_args(
args.input_image_channels,
args.start_channels,
args.block_args))
current_image_size = args.image_size
current_num_channels = args.start_channels
while current_image_size > args.bottleneck_image_size:
next_image_size = current_image_size // 2
next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)
self.downsample_blocks.append(create_downsample_block_from_block_args(
in_channels=current_num_channels,
out_channels=next_num_channels,
is_output_1x1=False,
block_args=args.block_args))
current_image_size = next_image_size
current_num_channels = next_num_channels
assert len(self.downsample_blocks) == self.num_levels
self.bottleneck_blocks = ModuleList()
self.bottleneck_blocks.append(create_conv3_block_from_block_args(
in_channels=current_num_channels + args.num_pose_params,
out_channels=current_num_channels,
block_args=args.block_args))
for i in range(1, args.num_bottleneck_blocks):
self.bottleneck_blocks.append(
ResnetBlock.create(
num_channels=current_num_channels,
is1x1=False,
block_args=args.block_args))
self.upsample_blocks = ModuleList()
while current_image_size < args.image_size:
next_image_size = current_image_size * 2
next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)
self.upsample_blocks.append(create_upsample_block_from_block_args(
in_channels=current_num_channels,
out_channels=next_num_channels,
block_args=args.block_args))
current_image_size = next_image_size
current_num_channels = next_num_channels
def get_num_output_channels_from_level(self, level: int):
return self.get_num_output_channels_from_image_size(self.args.image_size // (2 ** level))
def get_num_output_channels_from_image_size(self, image_size: int):
return min(self.args.start_channels * (self.args.image_size // image_size), self.args.max_channels)
def forward(self, image: Tensor, pose: Optional[Tensor] = None) -> List[Tensor]:
if self.args.num_pose_params != 0:
assert pose is not None
else:
assert pose is None
outputs = []
feature = image
outputs.append(feature)
for block in self.downsample_blocks:
feature = block(feature)
outputs.append(feature)
if pose is not None:
n, c = pose.shape
pose = pose.view(n, c, 1, 1).repeat(1, 1, self.args.bottleneck_image_size, self.args.bottleneck_image_size)
feature = torch.cat([feature, pose], dim=1)
for block in self.bottleneck_blocks:
feature = block(feature)
outputs.append(feature)
for block in self.upsample_blocks:
feature = block(feature)
outputs.append(feature)
outputs.reverse()
return outputs
================================================
FILE: src/tha4/nn/common/poser_encoder_decoder_00_separable.py
================================================
import math
from typing import Optional, List
import torch
from torch import Tensor
from torch.nn import ModuleList, Module
from tha4.nn.common.poser_encoder_decoder_00 import PoserEncoderDecoder00Args
from tha4.nn.resnet_block_seperable import ResnetBlockSeparable
from tha4.nn.separable_conv import create_separable_conv3_block, create_separable_downsample_block, \
create_separable_upsample_block
class PoserEncoderDecoder00Separable(Module):
def __init__(self, args: PoserEncoderDecoder00Args):
super().__init__()
self.args = args
self.num_levels = int(math.log2(args.image_size // args.bottleneck_image_size)) + 1
self.downsample_blocks = ModuleList()
self.downsample_blocks.append(
create_separable_conv3_block(
args.input_image_channels,
args.start_channels,
args.block_args))
current_image_size = args.image_size
current_num_channels = args.start_channels
while current_image_size > args.bottleneck_image_size:
next_image_size = current_image_size // 2
next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)
self.downsample_blocks.append(create_separable_downsample_block(
in_channels=current_num_channels,
out_channels=next_num_channels,
is_output_1x1=False,
block_args=args.block_args))
current_image_size = next_image_size
current_num_channels = next_num_channels
assert len(self.downsample_blocks) == self.num_levels
self.bottleneck_blocks = ModuleList()
self.bottleneck_blocks.append(create_separable_conv3_block(
in_channels=current_num_channels + args.num_pose_params,
out_channels=current_num_channels,
block_args=args.block_args))
for i in range(1, args.num_bottleneck_blocks):
self.bottleneck_blocks.append(
ResnetBlockSeparable.create(
num_channels=current_num_channels,
is1x1=False,
block_args=args.block_args))
self.upsample_blocks = ModuleList()
while current_image_size < args.image_size:
next_image_size = current_image_size * 2
next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)
self.upsample_blocks.append(create_separable_upsample_block(
in_channels=current_num_channels,
out_channels=next_num_channels,
block_args=args.block_args))
current_image_size = next_image_size
current_num_channels = next_num_channels
def get_num_output_channels_from_level(self, level: int):
return self.get_num_output_channels_from_image_size(self.args.image_size // (2 ** level))
def get_num_output_channels_from_image_size(self, image_size: int):
return min(self.args.start_channels * (self.args.image_size // image_size), self.args.max_channels)
def forward(self, image: Tensor, pose: Optional[Tensor] = None) -> List[Tensor]:
if self.args.num_pose_params != 0:
assert pose is not None
else:
assert pose is None
outputs = []
feature = image
outputs.append(feature)
for block in self.downsample_blocks:
feature = block(feature)
outputs.append(feature)
if pose is not None:
n, c = pose.shape
pose = pose.view(n, c, 1, 1).repeat(1, 1, self.args.bottleneck_image_size, self.args.bottleneck_image_size)
feature = torch.cat([feature, pose], dim=1)
for block in self.bottleneck_blocks:
feature = block(feature)
outputs.append(feature)
for block in self.upsample_blocks:
feature = block(feature)
outputs.append(feature)
outputs.reverse()
return outputs
================================================
FILE: src/tha4/nn/common/resize_conv_encoder_decoder.py
================================================
import math
from typing import Optional, List
import torch
from torch import Tensor
from torch.nn import Module, ModuleList, Sequential, Upsample
from tha4.nn.common.conv_block_factory import ConvBlockFactory
from tha4.nn.nonlinearity_factory import LeakyReLUFactory
from tha4.nn.normalization import InstanceNorm2dFactory
from tha4.nn.util import BlockArgs
class ResizeConvEncoderDecoderArgs:
def __init__(self,
image_size: int,
input_channels: int,
start_channels: int,
bottleneck_image_size,
num_bottleneck_blocks,
max_channels: int,
block_args: Optional[BlockArgs] = None,
upsample_mode: str = 'bilinear',
use_separable_convolution=False):
self.use_separable_convolution = use_separable_convolution
self.upsample_mode = upsample_mode
self.block_args = block_args
self.max_channels = max_channels
self.num_bottleneck_blocks = num_bottleneck_blocks
self.bottleneck_image_size = bottleneck_image_size
self.start_channels = start_channels
self.image_size = image_size
self.input_channels = input_channels
class ResizeConvEncoderDecoder(Module):
def __init__(self, args: ResizeConvEncoderDecoderArgs):
super().__init__()
self.args = args
self.num_levels = int(math.log2(args.image_size // args.bottleneck_image_size)) + 1
conv_block_factory = ConvBlockFactory(args.block_args, args.use_separable_convolution)
self.downsample_blocks = ModuleList()
self.downsample_blocks.append(conv_block_factory.create_conv7_block(args.input_channels, args.start_channels))
current_image_size = args.image_size
current_num_channels = args.start_channels
while current_image_size > args.bottleneck_image_size:
next_image_size = current_image_size // 2
next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)
self.downsample_blocks.append(conv_block_factory.create_downsample_block(
in_channels=current_num_channels,
out_channels=next_num_channels,
is_output_1x1=False))
current_image_size = next_image_size
current_num_channels = next_num_channels
assert len(self.downsample_blocks) == self.num_levels
self.bottleneck_blocks = ModuleList()
for i in range(args.num_bottleneck_blocks):
self.bottleneck_blocks.append(conv_block_factory.create_resnet_block(current_num_channels, is_1x1=False))
self.output_image_sizes = [current_image_size]
self.output_num_channels = [current_num_channels]
self.upsample_blocks = ModuleList()
if args.upsample_mode == 'nearest':
align_corners = None
else:
align_corners = False
while current_image_size < args.image_size:
next_image_size = current_image_size * 2
next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)
self.upsample_blocks.append(
Sequential(
Upsample(scale_factor=2, mode=args.upsample_mode, align_corners=align_corners),
conv_block_factory.create_conv3_block(
in_channels=current_num_channels, out_channels=next_num_channels)))
current_image_size = next_image_size
current_num_channels = next_num_channels
self.output_image_sizes.append(current_image_size)
self.output_num_channels.append(current_num_channels)
def get_num_output_channels_from_level(self, level: int):
return self.get_num_output_channels_from_image_size(self.args.image_size // (2 ** level))
def get_num_output_channels_from_image_size(self, image_size: int):
return min(self.args.start_channels * (self.args.image_size // image_size), self.args.max_channels)
def forward(self, feature: Tensor) -> List[Tensor]:
outputs = []
for block in self.downsample_blocks:
feature = block(feature)
for block in self.bottleneck_blocks:
feature = block(feature)
outputs.append(feature)
for block in self.upsample_blocks:
feature = block(feature)
outputs.append(feature)
return outputs
================================================
FILE: src/tha4/nn/common/resize_conv_unet.py
================================================
from typing import Optional, List
import torch
from torch import Tensor
from torch.nn import ModuleList, Module, Upsample
from tha4.nn.common.conv_block_factory import ConvBlockFactory
from tha4.nn.nonlinearity_factory import ReLUFactory
from tha4.nn.normalization import InstanceNorm2dFactory
from tha4.nn.util import BlockArgs
class ResizeConvUNetArgs:
def __init__(self,
image_size: int,
input_channels: int,
start_channels: int,
bottleneck_image_size: int,
num_bottleneck_blocks: int,
max_channels: int,
upsample_mode: str = 'bilinear',
block_args: Optional[BlockArgs] = None,
use_separable_convolution: bool = False):
if block_args is None:
block_args = BlockArgs(
normalization_layer_factory=InstanceNorm2dFactory(),
nonlinearity_factory=ReLUFactory(inplace=False))
self.use_separable_convolution = use_separable_convolution
self.block_args = block_args
self.upsample_mode = upsample_mode
self.max_channels = max_channels
self.num_bottleneck_blocks = num_bottleneck_blocks
self.bottleneck_image_size = bottleneck_image_size
self.input_channels = input_channels
self.start_channels = start_channels
self.image_size = image_size
class ResizeConvUNet(Module):
def __init__(self, args: ResizeConvUNetArgs):
super().__init__()
self.args = args
conv_block_factory = ConvBlockFactory(args.block_args, args.use_separable_convolution)
self.downsample_blocks = ModuleList()
self.downsample_blocks.append(conv_block_factory.create_conv3_block(
self.args.input_channels,
self.args.start_channels))
current_channels = self.args.start_channels
current_size = self.args.image_size
size_to_channel = {
current_size: current_channels
}
while current_size > self.args.bottleneck_image_size:
next_size = current_size // 2
next_channels = min(self.args.max_channels, current_channels * 2)
self.downsample_blocks.append(conv_block_factory.create_downsample_block(
current_channels,
next_channels,
is_output_1x1=False))
current_size = next_size
current_channels = next_channels
size_to_channel[current_size] = current_channels
self.bottleneck_blocks = ModuleList()
for i in range(self.args.num_bottleneck_blocks):
self.bottleneck_blocks.append(conv_block_factory.create_resnet_block(current_channels, is_1x1=False))
self.output_image_sizes = [current_size]
self.output_num_channels = [current_channels]
self.upsample_blocks = ModuleList()
while current_size < self.args.image_size:
next_size = current_size * 2
next_channels = size_to_channel[next_size]
self.upsample_blocks.append(conv_block_factory.create_conv3_block(
current_channels + next_channels,
next_channels))
current_size = next_size
current_channels = next_channels
self.output_image_sizes.append(current_size)
self.output_num_channels.append(current_channels)
if args.upsample_mode == 'nearest':
align_corners = None
else:
align_corners = False
self.double_resolution = Upsample(scale_factor=2, mode=args.upsample_mode, align_corners=align_corners)
def forward(self, feature: Tensor) -> List[Tensor]:
downsampled_features = []
for block in self.downsample_blocks:
feature = block(feature)
downsampled_features.append(feature)
for block in self.bottleneck_blocks:
feature = block(feature)
outputs = [feature]
for i in range(0, len(self.upsample_blocks)):
feature = self.double_resolution(feature)
feature = torch.cat([feature, downsampled_features[-i - 2]], dim=1)
feature = self.upsample_blocks[i](feature)
outputs.append(feature)
return outputs
================================================
FILE: src/tha4/nn/common/unet.py
================================================
import math
from enum import Enum
from typing import Optional, List
import torch
from torch import zero_, Tensor
from torch.nn import Module, GroupNorm, Sequential, SiLU, Conv2d, AvgPool2d, Linear, Dropout, ModuleList
from torch.nn.functional import interpolate
from tha4.shion.core.module_factory import ModuleFactory
class Identity(Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
class IdentityFactory(ModuleFactory):
def create(self) -> Module:
return Identity()
def init_to_zero(module: Module):
with torch.no_grad():
zero_(module.weight)
zero_(module.bias)
return module
class Upsample(Module):
def __init__(self, in_channels: int, out_channels: Optional[int] = None, use_conv: bool = False):
super().__init__()
if out_channels is None:
out_channels = in_channels
self.in_channels = in_channels
if use_conv or in_channels != out_channels:
self.postprocess = Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.postprocess = Identity()
def forward(self, x):
assert x.shape[1] == self.in_channels
return self.postprocess(interpolate(x, scale_factor=2, mode="nearest"))
class Downsample(Module):
def __init__(self, in_channels: int, out_channels: Optional[int] = None, use_conv: bool = False):
super().__init__()
if out_channels is None:
out_channels = in_channels
self.in_channels = in_channels
if use_conv or in_channels != out_channels:
self.op = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
else:
self.op = AvgPool2d(kernel_size=2, stride=2)
def forward(self, x):
assert x.shape[1] == self.in_channels
return self.op(x)
def GroupNorm32(channels):
return GroupNorm(min(32, channels), channels)
class SamplingMode(Enum):
SAME_RESOLUTION = 0
UPSAMPLING = 1
DOWNSAMPING = 2
class ResBlockArgs:
def __init__(self,
dropout_prob: float,
use_cond0: bool = True,
use_cond1: bool = False,
init_conditioned_residual_to_zero: bool = False,
use_conv_on_skip_connection: bool = False):
assert not use_cond1 or use_cond0
self.use_conv_on_skip_connection = use_conv_on_skip_connection
self.use_cond1 = use_cond1
self.use_cond0 = use_cond0
self.init_conditioned_residual_to_zero = init_conditioned_residual_to_zero
self.dropout_prob = dropout_prob
def apply_scaleshift(x: Tensor, scaleshift: Tensor, condition_bias: float = 1.0) -> Tensor:
assert len(scaleshift.shape) == 2
assert len(x.shape) == 4
assert x.shape[0] == scaleshift.shape[0]
assert 2 * x.shape[1] == scaleshift.shape[1]
scaleshift = scaleshift.reshape(scaleshift.shape[0], scaleshift.shape[1], 1, 1)
scale, shift = torch.chunk(scaleshift, 2, dim=1)
return x * (condition_bias + scale) + shift
class ResBlock(Module):
def __init__(self,
in_channels: int,
out_channels: int,
cond0_channels: Optional[int] = None,
cond1_channels: Optional[int] = None,
sampling_mode: SamplingMode = SamplingMode.SAME_RESOLUTION,
dropout_prob: float = 0.1,
condition_bias: float = 1.0):
super().__init__()
assert cond0_channels is not None or cond1_channels is None
self.in_channels = in_channels
self.out_channels = out_channels
self.sampling_mode = sampling_mode
self.cond0_channels = cond0_channels
self.cond1_channels = cond1_channels
self.condition_bias = condition_bias
if sampling_mode == SamplingMode.UPSAMPLING:
self.x_resample = Upsample(in_channels)
self.h_resample = Upsample(in_channels)
elif sampling_mode == SamplingMode.DOWNSAMPING:
self.x_resample = Downsample(in_channels)
self.h_resample = Downsample(in_channels)
else:
self.x_resample = Identity()
self.h_resample = Identity()
self.nonlinear = SiLU()
# Layers before conditioning
self.norm0 = GroupNorm32(in_channels)
self.conv0 = Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# Conditioning layers
if cond0_channels is not None:
self.cond0_layers = Sequential(
SiLU(),
Linear(cond0_channels, 2 * out_channels))
self.norm1 = GroupNorm32(out_channels)
self.dropout = Dropout(dropout_prob)
self.conv1 = init_to_zero(Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1))
if cond1_channels is not None:
self.cond1_layers = Sequential(
SiLU(),
Linear(cond0_channels, 2 * out_channels))
# Skip layer
if in_channels == out_channels:
self.skip = Identity()
else:
self.skip = Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x: Tensor, cond0: Optional[Tensor] = None, cond1: Optional[Tensor] = None) -> Tensor:
assert self.cond0_channels is None or cond0 is not None
assert self.cond1_channels is None or cond1 is not None
h = self.conv0(self.h_resample(self.nonlinear(self.norm0(x))))
if self.cond0_channels is not None:
h = self.norm1(h)
h = apply_scaleshift(h, self.cond0_layers(cond0), self.condition_bias)
if self.cond1_channels is not None:
h = apply_scaleshift(h, self.cond1_layers(cond1), self.condition_bias)
h = self.conv1(self.dropout(self.nonlinear(h)))
return self.skip(self.x_resample(x)) + h
class AttentionBlockArgs:
def __init__(self,
num_heads: Optional[int] = 1,
num_head_channels: Optional[int] = None,
use_new_attention_order: bool = False):
self.use_new_attention_order = use_new_attention_order
self.num_head_channels = num_head_channels
self.num_heads = num_heads
def qkv_attention_legacy(qkv: torch.Tensor, num_heads: int):
assert len(qkv.shape) == 3
B, W, L = qkv.shape
H = num_heads
assert W % (3 * H) == 0
C = W // (3 * H)
q, k, v = qkv.reshape(B * H, C * 3, L).split(C, dim=1)
scale = 1.0 / math.sqrt(math.sqrt(C))
weight = torch.einsum('bct,bcs->bts', q * scale, k * scale)
weight = torch.softmax(weight, dim=-1)
output = torch.einsum("bts,bcs->bct", weight, v)
return output.reshape(B, H * C, L)
def qkv_attention(qkv: torch.Tensor, num_heads: int):
B, W, L = qkv.shape
H = num_heads
assert W % (3 * H) == 0
C = W // (3 * H)
q, k, v = qkv.chunk(3, dim=1)
scale = 1.0 / math.sqrt(math.sqrt(C))
weight = torch.einsum("bct,bcs->bts", (q * scale).view(B * H, C, L), (k * scale).view(B * H, C, L))
weight = torch.softmax(weight, dim=-1)
output = torch.einsum("bts,bcs->bct", weight, v.reshape(B * H, C, L))
return output.reshape(B, H * C, L)
class AttentionBlock(Module):
def __init__(self,
num_channels: int,
args: AttentionBlockArgs):
super().__init__()
self.use_new_attention_order = args.use_new_attention_order
if args.num_head_channels is None:
assert args.num_heads is not None
assert num_channels % args.num_heads == 0
self.num_heads = args.num_heads
self.num_head_channels = num_channels // self.num_heads
elif args.num_heads is None:
assert args.num_head_channels is not None
assert num_channels % args.num_head_channels == 0
self.num_heads = num_channels // args.num_head_channels
self.num_head_channels = args.num_head_channels
self.norm = GroupNorm32(num_channels)
self.qkv = Conv2d(num_channels, 3 * num_channels, kernel_size=1, stride=1, padding=0)
self.conv = Conv2d(num_channels, num_channels, kernel_size=1, stride=1, padding=0)
with torch.no_grad():
zero_(self.conv.weight)
zero_(self.conv.bias)
def forward(self, x: torch.Tensor):
assert len(x.shape) == 4
B, C, H, W = x.shape
qkv = self.qkv(self.norm(x)).reshape(B, 3 * C, H * W)
if self.use_new_attention_order:
h = qkv_attention(qkv, self.num_heads)
else:
h = qkv_attention_legacy(qkv, self.num_heads)
h = self.conv(h.reshape(B, C, H, W))
return x + h
class Arity3To1(Module):
def __init__(self, module: Module):
super().__init__()
self.module = module
def forward(self, x: Tensor, y: Optional[Tensor] = None, z: Optional[Tensor] = None):
return self.module(x)
class DownsamplingBlock(Module):
def __init__(self,
in_channels: int,
out_channels: int,
cond0_channels: Optional[int],
cond1_channels: Optional[int],
num_res_blocks: int,
dropout_prob: float,
use_attention: bool,
perform_downsampling: bool,
resample_with_res_block: bool,
use_conv_to_resample: bool,
attention_block_args: AttentionBlockArgs,
condition_bias: float = 1.0):
super().__init__()
self.use_attention = use_attention
self.res_blocks = ModuleList()
self.attention_blocks = ModuleList()
self.perform_downsampling = perform_downsampling
self.output_channels = []
for j in range(num_res_blocks):
self.res_blocks.append(ResBlock(
in_channels=in_channels if j == 0 else out_channels,
out_channels=out_channels,
cond0_channels=cond0_channels,
cond1_channels=cond1_channels,
dropout_prob=dropout_prob,
condition_bias=condition_bias))
if use_attention:
self.attention_blocks.append(AttentionBlock(out_channels, attention_block_args))
self.output_channels.append(out_channels)
if perform_downsampling:
if resample_with_res_block:
self.downsample = ResBlock(
in_channels=out_channels,
out_channels=out_channels,
cond0_channels=cond0_channels,
cond1_channels=cond1_channels,
dropout_prob=dropout_prob,
sampling_mode=SamplingMode.DOWNSAMPING,
condition_bias=condition_bias)
else:
self.downsample = Arity3To1(Downsample(out_channels, use_conv_to_resample))
self.output_channels.append(out_channels)
def forward(self, h: Tensor, cond0: Optional[Tensor] = None, cond1: Optional[Tensor] = None) -> List[Tensor]:
hs = []
for i in range(len(self.res_blocks)):
h = self.res_blocks[i].forward(h, cond0, cond1)
if self.use_attention:
h = self.attention_blocks[i].forward(h)
hs.append(h)
if self.perform_downsampling:
hs.append(self.downsample(h, cond0, cond1))
return hs
class UpsamplingBlock(Module):
def __init__(self,
in_channels: int,
out_channels: int,
cond0_channels: Optional[int],
cond1_channels: Optional[int],
num_resnet_blocks: int,
skip_channels: List[int],
dropout_prob: float,
use_attention: bool,
perform_upsampling: bool,
resample_with_res_block: bool,
use_conv_to_resample: bool,
attention_block_args: AttentionBlockArgs,
condition_bias: float = 1.0):
super().__init__()
self.use_attention = use_attention
self.resnet_blocks = ModuleList()
self.attention_blocks = ModuleList()
self.perform_upsampling = perform_upsampling
for i in range(num_resnet_blocks):
self.resnet_blocks.append(ResBlock(
in_channels=(in_channels if i == 0 else out_channels) + skip_channels[i],
out_channels=out_channels,
cond0_channels=cond0_channels,
cond1_channels=cond1_channels,
dropout_prob=dropout_prob,
condition_bias=condition_bias))
if use_attention:
self.attention_blocks.append(AttentionBlock(out_channels, attention_block_args))
if perform_upsampling:
if resample_with_res_block:
self.upsample = ResBlock(
in_channels=out_channels,
out_channels=out_channels,
cond0_channels=cond0_channels,
cond1_channels=cond1_channels,
sampling_mode=SamplingMode.UPSAMPLING,
dropout_prob=dropout_prob,
condition_bias=condition_bias)
else:
self.upsample = Arity3To1(Upsample(out_channels, use_conv_to_resample))
def forward(self,
h: Tensor,
skips: List[Tensor],
cond0: Optional[Tensor] = None,
cond1: Optional[Tensor] = None) -> Tensor:
for i in range(len(self.resnet_blocks)):
h = self.resnet_blocks[i].forward(torch.concat([h, skips[i]], dim=1), cond0, cond1)
if self.use_attention:
h = self.attention_blocks[i].forward(h)
if self.perform_upsampling:
h = self.upsample.forward(h, cond0, cond1)
return h
def compute_timestep_embedding(t: Tensor, out_channels: int):
assert len(t.shape) == 2
b, c = t.shape
assert c == 1
half_channels = out_channels // 2
scale = -math.log(10000.0) / (half_channels - 1)
log_times = scale * torch.arange(0, half_channels, device=t.device)
times = torch.exp(log_times).reshape(1, half_channels) * t
t_emb = torch.cat([torch.cos(times), torch.sin(times)], dim=1)
if out_channels % 2 == 1:
t_emb = torch.nn.functional.pad(t_emb, (1, 1), mode='constant')
return t_emb
class TimeEmbedding(Module):
def __init__(self, out_channels: int):
super().__init__()
self.out_channels = out_channels
def forward(self, t: Tensor):
return compute_timestep_embedding(t, self.out_channels)
class UnetArgs:
def __init__(self,
in_channels: int = 3,
out_channels: int = 3,
model_channels: int = 64,
level_channel_multipliers: Optional[List[int]] = None,
level_use_attention: Optional[List[bool]] = None,
num_res_blocks_per_level: int = 2,
num_middle_res_blocks: int = 2,
time_embedding_channels: Optional[int] = None,
cond_input_channels: int = 4,
cond_internal_channels: int = 512,
attention_block_args: Optional[AttentionBlockArgs] = None,
dropout_prob: float = 0.1,
resample_with_res_block: bool = True,
use_conv_to_resample=False,
condition_bias: float = 1.0):
assert len(level_channel_multipliers) == len(level_use_attention)
assert not use_conv_to_resample or not resample_with_res_block
if time_embedding_channels is None:
time_embedding_channels = model_channels
if level_channel_multipliers is None:
level_channel_multipliers = [1, 2, 4, 8]
if level_use_attention is None:
level_use_attention = [False for _ in level_channel_multipliers]
if attention_block_args is None:
attention_block_args = AttentionBlockArgs(
num_heads=1,
num_head_channels=None,
use_new_attention_order=False)
self.condition_bias = condition_bias
self.use_conv_to_resample = use_conv_to_resample
self.resample_with_res_block = resample_with_res_block
self.cond_internal_channels = cond_internal_channels
self.dropout_prob = dropout_prob
self.attention_block_args = attention_block_args
self.time_embedding_channels = time_embedding_channels
self.num_res_blocks_per_level = num_res_blocks_per_level
self.level_use_attention = level_use_attention
self.level_channel_multipliers = level_channel_multipliers
self.model_channels = model_channels
self.out_channels = out_channels
self.in_channels = in_channels
self.num_levels = len(level_channel_multipliers)
self.num_middle_res_blocks = num_middle_res_blocks
self.cond_input_channels = cond_input_channels
class Unet(Module):
def __init__(self, args: UnetArgs):
super().__init__()
self.args = args
self.time_embed = Sequential(
TimeEmbedding(self.args.time_embedding_channels),
Linear(self.args.time_embedding_channels, self.args.cond_internal_channels),
SiLU(),
Linear(self.args.cond_internal_channels, self.args.cond_internal_channels))
self.cond_embed = Sequential(
Linear(self.args.cond_input_channels, self.args.cond_internal_channels),
SiLU(),
Linear(self.args.cond_internal_channels, self.args.cond_internal_channels))
self.first_conv = Conv2d(args.in_channels, args.model_channels, kernel_size=3, stride=1, padding=1)
current_channels = args.model_channels
channels = [current_channels]
# Downsampling blocks
self.down_blocks = ModuleList()
for i in range(args.num_levels):
out_channels = args.model_channels * args.level_channel_multipliers[i]
perform_downsampling = i < args.num_levels - 1
down_block = DownsamplingBlock(
in_channels=current_channels,
out_channels=out_channels,
cond0_channels=args.cond_internal_channels,
cond1_channels=args.cond_internal_channels,
num_res_blocks=args.num_res_blocks_per_level,
dropout_prob=args.dropout_prob,
use_attention=args.level_use_attention[i],
perform_downsampling=perform_downsampling,
attention_block_args=args.attention_block_args,
resample_with_res_block=args.resample_with_res_block,
use_conv_to_resample=args.use_conv_to_resample,
condition_bias=args.condition_bias)
self.down_blocks.append(down_block)
current_channels = out_channels
channels += down_block.output_channels
# Middle blocks
self.middle_blocks = ModuleList()
for i in range(self.args.num_middle_res_blocks - 1):
self.middle_blocks.append(ResBlock(
in_channels=current_channels,
out_channels=current_channels,
cond0_channels=args.cond_internal_channels,
cond1_channels=args.cond_internal_channels,
dropout_prob=args.dropout_prob,
condition_bias=args.condition_bias))
self.middle_blocks.append(
Arity3To1(AttentionBlock(num_channels=current_channels, args=args.attention_block_args)))
self.middle_blocks.append(ResBlock(
in_channels=current_channels,
out_channels=current_channels,
cond0_channels=args.cond_internal_channels,
cond1_channels=args.cond_internal_channels,
dropout_prob=args.dropout_prob,
condition_bias=args.condition_bias))
# Upsampling blocks
self.up_blocks = ModuleList()
for i in reversed(range(args.num_levels)):
skip_channels = []
for j in range(args.num_res_blocks_per_level + 1):
skip_channels.append(channels.pop())
perform_upsampling = i > 0
out_channels = args.model_channels * args.level_channel_multipliers[i]
up_block = UpsamplingBlock(
in_channels=current_channels,
out_channels=out_channels,
cond0_channels=args.cond_internal_channels,
cond1_channels=args.cond_internal_channels,
num_resnet_blocks=args.num_res_blocks_per_level + 1,
skip_channels=skip_channels,
dropout_prob=args.dropout_prob,
use_attention=args.level_use_attention[i],
perform_upsampling=perform_upsampling,
attention_block_args=args.attention_block_args,
resample_with_res_block=args.resample_with_res_block,
use_conv_to_resample=args.use_conv_to_resample,
condition_bias=args.condition_bias)
self.up_blocks.append(up_block)
current_channels = out_channels
assert len(channels) == 0
self.last = Sequential(
GroupNorm32(current_channels),
SiLU(),
init_to_zero(Conv2d(current_channels, args.out_channels, kernel_size=3, stride=1, padding=1)))
def forward(self, x: Tensor, t: Tensor, cond: Tensor):
t_emb = self.time_embed(t)
cond_emb = self.cond_embed(cond)
hs = [self.first_conv(x)]
for block in self.down_blocks:
hs += block.forward(hs[-1], t_emb, cond_emb)
h = hs[-1]
for block in self.middle_blocks:
h = block(h, t_emb, cond_emb)
for block in self.up_blocks:
skips = []
for i in range(self.args.num_res_blocks_per_level + 1):
skips.append(hs.pop())
h = block.forward(h, skips, t_emb, cond_emb)
assert len(hs) == 0
return self.last(h)
class UnetWithFirstConvAddition(Module):
def __init__(self, args: UnetArgs):
super().__init__()
self.args = args
self.time_embed = Sequential(
TimeEmbedding(self.args.time_embedding_channels),
Linear(self.args.time_embedding_channels, self.args.cond_internal_channels),
SiLU(),
Linear(self.args.cond_internal_channels, self.args.cond_internal_channels))
self.cond_embed = Sequential(
Linear(self.args.cond_input_channels, self.args.cond_internal_channels),
SiLU(),
Linear(self.args.cond_internal_channels, self.args.cond_internal_channels))
self.first_conv = Conv2d(args.in_channels, args.model_channels, kernel_size=3, stride=1, padding=1)
current_channels = args.model_channels
channels = [current_channels]
# Downsampling blocks
self.down_blocks = ModuleList()
for i in range(args.num_levels):
out_channels = args.model_channels * args.level_channel_multipliers[i]
perform_downsampling = i < args.num_levels - 1
down_block = DownsamplingBlock(
in_channels=current_channels,
out_channels=out_channels,
cond0_channels=args.cond_internal_channels,
cond1_channels=args.cond_internal_channels,
num_res_blocks=args.num_res_blocks_per_level,
dropout_prob=args.dropout_prob,
use_attention=args.level_use_attention[i],
perform_downsampling=perform_downsampling,
attention_block_args=args.attention_block_args,
resample_with_res_block=args.resample_with_res_block,
use_conv_to_resample=args.use_conv_to_resample,
condition_bias=args.condition_bias)
self.down_blocks.append(down_block)
current_channels = out_channels
channels += down_block.output_channels
# Middle blocks
self.middle_blocks = ModuleList()
for i in range(self.args.num_middle_res_blocks - 1):
self.middle_blocks.append(ResBlock(
in_channels=current_channels,
out_channels=current_channels,
cond0_channels=args.cond_internal_channels,
cond1_channels=args.cond_internal_channels,
dropout_prob=args.dropout_prob,
condition_bias=args.condition_bias))
self.middle_blocks.append(
Arity3To1(AttentionBlock(num_channels=current_channels, args=args.attention_block_args)))
self.middle_blocks.append(ResBlock(
in_channels=current_channels,
out_channels=current_channels,
cond0_channels=args.cond_internal_channels,
cond1_channels=args.cond_internal_channels,
dropout_prob=args.dropout_prob,
condition_bias=args.condition_bias))
# Upsampling blocks
self.up_blocks = ModuleList()
for i in reversed(range(args.num_levels)):
skip_channels = []
for j in range(args.num_res_blocks_per_level + 1):
skip_channels.append(channels.pop())
perform_upsampling = i > 0
out_channels = args.model_channels * args.level_channel_multipliers[i]
up_block = UpsamplingBlock(
in_channels=current_channels,
out_channels=out_channels,
cond0_channels=args.cond_internal_channels,
cond1_channels=args.cond_internal_channels,
num_resnet_blocks=args.num_res_blocks_per_level + 1,
skip_channels=skip_channels,
dropout_prob=args.dropout_prob,
use_attention=args.level_use_attention[i],
perform_upsampling=perform_upsampling,
attention_block_args=args.attention_block_args,
resample_with_res_block=args.resample_with_res_block,
use_conv_to_resample=args.use_conv_to_resample,
condition_bias=args.condition_bias)
self.up_blocks.append(up_block)
current_channels = out_channels
assert len(channels) == 0
self.last = Sequential(
GroupNorm32(current_channels),
SiLU(),
init_to_zero(Conv2d(current_channels, args.out_channels, kernel_size=3, stride=1, padding=1)))
def forward(self, x: Tensor, t: Tensor, cond: Tensor, first_conv_addition: Tensor):
t_emb = self.time_embed(t)
cond_emb = self.cond_embed(cond)
first_conv = self.first_conv(x)
hs = [first_conv + first_conv_addition]
for block in self.down_blocks:
hs += block.forward(hs[-1], t_emb, cond_emb)
h = hs[-1]
for block in self.middle_blocks:
h = block(h, t_emb, cond_emb)
for block in self.up_blocks:
skips = []
for i in range(self.args.num_res_blocks_per_level + 1):
skips.append(hs.pop())
h = block.forward(h, skips, t_emb, cond_emb)
assert len(hs) == 0
return self.last(h)
================================================
FILE: src/tha4/nn/conv.py
================================================
from typing import Optional, Union, Callable
from torch.nn import Conv2d, Module, Sequential, ConvTranspose2d
from tha4.shion.core.module_factory import ModuleFactory
from tha4.nn.nonlinearity_factory import resolve_nonlinearity_factory
from tha4.nn.normalization import NormalizationLayerFactory
from tha4.nn.util import wrap_conv_or_linear_module, BlockArgs
def create_conv7(in_channels: int, out_channels: int,
bias: bool = False,
initialization_method: Union[str, Callable[[Module], Module]] = 'he',
use_spectral_norm: bool = False) -> Module:
return wrap_conv_or_linear_module(
Conv2d(in_channels, out_channels, kernel_size=7, stride=1, padding=3, bias=bias),
initialization_method,
use_spectral_norm)
def create_conv7_from_block_args(in_channels: int,
out_channels: int,
bias: bool = False,
block_args: Optional[BlockArgs] = None) -> Module:
if block_args is None:
block_args = BlockArgs()
return create_conv7(
in_channels, out_channels, bias,
block_args.initialization_method,
block_args.use_spectral_norm)
def create_conv3(in_channels: int,
out_channels: int,
bias: bool = False,
initialization_method: Union[str, Callable[[Module], Module]] = 'he',
use_spectral_norm: bool = False) -> Module:
return wrap_conv_or_linear_module(
Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=bias),
initialization_method,
use_spectral_norm)
def create_conv3_from_block_args(in_channels: int, out_channels: int,
bias: bool = False,
block_args: Optional[BlockArgs] = None):
if block_args is None:
block_args = BlockArgs()
return create_conv3(in_channels, out_channels, bias,
block_args.initialization_method,
block_args.use_spectral_norm)
def create_conv1(in_channels: int, out_channels: int,
initialization_method: Union[str, Callable[[Module], Module]] = 'he',
bias: bool = False,
use_spectral_norm: bool = False) -> Module:
return wrap_conv_or_linear_module(
Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
initialization_method,
use_spectral_norm)
def create_conv1_from_block_args(in_channels: int,
out_channels: int,
bias: bool = False,
block_args: Optional[BlockArgs] = None) -> Module:
if block_args is None:
block_args = BlockArgs()
return create_conv1(
in_channels=in_channels,
out_channels=out_channels,
initialization_method=block_args.initialization_method,
bias=bias,
use_spectral_norm=block_args.use_spectral_norm)
def create_conv7_block(in_channels: int, out_channels: int,
initialization_method: Union[str, Callable[[Module], Module]] = 'he',
nonlinearity_factory: Optional[ModuleFactory] = None,
normalization_layer_factory: Optional[NormalizationLayerFactory] = None,
use_spectral_norm: bool = False) -> Module:
nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory)
return Sequential(
create_conv7(in_channels, out_channels,
bias=False, initialization_method=initialization_method, use_spectral_norm=use_spectral_norm),
NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(out_channels, affine=True),
resolve_nonlinearity_factory(nonlinearity_factory).create())
def create_conv7_block_from_block_args(
in_channels: int, out_channels: int,
block_args: Optional[BlockArgs] = None) -> Module:
if block_args is None:
block_args = BlockArgs()
return create_conv7_block(in_channels, out_channels,
block_args.initialization_method,
block_args.nonlinearity_factory,
block_args.normalization_layer_factory,
block_args.use_spectral_norm)
def create_conv3_block(in_channels: int, out_channels: int,
initialization_method: Union[str, Callable[[Module], Module]] = 'he',
nonlinearity_factory: Optional[ModuleFactory] = None,
normalization_layer_factory: Optional[NormalizationLayerFactory] = None,
use_spectral_norm: bool = False) -> Module:
nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory)
return Sequential(
create_conv3(in_channels, out_channels,
bias=False, initialization_method=initialization_method, use_spectral_norm=use_spectral_norm),
NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(out_channels, affine=True),
resolve_nonlinearity_factory(nonlinearity_factory).create())
def create_conv3_block_from_block_args(
in_channels: int, out_channels: int, block_args: Optional[BlockArgs] = None):
if block_args is None:
block_args = BlockArgs()
return create_conv3_block(in_channels, out_channels,
block_args.initialization_method,
block_args.nonlinearity_factory,
block_args.normalization_layer_factory,
block_args.use_spectral_norm)
def create_downsample_block(in_channels: int, out_channels: int,
is_output_1x1: bool = False,
initialization_method: Union[str, Callable[[Module], Module]] = 'he',
nonlinearity_factory: Optional[ModuleFactory] = None,
normalization_layer_factory: Optional[NormalizationLayerFactory] = None,
use_spectral_norm: bool = False) -> Module:
if is_output_1x1:
return Sequential(
wrap_conv_or_linear_module(
Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
initialization_method,
use_spectral_norm),
resolve_nonlinearity_factory(nonlinearity_factory).create())
else:
return Sequential(
wrap_conv_or_linear_module(
Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
initialization_method,
use_spectral_norm),
NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(out_channels, affine=True),
resolve_nonlinearity_factory(nonlinearity_factory).create())
def create_downsample_block_from_block_args(in_channels: int, out_channels: int,
is_output_1x1: bool = False,
block_args: Optional[BlockArgs] = None):
if block_args is None:
block_args = BlockArgs()
return create_downsample_block(
in_channels, out_channels,
is_output_1x1,
block_args.initialization_method,
block_args.nonlinearity_factory,
block_args.normalization_layer_factory,
block_args.use_spectral_norm)
def create_upsample_block(in_channels: int,
out_channels: int,
initialization_method: Union[str, Callable[[Module], Module]] = 'he',
nonlinearity_factory: Optional[ModuleFactory] = None,
normalization_layer_factory: Optional[NormalizationLayerFactory] = None,
use_spectral_norm: bool = False) -> Module:
nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory)
return Sequential(
wrap_conv_or_linear_module(
ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
initialization_method,
use_spectral_norm),
NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(out_channels, affine=True),
resolve_nonlinearity_factory(nonlinearity_factory).create())
def create_upsample_block_from_block_args(in_channels: int,
out_channels: int,
block_args: Optional[BlockArgs] = None) -> Module:
if block_args is None:
block_args = BlockArgs()
return create_upsample_block(in_channels, out_channels,
block_args.initialization_method,
block_args.nonlinearity_factory,
block_args.normalization_layer_factory,
block_args.use_spectral_norm)
================================================
FILE: src/tha4/nn/eyebrow_decomposer/__init__.py
================================================
================================================
FILE: src/tha4/nn/eyebrow_decomposer/eyebrow_decomposer_00.py
================================================
from typing import List, Optional
import torch
from torch import Tensor
from torch.nn import Module
from tha4.nn.common.poser_encoder_decoder_00 import PoserEncoderDecoder00Args, PoserEncoderDecoder00
from tha4.nn.image_processing_util import apply_color_change
from tha4.shion.core.module_factory import ModuleFactory
from tha4.nn.nonlinearity_factory import ReLUFactory
from tha4.nn.normalization import InstanceNorm2dFactory
from tha4.nn.util import BlockArgs
class EyebrowDecomposer00Args(PoserEncoderDecoder00Args):
def __init__(self,
image_size: int = 128,
image_channels: int = 4,
start_channels: int = 64,
bottleneck_image_size=16,
num_bottleneck_blocks=6,
max_channels: int = 512,
block_args: Optional[BlockArgs] = None):
super().__init__(
image_size,
image_channels,
image_channels,
0,
start_channels,
bottleneck_image_size,
num_bottleneck_blocks,
max_channels,
block_args)
class EyebrowDecomposer00(Module):
def __init__(self, args: EyebrowDecomposer00Args):
super().__init__()
self.args = args
self.body = PoserEncoderDecoder00(args)
self.background_layer_alpha = self.args.create_alpha_block()
self.background_layer_color_change = self.args.create_color_change_block()
self.eyebrow_layer_alpha = self.args.create_alpha_block()
self.eyebrow_layer_color_change = self.args.create_color_change_block()
def forward(self, image: Tensor, *args) -> List[Tensor]:
feature = self.body(image)[0]
background_layer_alpha = self.background_layer_alpha(feature)
background_layer_color_change = self.background_layer_color_change(feature)
background_layer_1 = apply_color_change(background_layer_alpha, background_layer_color_change, image)
eyebrow_layer_alpha = self.eyebrow_layer_alpha(feature)
eyebrow_layer_color_change = self.eyebrow_layer_color_change(feature)
eyebrow_layer = apply_color_change(eyebrow_layer_alpha, image, eyebrow_layer_color_change)
return [
eyebrow_layer, # 0
eyebrow_layer_alpha, # 1
eyebrow_layer_color_change, # 2
background_layer_1, # 3
background_layer_alpha, # 4
background_layer_color_change, # 5
]
EYEBROW_LAYER_INDEX = 0
EYEBROW_LAYER_ALPHA_INDEX = 1
EYEBROW_LAYER_COLOR_CHANGE_INDEX = 2
BACKGROUND_LAYER_INDEX = 3
BACKGROUND_LAYER_ALPHA_INDEX = 4
BACKGROUND_LAYER_COLOR_CHANGE_INDEX = 5
OUTPUT_LENGTH = 6
class EyebrowDecomposer00Factory(ModuleFactory):
def __init__(self, args: EyebrowDecomposer00Args):
super().__init__()
self.args = args
def create(self) -> Module:
return EyebrowDecomposer00(self.args)
================================================
FILE: src/tha4/nn/eyebrow_morphing_combiner/__init__.py
================================================
================================================
FILE: src/tha4/nn/eyebrow_morphing_combiner/eyebrow_morphing_combiner_00.py
================================================
from typing import List, Optional
import torch
from torch import Tensor
from torch.nn import Module
from tha4.nn.common.poser_encoder_decoder_00 import PoserEncoderDecoder00Args, PoserEncoderDecoder00
from tha4.nn.image_processing_util import apply_color_change, apply_grid_change, apply_rgb_change
from tha4.shion.core.module_factory import ModuleFactory
from tha4.nn.nonlinearity_factory import ReLUFactory
from tha4.nn.normalization import InstanceNorm2dFactory
from tha4.nn.util import BlockArgs
class EyebrowMorphingCombiner00Args(PoserEncoderDecoder00Args):
def __init__(self,
image_size: int = 128,
image_channels: int = 4,
num_pose_params: int = 12,
start_channels: int = 64,
bottleneck_image_size=16,
num_bottleneck_blocks=6,
max_channels: int = 512,
block_args: Optional[BlockArgs] = None):
super().__init__(
image_size,
2 * image_channels,
image_channels,
num_pose_params,
start_channels,
bottleneck_image_size,
num_bottleneck_blocks,
max_channels,
block_args)
class EyebrowMorphingCombiner00(Module):
def __init__(self, args: EyebrowMorphingCombiner00Args):
super().__init__()
self.args = args
self.body = PoserEncoderDecoder00(args)
self.morphed_eyebrow_layer_grid_change = self.args.create_grid_change_block()
self.morphed_eyebrow_layer_alpha = self.args.create_alpha_block()
self.morphed_eyebrow_layer_color_change = self.args.create_color_change_block()
self.combine_alpha = self.args.create_alpha_block()
def forward(self, background_layer: Tensor, eyebrow_layer: Tensor, pose: Tensor, *args) -> List[Tensor]:
combined_image = torch.cat([background_layer, eyebrow_layer], dim=1)
feature = self.body(combined_image, pose)[0]
morphed_eyebrow_layer_grid_change = self.morphed_eyebrow_layer_grid_change(feature)
morphed_eyebrow_layer_alpha = self.morphed_eyebrow_layer_alpha(feature)
morphed_eyebrow_layer_color_change = self.morphed_eyebrow_layer_color_change(feature)
warped_eyebrow_layer = apply_grid_change(morphed_eyebrow_layer_grid_change, eyebrow_layer)
morphed_eyebrow_layer = apply_color_change(
morphed_eyebrow_layer_alpha, morphed_eyebrow_layer_color_change, warped_eyebrow_layer)
combine_alpha = self.combine_alpha(feature)
eyebrow_image = apply_rgb_change(combine_alpha, morphed_eyebrow_layer, background_layer)
eyebrow_image_no_combine_alpha = apply_rgb_change(
(morphed_eyebrow_layer[:, 3:4, :, :] + 1.0) / 2.0, morphed_eyebrow_layer, background_layer)
return [
eyebrow_image, # 0
combine_alpha, # 1
eyebrow_image_no_combine_alpha, # 2
morphed_eyebrow_layer, # 3
morphed_eyebrow_layer_alpha, # 4
morphed_eyebrow_layer_color_change, # 5
warped_eyebrow_layer, # 6
morphed_eyebrow_layer_grid_change, # 7
]
EYEBROW_IMAGE_INDEX = 0
COMBINE_ALPHA_INDEX = 1
EYEBROW_IMAGE_NO_COMBINE_ALPHA_INDEX = 2
MORPHED_EYEBROW_LAYER_INDEX = 3
MORPHED_EYEBROW_LAYER_ALPHA_INDEX = 4
MORPHED_EYEBROW_LAYER_COLOR_CHANGE_INDEX = 5
WARPED_EYEBROW_LAYER_INDEX = 6
MORPHED_EYEBROW_LAYER_GRID_CHANGE_INDEX = 7
OUTPUT_LENGTH = 8
class EyebrowMorphingCombiner00Factory(ModuleFactory):
def __init__(self, args: EyebrowMorphingCombiner00Args):
super().__init__()
self.args = args
def create(self) -> Module:
return EyebrowMorphingCombiner00(self.args)
================================================
FILE: src/tha4/nn/face_morpher/__init__.py
================================================
================================================
FILE: src/tha4/nn/face_morpher/face_morpher_08.py
================================================
import math
from typing import List, Optional
import torch
from torch import Tensor
from torch.nn import ModuleList, Sequential, Sigmoid, Tanh, Module
from torch.nn.functional import affine_grid, grid_sample
from tha4.shion.core.module_factory import ModuleFactory
from tha4.nn.conv import create_conv3_block_from_block_args, \
create_downsample_block_from_block_args, create_upsample_block_from_block_args, create_conv3_from_block_args, \
create_conv3
from tha4.nn.nonlinearity_factory import LeakyReLUFactory
from tha4.nn.normalization import InstanceNorm2dFactory
from tha4.nn.resnet_block import ResnetBlock
from tha4.nn.util import BlockArgs
class FaceMorpher08Args:
def __init__(self,
image_size: int = 256,
image_channels: int = 4,
num_expression_params: int = 67,
start_channels: int = 16,
bottleneck_image_size=4,
num_bottleneck_blocks=3,
max_channels: int = 512,
block_args: Optional[BlockArgs] = None,
output_iris_mouth_grid_change: bool = False):
self.max_channels = max_channels
self.num_bottleneck_blocks = num_bottleneck_blocks
assert bottleneck_image_size > 1
self.bottleneck_image_size = bottleneck_image_size
self.start_channels = start_channels
self.image_channels = image_channels
self.num_expression_params = num_expression_params
self.image_size = image_size
self.output_iris_mouth_grid_change = output_iris_mouth_grid_change
if block_args is None:
self.block_args = BlockArgs(
normalization_layer_factory=InstanceNorm2dFactory(),
nonlinearity_factory=LeakyReLUFactory(negative_slope=0.2, inplace=True))
else:
self.block_args = block_args
class FaceMorpher08(Module):
def __init__(self, args: FaceMorpher08Args):
super().__init__()
self.args = args
self.num_levels = int(math.log2(args.image_size // args.bottleneck_image_size)) + 1
self.downsample_blocks = ModuleList()
self.downsample_blocks.append(
create_conv3_block_from_block_args(
args.image_channels,
args.start_channels,
args.block_args))
current_image_size = args.image_size
current_num_channels = args.start_channels
while current_image_size > args.bottleneck_image_size:
next_image_size = current_image_size // 2
next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)
self.downsample_blocks.append(create_downsample_block_from_block_args(
in_channels=current_num_channels,
out_channels=next_num_channels,
is_output_1x1=False,
block_args=args.block_args))
current_image_size = next_image_size
current_num_channels = next_num_channels
assert len(self.downsample_blocks) == self.num_levels
self.bottleneck_blocks = ModuleList()
self.bottleneck_blocks.append(create_conv3_block_from_block_args(
in_channels=current_num_channels + args.num_expression_params,
out_channels=current_num_channels,
block_args=args.block_args))
for i in range(1, args.num_bottleneck_blocks):
self.bottleneck_blocks.append(
ResnetBlock.create(
num_channels=current_num_channels,
is1x1=False,
block_args=args.block_args))
self.upsample_blocks = ModuleList()
while current_image_size < args.image_size:
next_image_size = current_image_size * 2
next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)
self.upsample_blocks.append(create_upsample_block_from_block_args(
in_channels=current_num_channels,
out_channels=next_num_channels,
block_args=args.block_args))
current_image_size = next_image_size
current_num_channels = next_num_channels
self.iris_mouth_grid_change = self.create_grid_change_block()
self.iris_mouth_color_change = self.create_color_change_block()
self.iris_mouth_alpha = self.create_alpha_block()
self.eye_color_change = self.create_color_change_block()
self.eye_alpha = self.create_alpha_block()
def create_alpha_block(self):
return Sequential(
create_conv3(
in_channels=self.args.start_channels,
out_channels=1,
bias=True,
initialization_method=self.args.block_args.initialization_method,
use_spectral_norm=False),
Sigmoid())
def create_color_change_block(self):
return Sequential(
create_conv3_from_block_args(
in_channels=self.args.start_channels,
out_channels=self.args.image_channels,
bias=True,
block_args=self.args.block_args),
Tanh())
def create_grid_change_block(self):
return create_conv3(
in_channels=self.args.start_channels,
out_channels=2,
bias=False,
initialization_method='zero',
use_spectral_norm=False)
def get_num_output_channels_from_level(self, level: int):
return self.get_num_output_channels_from_image_size(self.args.image_size // (2 ** level))
def get_num_output_channels_from_image_size(self, image_size: int):
return min(self.args.start_channels * (self.args.image_size // image_size), self.args.max_channels)
def merge_down(self, top_layer: Tensor, bottom_layer: Tensor):
top_layer_rgb = top_layer[:, 0:3, :, :]
top_layer_a = top_layer[:, 3:4, :, :]
return bottom_layer * (1 - top_layer_a) + torch.cat([top_layer_rgb * top_layer_a, top_layer_a], dim=1)
def apply_grid_change(self, grid_change, image: Tensor) -> Tensor:
n, c, h, w = image.shape
device = grid_change.device
grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2)
identity = torch.tensor(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
device=device,
dtype=grid_change.dtype).unsqueeze(0).repeat(n, 1, 1)
base_grid = affine_grid(identity, [n, c, h, w], align_corners=False)
grid = base_grid + grid_change
resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=False)
return resampled_image
def apply_color_change(self, alpha, color_change, image: Tensor) -> Tensor:
return color_change * alpha + image * (1 - alpha)
def forward(self, image: Tensor, pose: Tensor, *args) -> List[Tensor]:
feature = image
for block in self.downsample_blocks:
feature = block(feature)
n, c = pose.shape
pose = pose.view(n, c, 1, 1).repeat(1, 1, self.args.bottleneck_image_size, self.args.bottleneck_image_size)
feature = torch.cat([feature, pose], dim=1)
for block in self.bottleneck_blocks:
feature = block(feature)
for block in self.upsample_blocks:
feature = block(feature)
iris_mouth_grid_change = self.iris_mouth_grid_change(feature)
iris_mouth_image_0 = self.apply_grid_change(iris_mouth_grid_change, image)
iris_mouth_color_change = self.iris_mouth_color_change(feature)
iris_mouth_alpha = self.iris_mouth_alpha(feature)
iris_mouth_image_1 = self.apply_color_change(iris_mouth_alpha, iris_mouth_color_change, iris_mouth_image_0)
eye_color_change = self.eye_color_change(feature)
eye_alpha = self.eye_alpha(feature)
output_image = self.apply_color_change(eye_alpha, eye_color_change, iris_mouth_image_1.detach())
outputs = [
output_image, # 0
eye_alpha, # 1
eye_color_change, # 2
iris_mouth_image_1, # 3
iris_mouth_alpha, # 4
iris_mouth_color_change, # 5
iris_mouth_image_0, # 6
]
if self.args.output_iris_mouth_grid_change:
outputs.append(iris_mouth_grid_change)
return outputs
OUTPUT_IMAGE_INDEX = 0
EYE_ALPHA_INDEX = 1
EYE_COLOR_CHANGE_INDEX = 2
IRIS_MOUTH_IMAGE_1_INDEX = 3
IRIS_MOUTH_ALPHA_INDEX = 4
IRIS_MOUTH_COLOR_CHANGE_INDEX = 5
IRIS_MOUTH_IMAGE_0_INDEX = 6
IRIS_MOUTH_GRID_CHANGE_INDEX = 7
class FaceMorpher08Factory(ModuleFactory):
def __init__(self, args: FaceMorpher08Args):
super().__init__()
self.args = args
def create(self) -> Module:
return FaceMorpher08(self.args)
================================================
FILE: src/tha4/nn/image_processing_util.py
================================================
import torch
from torch import Tensor
from torch.nn.functional import affine_grid, grid_sample
def apply_rgb_change(alpha: Tensor, color_change: Tensor, image: Tensor):
image_rgb = image[:, 0:3, :, :]
color_change_rgb = color_change[:, 0:3, :, :]
output_rgb = color_change_rgb * alpha + image_rgb * (1 - alpha)
return torch.cat([output_rgb, image[:, 3:4, :, :]], dim=1)
def apply_grid_change(grid_change, image: Tensor) -> Tensor:
n, c, h, w = image.shape
device = grid_change.device
grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2)
identity = torch.tensor(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
dtype=grid_change.dtype,
device=device).unsqueeze(0).repeat(n, 1, 1)
base_grid = affine_grid(identity, [n, c, h, w], align_corners=False)
grid = base_grid + grid_change
resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=False)
return resampled_image
class GridChangeApplier:
def __init__(self):
self.last_n = None
self.last_device = None
self.last_identity = None
def apply(self, grid_change: Tensor, image: Tensor, align_corners: bool = False) -> Tensor:
n, c, h, w = image.shape
device = grid_change.device
grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2)
if n == self.last_n and device == self.last_device:
identity = self.last_identity
else:
identity = torch.tensor(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
dtype=grid_change.dtype,
device=device,
requires_grad=False) \
.unsqueeze(0).repeat(n, 1, 1)
self.last_identity = identity
self.last_n = n
self.last_device = device
base_grid = affine_grid(identity, [n, c, h, w], align_corners=align_corners)
grid = base_grid + grid_change
resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=align_corners)
return resampled_image
def apply_color_change(alpha, color_change, image: Tensor) -> Tensor:
return color_change * alpha + image * (1 - alpha)
================================================
FILE: src/tha4/nn/init_function.py
================================================
from typing import Callable
import torch
from torch import zero_
from torch.nn import Module
from torch.nn.init import kaiming_normal_, xavier_normal_, normal_
def create_init_function(method: str = 'none') -> Callable[[Module], Module]:
def init(module: Module):
if method == 'none':
return module
elif method == 'he':
kaiming_normal_(module.weight)
return module
elif method == 'xavier':
xavier_normal_(module.weight)
return module
elif method == 'dcgan':
normal_(module.weight, 0.0, 0.02)
return module
elif method == 'dcgan_001':
normal_(module.weight, 0.0, 0.01)
return module
elif method == "zero":
with torch.no_grad():
zero_(module.weight)
return module
else:
raise ("Invalid initialization method %s" % method)
return init
class HeInitialization:
def __init__(self, a: int = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'):
self.nonlinearity = nonlinearity
self.mode = mode
self.a = a
def __call__(self, module: Module) -> Module:
with torch.no_grad():
kaiming_normal_(module.weight, a=self.a, mode=self.mode, nonlinearity=self.nonlinearity)
return module
class NormalInitialization:
def __init__(self, mean: float = 0.0, std: float = 1.0):
self.std = std
self.mean = mean
def __call__(self, module: Module) -> Module:
with torch.no_grad():
normal_(module.weight, self.mean, self.std)
return module
class XavierInitialization:
def __init__(self, gain: float = 1.0):
self.gain = gain
def __call__(self, module: Module) -> Module:
with torch.no_grad():
xavier_normal_(module.weight, self.gain)
return module
class ZeroInitialization:
def __call__(self, module: Module) -> Module:
with torch.no_grad:
zero_(module.weight)
return module
class NoInitialization:
def __call__(self, module: Module) -> Module:
return module
================================================
FILE: src/tha4/nn/morpher/__init__.py
================================================
================================================
FILE: src/tha4/nn/morpher/morpher_00.py
================================================
from typing import List
import torch
from torch import Tensor
from torch.nn import Module
from tha4.shion.core.module_factory import ModuleFactory
from tha4.nn.image_processing_util import GridChangeApplier
from tha4.nn.common.unet import UnetArgs, Unet, AttentionBlockArgs
def apply_color_change(alpha, color_change, image: Tensor) -> Tensor:
return color_change * alpha + image * (1 - alpha)
class Morpher00Args:
def __init__(self,
image_size: int,
image_channels: int,
num_pose_parameters: int,
unet_args: UnetArgs):
assert unet_args.in_channels == image_channels
assert unet_args.out_channels == (
image_channels + # direct
2 + # warp
1 # alpha
)
assert unet_args.cond_input_channels == num_pose_parameters
self.image_channels = image_channels
self.image_size = image_size
self.num_pose_parameters = num_pose_parameters
self.unet_args = unet_args
class Morpher00(Module):
def __init__(self, args: Morpher00Args):
super().__init__()
self.args = args
self.body = Unet(args.unet_args)
self.grid_change_applier = GridChangeApplier()
def forward(self, image: torch.Tensor, pose: torch.Tensor) -> List[Tensor]:
assert len(image.shape) == 4
assert image.shape[1] == self.args.image_channels
assert image.shape[2] == self.args.image_size
assert image.shape[3] == self.args.image_size
assert len(pose.shape) == 2
assert image.shape[0] == pose.shape[0]
assert pose.shape[1] == self.args.num_pose_parameters
t = torch.zeros(image.shape[0], 1, device=image.device)
body_output = self.body(image, t, pose)
direct = body_output[:, 0:self.args.image_channels, :, :]
grid_change = body_output[:, self.args.image_channels:self.args.image_channels + 2, :, :]
alpha = torch.sigmoid(body_output[:, self.args.image_channels + 2:self.args.image_channels + 3, :, :])
warped = self.grid_change_applier.apply(grid_change, image)
merged = apply_color_change(alpha, direct, warped)
return [
merged,
alpha,
warped,
grid_change,
direct
]
INDEX_MERGED = 0
INDEX_ALPHA = 1
INDEX_WARPED = 2
INDEX_GRID_CHANGE = 3
INDEX_DIRECT = 4
class Morpher00Factory(ModuleFactory):
def __init__(self, args: Morpher00Args):
self.args = args
def create(self) -> Module:
return Morpher00(self.args)
================================================
FILE: src/tha4/nn/nonlinearity_factory.py
================================================
from typing import Optional
from torch.nn import Module, ReLU, LeakyReLU, ELU, ReLU6, Hardswish, SiLU, Tanh, Sigmoid
from tha4.shion.core.module_factory import ModuleFactory
class ReLUFactory(ModuleFactory):
def __init__(self, inplace: bool = False):
self.inplace = inplace
def create(self) -> Module:
return ReLU(self.inplace)
class LeakyReLUFactory(ModuleFactory):
def __init__(self, inplace: bool = False, negative_slope: float = 1e-2):
self.negative_slope = negative_slope
self.inplace = inplace
def create(self) -> Module:
return LeakyReLU(inplace=self.inplace, negative_slope=self.negative_slope)
class ELUFactory(ModuleFactory):
def __init__(self, inplace: bool = False, alpha: float = 1.0):
self.alpha = alpha
self.inplace = inplace
def create(self) -> Module:
return ELU(inplace=self.inplace, alpha=self.alpha)
class ReLU6Factory(ModuleFactory):
def __init__(self, inplace: bool = False):
self.inplace = inplace
def create(self) -> Module:
return ReLU6(inplace=self.inplace)
class SiLUFactory(ModuleFactory):
def __init__(self, inplace: bool = False):
self.inplace = inplace
def create(self) -> Module:
return SiLU(inplace=self.inplace)
class HardswishFactory(ModuleFactory):
def __init__(self, inplace: bool = False):
self.inplace = inplace
def create(self) -> Module:
return Hardswish(inplace=self.inplace)
class TanhFactory(ModuleFactory):
def create(self) -> Module:
return Tanh()
class SigmoidFactory(ModuleFactory):
def create(self) -> Module:
return Sigmoid()
def resolve_nonlinearity_factory(nonlinearity_fatory: Optional[ModuleFactory]) -> ModuleFactory:
if nonlinearity_fatory is None:
return ReLUFactory(inplace=False)
else:
return nonlinearity_fatory
================================================
FILE: src/tha4/nn/normalization.py
================================================
from abc import ABC, abstractmethod
from typing import Optional
import torch
from torch import layer_norm
from torch.nn import Module, BatchNorm2d, InstanceNorm2d, Parameter
from torch.nn.init import normal_, constant_
from tha4.nn.pass_through import PassThrough
class PixelNormalization(Module):
def __init__(self, epsilon=1e-8):
super().__init__()
self.epsilon = epsilon
def forward(self, x):
return x / torch.sqrt((x ** 2).mean(dim=1, keepdim=True) + self.epsilon)
class NormalizationLayerFactory(ABC):
def __init__(self):
super().__init__()
@abstractmethod
def create(self, num_features: int, affine: bool = True) -> Module:
pass
@staticmethod
def resolve_2d(factory: Optional['NormalizationLayerFactory']) -> 'NormalizationLayerFactory':
if factory is None:
return InstanceNorm2dFactory()
else:
return factory
class Bias2d(Module):
def __init__(self, num_features: int):
super().__init__()
self.num_features = num_features
self.bias = Parameter(torch.zeros(1, num_features, 1, 1))
def forward(self, x):
return x + self.bias
class NoNorm2dFactory(NormalizationLayerFactory):
def __init__(self):
super().__init__()
def create(self, num_features: int, affine: bool = True) -> Module:
if affine:
return Bias2d(num_features)
else:
return PassThrough()
class BatchNorm2dFactory(NormalizationLayerFactory):
def __init__(self,
weight_mean: Optional[float] = None,
weight_std: Optional[float] = None,
bias: Optional[float] = None):
super().__init__()
self.bias = bias
self.weight_std = weight_std
self.weight_mean = weight_mean
def get_weight_mean(self):
if self.weight_mean is None:
return 1.0
else:
return self.weight_mean
def get_weight_std(self):
if self.weight_std is None:
return 0.02
else:
return self.weight_std
def create(self, num_features: int, affine: bool = True) -> Module:
module = BatchNorm2d(num_features=num_features, affine=affine)
if affine:
if self.weight_mean is not None or self.weight_std is not None:
normal_(module.weight, self.get_weight_mean(), self.get_weight_std())
if self.bias is not None:
constant_(module.bias, self.bias)
return module
class InstanceNorm2dFactory(NormalizationLayerFactory):
def __init__(self):
super().__init__()
def create(self, num_features: int, affine: bool = True) -> Module:
return InstanceNorm2d(num_features=num_features, affine=affine)
class PixelNormFactory(NormalizationLayerFactory):
def __init__(self):
super().__init__()
def create(self, num_features: int, affine: bool = True) -> Module:
return PixelNormalization()
class LayerNorm2d(Module):
def __init__(self, channels: int, affine: bool = True):
super(LayerNorm2d, self).__init__()
self.channels = channels
self.affine = affine
if self.affine:
self.weight = Parameter(torch.ones(1, channels, 1, 1))
self.bias = Parameter(torch.zeros(1, channels, 1, 1))
def forward(self, x):
shape = x.size()[1:]
y = layer_norm(x, shape) * self.weight + self.bias
return y
class LayerNorm2dFactory(NormalizationLayerFactory):
def __init__(self):
super().__init__()
def create(self, num_features: int, affine: bool = True) -> Module:
return LayerNorm2d(channels=num_features, affine=affine)
================================================
FILE: src/tha4/nn/pass_through.py
================================================
from torch.nn import Module
class PassThrough(Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
================================================
FILE: src/tha4/nn/resnet_block.py
================================================
from typing import Optional
import torch
from torch.nn import Module, Sequential, Parameter
from tha4.shion.core.module_factory import ModuleFactory
from tha4.nn.conv import create_conv1, create_conv3
from tha4.nn.nonlinearity_factory import resolve_nonlinearity_factory
from tha4.nn.normalization import NormalizationLayerFactory
from tha4.nn.util import BlockArgs
class ResnetBlock(Module):
@staticmethod
def create(num_channels: int,
is1x1: bool = False,
use_scale_parameters: bool = False,
block_args: Optional[BlockArgs] = None):
if block_args is None:
block_args = BlockArgs()
return ResnetBlock(num_channels,
is1x1,
block_args.initialization_method,
block_args.nonlinearity_factory,
block_args.normalization_layer_factory,
block_args.use_spectral_norm,
use_scale_parameters)
def __init__(self,
num_channels: int,
is1x1: bool = False,
initialization_method: str = 'he',
nonlinearity_factory: ModuleFactory = None,
normalization_layer_factory: Optional[NormalizationLayerFactory] = None,
use_spectral_norm: bool = False,
use_scale_parameter: bool = False):
super().__init__()
self.use_scale_parameter = use_scale_parameter
if self.use_scale_parameter:
self.scale = Parameter(torch.zeros(1))
nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory)
if is1x1:
self.resnet_path = Sequential(
create_conv1(num_channels, num_channels, initialization_method,
bias=True,
use_spectral_norm=use_spectral_norm),
nonlinearity_factory.create(),
create_conv1(num_channels, num_channels, initialization_method,
bias=True,
use_spectral_norm=use_spectral_norm))
else:
self.resnet_path = Sequential(
create_conv3(num_channels, num_channels,
bias=False, initialization_method=initialization_method,
use_spectral_norm=use_spectral_norm),
NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(num_channels, affine=True),
nonlinearity_factory.create(),
create_conv3(num_channels, num_channels,
bias=False, initialization_method=initialization_method,
use_spectral_norm=use_spectral_norm),
NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(num_channels, affine=True))
def forward(self, x):
if self.use_scale_parameter:
return x + self.scale * self.resnet_path(x)
else:
return x + self.resnet_path(x)
================================================
FILE: src/tha4/nn/resnet_block_seperable.py
================================================
from typing import Optional
import torch
from torch.nn import Module, Sequential, Parameter
from tha4.shion.core.module_factory import ModuleFactory
from tha4.nn.conv import create_conv1
from tha4.nn.nonlinearity_factory import resolve_nonlinearity_factory
from tha4.nn.normalization import NormalizationLayerFactory
from tha4.nn.separable_conv import create_separable_conv3
from tha4.nn.util import BlockArgs
class ResnetBlockSeparable(Module):
@staticmethod
def create(num_channels: int,
is1x1: bool = False,
use_scale_parameters: bool = False,
block_args: Optional[BlockArgs] = None):
if block_args is None:
block_args = BlockArgs()
return ResnetBlockSeparable(
num_channels,
is1x1,
block_args.initialization_method,
block_args.nonlinearity_factory,
block_args.normalization_layer_factory,
block_args.use_spectral_norm,
use_scale_parameters)
def __init__(self,
num_channels: int,
is1x1: bool = False,
initialization_method: str = 'he',
nonlinearity_factory: ModuleFactory = None,
normalization_layer_factory: Optional[NormalizationLayerFactory] = None,
use_spectral_norm: bool = False,
use_scale_parameter: bool = False):
super().__init__()
self.use_scale_parameter = use_scale_parameter
if self.use_scale_parameter:
self.scale = Parameter(torch.zeros(1))
nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory)
if is1x1:
self.resnet_path = Sequential(
create_conv1(num_channels, num_channels, initialization_method,
bias=True,
use_spectral_norm=use_spectral_norm),
nonlinearity_factory.create(),
create_conv1(num_channels, num_channels, initialization_method,
bias=True,
use_spectral_norm=use_spectral_norm))
else:
self.resnet_path = Sequential(
create_separable_conv3(
num_channels, num_channels,
bias=False, initialization_method=initialization_method,
use_spectral_norm=use_spectral_norm),
NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(num_channels, affine=True),
nonlinearity_factory.create(),
create_separable_conv3(
num_channels, num_channels,
bias=False, initialization_method=initialization_method,
use_spectral_norm=use_spectral_norm),
NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(num_channels, affine=True))
def forward(self, x):
if self.use_scale_parameter:
return x + self.scale * self.resnet_path(x)
else:
return x + self.resnet_path(x)
================================================
FILE: src/tha4/nn/separable_conv.py
================================================
from typing import Optional
from torch.nn import Sequential, Conv2d, ConvTranspose2d, Module
from tha4.nn.normalization import NormalizationLayerFactory
from tha4.nn.util import BlockArgs, wrap_conv_or_linear_module
def create_separable_conv3(in_channels: int, out_channels: int,
bias: bool = False,
initialization_method='he',
use_spectral_norm: bool = False) -> Module:
return Sequential(
wrap_conv_or_linear_module(
Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=False, groups=in_channels),
initialization_method,
use_spectral_norm),
wrap_conv_or_linear_module(
Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
initialization_method,
use_spectral_norm))
def create_separable_conv7(in_channels: int, out_channels: int,
bias: bool = False,
initialization_method='he',
use_spectral_norm: bool = False) -> Module:
return Sequential(
wrap_conv_or_linear_module(
Conv2d(in_channels, in_channels, kernel_size=7, stride=1, padding=3, bias=False, groups=in_channels),
initialization_method,
use_spectral_norm),
wrap_conv_or_linear_module(
Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
initialization_method,
use_spectral_norm))
def create_separable_conv3_block(
in_channels: int, out_channels: int, block_args: Optional[BlockArgs] = None):
if block_args is None:
block_args = BlockArgs()
return Sequential(
wrap_conv_or_linear_module(
Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=False, groups=in_channels),
block_args.initialization_method,
block_args.use_spectral_norm),
wrap_conv_or_linear_module(
Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
block_args.initialization_method,
block_args.use_spectral_norm),
NormalizationLayerFactory.resolve_2d(block_args.normalization_layer_factory).create(out_channels, affine=True),
block_args.nonlinearity_factory.create())
def create_separable_conv7_block(
in_channels: int, out_channels: int, block_args: Optional[BlockArgs] = None):
if block_args is None:
block_args = BlockArgs()
return Sequential(
wrap_conv_or_linear_module(
Conv2d(in_channels, in_channels, kernel_size=7, stride=1, padding=3, bias=False, groups=in_channels),
block_args.initialization_method,
block_args.use_spectral_norm),
wrap_conv_or_linear_module(
Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
block_args.initialization_method,
block_args.use_spectral_norm),
NormalizationLayerFactory.resolve_2d(block_args.normalization_layer_factory).create(out_channels, affine=True),
block_args.nonlinearity_factory.create())
def create_separable_downsample_block(
in_channels: int, out_channels: int, is_output_1x1: bool, block_args: Optional[BlockArgs] = None):
if block_args is None:
block_args = BlockArgs()
if is_output_1x1:
return Sequential(
wrap_conv_or_linear_module(
Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1, bias=False, groups=in_channels),
block_args.initialization_method,
block_args.use_spectral_norm),
wrap_conv_or_linear_module(
Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
block_args.initialization_method,
block_args.use_spectral_norm),
block_args.nonlinearity_factory.create())
else:
return Sequential(
wrap_conv_or_linear_module(
Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1, bias=False, groups=in_channels),
block_args.initialization_method,
block_args.use_spectral_norm),
wrap_conv_or_linear_module(
Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
block_args.initialization_method,
block_args.use_spectral_norm),
NormalizationLayerFactory.resolve_2d(block_args.normalization_layer_factory)
.create(out_channels, affine=True),
block_args.nonlinearity_factory.create())
def create_separable_upsample_block(
in_channels: int, out_channels: int, block_args: Optional[BlockArgs] = None):
if block_args is None:
block_args = BlockArgs()
return Sequential(
wrap_conv_or_linear_module(
ConvTranspose2d(
in_channels, in_channels, kernel_size=4, stride=2, padding=1, bias=False, groups=in_channels),
block_args.initialization_method,
block_args.use_spectral_norm),
wrap_conv_or_linear_module(
Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
block_args.initialization_method,
block_args.use_spectral_norm),
NormalizationLayerFactory.resolve_2d(block_args.normalization_layer_factory)
.create(out_channels, affine=True),
block_args.nonlinearity_factory.create())
================================================
FILE: src/tha4/nn/siren/__init__.py
================================================
================================================
FILE: src/tha4/nn/siren/face_morpher/__init__.py
================================================
================================================
FILE: src/tha4/nn/siren/face_morpher/siren_face_morpher_00.py
================================================
from typing import Optional, List
import torch
from torch import Tensor
from torch.nn import Module
from torch.nn.functional import affine_grid
from tha4.shion.core.module_factory import ModuleFactory
from tha4.nn.siren.vanilla.siren import SirenArgs, Siren
class SirenFaceMorpher00Args:
def __init__(self,
image_size: int,
image_channels: int,
pose_size: int,
siren_args: SirenArgs):
assert siren_args.in_channels == pose_size + 2
assert siren_args.out_channels == image_channels
assert not siren_args.use_tanh
self.siren_args = siren_args
self.pose_size = pose_size
self.image_size = image_size
self.image_channels = image_channels
class SirenFaceMorpher00(Module):
def __init__(self, args: SirenFaceMorpher00Args):
super().__init__()
self.args = args
self.siren = Siren(self.args.siren_args)
def forward(self, pose: Tensor, position: Optional[Tensor] = None) -> Tensor:
n, p = pose.shape[0], pose.shape[1]
device = pose.device
if position is None:
h, w = self.args.image_size, self.args.image_size
identity = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device).unsqueeze(0)
position = affine_grid(identity, [1, 1, h, w], align_corners=False) \
.view(1, h * w, 2)
position = torch.transpose(position, dim0=1, dim1=2).view(1, 2, h, w) \
.repeat(n, 1, 1, 1)
h, w = position.shape[2], position.shape[3]
pose_image = pose.view(n, p, 1, 1).repeat(1, 1, h, w)
siren_input = torch.cat([position, pose_image], dim=1)
return self.siren.forward(siren_input)
class SirenFaceMorpher00Factory(ModuleFactory):
def __init__(self, args: SirenFaceMorpher00Args):
self.args = args
def create(self) -> Module:
return SirenFaceMorpher00(self.args)
================================================
FILE: src/tha4/nn/siren/face_morpher/siren_face_morpher_00_trainer.py
================================================
from typing import Dict, List, Optional, Callable
import torch
from tha4.shion.base.dataset.lazy_tensor_dataset import LazyTensorDataset
from tha4.shion.base.image_util import extract_pytorch_image_from_filelike
from tha4.shion.base.loss.l1_loss import L1Loss, MaskedL1Loss
from tha4.shion.base.loss.sum_loss import SumLoss
from tha4.shion.base.optimizer_factories import AdamOptimizerFactory
from tha4.shion.core.training.distrib.distributed_trainer import DistributedTrainer
from tha4.dataset.image_poses_and_aother_images_dataset import ImagePosesAndOtherImagesDataset
from tha4.nn.siren.face_morpher.siren_face_morpher_00 import SirenFaceMorpher00Factory, SirenFaceMorpher00Args
from tha4.nn.siren.face_morpher.siren_face_morpher_protocols_00 import SirenFaceMorpherComputationProtocol00, \
SirenFaceMorpherSampleOutputProtocol00
from tha4.nn.siren.morpher.siren_morpher_protocols_03 import SirenMorpherTrainingProtocol03
from tha4.nn.siren.vanilla.siren import SirenArgs
from tha4.poser.poser import Poser
from torch import Tensor
KEY_MODULE = "module"
KEY_POSER = "poser"
def get_poser():
import tha4.poser.modes.mode_12
poser = tha4.poser.modes.mode_12.create_poser(torch.device('cpu'))
return poser
class SirenFaceMorpher00TrainerArgs:
def __init__(self,
character_file_name: str,
face_mask_file_name: str,
pose_dataset_file_name: str,
num_training_total_examples: int = 1_000_000,
num_training_examples_per_checkpoint: int = 100_000,
num_training_examples_lr_boundaries: Optional[List[int]] = None,
num_training_examples_per_sample_output: Optional[int] = 5_000,
num_training_examples_per_snapshot: int = 10_000,
total_batch_size: int = 8,
training_random_seed: int = 2965603729,
sample_output_random_seed: int = 3522651501,
total_worker: int = 16,
poser_func: Optional[Callable[[], Poser]] = None,
base_learning_rate: float = 1e-4):
assert num_training_total_examples % num_training_examples_per_checkpoint == 0
if num_training_examples_lr_boundaries is None:
num_training_examples_lr_boundaries = [
int(num_training_examples_per_checkpoint * 2),
int(num_training_examples_per_checkpoint * 5),
int(num_training_examples_per_checkpoint * 8),
]
for x in num_training_examples_lr_boundaries:
assert x % num_training_examples_per_snapshot == 0
if poser_func is None:
poser_func = get_poser
self.face_mask_file_name = face_mask_file_name
self.base_learning_rate = base_learning_rate
self.poser_func = poser_func
self.total_worker = total_worker
self.num_training_examples_per_snapshot = num_training_examples_per_snapshot
self.num_training_examples_per_sample_output = num_training_examples_per_sample_output
self.sample_output_random_seed = sample_output_random_seed
self.training_random_seed = training_random_seed
self.total_batch_size = total_batch_size
self.num_training_total_examples = num_training_total_examples
self.num_training_examples_per_checkpoint = num_training_examples_per_checkpoint
self.num_training_examples_lr_boundaries = num_training_examples_lr_boundaries
self.pose_dataset_file_name = pose_dataset_file_name
self.character_file_name = character_file_name
def get_character_image(self):
return extract_pytorch_image_from_filelike(
self.character_file_name,
scale=2.0,
offset=-1.0,
premultiply_alpha=True,
perform_srgb_to_linear=True)
def get_face_mask_image(self):
loaded_image = extract_pytorch_image_from_filelike(
self.face_mask_file_name,
scale=1.0,
offset=0.0,
premultiply_alpha=True,
perform_srgb_to_linear=True)
output_image = torch.zeros(4, 128, 128)
center_x = 256
center_y = 128 + 16
for i in range(4):
output_image[i, :, :] = loaded_image[0, center_y - 64:center_y + 64, center_x - 64:center_x + 64]
return output_image
def get_training_dataset(self):
return ImagePosesAndOtherImagesDataset(
main_image_func=self.get_character_image,
other_image_funcs=[self.get_face_mask_image],
pose_dataset=LazyTensorDataset(self.pose_dataset_file_name))
def get_module_factory(self):
return SirenFaceMorpher00Factory(
SirenFaceMorpher00Args(
image_size=128,
image_channels=4,
pose_size=39,
siren_args=SirenArgs(
in_channels=39 + 2,
out_channels=4,
intermediate_channels=128,
num_sine_layers=8)))
def transform_pose_to_module_input(self, pose: Tensor):
return pose[:, 0:39]
def transform_original_image_to_module_input(self, image: Tensor):
center_x = 256
center_y = 128 + 16
return image[:, :, center_y - 64:center_y + 64, center_x - 64:center_x + 64]
def transform_poser_posed_image_to_groundtruth(self, image: Tensor):
center_x = 96
center_y = 96 + 16
return image[:, :, center_y - 64:center_y + 64, center_x - 64:center_x + 64]
def get_training_computation_protocol(self):
return SirenFaceMorpherComputationProtocol00(
transform_pose_to_module_input_func=self.transform_pose_to_module_input,
transform_original_image_to_module_input_func=self.transform_original_image_to_module_input,
transform_poser_posed_image_to_groundtruth_func=self.transform_poser_posed_image_to_groundtruth)
def get_learning_rate(self, examples_seen_so_far) -> Dict[str, float]:
if examples_seen_so_far < self.num_training_examples_lr_boundaries[0]:
return {
KEY_MODULE: self.base_learning_rate,
}
elif examples_seen_so_far < self.num_training_examples_lr_boundaries[1]:
return {
KEY_MODULE: self.base_learning_rate / 3.0,
}
elif examples_seen_so_far < self.num_training_examples_lr_boundaries[2]:
return {
KEY_MODULE: self.base_learning_rate / 10.0,
}
else:
return {
KEY_MODULE: self.base_learning_rate / 30.0,
}
def get_optimizer_factories(self):
return {
KEY_MODULE: AdamOptimizerFactory(betas=(0.9, 0.999)),
}
def get_poser(self):
return self.poser_func()
def get_training_protocol(self, world_size: int):
total_examples = self.num_training_total_examples
per_checkpoint_examples = self.num_training_examples_per_checkpoint
num_checkpoints = total_examples // per_checkpoint_examples
batch_size = self.total_batch_size // world_size
return SirenMorpherTrainingProtocol03(
check_point_examples=[per_checkpoint_examples * (i + 1) for i in range(num_checkpoints)],
batch_size=batch_size,
learning_rate=self.get_learning_rate,
optimizer_factories=self.get_optimizer_factories(),
random_seed=self.training_random_seed,
poser_func=self.get_poser,
key_module=KEY_MODULE,
key_poser=KEY_POSER)
def get_sample_output_protocol(self):
return SirenFaceMorpherSampleOutputProtocol00(
num_images=8,
image_size=128,
images_per_row=2,
examples_per_sample_output=self.num_training_examples_per_sample_output,
computation_protocol=self.get_training_computation_protocol(),
poser_func=self.get_poser,
random_seed=self.sample_output_random_seed)
def get_loss(self):
protocol = self.get_training_computation_protocol()
return SumLoss([
(
'full',
L1Loss(
expected_func=protocol.get_output_func(protocol.keys.groundtruth_posed_image),
actual_func=protocol.get_output_func(protocol.keys.predicted_posed_image),
weight=1.0)
),
(
'eye_mouth',
MaskedL1Loss(
expected_func=protocol.get_output_func(protocol.keys.groundtruth_posed_image),
actual_func=protocol.get_output_func(protocol.keys.predicted_posed_image),
mask_func=protocol.get_output_func(protocol.keys.eye_mouth_mask),
weight=20.0)
),
])
def create_trainer(self, prefix: str, world_size: int, distrib_backend: str = 'gloo'):
if self.num_training_examples_per_sample_output is not None:
sample_output_protocol = self.get_sample_output_protocol()
else:
sample_output_protocol = None
return DistributedTrainer(
prefix=prefix,
module_factories={
KEY_MODULE: self.get_module_factory(),
},
accumulators={},
losses={
KEY_MODULE: self.get_loss(),
},
training_dataset=self.get_training_dataset(),
validation_dataset=self.get_training_dataset(),
training_protocol=self.get_training_protocol(world_size),
validation_protocol=None,
sample_output_protocol=sample_output_protocol,
pretrained_module_file_names={},
example_per_snapshot=self.num_training_examples_per_snapshot,
num_data_loader_workers=max(1, self.total_worker // world_size),
distrib_backend=distrib_backend)
================================================
FILE: src/tha4/nn/siren/face_morpher/siren_face_morpher_protocols_00.py
================================================
import os
from dataclasses import dataclass
from typing import Dict, Any, Optional, Callable
import PIL.Image
import numpy
import torch
from tha4.shion.base.dataset.util import get_indexed_batch
from tha4.shion.base.image_util import pytorch_rgba_to_numpy_image
from tha4.shion.core.cached_computation import CachedComputationProtocol, ComputationState, \
ComposableCachedComputationProtocol, batch_indexing_func, add_step
from tha4.shion.core.training.sample_output_protocol import SampleOutputProtocol
from tha4.poser.general_poser_02 import GeneralPoser02
from torch import Tensor
from torch.nn import Module
from torch.utils.data import Dataset
KEY_MODULE = "module"
KEY_POSER = "poser"
@dataclass
class SirenMorpherProtocol00Keys:
module: str = KEY_MODULE
module_output: str = "module_output"
poser: str = KEY_POSER
poser_output: str = "poser_output"
original_image: str = "original_image"
original_pose: str = "original_pose"
module_input_image: str = "module_input_image"
module_input_pose: str = "module_input_pose"
groundtruth_posed_image: str = 'groundtruth_posed_image'
predicted_posed_image: str = 'predicted_posed_image'
eye_mouth_mask: str = 'eye_mouth_mask'
@dataclass
class SirenMorpherProtocol00Indices:
batch_original_image: int = 0
batch_pose: int = 1
batch_eye_mouth_mask: int = 2
poser_posed_image: int = 0
class SirenFaceMorpherComputationProtocol00(ComposableCachedComputationProtocol):
def __init__(self,
transform_pose_to_module_input_func: Callable[[Tensor], Tensor],
transform_original_image_to_module_input_func: Callable[[Tensor], Tensor],
transform_poser_posed_image_to_groundtruth_func: Callable[[Tensor], Tensor],
keys: Optional[SirenMorpherProtocol00Keys] = None,
indices: Optional[SirenMorpherProtocol00Indices] = None):
super().__init__()
if keys is None:
keys = SirenMorpherProtocol00Keys()
if indices is None:
indices = SirenMorpherProtocol00Indices()
self.keys = keys
self.indices = indices
self.transform_image_to_module_input_func = transform_original_image_to_module_input_func
self.transform_pose_to_module_input_func = transform_pose_to_module_input_func
self.transform_poser_posed_image_to_groundtruth_func = transform_poser_posed_image_to_groundtruth_func
self.computation_steps[keys.original_image] = batch_indexing_func(indices.batch_original_image)
self.computation_steps[keys.original_pose] = batch_indexing_func(indices.batch_pose)
@add_step(self.computation_steps, keys.module_input_pose)
def get_module_input_pose(protocol: CachedComputationProtocol, state: ComputationState):
original_pose = protocol.get_output(keys.original_pose, state)
return transform_pose_to_module_input_func(original_pose)
@add_step(self.computation_steps, keys.module_input_image)
def get_module_input_image(protocol: CachedComputationProtocol, state: ComputationState):
original_image = protocol.get_output(keys.original_image, state)
return transform_original_image_to_module_input_func(original_image)
@add_step(self.computation_steps, keys.poser_output)
def get_poser_output(protocol: CachedComputationProtocol, state: ComputationState):
with torch.no_grad():
poser = state.modules[keys.poser]
pose = protocol.get_output(keys.original_pose, state)
image = protocol.get_output(keys.original_image, state)
return poser.get_posing_outputs(image, pose)
@add_step(self.computation_steps, keys.groundtruth_posed_image)
def get_groundtruth_posed_image(protocol: CachedComputationProtocol, state: ComputationState):
poser_output = protocol.get_output(keys.poser_output, state)
poser_posed_image = poser_output[indices.poser_posed_image]
return transform_poser_posed_image_to_groundtruth_func(poser_posed_image)
@add_step(self.computation_steps, keys.module_output)
def get_module_output(protocol: CachedComputationProtocol, state: ComputationState):
module_input_pose = protocol.get_output(keys.module_input_pose, state)
module = state.modules[keys.module]
return module.forward(module_input_pose)
@add_step(self.computation_steps, keys.predicted_posed_image)
def get_predicted_image(protocol: CachedComputationProtocol, state: ComputationState):
return protocol.get_output(keys.module_output, state)
self.computation_steps[keys.eye_mouth_mask] = batch_indexing_func(indices.batch_eye_mouth_mask)
class SirenFaceMorpherSampleOutputProtocol00(SampleOutputProtocol):
def __init__(self,
num_images: int,
image_size: int,
images_per_row: int,
examples_per_sample_output: int,
computation_protocol: SirenFaceMorpherComputationProtocol00,
poser_func: Callable[[], GeneralPoser02],
random_seed: int = 54859395058,
batch_size: Optional[int] = None):
if batch_size is None:
batch_size = num_images
self.batch_size = batch_size
self.poser_func = poser_func
self.random_seed = random_seed
self.examples_per_sample_output = examples_per_sample_output
self.images_per_row = images_per_row
self.image_size = image_size
self.num_images = num_images
self.computation_protocol = computation_protocol
def get_examples_per_sample_output(self) -> int:
return self.examples_per_sample_output
def get_random_seed(self) -> int:
return self.random_seed
def get_sample_output_data(self, validation_dataset: Dataset, device: torch.device) -> dict:
example_indices = torch.randint(0, len(validation_dataset), (self.num_images,))
example_indices = [example_indices[i].item() for i in range(self.num_images)]
batch = get_indexed_batch(validation_dataset, example_indices, device)
poser = self.poser_func()
poser.to(device)
with torch.no_grad():
ground_truth = poser.pose(
batch[self.computation_protocol.indices.batch_original_image],
batch[self.computation_protocol.indices.batch_pose])
return {
'batch': batch,
'ground_truth': ground_truth
}
def save_sample_output_data(self,
modules: Dict[str, Module],
accumulated_modules: Dict[str, Module],
sample_output_data: Any,
prefix: str,
examples_seen_so_far: int,
device: torch.device):
batch = sample_output_data['batch']
ground_truth = sample_output_data['ground_truth']
ground_truth = self.computation_protocol.transform_poser_posed_image_to_groundtruth_func(ground_truth)
module = modules[self.computation_protocol.keys.module]
module.train(False)
if self.batch_size == self.num_images:
with torch.no_grad():
state = ComputationState(
modules=modules,
accumulated_modules=accumulated_modules,
batch=batch,
outputs={})
poser_output_images = self.computation_protocol.get_output(
self.computation_protocol.keys.predicted_posed_image, state)
else:
poser_output_images_list = []
start = 0
while start < self.num_images:
end = start + self.batch_size
end = min(self.num_images, end)
minibatch = [batch[i][start:end] for i in range(len(batch))]
state = ComputationState(
modules=modules,
accumulated_modules=accumulated_modules,
batch=minibatch,
outputs={})
with torch.no_grad():
poser_output_images = self.computation_protocol.get_output(
self.computation_protocol.keys.predicted_posed_image, state)
poser_output_images_list.append(poser_output_images)
start = end
poser_output_images = torch.cat(poser_output_images_list, dim=0)
num_rows = self.num_images // self.images_per_row
if self.num_images % self.images_per_row > 0:
num_rows += 1
num_cols = 2 * self.images_per_row
image_channels = 4
output_image = numpy.zeros([self.image_size * num_rows, self.image_size * num_cols, image_channels])
for image_index in range(self.num_images):
row = image_index // self.images_per_row
start_row = row * self.image_size
col = 2 * (image_index % self.images_per_row)
start_col = col * self.image_size
output_image[start_row:start_row + self.image_size, start_col:start_col + self.image_size, :] \
= pytorch_rgba_to_numpy_image(ground_truth[image_index].detach().cpu())
start_col += self.image_size
output_image[start_row:start_row + self.image_size, start_col:start_col + self.image_size, :] \
= pytorch_rgba_to_numpy_image(poser_output_images[image_index].detach().cpu())
file_name = "%s/sample_output_%010d.png" % (prefix, examples_seen_so_far)
os.makedirs(os.path.dirname(file_name), exist_ok=True)
pil_image = PIL.Image.fromarray(numpy.uint8(numpy.rint(output_image * 255.0)), mode='RGBA')
pil_image.save(file_name)
print("Saved %s" % file_name)
================================================
FILE: src/tha4/nn/siren/morpher/__init__.py
================================================
================================================
FILE: src/tha4/nn/siren/morpher/siren_morpher_03.py
================================================
from typing import List, Optional, Callable
import torch
from torch import Tensor
from torch.nn import Module, ModuleList, Sequential, Conv2d
from torch.nn.functional import affine_grid, interpolate
from tha4.shion.core.module_factory import ModuleFactory
from tha4.shion.nn00.initialization_funcs import HeInitialization
from tha4.nn.image_processing_util import GridChangeApplier
from tha4.nn.siren.vanilla.siren import SineLinearLayer
class SirenMorpherLevelArgs:
def __init__(self,
image_size: int,
intermediate_channels: int,
num_sine_layers: int):
assert num_sine_layers >= 2
self.image_size = image_size
self.num_sine_layers = num_sine_layers
self.intermediate_channels = intermediate_channels
class SirenMorpher03Args:
def __init__(self,
image_size: int,
image_channels: int,
pose_size: int,
level_args: List[SirenMorpherLevelArgs],
init_func: Optional[Callable[[Module], Module]] = None):
assert len(level_args) >= 2
if init_func is None:
init_func = HeInitialization()
self.image_size = image_size
self.init_func = init_func
self.level_args = level_args
self.pose_size = pose_size
self.image_channels = image_channels
class SirenMorpher03(Module):
def __init__(self, args: SirenMorpher03Args):
super().__init__()
self.args = args
self.siren_layers = ModuleList()
for i in range(len(args.level_args)):
level_args = args.level_args[i]
layers = []
if i == 0:
layers.append(SineLinearLayer(
in_channels=args.pose_size + 2,
out_channels=level_args.intermediate_channels,
is_first=True))
else:
layers.append(SineLinearLayer(
in_channels=level_args.intermediate_channels + args.pose_size + 2,
out_channels=level_args.intermediate_channels,
is_first=False))
for j in range(1, level_args.num_sine_layers - 1):
layers.append(SineLinearLayer(
in_channels=level_args.intermediate_channels,
out_channels=level_args.intermediate_channels,
is_first=False))
if i == len(args.level_args) - 1:
out_channels = level_args.intermediate_channels
else:
out_channels = args.level_args[i + 1].intermediate_channels
layers.append(SineLinearLayer(
in_channels=level_args.intermediate_channels,
out_channels=out_channels,
is_first=False))
self.siren_layers.append(Sequential(*layers))
self.last_linear = args.init_func(Conv2d(
args.level_args[-1].intermediate_channels,
args.image_channels + 2 + 1,
kernel_size=1,
stride=1,
padding=0,
bias=True))
self.grid_change_applier = GridChangeApplier()
def get_position_grid(self, n: int, image_size: int, device: torch.device):
h, w = image_size, image_size
identity = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device).unsqueeze(0)
position = affine_grid(identity, [1, 1, h, w], align_corners=False) \
.view(1, h * w, 2)
position = torch.transpose(position, dim0=1, dim1=2).view(1, 2, h, w) \
.repeat(n, 1, 1, 1)
return position
def get_pose_image(self, pose: Tensor, image_size: int):
n, p = pose.shape[0], pose.shape[1]
h, w = image_size, image_size
pose_image = pose.view(n, p, 1, 1).repeat(1, 1, h, w)
return pose_image
def forward(self, image: Tensor, pose: Tensor) -> List[Tensor]:
n = pose.shape[0]
device = pose.device
x = None
for i in range(len(self.args.level_args)):
args = self.args.level_args[i]
position_and_pose = torch.cat([
self.get_position_grid(n, args.image_size, device),
self.get_pose_image(pose, args.image_size)
], dim=1)
if i == 0:
x = self.siren_layers[i].forward(position_and_pose)
else:
x = interpolate(x, size=(args.image_size, args.image_size), mode='bilinear')
x = torch.cat([x, position_and_pose], dim=1)
x = self.siren_layers[i].forward(x)
siren_output = self.last_linear(x)
grid_change = siren_output[:, 0:2, :, :]
alpha = siren_output[:, 2:3, :, :]
color_change = siren_output[:, 3:, :, :]
warped_image = self.grid_change_applier.apply(grid_change, image, align_corners=False)
blended_image = (1 - alpha) * warped_image + alpha * color_change
return [
blended_image,
alpha,
color_change,
warped_image,
grid_change
]
INDEX_BLENDED_IMAGE = 0
INDEX_ALPHA = 1
INDEX_COLOR_CHANGE = 2
INDEX_WARPED_IMAGE = 3
INDEX_GRID_CHANGE = 4
class SirenMorpher03Factory(ModuleFactory):
def __init__(self, args: SirenMorpher03Args):
self.args = args
def create(self):
return SirenMorpher03(self.args)
================================================
FILE: src/tha4/nn/siren/morpher/siren_morpher_03_trainer.py
================================================
from enum import Enum
from typing import Dict, List, Optional, Callable
import torch
from tha4.shion.base.dataset.lazy_tensor_dataset import LazyTensorDataset
from tha4.shion.base.image_util import extract_pytorch_image_from_filelike
from tha4.shion.base.loss.l1_loss import L1Loss
from tha4.shion.base.loss.sum_loss import SumLoss
from tha4.shion.base.loss.time_dependently_weighted_loss import TimeDependentlyWeightedLoss
from tha4.shion.base.optimizer_factories import AdamOptimizerFactory
from tha4.shion.core.training.distrib.distributed_trainer import DistributedTrainer
from tha4.dataset.image_poses_and_aother_images_dataset import ImagePosesAndOtherImagesDataset
from tha4.nn.siren.morpher.siren_morpher_03 import SirenMorpherLevelArgs, SirenMorpher03Factory, SirenMorpher03Args
from tha4.nn.siren.morpher.siren_morpher_protocols_03 import SirenMorpherComputationProtocol03, \
SirenMorpherProtocol03Indices, KEY_MODULE, KEY_POSER, KEY_EXAMPLES_SEEN_SO_FAR, SirenMorpherTrainingProtocol03, \
SirenMorpherSampleOutputProtocol
from tha4.poser.poser import Poser
def get_poser():
import tha4.poser.modes.mode_07
poser = tha4.poser.modes.mode_07.create_poser(torch.device('cpu'))
return poser
class LossTerm(Enum):
full_blended = 1
full_warped = 2
full_grid_change = 3
full_color_change = 4
def get_loss(self, protocol: SirenMorpherComputationProtocol03):
if self == LossTerm.full_blended:
return L1Loss(
expected_func=protocol.get_output_func(protocol.keys.groundtruth_posed_image),
actual_func=protocol.get_output_func(protocol.keys.predicted_posed_image))
elif self == LossTerm.full_warped:
return L1Loss(
expected_func=protocol.get_output_func(protocol.keys.groundtruth_warped_image),
actual_func=protocol.get_output_func(protocol.keys.predicted_warped_image))
elif self == LossTerm.full_grid_change:
return L1Loss(
expected_func=protocol.get_output_func(protocol.keys.groundtruth_grid_change),
actual_func=protocol.get_output_func(protocol.keys.predicted_grid_change))
elif self == LossTerm.full_color_change:
return L1Loss(
expected_func=protocol.get_output_func(protocol.keys.groundtruth_posed_image),
actual_func=protocol.get_output_func(protocol.keys.predicted_color_change))
else:
raise RuntimeError(f"Unsupported loss term {self}")
class LossWeights:
def __init__(self, weights: Optional[Dict[LossTerm, float]] = None):
self.weights = {}
for term in LossTerm:
self.weights[term] = 0.0
if weights is not None:
for term in LossTerm:
if term in weights:
self.weights[term] = weights[term]
class TrainingPhase:
def __init__(self,
num_examples_upper_bound: int,
learning_rate: float,
loss_weights: LossWeights):
self.loss_weights = loss_weights
self.learning_rate = learning_rate
self.num_examples_upper_bound = num_examples_upper_bound
class LearningRateFunc:
def __init__(self, phases: List[TrainingPhase], keys: List[str]):
self.phases = phases
self.keys = keys
def make_learning_rate_dict(self, keys: List[str], value: float):
output = {}
for key in keys:
output[key] = value
return output
def __call__(self, examples_seen_so_far: int) -> Dict[str, float]:
for i in range(len(self.phases) - 1):
if examples_seen_so_far < self.phases[i].num_examples_upper_bound:
return self.make_learning_rate_dict(self.keys, self.phases[i].learning_rate)
return self.make_learning_rate_dict(self.keys, self.phases[-1].learning_rate)
class LossWeightFunc:
def __init__(self, phases: List[TrainingPhase], term: LossTerm):
self.term = term
self.phases = phases
def __call__(self, examples_seen_so_far: int) -> float:
for i in range(len(self.phases) - 1):
if examples_seen_so_far < self.phases[i].num_examples_upper_bound:
return self.phases[i].loss_weights.weights[self.term]
return self.phases[-1].loss_weights.weights[self.term]
class TrainingPhases:
def __init__(self, phases: List[TrainingPhase]):
assert len(phases) > 0
for i in range(1, len(phases)):
assert phases[i - 1].num_examples_upper_bound < phases[i].num_examples_upper_bound
self.phases = phases
def make_learning_rate_dict(self, keys: List[str], value: float):
output = {}
for key in keys:
output[key] = value
return output
def get_learning_rate_func(self, keys: List[str]):
return LearningRateFunc(self.phases, keys)
def get_loss_weight_func(self, term: LossTerm) -> Callable[[int], float]:
return LossWeightFunc(self.phases, term)
class SirenMorpher03TrainerArgs:
def __init__(self,
character_file_name: str,
pose_dataset_file_name: str,
training_phases: TrainingPhases,
num_training_examples_per_checkpoint: int = 100_000,
num_training_examples_per_sample_output: Optional[int] = 10_000,
num_training_examples_per_snapshot: int = 10_000,
total_batch_size: int = 8,
training_random_seed: int = 2965603729,
sample_output_random_seed: int = 3522651501,
total_worker: int = 8,
poser_func: Optional[Callable[[], Poser]] = None,
sample_output_batch_size: Optional[int] = None,
pretrained_module_file_name: Optional[str] = None):
for phase in training_phases.phases:
assert phase.num_examples_upper_bound % num_training_examples_per_checkpoint == 0
if poser_func is None:
poser_func = get_poser
self.training_phases = training_phases
self.pretrained_module_file_name = pretrained_module_file_name
self.sample_output_batch_size = sample_output_batch_size
self.poser_func = poser_func
self.total_worker = total_worker
self.num_training_examples_per_snapshot = num_training_examples_per_snapshot
self.num_training_examples_per_sample_output = num_training_examples_per_sample_output
self.sample_output_random_seed = sample_output_random_seed
self.training_random_seed = training_random_seed
self.total_batch_size = total_batch_size
self.num_training_examples_per_checkpoint = num_training_examples_per_checkpoint
self.pose_dataset_file_name = pose_dataset_file_name
self.character_file_name = character_file_name
def get_character_image(self):
return extract_pytorch_image_from_filelike(
self.character_file_name,
scale=2.0,
offset=-1.0,
premultiply_alpha=True,
perform_srgb_to_linear=True)
def get_training_dataset(self):
return ImagePosesAndOtherImagesDataset(
main_image_func=self.get_character_image,
pose_dataset=LazyTensorDataset(self.pose_dataset_file_name),
other_image_funcs=[])
def get_module_factory(self):
return SirenMorpher03Factory(
SirenMorpher03Args(
image_size=512,
image_channels=4,
pose_size=45,
level_args=[
SirenMorpherLevelArgs(
image_size=128,
intermediate_channels=360,
num_sine_layers=3),
SirenMorpherLevelArgs(
image_size=256,
intermediate_channels=180,
num_sine_layers=3),
SirenMorpherLevelArgs(
image_size=512,
intermediate_channels=90,
num_sine_layers=3),
]))
def get_training_computation_protocol(self):
return SirenMorpherComputationProtocol03(
indices=SirenMorpherProtocol03Indices(
batch_image=0,
batch_pose=1,
batch_face_mask=2))
def get_optimizer_factories(self):
return {
KEY_MODULE: AdamOptimizerFactory(betas=(0.9, 0.999)),
}
def get_poser(self):
return self.poser_func()
def get_training_protocol(self, world_size: int):
total_examples = self.training_phases.phases[-1].num_examples_upper_bound
per_checkpoint_examples = self.num_training_examples_per_checkpoint
num_checkpoints = total_examples // per_checkpoint_examples
batch_size = self.total_batch_size // world_size
return SirenMorpherTrainingProtocol03(
check_point_examples=[per_checkpoint_examples * (i + 1) for i in range(num_checkpoints)],
batch_size=batch_size,
learning_rate=self.training_phases.get_learning_rate_func([KEY_MODULE]),
optimizer_factories=self.get_optimizer_factories(),
random_seed=self.training_random_seed,
poser_func=self.get_poser,
key_module=KEY_MODULE,
key_poser=KEY_POSER)
def get_sample_output_protocol(self):
return SirenMorpherSampleOutputProtocol(
num_images=4,
image_size=512,
examples_per_sample_output=self.num_training_examples_per_sample_output,
computation_protocol=self.get_training_computation_protocol(),
poser_func=self.get_poser,
random_seed=self.sample_output_random_seed,
batch_size=self.sample_output_batch_size,
batch_pose_index=1,
batch_image_index=0)
def get_loss(self):
protocol = self.get_training_computation_protocol()
losses = []
for term in LossTerm:
base_loss = term.get_loss(protocol)
loss = TimeDependentlyWeightedLoss(
base_loss,
examples_seen_so_far_func=lambda state: state.outputs[KEY_EXAMPLES_SEEN_SO_FAR],
weight_func=self.training_phases.get_loss_weight_func(term))
losses.append((term.name, loss))
return SumLoss(losses)
def create_trainer(self, prefix: str, world_size: int, distrib_backend: str = 'gloo'):
if self.num_training_examples_per_sample_output is not None:
sample_output_protocol = self.get_sample_output_protocol()
else:
sample_output_protocol = None
pretrained_module_file_names = {}
if self.pretrained_module_file_name is not None:
pretrained_module_file_names[KEY_MODULE] = self.pretrained_module_file_name
return DistributedTrainer(
prefix=prefix,
module_factories={
KEY_MODULE: self.get_module_factory(),
},
accumulators={},
losses={
KEY_MODULE: self.get_loss(),
},
training_dataset=self.get_training_dataset(),
validation_dataset=self.get_training_dataset(),
training_protocol=self.get_training_protocol(world_size),
validation_protocol=None,
sample_output_protocol=sample_output_protocol,
pretrained_module_file_names=pretrained_module_file_names,
example_per_snapshot=self.num_training_examples_per_snapshot,
num_data_loader_workers=max(1, self.total_worker // world_size),
distrib_backend=distrib_backend)
================================================
FILE: src/tha4/nn/siren/morpher/siren_morpher_protocols_03.py
================================================
from dataclasses import dataclass
from typing import Optional, List, Callable, Dict, Any
import torch
from tha4.shion.base.dataset.util import get_indexed_batch
from tha4.shion.core.cached_computation import output_array_indexing_func, add_step, ComputationState, \
CachedComputationProtocol, ComposableCachedComputationProtocol, batch_indexing_func, proxy_func
from tha4.shion.core.loss import Loss
from tha4.shion.core.optimizer_factory import OptimizerFactory
from tha4.shion.core.training.sample_output_protocol import SampleOutputProtocol
from tha4.shion.core.training.training_protocol import AbstractTrainingProtocol
from tha4.nn.image_processing_util import GridChangeApplier
from tha4.nn.siren.morpher.siren_morpher_03 import SirenMorpher03
from tha4.poser.general_poser_02 import GeneralPoser02
from tha4.sampleoutput.sample_image_creator import SampleImageSpec, ImageSource, ImageType, SampleImageSaver
from torch.nn import Module
from torch.optim import Optimizer
from torch.utils.data import Dataset
KEY_MODULE = "module"
KEY_POSER = "poser"
KEY_EXAMPLES_SEEN_SO_FAR = "examples_seen_so_far"
@dataclass
class SirenMorpherProtocol03Keys:
module: str = KEY_MODULE
module_output: str = "module_output"
poser: str = KEY_POSER
poser_output: str = "poser_output"
image: str = "image"
pose: str = "pose"
face_mask: str = 'face_mask'
groundtruth_posed_image: str = 'groundtruth_posed_image'
groundtruth_grid_change: str = 'groundtruth_grid_change'
groundtruth_posed_face_mask: str = 'groundtruth_posed_face_mask'
predicted_posed_image: str = 'predicted_posed_image'
module_input_image: str = "module_input_image"
predicted_grid_change: str = "predicted_grid_change"
predicted_color_change: str = "predicted_color_change"
predicted_warped_image: str = "predicted_warped_image"
predicted_alpha: str = "predicted_alpha"
groundtruth_alpha: str = "groundtruth_alpha"
groundtruth_warped_image: str = "groundtruth_warped_image"
zero: str = "zero"
@dataclass
class SirenMorpherProtocol03Indices:
batch_image: int = 0
batch_face_mask: int = 1
batch_pose: int = 2
poser_posed_image: int = 0
poser_grid_change: int = 3
poser_output_module_input_image_index: int = 5
poser_alpha: int = 1
poser_warped_image: int = 2
module_blended_image: int = SirenMorpher03.INDEX_BLENDED_IMAGE
module_grid_change: int = SirenMorpher03.INDEX_GRID_CHANGE
module_color_change: int = SirenMorpher03.INDEX_COLOR_CHANGE
module_warped_image: int = SirenMorpher03.INDEX_WARPED_IMAGE
module_alpha: int = SirenMorpher03.INDEX_ALPHA
class SirenMorpherComputationProtocol03(ComposableCachedComputationProtocol):
def __init__(self,
keys: Optional[SirenMorpherProtocol03Keys] = None,
indices: Optional[SirenMorpherProtocol03Indices] = None):
super().__init__()
if keys is None:
keys = SirenMorpherProtocol03Keys()
if indices is None:
indices = SirenMorpherProtocol03Indices()
self.keys = keys
self.indices = indices
self.computation_steps[keys.image] = batch_indexing_func(indices.batch_image)
self.computation_steps[keys.pose] = batch_indexing_func(indices.batch_pose)
self.computation_steps[keys.face_mask] = batch_indexing_func(indices.batch_face_mask)
self.grid_change_applier = GridChangeApplier()
@add_step(self.computation_steps, keys.module_output)
def get_module_output(protocol: CachedComputationProtocol, state: ComputationState):
pose = protocol.get_output(keys.pose, state)
module = state.modules[self.keys.module]
return module.forward(pose)
self.computation_steps[keys.predicted_posed_image] = proxy_func(keys.module_output)
@add_step(self.computation_steps, keys.poser_output)
def get_poser_output(protocol: CachedComputationProtocol, state: ComputationState):
with torch.no_grad():
poser = state.modules[keys.poser]
pose = protocol.get_output(keys.pose, state)
image = protocol.get_output(keys.image, state)
return poser.get_posing_outputs(image, pose)
@add_step(self.computation_steps, keys.groundtruth_posed_image)
def get_groundtruth_posed_image(protocol: CachedComputationProtocol, state: ComputationState):
return protocol.get_output(keys.poser_output, state)[indices.poser_posed_image]
@add_step(self.computation_steps, keys.groundtruth_grid_change)
def get_groundtruth_posed_image(protocol: CachedComputationProtocol, state: ComputationState):
return protocol.get_output(keys.poser_output, state)[indices.poser_grid_change]
@add_step(self.computation_steps, keys.groundtruth_posed_face_mask)
def get_groundtruth_posed_face_mask(protocol: CachedComputationProtocol, state: ComputationState):
face_mask = protocol.get_output(keys.face_mask, state)
groundtruth_grid_change = protocol.get_output(keys.groundtruth_grid_change, state)
with torch.no_grad():
return self.grid_change_applier.apply(groundtruth_grid_change, face_mask)
@add_step(self.computation_steps, keys.module_input_image)
def get_module_input_image(protocol: CachedComputationProtocol, state: ComputationState):
poser_output = protocol.get_output(keys.poser_output, state)
return poser_output[indices.poser_output_module_input_image_index]
@add_step(self.computation_steps, keys.module_output)
def get_module_output(protocol: CachedComputationProtocol, state: ComputationState):
image = protocol.get_output(keys.module_input_image, state)
pose = protocol.get_output(keys.pose, state)
module = state.modules[self.keys.module]
return module.forward(image, pose)
self.computation_steps[keys.predicted_posed_image] = output_array_indexing_func(
keys.module_output, indices.module_blended_image)
self.computation_steps[keys.predicted_grid_change] = output_array_indexing_func(
keys.module_output, indices.module_grid_change)
self.computation_steps[keys.predicted_color_change] = output_array_indexing_func(
keys.module_output, indices.module_color_change)
self.computation_steps[keys.predicted_warped_image] = output_array_indexing_func(
keys.module_output, indices.module_warped_image)
self.computation_steps[keys.predicted_alpha] = output_array_indexing_func(
keys.module_output, indices.module_alpha)
self.computation_steps[keys.groundtruth_alpha] = output_array_indexing_func(
keys.poser_output, indices.poser_alpha)
self.computation_steps[keys.groundtruth_warped_image] = output_array_indexing_func(
keys.poser_output, indices.poser_warped_image)
@add_step(self.computation_steps, keys.zero)
def get_zero(protocol: CachedComputationProtocol, state: ComputationState):
pose = protocol.get_output(keys.pose, state)
device = pose.device
return torch.zeros(1, device=device)
class SirenMorpherTrainingProtocol03(AbstractTrainingProtocol):
def __init__(self,
check_point_examples: List[int],
batch_size: int,
learning_rate: Callable[[int], Dict[str, float]],
optimizer_factories: Dict[str, OptimizerFactory],
random_seed: int,
poser_func: Callable[[], GeneralPoser02],
key_module: str,
key_poser: str = KEY_POSER,
key_examples_seen_so_far: str = KEY_EXAMPLES_SEEN_SO_FAR):
super().__init__(check_point_examples, batch_size, learning_rate, optimizer_factories, random_seed)
self.key_examples_seen_so_far = key_examples_seen_so_far
self.key_poser = key_poser
self.key_module = key_module
self.poser_func = poser_func
self.poser = None
def run_training_iteration(
self,
batch: Any,
examples_seen_so_far: int,
modules: Dict[str, Module],
accumulated_modules: Dict[str, Module],
optimizers: Dict[str, Optimizer],
losses: Dict[str, Loss],
create_log_func: Optional[Callable[[str, int], Callable[[str, float], None]]],
device: torch.device):
if self.poser is None:
self.poser = self.poser_func()
self.poser.to(device)
module = modules[self.key_module]
module.train(True)
module_optimizer = optimizers[self.key_module]
module_optimizer.zero_grad(set_to_none=True)
loss = losses[self.key_module]
if create_log_func is not None:
log_func = create_log_func(f"training_{self.key_module}", examples_seen_so_far)
else:
log_func = None
state = ComputationState(
modules={
**modules,
self.key_poser: self.poser,
},
accumulated_modules=accumulated_modules,
batch=batch,
outputs={
self.key_examples_seen_so_far: examples_seen_so_far,
})
loss_value = loss.compute(state, log_func)
loss_value.backward()
module_optimizer.step()
class SirenMorpherSampleOutputProtocol(SampleOutputProtocol):
def __init__(self,
num_images: int,
image_size: int,
examples_per_sample_output: int,
computation_protocol,
poser_func: Callable[[], GeneralPoser02],
random_seed: int = 54859395058,
batch_image_index: int = 0,
batch_pose_index: int = 2,
batch_size: Optional[int] = None,
sample_image_specs: Optional[List[SampleImageSpec]] = None,
cell_size: Optional[int] = None):
if batch_size is None:
batch_size = num_images
if sample_image_specs is None:
sample_image_specs = [
SampleImageSpec(
ImageSource.BATCH,
computation_protocol.indices.poser_posed_image,
ImageType.COLOR),
SampleImageSpec(
ImageSource.OUTPUT,
computation_protocol.indices.module_blended_image,
ImageType.COLOR),
SampleImageSpec(
ImageSource.OUTPUT,
computation_protocol.indices.module_alpha,
ImageType.ALPHA),
SampleImageSpec(
ImageSource.OUTPUT,
computation_protocol.indices.module_color_change,
ImageType.COLOR),
SampleImageSpec(
ImageSource.OUTPUT,
computation_protocol.indices.module_warped_image,
ImageType.COLOR),
SampleImageSpec(
ImageSource.BATCH,
computation_protocol.indices.poser_grid_change,
ImageType.GRID_CHANGE),
SampleImageSpec(
ImageSource.OUTPUT,
computation_protocol.indices.module_grid_change,
ImageType.GRID_CHANGE),
]
if cell_size is None:
cell_size = image_size
self.batch_size = batch_size
self.batch_pose_index = batch_pose_index
self.batch_image_index = batch_image_index
self.poser_func = poser_func
self.random_seed = random_seed
self.examples_per_sample_output = examples_per_sample_output
self.image_size = image_size
self.num_images = num_images
self.computation_protocol = computation_protocol
self.cell_size = cell_size
self.sample_image_saver = SampleImageSaver(image_size, cell_size, 4, sample_image_specs)
def get_examples_per_sample_output(self) -> int:
return self.examples_per_sample_output
def get_random_seed(self) -> int:
return self.random_seed
def get_sample_output_data(self, validation_dataset: Dataset, device: torch.device) -> dict:
example_indices = torch.randint(0, len(validation_dataset), (self.num_images,))
example_indices = [example_indices[i].item() for i in range(self.num_images)]
batch = get_indexed_batch(validation_dataset, example_indices, device)
poser = self.poser_func()
poser.to(device)
with torch.no_grad():
ground_truth = poser.get_posing_outputs(batch[self.batch_image_index], batch[self.batch_pose_index])
return {
'batch': batch,
'ground_truth': ground_truth
}
def save_sample_output_data(self,
modules: Dict[str, Module],
accumulated_modules: Dict[str, Module],
sample_output_data: Any,
prefix: str,
examples_seen_so_far: int,
device: torch.device):
batch = sample_output_data['batch']
ground_truth = sample_output_data['ground_truth']
module = modules[self.computation_protocol.keys.module]
module.train(False)
if self.batch_size == self.num_images:
with torch.no_grad():
state = ComputationState(
modules=modules,
accumulated_modules=accumulated_modules,
batch=batch,
outputs={
self.computation_protocol.keys.poser_output: ground_truth,
})
module_outputs = self.computation_protocol.get_output(
self.computation_protocol.keys.module_output, state)
else:
module_outputs_list = []
start = 0
while start < self.num_images:
end = start + self.batch_size
end = min(self.num_images, end)
minibatch = [batch[i][start:end] for i in range(len(batch))]
ground_truth_batch = [ground_truth[i][start:end] for i in range(len(ground_truth))]
state = ComputationState(
modules=modules,
accumulated_modules=accumulated_modules,
batch=minibatch,
outputs={
self.computation_protocol.keys.poser_output: ground_truth_batch
})
with torch.no_grad():
module_outputs = self.computation_protocol.get_output(
self.computation_protocol.keys.module_output, state)
module_outputs_list.append(module_outputs)
start = end
module_outputs = []
for i in range(len(module_outputs_list[0])):
tensor_list = []
for j in range(len(module_outputs_list)):
tensor_list.append(module_outputs_list[j][i])
module_output = torch.cat(tensor_list, dim=0)
module_outputs.append(module_output)
self.sample_image_saver.save_sample_output_data(ground_truth, module_outputs, prefix, examples_seen_so_far)
================================================
FILE: src/tha4/nn/siren/vanilla/__init__.py
================================================
================================================
FILE: src/tha4/nn/siren/vanilla/siren.py
================================================
import math
from typing import Callable, Optional, List
import torch
from torch import Tensor
from torch.nn import Module, Conv2d, ModuleList
from tha4.shion.core.module_factory import ModuleFactory
from tha4.shion.nn00.initialization_funcs import HeInitialization
class SineLinearLayer(Module):
def __init__(self,
in_channels: int,
out_channels: int,
is_first=False,
omega_0=30.0):
super().__init__()
self.out_channels = out_channels
self.in_channels = in_channels
self.omega_0 = omega_0
self.is_first = is_first
self.linear = Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=True)
with torch.no_grad():
if self.is_first:
self.linear.weight.uniform_(-1 / in_channels, 1.0 / in_channels)
else:
self.linear.weight.uniform_(
-math.sqrt(6.0 / in_channels) / self.omega_0,
math.sqrt(6.0 / in_channels) / self.omega_0)
def forward(self, x: Tensor):
return torch.sin(self.omega_0 * self.linear(x))
class SirenArgs:
def __init__(
self,
in_channels: int,
out_channels: int,
intermediate_channels: int,
num_sine_layers: int,
use_tanh: bool = False,
init_func: Optional[Callable[[Module], Module]] = None):
if init_func is None:
init_func = HeInitialization()
self.init_func = init_func
self.use_tanh = use_tanh
assert num_sine_layers >= 1
self.intermediate_channels = intermediate_channels
self.num_sine_layers = num_sine_layers
self.out_channels = out_channels
self.in_channels = in_channels
class Siren(Module):
def __init__(self, args: SirenArgs):
super().__init__()
self.args = args
self.sine_layers = ModuleList()
self.sine_layers.append(
SineLinearLayer(
in_channels=args.in_channels, out_channels=args.intermediate_channels, is_first=True))
for i in range(args.num_sine_layers - 1):
self.sine_layers.append(
SineLinearLayer(
in_channels=args.intermediate_channels,
out_channels=args.intermediate_channels,
is_first=False))
self.last_linear = args.init_func(Conv2d(
args.intermediate_channels,
args.out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=True))
def forward(self, x: Tensor) -> Tensor:
for i in range(self.args.num_sine_layers):
x = self.sine_layers[i].forward(x)
x = self.last_linear(x)
if self.args.use_tanh:
return torch.tanh(x)
else:
return x
class SirenFactory(ModuleFactory):
def __init__(self, args: SirenArgs):
super().__init__()
self.args = args
def create(self) -> Module:
return Siren(self.args)
================================================
FILE: src/tha4/nn/spectral_norm.py
================================================
from torch.nn import Module
from torch.nn.utils import spectral_norm
def apply_spectral_norm(module: Module, use_spectrial_norm: bool = False) -> Module:
if use_spectrial_norm:
return spectral_norm(module)
else:
return module
================================================
FILE: src/tha4/nn/upscaler/__init__.py
================================================
================================================
FILE: src/tha4/nn/upscaler/upscaler_02.py
================================================
from typing import List
import torch
from torch import Tensor, zero_
from torch.nn import Module, Conv2d
from tha4.shion.core.module_factory import ModuleFactory
from tha4.nn.image_processing_util import GridChangeApplier
from tha4.nn.common.unet import UnetArgs, AttentionBlockArgs, UnetWithFirstConvAddition
class Upscaler02Args:
def __init__(self,
image_size: int,
image_channels: int,
num_pose_parameters: int,
unet_args: UnetArgs):
assert unet_args.in_channels == (
image_channels
)
assert unet_args.out_channels == (
image_channels + # direct
2 + # warp
1 # alpha
)
assert unet_args.cond_input_channels == num_pose_parameters
self.image_channels = image_channels
self.image_size = image_size
self.num_pose_parameters = num_pose_parameters
self.unet_args = unet_args
def apply_color_change(alpha, color_change, image: Tensor) -> Tensor:
return color_change * alpha + image * (1 - alpha)
class Upscaler02(Module):
def __init__(self, args: Upscaler02Args):
super().__init__()
self.args = args
self.body = UnetWithFirstConvAddition(args.unet_args)
self.grid_change_applier = GridChangeApplier()
self.coarse_image_conv = Conv2d(
args.image_channels + args.image_channels + 2,
args.unet_args.model_channels,
kernel_size=3,
stride=1,
padding=1)
with torch.no_grad():
zero_(self.coarse_image_conv.weight)
zero_(self.coarse_image_conv.bias)
def check_image(self, image: torch.Tensor):
assert len(image.shape) == 4
assert image.shape[1] == self.args.image_channels
assert image.shape[2] == self.args.image_size
assert image.shape[3] == self.args.image_size
def forward(self,
rest_image: torch.Tensor,
coarse_posed_image: torch.Tensor,
coarse_grid_change: torch.Tensor,
pose: torch.Tensor) -> List[Tensor]:
self.check_image(rest_image)
self.check_image(coarse_posed_image)
assert len(pose.shape) == 2
assert rest_image.shape[0] == pose.shape[0]
assert coarse_posed_image.shape[0] == pose.shape[0]
assert coarse_grid_change.shape[0] == pose.shape[0]
assert coarse_grid_change.shape[1] == 2
assert coarse_grid_change.shape[2] == self.args.image_size
assert coarse_grid_change.shape[3] == self.args.image_size
assert pose.shape[1] == self.args.num_pose_parameters
warped_image = self.grid_change_applier.apply(coarse_grid_change, rest_image)
t = torch.zeros(rest_image.shape[0], 1, device=rest_image.device)
feature = torch.cat([coarse_posed_image, warped_image, coarse_grid_change], dim=1)
first_conv_addition = self.coarse_image_conv(feature)
body_output = self.body(rest_image, t, pose, first_conv_addition)
direct = body_output[:, 0:self.args.image_channels, :, :]
grid_change = body_output[:, self.args.image_channels:self.args.image_channels + 2, :, :]
alpha = torch.sigmoid(body_output[:, self.args.image_channels + 2:self.args.image_channels + 3, :, :])
warped = self.grid_change_applier.apply(grid_change, rest_image)
merged = apply_color_change(alpha, direct, warped)
return [
merged,
alpha,
warped,
grid_change,
direct
]
INDEX_MERGED = 0
INDEX_ALPHA = 1
INDEX_WARPED = 2
INDEX_GRID_CHANGE = 3
INDEX_DIRECT = 4
class Upscaler02Factory(ModuleFactory):
def __init__(self, args: Upscaler02Args):
self.args = args
def create(self) -> Module:
return Upscaler02(self.args)
================================================
FILE: src/tha4/nn/util.py
================================================
from typing import Optional, Callable, Union
from torch.nn import Module
from tha4.shion.core.module_factory import ModuleFactory
from tha4.nn.init_function import create_init_function
from tha4.nn.nonlinearity_factory import resolve_nonlinearity_factory
from tha4.nn.normalization import NormalizationLayerFactory
from tha4.nn.spectral_norm import apply_spectral_norm
def wrap_conv_or_linear_module(module: Module,
initialization_method: Union[str, Callable[[Module], Module]],
use_spectral_norm: bool):
if isinstance(initialization_method, str):
init = create_init_function(initialization_method)
else:
init = initialization_method
return apply_spectral_norm(init(module), use_spectral_norm)
class BlockArgs:
def __init__(self,
initialization_method: Union[str, Callable[[Module], Module]] = 'he',
use_spectral_norm: bool = False,
normalization_layer_factory: Optional[NormalizationLayerFactory] = None,
nonlinearity_factory: Optional[ModuleFactory] = None):
self.nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory)
self.normalization_layer_factory = normalization_layer_factory
self.use_spectral_norm = use_spectral_norm
self.initialization_method = initialization_method
def wrap_module(self, module: Module) -> Module:
return wrap_conv_or_linear_module(module, self.get_init_func(), self.use_spectral_norm)
def get_init_func(self) -> Callable[[Module], Module]:
if isinstance(self.initialization_method, str):
return create_init_function(self.initialization_method)
else:
return self.initialization_method
================================================
FILE: src/tha4/poser/__init__.py
================================================
================================================
FILE: src/tha4/poser/general_poser_02.py
================================================
from typing import List, Optional, Tuple, Dict, Callable
import torch
from tha4.shion.core.cached_computation import ComputationState
from tha4.poser.poser import PoseParameterGroup, Poser
from torch import Tensor
from torch.nn import Module
class GeneralPoser02(Poser):
def __init__(self,
module_loaders: Dict[str, Callable[[], Module]],
device: torch.device,
output_length: int,
pose_parameters: List[PoseParameterGroup],
output_list_func: Callable[[ComputationState], List[Tensor]],
subrect: Optional[Tuple[Tuple[int, int], Tuple[int, int]]] = None,
default_output_index: int = 0,
image_size: int = 256,
dtype: torch.dtype = torch.float):
self.dtype = dtype
self.image_size = image_size
self.default_output_index = default_output_index
self.output_list_func = output_list_func
self.subrect = subrect
self.pose_parameters = pose_parameters
self.device = device
self.module_loaders = module_loaders
self.modules = None
self.num_parameters = 0
for pose_parameter in self.pose_parameters:
self.num_parameters += pose_parameter.get_arity()
self.output_length = output_length
def get_image_size(self) -> int:
return self.image_size
def get_modules(self):
if self.modules is None:
self.modules = {}
for key in self.module_loaders:
module = self.module_loaders[key]()
self.modules[key] = module
module.to(self.device)
module.train(False)
return self.modules
def get_pose_parameter_groups(self) -> List[PoseParameterGroup]:
return self.pose_parameters
def get_num_parameters(self) -> int:
return self.num_parameters
def pose(self, image: Tensor, pose: Tensor, output_index: Optional[int] = None) -> Tensor:
if output_index is None:
output_index = self.default_output_index
output_list = self.get_posing_outputs(image, pose)
return output_list[output_index]
def get_posing_outputs(self, image: Tensor, pose: Tensor) -> List[Tensor]:
modules = self.get_modules()
if len(image.shape) == 3:
image = image.unsqueeze(0)
if len(pose.shape) == 1:
pose = pose.unsqueeze(0)
if self.subrect is not None:
image = image[:, :, self.subrect[0][0]:self.subrect[0][1], self.subrect[1][0]:self.subrect[1][1]]
batch = [image, pose]
state = ComputationState(
modules=modules,
accumulated_modules={},
batch=batch,
outputs={})
return self.output_list_func(state)
def get_output_length(self) -> int:
return self.output_length
def free(self):
self.modules = None
def get_dtype(self) -> torch.dtype:
return self.dtype
def to(self, device: torch.device) -> 'GeneralPoser02':
if device == self.device:
return self
modules = self.get_modules()
self.device = device
for key in modules:
module = modules[key]
module.to(self.device)
return self
================================================
FILE: src/tha4/poser/modes/__init__.py
================================================
================================================
FILE: src/tha4/poser/modes/mode_07.py
================================================
from enum import Enum
from typing import List, Dict, Optional
import torch
from tha4.shion.core.cached_computation import CachedComputationProtocol, ComputationState
from tha4.shion.core.load_save import torch_load
from tha4.nn.eyebrow_decomposer.eyebrow_decomposer_00 import EyebrowDecomposer00, \
EyebrowDecomposer00Factory, EyebrowDecomposer00Args
from tha4.nn.eyebrow_morphing_combiner.eyebrow_morphing_combiner_00 import \
EyebrowMorphingCombiner00Factory, EyebrowMorphingCombiner00Args, EyebrowMorphingCombiner00
from tha4.nn.face_morpher.face_morpher_08 import FaceMorpher08Args, FaceMorpher08Factory
from tha4.nn.nonlinearity_factory import ReLUFactory
from tha4.nn.normalization import InstanceNorm2dFactory
from tha4.nn.util import BlockArgs
from tha4.nn.common.unet import UnetArgs, AttentionBlockArgs
from tha4.nn.morpher.morpher_00 import Morpher00Args, Morpher00
from tha4.nn.upscaler.upscaler_02 import Upscaler02Args, Upscaler02
from tha4.poser.general_poser_02 import GeneralPoser02
from tha4.poser.modes.pose_parameters import get_pose_parameters
from torch import Tensor
from torch.nn.functional import interpolate
class Network(Enum):
eyebrow_decomposer = 1
eyebrow_morphing_combiner = 2
face_morpher = 3
body_morpher = 4
upscaler = 5
@property
def outputs_key(self):
return f"{self.name}_outputs"
class Branch(Enum):
face_morphed_half = 1
face_morphed_full = 2
all_outputs = 3
NUM_EYEBROW_PARAMS = 12
NUM_FACE_PARAMS = 27
NUM_ROTATION_PARAMS = 6
class FiveStepPoserComputationProtocol(CachedComputationProtocol):
def __init__(self, eyebrow_morphed_image_index: int):
super().__init__()
self.eyebrow_morphed_image_index = eyebrow_morphed_image_index
self.cached_batch_0 = None
self.cached_eyebrow_decomposer_output = None
def compute_func(self):
def func(state: ComputationState) -> List[Tensor]:
if self.cached_batch_0 is None:
new_batch_0 = True
elif state.batch[0].shape[0] != self.cached_batch_0.shape[0]:
new_batch_0 = True
else:
new_batch_0 = torch.max((state.batch[0] - self.cached_batch_0).abs()).item() > 0
if not new_batch_0:
state.outputs[Network.eyebrow_decomposer.outputs_key] = self.cached_eyebrow_decomposer_output
output = self.get_output(Branch.all_outputs.name, state)
if new_batch_0:
self.cached_batch_0 = state.batch[0]
self.cached_eyebrow_decomposer_output = state.outputs[Network.eyebrow_decomposer.outputs_key]
return output
return func
def compute_output(self, key: str, state: ComputationState) -> List[Tensor]:
if key == Network.eyebrow_decomposer.outputs_key:
input_image = state.batch[0][:, :, 64:192, 64 + 128:192 + 128]
return state.modules[Network.eyebrow_decomposer.name].forward(input_image)
elif key == Network.eyebrow_morphing_combiner.outputs_key:
eyebrow_decomposer_output = self.get_output(Network.eyebrow_decomposer.outputs_key, state)
background_layer = eyebrow_decomposer_output[EyebrowDecomposer00.BACKGROUND_LAYER_INDEX]
eyebrow_layer = eyebrow_decomposer_output[EyebrowDecomposer00.EYEBROW_LAYER_INDEX]
eyebrow_pose = state.batch[1][:, :NUM_EYEBROW_PARAMS]
return state.modules[Network.eyebrow_morphing_combiner.name].forward(
background_layer,
eyebrow_layer,
eyebrow_pose)
elif key == Network.face_morpher.outputs_key:
eyebrow_morphing_combiner_output = self.get_output(
Network.eyebrow_morphing_combiner.outputs_key, state)
eyebrow_morphed_image = eyebrow_morphing_combiner_output[self.eyebrow_morphed_image_index]
input_image = state.batch[0][:, :, 32:32 + 192, (32 + 128):(32 + 192 + 128)].clone()
input_image[:, :, 32:32 + 128, 32:32 + 128] = eyebrow_morphed_image
face_pose = state.batch[1][:, NUM_EYEBROW_PARAMS:NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS]
return state.modules[Network.face_morpher.name].forward(input_image, face_pose)
elif key == Branch.face_morphed_full.name:
face_morpher_output = self.get_output(Network.face_morpher.outputs_key, state)
face_morphed_image = face_morpher_output[0]
input_image = state.batch[0].clone()
input_image[:, :, 32:32 + 192, 32 + 128:32 + 192 + 128] = face_morphed_image
return [input_image]
elif key == Branch.face_morphed_half.name:
face_morphed_full = self.get_output(Branch.face_morphed_full.name, state)[0]
return [
interpolate(face_morphed_full, size=(256, 256), mode='bilinear', align_corners=False)
]
elif key == Network.body_morpher.outputs_key:
face_morphed_half = self.get_output(Branch.face_morphed_half.name, state)[0]
rotation_pose = state.batch[1][:, NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS:]
return state.modules[Network.body_morpher.name].forward(face_morphed_half, rotation_pose)
elif key == Network.upscaler.outputs_key:
rest_image = self.get_output(Branch.face_morphed_full.name, state)[0]
body_morpher_outputs = self.get_output(
Network.body_morpher.outputs_key, state)
half_res_posed_image = body_morpher_outputs[Morpher00.INDEX_MERGED]
half_res_grid_change = body_morpher_outputs[Morpher00.INDEX_GRID_CHANGE]
coarse_posed_image = interpolate(half_res_posed_image, size=(512, 512), mode='bilinear')
coarse_grid_change = interpolate(half_res_grid_change, size=(512, 512), mode='bilinear')
rotation_pose = state.batch[1][:, NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS:]
return state.modules[Network.upscaler.name].forward(
rest_image, coarse_posed_image, coarse_grid_change, rotation_pose)
elif key == Branch.all_outputs.name:
upscaler_output = self.get_output(Network.upscaler.outputs_key, state)
face_morphed_full = self.get_output(Branch.face_morphed_full.name, state)
body_morpher_output = self.get_output(Network.body_morpher.outputs_key, state)
face_morpher_output = self.get_output(Network.face_morpher.outputs_key, state)
eyebrow_morphing_combiner_output = self.get_output(Network.eyebrow_morphing_combiner.outputs_key, state)
eyebrow_decomposer_output = self.get_output(Network.eyebrow_decomposer.outputs_key, state)
output = upscaler_output \
+ face_morphed_full \
+ body_morpher_output \
+ face_morpher_output \
+ eyebrow_morphing_combiner_output \
+ eyebrow_decomposer_output
return output
else:
raise RuntimeError("Unsupported key: " + key)
def load_eyebrow_decomposer(file_name: str):
factory = EyebrowDecomposer00Factory(
EyebrowDecomposer00Args(
image_size=128,
image_channels=4,
start_channels=64,
bottleneck_image_size=16,
num_bottleneck_blocks=6,
max_channels=512,
block_args=BlockArgs(
initialization_method='he',
use_spectral_norm=False,
normalization_layer_factory=InstanceNorm2dFactory(),
nonlinearity_factory=ReLUFactory(inplace=True))))
print("Loading the eyebrow decomposer ... ", end="")
module = factory.create()
module.load_state_dict(torch_load(file_name))
print("DONE!!!")
return module
def load_eyebrow_morphing_combiner(file_name: str):
factory = EyebrowMorphingCombiner00Factory(
EyebrowMorphingCombiner00Args(
image_size=128,
image_channels=4,
start_channels=64,
num_pose_params=12,
bottleneck_image_size=16,
num_bottleneck_blocks=6,
max_channels=512,
block_args=BlockArgs(
initialization_method='he',
use_spectral_norm=False,
normalization_layer_factory=InstanceNorm2dFactory(),
nonlinearity_factory=ReLUFactory(inplace=True))))
print("Loading the eyebrow morphing conbiner ... ", end="")
module = factory.create()
module.load_state_dict(torch_load(file_name))
print("DONE!!!")
return module
def load_face_morpher(file_name: str):
factory = FaceMorpher08Factory(
FaceMorpher08Args(
image_size=192,
image_channels=4,
num_expression_params=27,
start_channels=64,
bottleneck_image_size=24,
num_bottleneck_blocks=6,
max_channels=512,
block_args=BlockArgs(
initialization_method='he',
use_spectral_norm=False,
normalization_layer_factory=InstanceNorm2dFactory(),
nonlinearity_factory=ReLUFactory(inplace=False)
),
output_iris_mouth_grid_change=True,
)
)
print("Loading the face morpher ... ", end="")
module = factory.create()
module.load_state_dict(torch_load(file_name))
print("DONE!!!")
return module
def apply_color_change(alpha, color_change, image: Tensor) -> Tensor:
return color_change * alpha + image * (1 - alpha)
def load_morpher_00(file_name: str):
unet_args = UnetArgs(
in_channels=4,
out_channels=7,
model_channels=64,
level_channel_multipliers=[1, 2, 4, 4, 4],
level_use_attention=[False, False, False, False, True],
num_res_blocks_per_level=1,
num_middle_res_blocks=4,
time_embedding_channels=None,
cond_input_channels=6,
cond_internal_channels=256,
attention_block_args=AttentionBlockArgs(
num_heads=8,
use_new_attention_order=True),
dropout_prob=0.0)
morpher_00_args = Morpher00Args(
image_size=256,
image_channels=4,
num_pose_parameters=6,
unet_args=unet_args)
morpher_00 = Morpher00(morpher_00_args)
print("Loading the body morpher ... ", end="")
morpher_00.load_state_dict(torch_load(file_name))
print("DONE")
morpher_00.train(False)
return morpher_00
def load_upscaler_02(file_name: str):
unet_args = UnetArgs(
in_channels=4,
out_channels=7,
model_channels=32,
level_channel_multipliers=[1, 2, 4, 8, 8, 8],
level_use_attention=[False, False, False, False, False, True],
num_res_blocks_per_level=1,
num_middle_res_blocks=4,
time_embedding_channels=None,
cond_input_channels=6,
cond_internal_channels=256,
attention_block_args=AttentionBlockArgs(
num_heads=8,
use_new_attention_order=True),
dropout_prob=0.0)
upscaler_02_args = Upscaler02Args(
image_size=512,
image_channels=4,
num_pose_parameters=6,
unet_args=unet_args)
upscaler_02 = Upscaler02(upscaler_02_args)
print("Loading the upscaler ... ", end="")
upscaler_02.load_state_dict(torch_load(file_name))
print("DONE")
upscaler_02.train(False)
return upscaler_02
def create_poser(
device: torch.device,
module_file_names: Optional[Dict[str, str]] = None,
eyebrow_morphed_image_index: int = EyebrowMorphingCombiner00.EYEBROW_IMAGE_NO_COMBINE_ALPHA_INDEX,
default_output_index: int = 0) -> GeneralPoser02:
if module_file_names is None:
module_file_names = {}
if Network.eyebrow_decomposer.name not in module_file_names:
file_name = "data/tha4/eyebrow_decomposer.pt"
module_file_names[Network.eyebrow_decomposer.name] = file_name
if Network.eyebrow_morphing_combiner.name not in module_file_names:
file_name = "data/tha4/eyebrow_morphing_combiner.pt"
module_file_names[Network.eyebrow_morphing_combiner.name] = file_name
if Network.face_morpher.name not in module_file_names:
file_name = "data/tha4/face_morpher.pt"
module_file_names[Network.face_morpher.name] = file_name
if Network.body_morpher.name not in module_file_names:
file_name = "data/tha4/body_morpher.pt"
module_file_names[Network.body_morpher.name] = file_name
if Network.upscaler.name not in module_file_names:
file_name = "data/tha4/upscaler.pt"
module_file_names[Network.upscaler.name] = file_name
loaders = {
Network.eyebrow_decomposer.name:
lambda: load_eyebrow_decomposer(module_file_names[Network.eyebrow_decomposer.name]),
Network.eyebrow_morphing_combiner.name:
lambda: load_eyebrow_morphing_combiner(module_file_names[Network.eyebrow_morphing_combiner.name]),
Network.face_morpher.name:
lambda: load_face_morpher(module_file_names[Network.face_morpher.name]),
Network.body_morpher.name:
lambda: load_morpher_00(module_file_names[Network.body_morpher.name]),
Network.upscaler.name:
lambda: load_upscaler_02(module_file_names[Network.upscaler.name]),
}
return GeneralPoser02(
image_size=512,
module_loaders=loaders,
pose_parameters=get_pose_parameters().get_pose_parameter_groups(),
output_list_func=FiveStepPoserComputationProtocol(eyebrow_morphed_image_index).compute_func(),
subrect=None,
device=device,
output_length=5 + 1 + 5 + 8 + 8 + 6,
default_output_index=default_output_index)
================================================
FILE: src/tha4/poser/modes/mode_12.py
================================================
from enum import Enum
from typing import List, Dict, Optional, Any
import torch
from tha4.shion.core.cached_computation import CachedComputationProtocol, ComputationState
from tha4.shion.core.load_save import torch_load
from tha4.nn.eyebrow_decomposer.eyebrow_decomposer_00 import EyebrowDecomposer00, \
EyebrowDecomposer00Factory, EyebrowDecomposer00Args
from tha4.nn.eyebrow_morphing_combiner.eyebrow_morphing_combiner_00 import \
EyebrowMorphingCombiner00Factory, EyebrowMorphingCombiner00Args, EyebrowMorphingCombiner00
from tha4.nn.face_morpher.face_morpher_08 import FaceMorpher08Args, FaceMorpher08Factory
from tha4.nn.nonlinearity_factory import ReLUFactory
from tha4.nn.normalization import InstanceNorm2dFactory
from tha4.nn.util import BlockArgs
from tha4.poser.general_poser_02 import GeneralPoser02
from tha4.poser.modes.pose_parameters import get_pose_parameters
from torch import Tensor
class Network(Enum):
eyebrow_decomposer = 1
eyebrow_morphing_combiner = 2
face_morpher = 3
@property
def outputs_key(self):
return f"{self.name}_outputs"
class Branch(Enum):
face_morphed_half = 1
face_morphed_full = 2
all_outputs = 3
NUM_EYEBROW_PARAMS = 12
NUM_FACE_PARAMS = 27
NUM_ROTATION_PARAMS = 6
class FiveStepPoserComputationProtocol(CachedComputationProtocol):
def __init__(self, eyebrow_morphed_image_index: int):
super().__init__()
self.eyebrow_morphed_image_index = eyebrow_morphed_image_index
self.cached_batch_0 = None
self.cached_eyebrow_decomposer_output = None
def compute_func(self):
def func(state: ComputationState) -> List[Tensor]:
if self.cached_batch_0 is None:
new_batch_0 = True
elif state.batch[0].shape[0] != self.cached_batch_0.shape[0]:
new_batch_0 = True
else:
new_batch_0 = torch.max((state.batch[0] - self.cached_batch_0).abs()).item() > 0
if not new_batch_0:
state.outputs[Network.eyebrow_decomposer.outputs_key] = self.cached_eyebrow_decomposer_output
output = self.get_output(Branch.all_outputs.name, state)
if new_batch_0:
self.cached_batch_0 = state.batch[0]
self.cached_eyebrow_decomposer_output = state.outputs[Network.eyebrow_decomposer.outputs_key]
return output
return func
def compute_output(self, key: str, state: ComputationState) -> Any:
if key == Network.eyebrow_decomposer.outputs_key:
input_image = state.batch[0][:, :, 64:192, 64 + 128:192 + 128]
return state.modules[Network.eyebrow_decomposer.name].forward(input_image)
elif key == Network.eyebrow_morphing_combiner.outputs_key:
eyebrow_decomposer_output = self.get_output(Network.eyebrow_decomposer.outputs_key, state)
background_layer = eyebrow_decomposer_output[EyebrowDecomposer00.BACKGROUND_LAYER_INDEX]
eyebrow_layer = eyebrow_decomposer_output[EyebrowDecomposer00.EYEBROW_LAYER_INDEX]
eyebrow_pose = state.batch[1][:, :NUM_EYEBROW_PARAMS]
return state.modules[Network.eyebrow_morphing_combiner.name].forward(
background_layer,
eyebrow_layer,
eyebrow_pose)
elif key == Network.face_morpher.outputs_key:
eyebrow_morphing_combiner_output = self.get_output(
Network.eyebrow_morphing_combiner.outputs_key, state)
eyebrow_morphed_image = eyebrow_morphing_combiner_output[self.eyebrow_morphed_image_index]
input_image = state.batch[0][:, :, 32:32 + 192, (32 + 128):(32 + 192 + 128)].clone()
input_image[:, :, 32:32 + 128, 32:32 + 128] = eyebrow_morphed_image
face_pose = state.batch[1][:, NUM_EYEBROW_PARAMS:NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS]
return state.modules[Network.face_morpher.name].forward(input_image, face_pose)
elif key == Branch.all_outputs.name:
face_morpher_output = self.get_output(Network.face_morpher.outputs_key, state)
eyebrow_morphing_combiner_output = self.get_output(Network.eyebrow_morphing_combiner.outputs_key, state)
eyebrow_decomposer_output = self.get_output(Network.eyebrow_decomposer.outputs_key, state)
output = face_morpher_output \
+ eyebrow_morphing_combiner_output \
+ eyebrow_decomposer_output
return output
else:
raise RuntimeError("Unsupported key: " + key)
def load_eyebrow_decomposer(file_name: str):
factory = EyebrowDecomposer00Factory(
EyebrowDecomposer00Args(
image_size=128,
image_channels=4,
start_channels=64,
bottleneck_image_size=16,
num_bottleneck_blocks=6,
max_channels=512,
block_args=BlockArgs(
initialization_method='he',
use_spectral_norm=False,
normalization_layer_factory=InstanceNorm2dFactory(),
nonlinearity_factory=ReLUFactory(inplace=True))))
print("Loading the eyebrow decomposer ... ", end="")
module = factory.create()
module.load_state_dict(torch_load(file_name))
print("DONE!!!")
return module
def load_eyebrow_morphing_combiner(file_name: str):
factory = EyebrowMorphingCombiner00Factory(
EyebrowMorphingCombiner00Args(
image_size=128,
image_channels=4,
start_channels=64,
num_pose_params=12,
bottleneck_image_size=16,
num_bottleneck_blocks=6,
max_channels=512,
block_args=BlockArgs(
initialization_method='he',
use_spectral_norm=False,
normalization_layer_factory=InstanceNorm2dFactory(),
nonlinearity_factory=ReLUFactory(inplace=True))))
print("Loading the eyebrow morphing conbiner ... ", end="")
module = factory.create()
module.load_state_dict(torch_load(file_name))
print("DONE!!!")
return module
def load_face_morpher(file_name: str):
factory = FaceMorpher08Factory(
FaceMorpher08Args(
image_size=192,
image_channels=4,
num_expression_params=27,
start_channels=64,
bottleneck_image_size=24,
num_bottleneck_blocks=6,
max_channels=512,
block_args=BlockArgs(
initialization_method='he',
use_spectral_norm=False,
normalization_layer_factory=InstanceNorm2dFactory(),
nonlinearity_factory=ReLUFactory(inplace=False)),
output_iris_mouth_grid_change=True))
print("Loading the face morpher ... ", end="")
module = factory.create()
module.load_state_dict(torch_load(file_name))
print("DONE!!!")
return module
def apply_color_change(alpha, color_change, image: Tensor) -> Tensor:
return color_change * alpha + image * (1 - alpha)
def create_poser(
device: torch.device,
module_file_names: Optional[Dict[str, str]] = None,
eyebrow_morphed_image_index: int = EyebrowMorphingCombiner00.EYEBROW_IMAGE_NO_COMBINE_ALPHA_INDEX,
default_output_index: int = 0) -> GeneralPoser02:
if module_file_names is None:
module_file_names = {}
if Network.eyebrow_decomposer.name not in module_file_names:
file_name = "data/tha4/eyebrow_decomposer.pt"
module_file_names[Network.eyebrow_decomposer.name] = file_name
if Network.eyebrow_morphing_combiner.name not in module_file_names:
file_name = "data/tha4/eyebrow_morphing_combiner.pt"
module_file_names[Network.eyebrow_morphing_combiner.name] = file_name
if Network.face_morpher.name not in module_file_names:
file_name = "data/tha4/face_morpher.pt"
module_file_names[Network.face_morpher.name] = file_name
loaders = {
Network.eyebrow_decomposer.name:
lambda: load_eyebrow_decomposer(module_file_names[Network.eyebrow_decomposer.name]),
Network.eyebrow_morphing_combiner.name:
lambda: load_eyebrow_morphing_combiner(module_file_names[Network.eyebrow_morphing_combiner.name]),
Network.face_morpher.name:
lambda: load_face_morpher(module_file_names[Network.face_morpher.name]),
}
return GeneralPoser02(
image_size=512,
module_loaders=loaders,
pose_parameters=get_pose_parameters().get_pose_parameter_groups(),
output_list_func=FiveStepPoserComputationProtocol(eyebrow_morphed_image_index).compute_func(),
subrect=None,
device=device,
output_length=5 + 5 + 8,
default_output_index=default_output_index)
================================================
FILE: src/tha4/poser/modes/mode_14.py
================================================
from dataclasses import dataclass
from typing import List, Optional, Dict, Any
import torch
from tha4.shion.core.cached_computation import CachedComputationProtocol, ComputationState
from tha4.shion.core.load_save import torch_load
from tha4.nn.siren.face_morpher.siren_face_morpher_00 import SirenFaceMorpher00Args, SirenFaceMorpher00
from tha4.nn.siren.morpher.siren_morpher_03 import SirenMorpher03, SirenMorpher03Args, SirenMorpherLevelArgs
from tha4.nn.siren.vanilla.siren import SirenArgs
from tha4.poser.general_poser_02 import GeneralPoser02
from tha4.poser.modes.pose_parameters import get_pose_parameters
from torch import Tensor
KEY_FACE_MORPHER = "face_morpher"
KEY_BODY_MORPHER = "body_morpher"
@dataclass
class Keys:
face_morpher: str = KEY_FACE_MORPHER
face_morpher_output: str = "face_morpher_output"
face_morpher_input_image: str = "face_morpher_input_image"
face_morpher_input_pose: str = "face_morpher_input_pose"
body_morpher_input_image: str = "body_morpher_input_image"
body_morpher: str = KEY_BODY_MORPHER
body_morpher_output: str = "body_morpher_output"
all_outputs: str = "all_outputs"
@dataclass
class Indices:
original_image: int = 0
original_pose: int = 1
class TwoStepPoserComputationProtocol(CachedComputationProtocol):
def __init__(self, keys: Optional[Keys] = None, indices: Optional[Indices] = None):
super().__init__()
if keys is None:
keys = Keys()
if indices is None:
indices = Indices()
self.keys = keys
self.indices = indices
def compute_func(self):
def func(state: ComputationState) -> List[Tensor]:
return self.get_output(self.keys.all_outputs, state)
return func
def compute_output(self, key: str, state: ComputationState) -> Any:
if key == self.keys.face_morpher_input_image:
image = state.batch[self.indices.original_image]
center_x = 256
center_y = 128 + 16
return image[:, :, center_y - 64:center_y + 64, center_x - 64:center_x + 64]
elif key == self.keys.face_morpher_input_pose:
pose = state.batch[self.indices.original_pose]
return pose[:, 0:39]
elif key == self.keys.face_morpher_output:
module = state.modules[self.keys.face_morpher]
pose = self.get_output(self.keys.face_morpher_input_pose, state)
with torch.no_grad():
return module.forward(pose)
elif key == self.keys.body_morpher_input_image:
image = state.batch[self.indices.original_image].clone()
center_x = 256
center_y = 128 + 16
face_morphed_image = self.get_output(self.keys.face_morpher_output, state)
image[:, :, center_y - 64:center_y + 64, center_x - 64:center_x + 64] = face_morphed_image
return image
elif key == self.keys.body_morpher_output:
image = self.get_output(self.keys.body_morpher_input_image, state)
pose = state.batch[self.indices.original_pose]
body_morpher = state.modules[self.keys.body_morpher]
with torch.no_grad():
return body_morpher.forward(image, pose)
elif key == self.keys.all_outputs:
body_morpher_output = self.get_output(self.keys.body_morpher_output, state)
face_morpher_output = self.get_output(self.keys.face_morpher_output, state)
return body_morpher_output + [face_morpher_output]
else:
raise RuntimeError("Unsupported key: " + key)
def load_face_morpher(file_name: Optional[str] = None):
module = SirenFaceMorpher00(
SirenFaceMorpher00Args(
image_size=128,
image_channels=4,
pose_size=39,
siren_args=SirenArgs(
in_channels=39 + 2,
out_channels=4,
intermediate_channels=128,
num_sine_layers=8)))
if file_name is not None:
module.load_state_dict(torch_load(file_name))
return module
def load_body_morpher(file_name: Optional[str] = None):
module = SirenMorpher03(
SirenMorpher03Args(
image_size=512,
image_channels=4,
pose_size=45,
level_args=[
SirenMorpherLevelArgs(
image_size=128,
intermediate_channels=360,
num_sine_layers=3),
SirenMorpherLevelArgs(
image_size=256,
intermediate_channels=180,
num_sine_layers=3),
SirenMorpherLevelArgs(
image_size=512,
intermediate_channels=90,
num_sine_layers=3),
]))
if file_name is not None:
module.load_state_dict(torch_load(file_name))
return module
def create_poser(
device: torch.device,
module_file_names: Optional[Dict[str, str]] = None,
default_output_index: int = 0) -> GeneralPoser02:
if module_file_names is None:
module_file_names = {}
if KEY_FACE_MORPHER not in module_file_names:
file_name = "data/character_models/lambda_00/face_morpher.pt"
module_file_names[KEY_FACE_MORPHER] = file_name
if KEY_BODY_MORPHER not in module_file_names:
file_name = "data/character_models/lambda_00/body_morpher.pt"
module_file_names[KEY_BODY_MORPHER] = file_name
loaders = {
KEY_FACE_MORPHER:
lambda: load_face_morpher(module_file_names[KEY_FACE_MORPHER]),
KEY_BODY_MORPHER:
lambda: load_body_morpher(module_file_names[KEY_BODY_MORPHER]),
}
return GeneralPoser02(
image_size=512,
module_loaders=loaders,
pose_parameters=get_pose_parameters().get_pose_parameter_groups(),
output_list_func=TwoStepPoserComputationProtocol().compute_func(),
subrect=None,
device=device,
output_length=5 + 1,
default_output_index=default_output_index)
================================================
FILE: src/tha4/poser/modes/pose_parameters.py
================================================
from tha4.poser.poser import PoseParameters, PoseParameterCategory
def get_pose_parameters():
return PoseParameters.Builder() \
.add_parameter_group("eyebrow_troubled", PoseParameterCategory.EYEBROW, arity=2) \
.add_parameter_group("eyebrow_angry", PoseParameterCategory.EYEBROW, arity=2) \
.add_parameter_group("eyebrow_lowered", PoseParameterCategory.EYEBROW, arity=2) \
.add_parameter_group("eyebrow_raised", PoseParameterCategory.EYEBROW, arity=2) \
.add_parameter_group("eyebrow_happy", PoseParameterCategory.EYEBROW, arity=2) \
.add_parameter_group("eyebrow_serious", PoseParameterCategory.EYEBROW, arity=2) \
.add_parameter_group("eye_wink", PoseParameterCategory.EYE, arity=2) \
.add_parameter_group("eye_happy_wink", PoseParameterCategory.EYE, arity=2) \
.add_parameter_group("eye_surprised", PoseParameterCategory.EYE, arity=2) \
.add_parameter_group("eye_relaxed", PoseParameterCategory.EYE, arity=2) \
.add_parameter_group("eye_unimpressed", PoseParameterCategory.EYE, arity=2) \
.add_parameter_group("eye_raised_lower_eyelid", PoseParameterCategory.EYE, arity=2) \
.add_parameter_group("iris_small", PoseParameterCategory.IRIS_MORPH, arity=2) \
.add_parameter_group("mouth_aaa", PoseParameterCategory.MOUTH, arity=1, default_value=1.0) \
.add_parameter_group("mouth_iii", PoseParameterCategory.MOUTH, arity=1) \
.add_parameter_group("mouth_uuu", PoseParameterCategory.MOUTH, arity=1) \
.add_parameter_group("mouth_eee", PoseParameterCategory.MOUTH, arity=1) \
.add_parameter_group("mouth_ooo", PoseParameterCategory.MOUTH, arity=1) \
.add_parameter_group("mouth_delta", PoseParameterCategory.MOUTH, arity=1) \
.add_parameter_group("mouth_lowered_corner", PoseParameterCategory.MOUTH, arity=2) \
.add_parameter_group("mouth_raised_corner", PoseParameterCategory.MOUTH, arity=2) \
.add_parameter_group("mouth_smirk", PoseParameterCategory.MOUTH, arity=1) \
.add_parameter_group("iris_rotation_x", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("iris_rotation_y", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("head_x", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("head_y", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("neck_z", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("body_y", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("body_z", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("breathing", PoseParameterCategory.BREATHING, arity=1, range=(0.0, 1.0)) \
.build()
================================================
FILE: src/tha4/poser/poser.py
================================================
from abc import ABC, abstractmethod
from enum import Enum
from typing import Tuple, List, Optional
import torch
from torch import Tensor
class PoseParameterCategory(Enum):
EYEBROW = 1
EYE = 2
IRIS_MORPH = 3
IRIS_ROTATION = 4
MOUTH = 5
FACE_ROTATION = 6
BODY_ROTATION = 7
BREATHING = 8
class PoseParameterGroup:
def __init__(self,
group_name: str,
parameter_index: int,
category: PoseParameterCategory,
arity: int = 1,
discrete: bool = False,
default_value: float = 0.0,
range: Optional[Tuple[float, float]] = None):
assert arity == 1 or arity == 2
if range is None:
range = (0.0, 1.0)
if arity == 1:
parameter_names = [group_name]
else:
parameter_names = [group_name + "_left", group_name + "_right"]
assert len(parameter_names) == arity
self.parameter_names = parameter_names
self.range = range
self.default_value = default_value
self.discrete = discrete
self.arity = arity
self.category = category
self.parameter_index = parameter_index
self.group_name = group_name
def get_arity(self) -> int:
return self.arity
def get_group_name(self) -> str:
return self.group_name
def get_parameter_names(self) -> List[str]:
return self.parameter_names
def is_discrete(self) -> bool:
return self.discrete
def get_range(self) -> Tuple[float, float]:
return self.range
def get_default_value(self):
return self.default_value
def get_parameter_index(self):
return self.parameter_index
def get_category(self) -> PoseParameterCategory:
return self.category
class PoseParameters:
def __init__(self, pose_parameter_groups: List[PoseParameterGroup]):
self.pose_parameter_groups = pose_parameter_groups
def get_parameter_index(self, name: str) -> int:
index = 0
for parameter_group in self.pose_parameter_groups:
for param_name in parameter_group.parameter_names:
if name == param_name:
return index
index += 1
raise RuntimeError("Cannot find parameter with name %s" % name)
def get_parameter_name(self, index: int) -> str:
assert index >= 0 and index < self.get_parameter_count()
for group in self.pose_parameter_groups:
if index < group.get_arity():
return group.get_parameter_names()[index]
index -= group.arity
raise RuntimeError("Something is wrong here!!!")
def get_pose_parameter_groups(self):
return self.pose_parameter_groups
def get_parameter_count(self):
count = 0
for group in self.pose_parameter_groups:
count += group.arity
return count
class Builder:
def __init__(self):
self.index = 0
self.pose_parameter_groups = []
def add_parameter_group(self,
group_name: str,
category: PoseParameterCategory,
arity: int = 1,
discrete: bool = False,
default_value: float = 0.0,
range: Optional[Tuple[float, float]] = None):
self.pose_parameter_groups.append(
PoseParameterGroup(
group_name,
self.index,
category,
arity,
discrete,
default_value,
range))
self.index += arity
return self
def build(self) -> 'PoseParameters':
return PoseParameters(self.pose_parameter_groups)
class Poser(ABC):
@abstractmethod
def get_image_size(self) -> int:
pass
@abstractmethod
def get_output_length(self) -> int:
pass
@abstractmethod
def get_pose_parameter_groups(self) -> List[PoseParameterGroup]:
pass
@abstractmethod
def get_num_parameters(self) -> int:
pass
@abstractmethod
def pose(self, image: Tensor, pose: Tensor, output_index: int = 0) -> Tensor:
pass
@abstractmethod
def get_posing_outputs(self, image: Tensor, pose: Tensor) -> List[Tensor]:
pass
def get_dtype(self) -> torch.dtype:
return torch.float
@abstractmethod
def to(self, device: torch.device):
pass
================================================
FILE: src/tha4/pytasuku/__init__.py
================================================
================================================
FILE: src/tha4/pytasuku/indexed/__init__.py
================================================
================================================
FILE: src/tha4/pytasuku/indexed/all_tasks.py
================================================
from typing import Iterable
from tha4.pytasuku.workspace import Workspace
from tha4.pytasuku.indexed.indexed_tasks import IndexedTasks
from tha4.pytasuku.indexed.no_index_command_tasks import NoIndexCommandTasks
class AllTasks(NoIndexCommandTasks):
def __init__(
self,
workspace: Workspace, prefix: str,
tasks: Iterable[IndexedTasks],
command_name: str = "all",
define_tasks_immediately: bool = True):
super().__init__(workspace, prefix, command_name, define_tasks_immediately)
self.tasks = [t for t in tasks]
if define_tasks_immediately:
self.define_tasks()
def execute_run_command(self):
for task in self.tasks:
self.workspace.run(task.run_command)
def execute_clean_command(self):
for task in self.tasks:
self.workspace.run(task.clean_command)
================================================
FILE: src/tha4/pytasuku/indexed/bundled_indexed_file_tasks.py
================================================
import abc
from typing import Iterable, List
from tha4.pytasuku.workspace import Workspace
from tha4.pytasuku.indexed.indexed_tasks import IndexedTasks
from tha4.pytasuku.workspace import do_nothing
class BundledIndexedTasks:
__metaclass__ = abc.ABCMeta
@property
@abc.abstractmethod
def indexed_tasks_command_names(self) -> Iterable[str]:
pass
@abc.abstractmethod
def get_indexed_tasks(self, command_name) -> IndexedTasks:
pass
def define_all_tasks_from_list(workspace: Workspace, prefix: str, tasks: List[BundledIndexedTasks]):
for command_name in tasks[0].indexed_tasks_command_names:
workspace.create_command_task(
prefix + "/" + command_name,
[x.get_indexed_tasks(command_name).run_command for x in tasks],
do_nothing)
workspace.create_command_task(
prefix + "/" + command_name + "_clean",
[x.get_indexed_tasks(command_name).clean_command for x in tasks],
do_nothing)
================================================
FILE: src/tha4/pytasuku/indexed/indexed_file_tasks.py
================================================
import abc
from typing import List
from tha4.pytasuku.workspace import Workspace
from tha4.pytasuku.indexed.indexed_tasks import IndexedTasks
class IndexedFileTasks(IndexedTasks, abc.ABC):
def __init__(self, workspace: Workspace, prefix: str):
super().__init__(workspace, prefix)
@property
@abc.abstractmethod
def file_list(self) -> List[str]:
pass
@abc.abstractmethod
def get_file_name(self, *indices: int) -> str:
pass
================================================
FILE: src/tha4/pytasuku/indexed/indexed_tasks.py
================================================
import abc
from typing import List
from tha4.pytasuku.workspace import Workspace
class IndexedTasks(abc.ABC):
def __init__(self, workspace: Workspace, prefix: str):
self.prefix = prefix
self.workspace = workspace
@property
@abc.abstractmethod
def run_command(self) -> str:
pass
@property
@abc.abstractmethod
def clean_command(self) -> str:
pass
@property
@abc.abstractmethod
def shape(self) -> List[int]:
pass
@property
@abc.abstractmethod
def arity(self) -> int:
pass
@abc.abstractmethod
def define_tasks(self):
pass
================================================
FILE: src/tha4/pytasuku/indexed/no_index_command_tasks.py
================================================
import abc
from typing import List
from tha4.pytasuku.workspace import Workspace
from tha4.pytasuku.indexed.indexed_tasks import IndexedTasks
class NoIndexCommandTasks(IndexedTasks, abc.ABC):
def __init__(self, workspace: Workspace, prefix: str, command_name: str, define_tasks_immediately: bool = True):
super().__init__(workspace, prefix)
self.command_name = command_name
if define_tasks_immediately:
self.define_tasks()
@property
def run_command(self):
return self.prefix + "/" + self.command_name
@property
def clean_command(self):
return self.prefix + "/" + self.command_name + "_clean"
@property
def arity(self) -> int:
return 0
@property
def shape(self) -> List[int]:
return []
@abc.abstractmethod
def execute_run_command(self):
pass
@abc.abstractmethod
def execute_clean_command(self):
pass
def define_tasks(self):
self.workspace.create_command_task(self.run_command, [], self.execute_run_command)
self.workspace.create_command_task(self.clean_command, [], self.execute_clean_command)
================================================
FILE: src/tha4/pytasuku/indexed/no_index_file_tasks.py
================================================
import abc
from typing import List
from tha4.pytasuku.workspace import Workspace
from tha4.pytasuku.indexed.indexed_file_tasks import IndexedFileTasks
from tha4.pytasuku.indexed.util import delete_file
class NoIndexFileTasks(IndexedFileTasks, abc.ABC):
def __init__(self, workspace: Workspace, prefix: str, command_name: str, define_tasks_immediately: bool = True):
super().__init__(workspace, prefix)
self.command_name = command_name
if define_tasks_immediately:
self.define_tasks()
@property
@abc.abstractmethod
def file_name(self):
pass
@abc.abstractmethod
def create_file_task(self):
pass
def get_file_name(self, *indices: int) -> str:
if len(indices) > 0:
raise IndexError("NoIndexFileTasks has arity 0, but get_file_name is called with an index.")
return self.file_name
@property
def run_command(self):
return self.prefix + "/" + self.command_name
@property
def clean_command(self):
return self.prefix + "/" + self.command_name + "_clean"
@property
def arity(self) -> int:
return 0
@property
def shape(self) -> List[int]:
return []
@property
def file_list(self) -> List[str]:
return [self.file_name]
def clean(self):
delete_file(self.file_name)
def define_tasks(self):
self.create_file_task()
self.workspace.create_command_task(self.run_command, [self.file_name])
self.workspace.create_command_task(self.clean_command, [], lambda: self.clean())
================================================
FILE: src/tha4/pytasuku/indexed/one_index_file_tasks.py
================================================
import abc
from typing import List
from tha4.pytasuku.workspace import Workspace
from tha4.pytasuku.indexed.indexed_file_tasks import IndexedFileTasks
from tha4.pytasuku.indexed.util import delete_file
class OneIndexFileTasks(IndexedFileTasks, abc.ABC):
def __init__(self, workspace: Workspace, prefix: str, command_name: str, count: int,
define_tasks_immediately: bool = True):
super().__init__(workspace, prefix)
self.command_name = command_name
self.count = count
self.file_list_ = []
if define_tasks_immediately:
self.define_tasks()
@property
def run_command(self) -> str:
return self.prefix + "/" + self.command_name
@property
def clean_command(self) -> str:
return self.prefix + "/" + self.command_name + "_clean"
@property
def shape(self) -> List[int]:
return [self.count]
@property
def arity(self) -> int:
return 1
@abc.abstractmethod
def file_name(self, index):
pass
@abc.abstractmethod
def create_file_tasks(self, index):
pass
def get_file_name(self, *indices: int) -> str:
if len(indices) != 1:
raise IndexError("OneIndexFileTasks has arity 1, but "
"get_file_name does not get the appropriate number of arguments.")
return self.file_name(indices[0])
@property
def file_list(self):
if len(self.file_list_) == 0:
for i in range(self.count):
self.file_list_.append(self.file_name(i))
return self.file_list_
def clean(self):
for file in self.file_list:
delete_file(file)
def define_tasks(self):
for index in range(self.count):
self.create_file_tasks(index)
dependencies = self.file_list
self.workspace.create_command_task(self.run_command, dependencies)
self.workspace.create_command_task(self.clean_command, [], lambda: self.clean())
================================================
FILE: src/tha4/pytasuku/indexed/simple_no_index_file_tasks.py
================================================
from typing import Callable, List, Optional
from tha4.pytasuku.workspace import Workspace
from tha4.pytasuku.indexed.no_index_file_tasks import NoIndexFileTasks
class SimpleNoIndexFileTasks(NoIndexFileTasks):
def __init__(self,
workspace: Workspace,
prefix: str,
command_name: str,
file_name: str,
run_func: Callable[[], None],
dependencies: Optional[List[str]] = None):
super().__init__(workspace, prefix, command_name, define_tasks_immediately=False)
if dependencies is None:
dependencies = []
self.run_func = run_func
self._file_name = file_name
self.dependencies = dependencies
self.define_tasks()
@property
def file_name(self):
return self._file_name
def create_file_task(self):
self.workspace.create_file_task(self.file_name, self.dependencies, self.run_func)
================================================
FILE: src/tha4/pytasuku/indexed/two_indices_file_tasks.py
================================================
import abc
from typing import List
from tha4.pytasuku.workspace import Workspace
from tha4.pytasuku.indexed.indexed_file_tasks import IndexedFileTasks
from tha4.pytasuku.indexed.util import delete_file
class TwoIndicesFileTasks(IndexedFileTasks, abc.ABC):
def __init__(self, workspace: Workspace, prefix: str, command_name: str,
count0: int, count1: int, define_tasks_immediately: bool = True):
super().__init__(workspace, prefix)
self.count1 = count1
self.count0 = count0
self.command_name = command_name
self.file_list_ = []
if define_tasks_immediately:
self.define_tasks()
@property
def run_command(self) -> str:
return self.prefix + "/" + self.command_name
@property
def clean_command(self) -> str:
return self.prefix + "/" + self.command_name + "_clean"
@property
def shape(self) -> List[int]:
return [self.count0, self.count1]
@property
def arity(self) -> int:
return 2
@abc.abstractmethod
def file_name(self, index0: int, index1: int) -> str:
pass
@property
def file_list(self) -> List[str]:
if len(self.file_list_) == 0:
for i in range(self.count0):
for j in range(self.count1):
self.file_list_.append(self.file_name(i, j))
return self.file_list_
@abc.abstractmethod
def create_file_tasks(self, index0: int, index1: int):
pass
def get_file_name(self, *indices: int) -> str:
if len(indices) != 2:
raise IndexError("TwoIndicesFileTasks.get_file_name require two indices, " +
"but not exactly 2 indices were provide")
return self.file_name(indices[0], indices[1])
def clean(self):
for file in self.file_list:
delete_file(file)
def define_tasks(self):
for index0 in range(self.count0):
for index1 in range(self.count1):
self.create_file_tasks(index0, index1)
self.workspace.create_command_task(self.run_command, self.file_list)
self.workspace.create_command_task(self.clean_command, [], lambda: self.clean())
================================================
FILE: src/tha4/pytasuku/indexed/util.py
================================================
import os
from typing import Iterable, Dict, Callable, List
from tha4.pytasuku.workspace import Workspace
from tha4.pytasuku.indexed.all_tasks import AllTasks
from tha4.pytasuku.indexed.indexed_tasks import IndexedTasks
def delete_file(file_name):
if os.path.exists(file_name):
os.remove(file_name)
print("[delete] " + file_name)
else:
print("[not exist] " + file_name)
def all_tasks_from_named_tasks_map(
workspace: Workspace,
prefix: str,
tasks: Iterable[Dict[str, IndexedTasks]],
define_all_tasks: bool = True) \
-> Dict[str, IndexedTasks]:
subtasks = [x for x in tasks]
name_to_subtask_list = {}
for a_subtasks in subtasks:
for name in a_subtasks:
if not define_all_tasks and name == "all":
continue
if name not in name_to_subtask_list:
name_to_subtask_list[name] = []
name_to_subtask_list[name].append(a_subtasks[name])
output = {}
for name in name_to_subtask_list:
output[name] = AllTasks(workspace, prefix, name_to_subtask_list[name], name)
return output
def create_tasks_hierarchy_helper(
workspace: Workspace,
prefix: str,
tasks_func: Callable[[Workspace, str, List[str]], Dict[str, IndexedTasks]],
branches: List[List[str]],
path: List[str]):
if len(branches) == 0:
return tasks_func(workspace, prefix, path)
else:
tasks = {}
for branch in branches[0]:
output_tasks = create_tasks_hierarchy_helper(
workspace,
f"{prefix}/{branch}",
tasks_func,
branches[1:],
path + [branch])
if output_tasks is not None:
tasks[branch] = output_tasks
return all_tasks_from_named_tasks_map(workspace, prefix, tasks.values())
def create_task_hierarchy(
workspace: Workspace,
prefix: str,
tasks_func: Callable[[Workspace, str, List[str]], Dict[str, IndexedTasks]],
branches: List[List[str]]) -> Dict[str, IndexedTasks]:
return create_tasks_hierarchy_helper(workspace, prefix, tasks_func, branches, [])
def write_done_file(file_name: str):
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, "wt") as fout:
fout.write("DONE!!!")
================================================
FILE: src/tha4/pytasuku/task.py
================================================
import os
import logging
from typing import List
class Task:
def __init__(self, workspace: 'Workspace', name: str, dependencies: List[str]):
self._workspace = workspace
self._name = name
self._dependencies = dependencies
self._workspace.add_task(self)
def run(self):
pass
@property
def can_run(self) -> bool:
return True
@property
def needs_to_be_run(self) -> bool:
return False
@property
def name(self) -> str:
return self._name
@property
def dependencies(self) -> List[str]:
return self._dependencies
@property
def workspace(self) -> 'Workspace':
return self._workspace
@property
def timestamp(self) -> float:
return float("inf")
class CommandTask(Task):
def __init__(self, workspace, name, dependencies):
super().__init__(workspace, name, dependencies)
@property
def needs_to_be_run(self):
return True
class PlaceholderTask(Task):
def __init__(self, workspace, name):
super().__init__(workspace, name, [])
@property
def can_run(self):
return False
def run(self):
raise Exception("A placeholder task cannot be run! (%s)" % self.name)
@property
def needs_to_be_run(self):
return not os.path.isfile(self.name)
@property
def timestamp(self) -> float:
if not os.path.isfile(self.name):
return float("inf")
else:
return os.path.getmtime(self.name)
class FileTask(Task):
def __init__(self, workspace, name, dependencies):
super().__init__(workspace, name, dependencies)
@property
def timestamp(self):
return os.path.getmtime(self.name)
@property
def needs_to_be_run(self):
if not os.path.isfile(self.name):
logging.info("Task %s will be run because the corresponding file does not exist." % self.name)
return True
for dep in self.dependencies:
if self.workspace.needs_to_run(dep):
logging.info("Task %s will be run because dependency %s also needs to be run." % (self.name, dep))
return True
else:
self_timestamp = self.timestamp
dep_task = self.workspace.get_task(dep)
if dep_task.timestamp > self_timestamp:
if isinstance(dep_task, FileTask) or isinstance(dep_task, PlaceholderTask):
logging.info("Task %s needs to be run because task %s has later timestamp." %
(self.name, dep))
elif isinstance(dep_task, CommandTask):
logging.info("Task %s needs to be run because task %s is a command." % (self.name, dep))
return True
return False
================================================
FILE: src/tha4/pytasuku/task_selector_ui.py
================================================
from tkinter import Tk, BOTH, Button, RIGHT, Scrollbar
from tkinter.ttk import Frame, Treeview
from tha4.pytasuku.workspace import Workspace, PlaceholderTask
class TaskSelectorUi(Frame):
def __init__(self, root, workspace: Workspace):
super().__init__()
self.root = root
self.workspace = workspace
self.master.title("Tasks")
self.master.geometry("256x512")
treeview_frame = Frame(self)
treeview_frame.pack(fill=BOTH, expand=True)
self.treeview = Treeview(treeview_frame)
self.treeview["columns"] = ("task_name")
self.treeview.column("#0", width=256, minwidth=256)
self.treeview.heading("#0", text="Tree")
self.treeview.heading("task_name", text="Task Name")
treeview_vertical_scroll = Scrollbar(treeview_frame,
orient='vertical',
command=self.treeview.yview)
self.treeview.configure(yscrollcommand=treeview_vertical_scroll.set)
treeview_vertical_scroll.pack(side=RIGHT, fill='y')
self.treeview.pack(fill=BOTH, expand=True)
treeview_horizontal_scroll = Scrollbar(treeview_frame,
orient='horizontal',
command=self.treeview.xview)
self.treeview.configure(xscrollcommand=treeview_horizontal_scroll.set)
treeview_horizontal_scroll.pack(fill='x')
self.add_tree_nodes()
self.execute_button = Button(self, text="Execute!", command=self.run_selected_task)
self.execute_button.pack(side=RIGHT, padx=5, pady=5)
self.pack(fill=BOTH, expand=True)
self.selected_task_name = None
def add_tree_nodes(self):
nodes = {}
for task in self.workspace._tasks.values():
if isinstance(task, PlaceholderTask):
continue
comps = task.name.split('/')
for i in range(1, len(comps)):
assert len(comps) > 0
prefix = ""
index = 0
for comp in comps:
index = index + 1
parent = prefix
if prefix == "" and comp == "":
prefix = "/"
elif prefix == "":
prefix = prefix + comp
elif prefix == "/":
prefix = prefix + comp
else:
prefix = prefix + "/" + comp
if prefix in nodes:
continue
if index == len(comps):
data = prefix
else:
data = ""
if prefix == "/":
comp = "/"
nodes[prefix] = {
"name": str(prefix),
"display_name": comp,
"parent": parent,
"data": data
}
sorted_node_names = sorted(nodes.keys())
node_index = {}
for name in sorted_node_names:
node = nodes[name]
if node["parent"] == "":
id = self.treeview.insert("", "end", text=node["display_name"], values=node["data"], )
else:
parent = node_index[node["parent"]]
id = self.treeview.insert(parent, "end", text=node["display_name"], values=node["data"], )
node_index[node["name"]] = id
def run_selected_task(self):
selection = self.treeview.selection()
item = self.treeview.item(selection)
if item['values'] == "":
return
task_name = item["values"][0]
self.selected_task_name = task_name
self.root.destroy()
def run_task_selector_ui(workspace: Workspace):
root = Tk()
task_selector_ui = TaskSelectorUi(root, workspace=workspace)
root.mainloop()
task_name = task_selector_ui.selected_task_name
if task_name is not None:
print("Running", task_name, "...")
with workspace.session():
workspace.run(task_name)
================================================
FILE: src/tha4/pytasuku/util.py
================================================
import os.path
from typing import List
import logging
from tha4.pytasuku.workspace import Workspace
def create_delete_all_task(workspace: Workspace, name: str, files: List[str]):
def delete_all():
for file in files:
if os.path.exists(file):
logging.info("Removing %s ..." % file)
os.remove(file)
workspace.create_command_task(name, [], delete_all)
================================================
FILE: src/tha4/pytasuku/workspace.py
================================================
from contextlib import contextmanager
from enum import Enum
from typing import List
from tha4.pytasuku.task import Task, CommandTask, FileTask, PlaceholderTask
class WorkspaceState(Enum):
OUT_OF_SESSION = 1
IN_SESSION = 2
class NodeState(Enum):
IN_STACK = 1
VISITED = 2
class FuncCommandTask(CommandTask):
def __init__(self, workspace, name, dependencies, func):
super().__init__(workspace, name, dependencies)
self._func = func
def run(self):
self._func()
class FuncFileTask(FileTask):
def __init__(self, workspace, name, dependencies, func):
super().__init__(workspace, name, dependencies)
self._func = func
def run(self):
self._func()
def do_nothing():
pass
class Workspace:
def __init__(self):
self._tasks = dict()
self._name_to_done = None
self._state = WorkspaceState.OUT_OF_SESSION
self._modified = False
@property
def modified(self) -> bool:
return self._modified
@property
def state(self) -> WorkspaceState:
return self._state
@property
def in_session(self) -> bool:
return self._state == WorkspaceState.IN_SESSION
def task_exists(self, name: str) -> bool:
return name in self._tasks
def task_exists_and_not_placeholder(self, name: str) -> bool:
return self.task_exists(name) and not isinstance(self.get_task(name), PlaceholderTask)
def get_task(self, name: str) -> Task:
return self._tasks[name]
def add_task(self, task):
if self.in_session:
raise RuntimeError("New tasks can only be created when the workspace is out of session.")
if isinstance(task, PlaceholderTask):
if not self.task_exists(task.name):
self._tasks[task.name] = task
self._modified = True
else:
self._tasks[task.name] = task
for dep in task.dependencies:
PlaceholderTask(self, dep)
self._modified = True
def start_session(self):
if self.in_session:
raise RuntimeError("A session can only be started when the workspace is out of session.")
if self.modified:
self.check_cycle()
self._state = WorkspaceState.IN_SESSION
self._name_to_done = dict()
self._modified = False
def end_session(self):
if not self.in_session:
raise RuntimeError("A session can only be ended when the workspace is in session.")
self._state = WorkspaceState.OUT_OF_SESSION
self._name_to_done = None
@contextmanager
def session(self):
try:
self.start_session()
yield
finally:
self.end_session()
def check_cycle(self):
node_states = dict()
for name in self._tasks:
if name not in node_states:
self.dfs(name, node_states)
def dfs(self, name, node_states):
node_states[name] = NodeState.IN_STACK
task = self.get_task(name)
for dep in task.dependencies:
if dep not in node_states:
self.dfs(dep, node_states)
else:
state = node_states[dep]
if state == NodeState.IN_STACK:
raise RuntimeError("Dicovered cyclic dependency!")
node_states[name] = NodeState.VISITED
def run(self, name):
if not self.in_session:
raise RuntimeError("A task can only be run when the workspace is in session.")
if not self.task_exists(name):
raise RuntimeError("Task %s does not exists" % name)
self.run_helper(name)
def run_helper(self, name):
task = self.get_task(name)
for dep in task.dependencies:
if self.needs_to_run(dep):
self.run_helper(dep)
if self.needs_to_run(name):
task.run()
self._name_to_done[name] = True
def needs_to_run(self, name):
if not self.in_session:
raise RuntimeError("You can only check whether a task needs to run when the workspace is in session.")
if name in self._name_to_done:
return not self._name_to_done[name]
task = self.get_task(name)
need_to_run_value = task.needs_to_be_run
self._name_to_done[name] = not need_to_run_value
return need_to_run_value
def create_command_task(self, name, dependencies, func=do_nothing):
return FuncCommandTask(self, name, dependencies, func)
def create_file_task(self, name, dependencies, func):
return FuncFileTask(self, name, dependencies, func)
def command_task(workspace: Workspace, name: str, dependencies: List[str]):
def func(f):
workspace.create_command_task(name, dependencies, f)
return f
return func
def file_task(workspace: Workspace, name: str, dependencies: List[str]):
def func(f):
workspace.create_file_task(name, dependencies, f)
return f
return func
================================================
FILE: src/tha4/sampleoutput/__init__.py
================================================
================================================
FILE: src/tha4/sampleoutput/general_sample_output_protocol.py
================================================
import os
from enum import Enum
from typing import List, Dict
import PIL.Image
import numpy
import torch
from tha4.shion.base.dataset.util import get_indexed_batch
from tha4.shion.base.image_util import pytorch_rgb_to_numpy_image, pytorch_rgba_to_numpy_image
from tha4.shion.core.cached_computation import TensorCachedComputationFunc, ComputationState
from tha4.shion.core.training.sample_output_protocol import SampleOutputProtocol
from tha4.image_util import grid_change_to_numpy_image
from torch.nn import Module
from torch.nn.functional import interpolate
from torch.utils.data import Dataset
class ImageType(Enum):
COLOR = 0
ALPHA = 1
GRID_CHANGE = 2
SIGMOID_LOGIT = 3
class SampleImageSpec:
def __init__(self, value_func: TensorCachedComputationFunc, image_type: ImageType):
self.value_func = value_func
self.image_type = image_type
class SampleImageSaver:
def __init__(self,
cell_size: int,
image_channels: int,
sample_image_specs: List[SampleImageSpec]):
super().__init__()
self.sample_image_specs = sample_image_specs
self.cell_size = cell_size
self.image_channels = image_channels
def save_sample_output_data(self,
state: ComputationState,
prefix: str,
examples_seen_so_far: int):
num_cols = len(self.sample_image_specs)
num_rows = state.batch[0].shape[0]
output_image = numpy.zeros([self.cell_size * num_rows, self.cell_size * num_cols, self.image_channels])
for col in range(num_cols):
spec = self.sample_image_specs[col]
images = spec.value_func(state)
start_col = col * self.cell_size
for image_index in range(num_rows):
image = images[image_index].clone().detach()
row = image_index
start_row = row * self.cell_size
if spec.image_type == ImageType.COLOR:
c, h, w = image.shape
green_screen = torch.ones(3, h, w, device=image.device) * -1.0
green_screen[1, :, :] = 1.0
alpha = (image[3:4, :, :] + 1.0) * 0.5
image[0:3, :, :] = image[0:3, :, :] * alpha + green_screen * (1 - alpha)
image[3:4, :, :] = 1.0
image = image.cpu()
elif spec.image_type == ImageType.GRID_CHANGE:
image = image.cpu()
elif spec.image_type == ImageType.SIGMOID_LOGIT:
image = torch.sigmoid(image)
image = image.repeat(self.image_channels, 1, 1)
image = image * 2.0 - 1.0
image = image.cpu()
elif spec.image_type == ImageType.ALPHA:
if image.shape[0] == 1:
image = image.repeat(self.image_channels, 1, 1)
image = image * 2.0 - 1.0
image = image.cpu()
else:
raise RuntimeError(f"Unsupported image type: {spec.image_type}")
output_image[start_row:start_row + self.cell_size, start_col:start_col + self.cell_size, :] \
= self.convert_to_numpy_image(image)
file_name = "%s/sample_output_%010d.png" % (prefix, examples_seen_so_far)
os.makedirs(os.path.dirname(file_name), exist_ok=True)
if self.image_channels == 3:
mode = 'RGB'
else:
mode = 'RGBA'
pil_image = PIL.Image.fromarray(numpy.uint8(numpy.rint(output_image * 255.0)), mode=mode)
pil_image.save(file_name)
print("Saved %s" % file_name)
def convert_to_numpy_image(self, image: torch.Tensor):
image_size = image.shape[-1]
if self.cell_size != image_size:
image = interpolate(image.unsqueeze(0), size=self.cell_size).squeeze(0)
if image.shape[0] == 2:
return grid_change_to_numpy_image(image, num_channels=self.image_channels)
elif self.image_channels == 3:
return pytorch_rgb_to_numpy_image(image)
else:
return pytorch_rgba_to_numpy_image(image)
class GeneralSampleOutputProtocol(SampleOutputProtocol):
def __init__(self,
sample_image_specs: List[SampleImageSpec],
num_images: int = 8,
cell_size: int = 256,
image_channels: int = 4,
examples_per_sample_output: int = 5000,
random_seed: int = 1203040687):
super().__init__()
self.num_images = num_images
self.random_seed = random_seed
self.examples_per_sample_output = examples_per_sample_output
self.sample_image_saver = SampleImageSaver(cell_size, image_channels, sample_image_specs)
def get_examples_per_sample_output(self) -> int:
return self.examples_per_sample_output
def get_random_seed(self) -> int:
return self.random_seed
def get_sample_output_data(self, validation_dataset: Dataset, device: torch.device) -> dict:
example_indices = torch.randint(0, len(validation_dataset), (self.num_images,))
example_indices = [example_indices[i].item() for i in range(self.num_images)]
batch = get_indexed_batch(validation_dataset, example_indices, device)
return {'batch': batch}
def save_sample_output_data(self,
modules: Dict[str, Module],
accumulated_modules: Dict[str, Module],
sample_output_data: dict, prefix: str,
examples_seen_so_far: int,
device: torch.device):
for key in modules:
modules[key].train(False)
batch = sample_output_data['batch']
state = ComputationState(modules, accumulated_modules, batch, {})
self.sample_image_saver.save_sample_output_data(state, prefix, examples_seen_so_far)
================================================
FILE: src/tha4/sampleoutput/poser_sampler_output_protocol.py
================================================
from typing import Optional, List, Dict
import torch
from torch.nn import Module
from torch.utils.data import Dataset
from tha4.shion.base.dataset.util import get_indexed_batch
from tha4.shion.core.cached_computation import CachedComputationFunc, ComputationState
from tha4.shion.core.training.sample_output_protocol import SampleOutputProtocol
from tha4.sampleoutput.sample_image_creator import SampleImageSpec, ImageSource, ImageType, SampleImageSaver
class PoserSampleOutputProtocol(SampleOutputProtocol):
def __init__(self,
output_list_func: Optional[CachedComputationFunc] = None,
num_images: int = 8,
image_size: int = 256,
cell_size: int = 256,
image_channels: int = 4,
examples_per_sample_output: int = 5000,
sample_image_specs: Optional[List[SampleImageSpec]] = None,
random_seed: int = 1203040687):
super().__init__()
self.num_images = num_images
self.random_seed = random_seed
self.examples_per_sample_output = examples_per_sample_output
self.output_list_func = output_list_func
if sample_image_specs is None:
sample_image_specs = [
SampleImageSpec(ImageSource.BATCH, 0, ImageType.COLOR),
SampleImageSpec(ImageSource.BATCH, 2, ImageType.COLOR),
SampleImageSpec(ImageSource.OUTPUT, 0, ImageType.COLOR),
SampleImageSpec(ImageSource.OUTPUT, 1, ImageType.COLOR),
SampleImageSpec(ImageSource.OUTPUT, 2, ImageType.ALPHA),
SampleImageSpec(ImageSource.BATCH, 3, ImageType.COLOR),
SampleImageSpec(ImageSource.OUTPUT, 3, ImageType.COLOR),
SampleImageSpec(ImageSource.OUTPUT, 4, ImageType.COLOR),
SampleImageSpec(ImageSource.OUTPUT, 5, ImageType.ALPHA),
SampleImageSpec(ImageSource.OUTPUT, 6, ImageType.COLOR),
SampleImageSpec(ImageSource.BATCH, 4, ImageType.COLOR),
SampleImageSpec(ImageSource.OUTPUT, 7, ImageType.COLOR),
SampleImageSpec(ImageSource.OUTPUT, 8, ImageType.COLOR),
SampleImageSpec(ImageSource.OUTPUT, 9, ImageType.ALPHA),
SampleImageSpec(ImageSource.OUTPUT, 10, ImageType.COLOR),
]
self.sample_image_saver = SampleImageSaver(image_size, cell_size, image_channels, sample_image_specs)
def get_examples_per_sample_output(self) -> int:
return self.examples_per_sample_output
def get_random_seed(self) -> int:
return self.random_seed
def get_sample_output_data(self, validation_dataset: Dataset, device: torch.device) -> dict:
example_indices = torch.randint(0, len(validation_dataset), (self.num_images,))
example_indices = [example_indices[i].item() for i in range(self.num_images)]
batch = get_indexed_batch(validation_dataset, example_indices, device)
return {'batch': batch}
def save_sample_output_data(self,
modules: Dict[str, Module],
accumulated_modules: Dict[str, Module],
sample_output_data: dict, prefix: str,
examples_seen_so_far: int,
device: torch.device):
for key in modules:
modules[key].train(False)
batch = sample_output_data['batch']
with torch.no_grad():
outputs = self.output_list_func(ComputationState(modules, accumulated_modules, batch))
self.sample_image_saver.save_sample_output_data(batch, outputs, prefix, examples_seen_so_far)
================================================
FILE: src/tha4/sampleoutput/sample_image_creator.py
================================================
import math
import os
from enum import Enum
from typing import List
import numpy
import torch
from matplotlib import cm
from torch import Tensor
from torch.nn.functional import interpolate
from tha4.shion.base.image_util import save_numpy_image
class ImageSource(Enum):
BATCH = 0
OUTPUT = 1
class ImageType(Enum):
COLOR = 0
ALPHA = 1
GRID_CHANGE = 2
SIGMOID_LOGIT = 3
class SampleImageSpec:
def __init__(self, image_source: ImageSource, index: int, image_type: ImageType):
self.image_type = image_type
self.index = index
self.image_source = image_source
def torch_rgb_to_numpy_image(torch_image: Tensor, min_pixel_value=-1.0, max_pixel_value=1.0):
assert torch_image.dim() == 3
assert torch_image.shape[0] == 3
height = torch_image.shape[1]
width = torch_image.shape[2]
reshaped_image = torch_image.numpy().reshape(3, height * width).transpose().reshape(height, width, 3)
numpy_image = (reshaped_image - min_pixel_value) / (max_pixel_value - min_pixel_value)
return numpy_image
def torch_rgba_to_numpy_image(torch_image: Tensor, min_pixel_value=-1.0, max_pixel_value=1.0):
assert torch_image.dim() == 3
assert torch_image.shape[0] == 4
height = torch_image.shape[1]
width = torch_image.shape[2]
reshaped_image = torch_image.numpy().reshape(4, height * width).transpose().reshape(height, width, 4)
numpy_image = (reshaped_image - min_pixel_value) / (max_pixel_value - min_pixel_value)
numpy_image = numpy.clip(numpy_image, 0.0, 1.0)
return numpy_image
def torch_grid_change_to_numpy_image(torch_image, num_channels=3):
height = torch_image.shape[1]
width = torch_image.shape[2]
size_image = (torch_image[0, :, :] ** 2 + torch_image[1, :, :] ** 2).sqrt().view(height, width, 1).numpy()
hsv = cm.get_cmap('hsv')
angle_image = hsv(((torch.atan2(
torch_image[0, :, :].view(height * width),
torch_image[1, :, :].view(height * width)).view(height, width) + math.pi) / (2 * math.pi)).numpy()) * 3
numpy_image = size_image * angle_image[:, :, 0:3]
if num_channels == 3:
return numpy_image
elif num_channels == 4:
return numpy.concatenate([numpy_image, numpy.ones_like(size_image)], axis=2)
else:
raise RuntimeError("Unsupported num_channels: " + str(num_channels))
class SampleImageSaver:
def __init__(self,
image_size: int,
cell_size: int,
image_channels: int,
sample_image_specs: List[SampleImageSpec]):
super().__init__()
self.sample_image_specs = sample_image_specs
self.cell_size = cell_size
self.image_channels = image_channels
self.image_size = image_size
def save_sample_output_image(self, batch: List[Tensor], outputs: List[Tensor], file_name: str):
num_cols = len(self.sample_image_specs)
num_rows = batch[0].shape[0]
output_image = numpy.zeros([self.cell_size * num_rows, self.cell_size * num_cols, self.image_channels])
for image_index in range(num_rows):
row = image_index
start_row = row * self.cell_size
for col in range(num_cols):
spec = self.sample_image_specs[col]
start_col = col * self.cell_size
if spec.image_source == ImageSource.BATCH:
image = batch[spec.index][image_index].clone()
else:
image = outputs[spec.index][image_index].clone()
if spec.image_type == ImageType.COLOR:
c, h, w = image.shape
green_screen = torch.ones(3, h, w, device=image.device) * -1.0
green_screen[1, :, :] = 1.0
alpha = (image[3:4, :, :] + 1.0) * 0.5
image[0:3, :, :] = image[0:3, :, :] * alpha + green_screen * (1 - alpha)
image[3:4, :, :] = 1.0
image = image.detach().cpu()
elif spec.image_type == ImageType.GRID_CHANGE:
image = image.detach().cpu()
elif spec.image_type == ImageType.SIGMOID_LOGIT:
image = torch.sigmoid(image)
image = image.repeat(self.image_channels, 1, 1)
image = image * 2.0 - 1.0
image = image.detach().cpu()
else:
if image.shape[0] == 1:
image = image.repeat(self.image_channels, 1, 1)
image = image * 2.0 - 1.0
image = image.detach().cpu()
output_image[start_row:start_row + self.cell_size, start_col:start_col + self.cell_size, :] \
= self.convert_to_numpy_image(image)
os.makedirs(os.path.dirname(file_name), exist_ok=True)
save_numpy_image(output_image, file_name, save_straight_alpha=True)
def save_sample_output_data(self,
batch: List[Tensor],
outputs: List[Tensor],
prefix: str,
examples_seen_so_far: int):
file_name = "%s/sample_output_%010d.png" % (prefix, examples_seen_so_far)
self.save_sample_output_image(batch, outputs, file_name)
def convert_to_numpy_image(self, image: torch.Tensor):
if self.cell_size != self.image_size:
image = interpolate(image.unsqueeze(0), size=self.cell_size).squeeze(0)
if image.shape[0] == 2:
return torch_grid_change_to_numpy_image(image, num_channels=self.image_channels)
elif self.image_channels == 3:
return torch_rgb_to_numpy_image(image)
else:
return torch_rgba_to_numpy_image(image)
================================================
FILE: src/tha4/shion/__init__.py
================================================
================================================
FILE: src/tha4/shion/base/__init__.py
================================================
================================================
FILE: src/tha4/shion/base/dataset/__init__.py
================================================
================================================
FILE: src/tha4/shion/base/dataset/lazy_dataset.py
================================================
from typing import Callable
from torch.utils.data import Dataset
class LazyDataset(Dataset):
def __init__(self, source_func: Callable[[], Dataset]):
self.source_func = source_func
self.source = None
def get_source(self):
if self.source is None:
self.source = self.source_func()
return self.source
def __len__(self):
return len(self.get_source())
def __getitem__(self, item):
return self.get_source()[item]
================================================
FILE: src/tha4/shion/base/dataset/lazy_tensor_dataset.py
================================================
import torch
from torch.utils.data import Dataset, TensorDataset
from tha4.shion.core.load_save import torch_load
class LazyTensorDataset(Dataset):
def __init__(self, file_name: str):
self.file_name = file_name
self.dataset = None
def get_dataset(self):
if self.dataset is None:
data = torch_load(self.file_name)
if isinstance(data, torch.Tensor):
self.dataset = TensorDataset(data)
elif isinstance(data, tuple):
self.dataset = TensorDataset(*data)
elif isinstance(data, list):
self.dataset = TensorDataset(*data)
else:
raise RuntimeError("Unsupported data type: " + type(data))
return self.dataset
def __len__(self):
dataset = self.get_dataset()
return len(dataset)
def __getitem__(self, item):
dataset = self.get_dataset()
return dataset.__getitem__(item)
================================================
FILE: src/tha4/shion/base/dataset/png_in_dir_dataset.py
================================================
import os
from torch.nn import functional
from torch.utils.data import Dataset
from os import listdir
from os.path import isfile
from tha4.shion.base.image_util import extract_pytorch_image_from_filelike
class PngInDirDataset(Dataset):
def __init__(self, dir: str,
downscale_kernel_size: int = 1,
has_alpha=False,
scale=2.0,
offset=-1.0,
premultiply_alpha=True,
perfrom_srb_to_linear=True):
super().__init__()
self.perfrom_srb_to_linear = perfrom_srb_to_linear
self.premultiply_alpha = premultiply_alpha
self.offset = offset
self.scale = scale
self.has_alpha = has_alpha
self.downscale_kernel_size = downscale_kernel_size
self.dir = dir
self.file_names = None
def get_file_names(self):
if self.file_names is None:
self.file_names = [os.path.join(self.dir, x) for x in listdir(self.dir)]
self.file_names = [x for x in self.file_names if isfile(x) and x[-4:].lower() == ".png"]
self.file_names = sorted(self.file_names)
return self.file_names
def __len__(self):
file_names = self.get_file_names()
return len(file_names)
def __getitem__(self, item):
file_names = self.get_file_names()
file_name = file_names[item]
image = extract_pytorch_image_from_filelike(
file_name,
scale=self.scale,
offset=self.offset,
premultiply_alpha=self.has_alpha and self.premultiply_alpha,
perform_srgb_to_linear=self.perfrom_srb_to_linear)
if self.downscale_kernel_size == 1:
return [image]
else:
image = functional.avg_pool2d(image.unsqueeze(0), kernel_size=self.downscale_kernel_size).squeeze(0)
return [image]
================================================
FILE: src/tha4/shion/base/dataset/util.py
================================================
from typing import List
import torch
from torch.utils.data import Dataset
def get_indexed_batch(dataset: Dataset, example_indices: List[int], device: torch.device):
if len(example_indices) == 0:
return []
examples = []
for index in range(len(example_indices)):
example_index = example_indices[index]
raw_example = dataset[example_index]
example = []
for x in raw_example:
if isinstance(x, torch.Tensor):
y = x.to(device).unsqueeze(0)
elif isinstance(x, float) or isinstance(x, int):
y = torch.tensor([[x]], device=device)
else:
raise RuntimeError(f"get_indexed_batch: Data of type {type(x)} is not supported.")
example.append(y)
examples.append(example)
k = len(examples[0])
transposed = [[] for i in range(k)]
for example in examples:
for i in range(k):
transposed[i].append(example[i])
return [torch.cat(x, dim=0) for x in transposed]
================================================
FILE: src/tha4/shion/base/dataset/xformed_dataset.py
================================================
from typing import Any, Callable
from torch.utils.data import Dataset
class XformedDataset(Dataset):
def __init__(self, source: Dataset, xform_func: Callable[[Any], Any]):
self.xform_func = xform_func
self.source = source
def __len__(self):
return len(self.source)
def __getitem__(self, item):
return self.xform_func(self.source[item])
================================================
FILE: src/tha4/shion/base/image_util.py
================================================
import os
import PIL.Image
import numpy
import torch
from matplotlib import pyplot
from torch import Tensor
def numpy_srgb_to_linear(x):
x = numpy.clip(x, 0.0, 1.0)
return numpy.where(x <= 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4)
def numpy_linear_to_srgb(x):
x = numpy.clip(x, 0.0, 1.0)
return numpy.where(x <= 0.003130804953560372, x * 12.92, 1.055 * (x ** (1.0 / 2.4)) - 0.055)
def numpy_alpha_devide(rgb, a, epsilon=1e-5):
aaa = numpy.repeat(a, 3, axis=2)
aaa_prime = aaa + numpy.where(numpy.abs(aaa) < epsilon, epsilon, 0.0)
return numpy.where(numpy.abs(aaa) < epsilon, 0.0, rgb / aaa_prime)
def torch_srgb_to_linear(x: torch.Tensor):
x = torch.clip(x, 0.0, 1.0)
return torch.where(torch.le(x, 0.04045), x / 12.92, ((x + 0.055) / 1.055) ** 2.4)
def torch_linear_to_srgb(x):
x = torch.clip(x, 0.0, 1.0)
return torch.where(torch.le(x, 0.003130804953560372), x * 12.92, 1.055 * (x ** (1.0 / 2.4)) - 0.055)
def numpy_image_linear_to_srgb(image):
assert image.shape[2] == 3 or image.shape[2] == 4
if image.shape[2] == 3:
return numpy_linear_to_srgb(image)
else:
height, width, _ = image.shape
rgb_image = numpy_linear_to_srgb(image[:, :, 0:3])
a_image = image[:, :, 3:4]
return numpy.concatenate((rgb_image, a_image), axis=2)
def numpy_image_srgb_to_linear(image):
assert image.shape[2] == 3 or image.shape[2] == 4
if image.shape[2] == 3:
return numpy_srgb_to_linear(image)
else:
height, width, _ = image.shape
rgb_image = numpy_srgb_to_linear(image[:, :, 0:3])
a_image = image[:, :, 3:4]
return numpy.concatenate((rgb_image, a_image), axis=2)
def pytorch_rgb_to_numpy_image(torch_image: Tensor, min_pixel_value=-1.0, max_pixel_value=1.0):
assert torch_image.dim() == 3
assert torch_image.shape[0] == 3
height = torch_image.shape[1]
width = torch_image.shape[2]
reshaped_image = torch_image.numpy().reshape(3, height * width).transpose().reshape(height, width, 3)
numpy_image = (reshaped_image - min_pixel_value) / (max_pixel_value - min_pixel_value)
return numpy_linear_to_srgb(numpy_image)
def pytorch_rgba_to_numpy_image_greenscreen(torch_image: Tensor,
min_pixel_value=-1.0,
max_pixel_value=1.0,
include_alpha=False):
height = torch_image.shape[1]
width = torch_image.shape[2]
numpy_image = (torch_image.numpy().reshape(4, height * width).transpose().reshape(height, width,
4) - min_pixel_value) \
/ (max_pixel_value - min_pixel_value)
rgb_image = numpy_linear_to_srgb(numpy_image[:, :, 0:3])
a_image = numpy_image[:, :, 3]
rgb_image[:, :, 0:3] = rgb_image[:, :, 0:3] * a_image.reshape(a_image.shape[0], a_image.shape[1], 1)
rgb_image[:, :, 1] = rgb_image[:, :, 1] + (1 - a_image)
if not include_alpha:
return rgb_image
else:
return numpy.concatenate((rgb_image, numpy.ones_like(numpy_image[:, :, 3:4])), axis=2)
def pytorch_rgba_to_numpy_image(
torch_image: Tensor,
min_pixel_value=-1.0,
max_pixel_value=1.0,
perform_linear_to_srb: bool = True):
assert torch_image.dim() == 3
assert torch_image.shape[0] == 4
height = torch_image.shape[1]
width = torch_image.shape[2]
reshaped_image = torch_image.numpy().reshape(4, height * width).transpose().reshape(height, width, 4)
numpy_image = (reshaped_image - min_pixel_value) / (max_pixel_value - min_pixel_value)
if perform_linear_to_srb:
rgb_image = numpy_linear_to_srgb(numpy_image[:, :, 0:3])
else:
rgb_image = numpy.clip(numpy_image[:, :, 0:3], 0.0, 1.0)
a_image = numpy.clip(numpy_image[:, :, 3], 0.0, 1.0)
rgba_image = numpy.concatenate((rgb_image, a_image.reshape(height, width, 1)), axis=2)
return rgba_image
def pil_image_has_transparency(pil_image):
if pil_image.info.get("transparency", None) is not None:
return True
if pil_image.mode == "P":
transparent = pil_image.info.get("transparency", -1)
for _, index in pil_image.getcolors():
if index == transparent:
return True
elif pil_image.mode == "RGBA":
extrema = pil_image.getextrema()
if extrema[3][0] < 255:
return True
return False
def extract_numpy_image_from_PIL_image(pil_image, scale=2.0, offset=-1.0,
premultiply_alpha=True,
perform_srgb_to_linear=True):
has_alpha = pil_image_has_transparency(pil_image)
if has_alpha and pil_image.mode != 'RGBA':
pil_image = pil_image.convert("RGBA")
if not has_alpha and pil_image.mode != 'RGB':
pil_image = pil_image.convert("RGB")
if has_alpha:
num_channel = 4
else:
num_channel = 3
image_width = pil_image.width
image_height = pil_image.height
raw_image = numpy.asarray(pil_image, dtype=numpy.float32)
image = (raw_image / 255.0).reshape(image_height, image_width, num_channel)
if perform_srgb_to_linear:
image[:, :, 0:3] = numpy_srgb_to_linear(image[:, :, 0:3])
# Premultiply alpha
if has_alpha and premultiply_alpha:
image[:, :, 0:3] = image[:, :, 0:3] * image[:, :, 3:4]
return image * scale + offset
def extract_numpy_image_from_PIL_image_with_pytorch_layout(pil_image, scale=2.0, offset=-1.0,
premultiply_alpha=True,
perform_srgb_to_linear=True):
numpy_image = extract_numpy_image_from_PIL_image(
pil_image, scale, offset, premultiply_alpha, perform_srgb_to_linear)
image_height, image_width, num_channel = numpy_image.shape
image = numpy_image \
.reshape(image_height * image_width, num_channel) \
.transpose() \
.reshape(num_channel, image_height, image_width)
return image
def extract_numpy_image_from_filelike_with_pytorch_layout(file, scale=2.0, offset=-1.0, premultiply_alpha=True):
try:
pil_image = PIL.Image.open(file)
except Exception as e:
raise RuntimeError(file)
return extract_numpy_image_from_PIL_image_with_pytorch_layout(pil_image, scale, offset, premultiply_alpha)
def extract_numpy_image_from_filelike(file, scale=1.0, offset=0.0,
premultiply_alpha=True,
perform_srgb_to_linear: bool = True):
try:
pil_image = PIL.Image.open(file)
except Exception as e:
raise RuntimeError(file)
return extract_numpy_image_from_PIL_image(pil_image, scale, offset, premultiply_alpha, perform_srgb_to_linear)
def extract_pytorch_image_from_filelike(file, scale=2.0, offset=-1.0, premultiply_alpha=True,
perform_srgb_to_linear=True):
try:
pil_image = PIL.Image.open(file)
except Exception as e:
raise RuntimeError(file)
image = extract_numpy_image_from_PIL_image_with_pytorch_layout(pil_image, scale, offset, premultiply_alpha,
perform_srgb_to_linear)
return torch.from_numpy(image).float()
def extract_pytorch_image_from_PIL_image(pil_image, scale=2.0, offset=-1.0, premultiply_alpha=True,
perform_srgb_to_linear=True):
image = extract_numpy_image_from_PIL_image_with_pytorch_layout(
pil_image, scale, offset, premultiply_alpha, perform_srgb_to_linear)
return torch.from_numpy(image).float()
def convert_pytorch_image_to_zero_to_one_numpy_image(
torch_image: torch.Tensor,
scale: float = 2.0,
offset: float = -1.0):
torch_image = (torch_image - offset) / scale
torch_image = torch.permute(torch_image, (1, 2, 0))
numpy_image = torch_image.cpu().numpy()
return numpy_image
def convert_zero_to_one_numpy_image_to_PIL_image(
numpy_image, use_straight_alpha=True, perform_linear_to_srgb=True):
if numpy_image.shape[2] == 4:
rgb_image = numpy_image[:, :, 0:3]
a_image = numpy.clip(numpy_image[:, :, 3:4], 0.0, 1.0)
if use_straight_alpha:
rgb_image = numpy_alpha_devide(rgb_image, a_image)
if perform_linear_to_srgb:
rgb_image = numpy_linear_to_srgb(rgb_image)
else:
rgb_image = numpy.clip(rgb_image, 0.0, 1.0)
new_numpy_image = numpy.concatenate((rgb_image, a_image), axis=2)
pil_image = PIL.Image.fromarray(numpy.uint8(numpy.rint(new_numpy_image * 255.0)), mode='RGBA')
else:
if perform_linear_to_srgb:
numpy_image = numpy_linear_to_srgb(numpy_image)
else:
numpy_image = numpy.clip(numpy_image, 0.0, 1.0)
pil_image = PIL.Image.fromarray(numpy.uint8(numpy.rint(numpy_image * 255.0)), mode='RGB')
return pil_image
def save_numpy_image(numpy_image, file_name: str, save_straight_alpha=True, perform_linear_to_srgb=True):
pil_image = convert_zero_to_one_numpy_image_to_PIL_image(numpy_image, save_straight_alpha, perform_linear_to_srgb)
os.makedirs(os.path.dirname(file_name), exist_ok=True)
pil_image.save(file_name)
def resize_PIL_image(pil_image, size=(256, 256)):
w, h = pil_image.size
d = min(w, h)
r = ((w - d) // 2, (h - d) // 2, (w + d) // 2, (h + d) // 2)
return pil_image.resize(size, resample=PIL.Image.LANCZOS, box=r)
================================================
FILE: src/tha4/shion/base/loss/__init__.py
================================================
================================================
FILE: src/tha4/shion/base/loss/computed_scale_loss.py
================================================
from typing import Optional, Callable
from tha4.shion.core.cached_computation import TensorCachedComputationFunc, ComputationState
from tha4.shion.core.loss import Loss
class ComputedScaleLoss(Loss):
def __init__(self,
scale_func: TensorCachedComputationFunc,
loss: Loss,
weight: float = 1.0):
self.weight = weight
self.loss = loss
self.scale_func = scale_func
def compute(self, state: ComputationState, log_func: Optional[Callable[[str, float], None]] = None):
loss = self.loss.compute(state)
scale = self.scale_func(state)
loss = self.weight * scale * loss
if log_func is not None:
log_func("loss", loss.item())
return loss
================================================
FILE: src/tha4/shion/base/loss/computed_scaled_l2_loss.py
================================================
from typing import Callable, Optional
from tha4.shion.core.cached_computation import TensorCachedComputationFunc, ComputationState
from tha4.shion.core.loss import Loss
class ComputedScaledL2Loss(Loss):
def __init__(self,
expected_func: TensorCachedComputationFunc,
actual_func: TensorCachedComputationFunc,
element_scale_func: TensorCachedComputationFunc,
weight: float = 1.0):
self.element_scale_func = element_scale_func
self.actual_func = actual_func
self.expected_func = expected_func
self.weight = weight
def compute(
self,
state: ComputationState,
log_func: Optional[Callable[[str, float], None]] = None):
element_scale = self.element_scale_func(state)
expected = self.expected_func(state)
actual = self.actual_func(state)
diff = (expected - actual) * element_scale
loss = self.weight * (diff ** 2).mean()
if log_func is not None:
log_func("loss", loss.item())
return loss
================================================
FILE: src/tha4/shion/base/loss/l1_loss.py
================================================
from typing import Callable, Optional
import torch
from tha4.shion.core.cached_computation import TensorCachedComputationFunc, ComputationState
from tha4.shion.core.loss import Loss
class L1Loss(Loss):
def __init__(self,
expected_func: TensorCachedComputationFunc,
actual_func: TensorCachedComputationFunc,
weight: float = 1.0):
self.actual_func = actual_func
self.expected_func = expected_func
self.weight = weight
def compute(self, state: ComputationState, log_func: Optional[Callable[[str, float], None]] = None):
expected = self.expected_func(state)
actual = self.actual_func(state)
loss = self.weight * (expected - actual).abs().mean()
if log_func is not None:
log_func("loss", loss.item())
return loss
class ListL1Loss(Loss):
def __init__(self,
expected_func: TensorCachedComputationFunc,
actual_func: TensorCachedComputationFunc,
weight: float = 1.0):
self.actual_func = actual_func
self.expected_func = expected_func
self.weight = weight
def compute(self, state: ComputationState, log_func: Optional[Callable[[str, float], None]] = None):
expected = self.expected_func(state)
actual = self.actual_func(state)
assert len(expected) == len(actual)
loss = torch.zeros(1, device=expected[0].device)
for i in range(len(expected)):
loss += (expected[i] - actual[i]).abs().mean()
loss = self.weight * loss
if log_func is not None:
log_func("loss", loss.item())
return loss
class MaskedL1Loss(Loss):
def __init__(self,
expected_func: TensorCachedComputationFunc,
actual_func: TensorCachedComputationFunc,
mask_func: TensorCachedComputationFunc,
weight: float = 1.0):
self.mask_func = mask_func
self.actual_func = actual_func
self.expected_func = expected_func
self.weight = weight
def compute(self, state: ComputationState, log_func: Optional[Callable[[str, float], None]] = None):
mask = self.mask_func(state)
expected = self.expected_func(state)
actual = self.actual_func(state)
loss = self.weight * ((expected - actual) * mask).abs().mean()
if log_func is not None:
log_func("loss", loss.item())
return loss
================================================
FILE: src/tha4/shion/base/loss/l2_loss.py
================================================
from typing import Callable, Optional
from tha4.shion.core.cached_computation import TensorCachedComputationFunc, ComputationState
from tha4.shion.core.loss import Loss
class L2Loss(Loss):
def __init__(self,
expected_func: TensorCachedComputationFunc,
actual_func: TensorCachedComputationFunc,
weight: float = 1.0):
self.actual_func = actual_func
self.expected_func = expected_func
self.weight = weight
def compute(
self,
state: ComputationState,
log_func: Optional[Callable[[str, float], None]] = None):
expected = self.expected_func(state)
actual = self.actual_func(state)
loss = self.weight * ((expected - actual) ** 2).mean()
if log_func is not None:
log_func("loss", loss.item())
return loss
================================================
FILE: src/tha4/shion/base/loss/sum_loss.py
================================================
from typing import List, Tuple, Callable, Optional
import torch
from torch import Tensor
from tha4.shion.core.cached_computation import ComputationState
from tha4.shion.core.loss import Loss
class SumLoss(Loss):
def __init__(self, losses: List[Tuple[str, Loss]]):
self.losses = losses
def compute(self,
state: ComputationState,
log_func: Optional[Callable[[str, float], None]] = None) -> Tensor:
device = state.batch[0].device
loss_value = torch.zeros(1, device=device)
for loss_spec in self.losses:
loss_name = loss_spec[0]
loss = loss_spec[1]
if log_func is not None:
def loss_log_func(name, value):
log_func(loss_name + "_" + name, value)
else:
loss_log_func = None
loss_value = loss_value + loss.compute(state, loss_log_func)
if log_func is not None:
log_func("loss", loss_value.item())
return loss_value
================================================
FILE: src/tha4/shion/base/loss/time_dependently_weighted_loss.py
================================================
from typing import Callable, Optional
from torch import Tensor
from tha4.shion.core.cached_computation import ComputationState, CachedComputationFunc
from tha4.shion.core.loss import Loss
class TimeDependentlyWeightedLoss(Loss):
def __init__(self,
base_loss: Loss,
examples_seen_so_far_func: CachedComputationFunc,
weight_func: Callable[[int], float]):
self.weight_func = weight_func
self.examples_seen_so_far_func = examples_seen_so_far_func
self.base_loss = base_loss
def compute(self,
state: ComputationState,
log_func: Optional[Callable[[str, float], None]] = None) -> Tensor:
base_value = self.base_loss.compute(state)
examples_seen_so_far = self.examples_seen_so_far_func(state)
weight = self.weight_func(examples_seen_so_far)
loss_value = base_value * weight
if log_func is not None:
log_func("loss", loss_value.item())
return loss_value
================================================
FILE: src/tha4/shion/base/module_accumulators.py
================================================
from typing import Optional
import torch
from torch.nn import Module
from tha4.shion.core.module_accumulator import ModuleAccumulator
# Code from https://github.com/rosinality/style-based-gan-pytorch/blob/8437a8bbd106ad4a4691b798ce35d30b5111990b/train.py
def accumulate_modules(new_module: Module, accumulated_module: Module, beta=0.99):
with torch.no_grad():
new_module_params = dict(new_module.named_parameters())
accumulated_module_params = dict(accumulated_module.named_parameters())
for key in new_module_params.keys():
accumulated_module_params[key].mul_(beta).add_(new_module_params[key] * (1 - beta))
new_module_buffers = dict(new_module.named_buffers())
accumulated_module_buffers = dict(accumulated_module.named_buffers())
for key in new_module_buffers.keys():
accumulated_module_buffers[key].copy_(new_module_buffers[key])
class DecayAccumulator(ModuleAccumulator):
def __init__(self, decay: float = 0.999):
self.decay = decay
def accumulate(self, module: Module, output: Module, examples_seen_so_far: Optional[int] = None) -> Module:
accumulate_modules(module, output, self.decay)
return output
================================================
FILE: src/tha4/shion/base/optimizer_factories.py
================================================
from typing import Tuple, Iterable
from torch.nn import Parameter
from torch.optim import Optimizer, Adam, AdamW, SparseAdam, RMSprop
from tha4.shion.core.optimizer_factory import OptimizerFactory
class AdamOptimizerFactory(OptimizerFactory):
def __init__(self, betas: Tuple[float, float] = (0.9, 0.999), epsilon: float = 1e-8, weight_decay: float = 0.0):
super().__init__()
self.weight_decay = weight_decay
self.betas = betas
self.epsilon = epsilon
def create(self, parameters: Iterable[Parameter]) -> Optimizer:
return Adam(parameters, betas=self.betas, eps=self.epsilon, weight_decay=self.weight_decay)
class AdamWOptimizerFactory(OptimizerFactory):
def __init__(self, betas: Tuple[float, float] = (0.9, 0.999), epsilon: float = 1e-8, weight_decay: float = 0.01):
super().__init__()
self.weight_decay = weight_decay
self.betas = betas
self.epsilon = epsilon
def create(self, parameters: Iterable[Parameter]) -> Optimizer:
return AdamW(parameters, betas=self.betas, eps=self.epsilon, weight_decay=self.weight_decay)
class SparseAdamOptimizerFactory(OptimizerFactory):
def __init__(self, betas: Tuple[float, float] = (0.9, 0.999), epsilon: float = 1e-8):
super().__init__()
self.betas = betas
self.epsilon = epsilon
def create(self, parameters: Iterable[Parameter]) -> Optimizer:
return SparseAdam(list(parameters), betas=self.betas, eps=self.epsilon)
class RMSpropOptimizerFactory(OptimizerFactory):
def __init__(self):
super().__init__()
def create(self, parameters: Iterable[Parameter]) -> Optimizer:
return RMSprop(parameters)
================================================
FILE: src/tha4/shion/base/protocol/single_network_from_batch_input_computation_protocol.py
================================================
from typing import Optional, Any, List
from tha4.shion.core.cached_computation import CachedComputationProtocol, ComputationState
KEY_NETWORK = "network"
KEY_NETWORK_OUTPUT = "network_output"
class SingleNetworkBatchInputComputationProtocol(CachedComputationProtocol):
def __init__(self,
key_network: str = KEY_NETWORK,
key_network_output: str = KEY_NETWORK_OUTPUT,
input_index_to_batch_index: Optional[List[int]] = None):
if input_index_to_batch_index is None:
input_index_to_batch_index = [0]
self.input_index_to_batch_index = input_index_to_batch_index
self.key_network_output = key_network_output
self.key_network = key_network
def compute_output(self, key: str, state: ComputationState) -> Any:
if key == self.key_network_output:
inputs = []
for batch_index in self.input_index_to_batch_index:
inputs.append(state.batch[batch_index])
network = state.modules[self.key_network]
return network.forward(*inputs)
else:
raise RuntimeError("Computing output for key " + key + " is not supported!")
================================================
FILE: src/tha4/shion/base/training/__init__.py
================================================
================================================
FILE: src/tha4/shion/base/training/single_network.py
================================================
import time
from typing import List, Dict, Callable, Any, Optional
import torch
from torch.nn import Module
from torch.nn.utils import clip_grad_norm_
from torch.optim.optimizer import Optimizer
from tha4.shion.core.cached_computation import ComputationState
from tha4.shion.core.loss import Loss
from tha4.shion.core.optimizer_factory import OptimizerFactory
from tha4.shion.core.training.training_protocol import TrainingProtocol
from tha4.shion.core.training.validation_protocol import ValidationProtocol
KEY_NETWORK = "network"
class SingleNetworkTrainingProtocol(TrainingProtocol):
def __init__(self,
check_point_examples: List[int],
batch_size: int,
learning_rate: Callable[[int], Dict[str, float]],
optimizer_factories: Dict[str, OptimizerFactory],
module_key: str = KEY_NETWORK,
random_seed: int = 39549059840,
max_grad_norm: Optional[float] = None):
super().__init__()
self.max_grad_norm = max_grad_norm
self.module_key = module_key
self.optimizer_factories = optimizer_factories
self.learning_rate = learning_rate
self.batch_size = batch_size
self.random_seed = random_seed
self.check_point_examples = check_point_examples
def get_optimizer_factories(self) -> Dict[str, OptimizerFactory]:
return self.optimizer_factories
def get_checkpoint_examples(self) -> List[int]:
return self.check_point_examples
def get_random_seed(self) -> int:
return self.random_seed
def get_batch_size(self) -> int:
return self.batch_size
def get_learning_rate(self, examples_seen_so_far: int) -> Dict[str, float]:
return self.learning_rate(examples_seen_so_far)
def run_training_iteration(
self,
batch: Any,
examples_seen_so_far: int,
modules: Dict[str, Module],
accumulated_modules: Dict[str, Module],
optimizers: Dict[str, Optimizer],
losses: Dict[str, Loss],
create_log_func: Optional[Callable[[str, int], Callable[[str, float], None]]],
device: torch.device):
module = modules[self.module_key]
module.train(True)
optimizers[self.module_key].zero_grad(set_to_none=True)
if create_log_func is not None:
log_func = create_log_func("training_" + self.module_key, examples_seen_so_far)
else:
log_func = None
losses[self.module_key].compute(
ComputationState(modules, accumulated_modules, batch),
log_func).backward()
if self.max_grad_norm is not None:
clip_grad_norm_(module.parameters(), self.max_grad_norm)
optimizers[self.module_key].step()
class SingleNetworkValidationProtocol(ValidationProtocol):
def __init__(
self,
example_per_validation_iteration: int,
batch_size: int,
module_key: str = KEY_NETWORK):
super().__init__()
self.module_key = module_key
self.batch_size = batch_size
self.example_per_validation_iteration = example_per_validation_iteration
def get_batch_size(self, ) -> int:
return self.batch_size
def get_examples_per_validation_iteration(self) -> int:
return self.example_per_validation_iteration
def run_validation_iteration(
self,
batch: Any,
examples_seen_so_far: int,
modules: Dict[str, Module],
accumulated_modules: Dict[str, Module],
losses: Dict[str, Loss],
create_log_func: Callable[[str, int], Callable[[str, float], None]],
device: torch.device):
module = modules[self.module_key]
module.train(False)
with torch.no_grad():
log_func = create_log_func("validation_" + self.module_key, examples_seen_so_far)
losses[self.module_key].compute(
ComputationState(modules, accumulated_modules, batch),
log_func)
================================================
FILE: src/tha4/shion/base/training/single_network_with_minibatch.py
================================================
import time
from typing import List, Dict, Callable, Any, Optional
import torch
from torch.nn import Module
from torch.nn.utils import clip_grad_norm_
from torch.optim.optimizer import Optimizer
from tha4.shion.core.cached_computation import ComputationState
from tha4.shion.core.loss import Loss
from tha4.shion.core.optimizer_factory import OptimizerFactory
from tha4.shion.core.training.training_protocol import TrainingProtocol
from tha4.shion.core.training.validation_protocol import ValidationProtocol
KEY_NETWORK = "network"
class SingleNetworkWithMinibatchTrainingProtocol(TrainingProtocol):
def __init__(self,
check_point_examples: List[int],
batch_size: int,
minibatch_size: int,
learning_rate: Callable[[int], Dict[str, float]],
optimizer_factories: Dict[str, OptimizerFactory],
module_key: str = KEY_NETWORK,
random_seed: int = 39549059840,
max_grad_norm: Optional[float] = None):
super().__init__()
assert batch_size % minibatch_size == 0
self.minibatch_size = minibatch_size
self.max_grad_norm = max_grad_norm
self.module_key = module_key
self.optimizer_factories = optimizer_factories
self.learning_rate = learning_rate
self.batch_size = batch_size
self.random_seed = random_seed
self.check_point_examples = check_point_examples
def get_optimizer_factories(self) -> Dict[str, OptimizerFactory]:
return self.optimizer_factories
def get_checkpoint_examples(self) -> List[int]:
return self.check_point_examples
def get_random_seed(self) -> int:
return self.random_seed
def get_batch_size(self) -> int:
return self.batch_size
def get_learning_rate(self, examples_seen_so_far: int) -> Dict[str, float]:
return self.learning_rate(examples_seen_so_far)
def run_training_iteration(
self,
batch: Any,
examples_seen_so_far: int,
modules: Dict[str, Module],
accumulated_modules: Dict[str, Module],
optimizers: Dict[str, Optimizer],
losses: Dict[str, Loss],
create_log_func: Optional[Callable[[str, int], Callable[[str, float], None]]],
device: torch.device):
module = modules[self.module_key]
module.train(True)
optimizers[self.module_key].zero_grad(set_to_none=True)
if create_log_func is not None:
log_func = create_log_func("training_" + self.module_key, examples_seen_so_far)
else:
log_func = None
num_minibatch = self.batch_size // self.minibatch_size
for minibatch_index in range(num_minibatch):
minibatch = []
for item in batch:
minibatch.append(
item[minibatch_index * self.minibatch_size:(minibatch_index + 1) * self.minibatch_size])
loss = losses[self.module_key].compute(
ComputationState(modules, accumulated_modules, minibatch),
log_func if minibatch_index == 0 else None)
loss = loss / num_minibatch
loss.backward()
if self.max_grad_norm is not None:
clip_grad_norm_(module.parameters(), self.max_grad_norm)
optimizers[self.module_key].step()
================================================
FILE: src/tha4/shion/base/training/two_networks_training_protocol.py
================================================
from typing import List, Dict, Callable, Any, Optional
import torch
from torch.nn import Module
from torch.nn.utils import clip_grad_norm_
from torch.optim.optimizer import Optimizer
from tha4.shion.core.cached_computation import ComputationState
from tha4.shion.core.loss import Loss
from tha4.shion.core.optimizer_factory import OptimizerFactory
from tha4.shion.core.training.training_protocol import TrainingProtocol
class TwoNetworksWithMinibatchTrainingProtocol(TrainingProtocol):
def __init__(self,
check_point_examples: List[int],
batch_size: int,
learning_rate: Callable[[int], Dict[str, float]],
optimizer_factories: Dict[str, OptimizerFactory],
key_network_0: str,
key_network_1: str,
train_network_0: bool = False,
random_seed: int = 39549059840,
max_grad_norm: Optional[float] = None,
minibatch_size: Optional[int] = None):
super().__init__()
if minibatch_size is None:
minibatch_size = batch_size
assert batch_size % minibatch_size == 0
self.train_network_0 = train_network_0
self.key_network_1 = key_network_1
self.key_network_0 = key_network_0
self.minibatch_size = minibatch_size
self.max_grad_norm = max_grad_norm
self.optimizer_factories = optimizer_factories
self.learning_rate = learning_rate
self.batch_size = batch_size
self.random_seed = random_seed
self.check_point_examples = check_point_examples
def get_optimizer_factories(self) -> Dict[str, OptimizerFactory]:
return self.optimizer_factories
def get_checkpoint_examples(self) -> List[int]:
return self.check_point_examples
def get_random_seed(self) -> int:
return self.random_seed
def get_batch_size(self) -> int:
return self.batch_size
def get_learning_rate(self, examples_seen_so_far: int) -> Dict[str, float]:
return self.learning_rate(examples_seen_so_far)
def run_training_iteration(
self,
batch: Any,
examples_seen_so_far: int,
modules: Dict[str, Module],
accumulated_modules: Dict[str, Module],
optimizers: Dict[str, Optimizer],
losses: Dict[str, Loss],
create_log_func: Optional[Callable[[str, int], Callable[[str, float], None]]],
device: torch.device):
network_0 = modules[self.key_network_0]
network_0.train(self.train_network_0)
network_1 = modules[self.key_network_1]
network_1.train(True)
if self.train_network_0:
optimizers[self.key_network_0].zero_grad(set_to_none=True)
optimizers[self.key_network_1].zero_grad(set_to_none=True)
if create_log_func is not None:
network_0_log_func = create_log_func("training_" + self.key_network_0, examples_seen_so_far)
network_1_log_func = create_log_func("training_" + self.key_network_1, examples_seen_so_far)
else:
network_0_log_func = None
network_1_log_func = None
num_minibatch = self.batch_size // self.minibatch_size
for minibatch_index in range(num_minibatch):
minibatch = []
for item in batch:
minibatch.append(
item[minibatch_index * self.minibatch_size:(minibatch_index + 1) * self.minibatch_size])
loss = losses[self.key_network_1].compute(
ComputationState(modules, accumulated_modules, minibatch),
network_1_log_func if minibatch_index == 0 else None)
if self.train_network_0 and self.key_network_0 in losses:
loss = loss + losses[self.key_network_0].compute(
ComputationState(modules, accumulated_modules, minibatch),
network_0_log_func if minibatch_index == 0 else None)
loss = loss / num_minibatch
loss.backward()
if self.max_grad_norm is not None:
clip_grad_norm_(network_1.parameters(), self.max_grad_norm)
if self.train_network_0:
clip_grad_norm_(network_0.parameters(), self.max_grad_norm)
optimizers[self.key_network_1].step()
if self.train_network_0:
optimizers[self.key_network_0].step()
================================================
FILE: src/tha4/shion/core/__init__.py
================================================
================================================
FILE: src/tha4/shion/core/cached_computation.py
================================================
from abc import ABC, abstractmethod
from typing import Callable, Dict, Any, Optional
import torch
from torch import Tensor
from torch.nn import Module
class ComputationState:
def __init__(self,
modules: Dict[str, Module],
accumulated_modules: Dict[str, Module],
batch: Any,
outputs: Optional[Dict[str, Any]] = None):
if outputs is None:
outputs = {}
self.outputs = outputs
self.batch = batch
self.accumulated_modules = accumulated_modules
self.modules = modules
CachedComputationFunc = Callable[[ComputationState], Any]
TensorCachedComputationFunc = Callable[[ComputationState], Tensor]
def create_get_item_func(func: CachedComputationFunc, index):
def _f(state: ComputationState):
output = func(state)
return output[index]
return _f
def create_batch_element_func(index: int) -> TensorCachedComputationFunc:
def _f(state: ComputationState) -> Tensor:
return state.batch[index]
return _f
class CachedComputationProtocol(ABC):
def get_output(self, key: str, state: ComputationState) -> Any:
if key in state.outputs:
return state.outputs[key]
else:
output = self.compute_output(key, state)
state.outputs[key] = output
return state.outputs[key]
@abstractmethod
def compute_output(self, key: str, state: ComputationState) -> Any:
pass
def get_output_func(self, key: str) -> CachedComputationFunc:
def func(state: ComputationState):
return self.get_output(key, state)
return func
ComposableCachedComputationStep = Callable[[CachedComputationProtocol, ComputationState], Any]
class ComposableCachedComputationProtocol(CachedComputationProtocol):
def __init__(self, computation_steps: Optional[Dict[str, ComposableCachedComputationStep]] = None):
if computation_steps is None:
computation_steps = {}
self.computation_steps = computation_steps
def compute_output(self, key: str, state: ComputationState) -> Any:
if key in self.computation_steps:
return self.computation_steps[key](self, state)
else:
raise RuntimeError("Computing output for key " + key + " is not supported!")
def batch_indexing_func(index: int):
def _f(protocol: CachedComputationProtocol, state: ComputationState):
return state.batch[index]
return _f
def proxy_func(key: str):
def _f(protocol: CachedComputationProtocol, state: ComputationState):
return protocol.get_output(key, state)
return _f
def output_array_indexing_func(key: str, index: int):
def _f(protocol: CachedComputationProtocol, state: ComputationState):
return protocol.get_output(key, state)[index]
return _f
def add_step(step_dict: Dict[str, ComposableCachedComputationStep], name: str):
def _f(func):
step_dict[name] = func
return func
return _f
def zeros_like_func(key: str):
def _f(protocol: CachedComputationProtocol, state: ComputationState):
prototype = protocol.get_output(key, state)
return torch.zeros_like(prototype)
return _f
================================================
FILE: src/tha4/shion/core/load_save.py
================================================
import os
import torch
def torch_save(content, file_name):
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, 'wb') as f:
torch.save(content, f)
def torch_load(file_name):
with open(file_name, 'rb') as f:
return torch.load(f, map_location=lambda storage, loc: storage)
================================================
FILE: src/tha4/shion/core/loss.py
================================================
from abc import ABC, abstractmethod
from typing import Callable, Optional
from torch import Tensor
from tha4.shion.core.cached_computation import ComputationState
class Loss(ABC):
@abstractmethod
def compute(
self,
state: ComputationState,
log_func: Optional[Callable[[str, float], None]] = None) -> Tensor:
pass
================================================
FILE: src/tha4/shion/core/module_accumulator.py
================================================
from abc import ABC, abstractmethod
from typing import Optional
from torch.nn import Module
class ModuleAccumulator(ABC):
@abstractmethod
def accumulate(self, module: Module, output: Module, examples_seen_so_far: Optional[int] = None) -> Module:
pass
================================================
FILE: src/tha4/shion/core/module_factory.py
================================================
from abc import ABC, abstractmethod
from torch.nn import Module
class ModuleFactory(ABC):
@abstractmethod
def create(self) -> Module:
pass
================================================
FILE: src/tha4/shion/core/optimizer_factory.py
================================================
from abc import ABC, abstractmethod
from typing import Iterable
from torch.nn import Parameter
class OptimizerFactory(ABC):
@abstractmethod
def create(self, parameters: Iterable[Parameter]):
pass
================================================
FILE: src/tha4/shion/core/training/__init__.py
================================================
================================================
FILE: src/tha4/shion/core/training/distrib/__init__.py
================================================
================================================
FILE: src/tha4/shion/core/training/distrib/device_mapper.py
================================================
from typing import Dict
import torch
class SimpleCudaDeviceMapper:
def __call__(self, rank, local_rank):
return torch.device("cuda", local_rank)
class UserSpecifiedLocalRankToDeviceMapper:
def __init__(self, device_map: Dict[int, torch.device]):
self.device_map = device_map
def __call__(self, rank, local_rank):
assert local_rank in self.device_map
return self.device_map[local_rank]
================================================
FILE: src/tha4/shion/core/training/distrib/distributed_trainer.py
================================================
import argparse
import logging
import os.path
import time
from datetime import datetime
from typing import Dict, Optional, Callable, Any
import torch
import torch.distributed
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from tha4.shion.core.load_save import torch_save, torch_load
from tha4.shion.core.loss import Loss
from tha4.shion.core.module_accumulator import ModuleAccumulator
from tha4.shion.core.module_factory import ModuleFactory
from tha4.shion.core.training.distrib.device_mapper import SimpleCudaDeviceMapper
from tha4.shion.core.training.distrib.distributed_training_states import DistributedTrainingState
from tha4.shion.core.training.sample_output_protocol import SampleOutputProtocol
from tha4.shion.core.training.training_protocol import TrainingProtocol
from tha4.shion.core.training.util import set_learning_rate, create_log_func, get_least_greater_multiple
from tha4.shion.core.training.validation_protocol import ValidationProtocol
KEY_CHECKPOINT = 'checkpoint'
KEY_SNAPSHOT = 'snapshot'
KEY_VALIDATION = 'validation'
KEY_SAMPLE_OUTPUT = 'sample_output'
class DistributedTrainer:
def __init__(self,
prefix: str,
module_factories: Dict[str, ModuleFactory],
accumulators: Dict[str, ModuleAccumulator],
losses: Dict[str, Loss],
training_dataset: Dataset,
validation_dataset: Optional[Dataset],
training_protocol: TrainingProtocol,
validation_protocol: Optional[ValidationProtocol],
sample_output_protocol: Optional[SampleOutputProtocol],
pretrained_module_file_names: Dict[str, str],
example_per_snapshot: int,
num_data_loader_workers: int = 8,
distrib_backend: str = 'gloo'):
self.distrib_backend = distrib_backend
self.num_data_loader_workers = num_data_loader_workers
self.accumulators = accumulators
self.sample_output_protocol = sample_output_protocol
self.example_per_snapshot = example_per_snapshot
self.pretrained_module_file_names = pretrained_module_file_names
self.losses = losses
self.validation_protocol = validation_protocol
self.training_protocol = training_protocol
self.module_factories = module_factories
self.prefix = prefix
self.training_dataset = training_dataset
self.validation_dataset = validation_dataset
self.checkpoint_examples = self.training_protocol.get_checkpoint_examples()
assert len(self.checkpoint_examples) >= 1
assert self.checkpoint_examples[0] > 0
self.checkpoint_examples = [0] + self.checkpoint_examples
self.module_names = self.module_factories.keys()
assert len(self.module_names) > 0
self.training_data_loader = None
self.training_data_loader_iter = None
self.training_data_loader_batch_size = None
self.training_data_sampler = None
self.validation_data_loader = None
self.validation_data_loader_iter = None
self.validation_data_loader_batch_size = None
self.sample_output_data = None
self.summary_writer = None
self.log_dir = None
self.training_state = None
def get_sample_output_data_file_name(self):
return self.prefix + "/sample_output_data.pt"
def save_sample_output_data(self, rank: int, device: torch.device):
if rank != 0:
return
if os.path.exists(self.get_sample_output_data_file_name()):
return
if self.sample_output_protocol is not None:
torch.manual_seed(self.sample_output_protocol.get_random_seed())
sample_output_data = self.sample_output_protocol.get_sample_output_data(self.validation_dataset, device)
torch_save(sample_output_data, self.get_sample_output_data_file_name())
else:
torch_save({}, self.get_sample_output_data_file_name())
def load_sample_output_data(self, rank: int, device: torch.device):
if rank != 0:
return None
else:
self.save_sample_output_data(rank, device)
return torch_load(self.get_sample_output_data_file_name())
def get_snapshot_prefix(self) -> str:
return self.prefix + "/snapshot"
def can_load_training_state(self, prefix: str, world_size: int) -> bool:
return DistributedTrainingState.can_load(
prefix,
self.module_factories,
self.accumulators,
self.training_protocol.get_optimizer_factories(),
world_size)
def load_training_state(self, prefix, rank: int, local_rank: int, device: torch.device) -> DistributedTrainingState:
return DistributedTrainingState.load(
prefix,
self.module_factories,
self.accumulators,
self.training_protocol.get_optimizer_factories(),
rank,
local_rank,
device)
@staticmethod
def checkpoint_prefix(prefix: str, checkpoint_index: int) -> str:
return "%s/checkpoint/%04d" % (prefix, checkpoint_index)
def get_checkpoint_prefix(self, checkpoint_index) -> str:
return DistributedTrainer.checkpoint_prefix(self.prefix, checkpoint_index)
def get_initial_training_state(self, rank: int, local_rank: int, device: torch.device) -> DistributedTrainingState:
training_state = DistributedTrainingState.new(
self.module_factories,
self.accumulators,
self.training_protocol.get_optimizer_factories(),
self.training_protocol.get_random_seed(),
rank,
local_rank,
device,
self.pretrained_module_file_names)
logging.info("Created a new initial training state.")
return training_state
def load_previous_training_state(self,
target_checkpoint_examples: int,
world_size: int,
rank: int,
local_rank: int,
device: torch.device) -> DistributedTrainingState:
if self.can_load_training_state(self.get_snapshot_prefix(), world_size):
examples_seen_so_far = DistributedTrainingState.get_examples_seen_so_far(self.get_snapshot_prefix())
diff = examples_seen_so_far - target_checkpoint_examples
if diff < self.training_protocol.get_batch_size():
return self.load_training_state(self.get_snapshot_prefix(), rank, local_rank, device)
num_checkpoints = len(self.checkpoint_examples)
for checkpoint_index in range(num_checkpoints - 1, -1, -1):
if self.can_load_training_state(self.get_checkpoint_prefix(checkpoint_index), world_size):
examples_seen_so_far = DistributedTrainingState.get_examples_seen_so_far(
self.get_checkpoint_prefix(checkpoint_index))
diff = examples_seen_so_far - target_checkpoint_examples
if diff < self.training_protocol.get_batch_size():
return self.load_training_state(
self.get_checkpoint_prefix(checkpoint_index), rank, local_rank, device)
training_state = self.get_initial_training_state(rank, local_rank, device)
training_state.save(self.get_checkpoint_prefix(0), rank, lambda: self.barrier(local_rank))
training_state = self.load_training_state(self.get_checkpoint_prefix(0), rank, local_rank, device)
return training_state
def get_log_dir(self):
if self.log_dir is None:
now = datetime.now()
self.log_dir = self.prefix + "/log/" + now.strftime("%Y_%m_%d__%H_%M_%S")
return self.log_dir
def get_summary_writer(self, rank: int) -> Optional[SummaryWriter]:
if rank != 0:
return None
if self.summary_writer is None:
self.summary_writer = SummaryWriter(log_dir=self.get_log_dir())
return self.summary_writer
def get_effective_training_epoch_size(self, world_size: int):
batch_size = self.training_protocol.get_batch_size()
N = len(self.training_dataset)
N = (N // world_size) * world_size
N = (N // batch_size) * batch_size
return N
def get_training_epoch_index(self, examples_seen_so_far: int, world_size: int):
epoch_size = self.get_effective_training_epoch_size(world_size)
batch_size = self.training_protocol.get_batch_size()
return (examples_seen_so_far + batch_size * world_size) // epoch_size
def get_next_training_batch(self, examples_seen_so_far: int, world_size: int, device: torch.device):
batch_size = self.training_protocol.get_batch_size()
dataset = self.training_dataset
if self.training_data_loader is None:
self.training_data_sampler = DistributedSampler(
dataset,
shuffle=True,
drop_last=True)
self.training_data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=self.training_data_sampler,
shuffle=False,
num_workers=self.num_data_loader_workers,
drop_last=True)
if self.training_data_loader_iter is None:
epoch_index = self.get_training_epoch_index(examples_seen_so_far, world_size)
logging.info(f"Started a new epoch: index = {epoch_index}, examples_seen_so_far = {examples_seen_so_far}")
self.training_data_sampler.set_epoch(epoch_index)
self.training_data_loader_iter = iter(self.training_data_loader)
try:
batch = next(self.training_data_loader_iter)
except StopIteration:
epoch_index = self.get_training_epoch_index(examples_seen_so_far, world_size)
logging.info(f"Started a new epoch: index = {epoch_index}, examples_seen_so_far = {examples_seen_so_far}")
self.training_data_sampler.set_epoch(epoch_index)
self.training_data_loader_iter = iter(self.training_data_loader)
batch = next(self.training_data_loader_iter)
return [x.to(device) for x in batch]
def get_next_checkpoint_num_examples(self, examples_seen_so_far) -> int:
next_index = next(
(i for i in range(len(self.checkpoint_examples)) if self.checkpoint_examples[i] > examples_seen_so_far),
-1)
return self.checkpoint_examples[next_index]
def get_next_snapshot_num_examples(self, examples_seen_so_far) -> int:
return get_least_greater_multiple(examples_seen_so_far, self.example_per_snapshot)
def get_next_validation_num_examples(self, examples_seen_so_far) -> int:
if self.validation_protocol is None:
return -1
return get_least_greater_multiple(examples_seen_so_far,
self.validation_protocol.get_examples_per_validation_iteration())
def get_next_sample_output_num_examples(self, examples_seen_so_far) -> int:
if self.sample_output_protocol is None:
return -1
return get_least_greater_multiple(examples_seen_so_far,
self.sample_output_protocol.get_examples_per_sample_output())
def get_next_num_examples(self, examples_seen_so_far) -> Dict[str, int]:
return {
KEY_CHECKPOINT: self.get_next_checkpoint_num_examples(examples_seen_so_far),
KEY_SNAPSHOT: self.get_next_snapshot_num_examples(examples_seen_so_far),
KEY_VALIDATION: self.get_next_validation_num_examples(examples_seen_so_far),
KEY_SAMPLE_OUTPUT: self.get_next_sample_output_num_examples(examples_seen_so_far)
}
def get_next_validation_batch(self, device: torch.device):
if self.validation_dataset is None:
return None
if self.validation_data_loader is None:
self.validation_data_loader = DataLoader(
self.validation_dataset,
batch_size=self.validation_protocol.get_batch_size(),
shuffle=True,
num_workers=1,
drop_last=True)
if self.validation_data_loader_iter is None:
self.validation_data_loader_iter = iter(self.validation_data_loader)
try:
batch = next(self.validation_data_loader_iter)
except StopIteration:
self.validation_data_loader_iter = iter(self.validation_data_loader)
batch = next(self.validation_data_loader_iter)
return [x.to(device) for x in batch]
def get_checkpoint_index_to_save(self, examples_seen_so_far: int) -> int:
checkpoint_index = 0
for i in range(len(self.checkpoint_examples)):
if self.checkpoint_examples[i] <= examples_seen_so_far:
checkpoint_index = i
return checkpoint_index
def barrier(self, local_rank: int):
if self.distrib_backend == 'nccl':
torch.distributed.barrier(device_ids=[local_rank])
else:
torch.distributed.barrier()
def train(self,
world_size: int,
rank: int,
local_rank: int,
target_checkpoint_examples: Optional[int] = None,
device_mapper: Optional[Callable[[int, int], torch.device]] = None):
if target_checkpoint_examples is None:
target_checkpoint_examples = self.checkpoint_examples[-1]
if device_mapper is None:
device_mapper = SimpleCudaDeviceMapper()
device = device_mapper(rank, local_rank)
sample_output_data = self.load_sample_output_data(rank, device)
training_state = self.load_previous_training_state(
target_checkpoint_examples, world_size, rank, local_rank, device)
summary_writer = self.get_summary_writer(rank)
if summary_writer is not None:
log_func_factory = lambda name, num: create_log_func(summary_writer, name, num)
else:
log_func_factory = None
last_time = time.time()
while training_state.examples_seen_so_far < target_checkpoint_examples:
# Set the learning rate
learning_rate_by_module_name = self.training_protocol.get_learning_rate(training_state.examples_seen_so_far)
for module_name in self.module_factories.keys():
if module_name not in learning_rate_by_module_name or module_name not in training_state.optimizers:
continue
lr = learning_rate_by_module_name[module_name]
set_learning_rate(training_state.optimizers[module_name], lr)
if summary_writer is not None:
summary_writer.add_scalar(
module_name + "_learning_rate", lr, training_state.examples_seen_so_far)
# One training iteration
training_batch = self.get_next_training_batch(training_state.examples_seen_so_far, world_size, device)
self.training_protocol.run_training_iteration(
training_batch,
training_state.examples_seen_so_far,
training_state.modules,
training_state.accumulated_modules,
training_state.optimizers,
self.losses,
log_func_factory,
device)
# Accumulate model data
for module_name in self.accumulators:
new_module = training_state.modules[module_name]
if isinstance(new_module, DistributedDataParallel):
new_module = new_module.module
buffer_module = training_state.accumulated_modules[module_name]
self.accumulators[module_name].accumulate(
new_module, buffer_module, examples_seen_so_far=training_state.examples_seen_so_far)
# Advance the number of examples seen so far
next_num_examples = self.get_next_num_examples(training_state.examples_seen_so_far)
training_state.examples_seen_so_far += self.training_protocol.get_batch_size() * world_size
# Validation iteration
if self.validation_protocol is not None \
and training_state.examples_seen_so_far >= next_num_examples[KEY_VALIDATION] \
and rank == 0:
validation_batch = self.get_next_validation_batch(device)
self.validation_protocol.run_validation_iteration(
validation_batch,
training_state.examples_seen_so_far,
training_state.modules,
training_state.accumulated_modules,
self.losses,
log_func_factory,
device)
# Save sample output
if self.sample_output_protocol is not None \
and training_state.examples_seen_so_far >= next_num_examples[KEY_SAMPLE_OUTPUT]:
if rank == 0:
self.sample_output_protocol.save_sample_output_data(
training_state.modules,
training_state.accumulated_modules,
sample_output_data,
self.prefix + "/sample_outputs",
training_state.examples_seen_so_far,
device)
self.barrier(local_rank)
# Save checkpoint
if training_state.examples_seen_so_far >= next_num_examples[KEY_CHECKPOINT]:
checkpoint_index = self.get_checkpoint_index_to_save(training_state.examples_seen_so_far)
training_state.save(
self.get_checkpoint_prefix(checkpoint_index), rank, lambda: self.barrier(local_rank))
if next_num_examples[KEY_CHECKPOINT] != next_num_examples[KEY_SNAPSHOT]:
training_state.save(self.get_snapshot_prefix(), rank, lambda: self.barrier(local_rank))
# Save snapshot
if training_state.examples_seen_so_far >= next_num_examples[KEY_SNAPSHOT]:
training_state.save(self.get_snapshot_prefix(), rank, lambda: self.barrier(local_rank))
now = time.time()
if now - last_time > 10:
logging.info("Showed %d training examples." % training_state.examples_seen_so_far)
last_time = now
@staticmethod
def get_default_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description='Training script.')
parser.add_argument("--target_checkpoint_examples", type=int)
return parser
@staticmethod
def run_with_args(trainer_factory: Callable[[int, str], 'DistributedTrainer'],
args,
backend: str = 'gloo',
device_mapper: Optional[Callable[[int, int], torch.device]] = None):
world_size = int(os.environ['WORLD_SIZE'])
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
torch.distributed.init_process_group(backend)
trainer = trainer_factory(world_size, backend)
trainer.train(world_size, rank, local_rank, args.target_checkpoint_examples, device_mapper)
@staticmethod
def run(trainer_factory: Callable[[int, str], 'DistributedTrainer'],
backend: str = 'gloo',
device_mapper: Optional[Callable[[int, int], torch.device]] = None,
args: Optional[Any] = None):
if args is None:
parser = DistributedTrainer.get_default_arg_parser()
args = parser.parse_args()
DistributedTrainer.run_with_args(trainer_factory, args, backend, device_mapper)
================================================
FILE: src/tha4/shion/core/training/distrib/distributed_training_states.py
================================================
import copy
import logging
import os
from typing import Dict, Optional, Callable
import torch
from torch.nn import Module
from torch.nn.parallel import DistributedDataParallel
from torch.optim.optimizer import Optimizer
from tha4.shion.core.load_save import torch_save, torch_load
from tha4.shion.core.module_accumulator import ModuleAccumulator
from tha4.shion.core.module_factory import ModuleFactory
from tha4.shion.core.optimizer_factory import OptimizerFactory
from tha4.shion.core.training.util import optimizer_to_device
class DistributedTrainingState:
def __init__(self,
examples_seen_so_far: int,
modules: Dict[str, Module],
accumulated_modules: Dict[str, Module],
optimizers: Dict[str, Optimizer]):
self.accumulated_modules = accumulated_modules
self.optimizers = optimizers
self.modules = modules
self.examples_seen_so_far = examples_seen_so_far
@staticmethod
def get_examples_seen_so_far_file_name(prefix) -> str:
return prefix + "/examples_seen_so_far.txt"
@staticmethod
def get_module_file_name(prefix, module_name) -> str:
return "%s/module_%s.pt" % (prefix, module_name)
@staticmethod
def get_accumulated_module_file_name(prefix, module_name) -> str:
return "%s/accumulated_%s.pt" % (prefix, module_name)
@staticmethod
def get_optimizer_file_name(prefix, module_name) -> str:
return "%s/optimizer_%s.pt" % (prefix, module_name)
@staticmethod
def get_rng_state_file_name(prefix, rank: int):
return "%s/rng_state_%08d.pt" % (prefix, rank)
def mkdir(self, prefix: str):
os.makedirs(prefix, exist_ok=True)
def save_data(self, prefix: str, rank: int):
assert os.path.exists(prefix)
torch_save(torch.get_rng_state(), DistributedTrainingState.get_rng_state_file_name(prefix, rank))
logging.info("Saved %s" % DistributedTrainingState.get_rng_state_file_name(prefix, rank))
if rank == 0:
logging.info("Saving training state to %s" % prefix)
with open(DistributedTrainingState.get_examples_seen_so_far_file_name(prefix), "wt") as fout:
fout.write("%d\n" % self.examples_seen_so_far)
logging.info("Saved %s" % DistributedTrainingState.get_examples_seen_so_far_file_name(prefix))
for module_name in self.modules:
file_name = DistributedTrainingState.get_module_file_name(prefix, module_name)
module = self.modules[module_name]
if isinstance(module, DistributedDataParallel):
state_dict = module.module.state_dict()
else:
state_dict = module.state_dict()
torch_save(state_dict, file_name)
logging.info("Saved %s" % file_name)
for module_name in self.accumulated_modules:
file_name = DistributedTrainingState.get_accumulated_module_file_name(prefix, module_name)
torch_save(self.accumulated_modules[module_name].state_dict(), file_name)
logging.info("Saved %s" % file_name)
for module_name in self.optimizers:
file_name = DistributedTrainingState.get_optimizer_file_name(prefix, module_name)
torch_save(self.optimizers[module_name].state_dict(), file_name)
logging.info("Saved %s" % file_name)
logging.info("Done saving training state to %s" % prefix)
def save(self, prefix: str, rank: int, barrier_func: Callable[[], None]):
if rank == 0:
self.mkdir(prefix)
barrier_func()
self.save_data(prefix, rank)
barrier_func()
@staticmethod
def get_examples_seen_so_far(prefix: str) -> int:
with open(DistributedTrainingState.get_examples_seen_so_far_file_name(prefix)) as fin:
lines = fin.readlines()
return int(lines[0])
@staticmethod
def load(
prefix: str,
module_factories: Dict[str, ModuleFactory],
accumulators: Dict[str, ModuleAccumulator],
optimizer_factories: Dict[str, OptimizerFactory],
rank: int,
local_rank: int,
device: torch.device) -> 'DistributedTrainingState':
logging.info("Loading training state from %s" % prefix)
with open(DistributedTrainingState.get_examples_seen_so_far_file_name(prefix)) as fin:
lines = fin.readlines()
examples_seen_so_far = int(lines[0])
logging.info("Loaded %s" % DistributedTrainingState.get_examples_seen_so_far_file_name(prefix))
modules = {
module_name: factory.create()
for (module_name, factory) in module_factories.items()
}
for module_name in modules:
file_name = DistributedTrainingState.get_module_file_name(prefix, module_name)
module = modules[module_name]
state_dict = torch_load(file_name)
module.load_state_dict(state_dict)
module.to(device)
modules[module_name] = DistributedDataParallel(
module,
device_ids=[device.index],
output_device=device.index)
logging.info("Loaded %s" % file_name)
accumulated_modules = {}
for module_name in accumulators:
module_factory = module_factories[module_name]
module = module_factory.create()
file_name = DistributedTrainingState.get_accumulated_module_file_name(prefix, module_name)
module.load_state_dict(torch_load(file_name))
module.to(device)
accumulated_modules[module_name] = module
logging.info("Loaded %s" % file_name)
optimizers = {}
for module_name in optimizer_factories:
optimizer = optimizer_factories[module_name].create(modules[module_name].parameters())
file_name = DistributedTrainingState.get_optimizer_file_name(prefix, module_name)
optimizer.load_state_dict(torch_load(file_name))
optimizer_to_device(optimizer, device)
optimizers[module_name] = optimizer
logging.info("Loaded %s" % file_name)
torch.set_rng_state(torch_load(DistributedTrainingState.get_rng_state_file_name(prefix, rank)))
logging.info("Loaded %s" % DistributedTrainingState.get_rng_state_file_name(prefix, rank))
logging.info("Done loading training state from %s" % prefix)
return DistributedTrainingState(examples_seen_so_far, modules, accumulated_modules, optimizers)
@staticmethod
def new(module_factories: Dict[str, ModuleFactory],
accumulators: Dict[str, ModuleAccumulator],
optimizer_factories: Dict[str, OptimizerFactory],
random_seed: int,
rank: int,
local_rank: int,
device: torch.device,
pretrained_module_file_names: Optional[Dict[str, str]] = None) -> 'DistributedTrainingState':
examples_seen_so_far = 0
modules = {
module_name: factory.create()
for (module_name, factory) in module_factories.items()
}
for module_name in modules:
modules[module_name].to(device)
if pretrained_module_file_names is not None:
for module_name in modules:
if module_name in pretrained_module_file_names:
file_name = pretrained_module_file_names[module_name]
modules[module_name].load_state_dict(torch_load(file_name))
logging.info("Loaded initial state from %s ..." % file_name)
accumulated_modules = {}
for module_name in accumulators:
accumulated_modules[module_name] = copy.deepcopy(modules[module_name])
for module_name in modules:
module = modules[module_name]
modules[module_name] = DistributedDataParallel(
module,
device_ids=[device.index],
output_device=device.index)
optimizers = {}
for module_name in optimizer_factories:
module = modules[module_name]
optimizer = optimizer_factories[module_name].create(module.parameters())
optimizer_to_device(optimizer, device)
optimizers[module_name] = optimizer
torch.manual_seed(random_seed + rank)
return DistributedTrainingState(examples_seen_so_far, modules, accumulated_modules, optimizers)
@staticmethod
def can_load(prefix: str,
module_factories: Dict[str, ModuleFactory],
accumulators: Dict[str, ModuleAccumulator],
optimizer_factories: Dict[str, OptimizerFactory],
world_size: int) -> bool:
logging.info(f"Checking directory {prefix}")
if not os.path.isdir(prefix):
logging.info(f"Cannot load files in {prefix} because it is not a directory")
return False
examples_seen_so_far_file_name = DistributedTrainingState.get_examples_seen_so_far_file_name(prefix)
if not os.path.isfile(examples_seen_so_far_file_name):
logging.info(f"Cannot load files in {prefix} because {examples_seen_so_far_file_name} is not a file.")
return False
for module_name in module_factories.keys():
file_name = DistributedTrainingState.get_module_file_name(prefix, module_name)
if not os.path.isfile(file_name):
logging.info(f"Cannot load files in {prefix} because {file_name} is not a file.")
return False
for module_name in accumulators:
file_name = DistributedTrainingState.get_accumulated_module_file_name(prefix, module_name)
if not os.path.isfile(file_name):
logging.info(f"Cannot load files in {prefix} because {file_name} is not a file.")
return False
for module_name in optimizer_factories:
file_name = DistributedTrainingState.get_optimizer_file_name(prefix, module_name)
if not os.path.isfile(file_name):
logging.info(f"Cannot load files in {prefix} because {file_name} is not a file.")
return False
for rank in range(world_size):
file_name = DistributedTrainingState.get_rng_state_file_name(prefix, rank)
if not os.path.isfile(file_name):
logging.info(f"Cannot load files in {prefix} because {file_name} is not a file.")
return False
return True
================================================
FILE: src/tha4/shion/core/training/distrib/distributed_training_tasks.py
================================================
import logging
import os
import sys
from typing import Callable, List, Optional
from tha4.pytasuku.workspace import Workspace
from tha4.shion.core.training.distrib.distributed_trainer import DistributedTrainer
from tha4.shion.core.training.distrib.distributed_training_states import DistributedTrainingState
def get_torchrun_executable():
return os.path.dirname(sys.executable) + os.path.sep + "torchrun"
def run_distributed_training_script(
training_script_file_name: str,
num_nodes: int,
node_rank: int,
num_proc_per_node: int,
master_addr: int = "127.0.0.1",
master_port: int = 8888):
command = f"{get_torchrun_executable()} " \
f"--nproc_per_node={num_proc_per_node} " \
f"--nnodes={num_nodes} " \
f"--node_rank={node_rank} " \
f"--master_addr={master_addr} " \
f"--master_port={master_port} " \
f"{training_script_file_name}"
logging.info(f"Executing -- {command}")
os.system(command)
class RdzvConfig:
def __init__(self, id: int, port: int):
self.port = port
self.id = id
def run_standalone_distributed_training_script(
training_script_file_name: str,
num_proc_per_node: int,
target_checkpoint_examples: Optional[int] = None,
rdzv_config: Optional[RdzvConfig] = None):
command = f"{get_torchrun_executable()} " \
f"--nnodes=1 " \
f"--nproc_per_node={num_proc_per_node} "
if rdzv_config is not None:
command += f"--rdzv_endpoint=localhost:{rdzv_config.port} "
command += "--rdzv_backend=c10d "
command += f"--rdzv_id={rdzv_config.id} "
else:
command += "--standalone "
command += f"{training_script_file_name} "
if target_checkpoint_examples is not None:
command += f"--target_checkpoint_examples {target_checkpoint_examples} "
logging.info(f"Executing -- {command}")
os.system(command)
def define_distributed_training_tasks(
workspace: Workspace,
prefix: str,
training_script_file_name: str,
num_nodes: int,
num_proc_per_node: int,
master_addr: int = "127.0.0.1",
master_port: int = 8888):
def run_training_script_func(rank: int):
def _f():
run_distributed_training_script(
training_script_file_name,
num_nodes,
rank,
num_proc_per_node, master_addr,
master_port)
return _f
for i in range(num_nodes):
workspace.create_command_task(f"{prefix}/train_node_%06d" % i, [], run_training_script_func(i))
def define_standalone_distributed_training_tasks(
workspace: Workspace,
distributed_trainer_func: Callable[[int], DistributedTrainer],
training_script_file_name: str,
num_proc_per_node: int,
dependencies: Optional[List[str]] = None,
rdzv_config: Optional[RdzvConfig] = None):
trainer = distributed_trainer_func(1)
checkpoint_examples = trainer.training_protocol.get_checkpoint_examples()
assert len(checkpoint_examples) >= 1
assert checkpoint_examples[0] > 0
checkpoint_examples = [0] + checkpoint_examples
if dependencies is None:
dependencies = []
module_file_dependencies = dependencies[:]
for module_name in trainer.pretrained_module_file_names:
module_file_dependencies.append(trainer.pretrained_module_file_names[module_name])
def create_train_func(target_checkpoint_examples: int):
return lambda: run_standalone_distributed_training_script(
training_script_file_name,
num_proc_per_node,
target_checkpoint_examples,
rdzv_config=rdzv_config)
train_tasks = []
for checkpoint_index in range(0, len(checkpoint_examples)):
for module_name in trainer.module_names:
module_file_name = DistributedTrainingState.get_module_file_name(
trainer.get_checkpoint_prefix(checkpoint_index),
module_name)
workspace.create_file_task(
module_file_name,
module_file_dependencies,
create_train_func(trainer.checkpoint_examples[checkpoint_index]))
for module_name in trainer.accumulators:
accumulated_module_file_name = DistributedTrainingState.get_accumulated_module_file_name(
trainer.get_checkpoint_prefix(checkpoint_index),
module_name)
workspace.create_file_task(
accumulated_module_file_name,
module_file_dependencies,
create_train_func(checkpoint_examples[checkpoint_index]))
workspace.create_command_task(
trainer.get_checkpoint_prefix(checkpoint_index) + "/train_standalone",
module_file_dependencies,
create_train_func(checkpoint_examples[checkpoint_index]))
train_tasks.append(trainer.get_checkpoint_prefix(checkpoint_index) + "/train_standlone")
workspace.create_file_task(
trainer.prefix + "/train_standalone",
module_file_dependencies,
create_train_func(checkpoint_examples[-1]))
if __name__ == "__main__":
print(os.path.dirname(sys.executable) + os.path.sep + "torchrun")
================================================
FILE: src/tha4/shion/core/training/sample_output_protocol.py
================================================
from abc import ABC, abstractmethod
from typing import Dict, Any
import torch
from torch.nn import Module
from torch.utils.data import Dataset
class SampleOutputProtocol(ABC):
@abstractmethod
def get_examples_per_sample_output(self) -> int:
pass
@abstractmethod
def get_random_seed(self) -> int:
pass
@abstractmethod
def get_sample_output_data(self, validation_dataset: Dataset, device: torch.device) -> Any:
pass
@abstractmethod
def save_sample_output_data(
self,
modules: Dict[str, Module],
accumulated_modules: Dict[str, Module],
sample_output_data: Any,
prefix: str,
examples_seen_so_far: int,
device: torch.device):
pass
class AbstractSampleOutputProtocol(SampleOutputProtocol, ABC):
def __init__(self, examples_per_sample_output: int, random_seed: int):
self.random_seed = random_seed
self.examples_per_sample_output = examples_per_sample_output
def get_examples_per_sample_output(self) -> int:
return self.examples_per_sample_output
def get_random_seed(self) -> int:
return self.random_seed
================================================
FILE: src/tha4/shion/core/training/single/__init__.py
================================================
================================================
FILE: src/tha4/shion/core/training/single/training_states.py
================================================
import copy
import logging
import os
from typing import Dict, Optional
import torch
from torch.nn import Module
from torch.optim import Optimizer
from tha4.shion.core.load_save import torch_save, torch_load
from tha4.shion.core.module_accumulator import ModuleAccumulator
from tha4.shion.core.module_factory import ModuleFactory
from tha4.shion.core.optimizer_factory import OptimizerFactory
from tha4.shion.core.training.util import optimizer_to_device
class TrainingState:
def __init__(self,
examples_seen_so_far: int,
modules: Dict[str, Module],
accumulated_modules: Dict[str, Module],
optimizers: Dict[str, Optimizer]):
self.accumulated_modules = accumulated_modules
self.optimizers = optimizers
self.modules = modules
self.examples_seen_so_far = examples_seen_so_far
@staticmethod
def get_examples_seen_so_far_file_name(prefix) -> str:
return prefix + "/examples_seen_so_far.txt"
@staticmethod
def get_module_file_name(prefix, module_name) -> str:
return "%s/module_%s.pt" % (prefix, module_name)
@staticmethod
def get_accumulated_module_file_name(prefix, module_name) -> str:
return "%s/accumulated_%s.pt" % (prefix, module_name)
@staticmethod
def get_optimizer_file_name(prefix, module_name) -> str:
return "%s/optimizer_%s.pt" % (prefix, module_name)
@staticmethod
def get_rng_state_file_name(prefix):
return "%s/rng_state.pt" % prefix
def save(self, prefix):
logging.info("Saving training state to %s" % prefix)
os.makedirs(prefix, exist_ok=True)
with open(TrainingState.get_examples_seen_so_far_file_name(prefix), "wt") as fout:
fout.write("%d\n" % self.examples_seen_so_far)
logging.info("Saved %s" % TrainingState.get_examples_seen_so_far_file_name(prefix))
for module_name in self.modules:
file_name = TrainingState.get_module_file_name(prefix, module_name)
torch_save(self.modules[module_name].state_dict(), file_name)
logging.info("Saved %s" % file_name)
for module_name in self.accumulated_modules:
file_name = TrainingState.get_accumulated_module_file_name(prefix, module_name)
torch_save(self.accumulated_modules[module_name].state_dict(), file_name)
logging.info("Saved %s" % file_name)
for module_name in self.optimizers:
file_name = TrainingState.get_optimizer_file_name(prefix, module_name)
torch_save(self.optimizers[module_name].state_dict(), file_name)
logging.info("Saved %s" % file_name)
torch_save(torch.get_rng_state(), TrainingState.get_rng_state_file_name(prefix))
logging.info("Saved %s" % TrainingState.get_rng_state_file_name(prefix))
logging.info("Done saving training state to %s" % prefix)
@staticmethod
def get_examples_seen_so_far(prefix: str) -> int:
with open(TrainingState.get_examples_seen_so_far_file_name(prefix)) as fin:
lines = fin.readlines()
return int(lines[0])
@staticmethod
def load(prefix: str,
module_factories: Dict[str, ModuleFactory],
accumulators: Dict[str, ModuleAccumulator],
optimizer_factories: Dict[str, OptimizerFactory],
device: torch.device) -> 'TrainingState':
logging.info("Loading training state from %s" % prefix)
with open(TrainingState.get_examples_seen_so_far_file_name(prefix)) as fin:
lines = fin.readlines()
examples_seen_so_far = int(lines[0])
logging.info("Loaded %s" % TrainingState.get_examples_seen_so_far_file_name(prefix))
modules = {
module_name: factory.create()
for (module_name, factory) in module_factories.items()
}
for module_name in modules:
file_name = TrainingState.get_module_file_name(prefix, module_name)
modules[module_name].load_state_dict(torch_load(file_name))
modules[module_name].to(device)
logging.info("Loaded %s" % file_name)
accumulated_modules = {}
for module_name in accumulators:
module_factory = module_factories[module_name]
module = module_factory.create()
file_name = TrainingState.get_accumulated_module_file_name(prefix, module_name)
module.load_state_dict(torch_load(file_name))
module.to(device)
accumulated_modules[module_name] = module
logging.info("Loaded %s" % file_name)
optimizers = {}
for module_name in optimizer_factories:
optimizer = optimizer_factories[module_name].create(modules[module_name].parameters())
file_name = TrainingState.get_optimizer_file_name(prefix, module_name)
optimizer.load_state_dict(torch_load(file_name))
optimizer_to_device(optimizer, device)
optimizers[module_name] = optimizer
logging.info("Loaded %s" % file_name)
torch.set_rng_state(torch_load(TrainingState.get_rng_state_file_name(prefix)))
logging.info("Loaded %s" % TrainingState.get_rng_state_file_name(prefix))
logging.info("Done loading training state from %s" % prefix)
return TrainingState(examples_seen_so_far, modules, accumulated_modules, optimizers)
@staticmethod
def new(module_factories: Dict[str, ModuleFactory],
accumulators: Dict[str, ModuleAccumulator],
optimizer_factories: Dict[str, OptimizerFactory],
random_seed: int,
device: torch.device,
pretrained_module_file_names: Optional[Dict[str, str]] = None) -> 'TrainingState':
examples_seen_so_far = 0
modules = {
module_name: factory.create()
for (module_name, factory) in module_factories.items()
}
for module_name in modules:
modules[module_name].to(device)
if pretrained_module_file_names is not None:
for module_name in modules:
if module_name in pretrained_module_file_names:
file_name = pretrained_module_file_names[module_name]
modules[module_name].load_state_dict(torch_load(file_name))
logging.info("Loaded initial state from %s ..." % file_name)
accumulated_modules = {}
for module_name in accumulators:
accumulated_modules[module_name] = copy.deepcopy(modules[module_name])
optimizers = {}
for module_name in optimizer_factories:
module = modules[module_name]
optimizer = optimizer_factories[module_name].create(module.parameters())
optimizer_to_device(optimizer, device)
optimizers[module_name] = optimizer
torch.manual_seed(random_seed)
return TrainingState(examples_seen_so_far, modules, accumulated_modules, optimizers)
@staticmethod
def can_load(prefix: str,
module_factories: Dict[str, ModuleFactory],
accumulators: Dict[str, ModuleAccumulator],
optimizer_factories: Dict[str, OptimizerFactory]) -> bool:
if not os.path.isdir(prefix):
return False
if not os.path.isfile(TrainingState.get_examples_seen_so_far_file_name(prefix)):
return False
for module_name in module_factories.keys():
if not os.path.isfile(TrainingState.get_module_file_name(prefix, module_name)):
return False
for module_name in accumulators:
if not os.path.isfile(TrainingState.get_accumulated_module_file_name(prefix, module_name)):
return False
for module_name in optimizer_factories:
if not os.path.isfile(TrainingState.get_optimizer_file_name(prefix, module_name)):
return False
if not os.path.isfile(TrainingState.get_rng_state_file_name(prefix)):
return False
return True
================================================
FILE: src/tha4/shion/core/training/single/training_tasks.py
================================================
import logging
import time
from datetime import datetime
from typing import Optional, Dict, List
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from tha4.pytasuku.workspace import Workspace
from tha4.shion.core.load_save import torch_save, torch_load
from tha4.shion.core.loss import Loss
from tha4.shion.core.module_accumulator import ModuleAccumulator
from tha4.shion.core.module_factory import ModuleFactory
from tha4.shion.core.training.sample_output_protocol import SampleOutputProtocol
from tha4.shion.core.training.single.training_states import TrainingState
from tha4.shion.core.training.training_protocol import TrainingProtocol
from tha4.shion.core.training.util import get_least_greater_multiple, create_log_func, set_learning_rate
from tha4.shion.core.training.validation_protocol import ValidationProtocol
KEY_CHECKPOINT = 'checkpoint'
KEY_SNAPSHOT = 'snapshot'
KEY_VALIDATION = 'validation'
KEY_SAMPLE_OUTPUT = 'sample_output'
class TrainingTasks:
def __init__(
self,
workspace: Workspace,
prefix: str,
module_factories: Dict[str, ModuleFactory],
accumulators: Dict[str, ModuleAccumulator],
losses: Dict[str, Loss],
training_dataset: Dataset,
validation_dataset: Optional[Dataset],
training_protocol: TrainingProtocol,
validation_protocol: Optional[ValidationProtocol],
sample_output_protocol: Optional[SampleOutputProtocol],
pretrained_module_file_names: Dict[str, str],
example_per_snapshot: int,
device: torch.device,
num_data_loader_workers: int = 8,
dependencies: Optional[List[str]] = None):
super().__init__()
self.num_data_loader_workers = num_data_loader_workers
self.accumulators = accumulators
self.device = device
self.sample_output_protocol = sample_output_protocol
self.example_per_snapshot = example_per_snapshot
self.pretrained_module_file_names = pretrained_module_file_names
self.losses = losses
self.validation_protocol = validation_protocol
self.training_protocol = training_protocol
self.module_factories = module_factories
self.prefix = prefix
self.training_dataset = training_dataset
self.validation_dataset = validation_dataset
self.checkpoint_examples = self.training_protocol.get_checkpoint_examples()
assert len(self.checkpoint_examples) >= 1
assert self.checkpoint_examples[0] > 0
self.checkpoint_examples = [0] + self.checkpoint_examples
self.module_names = self.module_factories.keys()
assert len(self.module_names) > 0
self.training_data_loader = None
self.training_data_loader_iter = None
self.training_data_loader_batch_size = None
self.validation_data_loader = None
self.validation_data_loader_iter = None
self.validation_data_loader_batch_size = None
self.sample_output_data = None
self.summary_writer = None
self.log_dir = None
self.training_state = None
if dependencies is None:
dependencies = []
self.sample_output_data_task = workspace.create_file_task(
self.get_sample_output_data_file_name(), dependencies, self.save_sample_output_data)
module_file_dependencies = [self.sample_output_data_task.name]
for module_name in pretrained_module_file_names:
module_file_dependencies.append(self.pretrained_module_file_names[module_name])
def create_train_func(target_examples: int):
return lambda: self.train(target_examples)
train_tasks = []
for checkpoint_index in range(1, len(self.checkpoint_examples)):
for module_name in self.module_names:
module_file_name = TrainingState.get_module_file_name(
self.get_checkpoint_prefix(checkpoint_index),
module_name)
workspace.create_file_task(
module_file_name,
module_file_dependencies,
create_train_func(self.checkpoint_examples[checkpoint_index]))
for module_name in self.accumulators:
accumulated_module_file_name = TrainingState.get_accumulated_module_file_name(
self.get_checkpoint_prefix(checkpoint_index),
module_name)
workspace.create_file_task(
accumulated_module_file_name,
module_file_dependencies,
create_train_func(self.checkpoint_examples[checkpoint_index]))
workspace.create_command_task(
self.get_checkpoint_prefix(checkpoint_index) + "/train",
module_file_dependencies,
create_train_func(self.checkpoint_examples[checkpoint_index]))
train_tasks.append(self.get_checkpoint_prefix(checkpoint_index) + "/train")
self.train_task = workspace.create_file_task(
self.get_train_command_name(),
module_file_dependencies,
create_train_func(self.checkpoint_examples[-1]))
def get_sample_output_data_file_name(self):
return self.prefix + "/sample_output_data.pt"
def save_sample_output_data(self):
if self.sample_output_protocol is not None:
torch.manual_seed(self.sample_output_protocol.get_random_seed())
sample_output_data = self.sample_output_protocol.get_sample_output_data(self.validation_dataset,
self.device)
torch_save(sample_output_data, self.get_sample_output_data_file_name())
else:
torch_save({}, self.get_sample_output_data_file_name())
def get_module_file_name(self, checkpoint_index, module_name):
return TrainingState.get_module_file_name(self.get_checkpoint_prefix(checkpoint_index), module_name)
def get_last_module_file_name(self, module_name):
return self.get_module_file_name(len(self.checkpoint_examples) - 1, module_name)
def get_log_dir(self):
if self.log_dir is None:
now = datetime.now()
self.log_dir = self.prefix + "/log/" + now.strftime("%Y_%m_%d__%H_%M_%S")
return self.log_dir
def get_summary_writer(self) -> SummaryWriter:
if self.summary_writer is None:
self.summary_writer = SummaryWriter(log_dir=self.get_log_dir())
return self.summary_writer
def get_train_command_name(self) -> str:
return self.prefix + "/train"
def get_snapshot_prefix(self) -> str:
return self.prefix + "/snapshot"
def get_checkpoint_prefix(self, checkpoint_index) -> str:
return "%s/checkpoint/%04d" % (self.prefix, checkpoint_index)
def can_load_training_state(self, prefix) -> bool:
return TrainingState.can_load(
prefix,
self.module_factories,
self.accumulators,
self.training_protocol.get_optimizer_factories())
def load_training_state(self, prefix) -> TrainingState:
return TrainingState.load(
prefix,
self.module_factories,
self.accumulators,
self.training_protocol.get_optimizer_factories(),
self.device)
def get_initial_training_state(self) -> TrainingState:
training_state = TrainingState.new(
self.module_factories,
self.accumulators,
self.training_protocol.get_optimizer_factories(),
self.training_protocol.get_random_seed(),
self.device,
self.pretrained_module_file_names)
logging.info("Created a new initial training state.")
return training_state
def load_previous_training_state(self, target_checkpoint_examples: int) -> TrainingState:
if self.can_load_training_state(self.get_snapshot_prefix()):
examples_seen_so_far = TrainingState.get_examples_seen_so_far(self.get_snapshot_prefix())
diff = examples_seen_so_far - target_checkpoint_examples
if diff < self.training_protocol.get_batch_size():
return self.load_training_state(self.get_snapshot_prefix())
num_checkpoints = len(self.checkpoint_examples)
for checkpoint_index in range(num_checkpoints - 1, -1, -1):
if self.can_load_training_state(self.get_checkpoint_prefix(checkpoint_index)):
examples_seen_so_far = TrainingState.get_examples_seen_so_far(
self.get_checkpoint_prefix(checkpoint_index))
diff = examples_seen_so_far - target_checkpoint_examples
if diff < self.training_protocol.get_batch_size():
return self.load_training_state(self.get_checkpoint_prefix(checkpoint_index))
return self.get_initial_training_state()
def get_next_checkpoint_num_examples(self, examples_seen_so_far) -> int:
next_index = next(
(i for i in range(len(self.checkpoint_examples)) if self.checkpoint_examples[i] > examples_seen_so_far),
-1)
return self.checkpoint_examples[next_index]
def get_next_snapshot_num_examples(self, examples_seen_so_far) -> int:
return get_least_greater_multiple(examples_seen_so_far, self.example_per_snapshot)
def get_next_validation_num_examples(self, examples_seen_so_far) -> int:
if self.validation_protocol is None:
return -1
return get_least_greater_multiple(examples_seen_so_far,
self.validation_protocol.get_examples_per_validation_iteration())
def get_next_sample_output_num_examples(self, examples_seen_so_far) -> int:
if self.sample_output_protocol is None:
return -1
return get_least_greater_multiple(examples_seen_so_far,
self.sample_output_protocol.get_examples_per_sample_output())
def get_next_num_examples(self, examples_seen_so_far) -> Dict[str, int]:
return {
KEY_CHECKPOINT: self.get_next_checkpoint_num_examples(examples_seen_so_far),
KEY_SNAPSHOT: self.get_next_snapshot_num_examples(examples_seen_so_far),
KEY_VALIDATION: self.get_next_validation_num_examples(examples_seen_so_far),
KEY_SAMPLE_OUTPUT: self.get_next_sample_output_num_examples(examples_seen_so_far)
}
def get_checkpoint_index_to_save(self, examples_seen_so_far: int) -> int:
checkpoint_index = 0
for i in range(len(self.checkpoint_examples)):
if self.checkpoint_examples[i] <= examples_seen_so_far:
checkpoint_index = i
return checkpoint_index
def get_next_training_batch(self):
if self.training_data_loader is None:
self.training_data_loader = DataLoader(
self.training_dataset,
batch_size=self.training_protocol.get_batch_size(),
shuffle=True,
num_workers=self.num_data_loader_workers,
drop_last=True)
if self.training_data_loader_iter is None:
self.training_data_loader_iter = iter(self.training_data_loader)
try:
batch = next(self.training_data_loader_iter)
except StopIteration:
self.training_data_loader_iter = iter(self.training_data_loader)
batch = next(self.training_data_loader_iter)
return [x.to(self.device) for x in batch]
def get_next_validation_batch(self):
if self.validation_dataset is None:
return None
if self.validation_data_loader is None:
self.validation_data_loader = DataLoader(
self.validation_dataset,
batch_size=self.validation_protocol.get_batch_size(),
shuffle=True,
num_workers=self.num_data_loader_workers,
drop_last=True)
if self.validation_data_loader_iter is None:
self.validation_data_loader_iter = iter(self.validation_data_loader)
try:
batch = next(self.validation_data_loader_iter)
except StopIteration:
self.validation_data_loader_iter = iter(self.validation_data_loader)
batch = next(self.validation_data_loader_iter)
return [x.to(self.device) for x in batch]
def get_checkpoint_index(self, target_checkpoint_examples: int):
return self.checkpoint_examples.index(target_checkpoint_examples)
def train(self, target_checkpoint_examples: Optional[int] = None):
if target_checkpoint_examples is None:
target_checkpoint_examples = self.checkpoint_examples[-1]
sample_output_data = torch_load(self.get_sample_output_data_file_name())
logging.info("Loaded sampled output data from %s", self.get_sample_output_data_file_name())
training_state = self.load_previous_training_state(target_checkpoint_examples)
summary_writer = self.get_summary_writer()
last_time = time.time()
while training_state.examples_seen_so_far < target_checkpoint_examples:
# One training iteration
learning_rate = self.training_protocol.get_learning_rate(training_state.examples_seen_so_far)
for module_name in self.module_factories.keys():
if module_name not in learning_rate or module_name not in training_state.optimizers:
continue
lr = learning_rate[module_name]
set_learning_rate(training_state.optimizers[module_name], lr)
self.get_summary_writer().add_scalar(
module_name + "_learning_rate", lr, training_state.examples_seen_so_far)
training_batch = self.get_next_training_batch()
self.training_protocol.run_training_iteration(
training_batch,
training_state.examples_seen_so_far,
training_state.modules,
training_state.accumulated_modules,
training_state.optimizers,
self.losses,
lambda name, num: create_log_func(summary_writer, name, num),
self.device)
# Accumulate model data
for module_name in self.accumulators:
new_module = training_state.modules[module_name]
buffer_module = training_state.accumulated_modules[module_name]
self.accumulators[module_name].accumulate(
new_module,
buffer_module,
training_state.examples_seen_so_far)
# Advance the number of examples seen so far
next_num_examples = self.get_next_num_examples(training_state.examples_seen_so_far)
training_state.examples_seen_so_far += self.training_protocol.get_batch_size()
# Validation iteration
if self.validation_protocol is not None \
and training_state.examples_seen_so_far >= next_num_examples[KEY_VALIDATION]:
validation_batch = self.get_next_validation_batch()
self.validation_protocol.run_validation_iteration(
validation_batch,
training_state.examples_seen_so_far,
training_state.modules,
training_state.accumulated_modules,
self.losses,
lambda name, num: create_log_func(summary_writer, name, num),
self.device)
# Save sample output
if self.sample_output_protocol is not None \
and training_state.examples_seen_so_far >= next_num_examples[KEY_SAMPLE_OUTPUT]:
self.sample_output_protocol.save_sample_output_data(
training_state.modules,
training_state.accumulated_modules,
sample_output_data,
self.prefix + "/sample_outputs",
training_state.examples_seen_so_far,
self.device)
# Save checkpoint
if training_state.examples_seen_so_far >= next_num_examples[KEY_CHECKPOINT]:
checkpoint_index = self.get_checkpoint_index_to_save(training_state.examples_seen_so_far)
training_state.save(self.get_checkpoint_prefix(checkpoint_index))
if next_num_examples[KEY_CHECKPOINT] != next_num_examples[KEY_SNAPSHOT]:
training_state.save(self.get_snapshot_prefix())
# Save snapshot
if training_state.examples_seen_so_far >= next_num_examples[KEY_SNAPSHOT]:
training_state.save(self.get_snapshot_prefix())
now = time.time()
if now - last_time > 10:
logging.info("Showed %d training examples." % training_state.examples_seen_so_far)
last_time = now
================================================
FILE: src/tha4/shion/core/training/swarm/__init__.py
================================================
================================================
FILE: src/tha4/shion/core/training/swarm/swarm_training_tasks.py
================================================
from typing import Callable, Optional, List
from tha4.pytasuku.workspace import Workspace
from tha4.shion.core.training.distrib.distributed_training_tasks import RdzvConfig, \
run_standalone_distributed_training_script
from tha4.shion.core.training.single.training_states import TrainingState
from tha4.shion.core.training.swarm.swarm_unit_trainer import SwarmUnitTrainer
def define_standalone_swarm_training_tasks(
workspace: Workspace,
swarm_unit_trainer_func: Callable[[], SwarmUnitTrainer],
training_script_file_name: str,
num_proc_per_node: int,
dependencies: Optional[List[str]] = None,
rdzv_config: Optional[RdzvConfig] = None):
trainer = swarm_unit_trainer_func()
checkpoint_examples = trainer.training_protocol.get_checkpoint_examples()
assert len(checkpoint_examples) >= 1
assert checkpoint_examples[0] > 0
checkpoint_examples = [0] + checkpoint_examples
if dependencies is None:
dependencies = []
module_file_dependencies = dependencies[:]
for module_name in trainer.pretrained_module_file_names:
module_file_dependencies.append(trainer.pretrained_module_file_names[module_name])
def create_train_func(target_checkpoint_examples: int):
return lambda: run_standalone_distributed_training_script(
training_script_file_name,
num_proc_per_node,
target_checkpoint_examples,
rdzv_config=rdzv_config)
train_tasks = []
for checkpoint_index in range(0, len(checkpoint_examples)):
for module_name in trainer.module_names:
module_file_name = TrainingState.get_module_file_name(
trainer.get_checkpoint_prefix(checkpoint_index),
module_name)
workspace.create_file_task(
module_file_name,
module_file_dependencies,
create_train_func(trainer.checkpoint_examples[checkpoint_index]))
for module_name in trainer.accumulators:
accumulated_module_file_name = TrainingState.get_accumulated_module_file_name(
trainer.get_checkpoint_prefix(checkpoint_index),
module_name)
workspace.create_file_task(
accumulated_module_file_name,
module_file_dependencies,
create_train_func(checkpoint_examples[checkpoint_index]))
workspace.create_command_task(
trainer.get_checkpoint_prefix(checkpoint_index) + "/train_standalone",
module_file_dependencies,
create_train_func(checkpoint_examples[checkpoint_index]))
train_tasks.append(trainer.get_checkpoint_prefix(checkpoint_index) + "/train_standlone")
workspace.create_file_task(
trainer.prefix + "/train_standalone",
module_file_dependencies,
create_train_func(checkpoint_examples[-1]))
================================================
FILE: src/tha4/shion/core/training/swarm/swarm_unit_trainer.py
================================================
import argparse
import logging
import os
import time
from datetime import datetime
from typing import Dict, Optional, Callable
import torch.distributed
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from tha4.shion.core.load_save import torch_save, torch_load
from tha4.shion.core.loss import Loss
from tha4.shion.core.module_accumulator import ModuleAccumulator
from tha4.shion.core.module_factory import ModuleFactory
from tha4.shion.core.training.distrib.device_mapper import SimpleCudaDeviceMapper
from tha4.shion.core.training.sample_output_protocol import SampleOutputProtocol
from tha4.shion.core.training.single.training_states import TrainingState
from tha4.shion.core.training.single.training_tasks import KEY_CHECKPOINT, KEY_SNAPSHOT, KEY_VALIDATION, KEY_SAMPLE_OUTPUT
from tha4.shion.core.training.training_protocol import TrainingProtocol
from tha4.shion.core.training.util import get_least_greater_multiple, create_log_func, set_learning_rate
from tha4.shion.core.training.validation_protocol import ValidationProtocol
class SwarmUnitTrainer:
def __init__(self,
prefix: str,
module_factories: Dict[str, ModuleFactory],
accumulators: Dict[str, ModuleAccumulator],
losses: Dict[str, Loss],
training_dataset: Dataset,
validation_dataset: Optional[Dataset],
training_protocol: TrainingProtocol,
validation_protocol: Optional[ValidationProtocol],
sample_output_protocol: Optional[SampleOutputProtocol],
pretrained_module_file_names: Dict[str, str],
example_per_snapshot: int,
num_data_loader_workers: int = 8):
self.num_data_loader_workers = num_data_loader_workers
self.accumulators = accumulators
self.sample_output_protocol = sample_output_protocol
self.example_per_snapshot = example_per_snapshot
self.pretrained_module_file_names = pretrained_module_file_names
self.losses = losses
self.validation_protocol = validation_protocol
self.training_protocol = training_protocol
self.module_factories = module_factories
self.prefix = prefix
self.training_dataset = training_dataset
self.validation_dataset = validation_dataset
self.checkpoint_examples = self.training_protocol.get_checkpoint_examples()
assert len(self.checkpoint_examples) >= 1
assert self.checkpoint_examples[0] > 0
self.checkpoint_examples = [0] + self.checkpoint_examples
self.module_names = self.module_factories.keys()
assert len(self.module_names) > 0
self.training_data_loader = None
self.training_data_loader_iter = None
self.training_data_loader_batch_size = None
self.training_data_sampler = None
self.validation_data_loader = None
self.validation_data_loader_iter = None
self.validation_data_loader_batch_size = None
self.sample_output_data = None
self.summary_writer = None
self.log_dir = None
self.training_state = None
def get_sample_output_data_file_name(self):
return self.prefix + "/sample_output_data.pt"
def save_sample_output_data(self, device: torch.device):
if os.path.exists(self.get_sample_output_data_file_name()):
return
if self.sample_output_protocol is not None:
torch.manual_seed(self.sample_output_protocol.get_random_seed())
sample_output_data = self.sample_output_protocol.get_sample_output_data(self.validation_dataset, device)
torch_save(sample_output_data, self.get_sample_output_data_file_name())
else:
torch_save({}, self.get_sample_output_data_file_name())
def load_sample_output_data(self, device: torch.device):
self.save_sample_output_data(device)
return torch_load(self.get_sample_output_data_file_name())
def get_snapshot_prefix(self) -> str:
return self.prefix + "/snapshot"
def can_load_training_state(self, prefix: str) -> bool:
return TrainingState.can_load(
prefix,
self.module_factories,
self.accumulators,
self.training_protocol.get_optimizer_factories())
def load_training_state(self, prefix, device: torch.device) -> TrainingState:
return TrainingState.load(
prefix,
self.module_factories,
self.accumulators,
self.training_protocol.get_optimizer_factories(),
device)
@staticmethod
def checkpoint_prefix(prefix: str, checkpoint_index: int) -> str:
return "%s/checkpoint/%04d" % (prefix, checkpoint_index)
def get_checkpoint_prefix(self, checkpoint_index) -> str:
return SwarmUnitTrainer.checkpoint_prefix(self.prefix, checkpoint_index)
def get_initial_training_state(self, device: torch.device) -> TrainingState:
training_state = TrainingState.new(
self.module_factories,
self.accumulators,
self.training_protocol.get_optimizer_factories(),
self.training_protocol.get_random_seed(),
device,
self.pretrained_module_file_names)
logging.info("Created a new initial training state.")
return training_state
def load_previous_training_state(self,
target_checkpoint_examples: int,
device: torch.device) -> TrainingState:
if self.can_load_training_state(self.get_snapshot_prefix()):
examples_seen_so_far = TrainingState.get_examples_seen_so_far(self.get_snapshot_prefix())
diff = examples_seen_so_far - target_checkpoint_examples
if diff < self.training_protocol.get_batch_size():
return self.load_training_state(self.get_snapshot_prefix(), device)
num_checkpoints = len(self.checkpoint_examples)
for checkpoint_index in range(num_checkpoints - 1, -1, -1):
if self.can_load_training_state(self.get_checkpoint_prefix(checkpoint_index)):
examples_seen_so_far = TrainingState.get_examples_seen_so_far(
self.get_checkpoint_prefix(checkpoint_index))
diff = examples_seen_so_far - target_checkpoint_examples
if diff < self.training_protocol.get_batch_size():
return self.load_training_state(
self.get_checkpoint_prefix(checkpoint_index), device)
training_state = self.get_initial_training_state(device)
training_state.save(self.get_checkpoint_prefix(0))
training_state = self.load_training_state(self.get_checkpoint_prefix(0), device)
return training_state
def get_log_dir(self):
if self.log_dir is None:
now = datetime.now()
self.log_dir = self.prefix + "/log/" + now.strftime("%Y_%m_%d__%H_%M_%S")
return self.log_dir
def get_summary_writer(self) -> Optional[SummaryWriter]:
if self.summary_writer is None:
self.summary_writer = SummaryWriter(log_dir=self.get_log_dir())
return self.summary_writer
def get_next_training_batch(self, device: torch.device):
if self.training_data_loader is None:
self.training_data_loader = DataLoader(
self.training_dataset,
batch_size=self.training_protocol.get_batch_size(),
shuffle=True,
num_workers=self.num_data_loader_workers,
drop_last=True)
if self.training_data_loader_iter is None:
self.training_data_loader_iter = iter(self.training_data_loader)
try:
batch = next(self.training_data_loader_iter)
except StopIteration:
self.training_data_loader_iter = iter(self.training_data_loader)
batch = next(self.training_data_loader_iter)
return [x.to(device) for x in batch]
def get_next_checkpoint_num_examples(self, examples_seen_so_far) -> int:
next_index = next(
(i for i in range(len(self.checkpoint_examples)) if self.checkpoint_examples[i] > examples_seen_so_far),
-1)
return self.checkpoint_examples[next_index]
def get_next_snapshot_num_examples(self, examples_seen_so_far) -> int:
return get_least_greater_multiple(examples_seen_so_far, self.example_per_snapshot)
def get_next_validation_num_examples(self, examples_seen_so_far) -> int:
if self.validation_protocol is None:
return -1
return get_least_greater_multiple(examples_seen_so_far,
self.validation_protocol.get_examples_per_validation_iteration())
def get_next_sample_output_num_examples(self, examples_seen_so_far) -> int:
if self.sample_output_protocol is None:
return -1
return get_least_greater_multiple(examples_seen_so_far,
self.sample_output_protocol.get_examples_per_sample_output())
def get_next_num_examples(self, examples_seen_so_far) -> Dict[str, int]:
return {
KEY_CHECKPOINT: self.get_next_checkpoint_num_examples(examples_seen_so_far),
KEY_SNAPSHOT: self.get_next_snapshot_num_examples(examples_seen_so_far),
KEY_VALIDATION: self.get_next_validation_num_examples(examples_seen_so_far),
KEY_SAMPLE_OUTPUT: self.get_next_sample_output_num_examples(examples_seen_so_far)
}
def get_next_validation_batch(self, device: torch.device):
if self.validation_dataset is None:
return None
if self.validation_data_loader is None:
self.validation_data_loader = DataLoader(
self.validation_dataset,
batch_size=self.validation_protocol.get_batch_size(),
shuffle=True,
num_workers=1,
drop_last=True)
if self.validation_data_loader_iter is None:
self.validation_data_loader_iter = iter(self.validation_data_loader)
try:
batch = next(self.validation_data_loader_iter)
except StopIteration:
self.validation_data_loader_iter = iter(self.validation_data_loader)
batch = next(self.validation_data_loader_iter)
return [x.to(device) for x in batch]
def get_checkpoint_index_to_save(self, examples_seen_so_far: int) -> int:
checkpoint_index = 0
for i in range(len(self.checkpoint_examples)):
if self.checkpoint_examples[i] <= examples_seen_so_far:
checkpoint_index = i
return checkpoint_index
def train(self,
rank: int,
local_rank: int,
target_checkpoint_examples: Optional[int] = None,
device_mapper: Optional[Callable[[int, int], torch.device]] = None):
if target_checkpoint_examples is None:
target_checkpoint_examples = self.checkpoint_examples[-1]
if device_mapper is None:
device_mapper = SimpleCudaDeviceMapper()
device = device_mapper(rank, local_rank)
sample_output_data = self.load_sample_output_data(device)
training_state = self.load_previous_training_state(
target_checkpoint_examples, device)
summary_writer = self.get_summary_writer()
if summary_writer is not None:
log_func_factory = lambda name, num: create_log_func(summary_writer, name, num)
else:
log_func_factory = None
last_time = time.time()
while training_state.examples_seen_so_far < target_checkpoint_examples:
# Set the learning rate
learning_rate_by_module_name = self.training_protocol.get_learning_rate(training_state.examples_seen_so_far)
for module_name in self.module_factories.keys():
if module_name not in learning_rate_by_module_name or module_name not in training_state.optimizers:
continue
lr = learning_rate_by_module_name[module_name]
set_learning_rate(training_state.optimizers[module_name], lr)
if summary_writer is not None:
summary_writer.add_scalar(
module_name + "_learning_rate", lr, training_state.examples_seen_so_far)
# One training iteration
training_batch = self.get_next_training_batch(device)
self.training_protocol.run_training_iteration(
training_batch,
training_state.examples_seen_so_far,
training_state.modules,
training_state.accumulated_modules,
training_state.optimizers,
self.losses,
log_func_factory,
device)
# Accumulate model data
for module_name in self.accumulators:
new_module = training_state.modules[module_name]
buffer_module = training_state.accumulated_modules[module_name]
self.accumulators[module_name].accumulate(
new_module, buffer_module, examples_seen_so_far=training_state.examples_seen_so_far)
# Advance the number of examples seen so far
next_num_examples = self.get_next_num_examples(training_state.examples_seen_so_far)
training_state.examples_seen_so_far += self.training_protocol.get_batch_size()
# Validation iteration
if self.validation_protocol is not None \
and training_state.examples_seen_so_far >= next_num_examples[KEY_VALIDATION]:
validation_batch = self.get_next_validation_batch(device)
self.validation_protocol.run_validation_iteration(
validation_batch,
training_state.examples_seen_so_far,
training_state.modules,
training_state.accumulated_modules,
self.losses,
log_func_factory,
device)
# Save sample output
if self.sample_output_protocol is not None \
and training_state.examples_seen_so_far >= next_num_examples[KEY_SAMPLE_OUTPUT]:
self.sample_output_protocol.save_sample_output_data(
training_state.modules,
training_state.accumulated_modules,
sample_output_data,
self.prefix + "/sample_outputs",
training_state.examples_seen_so_far,
device)
# Save checkpoint
if training_state.examples_seen_so_far >= next_num_examples[KEY_CHECKPOINT]:
checkpoint_index = self.get_checkpoint_index_to_save(training_state.examples_seen_so_far)
training_state.save(self.get_checkpoint_prefix(checkpoint_index))
if next_num_examples[KEY_CHECKPOINT] != next_num_examples[KEY_SNAPSHOT]:
training_state.save(self.get_snapshot_prefix())
# Save snapshot
if training_state.examples_seen_so_far >= next_num_examples[KEY_SNAPSHOT]:
training_state.save(self.get_snapshot_prefix())
now = time.time()
if now - last_time > 10:
logging.info("[Rank %d] Showed %d training examples." % (rank, training_state.examples_seen_so_far))
last_time = now
@staticmethod
def run(trainer_factory: Dict[int, Callable[[], 'SwarmUnitTrainer']],
backend: str = 'gloo',
device_mapper: Optional[Callable[[int, int], torch.device]] = None):
parser = argparse.ArgumentParser(description='Training script.')
parser.add_argument("--target_checkpoint_examples", type=int)
args = parser.parse_args()
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
torch.distributed.init_process_group(backend)
if rank in trainer_factory:
trainer = trainer_factory[rank]()
trainer.train(rank, local_rank, args.target_checkpoint_examples, device_mapper)
================================================
FILE: src/tha4/shion/core/training/training_protocol.py
================================================
from abc import ABC, abstractmethod
from typing import Dict, List, Callable, Any, Optional
import torch
from torch.nn import Module
from torch.optim.optimizer import Optimizer
from tha4.shion.core.loss import Loss
from tha4.shion.core.optimizer_factory import OptimizerFactory
class TrainingProtocol(ABC):
@abstractmethod
def get_optimizer_factories(self) -> Dict[str, OptimizerFactory]:
pass
@abstractmethod
def get_checkpoint_examples(self) -> List[int]:
pass
@abstractmethod
def get_random_seed(self) -> int:
pass
@abstractmethod
def get_batch_size(self) -> int:
pass
@abstractmethod
def get_learning_rate(self, examples_seen_so_far: int) -> Dict[str, float]:
pass
@abstractmethod
def run_training_iteration(
self,
batch: Any,
examples_seen_so_far: int,
modules: Dict[str, Module],
accumulated_modules: Dict[str, Module],
optimizers: Dict[str, Optimizer],
losses: Dict[str, Loss],
create_log_func: Optional[Callable[[str, int], Callable[[str, float], None]]],
device: torch.device):
pass
class AbstractTrainingProtocol(TrainingProtocol, ABC):
def __init__(self,
check_point_examples: List[int],
batch_size: int,
learning_rate: Callable[[int], Dict[str, float]],
optimizer_factories: Dict[str, OptimizerFactory],
random_seed: int):
self.random_seed = random_seed
self.optimizer_factories = optimizer_factories
self.learning_rate = learning_rate
self.batch_size = batch_size
self.check_point_examples = check_point_examples
def get_optimizer_factories(self) -> Dict[str, OptimizerFactory]:
return self.optimizer_factories
def get_checkpoint_examples(self) -> List[int]:
return self.check_point_examples
def get_random_seed(self) -> int:
return self.random_seed
def get_batch_size(self) -> int:
return self.batch_size
def get_learning_rate(self, examples_seen_so_far: int) -> Dict[str, float]:
return self.learning_rate(examples_seen_so_far)
================================================
FILE: src/tha4/shion/core/training/util.py
================================================
from typing import Callable
import torch
from torch.nn import Module
from torch.optim import Optimizer
def optimizer_to_device(optim: Optimizer, device: torch.device):
for state in optim.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device)
def zero_module(module: Module):
parameters = dict(module.named_parameters())
for k in parameters.keys():
parameters[k].data.zero_()
def get_least_greater_multiple(x: int, m: int) -> int:
"""
:param x: a non-negative integer
:param m: a positive integer
:return: the next multiple of m that is greater than x
"""
assert x >= 0
assert m > 0
return (x // m + 1) * m
def create_log_func(summary_writer, prefix: str, examples_seen_so_far: int) -> Callable[[str, float], None]:
def log_func(tag: str, value: float):
summary_writer.add_scalar(prefix + "_" + tag, value, examples_seen_so_far)
return log_func
def set_learning_rate(module, lr):
for param_group in module.param_groups:
param_group['lr'] = lr
================================================
FILE: src/tha4/shion/core/training/validation_protocol.py
================================================
from abc import ABC, abstractmethod
from typing import Dict, Callable, Any
import torch
from torch.nn import Module
from tha4.shion.core.loss import Loss
class ValidationProtocol(ABC):
@abstractmethod
def get_batch_size(self) -> int:
pass
@abstractmethod
def get_examples_per_validation_iteration(self) -> int:
pass
@abstractmethod
def run_validation_iteration(
self,
batch: Any,
examples_seen_so_far: int,
modules: Dict[str, Module],
accumulated_modules: Dict[str, Module],
losses: Dict[str, Loss],
create_log_func: Callable[[str, int], Callable[[str, float], None]],
device: torch.device):
pass
class AbstractValidationProtocol(ValidationProtocol, ABC):
def __init__(self,
example_per_validation_iteration: int,
batch_size: int):
self.batch_size = batch_size
self.example_per_validation_iteration = example_per_validation_iteration
def get_batch_size(self) -> int:
return self.batch_size
def get_examples_per_validation_iteration(self) -> int:
return self.example_per_validation_iteration
================================================
FILE: src/tha4/shion/nn00/__init__.py
================================================
================================================
FILE: src/tha4/shion/nn00/block_args.py
================================================
from typing import Optional
from torch.nn import Module, Sequential
from tha4.shion.core.module_factory import ModuleFactory
from tha4.shion.nn00.linear_module_args import LinearModuleArgs
from tha4.shion.nn00.nonlinearity_factories import resolve_nonlinearity_factory
from tha4.shion.nn00.normalization_layer_factories import resolve_normalization_layer_factory
from tha4.shion.nn00.normalization_layer_factory import NormalizationLayerFactory
class BlockArgs:
def __init__(
self,
linear_module_args: Optional[LinearModuleArgs] = None,
normalization_layer_factory: Optional[NormalizationLayerFactory] = None,
nonlinearity_factory: Optional[ModuleFactory] = None):
if linear_module_args is None:
linear_module_args = LinearModuleArgs()
self.nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory)
self.normalization_layer_factory = resolve_normalization_layer_factory(normalization_layer_factory)
self.linear_module_args = linear_module_args
================================================
FILE: src/tha4/shion/nn00/conv.py
================================================
from typing import Optional, Union, Callable
from torch.nn import Conv2d, Module, Sequential, ConvTranspose2d
from tha4.shion.nn00.block_args import BlockArgs
from tha4.shion.nn00.linear_module_args import LinearModuleArgs, wrap_linear_module
def create_conv7(
in_channels: int,
out_channels: int,
bias: bool = False,
linear_module_args: Optional[LinearModuleArgs] = None) -> Module:
return wrap_linear_module(
Conv2d(in_channels, out_channels, kernel_size=7, stride=1, padding=3, bias=bias),
linear_module_args)
def create_conv3(in_channels: int,
out_channels: int,
bias: bool = False,
linear_module_args: Optional[LinearModuleArgs] = None) -> Module:
return wrap_linear_module(
Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=bias),
linear_module_args)
def create_conv1(
in_channels: int, out_channels: int,
bias: bool = False,
linear_module_args: Optional[LinearModuleArgs] = None) -> Module:
return wrap_linear_module(
Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
linear_module_args)
def create_conv7_block(
in_channels: int,
out_channels: int,
block_args: Optional[BlockArgs] = None) -> Module:
if block_args is None:
block_args = BlockArgs()
return Sequential(
create_conv7(
in_channels,
out_channels,
bias=False,
linear_module_args=block_args.linear_module_args),
block_args.normalization_layer_factory.create(out_channels, affine=True),
block_args.nonlinearity_factory.create())
def create_conv3_block(
in_channels: int,
out_channels: int,
block_args: Optional[BlockArgs] = None) -> Module:
if block_args is None:
block_args = BlockArgs()
return Sequential(
create_conv7(
in_channels,
out_channels,
bias=False,
linear_module_args=block_args.linear_module_args),
block_args.normalization_layer_factory.create(out_channels, affine=True),
block_args.nonlinearity_factory.create())
def create_downsample_block(
in_channels: int,
out_channels: int,
is_output_1x1: bool = False,
block_args: Optional[BlockArgs] = None) -> Module:
if block_args is None:
block_args = BlockArgs()
if is_output_1x1:
return Sequential(
wrap_linear_module(
Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
block_args.linear_module_args),
block_args.nonlinearity_factory.create())
else:
return Sequential(
wrap_linear_module(
Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
block_args.linear_module_args),
block_args.normalization_layer_factory.create(out_channels, affine=True),
block_args.nonlinearity_factory.create())
def create_upsample_block(
in_channels: int,
out_channels: int,
block_args: Optional[BlockArgs] = None) -> Module:
if block_args is None:
block_args = BlockArgs()
return Sequential(
wrap_linear_module(
ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
linear_module_args=block_args.linear_module_args),
block_args.normalization_layer_factory.create(out_channels, affine=True),
block_args.nonlinearity_factory.create())
================================================
FILE: src/tha4/shion/nn00/initialization_funcs.py
================================================
from typing import Callable, Optional
import torch
from torch import zero_
from torch.nn import Module
from torch.nn.init import kaiming_normal_, xavier_normal_, normal_
class HeInitialization:
def __init__(self, a: int = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'):
self.nonlinearity = nonlinearity
self.mode = mode
self.a = a
def __call__(self, module: Module) -> Module:
with torch.no_grad():
kaiming_normal_(module.weight, a=self.a, mode=self.mode, nonlinearity=self.nonlinearity)
return module
class NormalInitialization:
def __init__(self, mean: float = 0.0, std: float = 1.0):
self.std = std
self.mean = mean
def __call__(self, module: Module) -> Module:
with torch.no_grad():
normal_(module.weight, self.mean, self.std)
return module
class XavierInitialization:
def __init__(self, gain: float = 1.0):
self.gain = gain
def __call__(self, module: Module) -> Module:
with torch.no_grad():
xavier_normal_(module.weight, self.gain)
return module
class ZeroInitialization:
def __call__(self, module: Module) -> Module:
with torch.no_grad:
zero_(module.weight)
return module
class NoInitialization:
def __call__(self, module: Module) -> Module:
return module
def resolve_initialization_func(initialization: Optional[Callable[[Module], Module]]):
if initialization is None:
return NoInitialization()
else:
return initialization
================================================
FILE: src/tha4/shion/nn00/linear_module_args.py
================================================
from typing import Optional, Callable
from torch.nn import Module
from torch.nn.utils import spectral_norm
from tha4.shion.nn00.initialization_funcs import resolve_initialization_func
class LinearModuleArgs:
def __init__(
self,
initialization_func: Optional[Callable[[Module], Module]] = None,
use_spectral_norm: bool = False):
self.use_spectral_norm = use_spectral_norm
self.initialization_func = resolve_initialization_func(initialization_func)
def wrap_linear_module(self, module: Module) -> Module:
module = self.initialization_func(module)
if self.use_spectral_norm:
module = spectral_norm(module)
return module
def wrap_linear_module(module: Module, linear_module_args: Optional[LinearModuleArgs] = None):
if linear_module_args is None:
linear_module_args = LinearModuleArgs()
module = linear_module_args.initialization_func(module)
if linear_module_args.use_spectral_norm:
module = spectral_norm(module)
return module
================================================
FILE: src/tha4/shion/nn00/nonlinearity_factories.py
================================================
from typing import Optional
import torch
from torch import Tensor
from torch.nn import Module, ReLU, LeakyReLU, ELU, ReLU6, Hardswish, SiLU, Tanh, Sigmoid
from tha4.shion.core.module_factory import ModuleFactory
class ReLUFactory(ModuleFactory):
def __init__(self, inplace: bool = False):
self.inplace = inplace
def create(self) -> Module:
return ReLU(self.inplace)
class LeakyReLUFactory(ModuleFactory):
def __init__(self, inplace: bool = False, negative_slope: float = 1e-2):
self.negative_slope = negative_slope
self.inplace = inplace
def create(self) -> Module:
return LeakyReLU(inplace=self.inplace, negative_slope=self.negative_slope)
class ELUFactory(ModuleFactory):
def __init__(self, inplace: bool = False, alpha: float = 1.0):
self.alpha = alpha
self.inplace = inplace
def create(self) -> Module:
return ELU(inplace=self.inplace, alpha=self.alpha)
class ReLU6Factory(ModuleFactory):
def __init__(self, inplace: bool = False):
self.inplace = inplace
def create(self) -> Module:
return ReLU6(inplace=self.inplace)
class SiLUFactory(ModuleFactory):
def __init__(self, inplace: bool = False):
self.inplace = inplace
def create(self) -> Module:
return SiLU(inplace=self.inplace)
class HardswishFactory(ModuleFactory):
def __init__(self, inplace: bool = False):
self.inplace = inplace
def create(self) -> Module:
return Hardswish(inplace=self.inplace)
class TanhFactory(ModuleFactory):
def create(self) -> Module:
return Tanh()
class SigmoidFactory(ModuleFactory):
def create(self) -> Module:
return Sigmoid()
class Swish(Module):
def __init__(self):
super().__init__()
def forward(self, x: Tensor):
return x * torch.sigmoid(x)
class SwishFactory(ModuleFactory):
def create(self) -> Module:
return Swish()
def resolve_nonlinearity_factory(nonlinearity_factory: Optional[ModuleFactory]) -> ModuleFactory:
if nonlinearity_factory is None:
return ReLUFactory(inplace=False)
else:
return nonlinearity_factory
================================================
FILE: src/tha4/shion/nn00/normalization_layer_factories.py
================================================
from typing import Optional
import torch
from torch.nn import Module, Parameter, BatchNorm2d, InstanceNorm2d, GroupNorm
from torch.nn.functional import layer_norm
from torch.nn.init import normal_, constant_
from tha4.shion.nn00.normalization_layer_factory import NormalizationLayerFactory
from tha4.shion.nn00.pass_through import PassThrough
class Bias2d(Module):
def __init__(self, num_features: int):
super().__init__()
self.num_features = num_features
self.bias = Parameter(torch.zeros(1, num_features, 1, 1))
def forward(self, x):
return x + self.bias
class NoNorm2dFactory(NormalizationLayerFactory):
def __init__(self):
super().__init__()
def create(self, num_features: int, affine: bool = True) -> Module:
if affine:
return Bias2d(num_features)
else:
return PassThrough()
class BatchNorm2dFactory(NormalizationLayerFactory):
def __init__(self,
weight_mean: Optional[float] = None,
weight_std: Optional[float] = None,
bias: Optional[float] = None):
super().__init__()
self.bias = bias
self.weight_std = weight_std
self.weight_mean = weight_mean
def get_weight_mean(self):
if self.weight_mean is None:
return 1.0
else:
return self.weight_mean
def get_weight_std(self):
if self.weight_std is None:
return 0.02
else:
return self.weight_std
def create(self, num_features: int, affine: bool = True) -> Module:
module = BatchNorm2d(num_features=num_features, affine=affine)
if affine:
if self.weight_mean is not None or self.weight_std is not None:
normal_(module.weight, self.get_weight_mean(), self.get_weight_std())
if self.bias is not None:
constant_(module.bias, self.bias)
return module
class InstanceNorm2dFactory(NormalizationLayerFactory):
def __init__(self):
super().__init__()
def create(self, num_features: int, affine: bool = True) -> Module:
return InstanceNorm2d(num_features=num_features, affine=affine)
class LayerNorm2d(Module):
def __init__(self, channels: int, affine: bool = True):
super(LayerNorm2d, self).__init__()
self.channels = channels
self.affine = affine
if self.affine:
self.weight = Parameter(torch.ones(1, channels, 1, 1))
self.bias = Parameter(torch.zeros(1, channels, 1, 1))
def forward(self, x):
shape = x.size()[1:]
y = layer_norm(x, shape) * self.weight + self.bias
return y
class LayerNorm2dFactory(NormalizationLayerFactory):
def __init__(self):
super().__init__()
def create(self, num_features: int, affine: bool = True) -> Module:
return LayerNorm2d(channels=num_features, affine=affine)
class GroupNormFactory(NormalizationLayerFactory):
def __init__(self, num_groups: int, eps=1e-6):
super().__init__()
self.eps = eps
self.num_groups = num_groups
def create(self, num_features: int, affine: bool = True) -> Module:
return GroupNorm(num_channels=num_features, num_groups=self.num_groups, eps=self.eps, affine=affine)
def resolve_normalization_layer_factory(factory: Optional['NormalizationLayerFactory']) -> 'NormalizationLayerFactory':
if factory is None:
return InstanceNorm2dFactory()
else:
return factory
================================================
FILE: src/tha4/shion/nn00/normalization_layer_factory.py
================================================
from abc import ABC, abstractmethod
from torch.nn import Module
class NormalizationLayerFactory(ABC):
def __init__(self):
super().__init__()
@abstractmethod
def create(self, num_features: int, affine: bool = True) -> Module:
pass
================================================
FILE: src/tha4/shion/nn00/pass_through.py
================================================
from torch.nn import Module
class PassThrough(Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
================================================
FILE: src/tha4/shion/nn00/resnet_block.py
================================================
from typing import Optional
import torch
from torch.nn import Module, Sequential, Parameter
from tha4.shion.nn00.block_args import BlockArgs
from tha4.shion.nn00.conv import create_conv1, create_conv3
class ResnetBlock(Module):
def __init__(self,
num_channels: int,
is1x1: bool = False,
use_scale_parameter: bool = False,
block_args: Optional[BlockArgs] = None):
super().__init__()
if block_args is None:
block_args = BlockArgs()
self.use_scale_parameter = use_scale_parameter
if self.use_scale_parameter:
self.scale = Parameter(torch.zeros(1))
if is1x1:
self.resnet_path = Sequential(
create_conv1(
num_channels,
num_channels,
bias=True,
linear_module_args=block_args.linear_module_args),
block_args.nonlinearity_factory.create(),
create_conv1(
num_channels,
num_channels,
bias=True,
linear_module_args=block_args.linear_module_args))
else:
self.resnet_path = Sequential(
create_conv3(
num_channels,
num_channels,
bias=False,
linear_module_args=block_args.linear_module_args),
block_args.normalization_layer_factory.create(num_channels, affine=True),
block_args.nonlinearity_factory.create(),
create_conv3(
num_channels,
num_channels,
bias=False,
linear_module_args=block_args.linear_module_args),
block_args.normalization_layer_factory.create(num_channels, affine=True))
def forward(self, x):
if self.use_scale_parameter:
return x + self.scale * self.resnet_path(x)
else:
return x + self.resnet_path(x)