Repository: XingangPan/DragGAN
Branch: main
Commit: 336f120ce126
Files: 186
Total size: 1.5 MB
Directory structure:
gitextract_5nncgy8o/
├── .gitignore
├── Dockerfile
├── LICENSE.txt
├── README.md
├── dnnlib/
│ ├── __init__.py
│ └── util.py
├── environment.yml
├── gen_images.py
├── gradio_utils/
│ ├── __init__.py
│ └── utils.py
├── gui_utils/
│ ├── __init__.py
│ ├── gl_utils.py
│ ├── glfw_window.py
│ ├── imgui_utils.py
│ ├── imgui_window.py
│ └── text_utils.py
├── legacy.py
├── requirements.txt
├── stylegan_human/
│ ├── .gitignore
│ ├── PP_HumanSeg/
│ │ ├── deploy/
│ │ │ └── infer.py
│ │ ├── export_model/
│ │ │ └── download_export_model.py
│ │ └── pretrained_model/
│ │ └── download_pretrained_model.py
│ ├── README.md
│ ├── __init__.py
│ ├── alignment.py
│ ├── bg_white.py
│ ├── dnnlib/
│ │ ├── __init__.py
│ │ ├── tflib/
│ │ │ ├── __init__.py
│ │ │ ├── autosummary.py
│ │ │ ├── custom_ops.py
│ │ │ ├── network.py
│ │ │ ├── ops/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── fused_bias_act.cu
│ │ │ │ ├── fused_bias_act.py
│ │ │ │ ├── upfirdn_2d.cu
│ │ │ │ └── upfirdn_2d.py
│ │ │ ├── optimizer.py
│ │ │ └── tfutil.py
│ │ └── util.py
│ ├── docs/
│ │ └── Dataset.md
│ ├── edit/
│ │ ├── __init__.py
│ │ ├── edit_config.py
│ │ └── edit_helper.py
│ ├── edit.py
│ ├── environment.yml
│ ├── generate.py
│ ├── insetgan.py
│ ├── interpolation.py
│ ├── latent_direction/
│ │ └── ss_statics/
│ │ ├── bottom_length_statis/
│ │ │ ├── 3/
│ │ │ │ └── statis.csv
│ │ │ ├── 4/
│ │ │ │ └── statis.csv
│ │ │ └── 5/
│ │ │ └── statis.csv
│ │ └── upper_length_statis/
│ │ └── 5/
│ │ └── statis.csv
│ ├── legacy.py
│ ├── openpose/
│ │ ├── model/
│ │ │ └── .gitkeep
│ │ └── src/
│ │ ├── __init__.py
│ │ ├── body.py
│ │ ├── model.py
│ │ └── util.py
│ ├── pti/
│ │ ├── pti_configs/
│ │ │ ├── __init__.py
│ │ │ ├── global_config.py
│ │ │ ├── hyperparameters.py
│ │ │ └── paths_config.py
│ │ ├── pti_models/
│ │ │ ├── __init__.py
│ │ │ └── e4e/
│ │ │ ├── __init__.py
│ │ │ ├── encoders/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── helpers.py
│ │ │ │ ├── model_irse.py
│ │ │ │ └── psp_encoders.py
│ │ │ ├── latent_codes_pool.py
│ │ │ ├── psp.py
│ │ │ └── stylegan2/
│ │ │ ├── __init__.py
│ │ │ ├── model.py
│ │ │ └── op/
│ │ │ ├── __init__.py
│ │ │ ├── fused_act.py
│ │ │ ├── fused_bias_act.cpp
│ │ │ ├── fused_bias_act_kernel.cu
│ │ │ ├── upfirdn2d.cpp
│ │ │ ├── upfirdn2d.py
│ │ │ └── upfirdn2d_kernel.cu
│ │ └── training/
│ │ ├── __init__.py
│ │ ├── coaches/
│ │ │ ├── __init__.py
│ │ │ ├── base_coach.py
│ │ │ ├── localitly_regulizer.py
│ │ │ ├── multi_id_coach.py
│ │ │ └── single_id_coach.py
│ │ └── projectors/
│ │ ├── __init__.py
│ │ ├── w_plus_projector.py
│ │ └── w_projector.py
│ ├── run_pti.py
│ ├── style_mixing.py
│ ├── stylemixing_video.py
│ ├── torch_utils/
│ │ ├── __init__.py
│ │ ├── custom_ops.py
│ │ ├── misc.py
│ │ ├── models.py
│ │ ├── models_face.py
│ │ ├── op_edit/
│ │ │ ├── __init__.py
│ │ │ ├── fused_act.py
│ │ │ ├── fused_bias_act.cpp
│ │ │ ├── fused_bias_act_kernel.cu
│ │ │ ├── upfirdn2d.cpp
│ │ │ ├── upfirdn2d.py
│ │ │ └── upfirdn2d_kernel.cu
│ │ ├── ops/
│ │ │ ├── __init__.py
│ │ │ ├── bias_act.cpp
│ │ │ ├── bias_act.cu
│ │ │ ├── bias_act.h
│ │ │ ├── bias_act.py
│ │ │ ├── conv2d_gradfix.py
│ │ │ ├── conv2d_resample.py
│ │ │ ├── filtered_lrelu.cpp
│ │ │ ├── filtered_lrelu.cu
│ │ │ ├── filtered_lrelu.h
│ │ │ ├── filtered_lrelu.py
│ │ │ ├── filtered_lrelu_ns.cu
│ │ │ ├── filtered_lrelu_rd.cu
│ │ │ ├── filtered_lrelu_wr.cu
│ │ │ ├── fma.py
│ │ │ ├── grid_sample_gradfix.py
│ │ │ ├── upfirdn2d.cpp
│ │ │ ├── upfirdn2d.cu
│ │ │ ├── upfirdn2d.h
│ │ │ └── upfirdn2d.py
│ │ ├── persistence.py
│ │ └── training_stats.py
│ ├── training/
│ │ ├── __init__.py
│ │ ├── augment.py
│ │ ├── dataset.py
│ │ ├── loss.py
│ │ ├── networks_stylegan2.py
│ │ ├── networks_stylegan3.py
│ │ └── training_loop.py
│ ├── training_scripts/
│ │ ├── sg2/
│ │ │ ├── train.py
│ │ │ └── training/
│ │ │ ├── dataset.py
│ │ │ └── networks.py
│ │ └── sg3/
│ │ ├── train.py
│ │ └── training/
│ │ ├── dataset.py
│ │ ├── networks_stylegan2.py
│ │ └── networks_stylegan3.py
│ └── utils/
│ ├── ImagesDataset.py
│ ├── __init__.py
│ ├── data_utils.py
│ ├── face_alignment.py
│ ├── log_utils.py
│ ├── models_utils.py
│ └── util.py
├── torch_utils/
│ ├── __init__.py
│ ├── custom_ops.py
│ ├── misc.py
│ ├── ops/
│ │ ├── __init__.py
│ │ ├── bias_act.cpp
│ │ ├── bias_act.cu
│ │ ├── bias_act.h
│ │ ├── bias_act.py
│ │ ├── conv2d_gradfix.py
│ │ ├── conv2d_resample.py
│ │ ├── filtered_lrelu.cpp
│ │ ├── filtered_lrelu.cu
│ │ ├── filtered_lrelu.h
│ │ ├── filtered_lrelu.py
│ │ ├── filtered_lrelu_ns.cu
│ │ ├── filtered_lrelu_rd.cu
│ │ ├── filtered_lrelu_wr.cu
│ │ ├── fma.py
│ │ ├── grid_sample_gradfix.py
│ │ ├── upfirdn2d.cpp
│ │ ├── upfirdn2d.cu
│ │ ├── upfirdn2d.h
│ │ └── upfirdn2d.py
│ ├── persistence.py
│ └── training_stats.py
├── training/
│ ├── __init__.py
│ ├── augment.py
│ ├── dataset.py
│ ├── loss.py
│ ├── networks_stylegan2.py
│ ├── networks_stylegan3.py
│ └── training_loop.py
├── visualizer_drag.py
├── visualizer_drag_gradio.py
└── viz/
├── __init__.py
├── capture_widget.py
├── drag_widget.py
├── latent_widget.py
├── pickle_widget.py
└── renderer.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Created by .ignore support plugin (hsz.mobi)
### Python template
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# IPython Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# dotenv
.env
# virtualenv
venv/
ENV/
# Spyder project settings
.spyderproject
# Rope project settings
.ropeproject
### VirtualEnv template
# Virtualenv
# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
.Python
[Bb]in
[Ii]nclude
[Ll]ib
[Ll]ib64
[Ll]ocal
[Ss]cripts
!scripts/
pyvenv.cfg
.venv
pip-selfcheck.json
### JetBrains template
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# User-specific stuff:
.idea/workspace.xml
.idea/tasks.xml
.idea/dictionaries
.idea/vcs.xml
.idea/jsLibraryMappings.xml
# Sensitive or high-churn files:
.idea/dataSources.ids
.idea/dataSources.xml
.idea/dataSources.local.xml
.idea/sqlDataSources.xml
.idea/dynamic.xml
.idea/uiDesigner.xml
# Gradle:
.idea/gradle.xml
.idea/libraries
# Mongo Explorer plugin:
.idea/mongoSettings.xml
.idea/
## File-based project format:
*.iws
## Plugin-specific files:
# IntelliJ
/out/
# mpeltonen/sbt-idea plugin
.idea_modules/
# JIRA plugin
atlassian-ide-plugin.xml
# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties
# Mac related
.DS_Store
checkpoints
================================================
FILE: Dockerfile
================================================
FROM nvcr.io/nvidia/pytorch:23.05-py3
ENV PYTHONDONTWRITEBYTECODE 1
ENV PYTHONUNBUFFERED 1
RUN apt-get update && apt-get install -y --no-install-recommends \
make \
pkgconf \
xz-utils \
xorg-dev \
libgl1-mesa-dev \
libglu1-mesa-dev \
libxrandr-dev \
libxinerama-dev \
libxcursor-dev \
libxi-dev \
libxxf86vm-dev \
&& rm -rf /var/lib/apt/lists/*
RUN pip install --no-cache-dir --upgrade pip
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
WORKDIR /workspace
RUN (printf '#!/bin/bash\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh
ENTRYPOINT ["/entry.sh"]
================================================
FILE: LICENSE.txt
================================================
Copyright (c) 2021, NVIDIA Corporation & affiliates. All rights reserved.
NVIDIA Source Code License for StyleGAN3
=======================================================================
1. Definitions
"Licensor" means any person or entity that distributes its Work.
"Software" means the original work of authorship made available under
this License.
"Work" means the Software and any additions to or derivative works of
the Software that are made available under this License.
The terms "reproduce," "reproduction," "derivative works," and
"distribution" have the meaning as provided under U.S. copyright law;
provided, however, that for the purposes of this License, derivative
works shall not include works that remain separable from, or merely
link (or bind by name) to the interfaces of, the Work.
Works, including the Software, are "made available" under this License
by including in or with the Work either (a) a copyright notice
referencing the applicability of this License to the Work, or (b) a
copy of this License.
2. License Grants
2.1 Copyright Grant. Subject to the terms and conditions of this
License, each Licensor grants to you a perpetual, worldwide,
non-exclusive, royalty-free, copyright license to reproduce,
prepare derivative works of, publicly display, publicly perform,
sublicense and distribute its Work and any resulting derivative
works in any form.
3. Limitations
3.1 Redistribution. You may reproduce or distribute the Work only
if (a) you do so under this License, (b) you include a complete
copy of this License with your distribution, and (c) you retain
without modification any copyright, patent, trademark, or
attribution notices that are present in the Work.
3.2 Derivative Works. You may specify that additional or different
terms apply to the use, reproduction, and distribution of your
derivative works of the Work ("Your Terms") only if (a) Your Terms
provide that the use limitation in Section 3.3 applies to your
derivative works, and (b) you identify the specific derivative
works that are subject to Your Terms. Notwithstanding Your Terms,
this License (including the redistribution requirements in Section
3.1) will continue to apply to the Work itself.
3.3 Use Limitation. The Work and any derivative works thereof only
may be used or intended for use non-commercially. Notwithstanding
the foregoing, NVIDIA and its affiliates may use the Work and any
derivative works commercially. As used herein, "non-commercially"
means for research or evaluation purposes only.
3.4 Patent Claims. If you bring or threaten to bring a patent claim
against any Licensor (including any claim, cross-claim or
counterclaim in a lawsuit) to enforce any patents that you allege
are infringed by any Work, then your rights under this License from
such Licensor (including the grant in Section 2.1) will terminate
immediately.
3.5 Trademarks. This License does not grant any rights to use any
Licensor’s or its affiliates’ names, logos, or trademarks, except
as necessary to reproduce the notices described in this License.
3.6 Termination. If you violate any term of this License, then your
rights under this License (including the grant in Section 2.1) will
terminate immediately.
4. Disclaimer of Warranty.
THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
THIS LICENSE.
5. Limitation of Liability.
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
THE POSSIBILITY OF SUCH DAMAGES.
=======================================================================
================================================
FILE: README.md
================================================
Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold
Xingang Pan
·
Ayush Tewari
·
Thomas Leimkühler
·
Lingjie Liu
·
Abhimitra Meka
·
Christian Theobalt
SIGGRAPH 2023 Conference Proceedings
## Web Demos
[](https://openxlab.org.cn/apps/detail/XingangPan/DragGAN)
## Requirements
If you have CUDA graphic card, please follow the requirements of [NVlabs/stylegan3](https://github.com/NVlabs/stylegan3#requirements).
The usual installation steps involve the following commands, they should set up the correct CUDA version and all the python packages
```
conda env create -f environment.yml
conda activate stylegan3
```
Then install the additional requirements
```
pip install -r requirements.txt
```
Otherwise (for GPU acceleration on MacOS with Silicon Mac M1/M2, or just CPU) try the following:
```sh
cat environment.yml | \
grep -v -E 'nvidia|cuda' > environment-no-nvidia.yml && \
conda env create -f environment-no-nvidia.yml
conda activate stylegan3
# On MacOS
export PYTORCH_ENABLE_MPS_FALLBACK=1
```
## Run Gradio visualizer in Docker
Provided docker image is based on NGC PyTorch repository. To quickly try out visualizer in Docker, run the following:
```sh
# before you build the docker container, make sure you have cloned this repo, and downloaded the pretrained model by `python scripts/download_model.py`.
docker build . -t draggan:latest
docker run -p 7860:7860 -v "$PWD":/workspace/src -it draggan:latest bash
# (Use GPU)if you want to utilize your Nvidia gpu to accelerate in docker, please add command tag `--gpus all`, like:
# docker run --gpus all -p 7860:7860 -v "$PWD":/workspace/src -it draggan:latest bash
cd src && python visualizer_drag_gradio.py --listen
```
Now you can open a shared link from Gradio (printed in the terminal console).
Beware the Docker image takes about 25GB of disk space!
## Download pre-trained StyleGAN2 weights
To download pre-trained weights, simply run:
```
python scripts/download_model.py
```
If you want to try StyleGAN-Human and the Landscapes HQ (LHQ) dataset, please download weights from these links: [StyleGAN-Human](https://drive.google.com/file/d/1dlFEHbu-WzQWJl7nBBZYcTyo000H9hVm/view?usp=sharing), [LHQ](https://drive.google.com/file/d/16twEf0T9QINAEoMsWefoWiyhcTd-aiWc/view?usp=sharing), and put them under `./checkpoints`.
Feel free to try other pretrained StyleGAN.
## Run DragGAN GUI
To start the DragGAN GUI, simply run:
```sh
sh scripts/gui.sh
```
If you are using windows, you can run:
```
.\scripts\gui.bat
```
This GUI supports editing GAN-generated images. To edit a real image, you need to first perform GAN inversion using tools like [PTI](https://github.com/danielroich/PTI). Then load the new latent code and model weights to the GUI.
You can run DragGAN Gradio demo as well, this is universal for both windows and linux:
```sh
python visualizer_drag_gradio.py
```
## Acknowledgement
This code is developed based on [StyleGAN3](https://github.com/NVlabs/stylegan3). Part of the code is borrowed from [StyleGAN-Human](https://github.com/stylegan-human/StyleGAN-Human).
(cheers to the community as well)
## License
The code related to the DragGAN algorithm is licensed under [CC-BY-NC](https://creativecommons.org/licenses/by-nc/4.0/).
However, most of this project are available under a separate license terms: all codes used or modified from [StyleGAN3](https://github.com/NVlabs/stylegan3) is under the [Nvidia Source Code License](https://github.com/NVlabs/stylegan3/blob/main/LICENSE.txt).
Any form of use and derivative of this code must preserve the watermarking functionality showing "AI Generated".
## BibTeX
```bibtex
@inproceedings{pan2023draggan,
title={Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold},
author={Pan, Xingang and Tewari, Ayush, and Leimk{\"u}hler, Thomas and Liu, Lingjie and Meka, Abhimitra and Theobalt, Christian},
booktitle = {ACM SIGGRAPH 2023 Conference Proceedings},
year={2023}
}
```
================================================
FILE: dnnlib/__init__.py
================================================
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from .util import EasyDict, make_cache_dir_path
================================================
FILE: dnnlib/util.py
================================================
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Miscellaneous utility classes and functions."""
import ctypes
import fnmatch
import importlib
import inspect
import numpy as np
import os
import shutil
import sys
import types
import io
import pickle
import re
import requests
import html
import hashlib
import glob
import tempfile
import urllib
import urllib.request
import uuid
from distutils.util import strtobool
from typing import Any, List, Tuple, Union
# Util classes
# ------------------------------------------------------------------------------------------
class EasyDict(dict):
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError:
raise AttributeError(name)
def __setattr__(self, name: str, value: Any) -> None:
self[name] = value
def __delattr__(self, name: str) -> None:
del self[name]
class Logger(object):
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
self.file = None
if file_name is not None:
self.file = open(file_name, file_mode)
self.should_flush = should_flush
self.stdout = sys.stdout
self.stderr = sys.stderr
sys.stdout = self
sys.stderr = self
def __enter__(self) -> "Logger":
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.close()
def write(self, text: Union[str, bytes]) -> None:
"""Write text to stdout (and a file) and optionally flush."""
if isinstance(text, bytes):
text = text.decode()
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
return
if self.file is not None:
self.file.write(text)
self.stdout.write(text)
if self.should_flush:
self.flush()
def flush(self) -> None:
"""Flush written text to both stdout and a file, if open."""
if self.file is not None:
self.file.flush()
self.stdout.flush()
def close(self) -> None:
"""Flush, close possible files, and remove stdout/stderr mirroring."""
self.flush()
# if using multiple loggers, prevent closing in wrong order
if sys.stdout is self:
sys.stdout = self.stdout
if sys.stderr is self:
sys.stderr = self.stderr
if self.file is not None:
self.file.close()
self.file = None
# Cache directories
# ------------------------------------------------------------------------------------------
_dnnlib_cache_dir = None
def set_cache_dir(path: str) -> None:
global _dnnlib_cache_dir
_dnnlib_cache_dir = path
def make_cache_dir_path(*paths: str) -> str:
if _dnnlib_cache_dir is not None:
return os.path.join(_dnnlib_cache_dir, *paths)
if 'DNNLIB_CACHE_DIR' in os.environ:
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
if 'HOME' in os.environ:
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
if 'USERPROFILE' in os.environ:
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
# Small util functions
# ------------------------------------------------------------------------------------------
def format_time(seconds: Union[int, float]) -> str:
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
s = int(np.rint(seconds))
if s < 60:
return "{0}s".format(s)
elif s < 60 * 60:
return "{0}m {1:02}s".format(s // 60, s % 60)
elif s < 24 * 60 * 60:
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
else:
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
def format_time_brief(seconds: Union[int, float]) -> str:
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
s = int(np.rint(seconds))
if s < 60:
return "{0}s".format(s)
elif s < 60 * 60:
return "{0}m {1:02}s".format(s // 60, s % 60)
elif s < 24 * 60 * 60:
return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
else:
return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
def ask_yes_no(question: str) -> bool:
"""Ask the user the question until the user inputs a valid answer."""
while True:
try:
print("{0} [y/n]".format(question))
return strtobool(input().lower())
except ValueError:
pass
def tuple_product(t: Tuple) -> Any:
"""Calculate the product of the tuple elements."""
result = 1
for v in t:
result *= v
return result
_str_to_ctype = {
"uint8": ctypes.c_ubyte,
"uint16": ctypes.c_uint16,
"uint32": ctypes.c_uint32,
"uint64": ctypes.c_uint64,
"int8": ctypes.c_byte,
"int16": ctypes.c_int16,
"int32": ctypes.c_int32,
"int64": ctypes.c_int64,
"float32": ctypes.c_float,
"float64": ctypes.c_double
}
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
type_str = None
if isinstance(type_obj, str):
type_str = type_obj
elif hasattr(type_obj, "__name__"):
type_str = type_obj.__name__
elif hasattr(type_obj, "name"):
type_str = type_obj.name
else:
raise RuntimeError("Cannot infer type name from input")
assert type_str in _str_to_ctype.keys()
my_dtype = np.dtype(type_str)
my_ctype = _str_to_ctype[type_str]
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
return my_dtype, my_ctype
def is_pickleable(obj: Any) -> bool:
try:
with io.BytesIO() as stream:
pickle.dump(obj, stream)
return True
except:
return False
# Functionality to import modules/objects by name, and call functions by name
# ------------------------------------------------------------------------------------------
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
"""Searches for the underlying module behind the name to some python object.
Returns the module and the object name (original name with module part removed)."""
# allow convenience shorthands, substitute them by full names
obj_name = re.sub("^np.", "numpy.", obj_name)
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
# list alternatives for (module_name, local_obj_name)
parts = obj_name.split(".")
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
# try each alternative in turn
for module_name, local_obj_name in name_pairs:
try:
module = importlib.import_module(module_name) # may raise ImportError
get_obj_from_module(module, local_obj_name) # may raise AttributeError
return module, local_obj_name
except:
pass
# maybe some of the modules themselves contain errors?
for module_name, _local_obj_name in name_pairs:
try:
importlib.import_module(module_name) # may raise ImportError
except ImportError:
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
raise
# maybe the requested attribute is missing?
for module_name, local_obj_name in name_pairs:
try:
module = importlib.import_module(module_name) # may raise ImportError
get_obj_from_module(module, local_obj_name) # may raise AttributeError
except ImportError:
pass
# we are out of luck, but we have no idea why
raise ImportError(obj_name)
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
"""Traverses the object name and returns the last (rightmost) python object."""
if obj_name == '':
return module
obj = module
for part in obj_name.split("."):
obj = getattr(obj, part)
return obj
def get_obj_by_name(name: str) -> Any:
"""Finds the python object with the given name."""
module, obj_name = get_module_from_obj_name(name)
return get_obj_from_module(module, obj_name)
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
"""Finds the python object with the given name and calls it as a function."""
assert func_name is not None
func_obj = get_obj_by_name(func_name)
assert callable(func_obj)
return func_obj(*args, **kwargs)
def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
"""Finds the python class with the given name and constructs it with the given arguments."""
return call_func_by_name(*args, func_name=class_name, **kwargs)
def get_module_dir_by_obj_name(obj_name: str) -> str:
"""Get the directory path of the module containing the given object name."""
module, _ = get_module_from_obj_name(obj_name)
return os.path.dirname(inspect.getfile(module))
def is_top_level_function(obj: Any) -> bool:
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
def get_top_level_function_name(obj: Any) -> str:
"""Return the fully-qualified name of a top-level function."""
assert is_top_level_function(obj)
module = obj.__module__
if module == '__main__':
module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
return module + "." + obj.__name__
# File system helpers
# ------------------------------------------------------------------------------------------
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
"""List all files recursively in a given directory while ignoring given file and directory names.
Returns list of tuples containing both absolute and relative paths."""
assert os.path.isdir(dir_path)
base_name = os.path.basename(os.path.normpath(dir_path))
if ignores is None:
ignores = []
result = []
for root, dirs, files in os.walk(dir_path, topdown=True):
for ignore_ in ignores:
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
# dirs need to be edited in-place
for d in dirs_to_remove:
dirs.remove(d)
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
absolute_paths = [os.path.join(root, f) for f in files]
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
if add_base_to_relative:
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
assert len(absolute_paths) == len(relative_paths)
result += zip(absolute_paths, relative_paths)
return result
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
"""Takes in a list of tuples of (src, dst) paths and copies files.
Will create all necessary directories."""
for file in files:
target_dir_name = os.path.dirname(file[1])
# will create all intermediate-level directories
if not os.path.exists(target_dir_name):
os.makedirs(target_dir_name)
shutil.copyfile(file[0], file[1])
# URL helpers
# ------------------------------------------------------------------------------------------
def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
"""Determine whether the given object is a valid URL string."""
if not isinstance(obj, str) or not "://" in obj:
return False
if allow_file_urls and obj.startswith('file://'):
return True
try:
res = requests.compat.urlparse(obj)
if not res.scheme or not res.netloc or not "." in res.netloc:
return False
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
if not res.scheme or not res.netloc or not "." in res.netloc:
return False
except:
return False
return True
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
"""Download the given URL and return a binary-mode file object to access the data."""
assert num_attempts >= 1
assert not (return_filename and (not cache))
# Doesn't look like an URL scheme so interpret it as a local filename.
if not re.match('^[a-z]+://', url):
return url if return_filename else open(url, "rb")
# Handle file URLs. This code handles unusual file:// patterns that
# arise on Windows:
#
# file:///c:/foo.txt
#
# which would translate to a local '/c:/foo.txt' filename that's
# invalid. Drop the forward slash for such pathnames.
#
# If you touch this code path, you should test it on both Linux and
# Windows.
#
# Some internet resources suggest using urllib.request.url2pathname() but
# but that converts forward slashes to backslashes and this causes
# its own set of problems.
if url.startswith('file://'):
filename = urllib.parse.urlparse(url).path
if re.match(r'^/[a-zA-Z]:', filename):
filename = filename[1:]
return filename if return_filename else open(filename, "rb")
assert is_url(url)
# Lookup from cache.
if cache_dir is None:
cache_dir = make_cache_dir_path('downloads')
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
if cache:
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
if len(cache_files) == 1:
filename = cache_files[0]
return filename if return_filename else open(filename, "rb")
# Download.
url_name = None
url_data = None
with requests.Session() as session:
if verbose:
print("Downloading %s ..." % url, end="", flush=True)
for attempts_left in reversed(range(num_attempts)):
try:
with session.get(url) as res:
res.raise_for_status()
if len(res.content) == 0:
raise IOError("No data received")
if len(res.content) < 8192:
content_str = res.content.decode("utf-8")
if "download_warning" in res.headers.get("Set-Cookie", ""):
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
if len(links) == 1:
url = requests.compat.urljoin(url, links[0])
raise IOError("Google Drive virus checker nag")
if "Google Drive - Quota exceeded" in content_str:
raise IOError("Google Drive download quota exceeded -- please try again later")
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
url_name = match[1] if match else url
url_data = res.content
if verbose:
print(" done")
break
except KeyboardInterrupt:
raise
except:
if not attempts_left:
if verbose:
print(" failed")
raise
if verbose:
print(".", end="", flush=True)
# Save to cache.
if cache:
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
os.makedirs(cache_dir, exist_ok=True)
with open(temp_file, "wb") as f:
f.write(url_data)
os.replace(temp_file, cache_file) # atomic
if return_filename:
return cache_file
# Return data as file object.
assert not return_filename
return io.BytesIO(url_data)
================================================
FILE: environment.yml
================================================
name: stylegan3
channels:
- pytorch
- nvidia
dependencies:
- python >= 3.8
- pip
- numpy>=1.25
- click>=8.0
- pillow=9.4.0
- scipy=1.11.1
- pytorch>=2.0.1
- torchvision>=0.15.2
- cudatoolkit=11.1
- requests=2.26.0
- tqdm=4.62.2
- ninja=1.10.2
- matplotlib=3.4.2
- imageio=2.9.0
- pip:
- imgui==2.0.0
- glfw==2.6.1
- gradio==3.35.2
- pyopengl==3.1.5
- imageio-ffmpeg==0.4.3
# pyspng is currently broken on MacOS (see https://github.com/nurpax/pyspng/pull/6 for instance)
- pyspng-seunglab
================================================
FILE: gen_images.py
================================================
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Generate images using pretrained network pickle."""
import os
import re
from typing import List, Optional, Tuple, Union
import click
import dnnlib
import numpy as np
import PIL.Image
import torch
import legacy
#----------------------------------------------------------------------------
def parse_range(s: Union[str, List]) -> List[int]:
'''Parse a comma separated list of numbers or ranges and return a list of ints.
Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
'''
if isinstance(s, list): return s
ranges = []
range_re = re.compile(r'^(\d+)-(\d+)$')
for p in s.split(','):
m = range_re.match(p)
if m:
ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
else:
ranges.append(int(p))
return ranges
#----------------------------------------------------------------------------
def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]:
'''Parse a floating point 2-vector of syntax 'a,b'.
Example:
'0,1' returns (0,1)
'''
if isinstance(s, tuple): return s
parts = s.split(',')
if len(parts) == 2:
return (float(parts[0]), float(parts[1]))
raise ValueError(f'cannot parse 2-vector {s}')
#----------------------------------------------------------------------------
def make_transform(translate: Tuple[float,float], angle: float):
m = np.eye(3)
s = np.sin(angle/360.0*np.pi*2)
c = np.cos(angle/360.0*np.pi*2)
m[0][0] = c
m[0][1] = s
m[0][2] = translate[0]
m[1][0] = -s
m[1][1] = c
m[1][2] = translate[1]
return m
#----------------------------------------------------------------------------
@click.command()
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
@click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', required=True)
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
@click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2')
@click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE')
@click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
def generate_images(
network_pkl: str,
seeds: List[int],
truncation_psi: float,
noise_mode: str,
outdir: str,
translate: Tuple[float,float],
rotate: float,
class_idx: Optional[int]
):
"""Generate images using pretrained network pickle.
Examples:
\b
# Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left).
python gen_images.py --outdir=out --trunc=1 --seeds=2 \\
--network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
\b
# Generate uncurated images with truncation using the MetFaces-U dataset
python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\
--network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl
"""
print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
dtype = torch.float32 if device.type == 'mps' else torch.float64
with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype) # type: ignore
# import pickle
# G = legacy.load_network_pkl(f)
# output = open('checkpoints/stylegan2-car-config-f-pt.pkl', 'wb')
# pickle.dump(G, output)
os.makedirs(outdir, exist_ok=True)
# Labels.
label = torch.zeros([1, G.c_dim], device=device)
if G.c_dim != 0:
if class_idx is None:
raise click.ClickException('Must specify class label with --class when using a conditional network')
label[:, class_idx] = 1
else:
if class_idx is not None:
print ('warn: --class=lbl ignored when running on an unconditional network')
# Generate images.
for seed_idx, seed in enumerate(seeds):
print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device, dtype=dtype)
# Construct an inverse rotation/translation matrix and pass to the generator. The
# generator expects this matrix as an inverse to avoid potentially failing numerical
# operations in the network.
if hasattr(G.synthesis, 'input'):
m = make_transform(translate, rotate)
m = np.linalg.inv(m)
G.synthesis.input.transform.copy_(torch.from_numpy(m))
img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
#----------------------------------------------------------------------------
if __name__ == "__main__":
generate_images() # pylint: disable=no-value-for-parameter
#----------------------------------------------------------------------------
================================================
FILE: gradio_utils/__init__.py
================================================
from .utils import (ImageMask, draw_mask_on_image, draw_points_on_image,
get_latest_points_pair, get_valid_mask,
on_change_single_global_state)
__all__ = [
'draw_mask_on_image', 'draw_points_on_image',
'on_change_single_global_state', 'get_latest_points_pair',
'get_valid_mask', 'ImageMask'
]
================================================
FILE: gradio_utils/utils.py
================================================
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
class ImageMask(gr.components.Image):
"""
Sets: source="canvas", tool="sketch"
"""
is_template = True
def __init__(self, **kwargs):
super().__init__(source="upload",
tool="sketch",
interactive=False,
**kwargs)
def preprocess(self, x):
if x is None:
return x
if self.tool == "sketch" and self.source in ["upload", "webcam"
] and type(x) != dict:
decode_image = gr.processing_utils.decode_base64_to_image(x)
width, height = decode_image.size
mask = np.ones((height, width, 4), dtype=np.uint8)
mask[..., -1] = 255
mask = self.postprocess(mask)
x = {'image': x, 'mask': mask}
return super().preprocess(x)
def get_valid_mask(mask: np.ndarray):
"""Convert mask from gr.Image(0 to 255, RGBA) to binary mask.
"""
if mask.ndim == 3:
mask_pil = Image.fromarray(mask).convert('L')
mask = np.array(mask_pil)
if mask.max() == 255:
mask = mask / 255
return mask
def draw_points_on_image(image,
points,
curr_point=None,
highlight_all=True,
radius_scale=0.01):
overlay_rgba = Image.new("RGBA", image.size, 0)
overlay_draw = ImageDraw.Draw(overlay_rgba)
for point_key, point in points.items():
if ((curr_point is not None and curr_point == point_key)
or highlight_all):
p_color = (255, 0, 0)
t_color = (0, 0, 255)
else:
p_color = (255, 0, 0, 35)
t_color = (0, 0, 255, 35)
rad_draw = int(image.size[0] * radius_scale)
p_start = point.get("start_temp", point["start"])
p_target = point["target"]
if p_start is not None and p_target is not None:
p_draw = int(p_start[0]), int(p_start[1])
t_draw = int(p_target[0]), int(p_target[1])
overlay_draw.line(
(p_draw[0], p_draw[1], t_draw[0], t_draw[1]),
fill=(255, 255, 0),
width=2,
)
if p_start is not None:
p_draw = int(p_start[0]), int(p_start[1])
overlay_draw.ellipse(
(
p_draw[0] - rad_draw,
p_draw[1] - rad_draw,
p_draw[0] + rad_draw,
p_draw[1] + rad_draw,
),
fill=p_color,
)
if curr_point is not None and curr_point == point_key:
# overlay_draw.text(p_draw, "p", font=font, align="center", fill=(0, 0, 0))
overlay_draw.text(p_draw, "p", align="center", fill=(0, 0, 0))
if p_target is not None:
t_draw = int(p_target[0]), int(p_target[1])
overlay_draw.ellipse(
(
t_draw[0] - rad_draw,
t_draw[1] - rad_draw,
t_draw[0] + rad_draw,
t_draw[1] + rad_draw,
),
fill=t_color,
)
if curr_point is not None and curr_point == point_key:
# overlay_draw.text(t_draw, "t", font=font, align="center", fill=(0, 0, 0))
overlay_draw.text(t_draw, "t", align="center", fill=(0, 0, 0))
return Image.alpha_composite(image.convert("RGBA"),
overlay_rgba).convert("RGB")
def draw_mask_on_image(image, mask):
im_mask = np.uint8(mask * 255)
im_mask_rgba = np.concatenate(
(
np.tile(im_mask[..., None], [1, 1, 3]),
45 * np.ones(
(im_mask.shape[0], im_mask.shape[1], 1), dtype=np.uint8),
),
axis=-1,
)
im_mask_rgba = Image.fromarray(im_mask_rgba).convert("RGBA")
return Image.alpha_composite(image.convert("RGBA"),
im_mask_rgba).convert("RGB")
def on_change_single_global_state(keys,
value,
global_state,
map_transform=None):
if map_transform is not None:
value = map_transform(value)
curr_state = global_state
if isinstance(keys, str):
last_key = keys
else:
for k in keys[:-1]:
curr_state = curr_state[k]
last_key = keys[-1]
curr_state[last_key] = value
return global_state
def get_latest_points_pair(points_dict):
if not points_dict:
return None
point_idx = list(points_dict.keys())
latest_point_idx = max(point_idx)
return latest_point_idx
================================================
FILE: gui_utils/__init__.py
================================================
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
# empty
================================================
FILE: gui_utils/gl_utils.py
================================================
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import math
import os
import functools
import contextlib
import numpy as np
import OpenGL.GL as gl
import OpenGL.GL.ARB.texture_float
import dnnlib
#----------------------------------------------------------------------------
def init_egl():
assert os.environ['PYOPENGL_PLATFORM'] == 'egl' # Must be set before importing OpenGL.
import OpenGL.EGL as egl
import ctypes
# Initialize EGL.
display = egl.eglGetDisplay(egl.EGL_DEFAULT_DISPLAY)
assert display != egl.EGL_NO_DISPLAY
major = ctypes.c_int32()
minor = ctypes.c_int32()
ok = egl.eglInitialize(display, major, minor)
assert ok
assert major.value * 10 + minor.value >= 14
# Choose config.
config_attribs = [
egl.EGL_RENDERABLE_TYPE, egl.EGL_OPENGL_BIT,
egl.EGL_SURFACE_TYPE, egl.EGL_PBUFFER_BIT,
egl.EGL_NONE
]
configs = (ctypes.c_int32 * 1)()
num_configs = ctypes.c_int32()
ok = egl.eglChooseConfig(display, config_attribs, configs, 1, num_configs)
assert ok
assert num_configs.value == 1
config = configs[0]
# Create dummy pbuffer surface.
surface_attribs = [
egl.EGL_WIDTH, 1,
egl.EGL_HEIGHT, 1,
egl.EGL_NONE
]
surface = egl.eglCreatePbufferSurface(display, config, surface_attribs)
assert surface != egl.EGL_NO_SURFACE
# Setup GL context.
ok = egl.eglBindAPI(egl.EGL_OPENGL_API)
assert ok
context = egl.eglCreateContext(display, config, egl.EGL_NO_CONTEXT, None)
assert context != egl.EGL_NO_CONTEXT
ok = egl.eglMakeCurrent(display, surface, surface, context)
assert ok
#----------------------------------------------------------------------------
_texture_formats = {
('uint8', 1): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE, internalformat=gl.GL_LUMINANCE8),
('uint8', 2): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE_ALPHA, internalformat=gl.GL_LUMINANCE8_ALPHA8),
('uint8', 3): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGB, internalformat=gl.GL_RGB8),
('uint8', 4): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGBA, internalformat=gl.GL_RGBA8),
('float32', 1): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE32F_ARB),
('float32', 2): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE_ALPHA, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE_ALPHA32F_ARB),
('float32', 3): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGB, internalformat=gl.GL_RGB32F),
('float32', 4): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGBA, internalformat=gl.GL_RGBA32F),
}
def get_texture_format(dtype, channels):
return _texture_formats[(np.dtype(dtype).name, int(channels))]
#----------------------------------------------------------------------------
def prepare_texture_data(image):
image = np.asarray(image)
if image.ndim == 2:
image = image[:, :, np.newaxis]
if image.dtype.name == 'float64':
image = image.astype('float32')
return image
#----------------------------------------------------------------------------
def draw_pixels(image, *, pos=0, zoom=1, align=0, rint=True):
pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2])
align = np.broadcast_to(np.asarray(align, dtype='float32'), [2])
image = prepare_texture_data(image)
height, width, channels = image.shape
size = zoom * [width, height]
pos = pos - size * align
if rint:
pos = np.rint(pos)
fmt = get_texture_format(image.dtype, channels)
gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_PIXEL_MODE_BIT)
gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
gl.glRasterPos2f(pos[0], pos[1])
gl.glPixelZoom(zoom[0], -zoom[1])
gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
gl.glDrawPixels(width, height, fmt.format, fmt.type, image)
gl.glPopClientAttrib()
gl.glPopAttrib()
#----------------------------------------------------------------------------
def read_pixels(width, height, *, pos=0, dtype='uint8', channels=3):
pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
dtype = np.dtype(dtype)
fmt = get_texture_format(dtype, channels)
image = np.empty([height, width, channels], dtype=dtype)
gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1)
gl.glReadPixels(int(np.round(pos[0])), int(np.round(pos[1])), width, height, fmt.format, fmt.type, image)
gl.glPopClientAttrib()
return np.flipud(image)
#----------------------------------------------------------------------------
class Texture:
def __init__(self, *, image=None, width=None, height=None, channels=None, dtype=None, bilinear=True, mipmap=True):
self.gl_id = None
self.bilinear = bilinear
self.mipmap = mipmap
# Determine size and dtype.
if image is not None:
image = prepare_texture_data(image)
self.height, self.width, self.channels = image.shape
self.dtype = image.dtype
else:
assert width is not None and height is not None
self.width = width
self.height = height
self.channels = channels if channels is not None else 3
self.dtype = np.dtype(dtype) if dtype is not None else np.uint8
# Validate size and dtype.
assert isinstance(self.width, int) and self.width >= 0
assert isinstance(self.height, int) and self.height >= 0
assert isinstance(self.channels, int) and self.channels >= 1
assert self.is_compatible(width=width, height=height, channels=channels, dtype=dtype)
# Create texture object.
self.gl_id = gl.glGenTextures(1)
with self.bind():
gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR if self.bilinear else gl.GL_NEAREST)
gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR_MIPMAP_LINEAR if self.mipmap else gl.GL_NEAREST)
self.update(image)
def delete(self):
if self.gl_id is not None:
gl.glDeleteTextures([self.gl_id])
self.gl_id = None
def __del__(self):
try:
self.delete()
except:
pass
@contextlib.contextmanager
def bind(self):
prev_id = gl.glGetInteger(gl.GL_TEXTURE_BINDING_2D)
gl.glBindTexture(gl.GL_TEXTURE_2D, self.gl_id)
yield
gl.glBindTexture(gl.GL_TEXTURE_2D, prev_id)
def update(self, image):
if image is not None:
image = prepare_texture_data(image)
assert self.is_compatible(image=image)
with self.bind():
fmt = get_texture_format(self.dtype, self.channels)
gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, fmt.internalformat, self.width, self.height, 0, fmt.format, fmt.type, image)
if self.mipmap:
gl.glGenerateMipmap(gl.GL_TEXTURE_2D)
gl.glPopClientAttrib()
def draw(self, *, pos=0, zoom=1, align=0, rint=False, color=1, alpha=1, rounding=0):
zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2])
size = zoom * [self.width, self.height]
with self.bind():
gl.glPushAttrib(gl.GL_ENABLE_BIT)
gl.glEnable(gl.GL_TEXTURE_2D)
draw_rect(pos=pos, size=size, align=align, rint=rint, color=color, alpha=alpha, rounding=rounding)
gl.glPopAttrib()
def is_compatible(self, *, image=None, width=None, height=None, channels=None, dtype=None): # pylint: disable=too-many-return-statements
if image is not None:
if image.ndim != 3:
return False
ih, iw, ic = image.shape
if not self.is_compatible(width=iw, height=ih, channels=ic, dtype=image.dtype):
return False
if width is not None and self.width != width:
return False
if height is not None and self.height != height:
return False
if channels is not None and self.channels != channels:
return False
if dtype is not None and self.dtype != dtype:
return False
return True
#----------------------------------------------------------------------------
class Framebuffer:
def __init__(self, *, texture=None, width=None, height=None, channels=None, dtype=None, msaa=0):
self.texture = texture
self.gl_id = None
self.gl_color = None
self.gl_depth_stencil = None
self.msaa = msaa
# Determine size and dtype.
if texture is not None:
assert isinstance(self.texture, Texture)
self.width = texture.width
self.height = texture.height
self.channels = texture.channels
self.dtype = texture.dtype
else:
assert width is not None and height is not None
self.width = width
self.height = height
self.channels = channels if channels is not None else 4
self.dtype = np.dtype(dtype) if dtype is not None else np.float32
# Validate size and dtype.
assert isinstance(self.width, int) and self.width >= 0
assert isinstance(self.height, int) and self.height >= 0
assert isinstance(self.channels, int) and self.channels >= 1
assert width is None or width == self.width
assert height is None or height == self.height
assert channels is None or channels == self.channels
assert dtype is None or dtype == self.dtype
# Create framebuffer object.
self.gl_id = gl.glGenFramebuffers(1)
with self.bind():
# Setup color buffer.
if self.texture is not None:
assert self.msaa == 0
gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_TEXTURE_2D, self.texture.gl_id, 0)
else:
fmt = get_texture_format(self.dtype, self.channels)
self.gl_color = gl.glGenRenderbuffers(1)
gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_color)
gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, fmt.internalformat, self.width, self.height)
gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_RENDERBUFFER, self.gl_color)
# Setup depth/stencil buffer.
self.gl_depth_stencil = gl.glGenRenderbuffers(1)
gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_depth_stencil)
gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, gl.GL_DEPTH24_STENCIL8, self.width, self.height)
gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_DEPTH_STENCIL_ATTACHMENT, gl.GL_RENDERBUFFER, self.gl_depth_stencil)
def delete(self):
if self.gl_id is not None:
gl.glDeleteFramebuffers([self.gl_id])
self.gl_id = None
if self.gl_color is not None:
gl.glDeleteRenderbuffers(1, [self.gl_color])
self.gl_color = None
if self.gl_depth_stencil is not None:
gl.glDeleteRenderbuffers(1, [self.gl_depth_stencil])
self.gl_depth_stencil = None
def __del__(self):
try:
self.delete()
except:
pass
@contextlib.contextmanager
def bind(self):
prev_fbo = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING)
prev_rbo = gl.glGetInteger(gl.GL_RENDERBUFFER_BINDING)
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.gl_id)
if self.width is not None and self.height is not None:
gl.glViewport(0, 0, self.width, self.height)
yield
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, prev_fbo)
gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, prev_rbo)
def blit(self, dst=None):
assert dst is None or isinstance(dst, Framebuffer)
with self.bind():
gl.glBindFramebuffer(gl.GL_DRAW_FRAMEBUFFER, 0 if dst is None else dst.fbo)
gl.glBlitFramebuffer(0, 0, self.width, self.height, 0, 0, self.width, self.height, gl.GL_COLOR_BUFFER_BIT, gl.GL_NEAREST)
#----------------------------------------------------------------------------
def draw_shape(vertices, *, mode=gl.GL_TRIANGLE_FAN, pos=0, size=1, color=1, alpha=1):
assert vertices.ndim == 2 and vertices.shape[1] == 2
pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
size = np.broadcast_to(np.asarray(size, dtype='float32'), [2])
color = np.broadcast_to(np.asarray(color, dtype='float32'), [3])
alpha = np.clip(np.broadcast_to(np.asarray(alpha, dtype='float32'), []), 0, 1)
gl.glPushClientAttrib(gl.GL_CLIENT_VERTEX_ARRAY_BIT)
gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_TRANSFORM_BIT)
gl.glMatrixMode(gl.GL_MODELVIEW)
gl.glPushMatrix()
gl.glEnableClientState(gl.GL_VERTEX_ARRAY)
gl.glEnableClientState(gl.GL_TEXTURE_COORD_ARRAY)
gl.glVertexPointer(2, gl.GL_FLOAT, 0, vertices)
gl.glTexCoordPointer(2, gl.GL_FLOAT, 0, vertices)
gl.glTranslate(pos[0], pos[1], 0)
gl.glScale(size[0], size[1], 1)
gl.glColor4f(color[0] * alpha, color[1] * alpha, color[2] * alpha, alpha)
gl.glDrawArrays(mode, 0, vertices.shape[0])
gl.glPopMatrix()
gl.glPopAttrib()
gl.glPopClientAttrib()
#----------------------------------------------------------------------------
def draw_arrow(x1, y1, x2, y2, l=10, width=1.0):
# Compute the length and angle of the arrow
dx = x2 - x1
dy = y2 - y1
length = math.sqrt(dx**2 + dy**2)
if length < l:
return
angle = math.atan2(dy, dx)
# Save the current modelview matrix
gl.glPushMatrix()
# Translate and rotate the coordinate system
gl.glTranslatef(x1, y1, 0.0)
gl.glRotatef(angle * 180.0 / math.pi, 0.0, 0.0, 1.0)
# Set the line width
gl.glLineWidth(width)
# gl.glColor3f(0.75, 0.75, 0.75)
# Begin drawing lines
gl.glBegin(gl.GL_LINES)
# Draw the shaft of the arrow
gl.glVertex2f(0.0, 0.0)
gl.glVertex2f(length, 0.0)
# Draw the head of the arrow
gl.glVertex2f(length, 0.0)
gl.glVertex2f(length - 2 * l, l)
gl.glVertex2f(length, 0.0)
gl.glVertex2f(length - 2 * l, -l)
# End drawing lines
gl.glEnd()
# Restore the modelview matrix
gl.glPopMatrix()
#----------------------------------------------------------------------------
def draw_rect(*, pos=0, pos2=None, size=None, align=0, rint=False, color=1, alpha=1, rounding=0):
assert pos2 is None or size is None
pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
pos2 = np.broadcast_to(np.asarray(pos2, dtype='float32'), [2]) if pos2 is not None else None
size = np.broadcast_to(np.asarray(size, dtype='float32'), [2]) if size is not None else None
size = size if size is not None else pos2 - pos if pos2 is not None else np.array([1, 1], dtype='float32')
pos = pos - size * align
if rint:
pos = np.rint(pos)
rounding = np.broadcast_to(np.asarray(rounding, dtype='float32'), [2])
rounding = np.minimum(np.abs(rounding) / np.maximum(np.abs(size), 1e-8), 0.5)
if np.min(rounding) == 0:
rounding *= 0
vertices = _setup_rect(float(rounding[0]), float(rounding[1]))
draw_shape(vertices, mode=gl.GL_TRIANGLE_FAN, pos=pos, size=size, color=color, alpha=alpha)
@functools.lru_cache(maxsize=10000)
def _setup_rect(rx, ry):
t = np.linspace(0, np.pi / 2, 1 if max(rx, ry) == 0 else 64)
s = 1 - np.sin(t); c = 1 - np.cos(t)
x = [c * rx, 1 - s * rx, 1 - c * rx, s * rx]
y = [s * ry, c * ry, 1 - s * ry, 1 - c * ry]
v = np.stack([x, y], axis=-1).reshape(-1, 2)
return v.astype('float32')
#----------------------------------------------------------------------------
def draw_circle(*, center=0, radius=100, hole=0, color=1, alpha=1):
hole = np.broadcast_to(np.asarray(hole, dtype='float32'), [])
vertices = _setup_circle(float(hole))
draw_shape(vertices, mode=gl.GL_TRIANGLE_STRIP, pos=center, size=radius, color=color, alpha=alpha)
@functools.lru_cache(maxsize=10000)
def _setup_circle(hole):
t = np.linspace(0, np.pi * 2, 128)
s = np.sin(t); c = np.cos(t)
v = np.stack([c, s, c * hole, s * hole], axis=-1).reshape(-1, 2)
return v.astype('float32')
#----------------------------------------------------------------------------
================================================
FILE: gui_utils/glfw_window.py
================================================
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import time
import glfw
import OpenGL.GL as gl
from . import gl_utils
#----------------------------------------------------------------------------
class GlfwWindow: # pylint: disable=too-many-public-methods
def __init__(self, *, title='GlfwWindow', window_width=1920, window_height=1080, deferred_show=True, close_on_esc=True):
self._glfw_window = None
self._drawing_frame = False
self._frame_start_time = None
self._frame_delta = 0
self._fps_limit = None
self._vsync = None
self._skip_frames = 0
self._deferred_show = deferred_show
self._close_on_esc = close_on_esc
self._esc_pressed = False
self._drag_and_drop_paths = None
self._capture_next_frame = False
self._captured_frame = None
# Create window.
glfw.init()
glfw.window_hint(glfw.VISIBLE, False)
self._glfw_window = glfw.create_window(width=window_width, height=window_height, title=title, monitor=None, share=None)
self._attach_glfw_callbacks()
self.make_context_current()
# Adjust window.
self.set_vsync(False)
self.set_window_size(window_width, window_height)
if not self._deferred_show:
glfw.show_window(self._glfw_window)
def close(self):
if self._drawing_frame:
self.end_frame()
if self._glfw_window is not None:
glfw.destroy_window(self._glfw_window)
self._glfw_window = None
#glfw.terminate() # Commented out to play it nice with other glfw clients.
def __del__(self):
try:
self.close()
except:
pass
@property
def window_width(self):
return self.content_width
@property
def window_height(self):
return self.content_height + self.title_bar_height
@property
def content_width(self):
width, _height = glfw.get_window_size(self._glfw_window)
return width
@property
def content_height(self):
_width, height = glfw.get_window_size(self._glfw_window)
return height
@property
def title_bar_height(self):
_left, top, _right, _bottom = glfw.get_window_frame_size(self._glfw_window)
return top
@property
def monitor_width(self):
_, _, width, _height = glfw.get_monitor_workarea(glfw.get_primary_monitor())
return width
@property
def monitor_height(self):
_, _, _width, height = glfw.get_monitor_workarea(glfw.get_primary_monitor())
return height
@property
def frame_delta(self):
return self._frame_delta
def set_title(self, title):
glfw.set_window_title(self._glfw_window, title)
def set_window_size(self, width, height):
width = min(width, self.monitor_width)
height = min(height, self.monitor_height)
glfw.set_window_size(self._glfw_window, width, max(height - self.title_bar_height, 0))
if width == self.monitor_width and height == self.monitor_height:
self.maximize()
def set_content_size(self, width, height):
self.set_window_size(width, height + self.title_bar_height)
def maximize(self):
glfw.maximize_window(self._glfw_window)
def set_position(self, x, y):
glfw.set_window_pos(self._glfw_window, x, y + self.title_bar_height)
def center(self):
self.set_position((self.monitor_width - self.window_width) // 2, (self.monitor_height - self.window_height) // 2)
def set_vsync(self, vsync):
vsync = bool(vsync)
if vsync != self._vsync:
glfw.swap_interval(1 if vsync else 0)
self._vsync = vsync
def set_fps_limit(self, fps_limit):
self._fps_limit = int(fps_limit)
def should_close(self):
return glfw.window_should_close(self._glfw_window) or (self._close_on_esc and self._esc_pressed)
def skip_frame(self):
self.skip_frames(1)
def skip_frames(self, num): # Do not update window for the next N frames.
self._skip_frames = max(self._skip_frames, int(num))
def is_skipping_frames(self):
return self._skip_frames > 0
def capture_next_frame(self):
self._capture_next_frame = True
def pop_captured_frame(self):
frame = self._captured_frame
self._captured_frame = None
return frame
def pop_drag_and_drop_paths(self):
paths = self._drag_and_drop_paths
self._drag_and_drop_paths = None
return paths
def draw_frame(self): # To be overridden by subclass.
self.begin_frame()
# Rendering code goes here.
self.end_frame()
def make_context_current(self):
if self._glfw_window is not None:
glfw.make_context_current(self._glfw_window)
def begin_frame(self):
# End previous frame.
if self._drawing_frame:
self.end_frame()
# Apply FPS limit.
if self._frame_start_time is not None and self._fps_limit is not None:
delay = self._frame_start_time - time.perf_counter() + 1 / self._fps_limit
if delay > 0:
time.sleep(delay)
cur_time = time.perf_counter()
if self._frame_start_time is not None:
self._frame_delta = cur_time - self._frame_start_time
self._frame_start_time = cur_time
# Process events.
glfw.poll_events()
# Begin frame.
self._drawing_frame = True
self.make_context_current()
# Initialize GL state.
gl.glViewport(0, 0, self.content_width, self.content_height)
gl.glMatrixMode(gl.GL_PROJECTION)
gl.glLoadIdentity()
gl.glTranslate(-1, 1, 0)
gl.glScale(2 / max(self.content_width, 1), -2 / max(self.content_height, 1), 1)
gl.glMatrixMode(gl.GL_MODELVIEW)
gl.glLoadIdentity()
gl.glEnable(gl.GL_BLEND)
gl.glBlendFunc(gl.GL_ONE, gl.GL_ONE_MINUS_SRC_ALPHA) # Pre-multiplied alpha.
# Clear.
gl.glClearColor(0, 0, 0, 1)
gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT)
def end_frame(self):
assert self._drawing_frame
self._drawing_frame = False
# Skip frames if requested.
if self._skip_frames > 0:
self._skip_frames -= 1
return
# Capture frame if requested.
if self._capture_next_frame:
self._captured_frame = gl_utils.read_pixels(self.content_width, self.content_height)
self._capture_next_frame = False
# Update window.
if self._deferred_show:
glfw.show_window(self._glfw_window)
self._deferred_show = False
glfw.swap_buffers(self._glfw_window)
def _attach_glfw_callbacks(self):
glfw.set_key_callback(self._glfw_window, self._glfw_key_callback)
glfw.set_drop_callback(self._glfw_window, self._glfw_drop_callback)
def _glfw_key_callback(self, _window, key, _scancode, action, _mods):
if action == glfw.PRESS and key == glfw.KEY_ESCAPE:
self._esc_pressed = True
def _glfw_drop_callback(self, _window, paths):
self._drag_and_drop_paths = paths
#----------------------------------------------------------------------------
================================================
FILE: gui_utils/imgui_utils.py
================================================
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import contextlib
import imgui
#----------------------------------------------------------------------------
def set_default_style(color_scheme='dark', spacing=9, indent=23, scrollbar=27):
s = imgui.get_style()
s.window_padding = [spacing, spacing]
s.item_spacing = [spacing, spacing]
s.item_inner_spacing = [spacing, spacing]
s.columns_min_spacing = spacing
s.indent_spacing = indent
s.scrollbar_size = scrollbar
s.frame_padding = [4, 3]
s.window_border_size = 1
s.child_border_size = 1
s.popup_border_size = 1
s.frame_border_size = 1
s.window_rounding = 0
s.child_rounding = 0
s.popup_rounding = 3
s.frame_rounding = 3
s.scrollbar_rounding = 3
s.grab_rounding = 3
getattr(imgui, f'style_colors_{color_scheme}')(s)
c0 = s.colors[imgui.COLOR_MENUBAR_BACKGROUND]
c1 = s.colors[imgui.COLOR_FRAME_BACKGROUND]
s.colors[imgui.COLOR_POPUP_BACKGROUND] = [x * 0.7 + y * 0.3 for x, y in zip(c0, c1)][:3] + [1]
#----------------------------------------------------------------------------
@contextlib.contextmanager
def grayed_out(cond=True):
if cond:
s = imgui.get_style()
text = s.colors[imgui.COLOR_TEXT_DISABLED]
grab = s.colors[imgui.COLOR_SCROLLBAR_GRAB]
back = s.colors[imgui.COLOR_MENUBAR_BACKGROUND]
imgui.push_style_color(imgui.COLOR_TEXT, *text)
imgui.push_style_color(imgui.COLOR_CHECK_MARK, *grab)
imgui.push_style_color(imgui.COLOR_SLIDER_GRAB, *grab)
imgui.push_style_color(imgui.COLOR_SLIDER_GRAB_ACTIVE, *grab)
imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND, *back)
imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_HOVERED, *back)
imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_ACTIVE, *back)
imgui.push_style_color(imgui.COLOR_BUTTON, *back)
imgui.push_style_color(imgui.COLOR_BUTTON_HOVERED, *back)
imgui.push_style_color(imgui.COLOR_BUTTON_ACTIVE, *back)
imgui.push_style_color(imgui.COLOR_HEADER, *back)
imgui.push_style_color(imgui.COLOR_HEADER_HOVERED, *back)
imgui.push_style_color(imgui.COLOR_HEADER_ACTIVE, *back)
imgui.push_style_color(imgui.COLOR_POPUP_BACKGROUND, *back)
yield
imgui.pop_style_color(14)
else:
yield
#----------------------------------------------------------------------------
@contextlib.contextmanager
def item_width(width=None):
if width is not None:
imgui.push_item_width(width)
yield
imgui.pop_item_width()
else:
yield
#----------------------------------------------------------------------------
def scoped_by_object_id(method):
def decorator(self, *args, **kwargs):
imgui.push_id(str(id(self)))
res = method(self, *args, **kwargs)
imgui.pop_id()
return res
return decorator
#----------------------------------------------------------------------------
def button(label, width=0, enabled=True):
with grayed_out(not enabled):
clicked = imgui.button(label, width=width)
clicked = clicked and enabled
return clicked
#----------------------------------------------------------------------------
def collapsing_header(text, visible=None, flags=0, default=False, enabled=True, show=True):
expanded = False
if show:
if default:
flags |= imgui.TREE_NODE_DEFAULT_OPEN
if not enabled:
flags |= imgui.TREE_NODE_LEAF
with grayed_out(not enabled):
expanded, visible = imgui.collapsing_header(text, visible=visible, flags=flags)
expanded = expanded and enabled
return expanded, visible
#----------------------------------------------------------------------------
def popup_button(label, width=0, enabled=True):
if button(label, width, enabled):
imgui.open_popup(label)
opened = imgui.begin_popup(label)
return opened
#----------------------------------------------------------------------------
def input_text(label, value, buffer_length, flags, width=None, help_text=''):
old_value = value
color = list(imgui.get_style().colors[imgui.COLOR_TEXT])
if value == '':
color[-1] *= 0.5
with item_width(width):
imgui.push_style_color(imgui.COLOR_TEXT, *color)
value = value if value != '' else help_text
changed, value = imgui.input_text(label, value, buffer_length, flags)
value = value if value != help_text else ''
imgui.pop_style_color(1)
if not flags & imgui.INPUT_TEXT_ENTER_RETURNS_TRUE:
changed = (value != old_value)
return changed, value
#----------------------------------------------------------------------------
def drag_previous_control(enabled=True):
dragging = False
dx = 0
dy = 0
if imgui.begin_drag_drop_source(imgui.DRAG_DROP_SOURCE_NO_PREVIEW_TOOLTIP):
if enabled:
dragging = True
dx, dy = imgui.get_mouse_drag_delta()
imgui.reset_mouse_drag_delta()
imgui.end_drag_drop_source()
return dragging, dx, dy
#----------------------------------------------------------------------------
def drag_button(label, width=0, enabled=True):
clicked = button(label, width=width, enabled=enabled)
dragging, dx, dy = drag_previous_control(enabled=enabled)
return clicked, dragging, dx, dy
#----------------------------------------------------------------------------
def drag_hidden_window(label, x, y, width, height, enabled=True):
imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0)
imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0)
imgui.set_next_window_position(x, y)
imgui.set_next_window_size(width, height)
imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE))
dragging, dx, dy = drag_previous_control(enabled=enabled)
imgui.end()
imgui.pop_style_color(2)
return dragging, dx, dy
#----------------------------------------------------------------------------
def click_hidden_window(label, x, y, width, height, img_w, img_h, enabled=True):
imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0)
imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0)
imgui.set_next_window_position(x, y)
imgui.set_next_window_size(width, height)
imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE))
clicked, down = False, False
img_x, img_y = 0, 0
if imgui.is_mouse_down():
posx, posy = imgui.get_mouse_pos()
if posx >= x and posx < x + width and posy >= y and posy < y + height:
if imgui.is_mouse_clicked():
clicked = True
down = True
img_x = round((posx - x) / (width - 1) * (img_w - 1))
img_y = round((posy - y) / (height - 1) * (img_h - 1))
imgui.end()
imgui.pop_style_color(2)
return clicked, down, img_x, img_y
#----------------------------------------------------------------------------
================================================
FILE: gui_utils/imgui_window.py
================================================
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import os
import imgui
import imgui.integrations.glfw
from . import glfw_window
from . import imgui_utils
from . import text_utils
#----------------------------------------------------------------------------
class ImguiWindow(glfw_window.GlfwWindow):
def __init__(self, *, title='ImguiWindow', font=None, font_sizes=range(14,24), **glfw_kwargs):
if font is None:
font = text_utils.get_default_font()
font_sizes = {int(size) for size in font_sizes}
super().__init__(title=title, **glfw_kwargs)
# Init fields.
self._imgui_context = None
self._imgui_renderer = None
self._imgui_fonts = None
self._cur_font_size = max(font_sizes)
# Delete leftover imgui.ini to avoid unexpected behavior.
if os.path.isfile('imgui.ini'):
os.remove('imgui.ini')
# Init ImGui.
self._imgui_context = imgui.create_context()
self._imgui_renderer = _GlfwRenderer(self._glfw_window)
self._attach_glfw_callbacks()
imgui.get_io().ini_saving_rate = 0 # Disable creating imgui.ini at runtime.
imgui.get_io().mouse_drag_threshold = 0 # Improve behavior with imgui_utils.drag_custom().
self._imgui_fonts = {size: imgui.get_io().fonts.add_font_from_file_ttf(font, size) for size in font_sizes}
self._imgui_renderer.refresh_font_texture()
def close(self):
self.make_context_current()
self._imgui_fonts = None
if self._imgui_renderer is not None:
self._imgui_renderer.shutdown()
self._imgui_renderer = None
if self._imgui_context is not None:
#imgui.destroy_context(self._imgui_context) # Commented out to avoid creating imgui.ini at the end.
self._imgui_context = None
super().close()
def _glfw_key_callback(self, *args):
super()._glfw_key_callback(*args)
self._imgui_renderer.keyboard_callback(*args)
@property
def font_size(self):
return self._cur_font_size
@property
def spacing(self):
return round(self._cur_font_size * 0.4)
def set_font_size(self, target): # Applied on next frame.
self._cur_font_size = min((abs(key - target), key) for key in self._imgui_fonts.keys())[1]
def begin_frame(self):
# Begin glfw frame.
super().begin_frame()
# Process imgui events.
self._imgui_renderer.mouse_wheel_multiplier = self._cur_font_size / 10
if self.content_width > 0 and self.content_height > 0:
self._imgui_renderer.process_inputs()
# Begin imgui frame.
imgui.new_frame()
imgui.push_font(self._imgui_fonts[self._cur_font_size])
imgui_utils.set_default_style(spacing=self.spacing, indent=self.font_size, scrollbar=self.font_size+4)
def end_frame(self):
imgui.pop_font()
imgui.render()
imgui.end_frame()
self._imgui_renderer.render(imgui.get_draw_data())
super().end_frame()
#----------------------------------------------------------------------------
# Wrapper class for GlfwRenderer to fix a mouse wheel bug on Linux.
class _GlfwRenderer(imgui.integrations.glfw.GlfwRenderer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.mouse_wheel_multiplier = 1
def scroll_callback(self, window, x_offset, y_offset):
self.io.mouse_wheel += y_offset * self.mouse_wheel_multiplier
#----------------------------------------------------------------------------
================================================
FILE: gui_utils/text_utils.py
================================================
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import functools
from typing import Optional
import dnnlib
import numpy as np
import PIL.Image
import PIL.ImageFont
import scipy.ndimage
from . import gl_utils
#----------------------------------------------------------------------------
def get_default_font():
url = 'http://fonts.gstatic.com/s/opensans/v17/mem8YaGs126MiZpBA-U1UpcaXcl0Aw.ttf' # Open Sans regular
return dnnlib.util.open_url(url, return_filename=True)
#----------------------------------------------------------------------------
@functools.lru_cache(maxsize=None)
def get_pil_font(font=None, size=32):
if font is None:
font = get_default_font()
return PIL.ImageFont.truetype(font=font, size=size)
#----------------------------------------------------------------------------
def get_array(string, *, dropshadow_radius: int=None, **kwargs):
if dropshadow_radius is not None:
offset_x = int(np.ceil(dropshadow_radius*2/3))
offset_y = int(np.ceil(dropshadow_radius*2/3))
return _get_array_priv(string, dropshadow_radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs)
else:
return _get_array_priv(string, **kwargs)
@functools.lru_cache(maxsize=10000)
def _get_array_priv(
string: str, *,
size: int = 32,
max_width: Optional[int]=None,
max_height: Optional[int]=None,
min_size=10,
shrink_coef=0.8,
dropshadow_radius: int=None,
offset_x: int=None,
offset_y: int=None,
**kwargs
):
cur_size = size
array = None
while True:
if dropshadow_radius is not None:
# separate implementation for dropshadow text rendering
array = _get_array_impl_dropshadow(string, size=cur_size, radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs)
else:
array = _get_array_impl(string, size=cur_size, **kwargs)
height, width, _ = array.shape
if (max_width is None or width <= max_width) and (max_height is None or height <= max_height) or (cur_size <= min_size):
break
cur_size = max(int(cur_size * shrink_coef), min_size)
return array
#----------------------------------------------------------------------------
@functools.lru_cache(maxsize=10000)
def _get_array_impl(string, *, font=None, size=32, outline=0, outline_pad=3, outline_coef=3, outline_exp=2, line_pad: int=None):
pil_font = get_pil_font(font=font, size=size)
lines = [pil_font.getmask(line, 'L') for line in string.split('\n')]
lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines]
width = max(line.shape[1] for line in lines)
lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines]
line_spacing = line_pad if line_pad is not None else size // 2
lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:]
mask = np.concatenate(lines, axis=0)
alpha = mask
if outline > 0:
mask = np.pad(mask, int(np.ceil(outline * outline_pad)), mode='constant', constant_values=0)
alpha = mask.astype(np.float32) / 255
alpha = scipy.ndimage.gaussian_filter(alpha, outline)
alpha = 1 - np.maximum(1 - alpha * outline_coef, 0) ** outline_exp
alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8)
alpha = np.maximum(alpha, mask)
return np.stack([mask, alpha], axis=-1)
#----------------------------------------------------------------------------
@functools.lru_cache(maxsize=10000)
def _get_array_impl_dropshadow(string, *, font=None, size=32, radius: int, offset_x: int, offset_y: int, line_pad: int=None, **kwargs):
assert (offset_x > 0) and (offset_y > 0)
pil_font = get_pil_font(font=font, size=size)
lines = [pil_font.getmask(line, 'L') for line in string.split('\n')]
lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines]
width = max(line.shape[1] for line in lines)
lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines]
line_spacing = line_pad if line_pad is not None else size // 2
lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:]
mask = np.concatenate(lines, axis=0)
alpha = mask
mask = np.pad(mask, 2*radius + max(abs(offset_x), abs(offset_y)), mode='constant', constant_values=0)
alpha = mask.astype(np.float32) / 255
alpha = scipy.ndimage.gaussian_filter(alpha, radius)
alpha = 1 - np.maximum(1 - alpha * 1.5, 0) ** 1.4
alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8)
alpha = np.pad(alpha, [(offset_y, 0), (offset_x, 0)], mode='constant')[:-offset_y, :-offset_x]
alpha = np.maximum(alpha, mask)
return np.stack([mask, alpha], axis=-1)
#----------------------------------------------------------------------------
@functools.lru_cache(maxsize=10000)
def get_texture(string, bilinear=True, mipmap=True, **kwargs):
return gl_utils.Texture(image=get_array(string, **kwargs), bilinear=bilinear, mipmap=mipmap)
#----------------------------------------------------------------------------
================================================
FILE: legacy.py
================================================
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Converting legacy network pickle into the new format."""
import click
import pickle
import re
import copy
import numpy as np
import torch
import dnnlib
from torch_utils import misc
#----------------------------------------------------------------------------
def load_network_pkl(f, force_fp16=False):
data = _LegacyUnpickler(f).load()
# Legacy TensorFlow pickle => convert.
if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
tf_G, tf_D, tf_Gs = data
G = convert_tf_generator(tf_G)
D = convert_tf_discriminator(tf_D)
G_ema = convert_tf_generator(tf_Gs)
data = dict(G=G, D=D, G_ema=G_ema)
# Add missing fields.
if 'training_set_kwargs' not in data:
data['training_set_kwargs'] = None
if 'augment_pipe' not in data:
data['augment_pipe'] = None
# Validate contents.
assert isinstance(data['G'], torch.nn.Module)
assert isinstance(data['D'], torch.nn.Module)
assert isinstance(data['G_ema'], torch.nn.Module)
assert isinstance(data['training_set_kwargs'], (dict, type(None)))
assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
# Force FP16.
if force_fp16:
for key in ['G', 'D', 'G_ema']:
old = data[key]
kwargs = copy.deepcopy(old.init_kwargs)
fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs)
fp16_kwargs.num_fp16_res = 4
fp16_kwargs.conv_clamp = 256
if kwargs != old.init_kwargs:
new = type(old)(**kwargs).eval().requires_grad_(False)
misc.copy_params_and_buffers(old, new, require_all=True)
data[key] = new
return data
#----------------------------------------------------------------------------
class _TFNetworkStub(dnnlib.EasyDict):
pass
class _LegacyUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if module == 'dnnlib.tflib.network' and name == 'Network':
return _TFNetworkStub
return super().find_class(module, name)
#----------------------------------------------------------------------------
def _collect_tf_params(tf_net):
# pylint: disable=protected-access
tf_params = dict()
def recurse(prefix, tf_net):
for name, value in tf_net.variables:
tf_params[prefix + name] = value
for name, comp in tf_net.components.items():
recurse(prefix + name + '/', comp)
recurse('', tf_net)
return tf_params
#----------------------------------------------------------------------------
def _populate_module_params(module, *patterns):
for name, tensor in misc.named_params_and_buffers(module):
found = False
value = None
for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
match = re.fullmatch(pattern, name)
if match:
found = True
if value_fn is not None:
value = value_fn(*match.groups())
break
try:
assert found
if value is not None:
tensor.copy_(torch.from_numpy(np.array(value)))
except:
print(name, list(tensor.shape))
raise
#----------------------------------------------------------------------------
def convert_tf_generator(tf_G):
if tf_G.version < 4:
raise ValueError('TensorFlow pickle version too low')
# Collect kwargs.
tf_kwargs = tf_G.static_kwargs
known_kwargs = set()
def kwarg(tf_name, default=None, none=None):
known_kwargs.add(tf_name)
val = tf_kwargs.get(tf_name, default)
return val if val is not None else none
# Convert kwargs.
from training import networks_stylegan2
network_class = networks_stylegan2.Generator
kwargs = dnnlib.EasyDict(
z_dim = kwarg('latent_size', 512),
c_dim = kwarg('label_size', 0),
w_dim = kwarg('dlatent_size', 512),
img_resolution = kwarg('resolution', 1024),
img_channels = kwarg('num_channels', 3),
channel_base = kwarg('fmap_base', 16384) * 2,
channel_max = kwarg('fmap_max', 512),
num_fp16_res = kwarg('num_fp16_res', 0),
conv_clamp = kwarg('conv_clamp', None),
architecture = kwarg('architecture', 'skip'),
resample_filter = kwarg('resample_kernel', [1,3,3,1]),
use_noise = kwarg('use_noise', True),
activation = kwarg('nonlinearity', 'lrelu'),
mapping_kwargs = dnnlib.EasyDict(
num_layers = kwarg('mapping_layers', 8),
embed_features = kwarg('label_fmaps', None),
layer_features = kwarg('mapping_fmaps', None),
activation = kwarg('mapping_nonlinearity', 'lrelu'),
lr_multiplier = kwarg('mapping_lrmul', 0.01),
w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
),
)
# Check for unknown kwargs.
kwarg('truncation_psi')
kwarg('truncation_cutoff')
kwarg('style_mixing_prob')
kwarg('structure')
kwarg('conditioning')
kwarg('fused_modconv')
unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
if len(unknown_kwargs) > 0:
raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
# Collect params.
tf_params = _collect_tf_params(tf_G)
for name, value in list(tf_params.items()):
match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
if match:
r = kwargs.img_resolution // (2 ** int(match.group(1)))
tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
kwargs.synthesis.kwargs.architecture = 'orig'
#for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
# Convert params.
G = network_class(**kwargs).eval().requires_grad_(False)
# pylint: disable=unnecessary-lambda
# pylint: disable=f-string-without-interpolation
_populate_module_params(G,
r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
r'.*\.resample_filter', None,
r'.*\.act_filter', None,
)
return G
#----------------------------------------------------------------------------
def convert_tf_discriminator(tf_D):
if tf_D.version < 4:
raise ValueError('TensorFlow pickle version too low')
# Collect kwargs.
tf_kwargs = tf_D.static_kwargs
known_kwargs = set()
def kwarg(tf_name, default=None):
known_kwargs.add(tf_name)
return tf_kwargs.get(tf_name, default)
# Convert kwargs.
kwargs = dnnlib.EasyDict(
c_dim = kwarg('label_size', 0),
img_resolution = kwarg('resolution', 1024),
img_channels = kwarg('num_channels', 3),
architecture = kwarg('architecture', 'resnet'),
channel_base = kwarg('fmap_base', 16384) * 2,
channel_max = kwarg('fmap_max', 512),
num_fp16_res = kwarg('num_fp16_res', 0),
conv_clamp = kwarg('conv_clamp', None),
cmap_dim = kwarg('mapping_fmaps', None),
block_kwargs = dnnlib.EasyDict(
activation = kwarg('nonlinearity', 'lrelu'),
resample_filter = kwarg('resample_kernel', [1,3,3,1]),
freeze_layers = kwarg('freeze_layers', 0),
),
mapping_kwargs = dnnlib.EasyDict(
num_layers = kwarg('mapping_layers', 0),
embed_features = kwarg('mapping_fmaps', None),
layer_features = kwarg('mapping_fmaps', None),
activation = kwarg('nonlinearity', 'lrelu'),
lr_multiplier = kwarg('mapping_lrmul', 0.1),
),
epilogue_kwargs = dnnlib.EasyDict(
mbstd_group_size = kwarg('mbstd_group_size', None),
mbstd_num_channels = kwarg('mbstd_num_features', 1),
activation = kwarg('nonlinearity', 'lrelu'),
),
)
# Check for unknown kwargs.
kwarg('structure')
kwarg('conditioning')
unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
if len(unknown_kwargs) > 0:
raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
# Collect params.
tf_params = _collect_tf_params(tf_D)
for name, value in list(tf_params.items()):
match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
if match:
r = kwargs.img_resolution // (2 ** int(match.group(1)))
tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
kwargs.architecture = 'orig'
#for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
# Convert params.
from training import networks_stylegan2
D = networks_stylegan2.Discriminator(**kwargs).eval().requires_grad_(False)
# pylint: disable=unnecessary-lambda
# pylint: disable=f-string-without-interpolation
_populate_module_params(D,
r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
r'.*\.resample_filter', None,
)
return D
#----------------------------------------------------------------------------
@click.command()
@click.option('--source', help='Input pickle', required=True, metavar='PATH')
@click.option('--dest', help='Output pickle', required=True, metavar='PATH')
@click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
def convert_network_pickle(source, dest, force_fp16):
"""Convert legacy network pickle into the native PyTorch format.
The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
Example:
\b
python legacy.py \\
--source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
--dest=stylegan2-cat-config-f.pkl
"""
print(f'Loading "{source}"...')
with dnnlib.util.open_url(source) as f:
data = load_network_pkl(f, force_fp16=force_fp16)
print(f'Saving "{dest}"...')
with open(dest, 'wb') as f:
pickle.dump(data, f)
print('Done.')
#----------------------------------------------------------------------------
if __name__ == "__main__":
convert_network_pickle() # pylint: disable=no-value-for-parameter
#----------------------------------------------------------------------------
================================================
FILE: requirements.txt
================================================
torch>=2.0.0
scipy>=1.11.1
Ninja==1.10.2
gradio>=3.35.2
imageio-ffmpeg>=0.4.3
huggingface_hub
hf_transfer
pyopengl
imgui
glfw==2.6.1
pillow>=9.4.0
torchvision>=0.15.2
imageio>=2.9.0
================================================
FILE: stylegan_human/.gitignore
================================================
.DS_Store
__pycache__
*.pt
*.pth
*.pdparams
*.pdiparams
*.pdmodel
*.pkl
*.info
*.yaml
================================================
FILE: stylegan_human/PP_HumanSeg/deploy/infer.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import codecs
import os
import time
import yaml
import numpy as np
import cv2
import paddle
import paddleseg.transforms as T
from paddle.inference import create_predictor, PrecisionType
from paddle.inference import Config as PredictConfig
from paddleseg.core.infer import reverse_transform
from paddleseg.cvlibs import manager
from paddleseg.utils import TimeAverager
from ..scripts.optic_flow_process import optic_flow_process
class DeployConfig:
def __init__(self, path):
with codecs.open(path, 'r', 'utf-8') as file:
self.dic = yaml.load(file, Loader=yaml.FullLoader)
self._transforms = self._load_transforms(self.dic['Deploy'][
'transforms'])
self._dir = os.path.dirname(path)
@property
def transforms(self):
return self._transforms
@property
def model(self):
return os.path.join(self._dir, self.dic['Deploy']['model'])
@property
def params(self):
return os.path.join(self._dir, self.dic['Deploy']['params'])
def _load_transforms(self, t_list):
com = manager.TRANSFORMS
transforms = []
for t in t_list:
ctype = t.pop('type')
transforms.append(com[ctype](**t))
return transforms
class Predictor:
def __init__(self, args):
self.cfg = DeployConfig(args.cfg)
self.args = args
self.compose = T.Compose(self.cfg.transforms)
resize_h, resize_w = args.input_shape
self.disflow = cv2.DISOpticalFlow_create(
cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
self.prev_gray = np.zeros((resize_h, resize_w), np.uint8)
self.prev_cfd = np.zeros((resize_h, resize_w), np.float32)
self.is_init = True
pred_cfg = PredictConfig(self.cfg.model, self.cfg.params)
pred_cfg.disable_glog_info()
if self.args.use_gpu:
pred_cfg.enable_use_gpu(100, 0)
self.predictor = create_predictor(pred_cfg)
if self.args.test_speed:
self.cost_averager = TimeAverager()
def preprocess(self, img):
ori_shapes = []
processed_imgs = []
processed_img = self.compose(img)[0]
processed_imgs.append(processed_img)
ori_shapes.append(img.shape)
return processed_imgs, ori_shapes
def run(self, img, bg):
input_names = self.predictor.get_input_names()
input_handle = self.predictor.get_input_handle(input_names[0])
processed_imgs, ori_shapes = self.preprocess(img)
data = np.array(processed_imgs)
input_handle.reshape(data.shape)
input_handle.copy_from_cpu(data)
if self.args.test_speed:
start = time.time()
self.predictor.run()
if self.args.test_speed:
self.cost_averager.record(time.time() - start)
output_names = self.predictor.get_output_names()
output_handle = self.predictor.get_output_handle(output_names[0])
output = output_handle.copy_to_cpu()
return self.postprocess(output, img, ori_shapes[0], bg)
def postprocess(self, pred, img, ori_shape, bg):
if not os.path.exists(self.args.save_dir):
os.makedirs(self.args.save_dir)
resize_w = pred.shape[-1]
resize_h = pred.shape[-2]
if self.args.soft_predict:
if self.args.use_optic_flow:
score_map = pred[:, 1, :, :].squeeze(0)
score_map = 255 * score_map
cur_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
optflow_map = optic_flow_process(cur_gray, score_map, self.prev_gray, self.prev_cfd, \
self.disflow, self.is_init)
self.prev_gray = cur_gray.copy()
self.prev_cfd = optflow_map.copy()
self.is_init = False
score_map = np.repeat(optflow_map[:, :, np.newaxis], 3, axis=2)
score_map = np.transpose(score_map, [2, 0, 1])[np.newaxis, ...]
score_map = reverse_transform(
paddle.to_tensor(score_map),
ori_shape,
self.cfg.transforms,
mode='bilinear')
alpha = np.transpose(score_map.numpy().squeeze(0),
[1, 2, 0]) / 255
else:
score_map = pred[:, 1, :, :]
score_map = score_map[np.newaxis, ...]
score_map = reverse_transform(
paddle.to_tensor(score_map),
ori_shape,
self.cfg.transforms,
mode='bilinear')
alpha = np.transpose(score_map.numpy().squeeze(0), [1, 2, 0])
else:
if pred.ndim == 3:
pred = pred[:, np.newaxis, ...]
result = reverse_transform(
paddle.to_tensor(
pred, dtype='float32'),
ori_shape,
self.cfg.transforms,
mode='bilinear')
result = np.array(result)
if self.args.add_argmax:
result = np.argmax(result, axis=1)
else:
result = result.squeeze(1)
alpha = np.transpose(result, [1, 2, 0])
# background replace
h, w, _ = img.shape
if bg is None:
bg = np.ones_like(img)*255
else:
bg = cv2.resize(bg, (w, h))
if bg.ndim == 2:
bg = bg[..., np.newaxis]
comb = (alpha * img + (1 - alpha) * bg).astype(np.uint8)
return comb, alpha, bg, img
================================================
FILE: stylegan_human/PP_HumanSeg/export_model/download_export_model.py
================================================
# coding: utf8
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
TEST_PATH = os.path.join(LOCAL_PATH, "../../../", "test")
sys.path.append(TEST_PATH)
from paddleseg.utils.download import download_file_and_uncompress
model_urls = {
"pphumanseg_lite_portrait_398x224_with_softmax":
"https://paddleseg.bj.bcebos.com/dygraph/ppseg/ppseg_lite_portrait_398x224_with_softmax.tar.gz",
"deeplabv3p_resnet50_os8_humanseg_512x512_100k_with_softmax":
"https://paddleseg.bj.bcebos.com/dygraph/humanseg/export/deeplabv3p_resnet50_os8_humanseg_512x512_100k_with_softmax.zip",
"fcn_hrnetw18_small_v1_humanseg_192x192_with_softmax":
"https://paddleseg.bj.bcebos.com/dygraph/humanseg/export/fcn_hrnetw18_small_v1_humanseg_192x192_with_softmax.zip",
"pphumanseg_lite_generic_humanseg_192x192_with_softmax":
"https://paddleseg.bj.bcebos.com/dygraph/humanseg/export/pphumanseg_lite_generic_192x192_with_softmax.zip",
}
if __name__ == "__main__":
for model_name, url in model_urls.items():
download_file_and_uncompress(
url=url,
savepath=LOCAL_PATH,
extrapath=LOCAL_PATH,
extraname=model_name)
print("Export model download success!")
================================================
FILE: stylegan_human/PP_HumanSeg/pretrained_model/download_pretrained_model.py
================================================
# coding: utf8
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
TEST_PATH = os.path.join(LOCAL_PATH, "../../../", "test")
sys.path.append(TEST_PATH)
from paddleseg.utils.download import download_file_and_uncompress
model_urls = {
"pphumanseg_lite_portrait_398x224":
"https://paddleseg.bj.bcebos.com/dygraph/ppseg/ppseg_lite_portrait_398x224.tar.gz",
"deeplabv3p_resnet50_os8_humanseg_512x512_100k":
"https://paddleseg.bj.bcebos.com/dygraph/humanseg/train/deeplabv3p_resnet50_os8_humanseg_512x512_100k.zip",
"fcn_hrnetw18_small_v1_humanseg_192x192":
"https://paddleseg.bj.bcebos.com/dygraph/humanseg/train/fcn_hrnetw18_small_v1_humanseg_192x192.zip",
"pphumanseg_lite_generic_human_192x192":
"https://paddleseg.bj.bcebos.com/dygraph/humanseg/train/pphumanseg_lite_generic_192x192.zip",
}
if __name__ == "__main__":
for model_name, url in model_urls.items():
download_file_and_uncompress(
url=url,
savepath=LOCAL_PATH,
extrapath=LOCAL_PATH,
extraname=model_name)
print("Pretrained model download success!")
================================================
FILE: stylegan_human/README.md
================================================
# StyleGAN-Human: A Data-Centric Odyssey of Human Generation
>
>
> **Abstract:** *Unconditional human image generation is an important task in vision and graphics, which enables various applications in the creative industry. Existing studies in this field mainly focus on "network engineering" such as designing new components and objective functions. This work takes a data-centric perspective and investigates multiple critical aspects in "data engineering", which we believe would complement the current practice. To facilitate a comprehensive study, we collect and annotate a large-scale human image dataset with over 230K samples capturing diverse poses and textures. Equipped with this large dataset, we rigorously investigate three essential factors in data engineering for StyleGAN-based human generation, namely data size, data distribution, and data alignment. Extensive experiments reveal several valuable observations w.r.t. these aspects: 1) Large-scale data, more than 40K images, are needed to train a high-fidelity unconditional human generation model with vanilla StyleGAN. 2) A balanced training set helps improve the generation quality with rare face poses compared to the long-tailed counterpart, whereas simply balancing the clothing texture distribution does not effectively bring an improvement. 3) Human GAN models with body centers for alignment outperform models trained using face centers or pelvis points as alignment anchors. In addition, a model zoo and human editing applications are demonstrated to facilitate future research in the community.*
**Keyword:** Human Image Generation, Data-Centric, StyleGAN
[Jianglin Fu](mailto:fujianglin@sensetime.com), [Shikai Li](mailto:lishikai@sensetime.com), [Yuming Jiang](https://yumingj.github.io/), [Kwan-Yee Lin](https://kwanyeelin.github.io/), [Chen Qian](https://scholar.google.com/citations?user=AerkT0YAAAAJ&hl=zh-CN), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/), [Wayne Wu](https://wywu.github.io/), and [Ziwei Liu](https://liuziwei7.github.io/)
**[[Demo Video]](https://youtu.be/nIrb9hwsdcI)** | **[[Project Page]](https://stylegan-human.github.io/)** | **[[Paper]](https://arxiv.org/pdf/2204.11823.pdf)**
## Updates
- [20/07/2022] [SHHQ-1.0](./docs/Dataset.md) dataset with 40K images is released! :sparkles:
- [15/06/2022] Data alignment and real-image inversion scripts are released.
- [26/04/2022] Technical report released!
- [22/04/2022] Technical report will be released before May.
- [21/04/2022] The codebase and project page are created.
## Data Download
The first version SHHQ-1.0, with 40K images is released. To download and use the dataset set, please read the instructions in [Dataset.md](./docs/Dataset.md)
(We are currently facing large incoming applications, and we need to carefully verify all the applicants, please be patient, and we will reply to you as soon as possible.)
## Model Zoo
| Structure | 1024x512 | Metric | Scores | 512x256 | Metric | Scores |
| --------- |:----------:| :----------:| :----------:| :-----: | :-----: | :-----: |
| StyleGAN1 |[stylegan_human_v1_1024.pkl](https://drive.google.com/file/d/1h-R-IV-INGdPEzj4P9ml6JTEvihuNgLX/view?usp=sharing)| fid50k | 3.79 | to be released | - | - |
| StyleGAN2 |[stylegan_human_v2_1024.pkl](https://drive.google.com/file/d/1FlAb1rYa0r_--Zj_ML8e6shmaF28hQb5/view?usp=sharing)| fid50k_full | 1.57 |[stylegan_human_v2_512.pkl](https://drive.google.com/file/d/1dlFEHbu-WzQWJl7nBBZYcTyo000H9hVm/view?usp=sharing) | fid50k_full | 1.97 |
| StyleGAN3 |to be released | - | - | [stylegan_human_v3_512.pkl](https://drive.google.com/file/d/1_274jk_N6WSCkKWeu7hjHycqGvbuOFf5/view?usp=sharing) | fid50k_full | 2.54 |
## Web Demo
Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo for generation: [](https://huggingface.co/spaces/hysts/StyleGAN-Human) and interpolation [](https://huggingface.co/spaces/hysts/StyleGAN-Human-Interpolation)
We prepare a Colab demo to allow you to synthesize images with the provided models, as well as visualize the performance of style-mixing, interpolation, and attributes editing.
The notebook will guide you to install the necessary environment and download pretrained models. The output images can be found in `./StyleGAN-Human/outputs/`.
Hope you enjoy!
## Usage
### System requirements
* The original code bases are [stylegan (tensorflow)](https://github.com/NVlabs/stylegan), [stylegan2-ada (pytorch)](https://github.com/NVlabs/stylegan2-ada-pytorch), [stylegan3 (pytorch)](https://github.com/NVlabs/stylegan3), released by NVidia
* We tested in Python 3.8.5 and PyTorch 1.9.1 with CUDA 11.1. (See https://pytorch.org for PyTorch install instructions.)
### Installation
To work with this project on your own machine, you need to install the environmnet as follows:
```
conda env create -f environment.yml
conda activate stylehuman
# [Optional: tensorflow 1.x is required for StyleGAN1. ]
pip install nvidia-pyindex
pip install nvidia-tensorflow[horovod]
pip install nvidia-tensorboard==1.15
```
Extra notes:
1. In case having some conflicts when calling CUDA version, please try to empty the LD_LIBRARY_PATH. For example:
```
LD_LIBRARY_PATH=; python generate.py --outdir=out/stylegan_human_v2_1024 --trunc=1 --seeds=1,3,5,7
--network=pretrained_models/stylegan_human_v2_1024.pkl --version 2
```
2. We found the following troubleshooting links might be helpful: [1.](https://github.com/NVlabs/stylegan3), [2.](https://github.com/NVlabs/stylegan3/blob/main/docs/troubleshooting.md)
### Train
The training scripts are based on the original [stylegan1](https://github.com/NVlabs/stylegan), [stylegan2-ada](https://github.com/NVlabs/stylegan2-ada-pytorch), and [stylegan3](https://github.com/NVlabs/stylegan3) with minor changes. Here we only provide the scripts with modifications for SG2 and SG3. You can replace the old files with the provided scripts to train. (assume SHHQ-1.0 is placed under data/)
#### Train Stylegan2-ada-pytorch with SHHQ-1.0
```
python train.py --outdir=training_results/sg2/ --data=data/SHHQ-1.0/ \
--gpus=8 --aug=noaug --mirror=1 --snap=250 --cfg=shhq --square=False
```
#### Train Stylegan3 with SHHQ-1.0
```
python train.py --outdir=training_results/sg3/ --cfg=stylegan3-r --gpus=8 --batch=32 --gamma=12.4 \
--mirror=1 --aug=noaug --data=data/SHHQ-1.0/ --square=False --snap=250
```
### Pretrained models
Please put the downloaded pretrained models [from above link](#Model-Zoo) under the folder 'pretrained_models'.
### Generate full-body human images using our pretrained model
```
# Generate human full-body images without truncation
python generate.py --outdir=outputs/generate/stylegan_human_v2_1024 --trunc=1 --seeds=1,3,5,7 --network=pretrained_models/stylegan_human_v2_1024.pkl --version 2
# Generate human full-body images with truncation
python generate.py --outdir=outputs/generate/stylegan_human_v2_1024 --trunc=0.8 --seeds=0-10 --network=pretrained_models/stylegan_human_v2_1024.pkl --version 2
# Generate human full-body images using stylegan V1
python generate.py --outdir=outputs/generate/stylegan_human_v1_1024 --network=pretrained_models/stylegan_human_v1_1024.pkl --version 1 --seeds=1,3,5
# Generate human full-body images using stylegan V3
python generate.py --outdir=outputs/generate/stylegan_human_v3_512 --network=pretrained_models/stylegan_human_v3_512.pkl --version 3 --seeds=1,3,5
```
#### Note: The following demos are generated based on models related to StyleGAN V2 (stylegan_human_v2_512.pkl and stylegan_human_v2_1024.pkl). If you want to see results for V1 or V3, you need to change the loading method of the corresponding models.
### Interpolation
```
python interpolation.py --network=pretrained_models/stylegan_human_v2_1024.pkl --seeds=85,100 --outdir=outputs/inter_gifs
```
### Style-mixing **image** using stylegan2
```
python style_mixing.py --network=pretrained_models/stylegan_human_v2_1024.pkl --rows=85,100,75,458,1500 \\
--cols=55,821,1789,293 --styles=0-3 --outdir=outputs/stylemixing
```
### Style-mixing **video** using stylegan2
```
python stylemixing_video.py --network=pretrained_models/stylegan_human_v2_1024.pkl --row-seed=3859 \\
--col-seeds=3098,31759,3791 --col-styles=8-12 --trunc=0.8 --outdir=outputs/stylemixing_video
```
### Aligned raw images
For alignment, we use [openpose-pytorch](https://github.com/Hzzone/pytorch-openpose) for body-keypoints detection and [PaddlePaddle](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.5/contrib/PP-HumanSeg) for human segmentation.
Before running the alignment script, few models need to be installed:
1. download [body_pose_model.pth](https://drive.google.com/drive/folders/1JsvI4M4ZTg98fmnCZLFM-3TeovnCRElG?usp=sharing) and place it into openpose/model/.
2. download and extract [deeplabv3p_resnet50_os8_humanseg_512x512_100k_with_softmax](https://paddleseg.bj.bcebos.com/dygraph/humanseg/export/deeplabv3p_resnet50_os8_humanseg_512x512_100k_with_softmax.zip) into PP_HumanSeg/export_model/deeplabv3p_resnet50_os8_humanseg_512x512_100k_with_softmax.
3. download and extract [deeplabv3p_resnet50_os8_humanseg_512x512_100k](https://paddleseg.bj.bcebos.com/dygraph/humanseg/train/deeplabv3p_resnet50_os8_humanseg_512x512_100k.zip) into PP_HumanSeg/pretrained_model/deeplabv3p_resnet50_os8_humanseg_512x512_100k.
4. install paddlepaddel: ``` pip install paddleseg ```
Then you can start alignment:
```
python alignment.py --image-folder img/test/ --output-folder aligned_image/
```
### Invert real image with [PTI](https://github.com/danielroich/PTI)
Before inversion, please download our PTI weights: [e4e_w+.pt](https://drive.google.com/file/d/1NUfSJqLhsrU7c9PwAtlZ9xtrxhzS_6tu/view?usp=sharing) into /pti/.
Few parameters you can change:
- /pti/pti_configs/hyperparameters.py:
- first_inv_type = 'w+' -> Use pretrained e4e encoder
- first_inv_type = 'w' -> Use projection and optimization
- /pti/pti_configs/paths_config.py:
- input_data_path: path of real images
- e4e: path of e4e_w+.pt
- stylegan2_ada_shhq: pretrained stylegan2-ada model for SHHQ
```
python run_pti.py
```
Note: we used the test image under 'aligned_image/' (the output of alignment.py), the inverted latent code and fine-tuned generator will be saved in 'outputs/pti/'
### Editing with InterfaceGAN, StyleSpace, and Sefa
```
python edit.py --network pretrained_models/stylegan_human_v2_1024.pkl --attr_name upper_length \\
--seeds 61531,61570,61571,61610 --outdir outputs/edit_results
```
### Editing using inverted latent code
```
python edit.py ---network outputs/pti/checkpoints/model_test.pkl --attr_name upper_length \\
--outdir outputs/edit_results --real True --real_w_path outputs/pti/embeddings/test/PTI/test/0.pt --real_img_path aligned_image/test.png
```
Note:
1. ''upper_length'' and ''bottom_length'' of ''attr_name'' are available for demo.
2. Layers to control and editing strength are set in edit/edit_config.py.
### Demo for [InsetGAN](https://arxiv.org/abs/2203.07293)
We implement a quick demo using the key idea from InsetGAN: combining the face generated by FFHQ with the human-body generated by our pretrained model, optimizing both face and body latent codes to get a coherent full-body image.
Before running the script, you need to download the [FFHQ face model]( https://docs.google.com/uc?export=download&confirm=t&id=125OG7SMkXI-Kf2aqiwLLHyCvSW-gZk3M), or you can use your own face model, as well as [pretrained face landmark](https://docs.google.com/uc?export=download&confirm=&id=1A82DnJBJzt8wI2J8ZrCK5fgHcQ2-tcWM) and [pretrained CNN face detection model for dlib](https://docs.google.com/uc?export=download&confirm=&id=1MduBgju5KFNrQfDLoQXJ_1_h5MnctCIG)
```
python insetgan.py --body_network=pretrained_models/stylegan_human_v2_1024.pkl --face_network=pretrained_models/ffhq.pkl \\
--body_seed=82 --face_seed=43 --trunc=0.6 --outdir=outputs/insetgan/ --video 1
```
## Results
### Editing with inverted real image
(from left to right: real image | inverted image | InterFaceGAN result | StyleSpace result | SeFa result)
https://user-images.githubusercontent.com/98547009/173773800-bb7fe54a-84d3-4b30-9864-a6b7b311f8ff.mp4
### For more demo, please visit our [**web page**](https://stylegan-human.github.io/) .
## TODO List
- [ ] Release 1024x512 version of StyleGAN-Human based on StyleGAN3
- [ ] Release 512x256 version of StyleGAN-Human based on StyleGAN1
- [ ] Extension of downstream application (InsetGAN): Add face inversion interface to support fusing user face image and stylegen-human body image
- [x] Add Inversion Script into the provided editing pipeline
- [ ] Release Dataset
## Related Works
* (SIGGRAPH 2022) **Text2Human: Text-Driven Controllable Human Image Generation**, Yuming Jiang et al. [[Paper](https://arxiv.org/pdf/2205.15996.pdf)], [[Code](https://github.com/yumingj/Text2Human)], [[Project Page](https://yumingj.github.io/projects/Text2Human.html)], [[Dataset](https://github.com/yumingj/DeepFashion-MultiModal)]
* (ICCV 2021) **Talk-to-Edit: Fine-Grained Facial Editing via Dialog**, Yuming Jiang et al. [[Paper](https://arxiv.org/abs/2109.04425)], [[Code](https://github.com/yumingj/Talk-to-Edit)], [[Project Page](https://www.mmlab-ntu.com/project/talkedit/)], [[Dataset](https://mmlab.ie.cuhk.edu.hk/projects/CelebA/CelebA_Dialog.html)]
* (Technical Report 2022) **Generalizable Neural Performer: Learning Robust Radiance Fields for Human Novel View Synthesis**, Wei Cheng et al. [[Paper](https://arxiv.org/pdf/2204.11798.pdf)], [[Code](https://github.com/generalizable-neural-performer/gnr)], [[Project Page](https://generalizable-neural-performer.github.io/)], [[Dataset](https://generalizable-neural-performer.github.io/genebody.html)]
## Citation
If you find this work useful for your research, please consider citing our paper:
```bibtex
@article{fu2022styleganhuman,
title={StyleGAN-Human: A Data-Centric Odyssey of Human Generation},
author={Fu, Jianglin and Li, Shikai and Jiang, Yuming and Lin, Kwan-Yee and Qian, Chen and Loy, Chen-Change and Wu, Wayne and Liu, Ziwei},
journal = {arXiv preprint},
volume = {arXiv:2204.11823},
year = {2022}
```
## Acknowlegement
Part of the code is borrowed from [stylegan (tensorflow)](https://github.com/NVlabs/stylegan), [stylegan2-ada (pytorch)](https://github.com/NVlabs/stylegan2-ada-pytorch), [stylegan3 (pytorch)](https://github.com/NVlabs/stylegan3).
================================================
FILE: stylegan_human/__init__.py
================================================
================================================
FILE: stylegan_human/alignment.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
import os
import argparse
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
from utils.ImagesDataset import ImagesDataset
import cv2
import time
import copy
import imutils
# for openpose body keypoint detector : # (src:https://github.com/Hzzone/pytorch-openpose)
from openpose.src import util
from openpose.src.body import Body
# for paddlepaddle human segmentation : #(src: https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/contrib/PP-HumanSeg/)
from PP_HumanSeg.deploy.infer import Predictor as PP_HumenSeg_Predictor
import math
def angle_between_points(p0,p1,p2):
if p0[1]==-1 or p1[1]==-1 or p2[1]==-1:
return -1
a = (p1[0]-p0[0])**2 + (p1[1]-p0[1])**2
b = (p1[0]-p2[0])**2 + (p1[1]-p2[1])**2
c = (p2[0]-p0[0])**2 + (p2[1]-p0[1])**2
if a * b == 0:
return -1
return math.acos((a+b-c) / math.sqrt(4*a*b)) * 180 / math.pi
def crop_img_with_padding(img, keypoints, rect):
person_xmin,person_xmax, ymin, ymax= rect
img_h,img_w,_ = img.shape ## find body center using keypoints
middle_shoulder_x = keypoints[1][0]
middle_hip_x = (keypoints[8][0] + keypoints[11][0]) // 2
mid_x = (middle_hip_x + middle_shoulder_x) // 2
mid_y = (ymin + ymax) // 2
## find which side (l or r) is further than center x, use the further side
if abs(mid_x-person_xmin) > abs(person_xmax-mid_x): #left further
xmin = person_xmin
xmax = mid_x + (mid_x-person_xmin)
else:
############### may be negtive
### in this case, the script won't output any image, leave the case like this
### since we don't want to pad human body
xmin = mid_x - (person_xmax-mid_x)
xmax = person_xmax
w = xmax - xmin
h = ymax - ymin
## pad rectangle to w:h = 1:2 ## calculate desired border length
if h / w >= 2: #pad horizontally
target_w = h // 2
xmin_prime = int(mid_x - target_w / 2)
xmax_prime = int(mid_x + target_w / 2)
if xmin_prime < 0:
pad_left = abs(xmin_prime)# - xmin
xmin = 0
else:
pad_left = 0
xmin = xmin_prime
if xmax_prime > img_w:
pad_right = xmax_prime - img_w
xmax = img_w
else:
pad_right = 0
xmax = xmax_prime
cropped_img = img[int(ymin):int(ymax), int(xmin):int(xmax)]
im_pad = cv2.copyMakeBorder(cropped_img, 0, 0, int(pad_left), int(pad_right), cv2.BORDER_REPLICATE)
else: #pad vertically
target_h = w * 2
ymin_prime = mid_y - (target_h / 2)
ymax_prime = mid_y + (target_h / 2)
if ymin_prime < 0:
pad_up = abs(ymin_prime)# - ymin
ymin = 0
else:
pad_up = 0
ymin = ymin_prime
if ymax_prime > img_h:
pad_down = ymax_prime - img_h
ymax = img_h
else:
pad_down = 0
ymax = ymax_prime
print(ymin,ymax, xmin,xmax, img.shape)
cropped_img = img[int(ymin):int(ymax), int(xmin):int(xmax)]
im_pad = cv2.copyMakeBorder(cropped_img, int(pad_up), int(pad_down), 0,
0, cv2.BORDER_REPLICATE)
result = cv2.resize(im_pad,(512,1024),interpolation = cv2.INTER_AREA)
return result
def run(args):
os.makedirs(args.output_folder, exist_ok=True)
dataset = ImagesDataset(args.image_folder, transforms.Compose([transforms.ToTensor()]))
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
body_estimation = Body('openpose/model/body_pose_model.pth')
total = len(dataloader)
print('Num of dataloader : ', total)
os.makedirs(f'{args.output_folder}', exist_ok=True)
# os.makedirs(f'{args.output_folder}/middle_result', exist_ok=True)
## initialzide HumenSeg
human_seg_args = {}
human_seg_args['cfg'] = 'PP_HumanSeg/export_model/deeplabv3p_resnet50_os8_humanseg_512x512_100k_with_softmax/deploy.yaml'
human_seg_args['input_shape'] = [1024,512]
human_seg_args['save_dir'] = args.output_folder
human_seg_args['soft_predict'] = False
human_seg_args['use_gpu'] = True
human_seg_args['test_speed'] = False
human_seg_args['use_optic_flow'] = False
human_seg_args['add_argmax'] = True
human_seg_args= argparse.Namespace(**human_seg_args)
human_seg = PP_HumenSeg_Predictor(human_seg_args)
from tqdm import tqdm
for fname, image in tqdm(dataloader):
# try:
## tensor to numpy image
fname = fname[0]
print(f'Processing \'{fname}\'.')
image = (image.permute(0, 2, 3, 1) * 255).clamp(0, 255)
image = image.squeeze(0).numpy() # --> tensor to numpy, (H,W,C)
# avoid super high res img
if image.shape[0] >= 2000: # height ### for shein image
ratio = image.shape[0]/1200 #height
dim = (int(image.shape[1]/ratio),1200)#(width, height)
image = cv2.resize(image, dim, interpolation = cv2.INTER_AREA)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
## create segmentation
# mybg = cv2.imread('mybg.png')
comb, segmentation, bg, ori_img = human_seg.run(image,None) #mybg)
# cv2.imwrite('comb.png',comb) # [0,255]
# cv2.imwrite('alpha.png',segmentation*255) # segmentation [0,1] --> [0.255]
# cv2.imwrite('bg.png',bg) #[0,255]
# cv2.imwrite('ori_img.png',ori_img) # [0,255]
masks_np = (segmentation* 255)# .byte().cpu().numpy() #1024,512,1
mask0_np = masks_np[:,:,0].astype(np.uint8)#[0, :, :]
contours = cv2.findContours(mask0_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cnts = imutils.grab_contours(contours)
c = max(cnts, key=cv2.contourArea)
extTop = tuple(c[c[:, :, 1].argmin()][0])
extBot = tuple(c[c[:, :, 1].argmax()][0])
extBot = list(extBot)
extTop = list(extTop)
pad_range = int((extBot[1]-extTop[1])*0.05)
if (int(extTop[1])<=5 and int(extTop[1])>0) and (comb.shape[0]>int(extBot[1]) and int(extBot[1])>=comb.shape[0]-5): #seg mask already reaches to the edge
#pad with pure white, top 100 px, bottom 100 px
comb= cv2.copyMakeBorder(comb,pad_range+5,pad_range+5,0,0,cv2.BORDER_CONSTANT,value=[255,255,255])
elif int(extTop[1])<=0 or int(extBot[1])>=comb.shape[0]:
print('PAD: body out of boundary', fname) #should not happened
return {}
else:
comb = cv2.copyMakeBorder(comb, pad_range+5, pad_range+5, 0, 0, cv2.BORDER_REPLICATE) #105 instead of 100: give some extra space
extBot[1] = extBot[1] + pad_range+5
extTop[1] = extTop[1] + pad_range+5
extLeft = tuple(c[c[:, :, 0].argmin()][0])
extRight = tuple(c[c[:, :, 0].argmax()][0])
extLeft = list(extLeft)
extRight = list(extRight)
person_ymin = int(extTop[1])-pad_range # 100
person_ymax = int(extBot[1])+pad_range # 100 #height
if person_ymin<0 or person_ymax>comb.shape[0]: # out of range
return {}
person_xmin = int(extLeft[0])
person_xmax = int(extRight[0])
rect = [person_xmin,person_xmax,person_ymin, person_ymax]
# recimg = copy.deepcopy(comb)
# cv2.rectangle(recimg,(person_xmin,person_ymin),(person_xmax,person_ymax),(0,255,0),2)
# cv2.imwrite(f'{args.output_folder}/middle_result/{fname}_rec.png',recimg)
## detect keypoints
keypoints, subset = body_estimation(comb)
# print(keypoints, subset, len(subset))
if len(subset) != 1 or (len(subset)==1 and subset[0][-1]<15):
print(f'Processing \'{fname}\'. Please import image contains one person only. Also can check segmentation mask. ')
continue
# canvas = copy.deepcopy(comb)
# canvas = util.draw_bodypose(canvas, keypoints, subset, show_number=True)
# cv2.imwrite(f'{args.output_folder}/middle_result/{fname}_keypoints.png',canvas)
comb = crop_img_with_padding(comb, keypoints, rect)
cv2.imwrite(f'{args.output_folder}/{fname}.png', comb)
print(f' -- Finished processing \'{fname}\'. --')
# except:
# print(f'Processing \'{fname}\'. Not satisfied the alignment strategy.')
if __name__ == '__main__':
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
t1 = time.time()
arg_formatter = argparse.ArgumentDefaultsHelpFormatter
description = 'StyleGAN-Human data process'
parser = argparse.ArgumentParser(formatter_class=arg_formatter,
description=description)
parser.add_argument('--image-folder', type=str, dest='image_folder')
parser.add_argument('--output-folder', dest='output_folder', default='results', type=str)
# parser.add_argument('--cfg', dest='cfg for segmentation', default='PP_HumanSeg/export_model/ppseg_lite_portrait_398x224_with_softmax/deploy.yaml', type=str)
print('parsing arguments')
cmd_args = parser.parse_args()
run(cmd_args)
print('total time elapsed: ', str(time.time() - t1))
================================================
FILE: stylegan_human/bg_white.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
import os
import click
import cv2
import numpy as np
def bg_white(seg, raw, blur_level=3, gaussian=81):
seg = cv2.blur(seg, (blur_level, blur_level))
empty = np.ones_like(seg)
seg_bg = (empty - seg) * 255
seg_bg = cv2.GaussianBlur(seg_bg,(gaussian,gaussian),0)
background_mask = cv2.cvtColor(255 - cv2.cvtColor(seg, cv2.COLOR_BGR2GRAY), cv2.COLOR_GRAY2BGR)
masked_fg = (raw * (1 / 255)) * (seg * (1 / 255))
masked_bg = (seg_bg * (1 / 255)) * (background_mask * (1 / 255))
frame = np.uint8(cv2.add(masked_bg,masked_fg)*255)
return frame
"""
To turn background into white.
Examples:
\b
python bg_white.py --raw_img_dir=./SHHQ-1.0/no_segment/ --raw_seg_dir=./SHHQ-1.0/segments/ \\
--outdir=./SHHQ-1.0/bg_white/
"""
@click.command()
@click.pass_context
@click.option('--raw_img_dir', default="./SHHQ-1.0/no_segment/", help='folder of raw image', required=True)
@click.option('--raw_seg_dir', default='./SHHQ-1.0/segments/', help='folder of segmentation masks', required=True)
@click.option('--outdir', help='Where to save the output images', default= "./SHHQ-1.0/bg_white/" , type=str, required=True, metavar='DIR')
def main(
ctx: click.Context,
raw_img_dir: str,
raw_seg_dir: str,
outdir: str):
os.makedirs(outdir, exist_ok=True)
files = os.listdir(raw_img_dir)
for file in files:
print(file)
raw = cv2.imread(os.path.join(raw_img_dir, file))
seg = cv2.imread(os.path.join(raw_seg_dir, file))
assert raw is not None
assert seg is not None
white_frame = bg_white(seg, raw)
cv2.imwrite(os.path.join(outdir,file), white_frame)
if __name__ == "__main__":
main()
================================================
FILE: stylegan_human/dnnlib/__init__.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from .util import EasyDict, make_cache_dir_path
================================================
FILE: stylegan_human/dnnlib/tflib/__init__.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://nvlabs.github.io/stylegan2/license.html
from . import autosummary
from . import network
from . import optimizer
from . import tfutil
from . import custom_ops
from .tfutil import *
from .network import Network
from .optimizer import Optimizer
from .custom_ops import get_plugin
================================================
FILE: stylegan_human/dnnlib/tflib/autosummary.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://nvlabs.github.io/stylegan2/license.html
"""Helper for adding automatically tracked values to Tensorboard.
Autosummary creates an identity op that internally keeps track of the input
values and automatically shows up in TensorBoard. The reported value
represents an average over input components. The average is accumulated
constantly over time and flushed when save_summaries() is called.
Notes:
- The output tensor must be used as an input for something else in the
graph. Otherwise, the autosummary op will not get executed, and the average
value will not get accumulated.
- It is perfectly fine to include autosummaries with the same name in
several places throughout the graph, even if they are executed concurrently.
- It is ok to also pass in a python scalar or numpy array. In this case, it
is added to the average immediately.
"""
from collections import OrderedDict
import numpy as np
import tensorflow as tf
from tensorboard import summary as summary_lib
from tensorboard.plugins.custom_scalar import layout_pb2
from . import tfutil
from .tfutil import TfExpression
from .tfutil import TfExpressionEx
# Enable "Custom scalars" tab in TensorBoard for advanced formatting.
# Disabled by default to reduce tfevents file size.
enable_custom_scalars = False
_dtype = tf.float64
_vars = OrderedDict() # name => [var, ...]
_immediate = OrderedDict() # name => update_op, update_value
_finalized = False
_merge_op = None
def _create_var(name: str, value_expr: TfExpression) -> TfExpression:
"""Internal helper for creating autosummary accumulators."""
assert not _finalized
name_id = name.replace("/", "_")
v = tf.cast(value_expr, _dtype)
if v.shape.is_fully_defined():
size = np.prod(v.shape.as_list())
size_expr = tf.constant(size, dtype=_dtype)
else:
size = None
size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype))
if size == 1:
if v.shape.ndims != 0:
v = tf.reshape(v, [])
v = [size_expr, v, tf.square(v)]
else:
v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))]
v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype))
with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None):
var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)]
update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v))
if name in _vars:
_vars[name].append(var)
else:
_vars[name] = [var]
return update_op
def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None, condition: TfExpressionEx = True) -> TfExpressionEx:
"""Create a new autosummary.
Args:
name: Name to use in TensorBoard
value: TensorFlow expression or python value to track
passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node.
Example use of the passthru mechanism:
n = autosummary('l2loss', loss, passthru=n)
This is a shorthand for the following code:
with tf.control_dependencies([autosummary('l2loss', loss)]):
n = tf.identity(n)
"""
tfutil.assert_tf_initialized()
name_id = name.replace("/", "_")
if tfutil.is_tf_expression(value):
with tf.name_scope("summary_" + name_id), tf.device(value.device):
condition = tf.convert_to_tensor(condition, name='condition')
update_op = tf.cond(condition, lambda: tf.group(_create_var(name, value)), tf.no_op)
with tf.control_dependencies([update_op]):
return tf.identity(value if passthru is None else passthru)
else: # python scalar or numpy array
assert not tfutil.is_tf_expression(passthru)
assert not tfutil.is_tf_expression(condition)
if condition:
if name not in _immediate:
with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None):
update_value = tf.placeholder(_dtype)
update_op = _create_var(name, update_value)
_immediate[name] = update_op, update_value
update_op, update_value = _immediate[name]
tfutil.run(update_op, {update_value: value})
return value if passthru is None else passthru
def finalize_autosummaries() -> None:
"""Create the necessary ops to include autosummaries in TensorBoard report.
Note: This should be done only once per graph.
"""
global _finalized
tfutil.assert_tf_initialized()
if _finalized:
return None
_finalized = True
tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list])
# Create summary ops.
with tf.device(None), tf.control_dependencies(None):
for name, vars_list in _vars.items():
name_id = name.replace("/", "_")
with tfutil.absolute_name_scope("Autosummary/" + name_id):
moments = tf.add_n(vars_list)
moments /= moments[0]
with tf.control_dependencies([moments]): # read before resetting
reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list]
with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting
mean = moments[1]
std = tf.sqrt(moments[2] - tf.square(moments[1]))
tf.summary.scalar(name, mean)
if enable_custom_scalars:
tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std)
tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std)
# Setup layout for custom scalars.
layout = None
if enable_custom_scalars:
cat_dict = OrderedDict()
for series_name in sorted(_vars.keys()):
p = series_name.split("/")
cat = p[0] if len(p) >= 2 else ""
chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1]
if cat not in cat_dict:
cat_dict[cat] = OrderedDict()
if chart not in cat_dict[cat]:
cat_dict[cat][chart] = []
cat_dict[cat][chart].append(series_name)
categories = []
for cat_name, chart_dict in cat_dict.items():
charts = []
for chart_name, series_names in chart_dict.items():
series = []
for series_name in series_names:
series.append(layout_pb2.MarginChartContent.Series(
value=series_name,
lower="xCustomScalars/" + series_name + "/margin_lo",
upper="xCustomScalars/" + series_name + "/margin_hi"))
margin = layout_pb2.MarginChartContent(series=series)
charts.append(layout_pb2.Chart(title=chart_name, margin=margin))
categories.append(layout_pb2.Category(title=cat_name, chart=charts))
layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories))
return layout
def save_summaries(file_writer, global_step=None):
"""Call FileWriter.add_summary() with all summaries in the default graph,
automatically finalizing and merging them on the first call.
"""
global _merge_op
tfutil.assert_tf_initialized()
if _merge_op is None:
layout = finalize_autosummaries()
if layout is not None:
file_writer.add_summary(layout)
with tf.device(None), tf.control_dependencies(None):
_merge_op = tf.summary.merge_all()
file_writer.add_summary(_merge_op.eval(), global_step)
================================================
FILE: stylegan_human/dnnlib/tflib/custom_ops.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://nvlabs.github.io/stylegan2/license.html
"""TensorFlow custom ops builder.
"""
import os
import re
import uuid
import hashlib
import tempfile
import shutil
import tensorflow as tf
from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module
#----------------------------------------------------------------------------
# Global options.
cuda_cache_path = os.path.join(os.path.dirname(__file__), '_cudacache')
cuda_cache_version_tag = 'v1'
do_not_hash_included_headers = False # Speed up compilation by assuming that headers included by the CUDA code never change. Unsafe!
verbose = True # Print status messages to stdout.
compiler_bindir_search_path = [
'C:/Program Files (x86)/Microsoft Visual Studio/2017/Community/VC/Tools/MSVC/14.14.26428/bin/Hostx64/x64',
'C:/Program Files (x86)/Microsoft Visual Studio/2019/Community/VC/Tools/MSVC/14.23.28105/bin/Hostx64/x64',
'C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin',
]
#----------------------------------------------------------------------------
# Internal helper funcs.
def _find_compiler_bindir():
for compiler_path in compiler_bindir_search_path:
if os.path.isdir(compiler_path):
return compiler_path
return None
def _get_compute_cap(device):
caps_str = device.physical_device_desc
m = re.search('compute capability: (\\d+).(\\d+)', caps_str)
major = m.group(1)
minor = m.group(2)
return (major, minor)
def _get_cuda_gpu_arch_string():
gpus = [x for x in device_lib.list_local_devices() if x.device_type == 'GPU']
if len(gpus) == 0:
raise RuntimeError('No GPU devices found')
(major, minor) = _get_compute_cap(gpus[0])
return 'sm_%s%s' % (major, minor)
def _run_cmd(cmd):
with os.popen(cmd) as pipe:
output = pipe.read()
status = pipe.close()
if status is not None:
raise RuntimeError('NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s' % (cmd, output))
def _prepare_nvcc_cli(opts):
cmd = 'nvcc ' + opts.strip()
cmd += ' --disable-warnings'
cmd += ' --include-path "%s"' % tf.sysconfig.get_include()
cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'protobuf_archive', 'src')
cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'com_google_absl')
cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'eigen_archive')
compiler_bindir = _find_compiler_bindir()
if compiler_bindir is None:
# Require that _find_compiler_bindir succeeds on Windows. Allow
# nvcc to use whatever is the default on Linux.
if os.name == 'nt':
raise RuntimeError('Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".' % __file__)
else:
cmd += ' --compiler-bindir "%s"' % compiler_bindir
cmd += ' 2>&1'
return cmd
#----------------------------------------------------------------------------
# Main entry point.
_plugin_cache = dict()
def get_plugin(cuda_file):
cuda_file_base = os.path.basename(cuda_file)
cuda_file_name, cuda_file_ext = os.path.splitext(cuda_file_base)
# Already in cache?
if cuda_file in _plugin_cache:
return _plugin_cache[cuda_file]
# Setup plugin.
if verbose:
print('Setting up TensorFlow plugin "%s": ' % cuda_file_base, end='', flush=True)
try:
# Hash CUDA source.
md5 = hashlib.md5()
with open(cuda_file, 'rb') as f:
md5.update(f.read())
md5.update(b'\n')
# Hash headers included by the CUDA code by running it through the preprocessor.
if not do_not_hash_included_headers:
if verbose:
print('Preprocessing... ', end='', flush=True)
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + cuda_file_ext)
_run_cmd(_prepare_nvcc_cli('"%s" --preprocess -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir)))
with open(tmp_file, 'rb') as f:
bad_file_str = ('"' + cuda_file.replace('\\', '/') + '"').encode('utf-8') # __FILE__ in error check macros
good_file_str = ('"' + cuda_file_base + '"').encode('utf-8')
for ln in f:
if not ln.startswith(b'# ') and not ln.startswith(b'#line '): # ignore line number pragmas
ln = ln.replace(bad_file_str, good_file_str)
md5.update(ln)
md5.update(b'\n')
# Select compiler options.
compile_opts = ''
if os.name == 'nt':
compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.lib')
elif os.name == 'posix':
compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.so')
compile_opts += ' --compiler-options \'-fPIC -D_GLIBCXX_USE_CXX11_ABI=0\''
else:
assert False # not Windows or Linux, w00t?
compile_opts += ' --gpu-architecture=%s' % _get_cuda_gpu_arch_string()
compile_opts += ' --use_fast_math'
nvcc_cmd = _prepare_nvcc_cli(compile_opts)
# Hash build configuration.
md5.update(('nvcc_cmd: ' + nvcc_cmd).encode('utf-8') + b'\n')
md5.update(('tf.VERSION: ' + tf.VERSION).encode('utf-8') + b'\n')
md5.update(('cuda_cache_version_tag: ' + cuda_cache_version_tag).encode('utf-8') + b'\n')
# Compile if not already compiled.
bin_file_ext = '.dll' if os.name == 'nt' else '.so'
bin_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + md5.hexdigest() + bin_file_ext)
if not os.path.isfile(bin_file):
if verbose:
print('Compiling... ', end='', flush=True)
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + bin_file_ext)
_run_cmd(nvcc_cmd + ' "%s" --shared -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir))
os.makedirs(cuda_cache_path, exist_ok=True)
intermediate_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + uuid.uuid4().hex + '_tmp' + bin_file_ext)
shutil.copyfile(tmp_file, intermediate_file)
os.rename(intermediate_file, bin_file) # atomic
# Load.
if verbose:
print('Loading... ', end='', flush=True)
plugin = tf.load_op_library(bin_file)
# Add to cache.
_plugin_cache[cuda_file] = plugin
if verbose:
print('Done.', flush=True)
return plugin
except:
if verbose:
print('Failed!', flush=True)
raise
#----------------------------------------------------------------------------
================================================
FILE: stylegan_human/dnnlib/tflib/network.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://nvlabs.github.io/stylegan2/license.html
"""Helper for managing networks."""
import types
import inspect
import re
import uuid
import sys
import numpy as np
import tensorflow as tf
from collections import OrderedDict
from typing import Any, List, Tuple, Union
from . import tfutil
from .. import util
from .tfutil import TfExpression, TfExpressionEx
_import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import.
_import_module_src = dict() # Source code for temporary modules created during pickle import.
def import_handler(handler_func):
"""Function decorator for declaring custom import handlers."""
_import_handlers.append(handler_func)
return handler_func
class Network:
"""Generic network abstraction.
Acts as a convenience wrapper for a parameterized network construction
function, providing several utility methods and convenient access to
the inputs/outputs/weights.
Network objects can be safely pickled and unpickled for long-term
archival purposes. The pickling works reliably as long as the underlying
network construction function is defined in a standalone Python module
that has no side effects or application-specific imports.
Args:
name: Network name. Used to select TensorFlow name and variable scopes.
func_name: Fully qualified name of the underlying network construction function, or a top-level function object.
static_kwargs: Keyword arguments to be passed in to the network construction function.
Attributes:
name: User-specified name, defaults to build func name if None.
scope: Unique TensorFlow scope containing template graph and variables, derived from the user-specified name.
static_kwargs: Arguments passed to the user-supplied build func.
components: Container for sub-networks. Passed to the build func, and retained between calls.
num_inputs: Number of input tensors.
num_outputs: Number of output tensors.
input_shapes: Input tensor shapes (NC or NCHW), including minibatch dimension.
output_shapes: Output tensor shapes (NC or NCHW), including minibatch dimension.
input_shape: Short-hand for input_shapes[0].
output_shape: Short-hand for output_shapes[0].
input_templates: Input placeholders in the template graph.
output_templates: Output tensors in the template graph.
input_names: Name string for each input.
output_names: Name string for each output.
own_vars: Variables defined by this network (local_name => var), excluding sub-networks.
vars: All variables (local_name => var).
trainables: All trainable variables (local_name => var).
var_global_to_local: Mapping from variable global names to local names.
"""
def __init__(self, name: str = None, func_name: Any = None, **static_kwargs):
tfutil.assert_tf_initialized()
assert isinstance(name, str) or name is None
assert func_name is not None
assert isinstance(func_name, str) or util.is_top_level_function(func_name)
assert util.is_pickleable(static_kwargs)
self._init_fields()
self.name = name
self.static_kwargs = util.EasyDict(static_kwargs)
# Locate the user-specified network build function.
if util.is_top_level_function(func_name):
func_name = util.get_top_level_function_name(func_name)
module, self._build_func_name = util.get_module_from_obj_name(func_name)
self._build_func = util.get_obj_from_module(module, self._build_func_name)
assert callable(self._build_func)
# Dig up source code for the module containing the build function.
self._build_module_src = _import_module_src.get(module, None)
if self._build_module_src is None:
self._build_module_src = inspect.getsource(module)
# Init TensorFlow graph.
self._init_graph()
self.reset_own_vars()
def _init_fields(self) -> None:
self.name = None
self.scope = None
self.static_kwargs = util.EasyDict()
self.components = util.EasyDict()
self.num_inputs = 0
self.num_outputs = 0
self.input_shapes = [[]]
self.output_shapes = [[]]
self.input_shape = []
self.output_shape = []
self.input_templates = []
self.output_templates = []
self.input_names = []
self.output_names = []
self.own_vars = OrderedDict()
self.vars = OrderedDict()
self.trainables = OrderedDict()
self.var_global_to_local = OrderedDict()
self._build_func = None # User-supplied build function that constructs the network.
self._build_func_name = None # Name of the build function.
self._build_module_src = None # Full source code of the module containing the build function.
self._run_cache = dict() # Cached graph data for Network.run().
def _init_graph(self) -> None:
# Collect inputs.
self.input_names = []
for param in inspect.signature(self._build_func).parameters.values():
if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty:
self.input_names.append(param.name)
self.num_inputs = len(self.input_names)
assert self.num_inputs >= 1
# Choose name and scope.
if self.name is None:
self.name = self._build_func_name
assert re.match("^[A-Za-z0-9_.\\-]*$", self.name)
with tf.name_scope(None):
self.scope = tf.get_default_graph().unique_name(self.name, mark_as_used=True)
# Finalize build func kwargs.
build_kwargs = dict(self.static_kwargs)
build_kwargs["is_template_graph"] = True
build_kwargs["components"] = self.components
# Build template graph.
with tfutil.absolute_variable_scope(self.scope, reuse=False), tfutil.absolute_name_scope(self.scope): # ignore surrounding scopes
assert tf.get_variable_scope().name == self.scope
assert tf.get_default_graph().get_name_scope() == self.scope
with tf.control_dependencies(None): # ignore surrounding control dependencies
self.input_templates = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
out_expr = self._build_func(*self.input_templates, **build_kwargs)
# Collect outputs.
assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
self.output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
self.num_outputs = len(self.output_templates)
assert self.num_outputs >= 1
assert all(tfutil.is_tf_expression(t) for t in self.output_templates)
# Perform sanity checks.
if any(t.shape.ndims is None for t in self.input_templates):
raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.")
if any(t.shape.ndims is None for t in self.output_templates):
raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.")
if any(not isinstance(comp, Network) for comp in self.components.values()):
raise ValueError("Components of a Network must be Networks themselves.")
if len(self.components) != len(set(comp.name for comp in self.components.values())):
raise ValueError("Components of a Network must have unique names.")
# List inputs and outputs.
self.input_shapes = [t.shape.as_list() for t in self.input_templates]
self.output_shapes = [t.shape.as_list() for t in self.output_templates]
self.input_shape = self.input_shapes[0]
self.output_shape = self.output_shapes[0]
self.output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates]
# List variables.
self.own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/"))
self.vars = OrderedDict(self.own_vars)
self.vars.update((comp.name + "/" + name, var) for comp in self.components.values() for name, var in comp.vars.items())
self.trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable)
self.var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items())
def reset_own_vars(self) -> None:
"""Re-initialize all variables of this network, excluding sub-networks."""
tfutil.run([var.initializer for var in self.own_vars.values()])
def reset_vars(self) -> None:
"""Re-initialize all variables of this network, including sub-networks."""
tfutil.run([var.initializer for var in self.vars.values()])
def reset_trainables(self) -> None:
"""Re-initialize all trainable variables of this network, including sub-networks."""
tfutil.run([var.initializer for var in self.trainables.values()])
def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]:
"""Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s)."""
assert len(in_expr) == self.num_inputs
assert not all(expr is None for expr in in_expr)
# Finalize build func kwargs.
build_kwargs = dict(self.static_kwargs)
build_kwargs.update(dynamic_kwargs)
build_kwargs["is_template_graph"] = False
build_kwargs["components"] = self.components
# Build TensorFlow graph to evaluate the network.
with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name):
assert tf.get_variable_scope().name == self.scope
valid_inputs = [expr for expr in in_expr if expr is not None]
final_inputs = []
for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes):
if expr is not None:
expr = tf.identity(expr, name=name)
else:
expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name)
final_inputs.append(expr)
out_expr = self._build_func(*final_inputs, **build_kwargs)
# Propagate input shapes back to the user-specified expressions.
for expr, final in zip(in_expr, final_inputs):
if isinstance(expr, tf.Tensor):
expr.set_shape(final.shape)
# Express outputs in the desired format.
assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
if return_as_list:
out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
return out_expr
def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str:
"""Get the local name of a given variable, without any surrounding name scopes."""
assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str)
global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name
return self.var_global_to_local[global_name]
def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression:
"""Find variable by local or global name."""
assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str)
return self.vars[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name
def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray:
"""Get the value of a given variable as NumPy array.
Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible."""
return self.find_var(var_or_local_name).eval()
def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None:
"""Set the value of a given variable based on the given NumPy array.
Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible."""
tfutil.set_vars({self.find_var(var_or_local_name): new_value})
def __getstate__(self) -> dict:
"""Pickle export."""
state = dict()
state["version"] = 4
state["name"] = self.name
state["static_kwargs"] = dict(self.static_kwargs)
state["components"] = dict(self.components)
state["build_module_src"] = self._build_module_src
state["build_func_name"] = self._build_func_name
state["variables"] = list(zip(self.own_vars.keys(), tfutil.run(list(self.own_vars.values()))))
return state
def __setstate__(self, state: dict) -> None:
"""Pickle import."""
# pylint: disable=attribute-defined-outside-init
tfutil.assert_tf_initialized()
self._init_fields()
# Execute custom import handlers.
for handler in _import_handlers:
state = handler(state)
# Set basic fields.
assert state["version"] in [2, 3, 4]
self.name = state["name"]
self.static_kwargs = util.EasyDict(state["static_kwargs"])
self.components = util.EasyDict(state.get("components", {}))
self._build_module_src = state["build_module_src"]
self._build_func_name = state["build_func_name"]
# Create temporary module from the imported source code.
module_name = "_tflib_network_import_" + uuid.uuid4().hex
module = types.ModuleType(module_name)
sys.modules[module_name] = module
_import_module_src[module] = self._build_module_src
exec(self._build_module_src, module.__dict__) # pylint: disable=exec-used
# Locate network build function in the temporary module.
self._build_func = util.get_obj_from_module(module, self._build_func_name)
assert callable(self._build_func)
# Init TensorFlow graph.
self._init_graph()
self.reset_own_vars()
tfutil.set_vars({self.find_var(name): value for name, value in state["variables"]})
def clone(self, name: str = None, **new_static_kwargs) -> "Network":
"""Create a clone of this network with its own copy of the variables."""
# pylint: disable=protected-access
net = object.__new__(Network)
net._init_fields()
net.name = name if name is not None else self.name
net.static_kwargs = util.EasyDict(self.static_kwargs)
net.static_kwargs.update(new_static_kwargs)
net._build_module_src = self._build_module_src
net._build_func_name = self._build_func_name
net._build_func = self._build_func
net._init_graph()
net.copy_vars_from(self)
return net
def copy_own_vars_from(self, src_net: "Network") -> None:
"""Copy the values of all variables from the given network, excluding sub-networks."""
names = [name for name in self.own_vars.keys() if name in src_net.own_vars]
tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
def copy_vars_from(self, src_net: "Network") -> None:
"""Copy the values of all variables from the given network, including sub-networks."""
names = [name for name in self.vars.keys() if name in src_net.vars]
tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
def copy_trainables_from(self, src_net: "Network") -> None:
"""Copy the values of all trainable variables from the given network, including sub-networks."""
names = [name for name in self.trainables.keys() if name in src_net.trainables]
tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network":
"""Create new network with the given parameters, and copy all variables from this network."""
if new_name is None:
new_name = self.name
static_kwargs = dict(self.static_kwargs)
static_kwargs.update(new_static_kwargs)
net = Network(name=new_name, func_name=new_func_name, **static_kwargs)
net.copy_vars_from(self)
return net
def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation:
"""Construct a TensorFlow op that updates the variables of this network
to be slightly closer to those of the given network."""
with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"):
ops = []
for name, var in self.vars.items():
if name in src_net.vars:
cur_beta = beta if name in self.trainables else beta_nontrainable
new_value = tfutil.lerp(src_net.vars[name], var, cur_beta)
ops.append(var.assign(new_value))
return tf.group(*ops)
def run(self,
*in_arrays: Tuple[Union[np.ndarray, None], ...],
input_transform: dict = None,
output_transform: dict = None,
return_as_list: bool = False,
print_progress: bool = False,
minibatch_size: int = None,
num_gpus: int = 1,
assume_frozen: bool = False,
**dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]:
"""Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s).
Args:
input_transform: A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network.
The dict must contain a 'func' field that points to a top-level function. The function is called with the input
TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
output_transform: A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network.
The dict must contain a 'func' field that points to a top-level function. The function is called with the output
TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs.
print_progress: Print progress to the console? Useful for very large input arrays.
minibatch_size: Maximum minibatch size to use, None = disable batching.
num_gpus: Number of GPUs to use.
assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls.
dynamic_kwargs: Additional keyword arguments to be passed into the network build function.
"""
assert len(in_arrays) == self.num_inputs
assert not all(arr is None for arr in in_arrays)
assert input_transform is None or util.is_top_level_function(input_transform["func"])
assert output_transform is None or util.is_top_level_function(output_transform["func"])
output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs)
num_items = in_arrays[0].shape[0]
if minibatch_size is None:
minibatch_size = num_items
# Construct unique hash key from all arguments that affect the TensorFlow graph.
key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs)
def unwind_key(obj):
if isinstance(obj, dict):
return [(key, unwind_key(value)) for key, value in sorted(obj.items())]
if callable(obj):
return util.get_top_level_function_name(obj)
return obj
key = repr(unwind_key(key))
# Build graph.
if key not in self._run_cache:
with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None):
with tf.device("/cpu:0"):
in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
out_split = []
for gpu in range(num_gpus):
with tf.device("/gpu:%d" % gpu):
net_gpu = self.clone() if assume_frozen else self
in_gpu = in_split[gpu]
if input_transform is not None:
in_kwargs = dict(input_transform)
in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs)
in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu)
assert len(in_gpu) == self.num_inputs
out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs)
if output_transform is not None:
out_kwargs = dict(output_transform)
out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs)
out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu)
assert len(out_gpu) == self.num_outputs
out_split.append(out_gpu)
with tf.device("/cpu:0"):
out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)]
self._run_cache[key] = in_expr, out_expr
# Run minibatches.
in_expr, out_expr = self._run_cache[key]
out_arrays = [np.empty([num_items] + expr.shape.as_list()[1:], expr.dtype.name) for expr in out_expr]
for mb_begin in range(0, num_items, minibatch_size):
if print_progress:
print("\r%d / %d" % (mb_begin, num_items), end="")
mb_end = min(mb_begin + minibatch_size, num_items)
mb_num = mb_end - mb_begin
mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)]
mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in)))
for dst, src in zip(out_arrays, mb_out):
dst[mb_begin: mb_end] = src
# Done.
if print_progress:
print("\r%d / %d" % (num_items, num_items))
if not return_as_list:
out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays)
return out_arrays
def list_ops(self) -> List[TfExpression]:
include_prefix = self.scope + "/"
exclude_prefix = include_prefix + "_"
ops = tf.get_default_graph().get_operations()
ops = [op for op in ops if op.name.startswith(include_prefix)]
ops = [op for op in ops if not op.name.startswith(exclude_prefix)]
return ops
def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]:
"""Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to
individual layers of the network. Mainly intended to be used for reporting."""
layers = []
def recurse(scope, parent_ops, parent_vars, level):
# Ignore specific patterns.
if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]):
return
# Filter ops and vars by scope.
global_prefix = scope + "/"
local_prefix = global_prefix[len(self.scope) + 1:]
cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]]
cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]]
if not cur_ops and not cur_vars:
return
# Filter out all ops related to variables.
for var in [op for op in cur_ops if op.type.startswith("Variable")]:
var_prefix = var.name + "/"
cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)]
# Scope does not contain ops as immediate children => recurse deeper.
contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type not in ["Identity", "Cast", "Transpose"] for op in cur_ops)
if (level == 0 or not contains_direct_ops) and (len(cur_ops) + len(cur_vars)) > 1:
visited = set()
for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]:
token = rel_name.split("/")[0]
if token not in visited:
recurse(global_prefix + token, cur_ops, cur_vars, level + 1)
visited.add(token)
return
# Report layer.
layer_name = scope[len(self.scope) + 1:]
layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1]
layer_trainables = [var for _name, var in cur_vars if var.trainable]
layers.append((layer_name, layer_output, layer_trainables))
recurse(self.scope, self.list_ops(), list(self.vars.items()), 0)
return layers
def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None:
"""Print a summary table of the network structure."""
rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]]
rows += [["---"] * 4]
total_params = 0
for layer_name, layer_output, layer_trainables in self.list_layers():
num_params = sum(int(np.prod(var.shape.as_list())) for var in layer_trainables)
weights = [var for var in layer_trainables if var.name.endswith("/weight:0")]
weights.sort(key=lambda x: len(x.name))
if len(weights) == 0 and len(layer_trainables) == 1:
weights = layer_trainables
total_params += num_params
if not hide_layers_with_no_params or num_params != 0:
num_params_str = str(num_params) if num_params > 0 else "-"
output_shape_str = str(layer_output.shape)
weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-"
rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]]
rows += [["---"] * 4]
rows += [["Total", str(total_params), "", ""]]
widths = [max(len(cell) for cell in column) for column in zip(*rows)]
print()
for row in rows:
print(" ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths)))
print()
def setup_weight_histograms(self, title: str = None) -> None:
"""Construct summary ops to include histograms of all trainable parameters in TensorBoard."""
if title is None:
title = self.name
with tf.name_scope(None), tf.device(None), tf.control_dependencies(None):
for local_name, var in self.trainables.items():
if "/" in local_name:
p = local_name.split("/")
name = title + "_" + p[-1] + "/" + "_".join(p[:-1])
else:
name = title + "_toplevel/" + local_name
tf.summary.histogram(name, var)
#----------------------------------------------------------------------------
# Backwards-compatible emulation of legacy output transformation in Network.run().
_print_legacy_warning = True
def _handle_legacy_output_transforms(output_transform, dynamic_kwargs):
global _print_legacy_warning
legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"]
if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs):
return output_transform, dynamic_kwargs
if _print_legacy_warning:
_print_legacy_warning = False
print()
print("WARNING: Old-style output transformations in Network.run() are deprecated.")
print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'")
print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.")
print()
assert output_transform is None
new_kwargs = dict(dynamic_kwargs)
new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs}
new_transform["func"] = _legacy_output_transform_func
return new_transform, new_kwargs
def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None):
if out_mul != 1.0:
expr = [x * out_mul for x in expr]
if out_add != 0.0:
expr = [x + out_add for x in expr]
if out_shrink > 1:
ksize = [1, 1, out_shrink, out_shrink]
expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr]
if out_dtype is not None:
if tf.as_dtype(out_dtype).is_integer:
expr = [tf.round(x) for x in expr]
expr = [tf.saturate_cast(x, out_dtype) for x in expr]
return expr
================================================
FILE: stylegan_human/dnnlib/tflib/ops/__init__.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://nvlabs.github.io/stylegan2/license.html
# empty
================================================
FILE: stylegan_human/dnnlib/tflib/ops/fused_bias_act.cu
================================================
// Copyright (c) SenseTime Research. All rights reserved.
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#define EIGEN_USE_GPU
#define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include
using namespace tensorflow;
using namespace tensorflow::shape_inference;
#define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false)
//------------------------------------------------------------------------
// CUDA kernel.
template
struct FusedBiasActKernelParams
{
const T* x; // [sizeX]
const T* b; // [sizeB] or NULL
const T* ref; // [sizeX] or NULL
T* y; // [sizeX]
int grad;
int axis;
int act;
float alpha;
float gain;
int sizeX;
int sizeB;
int stepB;
int loopX;
};
template
static __global__ void FusedBiasActKernel(const FusedBiasActKernelParams p)
{
const float expRange = 80.0f;
const float halfExpRange = 40.0f;
const float seluScale = 1.0507009873554804934193349852946f;
const float seluAlpha = 1.6732632423543772848170429916717f;
// Loop over elements.
int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
{
// Load and apply bias.
float x = (float)p.x[xi];
if (p.b)
x += (float)p.b[(xi / p.stepB) % p.sizeB];
float ref = (p.ref) ? (float)p.ref[xi] : 0.0f;
if (p.gain != 0.0f & p.act != 9)
ref /= p.gain;
// Evaluate activation func.
float y;
switch (p.act * 10 + p.grad)
{
// linear
default:
case 10: y = x; break;
case 11: y = x; break;
case 12: y = 0.0f; break;
// relu
case 20: y = (x > 0.0f) ? x : 0.0f; break;
case 21: y = (ref > 0.0f) ? x : 0.0f; break;
case 22: y = 0.0f; break;
// lrelu
case 30: y = (x > 0.0f) ? x : x * p.alpha; break;
case 31: y = (ref > 0.0f) ? x : x * p.alpha; break;
case 32: y = 0.0f; break;
// tanh
case 40: { float c = expf(x); float d = 1.0f / c; y = (x < -expRange) ? -1.0f : (x > expRange) ? 1.0f : (c - d) / (c + d); } break;
case 41: y = x * (1.0f - ref * ref); break;
case 42: y = x * (1.0f - ref * ref) * (-2.0f * ref); break;
// sigmoid
case 50: y = (x < -expRange) ? 0.0f : 1.0f / (expf(-x) + 1.0f); break;
case 51: y = x * ref * (1.0f - ref); break;
case 52: y = x * ref * (1.0f - ref) * (1.0f - 2.0f * ref); break;
// elu
case 60: y = (x >= 0.0f) ? x : expf(x) - 1.0f; break;
case 61: y = (ref >= 0.0f) ? x : x * (ref + 1.0f); break;
case 62: y = (ref >= 0.0f) ? 0.0f : x * (ref + 1.0f); break;
// selu
case 70: y = (x >= 0.0f) ? seluScale * x : (seluScale * seluAlpha) * (expf(x) - 1.0f); break;
case 71: y = (ref >= 0.0f) ? x * seluScale : x * (ref + seluScale * seluAlpha); break;
case 72: y = (ref >= 0.0f) ? 0.0f : x * (ref + seluScale * seluAlpha); break;
// softplus
case 80: y = (x > expRange) ? x : logf(expf(x) + 1.0f); break;
case 81: y = x * (1.0f - expf(-ref)); break;
case 82: { float c = expf(-ref); y = x * c * (1.0f - c); } break;
// swish
case 90: y = (x < -expRange) ? 0.0f : x / (expf(-x) + 1.0f); break;
case 91: { float c = expf(ref); float d = c + 1.0f; y = (ref > halfExpRange) ? x : x * c * (ref + d) / (d * d); } break;
case 92: { float c = expf(ref); float d = c + 1.0f; y = (ref > halfExpRange) ? 0.0f : x * c * (ref * (2.0f - d) + 2.0f * d) / (d * d * d); } break;
}
// Apply gain and store.
p.y[xi] = (T)(y * p.gain);
}
}
//------------------------------------------------------------------------
// TensorFlow op.
template
struct FusedBiasActOp : public OpKernel
{
FusedBiasActKernelParams m_attribs;
FusedBiasActOp(OpKernelConstruction* ctx) : OpKernel(ctx)
{
memset(&m_attribs, 0, sizeof(m_attribs));
OP_REQUIRES_OK(ctx, ctx->GetAttr("grad", &m_attribs.grad));
OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &m_attribs.axis));
OP_REQUIRES_OK(ctx, ctx->GetAttr("act", &m_attribs.act));
OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &m_attribs.alpha));
OP_REQUIRES_OK(ctx, ctx->GetAttr("gain", &m_attribs.gain));
OP_REQUIRES(ctx, m_attribs.grad >= 0, errors::InvalidArgument("grad must be non-negative"));
OP_REQUIRES(ctx, m_attribs.axis >= 0, errors::InvalidArgument("axis must be non-negative"));
OP_REQUIRES(ctx, m_attribs.act >= 0, errors::InvalidArgument("act must be non-negative"));
}
void Compute(OpKernelContext* ctx)
{
FusedBiasActKernelParams p = m_attribs;
cudaStream_t stream = ctx->eigen_device().stream();
const Tensor& x = ctx->input(0); // [...]
const Tensor& b = ctx->input(1); // [sizeB] or [0]
const Tensor& ref = ctx->input(2); // x.shape or [0]
p.x = x.flat().data();
p.b = (b.NumElements()) ? b.flat().data() : NULL;
p.ref = (ref.NumElements()) ? ref.flat().data() : NULL;
OP_REQUIRES(ctx, b.NumElements() == 0 || m_attribs.axis < x.dims(), errors::InvalidArgument("axis out of bounds"));
OP_REQUIRES(ctx, b.dims() == 1, errors::InvalidArgument("b must have rank 1"));
OP_REQUIRES(ctx, b.NumElements() == 0 || b.NumElements() == x.dim_size(m_attribs.axis), errors::InvalidArgument("b has wrong number of elements"));
OP_REQUIRES(ctx, ref.NumElements() == ((p.grad == 0) ? 0 : x.NumElements()), errors::InvalidArgument("ref has wrong number of elements"));
OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("x is too large"));
p.sizeX = (int)x.NumElements();
p.sizeB = (int)b.NumElements();
p.stepB = 1;
for (int i = m_attribs.axis + 1; i < x.dims(); i++)
p.stepB *= (int)x.dim_size(i);
Tensor* y = NULL; // x.shape
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, x.shape(), &y));
p.y = y->flat().data();
p.loopX = 4;
int blockSize = 4 * 32;
int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
void* args[] = {&p};
OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel((void*)FusedBiasActKernel, gridSize, blockSize, args, 0, stream));
}
};
REGISTER_OP("FusedBiasAct")
.Input ("x: T")
.Input ("b: T")
.Input ("ref: T")
.Output ("y: T")
.Attr ("T: {float, half}")
.Attr ("grad: int = 0")
.Attr ("axis: int = 1")
.Attr ("act: int = 0")
.Attr ("alpha: float = 0.0")
.Attr ("gain: float = 1.0");
REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint("T"), FusedBiasActOp);
REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint("T"), FusedBiasActOp);
//------------------------------------------------------------------------
================================================
FILE: stylegan_human/dnnlib/tflib/ops/fused_bias_act.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://nvlabs.github.io/stylegan2/license.html
"""Custom TensorFlow ops for efficient bias and activation."""
import os
import numpy as np
import tensorflow as tf
from .. import custom_ops
from ...util import EasyDict
def _get_plugin():
return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu')
#----------------------------------------------------------------------------
activation_funcs = {
'linear': EasyDict(func=lambda x, **_: x, def_alpha=None, def_gain=1.0, cuda_idx=1, ref='y', zero_2nd_grad=True),
'relu': EasyDict(func=lambda x, **_: tf.nn.relu(x), def_alpha=None, def_gain=np.sqrt(2), cuda_idx=2, ref='y', zero_2nd_grad=True),
'lrelu': EasyDict(func=lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', zero_2nd_grad=True),
'tanh': EasyDict(func=lambda x, **_: tf.nn.tanh(x), def_alpha=None, def_gain=1.0, cuda_idx=4, ref='y', zero_2nd_grad=False),
'sigmoid': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x), def_alpha=None, def_gain=1.0, cuda_idx=5, ref='y', zero_2nd_grad=False),
'elu': EasyDict(func=lambda x, **_: tf.nn.elu(x), def_alpha=None, def_gain=1.0, cuda_idx=6, ref='y', zero_2nd_grad=False),
'selu': EasyDict(func=lambda x, **_: tf.nn.selu(x), def_alpha=None, def_gain=1.0, cuda_idx=7, ref='y', zero_2nd_grad=False),
'softplus': EasyDict(func=lambda x, **_: tf.nn.softplus(x), def_alpha=None, def_gain=1.0, cuda_idx=8, ref='y', zero_2nd_grad=False),
'swish': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x) * x, def_alpha=None, def_gain=np.sqrt(2), cuda_idx=9, ref='x', zero_2nd_grad=False),
}
#----------------------------------------------------------------------------
def fused_bias_act(x, b=None, axis=1, act='linear', alpha=None, gain=None, impl='cuda'):
r"""Fused bias and activation function.
Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
and scales the result by `gain`. Each of the steps is optional. In most cases,
the fused op is considerably more efficient than performing the same calculation
using standard TensorFlow ops. It supports first and second order gradients,
but not third order gradients.
Args:
x: Input activation tensor. Can have any shape, but if `b` is defined, the
dimension corresponding to `axis`, as well as the rank, must be known.
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
as `x`. The shape must be known, and it must match the dimension of `x`
corresponding to `axis`.
axis: The dimension in `x` corresponding to the elements of `b`.
The value of `axis` is ignored if `b` is not specified.
act: Name of the activation function to evaluate, or `"linear"` to disable.
Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
See `activation_funcs` for a full list. `None` is not allowed.
alpha: Shape parameter for the activation function, or `None` to use the default.
gain: Scaling factor for the output tensor, or `None` to use default.
See `activation_funcs` for the default scaling of each activation function.
If unsure, consider specifying `1.0`.
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
Returns:
Tensor of the same shape and datatype as `x`.
"""
impl_dict = {
'ref': _fused_bias_act_ref,
'cuda': _fused_bias_act_cuda,
}
return impl_dict[impl](x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain)
#----------------------------------------------------------------------------
def _fused_bias_act_ref(x, b, axis, act, alpha, gain):
"""Slow reference implementation of `fused_bias_act()` using standard TensorFlow ops."""
# Validate arguments.
x = tf.convert_to_tensor(x)
b = tf.convert_to_tensor(b) if b is not None else tf.constant([], dtype=x.dtype)
act_spec = activation_funcs[act]
assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis])
assert b.shape[0] == 0 or 0 <= axis < x.shape.rank
if alpha is None:
alpha = act_spec.def_alpha
if gain is None:
gain = act_spec.def_gain
# Add bias.
if b.shape[0] != 0:
x += tf.reshape(b, [-1 if i == axis else 1 for i in range(x.shape.rank)])
# Evaluate activation function.
x = act_spec.func(x, alpha=alpha)
# Scale by gain.
if gain != 1:
x *= gain
return x
#----------------------------------------------------------------------------
def _fused_bias_act_cuda(x, b, axis, act, alpha, gain):
"""Fast CUDA implementation of `fused_bias_act()` using custom ops."""
# Validate arguments.
x = tf.convert_to_tensor(x)
empty_tensor = tf.constant([], dtype=x.dtype)
b = tf.convert_to_tensor(b) if b is not None else empty_tensor
act_spec = activation_funcs[act]
assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis])
assert b.shape[0] == 0 or 0 <= axis < x.shape.rank
if alpha is None:
alpha = act_spec.def_alpha
if gain is None:
gain = act_spec.def_gain
# Special cases.
if act == 'linear' and b is None and gain == 1.0:
return x
if act_spec.cuda_idx is None:
return _fused_bias_act_ref(x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain)
# CUDA kernel.
cuda_kernel = _get_plugin().fused_bias_act
cuda_kwargs = dict(axis=axis, act=act_spec.cuda_idx, alpha=alpha, gain=gain)
# Forward pass: y = func(x, b).
def func_y(x, b):
y = cuda_kernel(x=x, b=b, ref=empty_tensor, grad=0, **cuda_kwargs)
y.set_shape(x.shape)
return y
# Backward pass: dx, db = grad(dy, x, y)
def grad_dx(dy, x, y):
ref = {'x': x, 'y': y}[act_spec.ref]
dx = cuda_kernel(x=dy, b=empty_tensor, ref=ref, grad=1, **cuda_kwargs)
dx.set_shape(x.shape)
return dx
def grad_db(dx):
if b.shape[0] == 0:
return empty_tensor
db = dx
if axis < x.shape.rank - 1:
db = tf.reduce_sum(db, list(range(axis + 1, x.shape.rank)))
if axis > 0:
db = tf.reduce_sum(db, list(range(axis)))
db.set_shape(b.shape)
return db
# Second order gradients: d_dy, d_x = grad2(d_dx, d_db, x, y)
def grad2_d_dy(d_dx, d_db, x, y):
ref = {'x': x, 'y': y}[act_spec.ref]
d_dy = cuda_kernel(x=d_dx, b=d_db, ref=ref, grad=1, **cuda_kwargs)
d_dy.set_shape(x.shape)
return d_dy
def grad2_d_x(d_dx, d_db, x, y):
ref = {'x': x, 'y': y}[act_spec.ref]
d_x = cuda_kernel(x=d_dx, b=d_db, ref=ref, grad=2, **cuda_kwargs)
d_x.set_shape(x.shape)
return d_x
# Fast version for piecewise-linear activation funcs.
@tf.custom_gradient
def func_zero_2nd_grad(x, b):
y = func_y(x, b)
@tf.custom_gradient
def grad(dy):
dx = grad_dx(dy, x, y)
db = grad_db(dx)
def grad2(d_dx, d_db):
d_dy = grad2_d_dy(d_dx, d_db, x, y)
return d_dy
return (dx, db), grad2
return y, grad
# Slow version for general activation funcs.
@tf.custom_gradient
def func_nonzero_2nd_grad(x, b):
y = func_y(x, b)
def grad_wrap(dy):
@tf.custom_gradient
def grad_impl(dy, x):
dx = grad_dx(dy, x, y)
db = grad_db(dx)
def grad2(d_dx, d_db):
d_dy = grad2_d_dy(d_dx, d_db, x, y)
d_x = grad2_d_x(d_dx, d_db, x, y)
return d_dy, d_x
return (dx, db), grad2
return grad_impl(dy, x)
return y, grad_wrap
# Which version to use?
if act_spec.zero_2nd_grad:
return func_zero_2nd_grad(x, b)
return func_nonzero_2nd_grad(x, b)
#----------------------------------------------------------------------------
================================================
FILE: stylegan_human/dnnlib/tflib/ops/upfirdn_2d.cu
================================================
// Copyright (c) SenseTime Research. All rights reserved.
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#define EIGEN_USE_GPU
#define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include
using namespace tensorflow;
using namespace tensorflow::shape_inference;
//------------------------------------------------------------------------
// Helpers.
#define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false)
static __host__ __device__ __forceinline__ int floorDiv(int a, int b)
{
int c = a / b;
if (c * b > a)
c--;
return c;
}
//------------------------------------------------------------------------
// CUDA kernel params.
template
struct UpFirDn2DKernelParams
{
const T* x; // [majorDim, inH, inW, minorDim]
const T* k; // [kernelH, kernelW]
T* y; // [majorDim, outH, outW, minorDim]
int upx;
int upy;
int downx;
int downy;
int padx0;
int padx1;
int pady0;
int pady1;
int majorDim;
int inH;
int inW;
int minorDim;
int kernelH;
int kernelW;
int outH;
int outW;
int loopMajor;
int loopX;
};
//------------------------------------------------------------------------
// General CUDA implementation for large filter kernels.
template
static __global__ void UpFirDn2DKernel_large(const UpFirDn2DKernelParams p)
{
// Calculate thread index.
int minorIdx = blockIdx.x * blockDim.x + threadIdx.x;
int outY = minorIdx / p.minorDim;
minorIdx -= outY * p.minorDim;
int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
int majorIdxBase = blockIdx.z * p.loopMajor;
if (outXBase >= p.outW || outY >= p.outH || majorIdxBase >= p.majorDim)
return;
// Setup Y receptive field.
int midY = outY * p.downy + p.upy - 1 - p.pady0;
int inY = min(max(floorDiv(midY, p.upy), 0), p.inH);
int h = min(max(floorDiv(midY + p.kernelH, p.upy), 0), p.inH) - inY;
int kernelY = midY + p.kernelH - (inY + 1) * p.upy;
// Loop over majorDim and outX.
for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor && majorIdx < p.majorDim; loopMajor++, majorIdx++)
for (int loopX = 0, outX = outXBase; loopX < p.loopX && outX < p.outW; loopX++, outX += blockDim.y)
{
// Setup X receptive field.
int midX = outX * p.downx + p.upx - 1 - p.padx0;
int inX = min(max(floorDiv(midX, p.upx), 0), p.inW);
int w = min(max(floorDiv(midX + p.kernelW, p.upx), 0), p.inW) - inX;
int kernelX = midX + p.kernelW - (inX + 1) * p.upx;
// Initialize pointers.
const T* xp = &p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx];
const T* kp = &p.k[kernelY * p.kernelW + kernelX];
int xpx = p.minorDim;
int kpx = -p.upx;
int xpy = p.inW * p.minorDim;
int kpy = -p.upy * p.kernelW;
// Inner loop.
float v = 0.0f;
for (int y = 0; y < h; y++)
{
for (int x = 0; x < w; x++)
{
v += (float)(*xp) * (float)(*kp);
xp += xpx;
kp += kpx;
}
xp += xpy - w * xpx;
kp += kpy - w * kpx;
}
// Store result.
p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v;
}
}
//------------------------------------------------------------------------
// Specialized CUDA implementation for small filter kernels.
template
static __global__ void UpFirDn2DKernel_small(const UpFirDn2DKernelParams p)
{
//assert(kernelW % upx == 0);
//assert(kernelH % upy == 0);
const int tileInW = ((tileOutW - 1) * downx + kernelW - 1) / upx + 1;
const int tileInH = ((tileOutH - 1) * downy + kernelH - 1) / upy + 1;
__shared__ volatile float sk[kernelH][kernelW];
__shared__ volatile float sx[tileInH][tileInW];
// Calculate tile index.
int minorIdx = blockIdx.x;
int tileOutY = minorIdx / p.minorDim;
minorIdx -= tileOutY * p.minorDim;
tileOutY *= tileOutH;
int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
int majorIdxBase = blockIdx.z * p.loopMajor;
if (tileOutXBase >= p.outW | tileOutY >= p.outH | majorIdxBase >= p.majorDim)
return;
// Load filter kernel (flipped).
for (int tapIdx = threadIdx.x; tapIdx < kernelH * kernelW; tapIdx += blockDim.x)
{
int ky = tapIdx / kernelW;
int kx = tapIdx - ky * kernelW;
float v = 0.0f;
if (kx < p.kernelW & ky < p.kernelH)
v = (float)p.k[(p.kernelH - 1 - ky) * p.kernelW + (p.kernelW - 1 - kx)];
sk[ky][kx] = v;
}
// Loop over majorDim and outX.
for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor & majorIdx < p.majorDim; loopMajor++, majorIdx++)
for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outW; loopX++, tileOutX += tileOutW)
{
// Load input pixels.
int tileMidX = tileOutX * downx + upx - 1 - p.padx0;
int tileMidY = tileOutY * downy + upy - 1 - p.pady0;
int tileInX = floorDiv(tileMidX, upx);
int tileInY = floorDiv(tileMidY, upy);
__syncthreads();
for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW; inIdx += blockDim.x)
{
int relInY = inIdx / tileInW;
int relInX = inIdx - relInY * tileInW;
int inX = relInX + tileInX;
int inY = relInY + tileInY;
float v = 0.0f;
if (inX >= 0 & inY >= 0 & inX < p.inW & inY < p.inH)
v = (float)p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx];
sx[relInY][relInX] = v;
}
// Loop over output pixels.
__syncthreads();
for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW; outIdx += blockDim.x)
{
int relOutY = outIdx / tileOutW;
int relOutX = outIdx - relOutY * tileOutW;
int outX = relOutX + tileOutX;
int outY = relOutY + tileOutY;
// Setup receptive field.
int midX = tileMidX + relOutX * downx;
int midY = tileMidY + relOutY * downy;
int inX = floorDiv(midX, upx);
int inY = floorDiv(midY, upy);
int relInX = inX - tileInX;
int relInY = inY - tileInY;
int kernelX = (inX + 1) * upx - midX - 1; // flipped
int kernelY = (inY + 1) * upy - midY - 1; // flipped
// Inner loop.
float v = 0.0f;
#pragma unroll
for (int y = 0; y < kernelH / upy; y++)
#pragma unroll
for (int x = 0; x < kernelW / upx; x++)
v += sx[relInY + y][relInX + x] * sk[kernelY + y * upy][kernelX + x * upx];
// Store result.
if (outX < p.outW & outY < p.outH)
p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v;
}
}
}
//------------------------------------------------------------------------
// TensorFlow op.
template
struct UpFirDn2DOp : public OpKernel
{
UpFirDn2DKernelParams m_attribs;
UpFirDn2DOp(OpKernelConstruction* ctx) : OpKernel(ctx)
{
memset(&m_attribs, 0, sizeof(m_attribs));
OP_REQUIRES_OK(ctx, ctx->GetAttr("upx", &m_attribs.upx));
OP_REQUIRES_OK(ctx, ctx->GetAttr("upy", &m_attribs.upy));
OP_REQUIRES_OK(ctx, ctx->GetAttr("downx", &m_attribs.downx));
OP_REQUIRES_OK(ctx, ctx->GetAttr("downy", &m_attribs.downy));
OP_REQUIRES_OK(ctx, ctx->GetAttr("padx0", &m_attribs.padx0));
OP_REQUIRES_OK(ctx, ctx->GetAttr("padx1", &m_attribs.padx1));
OP_REQUIRES_OK(ctx, ctx->GetAttr("pady0", &m_attribs.pady0));
OP_REQUIRES_OK(ctx, ctx->GetAttr("pady1", &m_attribs.pady1));
OP_REQUIRES(ctx, m_attribs.upx >= 1 && m_attribs.upy >= 1, errors::InvalidArgument("upx and upy must be at least 1x1"));
OP_REQUIRES(ctx, m_attribs.downx >= 1 && m_attribs.downy >= 1, errors::InvalidArgument("downx and downy must be at least 1x1"));
}
void Compute(OpKernelContext* ctx)
{
UpFirDn2DKernelParams p = m_attribs;
cudaStream_t stream = ctx->eigen_device().stream();
const Tensor& x = ctx->input(0); // [majorDim, inH, inW, minorDim]
const Tensor& k = ctx->input(1); // [kernelH, kernelW]
p.x = x.flat().data();
p.k = k.flat().data();
OP_REQUIRES(ctx, x.dims() == 4, errors::InvalidArgument("input must have rank 4"));
OP_REQUIRES(ctx, k.dims() == 2, errors::InvalidArgument("kernel must have rank 2"));
OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("input too large"));
OP_REQUIRES(ctx, k.NumElements() <= kint32max, errors::InvalidArgument("kernel too large"));
p.majorDim = (int)x.dim_size(0);
p.inH = (int)x.dim_size(1);
p.inW = (int)x.dim_size(2);
p.minorDim = (int)x.dim_size(3);
p.kernelH = (int)k.dim_size(0);
p.kernelW = (int)k.dim_size(1);
OP_REQUIRES(ctx, p.kernelW >= 1 && p.kernelH >= 1, errors::InvalidArgument("kernel must be at least 1x1"));
p.outW = (p.inW * p.upx + p.padx0 + p.padx1 - p.kernelW + p.downx) / p.downx;
p.outH = (p.inH * p.upy + p.pady0 + p.pady1 - p.kernelH + p.downy) / p.downy;
OP_REQUIRES(ctx, p.outW >= 1 && p.outH >= 1, errors::InvalidArgument("output must be at least 1x1"));
Tensor* y = NULL; // [majorDim, outH, outW, minorDim]
TensorShape ys;
ys.AddDim(p.majorDim);
ys.AddDim(p.outH);
ys.AddDim(p.outW);
ys.AddDim(p.minorDim);
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, ys, &y));
p.y = y->flat().data();
OP_REQUIRES(ctx, y->NumElements() <= kint32max, errors::InvalidArgument("output too large"));
// Choose CUDA kernel to use.
void* cudaKernel = (void*)UpFirDn2DKernel_large;
int tileOutW = -1;
int tileOutH = -1;
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 7 && p.kernelH <= 7) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 5 && p.kernelH <= 5) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 3 && p.kernelH <= 3) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 8) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 2 && p.kernelH <= 2) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 8 && p.kernelH <= 8) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; }
if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; }
if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; }
if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 2 && p.kernelH <= 2) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; }
// Choose launch params.
dim3 blockSize;
dim3 gridSize;
if (tileOutW > 0 && tileOutH > 0) // small
{
p.loopMajor = (p.majorDim - 1) / 16384 + 1;
p.loopX = 1;
blockSize = dim3(32 * 8, 1, 1);
gridSize = dim3(((p.outH - 1) / tileOutH + 1) * p.minorDim, (p.outW - 1) / (p.loopX * tileOutW) + 1, (p.majorDim - 1) / p.loopMajor + 1);
}
else // large
{
p.loopMajor = (p.majorDim - 1) / 16384 + 1;
p.loopX = 4;
blockSize = dim3(4, 32, 1);
gridSize = dim3((p.outH * p.minorDim - 1) / blockSize.x + 1, (p.outW - 1) / (p.loopX * blockSize.y) + 1, (p.majorDim - 1) / p.loopMajor + 1);
}
// Launch CUDA kernel.
void* args[] = {&p};
OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(cudaKernel, gridSize, blockSize, args, 0, stream));
}
};
REGISTER_OP("UpFirDn2D")
.Input ("x: T")
.Input ("k: T")
.Output ("y: T")
.Attr ("T: {float, half}")
.Attr ("upx: int = 1")
.Attr ("upy: int = 1")
.Attr ("downx: int = 1")
.Attr ("downy: int = 1")
.Attr ("padx0: int = 0")
.Attr ("padx1: int = 0")
.Attr ("pady0: int = 0")
.Attr ("pady1: int = 0");
REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint("T"), UpFirDn2DOp);
REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint("T"), UpFirDn2DOp);
//------------------------------------------------------------------------
================================================
FILE: stylegan_human/dnnlib/tflib/ops/upfirdn_2d.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://nvlabs.github.io/stylegan2/license.html
"""Custom TensorFlow ops for efficient resampling of 2D images."""
import os
import numpy as np
import tensorflow as tf
from .. import custom_ops
def _get_plugin():
return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu')
#----------------------------------------------------------------------------
def upfirdn_2d(x, k, upx=1, upy=1, downx=1, downy=1, padx0=0, padx1=0, pady0=0, pady1=0, impl='cuda'):
r"""Pad, upsample, FIR filter, and downsample a batch of 2D images.
Accepts a batch of 2D images of the shape `[majorDim, inH, inW, minorDim]`
and performs the following operations for each image, batched across
`majorDim` and `minorDim`:
1. Pad the image with zeros by the specified number of pixels on each side
(`padx0`, `padx1`, `pady0`, `pady1`). Specifying a negative value
corresponds to cropping the image.
2. Upsample the image by inserting the zeros after each pixel (`upx`, `upy`).
3. Convolve the image with the specified 2D FIR filter (`k`), shrinking the
image so that the footprint of all output pixels lies within the input image.
4. Downsample the image by throwing away pixels (`downx`, `downy`).
This sequence of operations bears close resemblance to scipy.signal.upfirdn().
The fused op is considerably more efficient than performing the same calculation
using standard TensorFlow ops. It supports gradients of arbitrary order.
Args:
x: Input tensor of the shape `[majorDim, inH, inW, minorDim]`.
k: 2D FIR filter of the shape `[firH, firW]`.
upx: Integer upsampling factor along the X-axis (default: 1).
upy: Integer upsampling factor along the Y-axis (default: 1).
downx: Integer downsampling factor along the X-axis (default: 1).
downy: Integer downsampling factor along the Y-axis (default: 1).
padx0: Number of pixels to pad on the left side (default: 0).
padx1: Number of pixels to pad on the right side (default: 0).
pady0: Number of pixels to pad on the top side (default: 0).
pady1: Number of pixels to pad on the bottom side (default: 0).
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
Returns:
Tensor of the shape `[majorDim, outH, outW, minorDim]`, and same datatype as `x`.
"""
impl_dict = {
'ref': _upfirdn_2d_ref,
'cuda': _upfirdn_2d_cuda,
}
return impl_dict[impl](x=x, k=k, upx=upx, upy=upy, downx=downx, downy=downy, padx0=padx0, padx1=padx1, pady0=pady0, pady1=pady1)
#----------------------------------------------------------------------------
def _upfirdn_2d_ref(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1):
"""Slow reference implementation of `upfirdn_2d()` using standard TensorFlow ops."""
x = tf.convert_to_tensor(x)
k = np.asarray(k, dtype=np.float32)
assert x.shape.rank == 4
inH = x.shape[1].value
inW = x.shape[2].value
minorDim = _shape(x, 3)
kernelH, kernelW = k.shape
assert inW >= 1 and inH >= 1
assert kernelW >= 1 and kernelH >= 1
assert isinstance(upx, int) and isinstance(upy, int)
assert isinstance(downx, int) and isinstance(downy, int)
assert isinstance(padx0, int) and isinstance(padx1, int)
assert isinstance(pady0, int) and isinstance(pady1, int)
# Upsample (insert zeros).
x = tf.reshape(x, [-1, inH, 1, inW, 1, minorDim])
x = tf.pad(x, [[0, 0], [0, 0], [0, upy - 1], [0, 0], [0, upx - 1], [0, 0]])
x = tf.reshape(x, [-1, inH * upy, inW * upx, minorDim])
# Pad (crop if negative).
x = tf.pad(x, [[0, 0], [max(pady0, 0), max(pady1, 0)], [max(padx0, 0), max(padx1, 0)], [0, 0]])
x = x[:, max(-pady0, 0) : x.shape[1].value - max(-pady1, 0), max(-padx0, 0) : x.shape[2].value - max(-padx1, 0), :]
# Convolve with filter.
x = tf.transpose(x, [0, 3, 1, 2])
x = tf.reshape(x, [-1, 1, inH * upy + pady0 + pady1, inW * upx + padx0 + padx1])
w = tf.constant(k[::-1, ::-1, np.newaxis, np.newaxis], dtype=x.dtype)
x = tf.nn.conv2d(x, w, strides=[1,1,1,1], padding='VALID', data_format='NCHW')
x = tf.reshape(x, [-1, minorDim, inH * upy + pady0 + pady1 - kernelH + 1, inW * upx + padx0 + padx1 - kernelW + 1])
x = tf.transpose(x, [0, 2, 3, 1])
# Downsample (throw away pixels).
return x[:, ::downy, ::downx, :]
#----------------------------------------------------------------------------
def _upfirdn_2d_cuda(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1):
"""Fast CUDA implementation of `upfirdn_2d()` using custom ops."""
x = tf.convert_to_tensor(x)
k = np.asarray(k, dtype=np.float32)
majorDim, inH, inW, minorDim = x.shape.as_list()
kernelH, kernelW = k.shape
assert inW >= 1 and inH >= 1
assert kernelW >= 1 and kernelH >= 1
assert isinstance(upx, int) and isinstance(upy, int)
assert isinstance(downx, int) and isinstance(downy, int)
assert isinstance(padx0, int) and isinstance(padx1, int)
assert isinstance(pady0, int) and isinstance(pady1, int)
outW = (inW * upx + padx0 + padx1 - kernelW) // downx + 1
outH = (inH * upy + pady0 + pady1 - kernelH) // downy + 1
assert outW >= 1 and outH >= 1
kc = tf.constant(k, dtype=x.dtype)
gkc = tf.constant(k[::-1, ::-1], dtype=x.dtype)
gpadx0 = kernelW - padx0 - 1
gpady0 = kernelH - pady0 - 1
gpadx1 = inW * upx - outW * downx + padx0 - upx + 1
gpady1 = inH * upy - outH * downy + pady0 - upy + 1
@tf.custom_gradient
def func(x):
y = _get_plugin().up_fir_dn2d(x=x, k=kc, upx=upx, upy=upy, downx=downx, downy=downy, padx0=padx0, padx1=padx1, pady0=pady0, pady1=pady1)
y.set_shape([majorDim, outH, outW, minorDim])
@tf.custom_gradient
def grad(dy):
dx = _get_plugin().up_fir_dn2d(x=dy, k=gkc, upx=downx, upy=downy, downx=upx, downy=upy, padx0=gpadx0, padx1=gpadx1, pady0=gpady0, pady1=gpady1)
dx.set_shape([majorDim, inH, inW, minorDim])
return dx, func
return y, grad
return func(x)
#----------------------------------------------------------------------------
def filter_2d(x, k, gain=1, data_format='NCHW', impl='cuda'):
r"""Filter a batch of 2D images with the given FIR filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
and filters each image with the given filter. The filter is normalized so that
if the input pixels are constant, they will be scaled by the specified `gain`.
Pixels outside the image are assumed to be zero.
Args:
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
gain: Scaling factor for signal magnitude (default: 1.0).
data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
Returns:
Tensor of the same shape and datatype as `x`.
"""
k = _setup_kernel(k) * gain
p = k.shape[0] - 1
return _simple_upfirdn_2d(x, k, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl)
#----------------------------------------------------------------------------
def upsample_2d(x, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
r"""Upsample a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
and upsamples each image with the given filter. The filter is normalized so that
if the input pixels are constant, they will be scaled by the specified `gain`.
Pixels outside the image are assumed to be zero, and the filter is padded with
zeros so that its shape is a multiple of the upsampling factor.
Args:
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
The default is `[1] * factor`, which corresponds to nearest-neighbor
upsampling.
factor: Integer upsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]` or
`[N, H * factor, W * factor, C]`, and same datatype as `x`.
"""
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
k = _setup_kernel(k) * (gain * (factor ** 2))
p = k.shape[0] - factor
return _simple_upfirdn_2d(x, k, up=factor, pad0=(p+1)//2+factor-1, pad1=p//2, data_format=data_format, impl=impl)
#----------------------------------------------------------------------------
def downsample_2d(x, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
r"""Downsample a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
and downsamples each image with the given filter. The filter is normalized so that
if the input pixels are constant, they will be scaled by the specified `gain`.
Pixels outside the image are assumed to be zero, and the filter is padded with
zeros so that its shape is a multiple of the downsampling factor.
Args:
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]` or
`[N, H // factor, W // factor, C]`, and same datatype as `x`.
"""
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
k = _setup_kernel(k) * gain
p = k.shape[0] - factor
return _simple_upfirdn_2d(x, k, down=factor, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl)
#----------------------------------------------------------------------------
def upsample_conv_2d(x, w, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
r"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
Padding is performed only once at the beginning, not between the operations.
The fused op is considerably more efficient than performing the same calculation
using standard TensorFlow ops. It supports gradients of arbitrary order.
Args:
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`.
Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
The default is `[1] * factor`, which corresponds to nearest-neighbor
upsampling.
factor: Integer upsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]` or
`[N, H * factor, W * factor, C]`, and same datatype as `x`.
"""
assert isinstance(factor, int) and factor >= 1
# Check weight shape.
w = tf.convert_to_tensor(w)
assert w.shape.rank == 4
convH = w.shape[0].value
convW = w.shape[1].value
inC = _shape(w, 2)
outC = _shape(w, 3)
assert convW == convH
# Setup filter kernel.
if k is None:
k = [1] * factor
k = _setup_kernel(k) * (gain * (factor ** 2))
p = (k.shape[0] - factor) - (convW - 1)
# Determine data dimensions.
if data_format == 'NCHW':
stride = [1, 1, factor, factor]
output_shape = [_shape(x, 0), outC, (_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW]
num_groups = _shape(x, 1) // inC
else:
stride = [1, factor, factor, 1]
output_shape = [_shape(x, 0), (_shape(x, 1) - 1) * factor + convH, (_shape(x, 2) - 1) * factor + convW, outC]
num_groups = _shape(x, 3) // inC
# Transpose weights.
w = tf.reshape(w, [convH, convW, inC, num_groups, -1])
w = tf.transpose(w[::-1, ::-1], [0, 1, 4, 3, 2])
w = tf.reshape(w, [convH, convW, -1, num_groups * inC])
# Execute.
x = tf.nn.conv2d_transpose(x, w, output_shape=output_shape, strides=stride, padding='VALID', data_format=data_format)
return _simple_upfirdn_2d(x, k, pad0=(p+1)//2+factor-1, pad1=p//2+1, data_format=data_format, impl=impl)
#----------------------------------------------------------------------------
def conv_downsample_2d(x, w, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
r"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
Padding is performed only once at the beginning, not between the operations.
The fused op is considerably more efficient than performing the same calculation
using standard TensorFlow ops. It supports gradients of arbitrary order.
Args:
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`.
Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]` or
`[N, H // factor, W // factor, C]`, and same datatype as `x`.
"""
assert isinstance(factor, int) and factor >= 1
w = tf.convert_to_tensor(w)
convH, convW, _inC, _outC = w.shape.as_list()
assert convW == convH
if k is None:
k = [1] * factor
k = _setup_kernel(k) * gain
p = (k.shape[0] - factor) + (convW - 1)
if data_format == 'NCHW':
s = [1, 1, factor, factor]
else:
s = [1, factor, factor, 1]
x = _simple_upfirdn_2d(x, k, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl)
return tf.nn.conv2d(x, w, strides=s, padding='VALID', data_format=data_format)
#----------------------------------------------------------------------------
# Internal helper funcs.
def _shape(tf_expr, dim_idx):
if tf_expr.shape.rank is not None:
dim = tf_expr.shape[dim_idx].value
if dim is not None:
return dim
return tf.shape(tf_expr)[dim_idx]
def _setup_kernel(k):
k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
k /= np.sum(k)
assert k.ndim == 2
assert k.shape[0] == k.shape[1]
return k
def _simple_upfirdn_2d(x, k, up=1, down=1, pad0=0, pad1=0, data_format='NCHW', impl='cuda'):
assert data_format in ['NCHW', 'NHWC']
assert x.shape.rank == 4
y = x
if data_format == 'NCHW':
y = tf.reshape(y, [-1, _shape(y, 2), _shape(y, 3), 1])
y = upfirdn_2d(y, k, upx=up, upy=up, downx=down, downy=down, padx0=pad0, padx1=pad1, pady0=pad0, pady1=pad1, impl=impl)
if data_format == 'NCHW':
y = tf.reshape(y, [-1, _shape(x, 1), _shape(y, 1), _shape(y, 2)])
return y
#----------------------------------------------------------------------------
================================================
FILE: stylegan_human/dnnlib/tflib/optimizer.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://nvlabs.github.io/stylegan2/license.html
"""Helper wrapper for a Tensorflow optimizer."""
import numpy as np
import tensorflow as tf
from collections import OrderedDict
from typing import List, Union
from . import autosummary
from . import tfutil
from .. import util
from .tfutil import TfExpression, TfExpressionEx
try:
# TensorFlow 1.13
from tensorflow.python.ops import nccl_ops
except:
# Older TensorFlow versions
import tensorflow.contrib.nccl as nccl_ops
class Optimizer:
"""A Wrapper for tf.train.Optimizer.
Automatically takes care of:
- Gradient averaging for multi-GPU training.
- Gradient accumulation for arbitrarily large minibatches.
- Dynamic loss scaling and typecasts for FP16 training.
- Ignoring corrupted gradients that contain NaNs/Infs.
- Reporting statistics.
- Well-chosen default settings.
"""
def __init__(self,
name: str = "Train", # Name string that will appear in TensorFlow graph.
tf_optimizer: str = "tf.train.AdamOptimizer", # Underlying optimizer class.
learning_rate: TfExpressionEx = 0.001, # Learning rate. Can vary over time.
minibatch_multiplier: TfExpressionEx = None, # Treat N consecutive minibatches as one by accumulating gradients.
share: "Optimizer" = None, # Share internal state with a previously created optimizer?
use_loss_scaling: bool = False, # Enable dynamic loss scaling for robust mixed-precision training?
loss_scaling_init: float = 64.0, # Log2 of initial loss scaling factor.
loss_scaling_inc: float = 0.0005, # Log2 of per-minibatch loss scaling increment when there is no overflow.
loss_scaling_dec: float = 1.0, # Log2 of per-minibatch loss scaling decrement when there is an overflow.
report_mem_usage: bool = False, # Report fine-grained memory usage statistics in TensorBoard?
**kwargs):
# Public fields.
self.name = name
self.learning_rate = learning_rate
self.minibatch_multiplier = minibatch_multiplier
self.id = self.name.replace("/", ".")
self.scope = tf.get_default_graph().unique_name(self.id)
self.optimizer_class = util.get_obj_by_name(tf_optimizer)
self.optimizer_kwargs = dict(kwargs)
self.use_loss_scaling = use_loss_scaling
self.loss_scaling_init = loss_scaling_init
self.loss_scaling_inc = loss_scaling_inc
self.loss_scaling_dec = loss_scaling_dec
# Private fields.
self._updates_applied = False
self._devices = OrderedDict() # device_name => EasyDict()
self._shared_optimizers = OrderedDict() # device_name => optimizer_class
self._gradient_shapes = None # [shape, ...]
self._report_mem_usage = report_mem_usage
# Validate arguments.
assert callable(self.optimizer_class)
# Share internal state if requested.
if share is not None:
assert isinstance(share, Optimizer)
assert self.optimizer_class is share.optimizer_class
assert self.learning_rate is share.learning_rate
assert self.optimizer_kwargs == share.optimizer_kwargs
self._shared_optimizers = share._shared_optimizers # pylint: disable=protected-access
def _get_device(self, device_name: str):
"""Get internal state for the given TensorFlow device."""
tfutil.assert_tf_initialized()
if device_name in self._devices:
return self._devices[device_name]
# Initialize fields.
device = util.EasyDict()
device.name = device_name
device.optimizer = None # Underlying optimizer: optimizer_class
device.loss_scaling_var = None # Log2 of loss scaling: tf.Variable
device.grad_raw = OrderedDict() # Raw gradients: var => [grad, ...]
device.grad_clean = OrderedDict() # Clean gradients: var => grad
device.grad_acc_vars = OrderedDict() # Accumulation sums: var => tf.Variable
device.grad_acc_count = None # Accumulation counter: tf.Variable
device.grad_acc = OrderedDict() # Accumulated gradients: var => grad
# Setup TensorFlow objects.
with tfutil.absolute_name_scope(self.scope + "/Devices"), tf.device(device_name), tf.control_dependencies(None):
if device_name not in self._shared_optimizers:
optimizer_name = self.scope.replace("/", "_") + "_opt%d" % len(self._shared_optimizers)
self._shared_optimizers[device_name] = self.optimizer_class(name=optimizer_name, learning_rate=self.learning_rate, **self.optimizer_kwargs)
device.optimizer = self._shared_optimizers[device_name]
if self.use_loss_scaling:
device.loss_scaling_var = tf.Variable(np.float32(self.loss_scaling_init), trainable=False, name="loss_scaling_var")
# Register device.
self._devices[device_name] = device
return device
def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None:
"""Register the gradients of the given loss function with respect to the given variables.
Intended to be called once per GPU."""
tfutil.assert_tf_initialized()
assert not self._updates_applied
device = self._get_device(loss.device)
# Validate trainables.
if isinstance(trainable_vars, dict):
trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars
assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])
assert all(var.device == device.name for var in trainable_vars)
# Validate shapes.
if self._gradient_shapes is None:
self._gradient_shapes = [var.shape.as_list() for var in trainable_vars]
assert len(trainable_vars) == len(self._gradient_shapes)
assert all(var.shape.as_list() == var_shape for var, var_shape in zip(trainable_vars, self._gradient_shapes))
# Report memory usage if requested.
deps = []
if self._report_mem_usage:
self._report_mem_usage = False
try:
with tf.name_scope(self.id + '_mem'), tf.device(device.name), tf.control_dependencies([loss]):
deps.append(autosummary.autosummary(self.id + "/mem_usage_gb", tf.contrib.memory_stats.BytesInUse() / 2**30))
except tf.errors.NotFoundError:
pass
# Compute gradients.
with tf.name_scope(self.id + "_grad"), tf.device(device.name), tf.control_dependencies(deps):
loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
gate = tf.train.Optimizer.GATE_NONE # disable gating to reduce memory usage
grad_list = device.optimizer.compute_gradients(loss=loss, var_list=trainable_vars, gate_gradients=gate)
# Register gradients.
for grad, var in grad_list:
if var not in device.grad_raw:
device.grad_raw[var] = []
device.grad_raw[var].append(grad)
def apply_updates(self, allow_no_op: bool = False) -> tf.Operation:
"""Construct training op to update the registered variables based on their gradients."""
tfutil.assert_tf_initialized()
assert not self._updates_applied
self._updates_applied = True
all_ops = []
# Check for no-op.
if allow_no_op and len(self._devices) == 0:
with tfutil.absolute_name_scope(self.scope):
return tf.no_op(name='TrainingOp')
# Clean up gradients.
for device_idx, device in enumerate(self._devices.values()):
with tfutil.absolute_name_scope(self.scope + "/Clean%d" % device_idx), tf.device(device.name):
for var, grad in device.grad_raw.items():
# Filter out disconnected gradients and convert to float32.
grad = [g for g in grad if g is not None]
grad = [tf.cast(g, tf.float32) for g in grad]
# Sum within the device.
if len(grad) == 0:
grad = tf.zeros(var.shape) # No gradients => zero.
elif len(grad) == 1:
grad = grad[0] # Single gradient => use as is.
else:
grad = tf.add_n(grad) # Multiple gradients => sum.
# Scale as needed.
scale = 1.0 / len(device.grad_raw[var]) / len(self._devices)
scale = tf.constant(scale, dtype=tf.float32, name="scale")
if self.minibatch_multiplier is not None:
scale /= tf.cast(self.minibatch_multiplier, tf.float32)
scale = self.undo_loss_scaling(scale)
device.grad_clean[var] = grad * scale
# Sum gradients across devices.
if len(self._devices) > 1:
with tfutil.absolute_name_scope(self.scope + "/Broadcast"), tf.device(None):
for all_vars in zip(*[device.grad_clean.keys() for device in self._devices.values()]):
if len(all_vars) > 0 and all(dim > 0 for dim in all_vars[0].shape.as_list()): # NCCL does not support zero-sized tensors.
all_grads = [device.grad_clean[var] for device, var in zip(self._devices.values(), all_vars)]
all_grads = nccl_ops.all_sum(all_grads)
for device, var, grad in zip(self._devices.values(), all_vars, all_grads):
device.grad_clean[var] = grad
# Apply updates separately on each device.
for device_idx, device in enumerate(self._devices.values()):
with tfutil.absolute_name_scope(self.scope + "/Apply%d" % device_idx), tf.device(device.name):
# pylint: disable=cell-var-from-loop
# Accumulate gradients over time.
if self.minibatch_multiplier is None:
acc_ok = tf.constant(True, name='acc_ok')
device.grad_acc = OrderedDict(device.grad_clean)
else:
# Create variables.
with tf.control_dependencies(None):
for var in device.grad_clean.keys():
device.grad_acc_vars[var] = tf.Variable(tf.zeros(var.shape), trainable=False, name="grad_acc_var")
device.grad_acc_count = tf.Variable(tf.zeros([]), trainable=False, name="grad_acc_count")
# Track counter.
count_cur = device.grad_acc_count + 1.0
count_inc_op = lambda: tf.assign(device.grad_acc_count, count_cur)
count_reset_op = lambda: tf.assign(device.grad_acc_count, tf.zeros([]))
acc_ok = (count_cur >= tf.cast(self.minibatch_multiplier, tf.float32))
all_ops.append(tf.cond(acc_ok, count_reset_op, count_inc_op))
# Track gradients.
for var, grad in device.grad_clean.items():
acc_var = device.grad_acc_vars[var]
acc_cur = acc_var + grad
device.grad_acc[var] = acc_cur
with tf.control_dependencies([acc_cur]):
acc_inc_op = lambda: tf.assign(acc_var, acc_cur)
acc_reset_op = lambda: tf.assign(acc_var, tf.zeros(var.shape))
all_ops.append(tf.cond(acc_ok, acc_reset_op, acc_inc_op))
# No overflow => apply gradients.
all_ok = tf.reduce_all(tf.stack([acc_ok] + [tf.reduce_all(tf.is_finite(g)) for g in device.grad_acc.values()]))
apply_op = lambda: device.optimizer.apply_gradients([(tf.cast(grad, var.dtype), var) for var, grad in device.grad_acc.items()])
all_ops.append(tf.cond(all_ok, apply_op, tf.no_op))
# Adjust loss scaling.
if self.use_loss_scaling:
ls_inc_op = lambda: tf.assign_add(device.loss_scaling_var, self.loss_scaling_inc)
ls_dec_op = lambda: tf.assign_sub(device.loss_scaling_var, self.loss_scaling_dec)
ls_update_op = lambda: tf.group(tf.cond(all_ok, ls_inc_op, ls_dec_op))
all_ops.append(tf.cond(acc_ok, ls_update_op, tf.no_op))
# Last device => report statistics.
if device_idx == len(self._devices) - 1:
all_ops.append(autosummary.autosummary(self.id + "/learning_rate", self.learning_rate))
all_ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(all_ok, 0, 1), condition=acc_ok))
if self.use_loss_scaling:
all_ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", device.loss_scaling_var))
# Initialize variables.
self.reset_optimizer_state()
if self.use_loss_scaling:
tfutil.init_uninitialized_vars([device.loss_scaling_var for device in self._devices.values()])
if self.minibatch_multiplier is not None:
tfutil.run([var.initializer for device in self._devices.values() for var in list(device.grad_acc_vars.values()) + [device.grad_acc_count]])
# Group everything into a single op.
with tfutil.absolute_name_scope(self.scope):
return tf.group(*all_ops, name="TrainingOp")
def reset_optimizer_state(self) -> None:
"""Reset internal state of the underlying optimizer."""
tfutil.assert_tf_initialized()
tfutil.run([var.initializer for device in self._devices.values() for var in device.optimizer.variables()])
def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:
"""Get or create variable representing log2 of the current dynamic loss scaling factor."""
return self._get_device(device).loss_scaling_var
def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
"""Apply dynamic loss scaling for the given expression."""
assert tfutil.is_tf_expression(value)
if not self.use_loss_scaling:
return value
return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
"""Undo the effect of dynamic loss scaling for the given expression."""
assert tfutil.is_tf_expression(value)
if not self.use_loss_scaling:
return value
return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type
class SimpleAdam:
"""Simplified version of tf.train.AdamOptimizer that behaves identically when used with dnnlib.tflib.Optimizer."""
def __init__(self, name="Adam", learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
self.name = name
self.learning_rate = learning_rate
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.all_state_vars = []
def variables(self):
return self.all_state_vars
def compute_gradients(self, loss, var_list, gate_gradients=tf.train.Optimizer.GATE_NONE):
assert gate_gradients == tf.train.Optimizer.GATE_NONE
return list(zip(tf.gradients(loss, var_list), var_list))
def apply_gradients(self, grads_and_vars):
with tf.name_scope(self.name):
state_vars = []
update_ops = []
# Adjust learning rate to deal with startup bias.
with tf.control_dependencies(None):
b1pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
b2pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
state_vars += [b1pow_var, b2pow_var]
b1pow_new = b1pow_var * self.beta1
b2pow_new = b2pow_var * self.beta2
update_ops += [tf.assign(b1pow_var, b1pow_new), tf.assign(b2pow_var, b2pow_new)]
lr_new = self.learning_rate * tf.sqrt(1 - b2pow_new) / (1 - b1pow_new)
# Construct ops to update each variable.
for grad, var in grads_and_vars:
with tf.control_dependencies(None):
m_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
v_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
state_vars += [m_var, v_var]
m_new = self.beta1 * m_var + (1 - self.beta1) * grad
v_new = self.beta2 * v_var + (1 - self.beta2) * tf.square(grad)
var_delta = lr_new * m_new / (tf.sqrt(v_new) + self.epsilon)
update_ops += [tf.assign(m_var, m_new), tf.assign(v_var, v_new), tf.assign_sub(var, var_delta)]
# Group everything together.
self.all_state_vars += state_vars
return tf.group(*update_ops)
================================================
FILE: stylegan_human/dnnlib/tflib/tfutil.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://nvlabs.github.io/stylegan2/license.html
"""Miscellaneous helper utils for Tensorflow."""
import os
import numpy as np
import tensorflow as tf
# Silence deprecation warnings from TensorFlow 1.13 onwards
import logging
logging.getLogger('tensorflow').setLevel(logging.ERROR)
import tensorflow.contrib # requires TensorFlow 1.x!
tf.contrib = tensorflow.contrib
from typing import Any, Iterable, List, Union
TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation]
"""A type that represents a valid Tensorflow expression."""
TfExpressionEx = Union[TfExpression, int, float, np.ndarray]
"""A type that can be converted to a valid Tensorflow expression."""
def run(*args, **kwargs) -> Any:
"""Run the specified ops in the default session."""
assert_tf_initialized()
return tf.get_default_session().run(*args, **kwargs)
def is_tf_expression(x: Any) -> bool:
"""Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation."""
return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation))
def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]:
"""Convert a Tensorflow shape to a list of ints. Retained for backwards compatibility -- use TensorShape.as_list() in new code."""
return [dim.value for dim in shape]
def flatten(x: TfExpressionEx) -> TfExpression:
"""Shortcut function for flattening a tensor."""
with tf.name_scope("Flatten"):
return tf.reshape(x, [-1])
def log2(x: TfExpressionEx) -> TfExpression:
"""Logarithm in base 2."""
with tf.name_scope("Log2"):
return tf.log(x) * np.float32(1.0 / np.log(2.0))
def exp2(x: TfExpressionEx) -> TfExpression:
"""Exponent in base 2."""
with tf.name_scope("Exp2"):
return tf.exp(x * np.float32(np.log(2.0)))
def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx:
"""Linear interpolation."""
with tf.name_scope("Lerp"):
return a + (b - a) * t
def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression:
"""Linear interpolation with clip."""
with tf.name_scope("LerpClip"):
return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)
def absolute_name_scope(scope: str) -> tf.name_scope:
"""Forcefully enter the specified name scope, ignoring any surrounding scopes."""
return tf.name_scope(scope + "/")
def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope:
"""Forcefully enter the specified variable scope, ignoring any surrounding scopes."""
return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False)
def _sanitize_tf_config(config_dict: dict = None) -> dict:
# Defaults.
cfg = dict()
cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is.
cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is.
cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info.
cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used.
cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed.
# Remove defaults for environment variables that are already set.
for key in list(cfg):
fields = key.split(".")
if fields[0] == "env":
assert len(fields) == 2
if fields[1] in os.environ:
del cfg[key]
# User overrides.
if config_dict is not None:
cfg.update(config_dict)
return cfg
def init_tf(config_dict: dict = None) -> None:
"""Initialize TensorFlow session using good default settings."""
# Skip if already initialized.
if tf.get_default_session() is not None:
return
# Setup config dict and random seeds.
cfg = _sanitize_tf_config(config_dict)
np_random_seed = cfg["rnd.np_random_seed"]
if np_random_seed is not None:
np.random.seed(np_random_seed)
tf_random_seed = cfg["rnd.tf_random_seed"]
if tf_random_seed == "auto":
tf_random_seed = np.random.randint(1 << 31)
if tf_random_seed is not None:
tf.set_random_seed(tf_random_seed)
# Setup environment variables.
for key, value in cfg.items():
fields = key.split(".")
if fields[0] == "env":
assert len(fields) == 2
os.environ[fields[1]] = str(value)
# Create default TensorFlow session.
create_session(cfg, force_as_default=True)
def assert_tf_initialized():
"""Check that TensorFlow session has been initialized."""
if tf.get_default_session() is None:
raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().")
def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session:
"""Create tf.Session based on config dict."""
# Setup TensorFlow config proto.
cfg = _sanitize_tf_config(config_dict)
config_proto = tf.ConfigProto()
for key, value in cfg.items():
fields = key.split(".")
if fields[0] not in ["rnd", "env"]:
obj = config_proto
for field in fields[:-1]:
obj = getattr(obj, field)
setattr(obj, fields[-1], value)
# Create session.
session = tf.Session(config=config_proto)
if force_as_default:
# pylint: disable=protected-access
session._default_session = session.as_default()
session._default_session.enforce_nesting = False
session._default_session.__enter__()
return session
def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None:
"""Initialize all tf.Variables that have not already been initialized.
Equivalent to the following, but more efficient and does not bloat the tf graph:
tf.variables_initializer(tf.report_uninitialized_variables()).run()
"""
assert_tf_initialized()
if target_vars is None:
target_vars = tf.global_variables()
test_vars = []
test_ops = []
with tf.control_dependencies(None): # ignore surrounding control_dependencies
for var in target_vars:
assert is_tf_expression(var)
try:
tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0"))
except KeyError:
# Op does not exist => variable may be uninitialized.
test_vars.append(var)
with absolute_name_scope(var.name.split(":")[0]):
test_ops.append(tf.is_variable_initialized(var))
init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited]
run([var.initializer for var in init_vars])
def set_vars(var_to_value_dict: dict) -> None:
"""Set the values of given tf.Variables.
Equivalent to the following, but more efficient and does not bloat the tf graph:
tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()]
"""
assert_tf_initialized()
ops = []
feed_dict = {}
for var, value in var_to_value_dict.items():
assert is_tf_expression(var)
try:
setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op
except KeyError:
with absolute_name_scope(var.name.split(":")[0]):
with tf.control_dependencies(None): # ignore surrounding control_dependencies
setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter
ops.append(setter)
feed_dict[setter.op.inputs[1]] = value
run(ops, feed_dict)
def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs):
"""Create tf.Variable with large initial value without bloating the tf graph."""
assert_tf_initialized()
assert isinstance(initial_value, np.ndarray)
zeros = tf.zeros(initial_value.shape, initial_value.dtype)
var = tf.Variable(zeros, *args, **kwargs)
set_vars({var: initial_value})
return var
def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
"""Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
Can be used as an input transformation for Network.run().
"""
images = tf.cast(images, tf.float32)
if nhwc_to_nchw:
images = tf.transpose(images, [0, 3, 1, 2])
return images * ((drange[1] - drange[0]) / 255) + drange[0]
def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1):
"""Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
Can be used as an output transformation for Network.run().
"""
images = tf.cast(images, tf.float32)
if shrink > 1:
ksize = [1, 1, shrink, shrink]
images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW")
if nchw_to_nhwc:
images = tf.transpose(images, [0, 2, 3, 1])
scale = 255 / (drange[1] - drange[0])
images = images * scale + (0.5 - drange[0] * scale)
return tf.saturate_cast(images, tf.uint8)
================================================
FILE: stylegan_human/dnnlib/util.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Miscellaneous utility classes and functions."""
import ctypes
import fnmatch
import importlib
import inspect
import numpy as np
import os
import shutil
import sys
import types
import io
import pickle
import re
import requests
import html
import hashlib
import glob
import tempfile
import urllib
import urllib.request
import uuid
from distutils.util import strtobool
from typing import Any, List, Tuple, Union
# Util classes
# ------------------------------------------------------------------------------------------
class EasyDict(dict):
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError:
raise AttributeError(name)
def __setattr__(self, name: str, value: Any) -> None:
self[name] = value
def __delattr__(self, name: str) -> None:
del self[name]
class Logger(object):
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
self.file = None
if file_name is not None:
self.file = open(file_name, file_mode)
self.should_flush = should_flush
self.stdout = sys.stdout
self.stderr = sys.stderr
sys.stdout = self
sys.stderr = self
def __enter__(self) -> "Logger":
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.close()
def write(self, text: Union[str, bytes]) -> None:
"""Write text to stdout (and a file) and optionally flush."""
if isinstance(text, bytes):
text = text.decode()
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
return
if self.file is not None:
self.file.write(text)
self.stdout.write(text)
if self.should_flush:
self.flush()
def flush(self) -> None:
"""Flush written text to both stdout and a file, if open."""
if self.file is not None:
self.file.flush()
self.stdout.flush()
def close(self) -> None:
"""Flush, close possible files, and remove stdout/stderr mirroring."""
self.flush()
# if using multiple loggers, prevent closing in wrong order
if sys.stdout is self:
sys.stdout = self.stdout
if sys.stderr is self:
sys.stderr = self.stderr
if self.file is not None:
self.file.close()
self.file = None
# Cache directories
# ------------------------------------------------------------------------------------------
_dnnlib_cache_dir = None
def set_cache_dir(path: str) -> None:
global _dnnlib_cache_dir
_dnnlib_cache_dir = path
def make_cache_dir_path(*paths: str) -> str:
if _dnnlib_cache_dir is not None:
return os.path.join(_dnnlib_cache_dir, *paths)
if 'DNNLIB_CACHE_DIR' in os.environ:
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
if 'HOME' in os.environ:
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
if 'USERPROFILE' in os.environ:
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
# Small util functions
# ------------------------------------------------------------------------------------------
def format_time(seconds: Union[int, float]) -> str:
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
s = int(np.rint(seconds))
if s < 60:
return "{0}s".format(s)
elif s < 60 * 60:
return "{0}m {1:02}s".format(s // 60, s % 60)
elif s < 24 * 60 * 60:
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
else:
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
def ask_yes_no(question: str) -> bool:
"""Ask the user the question until the user inputs a valid answer."""
while True:
try:
print("{0} [y/n]".format(question))
return strtobool(input().lower())
except ValueError:
pass
def tuple_product(t: Tuple) -> Any:
"""Calculate the product of the tuple elements."""
result = 1
for v in t:
result *= v
return result
_str_to_ctype = {
"uint8": ctypes.c_ubyte,
"uint16": ctypes.c_uint16,
"uint32": ctypes.c_uint32,
"uint64": ctypes.c_uint64,
"int8": ctypes.c_byte,
"int16": ctypes.c_int16,
"int32": ctypes.c_int32,
"int64": ctypes.c_int64,
"float32": ctypes.c_float,
"float64": ctypes.c_double
}
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
type_str = None
if isinstance(type_obj, str):
type_str = type_obj
elif hasattr(type_obj, "__name__"):
type_str = type_obj.__name__
elif hasattr(type_obj, "name"):
type_str = type_obj.name
else:
raise RuntimeError("Cannot infer type name from input")
assert type_str in _str_to_ctype.keys()
my_dtype = np.dtype(type_str)
my_ctype = _str_to_ctype[type_str]
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
return my_dtype, my_ctype
def is_pickleable(obj: Any) -> bool:
try:
with io.BytesIO() as stream:
pickle.dump(obj, stream)
return True
except:
return False
# Functionality to import modules/objects by name, and call functions by name
# ------------------------------------------------------------------------------------------
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
"""Searches for the underlying module behind the name to some python object.
Returns the module and the object name (original name with module part removed)."""
# allow convenience shorthands, substitute them by full names
obj_name = re.sub("^np.", "numpy.", obj_name)
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
# list alternatives for (module_name, local_obj_name)
parts = obj_name.split(".")
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
# try each alternative in turn
for module_name, local_obj_name in name_pairs:
try:
module = importlib.import_module(module_name) # may raise ImportError
get_obj_from_module(module, local_obj_name) # may raise AttributeError
return module, local_obj_name
except:
pass
# maybe some of the modules themselves contain errors?
for module_name, _local_obj_name in name_pairs:
try:
importlib.import_module(module_name) # may raise ImportError
except ImportError:
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
raise
# maybe the requested attribute is missing?
for module_name, local_obj_name in name_pairs:
try:
module = importlib.import_module(module_name) # may raise ImportError
get_obj_from_module(module, local_obj_name) # may raise AttributeError
except ImportError:
pass
# we are out of luck, but we have no idea why
raise ImportError(obj_name)
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
"""Traverses the object name and returns the last (rightmost) python object."""
if obj_name == '':
return module
obj = module
for part in obj_name.split("."):
obj = getattr(obj, part)
return obj
def get_obj_by_name(name: str) -> Any:
"""Finds the python object with the given name."""
module, obj_name = get_module_from_obj_name(name)
return get_obj_from_module(module, obj_name)
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
"""Finds the python object with the given name and calls it as a function."""
assert func_name is not None
# print('func_name: ', func_name) #'training.dataset.ImageFolderDataset'
func_obj = get_obj_by_name(func_name)
assert callable(func_obj)
return func_obj(*args, **kwargs)
def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
"""Finds the python class with the given name and constructs it with the given arguments."""
return call_func_by_name(*args, func_name=class_name, **kwargs)
def get_module_dir_by_obj_name(obj_name: str) -> str:
"""Get the directory path of the module containing the given object name."""
module, _ = get_module_from_obj_name(obj_name)
return os.path.dirname(inspect.getfile(module))
def is_top_level_function(obj: Any) -> bool:
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
def get_top_level_function_name(obj: Any) -> str:
"""Return the fully-qualified name of a top-level function."""
assert is_top_level_function(obj)
module = obj.__module__
if module == '__main__':
module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
return module + "." + obj.__name__
# File system helpers
# ------------------------------------------------------------------------------------------
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
"""List all files recursively in a given directory while ignoring given file and directory names.
Returns list of tuples containing both absolute and relative paths."""
assert os.path.isdir(dir_path)
base_name = os.path.basename(os.path.normpath(dir_path))
if ignores is None:
ignores = []
result = []
for root, dirs, files in os.walk(dir_path, topdown=True):
for ignore_ in ignores:
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
# dirs need to be edited in-place
for d in dirs_to_remove:
dirs.remove(d)
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
absolute_paths = [os.path.join(root, f) for f in files]
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
if add_base_to_relative:
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
assert len(absolute_paths) == len(relative_paths)
result += zip(absolute_paths, relative_paths)
return result
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
"""Takes in a list of tuples of (src, dst) paths and copies files.
Will create all necessary directories."""
for file in files:
target_dir_name = os.path.dirname(file[1])
# will create all intermediate-level directories
if not os.path.exists(target_dir_name):
os.makedirs(target_dir_name)
shutil.copyfile(file[0], file[1])
# URL helpers
# ------------------------------------------------------------------------------------------
def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
"""Determine whether the given object is a valid URL string."""
if not isinstance(obj, str) or not "://" in obj:
return False
if allow_file_urls and obj.startswith('file://'):
return True
try:
res = requests.compat.urlparse(obj)
if not res.scheme or not res.netloc or not "." in res.netloc:
return False
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
if not res.scheme or not res.netloc or not "." in res.netloc:
return False
except:
return False
return True
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
"""Download the given URL and return a binary-mode file object to access the data."""
assert num_attempts >= 1
assert not (return_filename and (not cache))
# Doesn't look like an URL scheme so interpret it as a local filename.
if not re.match('^[a-z]+://', url):
return url if return_filename else open(url, "rb")
# Handle file URLs. This code handles unusual file:// patterns that
# arise on Windows:
#
# file:///c:/foo.txt
#
# which would translate to a local '/c:/foo.txt' filename that's
# invalid. Drop the forward slash for such pathnames.
#
# If you touch this code path, you should test it on both Linux and
# Windows.
#
# Some internet resources suggest using urllib.request.url2pathname() but
# but that converts forward slashes to backslashes and this causes
# its own set of problems.
if url.startswith('file://'):
filename = urllib.parse.urlparse(url).path
if re.match(r'^/[a-zA-Z]:', filename):
filename = filename[1:]
return filename if return_filename else open(filename, "rb")
assert is_url(url)
# Lookup from cache.
if cache_dir is None:
cache_dir = make_cache_dir_path('downloads')
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
if cache:
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
if len(cache_files) == 1:
filename = cache_files[0]
return filename if return_filename else open(filename, "rb")
# Download.
url_name = None
url_data = None
with requests.Session() as session:
if verbose:
print("Downloading %s ..." % url, end="", flush=True)
for attempts_left in reversed(range(num_attempts)):
try:
with session.get(url) as res:
res.raise_for_status()
if len(res.content) == 0:
raise IOError("No data received")
if len(res.content) < 8192:
content_str = res.content.decode("utf-8")
if "download_warning" in res.headers.get("Set-Cookie", ""):
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
if len(links) == 1:
url = requests.compat.urljoin(url, links[0])
raise IOError("Google Drive virus checker nag")
if "Google Drive - Quota exceeded" in content_str:
raise IOError("Google Drive download quota exceeded -- please try again later")
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
url_name = match[1] if match else url
url_data = res.content
if verbose:
print(" done")
break
except KeyboardInterrupt:
raise
except:
if not attempts_left:
if verbose:
print(" failed")
raise
if verbose:
print(".", end="", flush=True)
# Save to cache.
if cache:
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
os.makedirs(cache_dir, exist_ok=True)
with open(temp_file, "wb") as f:
f.write(url_data)
os.replace(temp_file, cache_file) # atomic
if return_filename:
return cache_file
# Return data as file object.
assert not return_filename
return io.BytesIO(url_data)
================================================
FILE: stylegan_human/docs/Dataset.md
================================================
# SHHQ Dataset
## Overview
SHHQ is a dataset with high-quality full-body human images in a resolution of 1024 × 512.
Since we need to follow a rigorous legal review in our institute, we can not release all of the data at once.
For now, SHHQ-1.0 with 40K images is released! More data will be released in the later versions.
## Data Sources
Images are collected in two main ways:
1) From the Internet.
We developed a crawler tool with an official API, mainly downloading images from Flickr, Pixabay and Pexels. So you need to meet all the following licenses when using the dataset: CC0, [Pixabay License](https://pixabay.com/service/license/), and [Pexels Licenses](https://www.pexels.com/license/).
2) From the data providers.
We purchased images from databases of individual photographers, modeling agencies and other suppliers.
Images were reviewed by our legal team prior to purchase to ensure permission for use in research.
### Note:
The composition of SHHQ-1.0:
1) Images obtained from the above sources.
2) Processed 9991 DeepFashion [[1]](#1) images (retain only full body images).
3) 1940 African images from the InFashAI [[2]](#2) dataset to increase data diversity.
## Data License
We are aware of privacy concerns and seriously treat the license and privacy issues. All released data will be ensured under the license of CC0 and free for research use. Also, persons in the dataset are anonymised without additional private or sensitive metadata.
## Agreement
The SHHQ is available for non-commercial research purposes only.
You agree not to reproduce, duplicate, copy, sell, trade, resell or exploit any portion of the images and any portion of the derived data for commercial purposes.
You agree NOT to further copy, publish or distribute any portion of SHHQ to any third party for any purpose. Except, for internal use at a single site within the same organization it is allowed to make copies of the dataset.
Shanghai AI Lab reserves the right to terminate your access to the SHHQ at any time.
## Dataset Preview
For those interested in our dataset, we provide a preview version with 100 images randomly sampled from SHHQ-1.0: [SHHQ-1.0_samples](https://drive.google.com/file/d/1tnNFfmFtzRbYL3qEnNXQ_ShaN9YV5tI5/view?usp=sharing).
In SHHQ-1.0, we provide aligned raw images along with machine-calculated segmentation masks. Later we are planning to release manually annotated human-parsing version of these 40,000 images. Please stay tuned.
> We also provide script [bg_white.py](../bg_white.py) to whiten the background of the raw image using its segmentation mask.
If you want to access the full SHHQ-1.0, please read the following instructions.
## Model trained using SHHQ-1.0
| Structure | 1024x512 | Metric | Scores | 512x256 | Metric | Scores |
| --------- |:----------:| :----------:| :----------:| :-----: | :-----: | :-----: |
| StyleGAN1 | to be released | - | - | to be released | - | - |
| StyleGAN2 | [SHHQ-1.0_sg2_1024.pkl](https://drive.google.com/file/d/1PuvE72xpc69Zq4y58dohuKbG9dFnnjEX/view?usp=sharing) | fid50k_full | 3.56 | [SHHQ-1.0_sg2_512.pkl](https://drive.google.com/file/d/170t2FRWxR8_TG3_y0nVtDBogLPOClnyf/view?usp=sharing) | fid50k_full | 3.68 |
| StyleGAN3 | to be released | - | - |to be released | - | - |
## Download Instructions
Please download the SHHQ Dataset Release Agreement from [link](./SHHQ_Dataset_Release_Agreement.pdf).
Read it carefully, complete and sign it appropriately.
Please send the completed form to Jianglin Fu (arlenefu@outlook.com) and Shikai Li (lishikai@pjlab.org.cn), and cc to Wayne Wu (wuwenyan0503@gmail.com) using institutional email address. The email Subject Title is "SHHQ Dataset Release Agreement". We will verify your request and contact you with the dataset link and password to unzip the image data.
Note:
1. We are currently facing large incoming applications, and we need to carefully verify all the applicants, please be patient, and we will reply to you as soon as possible.
2. The signature in the agreement should be hand-written.
## References
[1]
Liu, Ziwei and Luo, Ping and Qiu, Shi and Wang, Xiaogang and Tang, Xiaoou. DeepFashion: Powering Robust Clothes Recognition and Retrieval with Rich Annotations. CVPR (2016)
[2]
Hacheme, Gilles and Sayouti, Noureini. Neural fashion image captioning: Accounting for data diversity. arXiv preprint arXiv:2106.12154 (2021)
================================================
FILE: stylegan_human/edit/__init__.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
# empty
================================================
FILE: stylegan_human/edit/edit_config.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
attr_dict = dict(
interface_gan={ # strength
'upper_length': [-1], # strength: negative for shorter, positive for longer
'bottom_length': [1]
},
stylespace={ # layer, strength, threshold
'upper_length': [5, -5, 0.0028], # strength: negative for shorter, positive for longer
'bottom_length': [3, 5, 0.003]
},
sefa={ # layer, strength
'upper_length': [[4, 5, 6, 7], 5], #-5 # strength: negative for longer, positive for shorter
'bottom_length': [[4, 5, 6, 7], 5]
}
)
================================================
FILE: stylegan_human/edit/edit_helper.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
from legacy import save_obj, load_pkl
import torch
from torch.nn import functional as F
import pandas as pd
from .edit_config import attr_dict
import os
def conv_warper(layer, input, style, noise):
# the conv should change
conv = layer.conv
batch, in_channel, height, width = input.shape
style = style.view(batch, 1, in_channel, 1, 1)
weight = conv.scale * conv.weight * style
if conv.demodulate:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
weight = weight * demod.view(batch, conv.out_channel, 1, 1, 1)
weight = weight.view(
batch * conv.out_channel, in_channel, conv.kernel_size, conv.kernel_size
)
if conv.upsample:
input = input.view(1, batch * in_channel, height, width)
weight = weight.view(
batch, conv.out_channel, in_channel, conv.kernel_size, conv.kernel_size
)
weight = weight.transpose(1, 2).reshape(
batch * in_channel, conv.out_channel, conv.kernel_size, conv.kernel_size
)
out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, conv.out_channel, height, width)
out = conv.blur(out)
elif conv.downsample:
input = conv.blur(input)
_, _, height, width = input.shape
input = input.view(1, batch * in_channel, height, width)
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, conv.out_channel, height, width)
else:
input = input.view(1, batch * in_channel, height, width)
out = F.conv2d(input, weight, padding=conv.padding, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, conv.out_channel, height, width)
out = layer.noise(out, noise=noise)
out = layer.activate(out)
return out
def decoder(G, style_space, latent, noise):
# an decoder warper for G
out = G.input(latent)
out = conv_warper(G.conv1, out, style_space[0], noise[0])
skip = G.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(
G.convs[::2], G.convs[1::2], noise[1::2], noise[2::2], G.to_rgbs
):
out = conv_warper(conv1, out, style_space[i], noise=noise1)
out = conv_warper(conv2, out, style_space[i+1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
i += 2
image = skip
return image
def encoder_ifg(G, noise, attr_name, truncation=1, truncation_latent=None,
latent_dir='latent_direction/ss/',
step=0, total=0, real=False):
if not real:
styles = [noise]
styles = [G.style(s) for s in styles]
style_space = []
if truncation<1:
if not real:
style_t = []
for style in styles:
style_t.append(truncation_latent + truncation * (style - truncation_latent))
styles = style_t
else: # styles are latent (tensor: 1,18,512), for real PTI output
truncation_latent = truncation_latent.repeat(18,1).unsqueeze(0) # (1,512) --> (1,18,512)
styles = torch.add(truncation_latent,torch.mul(torch.sub(noise,truncation_latent),truncation))
noise = [getattr(G.noises, 'noise_{}'.format(i)) for i in range(G.num_layers)]
if not real:
inject_index = G.n_latent
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else: latent=styles
style_space.append(G.conv1.conv.modulation(latent[:, 0]))
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(
G.convs[::2], G.convs[1::2], noise[1::2], noise[2::2], G.to_rgbs
):
style_space.append(conv1.conv.modulation(latent[:, i]))
style_space.append(conv2.conv.modulation(latent[:, i+1]))
i += 2
# get layer, strength by dict
strength = attr_dict['interface_gan'][attr_name][0]
if step != 0 and total != 0:
strength = step / total * strength
for i in range(15):
style_vect = load_pkl(os.path.join(latent_dir, '{}/style_vect_mean_{}.pkl'.format(attr_name, i)))
style_vect = torch.from_numpy(style_vect).to(latent.device).float()
style_space[i] += style_vect * strength
return style_space, latent, noise
def encoder_ss(G, noise, attr_name, truncation=1, truncation_latent=None,
statics_dir="latent_direction/ss_statics",
latent_dir="latent_direction/ss/",
step=0, total=0,real=False):
if not real:
styles = [noise]
styles = [G.style(s) for s in styles]
style_space = []
if truncation<1:
if not real:
style_t = []
for style in styles:
style_t.append(
truncation_latent + truncation * (style - truncation_latent)
)
styles = style_t
else: # styles are latent (tensor: 1,18,512), for real PTI output
truncation_latent = truncation_latent.repeat(18,1).unsqueeze(0) # (1,512) --> (1,18,512)
styles = torch.add(truncation_latent,torch.mul(torch.sub(noise,truncation_latent),truncation))
noise = [getattr(G.noises, 'noise_{}'.format(i)) for i in range(G.num_layers)]
if not real:
inject_index = G.n_latent
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else: latent = styles
style_space.append(G.conv1.conv.modulation(latent[:, 0]))
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(
G.convs[::2], G.convs[1::2], noise[1::2], noise[2::2], G.to_rgbs
):
style_space.append(conv1.conv.modulation(latent[:, i]))
style_space.append(conv2.conv.modulation(latent[:, i+1]))
i += 2
# get threshold, layer, strength by dict
layer, strength, threshold = attr_dict['stylespace'][attr_name]
statis_dir = os.path.join(statics_dir, "{}_statis/{}".format(attr_name, layer))
statis_csv_path = os.path.join(statis_dir, "statis.csv")
statis_df = pd.read_csv(statis_csv_path)
statis_df = statis_df.sort_values(by='channel', ascending=True)
ch_mask = statis_df['strength'].values
ch_mask = torch.from_numpy(ch_mask).to(latent.device).float()
ch_mask = (ch_mask.abs()>threshold).float()
style_vect = load_pkl(os.path.join(latent_dir, '{}/style_vect_mean_{}.pkl'.format(attr_name, layer)))
style_vect = torch.from_numpy(style_vect).to(latent.device).float()
style_vect = style_vect * ch_mask
if step != 0 and total != 0:
strength = step / total * strength
style_space[layer] += style_vect * strength
return style_space, latent, noise
def encoder_sefa(G, noise, attr_name, truncation=1, truncation_latent=None,
latent_dir='latent_direction/sefa/',
step=0, total=0, real=False):
if not real:
styles = [noise]
styles = [G.style(s) for s in styles]
if truncation<1:
if not real:
style_t = []
for style in styles:
style_t.append(
truncation_latent + truncation * (style - truncation_latent)
)
styles = style_t
else:
truncation_latent = truncation_latent.repeat(18,1).unsqueeze(0) # (1,512) --> (1,18,512)
styles = torch.add(truncation_latent,torch.mul(torch.sub(noise,truncation_latent),truncation))
noise = [getattr(G.noises, 'noise_{}'.format(i)) for i in range(G.num_layers)]
if not real:
inject_index = G.n_latent
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else: latent = styles
layer, strength = attr_dict['sefa'][attr_name]
sefa_vect = torch.load(os.path.join(latent_dir, '{}.pt'.format(attr_name))).to(latent.device).float()
if step != 0 and total != 0:
strength = step / total * strength
for l in layer:
latent[:, l, :] += (sefa_vect * strength * 2)
return latent, noise
================================================
FILE: stylegan_human/edit.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
import os
import sys
import torch
import numpy as np
sys.path.append(".")
from torch_utils.models import Generator
import click
import cv2
from typing import List, Optional
import subprocess
import legacy
from edit.edit_helper import conv_warper, decoder, encoder_ifg, encoder_ss, encoder_sefa
"""
Edit generated images with different SOTA methods.
Notes:
1. We provide some latent directions in the folder, you can play around with them.
2. ''upper_length'' and ''bottom_length'' of ''attr_name'' are available for demo.
3. Layers to control and editing strength are set in edit/edit_config.py.
Examples:
\b
# Editing with InterfaceGAN, StyleSpace, and Sefa
python edit.py --network pretrained_models/stylegan_human_v2_1024.pkl --attr_name upper_length \\
--seeds 61531,61570,61571,61610 --outdir outputs/edit_results
# Editing using inverted latent code
python edit.py ---network outputs/pti/checkpoints/model_test.pkl --attr_name upper_length \\
--outdir outputs/edit_results --real True --real_w_path outputs/pti/embeddings/test/PTI/test/0.pt --real_img_path aligned_image/test.png
"""
@click.command()
@click.pass_context
@click.option('--network', 'ckpt_path', help='Network pickle filename', required=True)
@click.option('--attr_name', help='choose one of the attr: upper_length or bottom_length', type=str, required=True)
@click.option('--trunc', 'truncation', type=float, help='Truncation psi', default=0.8, show_default=True)
@click.option('--gen_video', type=bool, default=True, help='If want to generate video')
@click.option('--combine', type=bool, default=True, help='If want to combine different editing results in the same frame')
@click.option('--seeds', type=legacy.num_range, help='List of random seeds')
@click.option('--outdir', help='Where to save the output images', type=str, required=True, default='outputs/editing', metavar='DIR')
@click.option('--real', type=bool, help='True for editing real image', default=False)
@click.option('--real_w_path', help='Path of latent code for real image')
@click.option('--real_img_path', help='Path of real image, this just concat real image with inverted and edited results together')
def main(
ctx: click.Context,
ckpt_path: str,
attr_name: str,
truncation: float,
gen_video: bool,
combine: bool,
seeds: Optional[List[int]],
outdir: str,
real: str,
real_w_path: str,
real_img_path: str
):
## convert pkl to pth
# if not os.path.exists(ckpt_path.replace('.pkl','.pth')):
legacy.convert(ckpt_path, ckpt_path.replace('.pkl','.pth'), G_only=real)
ckpt_path = ckpt_path.replace('.pkl','.pth')
print("start...", flush=True)
config = {"latent" : 512, "n_mlp" : 8, "channel_multiplier": 2}
generator = Generator(
size = 1024,
style_dim=config["latent"],
n_mlp=config["n_mlp"],
channel_multiplier=config["channel_multiplier"]
)
generator.load_state_dict(torch.load(ckpt_path)['g_ema'])
generator.eval().cuda()
with torch.no_grad():
mean_path = os.path.join('edit','mean_latent.pkl')
if not os.path.exists(mean_path):
mean_n = 3000
mean_latent = generator.mean_latent(mean_n).detach()
legacy.save_obj(mean_latent, mean_path)
else:
mean_latent = legacy.load_pkl(mean_path).cuda()
finals = []
## -- selected sample seeds -- ##
# seeds = [60948,60965,61174,61210,61511,61598,61610] #bottom -> long
# [60941,61064,61103,61313,61531,61570,61571] # bottom -> short
# [60941,60965,61064,61103,6117461210,61531,61570,61571,61610] # upper --> long
# [60948,61313,61511,61598] # upper --> short
if real: seeds = [0]
for t in seeds:
if real: # now assume process single real image only
if real_img_path:
real_image = cv2.imread(real_img_path)
real_image = cv2.cvtColor(real_image, cv2.COLOR_BGR2RGB)
import torchvision.transforms as transforms
transform = transforms.Compose( # normalize to (-1, 1)
[transforms.ToTensor(),
transforms.Normalize(mean=(.5,.5,.5), std=(.5,.5,.5))]
)
real_image = transform(real_image).unsqueeze(0).cuda()
test_input = torch.load(real_w_path)
output, _ = generator(test_input, False, truncation=1,input_is_latent=True, real=True)
else: # generate image from random seeds
test_input = torch.from_numpy(np.random.RandomState(t).randn(1, 512)).float().cuda() # torch.Size([1, 512])
output, _ = generator([test_input], False, truncation=truncation, truncation_latent=mean_latent, real=real)
# interfacegan
style_space, latent, noise = encoder_ifg(generator, test_input, attr_name, truncation, mean_latent,real=real)
image1 = decoder(generator, style_space, latent, noise)
# stylespace
style_space, latent, noise = encoder_ss(generator, test_input, attr_name, truncation, mean_latent,real=real)
image2 = decoder(generator, style_space, latent, noise)
# sefa
latent, noise = encoder_sefa(generator, test_input, attr_name, truncation, mean_latent,real=real)
image3, _ = generator([latent], noise=noise, input_is_latent=True)
if real_img_path:
final = torch.cat((real_image, output, image1, image2, image3), 3)
else:
final = torch.cat((output, image1, image2, image3), 3)
# legacy.visual(output, f'{outdir}/{attr_name}_{t:05d}_raw.jpg')
# legacy.visual(image1, f'{outdir}/{attr_name}_{t:05d}_ifg.jpg')
# legacy.visual(image2, f'{outdir}/{attr_name}_{t:05d}_ss.jpg')
# legacy.visual(image3, f'{outdir}/{attr_name}_{t:05d}_sefa.jpg')
if gen_video:
total_step = 90
if real:
video_ifg_path = f"{outdir}/video/ifg_{attr_name}_{real_w_path.split('/')[-2]}/"
video_ss_path = f"{outdir}/video/ss_{attr_name}_{real_w_path.split('/')[-2]}/"
video_sefa_path = f"{outdir}/video/ss_{attr_name}_{real_w_path.split('/')[-2]}/"
else:
video_ifg_path = f"{outdir}/video/ifg_{attr_name}_{t:05d}/"
video_ss_path = f"{outdir}/video/ss_{attr_name}_{t:05d}/"
video_sefa_path = f"{outdir}/video/ss_{attr_name}_{t:05d}/"
video_comb_path = f"{outdir}/video/tmp"
if combine:
if not os.path.exists(video_comb_path):
os.makedirs(video_comb_path)
else:
if not os.path.exists(video_ifg_path):
os.makedirs(video_ifg_path)
if not os.path.exists(video_ss_path):
os.makedirs(video_ss_path)
if not os.path.exists(video_sefa_path):
os.makedirs(video_sefa_path)
for i in range(total_step):
style_space, latent, noise = encoder_ifg(generator, test_input, attr_name, truncation, mean_latent, step=i, total=total_step,real=real)
image1 = decoder(generator, style_space, latent, noise)
style_space, latent, noise = encoder_ss(generator, test_input, attr_name, truncation, mean_latent, step=i, total=total_step,real=real)
image2 = decoder(generator, style_space, latent, noise)
latent, noise = encoder_sefa(generator, test_input, attr_name, truncation, mean_latent, step=i, total=total_step,real=real)
image3, _ = generator([latent], noise=noise, input_is_latent=True)
if combine:
if real_img_path:
comb_img = torch.cat((real_image, output, image1, image2, image3), 3)
else:
comb_img = torch.cat((output, image1, image2, image3), 3)
legacy.visual(comb_img, os.path.join(video_comb_path, f'{i:05d}.jpg'))
else:
legacy.visual(image1, os.path.join(video_ifg_path, f'{i:05d}.jpg'))
legacy.visual(image2, os.path.join(video_ss_path, f'{i:05d}.jpg'))
if combine:
cmd=f"ffmpeg -hide_banner -loglevel error -y -r 30 -i {video_comb_path}/%05d.jpg -vcodec libx264 -pix_fmt yuv420p {video_ifg_path.replace('ifg_', '')[:-1] + '.mp4'}"
subprocess.call(cmd, shell=True)
else:
cmd=f"ffmpeg -hide_banner -loglevel error -y -r 30 -i {video_ifg_path}/%05d.jpg -vcodec libx264 -pix_fmt yuv420p {video_ifg_path[:-1] + '.mp4'}"
subprocess.call(cmd, shell=True)
cmd=f"ffmpeg -hide_banner -loglevel error -y -r 30 -i {video_ss_path}/%05d.jpg -vcodec libx264 -pix_fmt yuv420p {video_ss_path[:-1] + '.mp4'}"
subprocess.call(cmd, shell=True)
# interfacegan, stylespace, sefa
finals.append(final)
final = torch.cat(finals, 2)
legacy.visual(final, os.path.join(outdir,'final.jpg'))
if __name__ == "__main__":
main()
================================================
FILE: stylegan_human/environment.yml
================================================
name: stylehuman
channels:
- pytorch
- nvidia
dependencies:
- python == 3.8
- pip
- numpy>=1.20
- click>=8.0
- pillow=8.3.1
- scipy=1.7.1
- pytorch=1.9.1
- cudatoolkit=11.1
- requests=2.26.0
- tqdm=4.62.2
- ninja=1.10.2
- matplotlib=3.4.2
- imageio=2.9.0
- pip:
- imgui==1.3.0
- glfw==2.2.0
- pyopengl==3.1.5
- imageio-ffmpeg==0.4.3
- lpips==0.1.4
- pyspng
- dlib
- opencv-python
- pandas
- moviepy
- imutils
================================================
FILE: stylegan_human/generate.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://nvlabs.github.io/stylegan2/license.html
## this script is for generating images from pre-trained network based on StyleGAN1 (TensorFlow) and StyleGAN2-ada (PyTorch) ##
import os
import click
import dnnlib
import numpy as np
import PIL.Image
import legacy
from typing import List, Optional
"""
Generate images using pretrained network pickle.
Examples:
\b
# Generate human full-body images without truncation
python generate.py --outdir=outputs/generate/stylegan_human_v2_1024 --trunc=1 --seeds=1,3,5,7 \\
--network=pretrained_models/stylegan_human_v2_1024.pkl --version 2
\b
# Generate human full-body images with truncation
python generate.py --outdir=outputs/generate/stylegan_human_v2_1024 --trunc=0.8 --seeds=0-100\\
--network=pretrained_models/stylegan_human_v2_1024.pkl --version 2
# \b
# Generate human full-body images using stylegan V1
# python generate.py --outdir=outputs/generate/stylegan_human_v1_1024 \\
# --network=pretrained_models/stylegan_human_v1_1024.pkl --version 1
"""
@click.command()
@click.pass_context
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
@click.option('--seeds', type=legacy.num_range, help='List of random seeds')
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
@click.option('--outdir', help='Where to save the output images', default= 'outputs/generate/' , type=str, required=True, metavar='DIR')
@click.option('--version', help="stylegan version, 1, 2 or 3", type=int, default=2)
def generate_images(
ctx: click.Context,
network_pkl: str,
seeds: Optional[List[int]],
truncation_psi: float,
noise_mode: str,
outdir: str,
version: int
):
print('Loading networks from "%s"...' % network_pkl)
if version == 1:
import dnnlib.tflib as tflib
tflib.init_tf()
G, D, Gs = legacy.load_pkl(network_pkl)
else:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
dtype = torch.float32 if device.type == 'mps' else torch.float64
with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype) # type: ignore
os.makedirs(outdir, exist_ok=True)
if seeds is None:
ctx.fail('--seeds option is required.')
# Generate images.
target_z = np.array([])
target_w = np.array([])
latent_out = outdir.replace('/images/','')
for seed_idx, seed in enumerate(seeds):
if seed % 5000 == 0:
print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
if version == 1: ## stylegan v1
z = np.random.RandomState(seed).randn(1, Gs.input_shape[1])
# Generate image.
fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
if noise_mode == 'const': randomize_noise=False
else: randomize_noise = True
images = Gs.run(z, None, truncation_psi=truncation_psi, randomize_noise=randomize_noise, output_transform=fmt)
PIL.Image.fromarray(images[0], 'RGB').save(f'{outdir}/seed{seed:04d}.png')
else: ## stylegan v2/v3
label = torch.zeros([1, G.c_dim], device=device)
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device, dtype=dtype)
if target_z.size==0:
target_z= z.cpu()
else:
target_z=np.append(target_z, z.cpu(), axis=0)
w = G.mapping(z, label,truncation_psi=truncation_psi)
img = G.synthesis(w, noise_mode=noise_mode,force_fp32 = True)
if target_w.size==0:
target_w= w.cpu()
else:
target_w=np.append(target_w, w.cpu(), axis=0)
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
# print(target_z)
# print(target_z.shape,target_w.shape)
#----------------------------------------------------------------------------
if __name__ == "__main__":
generate_images()
#----------------------------------------------------------------------------
================================================
FILE: stylegan_human/insetgan.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
import torch
import torch.nn.functional as F
from tqdm import tqdm
from lpips import LPIPS
import numpy as np
from torch_utils.models import Generator as bodyGAN
from torch_utils.models_face import Generator as FaceGAN
import dlib
from utils.face_alignment import align_face_for_insetgan
from utils.util import visual,tensor_to_numpy, numpy_to_tensor
import legacy
import os
import click
class InsetGAN(torch.nn.Module):
def __init__(self, stylebody_ckpt, styleface_ckpt):
super().__init__()
## convert pkl to pth
if not os.path.exists(stylebody_ckpt.replace('.pkl','.pth')):
legacy.convert(stylebody_ckpt, stylebody_ckpt.replace('.pkl','.pth'))
stylebody_ckpt = stylebody_ckpt.replace('.pkl','.pth')
if not os.path.exists(styleface_ckpt.replace('.pkl','.pth')):
legacy.convert(styleface_ckpt, styleface_ckpt.replace('.pkl','.pth'))
styleface_ckpt = styleface_ckpt.replace('.pkl','.pth')
# dual generator
config = {"latent" : 512, "n_mlp" : 8, "channel_multiplier": 2}
self.body_generator = bodyGAN(
size = 1024,
style_dim=config["latent"],
n_mlp=config["n_mlp"],
channel_multiplier=config["channel_multiplier"]
)
self.body_generator.load_state_dict(torch.load(stylebody_ckpt)['g_ema'])
self.body_generator.eval().requires_grad_(False).cuda()
self.face_generator = FaceGAN(
size = 1024,
style_dim=config["latent"],
n_mlp=config["n_mlp"],
channel_multiplier=config["channel_multiplier"]
)
self.face_generator.load_state_dict(torch.load(styleface_ckpt)['g_ema'])
self.face_generator.eval().requires_grad_(False).cuda()
# crop function
self.dlib_predictor = dlib.shape_predictor('./pretrained_models/shape_predictor_68_face_landmarks.dat')
self.dlib_cnn_face_detector = dlib.cnn_face_detection_model_v1("pretrained_models/mmod_human_face_detector.dat")
# criterion
self.lpips_loss = LPIPS(net='alex').cuda().eval()
self.l1_loss = torch.nn.L1Loss(reduction='mean')
def loss_coarse(self, A_face, B, p1=500, p2=0.05):
A_face = F.interpolate(A_face, size=(64, 64), mode='area')
B = F.interpolate(B, size=(64, 64), mode='area')
loss_l1 = p1 * self.l1_loss(A_face, B)
loss_lpips = p2 * self.lpips_loss(A_face, B)
return loss_l1 + loss_lpips
@staticmethod
def get_border_mask(A, x, spec):
mask = torch.zeros_like(A)
mask[:, :, :x, ] = 1
mask[:, :, -x:, ] = 1
mask[:, :, :, :x ] = 1
mask[:, :, :, -x:] = 1
return mask
@staticmethod
def get_body_mask(A, crop, padding=4):
mask = torch.ones_like(A)
mask[:, :, crop[1]-padding:crop[3]+padding, crop[0]-padding:crop[2]+padding] = 0
return mask
def loss_border(self, A_face, B, p1=10000, p2=2, spec=None):
mask = self.get_border_mask(A_face, 8, spec)
loss_l1 = p1 * self.l1_loss(A_face*mask, B*mask)
loss_lpips = p2 * self.lpips_loss(A_face*mask, B*mask)
return loss_l1 + loss_lpips
def loss_body(self, A, B, crop, p1=9000, p2=0.1):
padding = int((crop[3] - crop[1]) / 20)
mask = self.get_body_mask(A, crop, padding)
loss_l1 = p1 * self.l1_loss(A*mask, B*mask)
loss_lpips = p2 * self.lpips_loss(A*mask, B*mask)
return loss_l1+loss_lpips
def loss_face(self, A, B, crop, p1=5000, p2=1.75):
mask = 1 - self.get_body_mask(A, crop)
loss_l1 = p1 * self.l1_loss(A*mask, B*mask)
loss_lpips = p2 * self.lpips_loss(A*mask, B*mask)
return loss_l1+loss_lpips
def loss_reg(self, w, w_mean, p1, w_plus_delta=None, p2=None):
return p1 * torch.mean(((w - w_mean) ** 2)) + p2 * torch.mean(w_plus_delta ** 2)
# FFHQ type
def detect_face_dlib(self, img):
# tensor to numpy array rgb uint8
img = tensor_to_numpy(img)
aligned_image, crop, rect = align_face_for_insetgan(img=img,
detector=self.dlib_cnn_face_detector,
predictor=self.dlib_predictor,
output_size=256)
aligned_image = np.array(aligned_image)
aligned_image = numpy_to_tensor(aligned_image)
return aligned_image, crop, rect
# joint optimization
def dual_optimizer(self,
face_w,
body_w,
joint_steps=500,
face_initial_learning_rate=0.02,
body_initial_learning_rate=0.05,
lr_rampdown_length=0.25,
lr_rampup_length=0.05,
seed=None,
output_path=None,
video=0):
'''
Given a face_w, optimize a body_w with suitable body pose & shape for face_w
'''
def visual_(path, synth_body, synth_face, body_crop, step, both=False, init_body_with_face=None):
tmp = synth_body.clone().detach()
tmp[:, :, body_crop[1]:body_crop[3], body_crop[0]:body_crop[2]] = synth_face
if both:
tmp = torch.cat([synth_body, tmp], dim=3)
save_path = os.path.join(path, f"{step:04d}.jpg")
visual(tmp, save_path)
def forward(face_w_opt,
body_w_opt,
face_w_delta,
body_w_delta,
body_crop,
update_crop=False
):
if face_w_opt.shape[1] != 18:
face_ws = (face_w_opt).repeat([1, 18, 1])
else:
face_ws = face_w_opt.clone()
face_ws = face_ws + face_w_delta
synth_face, _ = self.face_generator([face_ws], input_is_latent=True, randomize_noise=False)
body_ws = (body_w_opt).repeat([1, 18, 1])
body_ws = body_ws + body_w_delta
synth_body, _ = self.body_generator([body_ws], input_is_latent=True, randomize_noise=False)
if update_crop:
old_r = (body_crop[3]-body_crop[1]) // 2, (body_crop[2]-body_crop[0]) // 2
_, body_crop, _ = self.detect_face_dlib(synth_body)
center = (body_crop[1] + body_crop[3]) // 2, (body_crop[0] + body_crop[2]) // 2
body_crop = (center[1] - old_r[1], center[0] - old_r[0], center[1] + old_r[1], center[0] + old_r[0])
synth_body_face = synth_body[:, :, body_crop[1]:body_crop[3], body_crop[0]:body_crop[2]]
if synth_face.shape[2] > body_crop[3]-body_crop[1]:
synth_face_resize = F.interpolate(synth_face, size=(body_crop[3]-body_crop[1], body_crop[2]-body_crop[0]), mode='area')
return synth_body, synth_body_face, synth_face, synth_face_resize, body_crop
def update_lr(init_lr, step, num_steps, lr_rampdown_length, lr_rampup_length):
t = step / num_steps
lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
lr = init_lr * lr_ramp
return lr
# update output_path
output_path = os.path.join(output_path, seed)
os.makedirs(output_path, exist_ok=True)
# define optimized params
body_w_mean = self.body_generator.mean_latent(10000).detach()
face_w_opt = face_w.clone().detach().requires_grad_(True)
body_w_opt = body_w.clone().detach().requires_grad_(True)
face_w_delta = torch.zeros_like(face_w.repeat([1, 18, 1])).requires_grad_(True)
body_w_delta = torch.zeros_like(body_w.repeat([1, 18, 1])).requires_grad_(True)
# generate ref face & body
ref_body, _ = self.body_generator([body_w.repeat([1, 18, 1])], input_is_latent=True, randomize_noise=False)
# for inversion
ref_face, _ = self.face_generator([face_w.repeat([1, 18, 1])], input_is_latent=True, randomize_noise=False)
# get initilized crop
_, body_crop, _ = self.detect_face_dlib(ref_body)
_, _, face_crop = self.detect_face_dlib(ref_face) # NOTE: this is face rect only. no FFHQ type.
# create optimizer
face_optimizer = torch.optim.Adam([face_w_opt, face_w_delta], betas=(0.9, 0.999), lr=face_initial_learning_rate)
body_optimizer = torch.optim.Adam([body_w_opt, body_w_delta], betas=(0.9, 0.999), lr=body_initial_learning_rate)
global_step = 0
# Stage1: remove background of face image
face_steps = 25
pbar = tqdm(range(face_steps))
for step in pbar:
face_lr = update_lr(face_initial_learning_rate / 2, step, face_steps, lr_rampdown_length, lr_rampup_length)
for param_group in face_optimizer.param_groups:
param_group['lr'] =face_lr
synth_body, synth_body_face, synth_face_raw, synth_face, body_crop = forward(face_w_opt,
body_w_opt,
face_w_delta,
body_w_delta,
body_crop)
loss_face = self.loss_face(synth_face_raw, ref_face, face_crop, 5000, 1.75)
loss_coarse = self.loss_coarse(synth_face, synth_body_face, 50, 0.05)
loss_border = self.loss_border(synth_face, synth_body_face, 1000, 0.1)
loss = loss_coarse + loss_border + loss_face
face_optimizer.zero_grad()
loss.backward()
face_optimizer.step()
# visualization
if video:
visual_(output_path, synth_body, synth_face, body_crop, global_step)
pbar.set_description(
(
f"face: {step:.4f}, lr: {face_lr}, loss: {loss.item():.2f}, loss_coarse: {loss_coarse.item():.2f};"
f"loss_border: {loss_border.item():.2f}, loss_face: {loss_face.item():.2f};"
)
)
global_step += 1
# Stage2: find a suitable body
body_steps = 150
pbar = tqdm(range(body_steps))
for step in pbar:
body_lr = update_lr(body_initial_learning_rate, step, body_steps, lr_rampdown_length, lr_rampup_length)
update_crop = True if (step % 50 == 0) else False
# update_crop = False
for param_group in body_optimizer.param_groups:
param_group['lr'] =body_lr
synth_body, synth_body_face, synth_face_raw, synth_face, body_crop = forward(face_w_opt,
body_w_opt,
face_w_delta,
body_w_delta,
body_crop,
update_crop=update_crop)
loss_coarse = self.loss_coarse(synth_face, synth_body_face, 500, 0.05)
loss_border = self.loss_border(synth_face, synth_body_face, 2500, 0)
loss_body = self.loss_body(synth_body, ref_body, body_crop, 9000, 0.1)
loss_reg = self.loss_reg(body_w_opt, body_w_mean, 15000, body_w_delta, 0)
loss = loss_coarse + loss_border + loss_body + loss_reg
body_optimizer.zero_grad()
loss.backward()
body_optimizer.step()
# visualization
if video:
visual_(output_path, synth_body, synth_face, body_crop, global_step)
pbar.set_description(
(
f"body: {step:.4f}, lr: {body_lr}, loss: {loss.item():.2f}, loss_coarse: {loss_coarse.item():.2f};"
f"loss_border: {loss_border.item():.2f}, loss_body: {loss_body.item():.2f}, loss_reg: {loss_reg:.2f}"
)
)
global_step += 1
# Stage3: joint optimization
interval = 50
joint_face_steps = joint_steps // 2
joint_body_steps = joint_steps // 2
face_step = 0
body_step = 0
pbar = tqdm(range(joint_steps))
flag = -1
for step in pbar:
if step % interval == 0: flag += 1
text_flag = 'optimize_face' if flag % 2 == 0 else 'optimize_body'
synth_body, synth_body_face, synth_face_raw, synth_face, body_crop = forward(face_w_opt,
body_w_opt,
face_w_delta,
body_w_delta,
body_crop)
if text_flag == 'optimize_face':
face_lr = update_lr(face_initial_learning_rate, face_step, joint_face_steps, lr_rampdown_length, lr_rampup_length)
for param_group in face_optimizer.param_groups:
param_group['lr'] =face_lr
loss_face = self.loss_face(synth_face_raw, ref_face, face_crop, 5000, 1.75)
loss_coarse = self.loss_coarse(synth_face, synth_body_face, 500, 0.05)
loss_border = self.loss_border(synth_face, synth_body_face, 25000, 0)
loss = loss_coarse + loss_border + loss_face
face_optimizer.zero_grad()
loss.backward()
face_optimizer.step()
pbar.set_description(
(
f"face: {step}, lr: {face_lr:.4f}, loss: {loss.item():.2f}, loss_coarse: {loss_coarse.item():.2f};"
f"loss_border: {loss_border.item():.2f}, loss_face: {loss_face.item():.2f};"
)
)
face_step += 1
else:
body_lr = update_lr(body_initial_learning_rate, body_step, joint_body_steps, lr_rampdown_length, lr_rampup_length)
for param_group in body_optimizer.param_groups:
param_group['lr'] =body_lr
loss_coarse = self.loss_coarse(synth_face, synth_body_face, 500, 0.05)
loss_border = self.loss_border(synth_face, synth_body_face, 2500, 0)
loss_body = self.loss_body(synth_body, ref_body, body_crop, 9000, 0.1)
loss_reg = self.loss_reg(body_w_opt, body_w_mean, 25000, body_w_delta, 0)
loss = loss_coarse + loss_border + loss_body + loss_reg
body_optimizer.zero_grad()
loss.backward()
body_optimizer.step()
pbar.set_description(
(
f"body: {step}, lr: {body_lr:.4f}, loss: {loss.item():.2f}, loss_coarse: {loss_coarse.item():.2f};"
f"loss_border: {loss_border.item():.2f}, loss_body: {loss_body.item():.2f}, loss_reg: {loss_reg:.2f}"
)
)
body_step += 1
if video:
visual_(output_path, synth_body, synth_face, body_crop, global_step)
global_step += 1
return face_w_opt.repeat([1, 18, 1])+face_w_delta, body_w_opt.repeat([1, 18, 1])+body_w_delta, body_crop
"""
Jointly combine and optimize generated faces and bodies .
Examples:
\b
# Combine the generate human full-body image from the provided StyleGAN-Human pre-trained model
# and the generated face image from FFHQ model, optimize both latent codes to produce the coherent face-body image
python insetgan.py --body_network=pretrained_models/stylegan_human_v2_1024.pkl --face_network=pretrained_models/ffhq.pkl \\
--body_seed=82 --face_seed=43 --trunc=0.6 --outdir=outputs/insetgan/ --video 1
"""
@click.command()
@click.pass_context
@click.option('--face_network', default="./pretrained_models/ffhq.pkl", help='Network pickle filename', required=True)
@click.option('--body_network', default='./pretrained_models/stylegan2_1024.pkl', help='Network pickle filename', required=True)
@click.option('--face_seed', type=int, default=82, help='selected random seed')
@click.option('--body_seed', type=int, default=43, help='selected random seed')
@click.option('--joint_steps', type=int, default=500, help='num steps for joint optimization')
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=0.6, show_default=True)
@click.option('--outdir', help='Where to save the output images', default= "outputs/insetgan/" , type=str, required=True, metavar='DIR')
@click.option('--video', help="set to 1 if want to save video", type=int, default=0)
def main(
ctx: click.Context,
face_network: str,
body_network: str,
face_seed: int,
body_seed: int,
joint_steps: int,
truncation_psi: float,
outdir: str,
video: int):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
insgan = InsetGAN(body_network, face_network)
os.makedirs(outdir, exist_ok=True)
face_z = np.random.RandomState(face_seed).randn(1, 512).astype(np.float32)
face_mean = insgan.face_generator.mean_latent(3000)
face_w = insgan.face_generator.get_latent(torch.from_numpy(face_z).to(device)) # [N, L, C]
face_w = truncation_psi * face_w + (1-truncation_psi) * face_mean
face_img, _ = insgan.face_generator([face_w], input_is_latent=True)
body_z = np.random.RandomState(body_seed).randn(1, 512).astype(np.float32)
body_mean = insgan.body_generator.mean_latent(3000)
body_w = insgan.body_generator.get_latent(torch.from_numpy(body_z).to(device)) # [N, L, C]
body_w = truncation_psi * body_w + (1-truncation_psi) * body_mean
body_img, _ = insgan.body_generator([body_w], input_is_latent=True)
_, body_crop, _ = insgan.detect_face_dlib(body_img)
face_img = F.interpolate(face_img, size=(body_crop[3]-body_crop[1], body_crop[2]-body_crop[0]), mode='area')
cp_body = body_img.clone()
cp_body[:, :, body_crop[1]:body_crop[3], body_crop[0]:body_crop[2]] = face_img
optim_face_w, optim_body_w, crop = insgan.dual_optimizer(
face_w,
body_w,
joint_steps=joint_steps,
seed=f'{face_seed:04d}_{body_seed:04d}',
output_path=outdir,
video=video
)
if video:
ffmpeg_cmd = f"ffmpeg -hide_banner -loglevel error -i ./{outdir}/{face_seed:04d}_{body_seed:04d}/%04d.jpg -c:v libx264 -vf fps=30 -pix_fmt yuv420p ./{outdir}/{face_seed:04d}_{body_seed:04d}.mp4"
os.system(ffmpeg_cmd)
new_face_img, _ = insgan.face_generator([optim_face_w], input_is_latent=True)
new_shape = crop[3] - crop[1], crop[2] - crop[0]
new_face_img_crop = F.interpolate(new_face_img, size=new_shape, mode='area')
seamless_body, _ = insgan.body_generator([optim_body_w], input_is_latent=True)
seamless_body[:, :, crop[1]:crop[3], crop[0]:crop[2]] = new_face_img_crop
temp = torch.cat([cp_body, seamless_body], dim=3)
visual(temp, f"{outdir}/{face_seed:04d}_{body_seed:04d}.png")
if __name__ == "__main__":
main()
================================================
FILE: stylegan_human/interpolation.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
## interpolate between two z code
## score all middle latent code
# https://www.aiuai.cn/aifarm1929.html
import os
import re
from typing import List
from tqdm import tqdm
import click
import dnnlib
import numpy as np
import PIL.Image
import torch
import click
import legacy
import random
from typing import List, Optional
def lerp(code1, code2, alpha):
return code1 * alpha + code2 * (1 - alpha)
# Taken and adapted from wikipedia's slerp article
# https://en.wikipedia.org/wiki/Slerp
def slerp(code1, code2, alpha, DOT_THRESHOLD=0.9995): # Spherical linear interpolation
code1_copy = np.copy(code1)
code2_copy = np.copy(code2)
code1 = code1 / np.linalg.norm(code1)
code2 = code2 / np.linalg.norm(code2)
dot = np.sum(code1 * code2)
if np.abs(dot) > DOT_THRESHOLD:
return lerp(code1_copy, code2_copy, alpha)
# Calculate initial angle between v0 and v1
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
# Angle at timestep t
theta_t = theta_0 * alpha
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
code3 = s0 * code1_copy + s1 * code2_copy
return code3
def generate_image_from_z(G, z, noise_mode, truncation_psi, device):
label = torch.zeros([1, G.c_dim], device=device)
w = G.mapping(z, label,truncation_psi=truncation_psi)
img = G.synthesis(w, noise_mode=noise_mode,force_fp32 = True)
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
return img
def get_concat_h(im1, im2):
dst = PIL.Image.new('RGB', (im1.width + im2.width, im1.height))
dst.paste(im1, (0, 0))
dst.paste(im2, (im1.width, 0))
return dst
def make_latent_interp_animation(G, code1, code2, img1, img2, num_interps, noise_mode, save_mid_image, truncation_psi,device, outdir,fps):
step_size = 1.0/num_interps
all_imgs = []
amounts = np.arange(0, 1, step_size)
for seed_idx, alpha in enumerate(tqdm(amounts)):
interpolated_latent_code = lerp(code1, code2, alpha)
image = generate_image_from_z(G,interpolated_latent_code, noise_mode, truncation_psi, device)
interp_latent_image = image.resize((512, 1024))
if not os.path.exists(os.path.join(outdir,'img')): os.makedirs(os.path.join(outdir,'img'), exist_ok=True)
if save_mid_image:
interp_latent_image.save(f'{outdir}/img/seed{seed_idx:04d}.png')
frame = get_concat_h(img2, interp_latent_image)
frame = get_concat_h(frame, img1)
all_imgs.append(frame)
save_name = os.path.join(outdir,'latent_space_traversal.gif')
all_imgs[0].save(save_name, save_all=True, append_images=all_imgs[1:], duration=1000/fps, loop=0)
"""
Create interpolated images between two given seeds using pretrained network pickle.
Examples:
\b
python interpolation.py --network=pretrained_models/stylegan_human_v2_1024.pkl --seeds=85,100 --outdir=outputs/inter_gifs
"""
@click.command()
@click.pass_context
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
@click.option('--seeds', type=legacy.num_range, help='List of 2 random seeds, e.g. 1,2')
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=0.8, show_default=True)
@click.option('--noise-mode', 'noise_mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
@click.option('--outdir', default= 'outputs/inter_gifs', help='Where to save the output images', type=str, required=True, metavar='DIR')
@click.option('--save_mid_image', default=True, type=bool, help='select True if you want to save all interpolated images')
@click.option('--fps', default= 15, help='FPS for GIF', type=int)
@click.option('--num_interps', default= 100, help='Number of interpolation images', type=int)
def main(
ctx: click.Context,
network_pkl: str,
seeds: Optional[List[int]],
truncation_psi: float,
noise_mode: str,
outdir: str,
save_mid_image: bool,
fps:int,
num_interps:int
):
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
dtype = torch.float32 if device.type == 'mps' else torch.float64
with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype) # type: ignore
outdir = os.path.join(outdir)
if not os.path.exists(outdir):
os.makedirs(outdir,exist_ok=True)
os.makedirs(os.path.join(outdir,'img'),exist_ok=True)
if len(seeds) > 2:
print("Receiving more than two seeds, only use the first two.")
seeds = seeds[0:2]
elif len(seeds) == 1:
print('Require two seeds, randomly generate two now.')
seeds = [seeds[0],random.randint(0,10000)]
z1 = torch.from_numpy(np.random.RandomState(seeds[0]).randn(1, G.z_dim)).to(device, dtype=dtype)
z2 = torch.from_numpy(np.random.RandomState(seeds[1]).randn(1, G.z_dim)).to(device, dtype=dtype)
img1 = generate_image_from_z(G, z1, noise_mode, truncation_psi, device)
img2 = generate_image_from_z(G, z2, noise_mode, truncation_psi, device)
img1.save(f'{outdir}/seed{seeds[0]:04d}.png')
img2.save(f'{outdir}/seed{seeds[1]:04d}.png')
make_latent_interp_animation(G, z1, z2, img1, img2, num_interps, noise_mode, save_mid_image, truncation_psi, device, outdir, fps)
if __name__ == "__main__":
main()
================================================
FILE: stylegan_human/latent_direction/ss_statics/bottom_length_statis/3/statis.csv
================================================
,channel,strength
401,401,0.0051189866
79,79,0.004417926
499,499,0.0042351373
272,272,0.0033855115
2,2,0.003143758
267,267,0.0025972966
510,510,0.0025229468
130,130,0.0022487796
228,228,0.0021741684
101,101,0.001948409
418,418,0.0018696061
481,481,0.0017976156
88,88,0.0017784507
58,58,0.0017771542
116,116,0.0017733901
282,282,0.0017370607
207,207,0.0017357969
429,429,0.0016839057
284,284,0.0016397897
139,139,0.0016203154
242,242,0.0016087457
319,319,0.0015855831
237,237,0.0015757639
475,475,0.001572773
427,427,0.001526569
276,276,0.0015258243
163,163,0.0014731362
460,460,0.0014698224
268,268,0.0014479166
87,87,0.0013900393
486,486,0.0013423131
367,367,0.0013412643
129,129,0.0013402662
448,448,0.0013169902
438,438,0.0012944449
463,463,0.001292959
109,109,0.0012898487
197,197,0.0012845552
215,215,0.0012809597
419,419,0.0012448858
170,170,0.0012249282
46,46,0.0012235934
191,191,0.0012160796
6,6,0.0012132789
292,292,0.0012097002
174,174,0.0011935516
198,198,0.0011886621
450,450,0.0011825112
334,334,0.0011808838
134,134,0.0011740165
297,297,0.0011682409
388,388,0.0011680947
4,4,0.0011665267
96,96,0.001155285
144,144,0.0011436475
383,383,0.0011424082
472,472,0.0011330546
200,200,0.0011163615
126,126,0.0011128024
7,7,0.0011117855
149,149,0.0011032915
142,142,0.0010992303
108,108,0.0010948912
55,55,0.0010793228
35,35,0.0010671945
156,156,0.0010658612
75,75,0.0010604868
497,497,0.0010573129
333,333,0.0010556887
346,346,0.0010499252
259,259,0.0010397168
33,33,0.0010339752
196,196,0.0010080026
321,321,0.0010073334
169,169,0.001006006
187,187,0.0010003602
421,421,0.0009871207
347,347,0.0009822787
495,495,0.0009788647
235,235,0.00097607024
313,313,0.000972718
316,316,0.00096160895
32,32,0.0009501351
365,365,0.0009465426
50,50,0.0009324631
309,309,0.0009274587
461,461,0.0009274281
439,439,0.0009251979
140,140,0.00091394293
220,220,0.0009082752
482,482,0.0009080755
430,430,0.0009043302
218,218,0.0009004896
143,143,0.00089990336
99,99,0.0008916955
70,70,0.00089066115
168,168,0.0008892674
209,209,0.00088266545
391,391,0.0008787587
137,137,0.00087609346
369,369,0.00087306765
355,355,0.0008672569
354,354,0.0008614661
352,352,0.0008591456
359,359,0.00085570634
258,258,0.00084690325
385,385,0.0008433214
296,296,0.0008431721
153,153,0.0008382941
17,17,0.0008377292
186,186,0.00083413016
162,162,0.00083256833
473,473,0.00082412886
47,47,0.0008205253
89,89,0.0008157718
283,283,0.00080831826
351,351,0.0008064204
124,124,0.00080594653
457,457,0.0008049854
188,188,0.00078665616
154,154,0.00078267895
120,120,0.0007801328
190,190,0.0007771554
85,85,0.0007743246
22,22,0.00076702645
266,266,0.00076506665
227,227,0.0007623983
21,21,0.0007611779
295,295,0.00075573527
476,476,0.00074219145
115,115,0.0007414153
363,363,0.000735065
44,44,0.00073410827
83,83,0.0007331246
118,118,0.0007294682
511,511,0.00072840624
322,322,0.00072799233
483,483,0.0007264888
219,219,0.0007200153
274,274,0.0007194695
455,455,0.0007167979
509,509,0.0007164892
412,412,0.00071247795
239,239,0.00071194395
3,3,0.00070765684
420,420,0.0007068611
53,53,0.0007004219
173,173,0.00069374195
480,480,0.0006935094
189,189,0.0006923614
80,80,0.0006903429
141,141,0.00068950764
208,208,0.00068684825
474,474,0.00068534614
386,386,0.00068098377
107,107,0.00068088406
504,504,0.0006799584
328,328,0.00067885965
307,307,0.00067760365
64,64,0.0006768076
362,362,0.00067168014
86,86,0.000671517
279,279,0.0006705279
361,361,0.00066772755
175,175,0.0006668292
16,16,0.0006653719
345,345,0.00066359126
372,372,0.00066246424
380,380,0.0006578692
330,330,0.00065479317
470,470,0.00065403903
43,43,0.00065164984
205,205,0.0006473879
294,294,0.00063388806
357,357,0.00063241523
36,36,0.00063216814
68,68,0.00063114305
57,57,0.0006299799
213,213,0.00062879996
210,210,0.0006244437
49,49,0.00062352256
241,241,0.00062221487
487,487,0.0006212713
82,82,0.00062058726
466,466,0.0006202627
395,395,0.0006191936
72,72,0.0006174072
158,158,0.0006166083
437,437,0.00061549625
113,113,0.0006142644
277,277,0.0006109689
157,157,0.00060956663
291,291,0.0006088704
370,370,0.0006068144
104,104,0.00060625107
41,41,0.0006062162
94,94,0.00060481543
493,493,0.00060343795
247,247,0.00060251716
338,338,0.00060242455
204,204,0.00060240383
424,424,0.0005989864
344,344,0.0005951116
360,360,0.0005884679
151,151,0.0005868215
264,264,0.0005865409
293,293,0.00058519834
62,62,0.0005819824
300,300,0.00058131246
238,238,0.00057876547
407,407,0.0005770471
342,342,0.0005726651
5,5,0.00057154294
114,114,0.00057109277
240,240,0.00057103945
452,452,0.0005675462
91,91,0.0005672489
413,413,0.0005642817
119,119,0.00056401774
458,458,0.0005635342
180,180,0.0005617487
10,10,0.0005590609
181,181,0.00055804825
479,479,0.0005575437
29,29,0.00055724994
9,9,0.00055689306
102,102,0.0005544259
399,399,0.00055424945
97,97,0.0005539326
172,172,0.00054995815
31,31,0.0005499472
364,364,0.0005491284
492,492,0.00054688356
164,164,0.00054382335
371,371,0.0005417347
275,275,0.00053873524
308,308,0.00053855526
501,501,0.00053753663
92,92,0.00053742283
506,506,0.0005367725
167,167,0.00053528306
305,305,0.00053263427
485,485,0.00053174875
318,318,0.0005315904
177,177,0.00053113786
166,166,0.0005307643
193,193,0.0005297839
469,469,0.0005261937
25,25,0.000521339
48,48,0.0005209389
128,128,0.00052093086
498,498,0.00052030955
405,405,0.0005189927
201,201,0.00051636016
229,229,0.00051383715
24,24,0.00051297864
123,123,0.0005124072
477,477,0.0005120602
402,402,0.0005115426
377,377,0.00051048637
348,348,0.0005102354
23,23,0.00050849793
451,451,0.00050814415
406,406,0.0005045002
27,27,0.0004999815
350,350,0.0004998393
185,185,0.0004971214
390,390,0.00049634936
375,375,0.0004955703
431,431,0.00049411313
105,105,0.00049172394
411,411,0.0004917152
148,148,0.00049001497
250,250,0.0004884555
392,392,0.00048794085
374,374,0.00048640848
252,252,0.00048480998
269,269,0.00048419714
192,192,0.00048391512
217,217,0.00048157398
263,263,0.00048102875
415,415,0.00047999277
212,212,0.0004762149
417,417,0.00047523607
467,467,0.0004741602
340,340,0.00047381772
397,397,0.00047334703
433,433,0.00047333006
378,378,0.00047203893
206,206,0.0004719441
443,443,0.00047179937
484,484,0.00047088487
434,434,0.0004697333
396,396,0.00046903393
13,13,0.00046736852
379,379,0.00046703266
178,178,0.0004656918
202,202,0.0004656007
341,341,0.00046170983
456,456,0.0004603486
462,462,0.0004575785
67,67,0.00045551857
138,138,0.00045521912
459,459,0.00045479517
358,358,0.000450105
77,77,0.00044913465
146,146,0.00044637956
66,66,0.0004448067
98,98,0.00044425039
442,442,0.00044048973
0,0,0.00044048866
216,216,0.0004404604
18,18,0.0004400146
54,54,0.00043942602
20,20,0.00043839475
508,508,0.0004379366
285,285,0.0004373548
195,195,0.00043511056
155,155,0.00043351707
444,444,0.0004311831
257,257,0.00043021288
287,287,0.00042994966
449,449,0.0004278185
280,280,0.00042747098
255,255,0.000425165
56,56,0.00042424107
404,404,0.0004226035
488,488,0.00042232242
356,356,0.00042173947
244,244,0.0004209792
432,432,0.0004125784
214,214,0.0004114269
393,393,0.0004107277
270,270,0.00041058182
111,111,0.0004104286
324,324,0.00040866464
61,61,0.00040655467
366,366,0.00040608697
147,147,0.00040604445
311,311,0.00040550664
500,500,0.00040497814
211,211,0.00040463882
112,112,0.00040117715
100,100,0.00040099313
234,234,0.00040040378
132,132,0.00039979443
478,478,0.0003981258
221,221,0.0003965061
368,368,0.0003960585
336,336,0.00039551125
339,339,0.00039536165
19,19,0.0003951851
71,71,0.00039469448
490,490,0.00039201186
253,253,0.0003899332
332,332,0.00038865488
447,447,0.00038850866
223,223,0.0003865917
12,12,0.00038631624
256,256,0.00038542095
303,303,0.00038529842
335,335,0.0003850341
125,125,0.00038496146
52,52,0.00038445802
465,465,0.00037988674
14,14,0.00037270828
445,445,0.0003714122
51,51,0.0003710708
183,183,0.00036938934
435,435,0.00036892455
76,76,0.0003672379
203,203,0.0003666505
74,74,0.00036636737
464,464,0.00036338985
28,28,0.00036282517
376,376,0.00036232817
389,389,0.00036217785
394,394,0.00036191306
30,30,0.00036106672
327,327,0.000358803
73,73,0.0003566918
343,343,0.00035621642
384,384,0.0003560467
440,440,0.00035523146
251,251,0.00035423675
260,260,0.00035293537
265,265,0.00035225553
387,387,0.0003514995
298,298,0.00034635572
306,306,0.00034440306
110,110,0.00034401406
254,254,0.0003433519
505,505,0.00034252735
60,60,0.00034171442
302,302,0.000340328
171,171,0.00033906853
38,38,0.0003379222
59,59,0.00033599927
353,353,0.00033367483
317,317,0.0003322806
337,337,0.0003305284
135,135,0.0003302332
423,423,0.0003287331
310,310,0.00032717254
503,503,0.00032671163
69,69,0.00032358448
145,145,0.00032273383
160,160,0.00032157102
40,40,0.00032081635
400,400,0.00031983448
278,278,0.00031925194
489,489,0.0003178289
199,199,0.00031677147
133,133,0.00031545162
373,373,0.00031506256
331,331,0.0003125472
382,382,0.00031192778
11,11,0.00031145735
494,494,0.00031131672
426,426,0.00031126558
233,233,0.00030971633
290,290,0.000309349
232,232,0.00030930655
262,262,0.00030752877
231,231,0.00030624977
314,314,0.0003054031
502,502,0.00030359824
323,323,0.0003030356
222,222,0.0002989251
428,428,0.0002974012
496,496,0.00029618212
230,230,0.0002957415
121,121,0.00029490213
304,304,0.00029465224
179,179,0.00029258238
248,248,0.00029258014
436,436,0.0002923748
425,425,0.0002921299
236,236,0.0002915477
150,150,0.00029135606
414,414,0.00029009863
286,286,0.00028853436
320,320,0.00028726715
37,37,0.00028702882
131,131,0.0002847645
225,225,0.0002837026
441,441,0.00028333988
326,326,0.0002814045
422,422,0.00028055938
165,165,0.00027847502
471,471,0.00027833483
349,349,0.0002757503
409,409,0.00027481435
103,103,0.00027326684
95,95,0.00027135346
249,249,0.00027121097
90,90,0.0002710454
224,224,0.00027073256
34,34,0.0002699063
65,65,0.00026914835
184,184,0.00026874637
398,398,0.00026665113
301,301,0.00026612534
325,325,0.00026601987
261,261,0.00026447696
246,246,0.00026436363
122,122,0.00026221105
84,84,0.00026163872
78,78,0.00025123646
299,299,0.00024949652
408,408,0.00024725433
161,161,0.00024636975
288,288,0.00024635496
42,42,0.0002452217
106,106,0.00024499706
182,182,0.00024430823
507,507,0.00024347423
271,271,0.00024137288
136,136,0.00023734072
403,403,0.00023719446
453,453,0.0002368316
26,26,0.00023657693
468,468,0.00023344165
127,127,0.00023242869
117,117,0.0002311785
45,45,0.00022815086
1,1,0.0002279681
194,194,0.00022725234
312,312,0.00022653052
410,410,0.00022528754
491,491,0.00022254683
93,93,0.00022081727
63,63,0.00022056459
226,226,0.00021268189
381,381,0.00020935576
329,329,0.00020906379
446,446,0.00020802105
245,245,0.00020745523
15,15,0.0002072561
281,281,0.00020714967
315,315,0.00020467484
152,152,0.00020205516
8,8,0.00019883284
81,81,0.00019411556
273,273,0.00019290911
39,39,0.00019221198
243,243,0.0001919428
416,416,0.00018266037
289,289,0.00016792006
454,454,0.0001655061
176,176,0.00015807906
159,159,0.0001545051
================================================
FILE: stylegan_human/latent_direction/ss_statics/bottom_length_statis/4/statis.csv
================================================
,channel,strength
371,371,0.010705759
130,130,0.007931795
66,66,0.0069621187
411,411,0.0065370337
241,241,0.0061685536
178,178,0.0057360367
422,422,0.0055051707
59,59,0.0054199533
193,193,0.0052992324
405,405,0.00511423
202,202,0.00487821
414,414,0.004596879
347,347,0.0045533227
325,325,0.0042810338
479,479,0.0041018343
234,234,0.003791195
104,104,0.0037741603
437,437,0.0036558367
186,186,0.0036010114
214,214,0.0035913743
472,472,0.0035745492
99,99,0.003559262
13,13,0.003553579
302,302,0.0034689216
428,428,0.0034320198
43,43,0.0032388393
215,215,0.0031643964
346,346,0.0031622355
392,392,0.0031421043
469,469,0.0031391508
185,185,0.0031346607
110,110,0.0031338846
152,152,0.0031238336
255,255,0.0031061403
27,27,0.003093111
494,494,0.0030917446
238,238,0.0030462171
111,111,0.003043536
162,162,0.0030410155
125,125,0.0030364853
51,51,0.0030085724
231,231,0.0029884125
335,335,0.002956904
184,184,0.002923058
80,80,0.0029210225
253,253,0.0029075942
357,357,0.0028416363
180,180,0.0028330602
360,360,0.0027900473
105,105,0.0027881858
33,33,0.0027586774
475,475,0.0027457555
332,332,0.0027370767
220,220,0.0026984583
31,31,0.0026166395
53,53,0.0026120786
106,106,0.0025991872
412,412,0.00259617
382,382,0.0025869396
38,38,0.0025757526
316,316,0.0025433942
389,389,0.0025421656
435,435,0.002534331
225,225,0.002509905
354,354,0.00245047
243,243,0.0024502743
221,221,0.0024180906
218,218,0.0023629142
56,56,0.0023580198
230,230,0.0023460235
8,8,0.0022959905
344,344,0.0022738976
102,102,0.002244589
279,279,0.0022293774
77,77,0.0022287841
404,404,0.0022230886
200,200,0.0022142548
450,450,0.002204902
319,319,0.0021294882
117,117,0.0021246204
6,6,0.0021206948
247,247,0.002117081
297,297,0.0020922043
40,40,0.0020740507
352,352,0.0020036534
239,239,0.0019678425
402,402,0.0019613549
315,315,0.0019569998
195,195,0.0019549306
128,128,0.0019411439
207,207,0.0019369258
432,432,0.0019148714
365,365,0.0019143238
322,322,0.0018905443
24,24,0.0018891945
265,265,0.0018829522
417,417,0.0018819926
334,334,0.0018748969
0,0,0.0018602412
85,85,0.0018551659
126,126,0.0018536468
4,4,0.001846846
232,232,0.0018420395
376,376,0.0018346108
333,333,0.0018218933
250,250,0.0018113112
169,169,0.0018011285
361,361,0.0017686478
25,25,0.0017465113
427,427,0.0017453731
461,461,0.0017188549
182,182,0.0017154247
11,11,0.0016988176
197,197,0.0016953972
20,20,0.0016696001
246,246,0.0016686992
339,339,0.0016661945
205,205,0.0016548207
177,177,0.0016535291
153,153,0.0016440097
356,356,0.0016337134
456,456,0.0016318822
42,42,0.0016255479
378,378,0.0016247402
155,155,0.0016181484
401,401,0.0016147293
16,16,0.001612585
30,30,0.0016062072
377,377,0.0015920534
385,385,0.0015913579
79,79,0.0015898152
135,135,0.0015849494
384,384,0.0015830033
338,338,0.0015758197
98,98,0.0015708887
21,21,0.0015523953
35,35,0.0015421407
364,364,0.0015421154
270,270,0.0015403415
447,447,0.0015402873
485,485,0.0015397645
121,121,0.0015270623
408,408,0.0015214243
32,32,0.0015187038
336,336,0.0014896884
413,413,0.0014843026
499,499,0.0014718628
487,487,0.0014659785
488,488,0.0014608143
122,122,0.0014533442
491,491,0.0014338568
54,54,0.001432836
363,363,0.0014307784
151,151,0.001427959
91,91,0.0014276047
314,314,0.0014251476
161,161,0.0014244247
211,211,0.0014210346
362,362,0.0014053485
216,216,0.0013955731
159,159,0.0013930433
233,233,0.0013912213
449,449,0.0013888216
48,48,0.0013732343
248,248,0.0013702087
299,299,0.0013695534
503,503,0.0013617459
1,1,0.0013607475
237,237,0.0013588071
57,57,0.0013579503
409,409,0.001355633
483,483,0.0013551097
229,229,0.001353437
19,19,0.0013416156
293,293,0.0013304886
390,390,0.0013271623
168,168,0.0013254886
381,381,0.0013072129
366,366,0.0013022809
288,288,0.0012966645
451,451,0.0012919087
244,244,0.0012890152
292,292,0.0012853605
463,463,0.0012843801
470,470,0.0012821403
416,416,0.0012795452
157,157,0.0012787202
464,464,0.0012758262
329,329,0.0012730482
490,490,0.0012670642
74,74,0.0012638838
170,170,0.001260051
278,278,0.0012583392
22,22,0.0012556936
399,399,0.0012530655
100,100,0.0012508945
355,355,0.001250555
486,486,0.0012472017
506,506,0.0012441961
459,459,0.0012374625
309,309,0.0012292771
113,113,0.0012270019
138,138,0.0012170793
47,47,0.0012084226
65,65,0.0012071957
198,198,0.001206154
196,196,0.001205836
285,285,0.0012046142
49,49,0.0012002704
457,457,0.0011999249
425,425,0.0011967781
175,175,0.0011944658
148,148,0.0011917639
86,86,0.0011908459
94,94,0.0011882967
501,501,0.0011801487
476,476,0.0011663077
156,156,0.0011636167
387,387,0.0011613253
266,266,0.0011546519
496,496,0.0011480699
340,340,0.0011474552
343,343,0.0011469066
52,52,0.0011451634
369,369,0.0011436169
90,90,0.0011377904
386,386,0.0011357777
96,96,0.0011343259
124,124,0.0011335325
460,460,0.0011268322
321,321,0.0011234696
264,264,0.0011168371
287,287,0.0011126697
353,353,0.0011116265
462,462,0.0011107579
154,154,0.0011100196
388,388,0.0011071602
448,448,0.0011041508
187,187,0.0010949599
328,328,0.0010922498
454,454,0.0010910842
306,306,0.0010906332
320,320,0.0010779172
391,391,0.0010766807
318,318,0.0010763333
107,107,0.0010709754
505,505,0.0010685796
206,206,0.0010629565
34,34,0.0010543949
473,473,0.0010541352
173,173,0.0010531493
109,109,0.0010509333
424,424,0.0010466026
430,430,0.0010448382
150,150,0.0010423292
268,268,0.001038994
2,2,0.0010353344
29,29,0.0010323756
504,504,0.0010309685
119,119,0.0010272656
174,174,0.0010258228
260,260,0.0010184883
249,249,0.0010175364
36,36,0.0010174246
137,137,0.001017379
303,303,0.0010154104
163,163,0.0010028904
455,455,0.0010002597
510,510,0.0009994362
245,245,0.00099766
262,262,0.0009885678
281,281,0.0009873835
28,28,0.0009856436
380,380,0.0009821856
228,228,0.0009717986
367,367,0.000969772
286,286,0.00096976873
7,7,0.00096907315
146,146,0.00096689834
139,139,0.000962454
204,204,0.0009589661
3,3,0.000956985
280,280,0.0009566337
509,509,0.00094728253
181,181,0.0009440433
68,68,0.0009432923
269,269,0.0009399444
179,179,0.00093980297
274,274,0.0009315241
95,95,0.0009284373
263,263,0.00092792546
72,72,0.00092775293
277,277,0.0009271326
436,436,0.0009269056
500,500,0.0009265228
87,87,0.00092501455
310,310,0.00092400954
300,300,0.00092247425
144,144,0.0009218417
426,426,0.00091638684
81,81,0.00091634423
188,188,0.00090678554
289,289,0.0009011138
418,418,0.00089773914
397,397,0.00089350634
304,304,0.0008868619
482,482,0.0008865463
495,495,0.0008795051
312,312,0.00087571866
166,166,0.00087193854
183,183,0.00087024615
5,5,0.00086923986
446,446,0.0008691286
212,212,0.00086676405
46,46,0.00086484046
118,118,0.00086480496
254,254,0.00086135446
88,88,0.000861164
219,219,0.0008608658
467,467,0.0008600293
76,76,0.00085848826
331,331,0.00085327355
498,498,0.0008505213
39,39,0.000848981
93,93,0.00084537006
433,433,0.0008452787
410,410,0.00084377866
194,194,0.00083563104
478,478,0.00083321065
272,272,0.00083227345
223,223,0.00083198043
311,311,0.00083184306
431,431,0.0008309719
337,337,0.0008306422
189,189,0.0008290602
341,341,0.0008284594
394,394,0.00082656107
396,396,0.0008221022
92,92,0.0008204752
50,50,0.0008168993
84,84,0.0008081732
64,64,0.0008072594
123,123,0.00080666045
71,71,0.0008044883
140,140,0.00080113776
120,120,0.0007968432
14,14,0.00079405046
324,324,0.00079051324
115,115,0.0007894897
191,191,0.0007880761
439,439,0.0007866818
393,393,0.000783365
131,131,0.0007829777
327,327,0.0007810872
17,17,0.0007793945
97,97,0.00077890133
295,295,0.0007760637
423,423,0.00077443745
403,403,0.0007743149
497,497,0.0007740729
171,171,0.00077238533
276,276,0.0007655106
452,452,0.00076521165
136,136,0.0007644099
82,82,0.0007621963
142,142,0.000760782
374,374,0.0007555941
444,444,0.0007507936
282,282,0.0007499849
421,421,0.0007498477
375,375,0.00074311404
358,358,0.0007403847
217,217,0.00073654624
165,165,0.00073613157
420,420,0.0007335366
227,227,0.0007332281
330,330,0.00073107216
368,368,0.00072970067
37,37,0.00072445447
149,149,0.00072384684
477,477,0.000723286
407,407,0.00072004553
242,242,0.00071955
134,134,0.0007190485
172,172,0.0007161819
69,69,0.00071376155
372,372,0.00071212306
236,236,0.00071169576
349,349,0.000709231
484,484,0.00070886995
222,222,0.0007068514
9,9,0.00070441724
481,481,0.0007041735
373,373,0.0007033714
323,323,0.00070036115
434,434,0.0006995436
438,438,0.00069369253
359,359,0.0006884251
370,370,0.00068327464
308,308,0.000678813
445,445,0.0006780726
10,10,0.0006776827
127,127,0.0006771573
224,224,0.0006768041
296,296,0.00067331357
256,256,0.00066997507
493,493,0.0006680274
67,67,0.0006640576
116,116,0.0006630596
132,132,0.000661265
62,62,0.00066118705
508,508,0.00065791415
468,468,0.0006579114
440,440,0.00065191026
317,317,0.00065190904
160,160,0.00065158366
492,492,0.00065123267
458,458,0.0006491314
114,114,0.0006431901
58,58,0.00063985295
313,313,0.0006342638
18,18,0.0006290864
261,261,0.0006265827
383,383,0.00062478095
294,294,0.00062198006
143,143,0.0006198618
61,61,0.0006060728
103,103,0.0006045529
419,419,0.00060209975
466,466,0.0006016268
507,507,0.0005986943
273,273,0.0005929426
240,240,0.0005922067
350,350,0.00058880384
429,429,0.000588744
129,129,0.00058819435
267,267,0.00058599794
252,252,0.0005856361
23,23,0.00058518484
400,400,0.0005806418
283,283,0.0005790221
89,89,0.00057516847
12,12,0.0005741658
176,176,0.0005739618
192,192,0.00057157595
101,101,0.00056115567
442,442,0.0005588267
41,41,0.0005569018
63,63,0.0005514146
441,441,0.0005468517
398,398,0.00054553134
307,307,0.0005450811
298,298,0.00054445845
342,342,0.00054111326
502,502,0.0005392242
78,78,0.0005375705
44,44,0.0005357114
73,73,0.00053344626
209,209,0.00053105725
158,158,0.0005261206
70,70,0.0005217334
199,199,0.0005199517
471,471,0.0005012095
511,511,0.0004969943
259,259,0.000493192
235,235,0.00047622804
301,301,0.0004594244
275,275,0.0004556442
167,167,0.00045370075
133,133,0.00044329825
147,147,0.00043812627
348,348,0.00042484334
190,190,0.0004229568
406,406,0.00042175007
480,480,0.000417746
108,108,0.00041762542
395,395,0.000411962
305,305,0.0004101298
290,290,0.00040680933
489,489,0.0004044473
251,251,0.0004041571
164,164,0.00040222614
257,257,0.00039560342
379,379,0.000392651
326,326,0.00038673886
112,112,0.00036090027
83,83,0.000344462
351,351,0.00034133552
210,210,0.00033938422
141,141,0.0003355473
60,60,0.00033554618
226,226,0.00033128157
203,203,0.00032476717
15,15,0.00027671162
208,208,0.0002761039
291,291,0.000266872
213,213,0.0002634147
415,415,0.00026088266
474,474,0.000256801
271,271,0.00024554084
201,201,0.00023913333
443,443,0.00023458686
145,145,0.00022581422
284,284,0.00022399156
258,258,0.00021423028
465,465,0.00021232266
453,453,0.00021226592
75,75,0.00020900984
55,55,0.00020078795
26,26,0.00018182797
45,45,0.00016305555
345,345,0.00015414672
================================================
FILE: stylegan_human/latent_direction/ss_statics/bottom_length_statis/5/statis.csv
================================================
,channel,strength
242,242,0.01746412
134,134,0.011444086
71,71,0.01060778
395,395,0.0062382165
363,363,0.0058679837
175,175,0.004722381
53,53,0.0044826367
112,112,0.0042659384
457,457,0.003703029
288,288,0.003450465
328,328,0.00344062
414,414,0.0032427178
205,205,0.0032254586
321,321,0.003165059
32,32,0.003014796
9,9,0.0025584965
180,180,0.0025317036
452,452,0.0024022916
69,69,0.0023928392
210,210,0.0023827436
385,385,0.0023732155
98,98,0.0023490055
307,307,0.0023184065
418,418,0.002295881
470,470,0.0022638584
341,341,0.0021729823
308,308,0.0021633923
37,37,0.0021602109
440,440,0.0021246527
16,16,0.0020481
10,10,0.0020317773
486,486,0.002021162
150,150,0.0020176854
89,89,0.0019858822
278,278,0.001971183
430,430,0.0019591004
463,463,0.0019105887
434,434,0.0018847745
437,437,0.0017904
127,127,0.0017076981
310,310,0.0016927454
151,151,0.0016283162
224,224,0.0015725751
268,268,0.0015449473
402,402,0.00153654
190,190,0.0014757048
92,92,0.0014610167
117,117,0.0014568182
110,110,0.0014490561
423,423,0.0014475571
161,161,0.0014227595
291,291,0.001410095
225,225,0.0013883268
189,189,0.001364547
157,157,0.0013630674
499,499,0.00135625
274,274,0.0013522932
166,166,0.0013465862
475,475,0.0013449993
300,300,0.0013322539
368,368,0.0012959774
267,267,0.0012800089
36,36,0.0012684392
11,11,0.0012514348
184,184,0.0012406422
453,453,0.0012389477
173,173,0.0012284226
429,429,0.001227794
229,229,0.0011854741
212,212,0.0011837278
295,295,0.0011766667
318,318,0.0011758378
390,390,0.0011577767
67,67,0.0011576377
26,26,0.001157294
256,256,0.0011399041
287,287,0.0011214579
245,245,0.0011204005
118,118,0.001119347
379,379,0.0011182849
412,412,0.0011027583
169,169,0.0010857102
488,488,0.001084579
108,108,0.0010608159
155,155,0.0010607184
465,465,0.0010555563
80,80,0.0010430918
285,285,0.0010419621
191,191,0.001039581
320,320,0.0010336601
489,489,0.0009907437
46,46,0.00098818
359,359,0.0009863363
415,415,0.000984566
438,438,0.0009843353
2,2,0.0009840623
483,483,0.000978259
116,116,0.00096878014
279,279,0.00096704334
391,391,0.0009629472
75,75,0.0009625787
386,386,0.0009506957
213,213,0.00094954396
81,81,0.0009334278
170,170,0.0009309288
459,459,0.0009304561
25,25,0.0009263908
422,422,0.0009251744
316,316,0.0009241467
254,254,0.00092409964
294,294,0.00092359504
322,322,0.00092271896
493,493,0.00092130696
168,168,0.00091688987
361,361,0.0009016815
302,302,0.0008988095
199,199,0.0008969074
42,42,0.00089361327
275,275,0.00089227123
20,20,0.00089052174
197,197,0.00088976097
43,43,0.0008803114
370,370,0.00087934994
436,436,0.00087875564
28,28,0.00087209826
290,290,0.0008675793
330,330,0.00085718994
94,94,0.0008566909
511,511,0.0008561607
77,77,0.0008551629
484,484,0.0008420306
202,202,0.0008376041
78,78,0.00083523884
487,487,0.00083071465
44,44,0.0008302506
456,456,0.00082660967
343,343,0.00082623283
186,186,0.0008204403
428,428,0.0008122731
63,63,0.000809409
371,371,0.0007866993
367,367,0.0007859241
410,410,0.00078440236
129,129,0.00078421313
492,492,0.0007841307
219,219,0.00078324176
181,181,0.0007805426
192,192,0.000759081
348,348,0.00075573416
156,156,0.00075448444
149,149,0.0007483769
497,497,0.0007449612
97,97,0.0007447865
238,238,0.0007345617
427,427,0.0007344842
48,48,0.0007277395
496,496,0.0007251045
468,468,0.0007233464
351,351,0.00071131974
396,396,0.0007099477
240,240,0.00070780434
277,277,0.00070575473
397,397,0.00070399133
362,362,0.0007016971
122,122,0.0006994947
425,425,0.0006994096
347,347,0.0006979086
502,502,0.00069097895
377,377,0.00068525685
70,70,0.00068207615
481,481,0.0006812176
185,185,0.00067786046
90,90,0.00067573227
472,472,0.0006711317
339,339,0.0006699505
405,405,0.000669038
426,426,0.0006675291
204,204,0.0006652177
296,296,0.0006634654
235,235,0.000660739
141,141,0.0006510568
203,203,0.00064648956
293,293,0.0006462373
508,508,0.00064582145
121,121,0.00064454816
93,93,0.00064454804
406,406,0.00064240245
12,12,0.0006397705
344,344,0.00063504284
19,19,0.0006331822
332,332,0.0006323309
194,194,0.00062925677
313,313,0.0006285695
507,507,0.0006174056
82,82,0.0006161944
239,239,0.00060919294
490,490,0.0006089468
266,266,0.0006075872
128,128,0.0006045006
407,407,0.00060399977
451,451,0.00060152815
137,137,0.0005997921
454,454,0.00059671386
270,270,0.00059550925
404,404,0.00059543905
439,439,0.000591806
460,460,0.0005913006
23,23,0.0005911026
373,373,0.000590888
96,96,0.0005903696
55,55,0.00059018075
365,365,0.000585234
411,411,0.00058369443
374,374,0.00058170315
119,119,0.000581078
458,458,0.00058017287
420,420,0.00057635305
393,393,0.00057374657
261,261,0.0005727315
319,319,0.00057239935
283,283,0.00057194225
233,233,0.0005646504
403,403,0.00056403375
357,357,0.00056290894
241,241,0.00056050875
182,182,0.00056047173
257,257,0.000556821
364,364,0.00055118365
174,174,0.0005500873
111,111,0.00054910127
309,309,0.0005458186
193,193,0.00054326915
5,5,0.0005425164
24,24,0.0005411514
162,162,0.0005377205
432,432,0.0005299737
284,284,0.000529899
491,491,0.0005281161
352,352,0.000527957
39,39,0.0005278845
297,297,0.00052527175
312,312,0.0005249035
443,443,0.0005247692
143,143,0.0005237088
389,389,0.00052194
132,132,0.0005208663
57,57,0.00051607064
85,85,0.00051471085
482,482,0.00051036675
247,247,0.0005102306
477,477,0.00050744385
358,358,0.0005060787
125,125,0.000505312
466,466,0.00050286297
479,479,0.00050196645
323,323,0.0005004967
232,232,0.0004986481
114,114,0.0004957778
87,87,0.0004945035
158,158,0.0004930025
376,376,0.00049192616
384,384,0.00049149833
65,65,0.00049068604
144,144,0.0004868068
41,41,0.00048459516
33,33,0.00048199162
135,135,0.00047957737
72,72,0.00047893118
449,449,0.00047819986
442,442,0.0004762921
130,130,0.00047556465
244,244,0.00047063318
163,163,0.00047034535
292,292,0.00047004907
350,350,0.00046825878
304,304,0.0004671009
474,474,0.0004649635
208,208,0.00046426945
286,286,0.00046057467
178,178,0.00045859
450,450,0.00045825946
226,226,0.00045819176
8,8,0.00045815166
252,252,0.0004559861
154,154,0.00045457872
264,264,0.00045156496
146,146,0.00044669665
220,220,0.00044567848
501,501,0.00044542857
273,273,0.00044488933
171,171,0.0004427107
401,401,0.00044082777
381,381,0.00043972745
140,140,0.00043886708
356,356,0.00043867563
464,464,0.0004386433
353,353,0.00043817703
230,230,0.0004378558
249,249,0.00043769262
15,15,0.0004355795
345,345,0.0004354763
133,133,0.0004346199
120,120,0.00043198353
505,505,0.00043042895
95,95,0.00042860254
378,378,0.0004282037
455,455,0.00042507777
298,298,0.00042234876
145,145,0.0004205409
104,104,0.0004192717
394,394,0.00041797172
45,45,0.00041774593
366,366,0.00041443488
325,325,0.00041310437
62,62,0.00041287072
179,179,0.00041070042
317,317,0.0004105377
338,338,0.00040974608
79,79,0.00040827427
66,66,0.00040683098
139,139,0.0004056471
29,29,0.00040440765
152,152,0.00040363474
214,214,0.00040329635
435,435,0.0004028462
22,22,0.00040165885
346,346,0.0004001219
209,209,0.00039810632
392,392,0.00039726324
315,315,0.00039714077
433,433,0.00039478665
506,506,0.0003915952
105,105,0.0003900892
446,446,0.00038801826
0,0,0.00038777932
18,18,0.0003875171
424,424,0.00038726558
331,331,0.00038722638
51,51,0.00038692914
417,417,0.00038655673
14,14,0.0003852234
6,6,0.00038509126
281,281,0.0003836144
383,383,0.00038292323
216,216,0.00038181938
262,262,0.00037940088
74,74,0.00037926337
3,3,0.00037809636
387,387,0.00037650907
342,342,0.00037650426
398,398,0.0003733892
136,136,0.00037266538
243,243,0.00037207166
471,471,0.00037091138
73,73,0.00036918517
86,86,0.00036871075
68,68,0.00036807635
372,372,0.00036491535
200,200,0.0003633216
107,107,0.00036330617
480,480,0.00036279066
124,124,0.00036251516
131,131,0.00036231225
269,269,0.00036161783
447,447,0.00036158395
167,167,0.0003561097
38,38,0.00035443963
211,211,0.00035319992
369,369,0.00035081702
248,248,0.00035030014
101,101,0.00034903514
115,115,0.00034696967
413,413,0.0003440623
289,289,0.00034358836
56,56,0.00034278553
195,195,0.0003423341
388,388,0.00034183572
416,416,0.00034126177
206,206,0.0003408197
375,375,0.00033973216
4,4,0.00033931396
494,494,0.00033901844
99,99,0.00033845875
282,282,0.00033784402
172,172,0.00033696217
218,218,0.00033630853
165,165,0.0003358777
196,196,0.0003349965
419,419,0.0003341522
469,469,0.00033384332
327,327,0.0003337259
324,324,0.00033304066
260,260,0.00033025324
478,478,0.00032999786
148,148,0.0003298886
500,500,0.00032941083
83,83,0.00032738544
231,231,0.00032720037
280,280,0.00032718244
498,498,0.00032587873
467,467,0.00032520766
102,102,0.00032370255
84,84,0.00032345828
21,21,0.00032132064
334,334,0.00032061047
237,237,0.00032028827
441,441,0.0003187737
485,485,0.0003181529
159,159,0.00031660515
50,50,0.00031633597
258,258,0.00031522466
164,164,0.00031313743
59,59,0.00031286437
160,160,0.00031265983
355,355,0.00031167237
7,7,0.00031130642
76,76,0.00030938906
476,476,0.0003087692
263,263,0.0003072378
495,495,0.0003052068
27,27,0.00030507136
234,234,0.00030466946
299,299,0.00030342882
113,113,0.00030309713
276,276,0.00030135227
217,217,0.00030106696
61,61,0.00030104464
54,54,0.00030067936
349,349,0.0003001497
380,380,0.00029829858
503,503,0.00029654359
303,303,0.0002964005
88,88,0.00029631294
188,188,0.00029344021
444,444,0.00029314525
329,329,0.0002908486
400,400,0.0002895397
223,223,0.00028937528
123,123,0.0002876151
272,272,0.0002861553
91,91,0.0002857061
354,354,0.00028297797
409,409,0.0002819136
448,448,0.0002804029
509,509,0.00027843454
215,215,0.00027820378
183,183,0.0002780649
253,253,0.0002754507
461,461,0.00027407336
255,255,0.00027380435
251,251,0.00027299367
109,109,0.00027289733
246,246,0.00027286436
259,259,0.00027249
201,201,0.00026961832
207,207,0.00026913724
198,198,0.00026806045
236,236,0.00026771048
326,326,0.00026338472
49,49,0.00026074707
138,138,0.0002606831
147,147,0.00025864976
30,30,0.00025828253
408,408,0.0002574985
177,177,0.00025741325
153,153,0.00025581883
187,187,0.0002544332
34,34,0.0002536471
58,58,0.00025209208
473,473,0.00024838003
221,221,0.00024675552
126,126,0.00024573723
228,228,0.00024463152
306,306,0.00024407305
250,250,0.00024358112
421,421,0.00024149592
176,176,0.00024142586
64,64,0.00023718696
336,336,0.00023363969
40,40,0.00023206352
504,504,0.00023138674
399,399,0.0002291037
305,305,0.00022865152
60,60,0.0002266992
301,301,0.00022585278
222,222,0.00022387604
311,311,0.00022371336
17,17,0.0002226341
52,52,0.00022181227
142,142,0.00021962133
103,103,0.00021821138
382,382,0.00021803642
1,1,0.0002171288
337,337,0.00021468119
445,445,0.00021307352
314,314,0.00021049718
100,100,0.00020906334
360,360,0.00020833187
340,340,0.00020668289
106,106,0.0002023169
462,462,0.00019855957
271,271,0.00019850428
227,227,0.00019575101
35,35,0.00019263716
510,510,0.00018776124
265,265,0.0001874168
31,31,0.00018595619
333,333,0.00018071612
13,13,0.00017645532
431,431,0.00017538466
47,47,0.00017165719
335,335,0.00016702584
================================================
FILE: stylegan_human/latent_direction/ss_statics/upper_length_statis/5/statis.csv
================================================
,channel,strength
423,423,0.004408924
341,341,0.0032079767
379,379,0.0028457695
184,184,0.0028266357
368,368,0.0027224978
486,486,0.0025201144
313,313,0.0025118296
426,426,0.002334192
367,367,0.0022732853
213,213,0.0022067663
436,436,0.0021884514
308,308,0.0021619347
496,496,0.0020120202
393,393,0.001984407
422,422,0.0019403459
425,425,0.0018690284
511,511,0.0017944475
181,181,0.00178793
497,497,0.0016764585
361,361,0.0016183082
2,2,0.0015440682
267,267,0.0015391556
114,114,0.0015294778
170,170,0.001501581
432,432,0.001494758
385,385,0.0014706858
339,339,0.001420463
415,415,0.0013572118
116,116,0.0013073295
373,373,0.0013072236
453,453,0.0013039358
320,320,0.0012937128
256,256,0.0012766926
127,127,0.0012432956
75,75,0.0012031216
287,287,0.0011476731
309,309,0.001147382
456,456,0.0011466638
343,343,0.0011463353
457,457,0.0011426552
98,98,0.0011147967
437,437,0.0011112978
435,435,0.0010973605
182,182,0.0010894896
150,150,0.0010746964
279,279,0.0010726912
189,189,0.0010475953
9,9,0.001045001
371,371,0.001035062
53,53,0.0010211505
168,168,0.0010098461
146,146,0.0010040879
470,470,0.00097405957
69,69,0.00097038125
375,375,0.0009642161
134,134,0.00094965403
475,475,0.0009452465
125,125,0.0009439787
434,434,0.0009384095
288,288,0.0009309878
205,205,0.0009186115
128,128,0.0009111188
328,328,0.0008899651
319,319,0.0008791126
42,42,0.00087770226
458,458,0.0008629476
272,272,0.0008628456
414,414,0.00086089334
261,261,0.00085644424
304,304,0.00083918776
24,24,0.0008195494
161,161,0.0008154048
225,225,0.0008051095
67,67,0.0008038561
482,482,0.0007958309
430,430,0.0007924481
499,499,0.0007831592
390,390,0.00078276
11,11,0.00076811534
332,332,0.0007633267
197,197,0.0007563872
325,325,0.00074650464
322,322,0.00073741924
157,157,0.00071299827
93,93,0.00069571583
108,108,0.0006955461
185,185,0.00068194594
115,115,0.0006798201
80,80,0.0006779811
405,405,0.00067563757
450,450,0.00067155104
105,105,0.00065722043
57,57,0.00065079477
36,36,0.0006450097
454,454,0.0006430749
247,247,0.0006401138
8,8,0.00063832046
102,102,0.0006370239
316,316,0.0006331501
488,488,0.00062760746
364,364,0.0006171263
389,389,0.00060887367
172,172,0.0006087077
352,352,0.00060442963
463,463,0.0006023239
19,19,0.00060143415
210,210,0.00059817376
186,186,0.0005907412
471,471,0.0005869326
140,140,0.00058600784
94,94,0.000580931
81,81,0.000580287
270,270,0.0005789991
417,417,0.0005783043
196,196,0.0005754427
46,46,0.00057365885
464,464,0.00056969275
104,104,0.00056765747
171,171,0.00056728435
487,487,0.0005661972
220,220,0.00056612946
68,68,0.00056457426
33,33,0.0005640128
72,72,0.0005622605
229,229,0.00055803073
41,41,0.0005567271
148,148,0.0005566318
455,455,0.00055565406
349,349,0.00055537344
442,442,0.00055352686
162,162,0.0005533156
191,191,0.00054769375
202,202,0.0005464838
226,226,0.0005416205
466,466,0.00054103765
192,192,0.0005383168
479,479,0.00052559533
143,143,0.00052148907
199,199,0.0005204556
281,281,0.00051968545
410,410,0.0005187148
403,403,0.0005185161
38,38,0.00051575975
429,429,0.00051309256
381,381,0.0005028584
363,363,0.00050276244
26,26,0.0004997491
502,502,0.0004992033
208,208,0.000498523
109,109,0.0004985198
179,179,0.000497237
433,433,0.00049502135
190,190,0.0004945647
223,223,0.00049233536
50,50,0.00048974034
97,97,0.00048946124
37,37,0.0004884755
85,85,0.00048751346
6,6,0.00048617192
351,351,0.0004857896
212,212,0.00048565448
259,259,0.00048205393
294,294,0.00048184753
118,118,0.00048155375
117,117,0.00047967743
428,428,0.0004783297
110,110,0.0004755637
90,90,0.00047256536
55,55,0.00047060288
233,233,0.0004681879
264,264,0.00046594813
240,240,0.00046441547
310,310,0.00046434483
238,238,0.00046259057
145,145,0.00046210623
22,22,0.00046033954
395,395,0.00045831574
397,397,0.00045133353
18,18,0.00044636318
129,129,0.0004431245
241,241,0.0004430129
465,465,0.00044261583
63,63,0.00044257144
418,418,0.00044039512
507,507,0.00043912584
358,358,0.0004376243
365,365,0.00043710147
242,242,0.00043396762
411,411,0.00043118192
338,338,0.0004302841
122,122,0.00042847547
235,235,0.00042753594
476,476,0.0004272296
3,3,0.00042361638
176,176,0.00041432443
407,407,0.00041412332
391,391,0.00041366852
1,1,0.00041138142
73,73,0.0004112309
493,493,0.00041061096
421,421,0.000409816
459,459,0.00040876624
14,14,0.00040792229
284,284,0.00040753878
61,61,0.0004073927
293,293,0.00040734603
275,275,0.00040710357
494,494,0.0004049803
180,180,0.00040350194
301,301,0.00040267198
198,198,0.00040212894
193,193,0.00040090212
280,280,0.0004008506
396,396,0.00040081202
5,5,0.0004002457
76,76,0.00040004638
350,350,0.00039750861
283,283,0.0003969664
344,344,0.0003949259
230,230,0.00039482582
149,149,0.0003945113
244,244,0.00039450615
357,357,0.00039388
491,491,0.00039187729
45,45,0.00038878893
298,298,0.00038755446
56,56,0.00038721022
503,503,0.00038716424
217,217,0.0003867056
159,159,0.00038530145
468,468,0.00038489333
427,427,0.0003842091
291,291,0.00038256386
500,500,0.0003809668
290,290,0.00038067088
460,460,0.00038007705
331,331,0.00037731865
211,211,0.00037723288
440,440,0.00037687554
74,74,0.00037646503
119,119,0.0003755539
133,133,0.00037539483
353,353,0.00037147343
321,321,0.00037131374
32,32,0.0003699305
16,16,0.0003695323
71,71,0.00036733068
483,483,0.0003673293
131,131,0.00036619935
404,404,0.00036550473
218,218,0.00036407504
424,424,0.00036398123
23,23,0.00036384634
65,65,0.0003628074
111,111,0.00036226492
260,260,0.00036025967
258,258,0.0003602185
112,112,0.00035953286
492,492,0.00035905885
167,167,0.00035840494
399,399,0.00035528265
376,376,0.00035397
326,326,0.00035366518
481,481,0.00035335837
79,79,0.00035233115
101,101,0.00035053334
276,276,0.00034943986
296,296,0.00034848423
152,152,0.00034692476
10,10,0.00034476537
489,489,0.00034423766
438,438,0.0003438288
120,120,0.0003430017
239,239,0.00034234222
317,317,0.00034196727
347,347,0.00034175767
206,206,0.0003414347
84,84,0.00034051857
490,490,0.0003404467
107,107,0.0003396762
495,495,0.00033961574
333,333,0.00033911737
446,446,0.000338417
254,254,0.00033728743
386,386,0.00033677832
250,250,0.00033664028
306,306,0.0003353818
246,246,0.0003352706
372,372,0.00033452015
169,169,0.00033249875
451,451,0.0003304912
173,173,0.0003292484
302,302,0.00032853018
151,151,0.0003258597
263,263,0.0003249236
274,274,0.00032480215
156,156,0.00032411155
307,307,0.00032296553
88,88,0.00032126167
39,39,0.0003210367
91,91,0.00032037924
413,413,0.00032021984
232,232,0.0003186513
366,366,0.00031673085
480,480,0.00031612715
44,44,0.00031571824
462,462,0.00031546177
380,380,0.0003152556
83,83,0.00031272677
132,132,0.0003114493
209,209,0.00030984185
48,48,0.00030914877
382,382,0.00030887048
195,195,0.00030860244
154,154,0.00030410106
166,166,0.0003024194
245,245,0.00030229168
262,262,0.00030191787
237,237,0.0002995339
443,443,0.0002943032
467,467,0.00029363343
121,121,0.00029333794
416,416,0.00029272414
160,160,0.00029269484
4,4,0.00029229806
92,92,0.00029168854
77,77,0.00028903817
400,400,0.0002876192
278,278,0.00028760926
474,474,0.00028757288
402,402,0.00028493986
506,506,0.0002847649
234,234,0.00028450877
277,277,0.00028409314
447,447,0.0002835903
342,342,0.00028351165
285,285,0.00028341086
345,345,0.00028339337
348,348,0.0002823747
300,300,0.00028156798
383,383,0.00028049652
231,231,0.0002790845
203,203,0.00027895247
355,355,0.00027876275
204,204,0.00027841472
216,216,0.0002779351
508,508,0.00027720784
282,282,0.00027655836
297,297,0.00027502645
292,292,0.00027430354
327,327,0.0002727945
100,100,0.000269865
95,95,0.0002694548
187,187,0.0002689126
408,408,0.0002658863
477,477,0.00026576317
384,384,0.0002645117
54,54,0.00026404977
374,374,0.00026287523
420,420,0.00026245107
509,509,0.00026231605
28,28,0.00026166256
449,449,0.0002611203
336,336,0.0002604421
178,178,0.00026030626
299,299,0.00025961167
103,103,0.00025886018
388,388,0.00025811547
271,271,0.00025790904
207,207,0.00025755883
248,248,0.00025613784
249,249,0.00025567278
138,138,0.00025559164
78,78,0.00025549062
269,269,0.00025442135
273,273,0.00025399768
286,286,0.00025389006
12,12,0.0002534591
478,478,0.00025331602
452,452,0.00025299162
27,27,0.0002521485
82,82,0.00025113745
295,295,0.0002508786
201,201,0.0002506164
409,409,0.00025036978
359,359,0.00024992102
394,394,0.00024895562
330,330,0.00024881997
501,501,0.0002484243
64,64,0.00024702528
52,52,0.00024523196
106,106,0.00024233114
175,175,0.00024187689
135,135,0.00024020251
419,419,0.00023802932
139,139,0.00023566972
25,25,0.00023424833
312,312,0.00023372385
469,469,0.00023334593
7,7,0.00023190145
158,158,0.00023087433
60,60,0.00023052357
441,441,0.00022939079
165,165,0.00022774191
59,59,0.00022733606
147,147,0.00022624452
62,62,0.00022606985
144,144,0.00022451063
370,370,0.00022431195
21,21,0.0002239656
369,369,0.00022302035
314,314,0.00022240904
377,377,0.00022210156
406,406,0.00022187224
255,255,0.00022102552
356,356,0.00022071223
472,472,0.00021941346
484,484,0.00021900832
289,289,0.0002186439
137,137,0.00021708063
17,17,0.00021706238
51,51,0.00021587432
174,174,0.0002145808
124,124,0.00021406145
253,253,0.00021273737
251,251,0.00021224513
485,485,0.00021198599
214,214,0.00021052467
227,227,0.00020948028
126,126,0.00020917962
362,362,0.00020837369
473,473,0.00020753275
311,311,0.00020720006
346,346,0.00020569382
243,243,0.00020554132
439,439,0.00020543092
177,177,0.00020462318
86,86,0.00020450725
43,43,0.00020431104
354,354,0.0002021195
323,323,0.00020008704
378,378,0.00019674652
153,153,0.00019593372
324,324,0.00019575354
194,194,0.0001956694
360,360,0.00019427242
188,188,0.0001923151
265,265,0.00019114434
431,431,0.00019109932
219,219,0.0001904252
315,315,0.00019023697
224,224,0.00018825711
412,412,0.0001871387
89,89,0.00018677485
268,268,0.0001852995
257,257,0.00018430859
392,392,0.0001833855
35,35,0.00018329639
20,20,0.00018285586
222,222,0.00018160306
141,141,0.00018157126
398,398,0.00018143529
461,461,0.0001811381
29,29,0.00018102965
318,318,0.00017966877
448,448,0.00017874948
58,58,0.00017864846
329,329,0.00017672384
401,401,0.0001746514
183,183,0.00017442314
142,142,0.00017366462
498,498,0.00017296121
0,0,0.00017286446
504,504,0.00017194722
444,444,0.00017125413
155,155,0.00016847586
87,87,0.00016776803
96,96,0.00016757102
99,99,0.00016714378
340,340,0.00016512058
505,505,0.00016355969
266,266,0.00016327428
49,49,0.00016297943
221,221,0.00016184167
15,15,0.00016183533
303,303,0.00016075459
113,113,0.00016014857
130,130,0.00015612668
215,215,0.00015609166
228,228,0.0001551064
305,305,0.00015308998
335,335,0.00015264843
40,40,0.00015258256
66,66,0.00015006233
200,200,0.00015001767
387,387,0.00014855604
252,252,0.00014628381
164,164,0.00014496325
136,136,0.00014322113
445,445,0.00014278234
123,123,0.00014132661
236,236,0.00013684056
31,31,0.00013596549
34,34,0.00013567152
334,334,0.00013561502
337,337,0.00013048301
70,70,0.00012578799
510,510,0.00012369145
47,47,0.00012088309
13,13,0.00011961328
163,163,0.00010987895
30,30,9.594474e-05
================================================
FILE: stylegan_human/legacy.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
#
import pickle
import dnnlib
import re
from typing import List, Optional
import torch
import copy
import numpy as np
from torch_utils import misc
#----------------------------------------------------------------------------
## loading torch pkl
def load_network_pkl(f, force_fp16=False, G_only=False):
data = _LegacyUnpickler(f).load()
if G_only:
f = open('ori_model_Gonly.txt','a+')
else: f = open('ori_model.txt','a+')
for key in data.keys():
f.write(str(data[key]))
f.close()
## We comment out this part, if you want to convert TF pickle, you can use the original script from StyleGAN2-ada-pytorch
# # Legacy TensorFlow pickle => convert.
# if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
# tf_G, tf_D, tf_Gs = data
# G = convert_tf_generator(tf_G)
# D = convert_tf_discriminator(tf_D)
# G_ema = convert_tf_generator(tf_Gs)
# data = dict(G=G, D=D, G_ema=G_ema)
# Add missing fields.
if 'training_set_kwargs' not in data:
data['training_set_kwargs'] = None
if 'augment_pipe' not in data:
data['augment_pipe'] = None
# Validate contents.
assert isinstance(data['G_ema'], torch.nn.Module)
if not G_only:
assert isinstance(data['D'], torch.nn.Module)
assert isinstance(data['G'], torch.nn.Module)
assert isinstance(data['training_set_kwargs'], (dict, type(None)))
assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
# Force FP16.
if force_fp16:
if G_only:
convert_list = ['G_ema'] #'G'
else: convert_list = ['G', 'D', 'G_ema']
for key in convert_list:
old = data[key]
kwargs = copy.deepcopy(old.init_kwargs)
if key.startswith('G'):
kwargs.synthesis_kwargs = dnnlib.EasyDict(kwargs.get('synthesis_kwargs', {}))
kwargs.synthesis_kwargs.num_fp16_res = 4
kwargs.synthesis_kwargs.conv_clamp = 256
if key.startswith('D'):
kwargs.num_fp16_res = 4
kwargs.conv_clamp = 256
if kwargs != old.init_kwargs:
new = type(old)(**kwargs).eval().requires_grad_(False)
misc.copy_params_and_buffers(old, new, require_all=True)
data[key] = new
return data
class _TFNetworkStub(dnnlib.EasyDict):
pass
class _LegacyUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if module == 'dnnlib.tflib.network' and name == 'Network':
return _TFNetworkStub
return super().find_class(module, name)
#----------------------------------------------------------------------------
def num_range(s: str) -> List[int]:
'''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
range_re = re.compile(r'^(\d+)-(\d+)$')
m = range_re.match(s)
if m:
return list(range(int(m.group(1)), int(m.group(2))+1))
vals = s.split(',')
return [int(x) for x in vals]
#----------------------------------------------------------------------------
#### loading tf pkl
def load_pkl(file_or_url):
with open(file_or_url, 'rb') as file:
return pickle.load(file, encoding='latin1')
#----------------------------------------------------------------------------
### For editing
def visual(output, out_path):
import torch
import cv2
import numpy as np
output = (output + 1)/2
output = torch.clamp(output, 0, 1)
if output.shape[1] == 1:
output = torch.cat([output, output, output], 1)
output = output[0].detach().cpu().permute(1,2,0).numpy()
output = (output*255).astype(np.uint8)
output = output[:,:,::-1]
cv2.imwrite(out_path, output)
def save_obj(obj, path):
with open(path, 'wb+') as f:
pickle.dump(obj, f, protocol=4)
#----------------------------------------------------------------------------
## Converting pkl to pth, change dict info inside pickle
def convert_to_rgb(state_ros, state_nv, ros_name, nv_name):
state_ros[f"{ros_name}.conv.weight"] = state_nv[f"{nv_name}.torgb.weight"].unsqueeze(0)
state_ros[f"{ros_name}.bias"] = state_nv[f"{nv_name}.torgb.bias"].unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
state_ros[f"{ros_name}.conv.modulation.weight"] = state_nv[f"{nv_name}.torgb.affine.weight"]
state_ros[f"{ros_name}.conv.modulation.bias"] = state_nv[f"{nv_name}.torgb.affine.bias"]
def convert_conv(state_ros, state_nv, ros_name, nv_name):
state_ros[f"{ros_name}.conv.weight"] = state_nv[f"{nv_name}.weight"].unsqueeze(0)
state_ros[f"{ros_name}.activate.bias"] = state_nv[f"{nv_name}.bias"]
state_ros[f"{ros_name}.conv.modulation.weight"] = state_nv[f"{nv_name}.affine.weight"]
state_ros[f"{ros_name}.conv.modulation.bias"] = state_nv[f"{nv_name}.affine.bias"]
state_ros[f"{ros_name}.noise.weight"] = state_nv[f"{nv_name}.noise_strength"].unsqueeze(0)
def convert_blur_kernel(state_ros, state_nv, level):
"""Not quite sure why there is a factor of 4 here"""
# They are all the same
state_ros[f"convs.{2*level}.conv.blur.kernel"] = 4*state_nv["synthesis.b4.resample_filter"]
state_ros[f"to_rgbs.{level}.upsample.kernel"] = 4*state_nv["synthesis.b4.resample_filter"]
def determine_config(state_nv):
mapping_names = [name for name in state_nv.keys() if "mapping.fc" in name]
sythesis_names = [name for name in state_nv.keys() if "synthesis.b" in name]
n_mapping = max([int(re.findall("(\d+)", n)[0]) for n in mapping_names]) + 1
resolution = max([int(re.findall("(\d+)", n)[0]) for n in sythesis_names])
n_layers = np.log(resolution/2)/np.log(2)
return n_mapping, n_layers
def convert(network_pkl, output_file, G_only=False):
with dnnlib.util.open_url(network_pkl) as f:
G_nvidia = load_network_pkl(f,G_only=G_only)['G_ema']
state_nv = G_nvidia.state_dict()
n_mapping, n_layers = determine_config(state_nv)
state_ros = {}
for i in range(n_mapping):
state_ros[f"style.{i+1}.weight"] = state_nv[f"mapping.fc{i}.weight"]
state_ros[f"style.{i+1}.bias"] = state_nv[f"mapping.fc{i}.bias"]
for i in range(int(n_layers)):
if i > 0:
for conv_level in range(2):
convert_conv(state_ros, state_nv, f"convs.{2*i-2+conv_level}", f"synthesis.b{4*(2**i)}.conv{conv_level}")
state_ros[f"noises.noise_{2*i-1+conv_level}"] = state_nv[f"synthesis.b{4*(2**i)}.conv{conv_level}.noise_const"].unsqueeze(0).unsqueeze(0)
convert_to_rgb(state_ros, state_nv, f"to_rgbs.{i-1}", f"synthesis.b{4*(2**i)}")
convert_blur_kernel(state_ros, state_nv, i-1)
else:
state_ros[f"input.input"] = state_nv[f"synthesis.b{4*(2**i)}.const"].unsqueeze(0)
convert_conv(state_ros, state_nv, "conv1", f"synthesis.b{4*(2**i)}.conv1")
state_ros[f"noises.noise_{2*i}"] = state_nv[f"synthesis.b{4*(2**i)}.conv1.noise_const"].unsqueeze(0).unsqueeze(0)
convert_to_rgb(state_ros, state_nv, "to_rgb1", f"synthesis.b{4*(2**i)}")
# https://github.com/yuval-alaluf/restyle-encoder/issues/1#issuecomment-828354736
latent_avg = state_nv['mapping.w_avg']
state_dict = {"g_ema": state_ros, "latent_avg": latent_avg}
# if G_only:
# f = open('converted_model_Gonly.txt','a+')
# else:
# f = open('converted_model.txt','a+')
# for key in state_dict['g_ema'].keys():
# f.write(str(key)+': '+str(state_dict['g_ema'][key].shape)+'\n')
# f.close()
torch.save(state_dict, output_file)
================================================
FILE: stylegan_human/openpose/model/.gitkeep
================================================
================================================
FILE: stylegan_human/openpose/src/__init__.py
================================================
================================================
FILE: stylegan_human/openpose/src/body.py
================================================
import cv2
import numpy as np
import math
import time
from scipy.ndimage.filters import gaussian_filter
import matplotlib.pyplot as plt
import matplotlib
import torch
from torchvision import transforms
from openpose.src import util
from openpose.src.model import bodypose_model
class Body(object):
def __init__(self, model_path):
self.model = bodypose_model()
if torch.cuda.is_available():
self.model = self.model.cuda()
model_dict = util.transfer(self.model, torch.load(model_path))
self.model.load_state_dict(model_dict)
self.model.eval()
def __call__(self, oriImg):
# scale_search = [0.5, 1.0, 1.5, 2.0]
scale_search = [0.5]
boxsize = 368
stride = 8
padValue = 128
thre1 = 0.1
thre2 = 0.05
multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
for m in range(len(multiplier)):
scale = multiplier[m]
imageToTest = cv2.resize(oriImg, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
im = np.ascontiguousarray(im)
data = torch.from_numpy(im).float()
if torch.cuda.is_available():
data = data.cuda()
# data = data.permute([2, 0, 1]).unsqueeze(0).float()
with torch.no_grad():
Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
# extract outputs, resize, and remove padding
# heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps
heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps
heatmap = cv2.resize(heatmap, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
# paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs
paf = cv2.resize(paf, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
paf = cv2.resize(paf, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
heatmap_avg += heatmap_avg + heatmap / len(multiplier)
paf_avg += + paf / len(multiplier)
all_peaks = []
peak_counter = 0
for part in range(18):
map_ori = heatmap_avg[:, :, part]
one_heatmap = gaussian_filter(map_ori, sigma=3)
map_left = np.zeros(one_heatmap.shape)
map_left[1:, :] = one_heatmap[:-1, :]
map_right = np.zeros(one_heatmap.shape)
map_right[:-1, :] = one_heatmap[1:, :]
map_up = np.zeros(one_heatmap.shape)
map_up[:, 1:] = one_heatmap[:, :-1]
map_down = np.zeros(one_heatmap.shape)
map_down[:, :-1] = one_heatmap[:, 1:]
peaks_binary = np.logical_and.reduce(
(one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1))
peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse
peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
peak_id = range(peak_counter, peak_counter + len(peaks))
peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))]
all_peaks.append(peaks_with_score_and_id)
peak_counter += len(peaks)
# find connection in the specified sequence, center 29 is in the position 15
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
[1, 16], [16, 18], [3, 17], [6, 18]]
# the middle joints heatmap correpondence
mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \
[23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \
[55, 56], [37, 38], [45, 46]]
connection_all = []
special_k = []
mid_num = 10
for k in range(len(mapIdx)):
score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
candA = all_peaks[limbSeq[k][0] - 1]
candB = all_peaks[limbSeq[k][1] - 1]
nA = len(candA)
nB = len(candB)
indexA, indexB = limbSeq[k]
if (nA != 0 and nB != 0):
connection_candidate = []
for i in range(nA):
for j in range(nB):
vec = np.subtract(candB[j][:2], candA[i][:2])
norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
norm = max(0.001, norm)
vec = np.divide(vec, norm)
startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \
np.linspace(candA[i][1], candB[j][1], num=mid_num)))
vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \
for I in range(len(startend))])
vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \
for I in range(len(startend))])
score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1])
score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min(
0.5 * oriImg.shape[0] / norm - 1, 0)
criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts)
criterion2 = score_with_dist_prior > 0
if criterion1 and criterion2:
connection_candidate.append(
[i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]])
connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True)
connection = np.zeros((0, 5))
for c in range(len(connection_candidate)):
i, j, s = connection_candidate[c][0:3]
if (i not in connection[:, 3] and j not in connection[:, 4]):
connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]])
if (len(connection) >= min(nA, nB)):
break
connection_all.append(connection)
else:
special_k.append(k)
connection_all.append([])
# last number in each row is the total parts number of that person
# the second last number in each row is the score of the overall configuration
subset = -1 * np.ones((0, 20))
candidate = np.array([item for sublist in all_peaks for item in sublist])
for k in range(len(mapIdx)):
if k not in special_k:
partAs = connection_all[k][:, 0]
partBs = connection_all[k][:, 1]
indexA, indexB = np.array(limbSeq[k]) - 1
for i in range(len(connection_all[k])): # = 1:size(temp,1)
found = 0
subset_idx = [-1, -1]
for j in range(len(subset)): # 1:size(subset,1):
if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]:
subset_idx[found] = j
found += 1
if found == 1:
j = subset_idx[0]
if subset[j][indexB] != partBs[i]:
subset[j][indexB] = partBs[i]
subset[j][-1] += 1
subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
elif found == 2: # if found 2 and disjoint, merge them
j1, j2 = subset_idx
membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2]
if len(np.nonzero(membership == 2)[0]) == 0: # merge
subset[j1][:-2] += (subset[j2][:-2] + 1)
subset[j1][-2:] += subset[j2][-2:]
subset[j1][-2] += connection_all[k][i][2]
subset = np.delete(subset, j2, 0)
else: # as like found == 1
subset[j1][indexB] = partBs[i]
subset[j1][-1] += 1
subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
# if find no partA in the subset, create a new subset
elif not found and k < 17:
row = -1 * np.ones(20)
row[indexA] = partAs[i]
row[indexB] = partBs[i]
row[-1] = 2
row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2]
subset = np.vstack([subset, row])
# delete some rows of subset which has few parts occur
deleteIdx = []
for i in range(len(subset)):
if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
deleteIdx.append(i)
subset = np.delete(subset, deleteIdx, axis=0)
# subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts
# candidate: x, y, score, id
return candidate, subset
if __name__ == "__main__":
body_estimation = Body('../model/body_pose_model.pth')
test_image = '../images/ski.jpg'
oriImg = cv2.imread(test_image) # B,G,R order
candidate, subset = body_estimation(oriImg)
canvas = util.draw_bodypose(oriImg, candidate, subset)
plt.imshow(canvas[:, :, [2, 1, 0]])
plt.show()
================================================
FILE: stylegan_human/openpose/src/model.py
================================================
import torch
from collections import OrderedDict
import torch
import torch.nn as nn
def make_layers(block, no_relu_layers):
layers = []
for layer_name, v in block.items():
if 'pool' in layer_name:
layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1],
padding=v[2])
layers.append((layer_name, layer))
else:
conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1],
kernel_size=v[2], stride=v[3],
padding=v[4])
layers.append((layer_name, conv2d))
if layer_name not in no_relu_layers:
layers.append(('relu_'+layer_name, nn.ReLU(inplace=True)))
return nn.Sequential(OrderedDict(layers))
class bodypose_model(nn.Module):
def __init__(self):
super(bodypose_model, self).__init__()
# these layers have no relu layer
no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\
'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\
'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\
'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1']
blocks = {}
block0 = OrderedDict([
('conv1_1', [3, 64, 3, 1, 1]),
('conv1_2', [64, 64, 3, 1, 1]),
('pool1_stage1', [2, 2, 0]),
('conv2_1', [64, 128, 3, 1, 1]),
('conv2_2', [128, 128, 3, 1, 1]),
('pool2_stage1', [2, 2, 0]),
('conv3_1', [128, 256, 3, 1, 1]),
('conv3_2', [256, 256, 3, 1, 1]),
('conv3_3', [256, 256, 3, 1, 1]),
('conv3_4', [256, 256, 3, 1, 1]),
('pool3_stage1', [2, 2, 0]),
('conv4_1', [256, 512, 3, 1, 1]),
('conv4_2', [512, 512, 3, 1, 1]),
('conv4_3_CPM', [512, 256, 3, 1, 1]),
('conv4_4_CPM', [256, 128, 3, 1, 1])
])
# Stage 1
block1_1 = OrderedDict([
('conv5_1_CPM_L1', [128, 128, 3, 1, 1]),
('conv5_2_CPM_L1', [128, 128, 3, 1, 1]),
('conv5_3_CPM_L1', [128, 128, 3, 1, 1]),
('conv5_4_CPM_L1', [128, 512, 1, 1, 0]),
('conv5_5_CPM_L1', [512, 38, 1, 1, 0])
])
block1_2 = OrderedDict([
('conv5_1_CPM_L2', [128, 128, 3, 1, 1]),
('conv5_2_CPM_L2', [128, 128, 3, 1, 1]),
('conv5_3_CPM_L2', [128, 128, 3, 1, 1]),
('conv5_4_CPM_L2', [128, 512, 1, 1, 0]),
('conv5_5_CPM_L2', [512, 19, 1, 1, 0])
])
blocks['block1_1'] = block1_1
blocks['block1_2'] = block1_2
self.model0 = make_layers(block0, no_relu_layers)
# Stages 2 - 6
for i in range(2, 7):
blocks['block%d_1' % i] = OrderedDict([
('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]),
('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]),
('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]),
('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]),
('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]),
('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]),
('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0])
])
blocks['block%d_2' % i] = OrderedDict([
('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]),
('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]),
('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]),
('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]),
('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]),
('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]),
('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0])
])
for k in blocks.keys():
blocks[k] = make_layers(blocks[k], no_relu_layers)
self.model1_1 = blocks['block1_1']
self.model2_1 = blocks['block2_1']
self.model3_1 = blocks['block3_1']
self.model4_1 = blocks['block4_1']
self.model5_1 = blocks['block5_1']
self.model6_1 = blocks['block6_1']
self.model1_2 = blocks['block1_2']
self.model2_2 = blocks['block2_2']
self.model3_2 = blocks['block3_2']
self.model4_2 = blocks['block4_2']
self.model5_2 = blocks['block5_2']
self.model6_2 = blocks['block6_2']
def forward(self, x):
out1 = self.model0(x)
out1_1 = self.model1_1(out1)
out1_2 = self.model1_2(out1)
out2 = torch.cat([out1_1, out1_2, out1], 1)
out2_1 = self.model2_1(out2)
out2_2 = self.model2_2(out2)
out3 = torch.cat([out2_1, out2_2, out1], 1)
out3_1 = self.model3_1(out3)
out3_2 = self.model3_2(out3)
out4 = torch.cat([out3_1, out3_2, out1], 1)
out4_1 = self.model4_1(out4)
out4_2 = self.model4_2(out4)
out5 = torch.cat([out4_1, out4_2, out1], 1)
out5_1 = self.model5_1(out5)
out5_2 = self.model5_2(out5)
out6 = torch.cat([out5_1, out5_2, out1], 1)
out6_1 = self.model6_1(out6)
out6_2 = self.model6_2(out6)
return out6_1, out6_2
class handpose_model(nn.Module):
def __init__(self):
super(handpose_model, self).__init__()
# these layers have no relu layer
no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\
'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6']
# stage 1
block1_0 = OrderedDict([
('conv1_1', [3, 64, 3, 1, 1]),
('conv1_2', [64, 64, 3, 1, 1]),
('pool1_stage1', [2, 2, 0]),
('conv2_1', [64, 128, 3, 1, 1]),
('conv2_2', [128, 128, 3, 1, 1]),
('pool2_stage1', [2, 2, 0]),
('conv3_1', [128, 256, 3, 1, 1]),
('conv3_2', [256, 256, 3, 1, 1]),
('conv3_3', [256, 256, 3, 1, 1]),
('conv3_4', [256, 256, 3, 1, 1]),
('pool3_stage1', [2, 2, 0]),
('conv4_1', [256, 512, 3, 1, 1]),
('conv4_2', [512, 512, 3, 1, 1]),
('conv4_3', [512, 512, 3, 1, 1]),
('conv4_4', [512, 512, 3, 1, 1]),
('conv5_1', [512, 512, 3, 1, 1]),
('conv5_2', [512, 512, 3, 1, 1]),
('conv5_3_CPM', [512, 128, 3, 1, 1])
])
block1_1 = OrderedDict([
('conv6_1_CPM', [128, 512, 1, 1, 0]),
('conv6_2_CPM', [512, 22, 1, 1, 0])
])
blocks = {}
blocks['block1_0'] = block1_0
blocks['block1_1'] = block1_1
# stage 2-6
for i in range(2, 7):
blocks['block%d' % i] = OrderedDict([
('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]),
('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]),
('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]),
('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]),
('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]),
('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]),
('Mconv7_stage%d' % i, [128, 22, 1, 1, 0])
])
for k in blocks.keys():
blocks[k] = make_layers(blocks[k], no_relu_layers)
self.model1_0 = blocks['block1_0']
self.model1_1 = blocks['block1_1']
self.model2 = blocks['block2']
self.model3 = blocks['block3']
self.model4 = blocks['block4']
self.model5 = blocks['block5']
self.model6 = blocks['block6']
def forward(self, x):
out1_0 = self.model1_0(x)
out1_1 = self.model1_1(out1_0)
concat_stage2 = torch.cat([out1_1, out1_0], 1)
out_stage2 = self.model2(concat_stage2)
concat_stage3 = torch.cat([out_stage2, out1_0], 1)
out_stage3 = self.model3(concat_stage3)
concat_stage4 = torch.cat([out_stage3, out1_0], 1)
out_stage4 = self.model4(concat_stage4)
concat_stage5 = torch.cat([out_stage4, out1_0], 1)
out_stage5 = self.model5(concat_stage5)
concat_stage6 = torch.cat([out_stage5, out1_0], 1)
out_stage6 = self.model6(concat_stage6)
return out_stage6
================================================
FILE: stylegan_human/openpose/src/util.py
================================================
import numpy as np
import math
import cv2
import matplotlib
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
import numpy as np
import matplotlib.pyplot as plt
import cv2
def padRightDownCorner(img, stride, padValue):
h = img.shape[0]
w = img.shape[1]
pad = 4 * [None]
pad[0] = 0 # up
pad[1] = 0 # left
pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
img_padded = img
pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1))
img_padded = np.concatenate((pad_up, img_padded), axis=0)
pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1))
img_padded = np.concatenate((pad_left, img_padded), axis=1)
pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1))
img_padded = np.concatenate((img_padded, pad_down), axis=0)
pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1))
img_padded = np.concatenate((img_padded, pad_right), axis=1)
return img_padded, pad
# transfer caffe model to pytorch which will match the layer name
def transfer(model, model_weights):
transfered_model_weights = {}
for weights_name in model.state_dict().keys():
transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
return transfered_model_weights
# draw the body keypoint and lims
def draw_bodypose(canvas, candidate, subset,show_number=False):
stickwidth = 4
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
[1, 16], [16, 18], [3, 17], [6, 18]]
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
for i in range(18):
for n in range(len(subset)):
index = int(subset[n][i])
if index == -1:
continue
x, y = candidate[index][0:2]
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
if show_number:
cv2.putText(canvas, f'{index}', (int(x), int(y)),cv2.FONT_HERSHEY_SIMPLEX, 0.6,
(255,255,0), 1, cv2.LINE_AA)
## calc and print average
for i in range(17):
for n in range(len(subset)):
index = subset[n][np.array(limbSeq[i]) - 1]
if -1 in index:
continue
cur_canvas = canvas.copy()
Y = candidate[index.astype(int), 0]
X = candidate[index.astype(int), 1]
mX = np.mean(X)
mY = np.mean(Y)
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
return canvas
# get max index of 2d array
def npmax(array):
arrayindex = array.argmax(1)
arrayvalue = array.max(1)
i = arrayvalue.argmax()
j = arrayindex[i]
return i, j
# get max index of 2d array
def npmax_with_score(array):
arrayindex = array.argmax(1)
arrayvalue = array.max(1)
i = arrayvalue.argmax()
j = arrayindex[i]
score =array[i][j]
return i, j,score
================================================
FILE: stylegan_human/pti/pti_configs/__init__.py
================================================
================================================
FILE: stylegan_human/pti/pti_configs/global_config.py
================================================
## Device
cuda_visible_devices = '0'
device = 'cuda:0'
## Logs
training_step = 1
image_rec_result_log_snapshot = 100
pivotal_training_steps = 0
model_snapshot_interval = 400
## Run name to be updated during PTI
run_name = 'exp'
================================================
FILE: stylegan_human/pti/pti_configs/hyperparameters.py
================================================
## Architechture
lpips_type = 'alex'
first_inv_type = 'w+'#'w+'
optim_type = 'adam'
## Locality regularization
latent_ball_num_of_samples = 1
locality_regularization_interval = 1
use_locality_regularization = False
regulizer_l2_lambda = 0.1
regulizer_lpips_lambda = 0.1
regulizer_alpha = 30
## Loss
pt_l2_lambda = 1
pt_lpips_lambda = 1
## Steps
LPIPS_value_threshold = 0.04
max_pti_steps = 350
first_inv_steps = 450
max_images_to_invert = 30
## Optimization
pti_learning_rate = 5e-4
first_inv_lr = 8e-3
train_batch_size = 1
use_last_w_pivots = False
================================================
FILE: stylegan_human/pti/pti_configs/paths_config.py
================================================
import os
## Pretrained models paths
e4e = './pti/e4e_w+.pt'
stylegan2_ada_shhq = './pretrained_models/stylegan_human_v2_1024.pkl'
ir_se50 = '' #'./model_ir_se50.pth'
## Dirs for output files
checkpoints_dir = './outputs/pti/checkpoints/'
embedding_base_dir = './outputs/pti/embeddings'
experiments_output_dir = './outputs/pti/'
## Input info
### Input dir, where the images reside
input_data_path = 'aligned_image/'
### Inversion identifier, used to keeping track of the inversion results. Both the latent code and the generator
input_data_id = 'test'
## Keywords
pti_results_keyword = 'PTI'
e4e_results_keyword = 'e4e'
sg2_results_keyword = 'SG2'
sg2_plus_results_keyword = 'SG2_Plus'
multi_id_model_type = 'multi_id'
================================================
FILE: stylegan_human/pti/pti_models/__init__.py
================================================
================================================
FILE: stylegan_human/pti/pti_models/e4e/__init__.py
================================================
================================================
FILE: stylegan_human/pti/pti_models/e4e/encoders/__init__.py
================================================
================================================
FILE: stylegan_human/pti/pti_models/e4e/encoders/helpers.py
================================================
from collections import namedtuple
import torch
import torch.nn.functional as F
from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
"""
ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
"""
class Flatten(Module):
def forward(self, input):
return input.view(input.size(0), -1)
def l2_norm(input, axis=1):
norm = torch.norm(input, 2, axis, True)
output = torch.div(input, norm)
return output
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
""" A named tuple describing a ResNet block. """
def get_block(in_channel, depth, num_units, stride=2):
return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
def get_blocks(num_layers):
if num_layers == 50:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=4),
get_block(in_channel=128, depth=256, num_units=14),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 100:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=13),
get_block(in_channel=128, depth=256, num_units=30),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 152:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=8),
get_block(in_channel=128, depth=256, num_units=36),
get_block(in_channel=256, depth=512, num_units=3)
]
else:
raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
return blocks
class SEModule(Module):
def __init__(self, channels, reduction):
super(SEModule, self).__init__()
self.avg_pool = AdaptiveAvgPool2d(1)
self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
self.relu = ReLU(inplace=True)
self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
self.sigmoid = Sigmoid()
def forward(self, x):
module_input = x
x = self.avg_pool(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return module_input * x
class bottleneck_IR(Module):
def __init__(self, in_channel, depth, stride):
super(bottleneck_IR, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth)
)
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
)
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
class bottleneck_IR_SE(Module):
def __init__(self, in_channel, depth, stride):
super(bottleneck_IR_SE, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth)
)
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
PReLU(depth),
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth),
SEModule(depth, 16)
)
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
def _upsample_add(x, y):
"""Upsample and add two feature maps.
Args:
x: (Variable) top feature map to be upsampled.
y: (Variable) lateral feature map.
Returns:
(Variable) added feature map.
Note in PyTorch, when input size is odd, the upsampled feature map
with `F.upsample(..., scale_factor=2, mode='nearest')`
maybe not equal to the lateral feature map size.
e.g.
original input size: [N,_,15,15] ->
conv2d feature map size: [N,_,8,8] ->
upsampled feature map size: [N,_,16,16]
So we choose bilinear upsample which supports arbitrary output sizes.
"""
_, _, H, W = y.size()
return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y
================================================
FILE: stylegan_human/pti/pti_models/e4e/encoders/model_irse.py
================================================
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
from encoder4editing.models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
"""
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
"""
class Backbone(Module):
def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
super(Backbone, self).__init__()
assert input_size in [112, 224], "input_size should be 112 or 224"
assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
blocks = get_blocks(num_layers)
if mode == 'ir':
unit_module = bottleneck_IR
elif mode == 'ir_se':
unit_module = bottleneck_IR_SE
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
BatchNorm2d(64),
PReLU(64))
if input_size == 112:
self.output_layer = Sequential(BatchNorm2d(512),
Dropout(drop_ratio),
Flatten(),
Linear(512 * 7 * 7, 512),
BatchNorm1d(512, affine=affine))
else:
self.output_layer = Sequential(BatchNorm2d(512),
Dropout(drop_ratio),
Flatten(),
Linear(512 * 14 * 14, 512),
BatchNorm1d(512, affine=affine))
modules = []
for block in blocks:
for bottleneck in block:
modules.append(unit_module(bottleneck.in_channel,
bottleneck.depth,
bottleneck.stride))
self.body = Sequential(*modules)
def forward(self, x):
x = self.input_layer(x)
x = self.body(x)
x = self.output_layer(x)
return l2_norm(x)
def IR_50(input_size):
"""Constructs a ir-50 model."""
model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
return model
def IR_101(input_size):
"""Constructs a ir-101 model."""
model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
return model
def IR_152(input_size):
"""Constructs a ir-152 model."""
model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
return model
def IR_SE_50(input_size):
"""Constructs a ir_se-50 model."""
model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
return model
def IR_SE_101(input_size):
"""Constructs a ir_se-101 model."""
model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
return model
def IR_SE_152(input_size):
"""Constructs a ir_se-152 model."""
model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
return model
================================================
FILE: stylegan_human/pti/pti_models/e4e/encoders/psp_encoders.py
================================================
from enum import Enum
import math
import numpy as np
import torch
from torch import nn
from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module
from pti.pti_models.e4e.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE, _upsample_add
from pti.pti_models.e4e.stylegan2.model import EqualLinear
class ProgressiveStage(Enum):
WTraining = 0
Delta1Training = 1
Delta2Training = 2
Delta3Training = 3
Delta4Training = 4
Delta5Training = 5
Delta6Training = 6
Delta7Training = 7
Delta8Training = 8
Delta9Training = 9
Delta10Training = 10
Delta11Training = 11
Delta12Training = 12
Delta13Training = 13
Delta14Training = 14
Delta15Training = 15
Delta16Training = 16
Delta17Training = 17
Inference = 18
class GradualStyleBlock(Module):
def __init__(self, in_c, out_c, spatial):
super(GradualStyleBlock, self).__init__()
self.out_c = out_c
self.spatial = spatial
num_pools = int(np.log2(spatial))
modules = []
modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU()]
for i in range(num_pools - 1):
modules += [
Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU()
]
self.convs = nn.Sequential(*modules)
self.linear = EqualLinear(out_c, out_c, lr_mul=1)
def forward(self, x):
x = self.convs(x)
x = x.view(-1, self.out_c)
x = self.linear(x)
return x
class GradualStyleEncoder(Module):
def __init__(self, num_layers, mode='ir', opts=None):
super(GradualStyleEncoder, self).__init__()
assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
blocks = get_blocks(num_layers)
if mode == 'ir':
unit_module = bottleneck_IR
elif mode == 'ir_se':
unit_module = bottleneck_IR_SE
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
BatchNorm2d(64),
PReLU(64))
modules = []
for block in blocks:
for bottleneck in block:
modules.append(unit_module(bottleneck.in_channel,
bottleneck.depth,
bottleneck.stride))
self.body = Sequential(*modules)
self.styles = nn.ModuleList()
log_size = int(math.log(opts.stylegan_size, 2))
self.style_count = 2 * log_size - 2
self.coarse_ind = 3
self.middle_ind = 7
for i in range(self.style_count):
if i < self.coarse_ind:
style = GradualStyleBlock(512, 512, 16)
elif i < self.middle_ind:
style = GradualStyleBlock(512, 512, 32)
else:
style = GradualStyleBlock(512, 512, 64)
self.styles.append(style)
self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
def forward(self, x):
x = self.input_layer(x)
latents = []
modulelist = list(self.body._modules.values())
for i, l in enumerate(modulelist):
x = l(x)
if i == 6:
c1 = x
elif i == 20:
c2 = x
elif i == 23:
c3 = x
for j in range(self.coarse_ind):
latents.append(self.styles[j](c3))
p2 = _upsample_add(c3, self.latlayer1(c2))
for j in range(self.coarse_ind, self.middle_ind):
latents.append(self.styles[j](p2))
p1 = _upsample_add(p2, self.latlayer2(c1))
for j in range(self.middle_ind, self.style_count):
latents.append(self.styles[j](p1))
out = torch.stack(latents, dim=1)
return out
class Encoder4Editing(Module):
def __init__(self, num_layers, mode='ir', opts=None):
super(Encoder4Editing, self).__init__()
assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
blocks = get_blocks(num_layers)
if mode == 'ir':
unit_module = bottleneck_IR
elif mode == 'ir_se':
unit_module = bottleneck_IR_SE
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
BatchNorm2d(64),
PReLU(64))
modules = []
for block in blocks:
for bottleneck in block:
modules.append(unit_module(bottleneck.in_channel,
bottleneck.depth,
bottleneck.stride))
self.body = Sequential(*modules)
self.styles = nn.ModuleList()
log_size = int(math.log(opts.stylegan_size, 2))
self.style_count = 2 * log_size - 2
self.coarse_ind = 3
self.middle_ind = 7
for i in range(self.style_count):
if i < self.coarse_ind:
style = GradualStyleBlock(512, 512, 16)
elif i < self.middle_ind:
style = GradualStyleBlock(512, 512, 32)
else:
style = GradualStyleBlock(512, 512, 64)
self.styles.append(style)
self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
self.progressive_stage = ProgressiveStage.Inference
def get_deltas_starting_dimensions(self):
''' Get a list of the initial dimension of every delta from which it is applied '''
return list(range(self.style_count)) # Each dimension has a delta applied to it
def set_progressive_stage(self, new_stage: ProgressiveStage):
self.progressive_stage = new_stage
print('Changed progressive stage to: ', new_stage)
def forward(self, x):
x = self.input_layer(x)
modulelist = list(self.body._modules.values())
for i, l in enumerate(modulelist):
x = l(x)
if i == 6:
c1 = x
elif i == 20:
c2 = x
elif i == 23:
c3 = x
# Infer main W and duplicate it
w0 = self.styles[0](c3)
w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2)
stage = self.progressive_stage.value
features = c3
for i in range(1, min(stage + 1, self.style_count)): # Infer additional deltas
if i == self.coarse_ind:
p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features
features = p2
elif i == self.middle_ind:
p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features
features = p1
delta_i = self.styles[i](features)
w[:, i] += delta_i
return w
================================================
FILE: stylegan_human/pti/pti_models/e4e/latent_codes_pool.py
================================================
import random
import torch
class LatentCodesPool:
"""This class implements latent codes buffer that stores previously generated w latent codes.
This buffer enables us to update discriminators using a history of generated w's
rather than the ones produced by the latest encoder.
"""
def __init__(self, pool_size):
"""Initialize the ImagePool class
Parameters:
pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
"""
self.pool_size = pool_size
if self.pool_size > 0: # create an empty pool
self.num_ws = 0
self.ws = []
def query(self, ws):
"""Return w's from the pool.
Parameters:
ws: the latest generated w's from the generator
Returns w's from the buffer.
By 50/100, the buffer will return input w's.
By 50/100, the buffer will return w's previously stored in the buffer,
and insert the current w's to the buffer.
"""
if self.pool_size == 0: # if the buffer size is 0, do nothing
return ws
return_ws = []
for w in ws: # ws.shape: (batch, 512) or (batch, n_latent, 512)
# w = torch.unsqueeze(image.data, 0)
if w.ndim == 2:
i = random.randint(0, len(w) - 1) # apply a random latent index as a candidate
w = w[i]
self.handle_w(w, return_ws)
return_ws = torch.stack(return_ws, 0) # collect all the images and return
return return_ws
def handle_w(self, w, return_ws):
if self.num_ws < self.pool_size: # if the buffer is not full; keep inserting current codes to the buffer
self.num_ws = self.num_ws + 1
self.ws.append(w)
return_ws.append(w)
else:
p = random.uniform(0, 1)
if p > 0.5: # by 50% chance, the buffer will return a previously stored latent code, and insert the current code into the buffer
random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
tmp = self.ws[random_id].clone()
self.ws[random_id] = w
return_ws.append(tmp)
else: # by another 50% chance, the buffer will return the current image
return_ws.append(w)
================================================
FILE: stylegan_human/pti/pti_models/e4e/psp.py
================================================
import matplotlib
from pti.pti_configs import paths_config
matplotlib.use('Agg')
import torch
from torch import nn
from pti.pti_models.e4e.encoders import psp_encoders
from pti.pti_models.e4e.stylegan2.model import Generator
def get_keys(d, name):
if 'state_dict' in d:
d = d['state_dict']
d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
return d_filt
class pSp(nn.Module):
def __init__(self, opts):
super(pSp, self).__init__()
self.opts = opts
# Define architecture
self.encoder = self.set_encoder()
self.decoder = Generator(opts.stylegan_size, 512, 8, channel_multiplier=2)
self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256 // 2))
# Load weights if needed
self.load_weights()
def set_encoder(self):
if self.opts.encoder_type == 'GradualStyleEncoder':
encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts)
elif self.opts.encoder_type == 'Encoder4Editing':
encoder = psp_encoders.Encoder4Editing(50, 'ir_se', self.opts)
elif self.opts.encoder_type == 'SingleStyleCodeEncoder':
encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts)
else:
raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
return encoder
def load_weights(self):
if self.opts.checkpoint_path is not None:
print('Loading e4e over the pSp framework from checkpoint: {}'.format(self.opts.checkpoint_path))
ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
self.__load_latent_avg(ckpt)
else:
print('Loading encoders weights from irse50!')
encoder_ckpt = torch.load(model_paths['ir_se50'])
self.encoder.load_state_dict(encoder_ckpt, strict=False)
print('Loading decoder weights from pretrained!')
ckpt = torch.load(self.opts.stylegan_weights)
self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
self.__load_latent_avg(ckpt, repeat=self.encoder.style_count)
def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
inject_latent=None, return_latents=False, alpha=None):
if input_code:
codes = x
else:
codes = self.encoder(x)
# normalize with respect to the center of an average face
if self.opts.start_from_latent_avg:
if codes.ndim == 2:
codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
else:
codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
if latent_mask is not None:
for i in latent_mask:
if inject_latent is not None:
if alpha is not None:
codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
else:
codes[:, i] = inject_latent[:, i]
else:
codes[:, i] = 0
input_is_latent = not input_code
images, result_latent = self.decoder([codes],
input_is_latent=input_is_latent,
randomize_noise=randomize_noise,
return_latents=return_latents)
if resize:
images = self.face_pool(images)
if return_latents:
return images, result_latent
else:
return images
def __load_latent_avg(self, ckpt, repeat=None):
if 'latent_avg' in ckpt:
self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
if repeat is not None:
self.latent_avg = self.latent_avg.repeat(repeat, 1)
else:
self.latent_avg = None
================================================
FILE: stylegan_human/pti/pti_models/e4e/stylegan2/__init__.py
================================================
================================================
FILE: stylegan_human/pti/pti_models/e4e/stylegan2/model.py
================================================
import math
import random
import torch
from torch import nn
from torch.nn import functional as F
from .op.fused_act import FusedLeakyReLU, fused_leaky_relu
from .op.upfirdn2d import upfirdn2d
class PixelNorm(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input):
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
def make_kernel(k):
k = torch.tensor(k, dtype=torch.float32)
if k.ndim == 1:
k = k[None, :] * k[:, None]
k /= k.sum()
return k
class Upsample(nn.Module):
def __init__(self, kernel, factor=2):
super().__init__()
self.factor = factor
kernel = make_kernel(kernel) * (factor ** 2)
self.register_buffer('kernel', kernel)
p = kernel.shape[0] - factor
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2
self.pad = (pad0, pad1)
def forward(self, input):
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
return out
class Downsample(nn.Module):
def __init__(self, kernel, factor=2):
super().__init__()
self.factor = factor
kernel = make_kernel(kernel)
self.register_buffer('kernel', kernel)
p = kernel.shape[0] - factor
pad0 = (p + 1) // 2
pad1 = p // 2
self.pad = (pad0, pad1)
def forward(self, input):
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
return out
class Blur(nn.Module):
def __init__(self, kernel, pad, upsample_factor=1):
super().__init__()
kernel = make_kernel(kernel)
if upsample_factor > 1:
kernel = kernel * (upsample_factor ** 2)
self.register_buffer('kernel', kernel)
self.pad = pad
def forward(self, input):
out = upfirdn2d(input, self.kernel, pad=self.pad)
return out
class EqualConv2d(nn.Module):
def __init__(
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
):
super().__init__()
self.weight = nn.Parameter(
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
)
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
self.stride = stride
self.padding = padding
if bias:
self.bias = nn.Parameter(torch.zeros(out_channel))
else:
self.bias = None
def forward(self, input):
out = F.conv2d(
input,
self.weight * self.scale,
bias=self.bias,
stride=self.stride,
padding=self.padding,
)
return out
def __repr__(self):
return (
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
)
class EqualLinear(nn.Module):
def __init__(
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
if bias:
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
else:
self.bias = None
self.activation = activation
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
self.lr_mul = lr_mul
def forward(self, input):
if self.activation:
out = F.linear(input, self.weight * self.scale)
out = fused_leaky_relu(out, self.bias * self.lr_mul)
else:
out = F.linear(
input, self.weight * self.scale, bias=self.bias * self.lr_mul
)
return out
def __repr__(self):
return (
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
)
class ScaledLeakyReLU(nn.Module):
def __init__(self, negative_slope=0.2):
super().__init__()
self.negative_slope = negative_slope
def forward(self, input):
out = F.leaky_relu(input, negative_slope=self.negative_slope)
return out * math.sqrt(2)
class ModulatedConv2d(nn.Module):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
style_dim,
demodulate=True,
upsample=False,
downsample=False,
blur_kernel=[1, 3, 3, 1],
):
super().__init__()
self.eps = 1e-8
self.kernel_size = kernel_size
self.in_channel = in_channel
self.out_channel = out_channel
self.upsample = upsample
self.downsample = downsample
if upsample:
factor = 2
p = (len(blur_kernel) - factor) - (kernel_size - 1)
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2 + 1
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
fan_in = in_channel * kernel_size ** 2
self.scale = 1 / math.sqrt(fan_in)
self.padding = kernel_size // 2
self.weight = nn.Parameter(
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
)
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
self.demodulate = demodulate
def __repr__(self):
return (
f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
f'upsample={self.upsample}, downsample={self.downsample})'
)
def forward(self, input, style):
batch, in_channel, height, width = input.shape
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
weight = self.scale * self.weight * style
if self.demodulate:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
weight = weight.view(
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
)
if self.upsample:
input = input.view(1, batch * in_channel, height, width)
weight = weight.view(
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
)
weight = weight.transpose(1, 2).reshape(
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
)
out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)
out = self.blur(out)
elif self.downsample:
input = self.blur(input)
_, _, height, width = input.shape
input = input.view(1, batch * in_channel, height, width)
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)
else:
input = input.view(1, batch * in_channel, height, width)
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)
return out
class NoiseInjection(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.zeros(1))
def forward(self, image, noise=None):
if noise is None:
batch, _, height, width = image.shape
noise = image.new_empty(batch, 1, height, width).normal_()
return image + self.weight * noise
class ConstantInput(nn.Module):
def __init__(self, channel, size=4):
super().__init__()
self.input = nn.Parameter(torch.randn(1, channel, size, size // 2))
def forward(self, input):
batch = input.shape[0]
out = self.input.repeat(batch, 1, 1, 1)
return out
class StyledConv(nn.Module):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
style_dim,
upsample=False,
blur_kernel=[1, 3, 3, 1],
demodulate=True,
):
super().__init__()
self.conv = ModulatedConv2d(
in_channel,
out_channel,
kernel_size,
style_dim,
upsample=upsample,
blur_kernel=blur_kernel,
demodulate=demodulate,
)
self.noise = NoiseInjection()
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
# self.activate = ScaledLeakyReLU(0.2)
self.activate = FusedLeakyReLU(out_channel)
def forward(self, input, style, noise=None):
out = self.conv(input, style)
out = self.noise(out, noise=noise)
# out = out + self.bias
out = self.activate(out)
return out
class ToRGB(nn.Module):
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
super().__init__()
if upsample:
self.upsample = Upsample(blur_kernel)
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
def forward(self, input, style, skip=None):
out = self.conv(input, style)
out = out + self.bias
if skip is not None:
skip = self.upsample(skip)
out = out + skip
return out
class Generator(nn.Module):
def __init__(
self,
size,
style_dim,
n_mlp,
channel_multiplier=2,
blur_kernel=[1, 3, 3, 1],
lr_mlp=0.01,
):
super().__init__()
self.size = size
self.style_dim = style_dim
layers = [PixelNorm()]
for i in range(n_mlp):
layers.append(
EqualLinear(
style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
)
)
self.style = nn.Sequential(*layers)
self.channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
self.input = ConstantInput(self.channels[4])
self.conv1 = StyledConv(
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
)
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
self.log_size = int(math.log(size, 2))
self.num_layers = (self.log_size - 2) * 2 + 1
self.convs = nn.ModuleList()
self.upsamples = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
self.noises = nn.Module()
in_channel = self.channels[4]
for layer_idx in range(self.num_layers):
res = (layer_idx + 5) // 2
shape = [1, 1, 2 ** res, 2 ** res // 2]
self.noises.register_buffer(
"noise_{}".format(layer_idx), torch.randn(*shape)
)
for i in range(3, self.log_size + 1):
out_channel = self.channels[2 ** i]
self.convs.append(
StyledConv(
in_channel,
out_channel,
3,
style_dim,
upsample=True,
blur_kernel=blur_kernel,
)
)
self.convs.append(
StyledConv(
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
)
)
self.to_rgbs.append(ToRGB(out_channel, style_dim))
in_channel = out_channel
self.n_latent = self.log_size * 2 - 2
def make_noise(self):
device = self.input.input.device
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2 // 2, device=device)]
for i in range(3, self.log_size + 1):
for _ in range(2):
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i // 2, device=device))
return noises
def mean_latent(self, n_latent):
latent_in = torch.randn(
n_latent, self.style_dim, device=self.input.input.device
)
latent = self.style(latent_in).mean(0, keepdim=True)
return latent
def get_latent(self, input):
return self.style(input)
def forward(
self,
styles,
return_latents=False,
return_features=False,
inject_index=None,
truncation=1,
truncation_latent=None,
input_is_latent=False,
noise=None,
randomize_noise=True,
):
if not input_is_latent:
styles = [self.style(s) for s in styles]
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers
else:
noise = [
getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
]
if truncation < 1:
style_t = []
for style in styles:
style_t.append(
truncation_latent + truncation * (style - truncation_latent)
)
styles = style_t
if len(styles) < 2:
inject_index = self.n_latent
if styles[0].ndim < 3:
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else:
latent = styles[0]
else:
if inject_index is None:
inject_index = random.randint(1, self.n_latent - 1)
# latent = styles[0].unsqueeze(0)
# if latent.shape[1] == 1:
# latent = latent.repeat(1, inject_index, 1)
# else:
# latent = latent[:, :inject_index, :]
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
# latent = styles[0][:, :inject_index, :]
# latent2 = styles[1][:, inject_index:, :]
latent = torch.cat([latent, latent2], 1)
out = self.input(latent)
out = self.conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
):
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
i += 2
image = skip
if return_latents:
return image, latent
elif return_features:
return image, out
else:
return image, None
class ConvLayer(nn.Sequential):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
downsample=False,
blur_kernel=[1, 3, 3, 1],
bias=True,
activate=True,
):
layers = []
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
stride = 2
self.padding = 0
else:
stride = 1
self.padding = kernel_size // 2
layers.append(
EqualConv2d(
in_channel,
out_channel,
kernel_size,
padding=self.padding,
stride=stride,
bias=bias and not activate,
)
)
if activate:
if bias:
layers.append(FusedLeakyReLU(out_channel))
else:
layers.append(ScaledLeakyReLU(0.2))
super().__init__(*layers)
class ResBlock(nn.Module):
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
super().__init__()
self.conv1 = ConvLayer(in_channel, in_channel, 3)
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
self.skip = ConvLayer(
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
)
def forward(self, input):
out = self.conv1(input)
out = self.conv2(out)
skip = self.skip(input)
out = (out + skip) / math.sqrt(2)
return out
class Discriminator(nn.Module):
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
super().__init__()
channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
convs = [ConvLayer(3, channels[size], 1)]
log_size = int(math.log(size, 2))
in_channel = channels[size]
for i in range(log_size, 2, -1):
out_channel = channels[2 ** (i - 1)]
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
in_channel = out_channel
self.convs = nn.Sequential(*convs)
self.stddev_group = 4
self.stddev_feat = 1
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
self.final_linear = nn.Sequential(
EqualLinear(channels[4] * 4 * 4 // 2, channels[4], activation='fused_lrelu'),
EqualLinear(channels[4], 1),
)
def forward(self, input):
out = self.convs(input)
batch, channel, height, width = out.shape
group = min(batch, self.stddev_group)
stddev = out.view(
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
)
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
stddev = stddev.repeat(group, 1, height, width)
out = torch.cat([out, stddev], 1)
out = self.final_conv(out)
out = out.view(batch, -1)
out = self.final_linear(out)
return out
================================================
FILE: stylegan_human/pti/pti_models/e4e/stylegan2/op/__init__.py
================================================
from .fused_act import FusedLeakyReLU, fused_leaky_relu
from .upfirdn2d import upfirdn2d
================================================
FILE: stylegan_human/pti/pti_models/e4e/stylegan2/op/fused_act.py
================================================
import os
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Function
module_path = os.path.dirname(__file__)
class FusedLeakyReLU(nn.Module):
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
super().__init__()
self.bias = nn.Parameter(torch.zeros(channel))
self.negative_slope = negative_slope
self.scale = scale
def forward(self, input):
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
rest_dim = [1] * (input.ndim - bias.ndim - 1)
input = input.cuda()
return (
F.leaky_relu(
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope
)
* scale
)
================================================
FILE: stylegan_human/pti/pti_models/e4e/stylegan2/op/fused_bias_act.cpp
================================================
#include
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
int act, int grad, float alpha, float scale);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
int act, int grad, float alpha, float scale) {
CHECK_CUDA(input);
CHECK_CUDA(bias);
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
}
================================================
FILE: stylegan_human/pti/pti_models/e4e/stylegan2/op/fused_bias_act_kernel.cu
================================================
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include
#include
#include
#include
#include
#include
#include
template
static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
scalar_t zero = 0.0;
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
scalar_t x = p_x[xi];
if (use_bias) {
x += p_b[(xi / step_b) % size_b];
}
scalar_t ref = use_ref ? p_ref[xi] : zero;
scalar_t y;
switch (act * 10 + grad) {
default:
case 10: y = x; break;
case 11: y = x; break;
case 12: y = 0.0; break;
case 30: y = (x > 0.0) ? x : x * alpha; break;
case 31: y = (ref > 0.0) ? x : x * alpha; break;
case 32: y = 0.0; break;
}
out[xi] = y * scale;
}
}
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
int act, int grad, float alpha, float scale) {
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
auto x = input.contiguous();
auto b = bias.contiguous();
auto ref = refer.contiguous();
int use_bias = b.numel() ? 1 : 0;
int use_ref = ref.numel() ? 1 : 0;
int size_x = x.numel();
int size_b = b.numel();
int step_b = 1;
for (int i = 1 + 1; i < x.dim(); i++) {
step_b *= x.size(i);
}
int loop_x = 4;
int block_size = 4 * 32;
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
auto y = torch::empty_like(x);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
fused_bias_act_kernel<<>>(
y.data_ptr(),
x.data_ptr(),
b.data_ptr(),
ref.data_ptr(),
act,
grad,
alpha,
scale,
loop_x,
size_x,
step_b,
size_b,
use_bias,
use_ref
);
});
return y;
}
================================================
FILE: stylegan_human/pti/pti_models/e4e/stylegan2/op/upfirdn2d.cpp
================================================
#include
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
int up_x, int up_y, int down_x, int down_y,
int pad_x0, int pad_x1, int pad_y0, int pad_y1);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
int up_x, int up_y, int down_x, int down_y,
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
CHECK_CUDA(input);
CHECK_CUDA(kernel);
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
}
================================================
FILE: stylegan_human/pti/pti_models/e4e/stylegan2/op/upfirdn2d.py
================================================
import os
import torch
from torch.nn import functional as F
module_path = os.path.dirname(__file__)
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
out = upfirdn2d_native(
input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
)
return out
def upfirdn2d_native(
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
):
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
)
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
:,
]
out = out.permute(0, 3, 1, 2)
out = out.reshape(
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
)
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.view(-1, channel, out_h, out_w)
================================================
FILE: stylegan_human/pti/pti_models/e4e/stylegan2/op/upfirdn2d_kernel.cu
================================================
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include
#include
#include
#include
#include
#include
#include
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
int c = a / b;
if (c * b > a) {
c--;
}
return c;
}
struct UpFirDn2DKernelParams {
int up_x;
int up_y;
int down_x;
int down_y;
int pad_x0;
int pad_x1;
int pad_y0;
int pad_y1;
int major_dim;
int in_h;
int in_w;
int minor_dim;
int kernel_h;
int kernel_w;
int out_h;
int out_w;
int loop_major;
int loop_x;
};
template
__global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) {
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
__shared__ volatile float sk[kernel_h][kernel_w];
__shared__ volatile float sx[tile_in_h][tile_in_w];
int minor_idx = blockIdx.x;
int tile_out_y = minor_idx / p.minor_dim;
minor_idx -= tile_out_y * p.minor_dim;
tile_out_y *= tile_out_h;
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
int major_idx_base = blockIdx.z * p.loop_major;
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) {
return;
}
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) {
int ky = tap_idx / kernel_w;
int kx = tap_idx - ky * kernel_w;
scalar_t v = 0.0;
if (kx < p.kernel_w & ky < p.kernel_h) {
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
}
sk[ky][kx] = v;
}
for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) {
for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) {
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
int tile_in_x = floor_div(tile_mid_x, up_x);
int tile_in_y = floor_div(tile_mid_y, up_y);
__syncthreads();
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) {
int rel_in_y = in_idx / tile_in_w;
int rel_in_x = in_idx - rel_in_y * tile_in_w;
int in_x = rel_in_x + tile_in_x;
int in_y = rel_in_y + tile_in_y;
scalar_t v = 0.0;
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx];
}
sx[rel_in_y][rel_in_x] = v;
}
__syncthreads();
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) {
int rel_out_y = out_idx / tile_out_w;
int rel_out_x = out_idx - rel_out_y * tile_out_w;
int out_x = rel_out_x + tile_out_x;
int out_y = rel_out_y + tile_out_y;
int mid_x = tile_mid_x + rel_out_x * down_x;
int mid_y = tile_mid_y + rel_out_y * down_y;
int in_x = floor_div(mid_x, up_x);
int in_y = floor_div(mid_y, up_y);
int rel_in_x = in_x - tile_in_x;
int rel_in_y = in_y - tile_in_y;
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
scalar_t v = 0.0;
#pragma unroll
for (int y = 0; y < kernel_h / up_y; y++)
#pragma unroll
for (int x = 0; x < kernel_w / up_x; x++)
v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x];
if (out_x < p.out_w & out_y < p.out_h) {
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v;
}
}
}
}
}
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
int up_x, int up_y, int down_x, int down_y,
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
UpFirDn2DKernelParams p;
auto x = input.contiguous();
auto k = kernel.contiguous();
p.major_dim = x.size(0);
p.in_h = x.size(1);
p.in_w = x.size(2);
p.minor_dim = x.size(3);
p.kernel_h = k.size(0);
p.kernel_w = k.size(1);
p.up_x = up_x;
p.up_y = up_y;
p.down_x = down_x;
p.down_y = down_y;
p.pad_x0 = pad_x0;
p.pad_x1 = pad_x1;
p.pad_y0 = pad_y0;
p.pad_y1 = pad_y1;
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y;
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x;
auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
int mode = -1;
int tile_out_h;
int tile_out_w;
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 1;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) {
mode = 2;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 3;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) {
mode = 4;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 5;
tile_out_h = 8;
tile_out_w = 32;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) {
mode = 6;
tile_out_h = 8;
tile_out_w = 32;
}
dim3 block_size;
dim3 grid_size;
if (tile_out_h > 0 && tile_out_w) {
p.loop_major = (p.major_dim - 1) / 16384 + 1;
p.loop_x = 1;
block_size = dim3(32 * 8, 1, 1);
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
(p.major_dim - 1) / p.loop_major + 1);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
switch (mode) {
case 1:
upfirdn2d_kernel<<>>(
out.data_ptr(), x.data_ptr(), k.data_ptr(), p
);
break;
case 2:
upfirdn2d_kernel<<>>(
out.data_ptr(), x.data_ptr(), k.data_ptr(), p
);
break;
case 3:
upfirdn2d_kernel<<>>(
out.data_ptr(), x.data_ptr(), k.data_ptr(), p
);
break;
case 4:
upfirdn2d_kernel<<>>(
out.data_ptr(), x.data_ptr(), k.data_ptr(), p
);
break;
case 5:
upfirdn2d_kernel<<>>(
out.data_ptr(), x.data_ptr(), k.data_ptr(), p
);
break;
case 6:
upfirdn2d_kernel<<>>(
out.data_ptr(), x.data_ptr(), k.data_ptr(), p
);
break;
}
});
return out;
}
================================================
FILE: stylegan_human/pti/training/__init__.py
================================================
================================================
FILE: stylegan_human/pti/training/coaches/__init__.py
================================================
================================================
FILE: stylegan_human/pti/training/coaches/base_coach.py
================================================
import abc
import os
import pickle
from argparse import Namespace
import wandb
import os.path
from .localitly_regulizer import Space_Regulizer, l2_loss
import torch
from torchvision import transforms
from lpips import LPIPS
from pti.training.projectors import w_projector
from pti.pti_configs import global_config, paths_config, hyperparameters
from pti.pti_models.e4e.psp import pSp
from utils.log_utils import log_image_from_w
from utils.models_utils import toogle_grad, load_old_G
class BaseCoach:
def __init__(self, data_loader, use_wandb):
self.use_wandb = use_wandb
self.data_loader = data_loader
self.w_pivots = {}
self.image_counter = 0
if hyperparameters.first_inv_type == 'w+':
self.initilize_e4e()
self.e4e_image_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((256, 128)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
# Initialize loss
self.lpips_loss = LPIPS(net=hyperparameters.lpips_type).to(global_config.device).eval()
self.restart_training()
# Initialize checkpoint dir
self.checkpoint_dir = paths_config.checkpoints_dir
os.makedirs(self.checkpoint_dir, exist_ok=True)
def restart_training(self):
# Initialize networks
self.G = load_old_G()
toogle_grad(self.G, True)
self.original_G = load_old_G()
self.space_regulizer = Space_Regulizer(self.original_G, self.lpips_loss)
self.optimizer = self.configure_optimizers()
def get_inversion(self, w_path_dir, image_name, image):
embedding_dir = f'{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}'
os.makedirs(embedding_dir, exist_ok=True)
w_pivot = None
if hyperparameters.use_last_w_pivots:
w_pivot = self.load_inversions(w_path_dir, image_name)
if not hyperparameters.use_last_w_pivots or w_pivot is None:
w_pivot = self.calc_inversions(image, image_name)
torch.save(w_pivot, f'{embedding_dir}/0.pt')
w_pivot = w_pivot.to(global_config.device)
return w_pivot
def load_inversions(self, w_path_dir, image_name):
if image_name in self.w_pivots:
return self.w_pivots[image_name]
if hyperparameters.first_inv_type == 'w+':
w_potential_path = f'{w_path_dir}/{paths_config.e4e_results_keyword}/{image_name}/0.pt'
else:
w_potential_path = f'{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}/0.pt'
if not os.path.isfile(w_potential_path):
return None
w = torch.load(w_potential_path).to(global_config.device)
self.w_pivots[image_name] = w
return w
def calc_inversions(self, image, image_name):
if hyperparameters.first_inv_type == 'w+':
w = self.get_e4e_inversion(image)
else:
id_image = torch.squeeze((image.to(global_config.device) + 1) / 2) * 255
w = w_projector.project(self.G, id_image, device=torch.device(global_config.device), w_avg_samples=600,
num_steps=hyperparameters.first_inv_steps, w_name=image_name,
use_wandb=self.use_wandb)
return w
@abc.abstractmethod
def train(self):
pass
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.G.parameters(), lr=hyperparameters.pti_learning_rate)
return optimizer
def calc_loss(self, generated_images, real_images, log_name, new_G, use_ball_holder, w_batch):
loss = 0.0
if hyperparameters.pt_l2_lambda > 0:
l2_loss_val = l2_loss(generated_images, real_images)
if self.use_wandb:
wandb.log({f'MSE_loss_val_{log_name}': l2_loss_val.detach().cpu()}, step=global_config.training_step)
loss += l2_loss_val * hyperparameters.pt_l2_lambda
if hyperparameters.pt_lpips_lambda > 0:
loss_lpips = self.lpips_loss(generated_images, real_images)
loss_lpips = torch.squeeze(loss_lpips)
if self.use_wandb:
wandb.log({f'LPIPS_loss_val_{log_name}': loss_lpips.detach().cpu()}, step=global_config.training_step)
loss += loss_lpips * hyperparameters.pt_lpips_lambda
if use_ball_holder and hyperparameters.use_locality_regularization:
ball_holder_loss_val = self.space_regulizer.space_regulizer_loss(new_G, w_batch, use_wandb=self.use_wandb)
loss += ball_holder_loss_val
return loss, l2_loss_val, loss_lpips
def forward(self, w):
generated_images = self.G.synthesis(w, noise_mode='const', force_fp32=True)
return generated_images
def initilize_e4e(self):
ckpt = torch.load(paths_config.e4e, map_location='cpu')
opts = ckpt['opts']
opts['batch_size'] = hyperparameters.train_batch_size
opts['checkpoint_path'] = paths_config.e4e
opts = Namespace(**opts)
self.e4e_inversion_net = pSp(opts)
self.e4e_inversion_net.eval()
self.e4e_inversion_net = self.e4e_inversion_net.to(global_config.device)
toogle_grad(self.e4e_inversion_net, False)
def get_e4e_inversion(self, image):
image = (image + 1) / 2
new_image = self.e4e_image_transform(image[0]).to(global_config.device)
_, w = self.e4e_inversion_net(new_image.unsqueeze(0), randomize_noise=False, return_latents=True, resize=False,
input_code=False)
if self.use_wandb:
log_image_from_w(w, self.G, 'First e4e inversion')
return w
================================================
FILE: stylegan_human/pti/training/coaches/localitly_regulizer.py
================================================
import torch
import numpy as np
import wandb
from pti.pti_configs import hyperparameters, global_config
l2_criterion = torch.nn.MSELoss(reduction='mean')
def l2_loss(real_images, generated_images):
loss = l2_criterion(real_images, generated_images)
return loss
class Space_Regulizer:
def __init__(self, original_G, lpips_net):
self.original_G = original_G
self.morphing_regulizer_alpha = hyperparameters.regulizer_alpha
self.lpips_loss = lpips_net
def get_morphed_w_code(self, new_w_code, fixed_w):
interpolation_direction = new_w_code - fixed_w
interpolation_direction_norm = torch.norm(interpolation_direction, p=2)
direction_to_move = hyperparameters.regulizer_alpha * interpolation_direction / interpolation_direction_norm
result_w = fixed_w + direction_to_move
self.morphing_regulizer_alpha * fixed_w + (1 - self.morphing_regulizer_alpha) * new_w_code
return result_w
def get_image_from_ws(self, w_codes, G):
return torch.cat([G.synthesis(w_code, noise_mode='none', force_fp32=True) for w_code in w_codes])
def ball_holder_loss_lazy(self, new_G, num_of_sampled_latents, w_batch, use_wandb=False):
loss = 0.0
z_samples = np.random.randn(num_of_sampled_latents, self.original_G.z_dim)
w_samples = self.original_G.mapping(torch.from_numpy(z_samples).to(global_config.device), None,
truncation_psi=0.5)
territory_indicator_ws = [self.get_morphed_w_code(w_code.unsqueeze(0), w_batch) for w_code in w_samples]
for w_code in territory_indicator_ws:
new_img = new_G.synthesis(w_code, noise_mode='none', force_fp32=True)
with torch.no_grad():
old_img = self.original_G.synthesis(w_code, noise_mode='none', force_fp32=True)
if hyperparameters.regulizer_l2_lambda > 0:
l2_loss_val = l2_loss.l2_loss(old_img, new_img)
if use_wandb:
wandb.log({f'space_regulizer_l2_loss_val': l2_loss_val.detach().cpu()},
step=global_config.training_step)
loss += l2_loss_val * hyperparameters.regulizer_l2_lambda
if hyperparameters.regulizer_lpips_lambda > 0:
loss_lpips = self.lpips_loss(old_img, new_img)
loss_lpips = torch.mean(torch.squeeze(loss_lpips))
if use_wandb:
wandb.log({f'space_regulizer_lpips_loss_val': loss_lpips.detach().cpu()},
step=global_config.training_step)
loss += loss_lpips * hyperparameters.regulizer_lpips_lambda
return loss / len(territory_indicator_ws)
def space_regulizer_loss(self, new_G, w_batch, use_wandb):
ret_val = self.ball_holder_loss_lazy(new_G, hyperparameters.latent_ball_num_of_samples, w_batch, use_wandb)
return ret_val
================================================
FILE: stylegan_human/pti/training/coaches/multi_id_coach.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
import os
import torch
from tqdm import tqdm
from pti.pti_configs import paths_config, hyperparameters, global_config
from pti.training.coaches.base_coach import BaseCoach
from utils.log_utils import log_images_from_w
class MultiIDCoach(BaseCoach):
def __init__(self, data_loader, use_wandb):
super().__init__(data_loader, use_wandb)
def train(self):
self.G.synthesis.train()
self.G.mapping.train()
w_path_dir = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}'
os.makedirs(w_path_dir, exist_ok=True)
os.makedirs(f'{w_path_dir}/{paths_config.pti_results_keyword}', exist_ok=True)
use_ball_holder = True
w_pivots = []
images = []
for fname, image in self.data_loader:
if self.image_counter >= hyperparameters.max_images_to_invert:
break
image_name = fname[0]
if hyperparameters.first_inv_type == 'w+':
embedding_dir = f'{w_path_dir}/{paths_config.e4e_results_keyword}/{image_name}'
else:
embedding_dir = f'{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}'
os.makedirs(embedding_dir, exist_ok=True)
w_pivot = self.get_inversion(w_path_dir, image_name, image)
w_pivots.append(w_pivot)
images.append((image_name, image))
self.image_counter += 1
for i in tqdm(range(hyperparameters.max_pti_steps)):
self.image_counter = 0
for data, w_pivot in zip(images, w_pivots):
image_name, image = data
if self.image_counter >= hyperparameters.max_images_to_invert:
break
real_images_batch = image.to(global_config.device)
generated_images = self.forward(w_pivot)
loss, l2_loss_val, loss_lpips = self.calc_loss(generated_images, real_images_batch, image_name,
self.G, use_ball_holder, w_pivot)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
use_ball_holder = global_config.training_step % hyperparameters.locality_regularization_interval == 0
global_config.training_step += 1
self.image_counter += 1
if self.use_wandb:
log_images_from_w(w_pivots, self.G, [image[0] for image in images])
# torch.save(self.G,
# f'{paths_config.checkpoints_dir}/model_{global_config.run_name}_multi_id.pt')
snapshot_data = dict()
snapshot_data['G_ema'] = self.G
import pickle
with open(f'{paths_config.checkpoints_dir}/model_{global_config.run_name}_multi_id.pkl', 'wb') as f:
pickle.dump(snapshot_data, f)
================================================
FILE: stylegan_human/pti/training/coaches/single_id_coach.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
import os
import torch
from tqdm import tqdm
from pti.pti_configs import paths_config, hyperparameters, global_config
from pti.training.coaches.base_coach import BaseCoach
from utils.log_utils import log_images_from_w
from torchvision.utils import save_image
class SingleIDCoach(BaseCoach):
def __init__(self, data_loader, use_wandb):
super().__init__(data_loader, use_wandb)
def train(self):
w_path_dir = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}'
os.makedirs(w_path_dir, exist_ok=True)
os.makedirs(f'{w_path_dir}/{paths_config.pti_results_keyword}', exist_ok=True)
use_ball_holder = True
for fname, image in tqdm(self.data_loader):
image_name = fname[0]
self.restart_training()
if self.image_counter >= hyperparameters.max_images_to_invert:
break
embedding_dir = f'{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}'
os.makedirs(embedding_dir, exist_ok=True)
w_pivot = None
if hyperparameters.use_last_w_pivots:
w_pivot = self.load_inversions(w_path_dir, image_name)
# Copyright (c) SenseTime Research. All rights reserved.
elif not hyperparameters.use_last_w_pivots or w_pivot is None:
w_pivot = self.calc_inversions(image, image_name)
# w_pivot = w_pivot.detach().clone().to(global_config.device)
w_pivot = w_pivot.to(global_config.device)
torch.save(w_pivot, f'{embedding_dir}/0.pt')
log_images_counter = 0
real_images_batch = image.to(global_config.device)
for i in range(hyperparameters.max_pti_steps):
generated_images = self.forward(w_pivot)
loss, l2_loss_val, loss_lpips = self.calc_loss(generated_images, real_images_batch, image_name,
self.G, use_ball_holder, w_pivot)
if i == 0:
tmp1 = torch.clone(generated_images)
if i % 10 == 0:
print("pti loss: ", i, loss.data, loss_lpips.data)
self.optimizer.zero_grad()
if loss_lpips <= hyperparameters.LPIPS_value_threshold:
break
loss.backward()
self.optimizer.step()
use_ball_holder = global_config.training_step % hyperparameters.locality_regularization_interval == 0
if self.use_wandb and log_images_counter % global_config.image_rec_result_log_snapshot == 0:
log_images_from_w([w_pivot], self.G, [image_name])
global_config.training_step += 1
log_images_counter += 1
# save output image
tmp = torch.cat([real_images_batch, tmp1, generated_images], axis= 3)
save_image(tmp, f"{paths_config.experiments_output_dir}/{image_name}.png", normalize=True)
self.image_counter += 1
# torch.save(self.G,
# f'{paths_config.checkpoints_dir}/model_{image_name}.pt') #'.pt'
snapshot_data = dict()
snapshot_data['G_ema'] = self.G
import pickle
with open(f'{paths_config.checkpoints_dir}/model_{image_name}.pkl', 'wb') as f:
pickle.dump(snapshot_data, f)
================================================
FILE: stylegan_human/pti/training/projectors/__init__.py
================================================
================================================
FILE: stylegan_human/pti/training/projectors/w_plus_projector.py
================================================
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Project given image to the latent space of pretrained network pickle."""
import copy
import wandb
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
from configs import global_config, hyperparameters
import dnnlib
from utils.log_utils import log_image_from_w
def project(
G,
target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
*,
num_steps=1000,
w_avg_samples=10000,
initial_learning_rate=0.01,
initial_noise_factor=0.05,
lr_rampdown_length=0.25,
lr_rampup_length=0.05,
noise_ramp_length=0.75,
regularize_noise_weight=1e5,
verbose=False,
device: torch.device,
use_wandb=False,
initial_w=None,
image_log_step=global_config.image_rec_result_log_snapshot,
w_name: str
):
print('inside training/projectors/w_plus_projector')
print(target.shape, G.img_channels, G.img_resolution * 2 , G.img_resolution)
assert target.shape == (G.img_channels, G.img_resolution * 2, G.img_resolution)
def logprint(*args):
if verbose:
print(*args)
G = copy.deepcopy(G).eval().requires_grad_(False).to(device).float() # type: ignore
# Compute w stats.
logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C]
w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C]
w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C]
w_avg_tensor = torch.from_numpy(w_avg).to(global_config.device)
w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
start_w = initial_w if initial_w is not None else w_avg
# Setup noise inputs.
noise_bufs = {name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name}
# Load VGG16 feature detector.
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
with dnnlib.util.open_url(url) as f:
vgg16 = torch.jit.load(f).eval().to(device)
# Features for target image.
target_images = target.unsqueeze(0).to(device).to(torch.float32)
if target_images.shape[2] > 256:
target_images = F.interpolate(target_images, size=(256, 256), mode='area')
target_features = vgg16(target_images, resize_images=False, return_lpips=True)
start_w = np.repeat(start_w, G.mapping.num_ws, axis=1)
w_opt = torch.tensor(start_w, dtype=torch.float32, device=device,
requires_grad=True) # pylint: disable=not-callable
optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999),
lr=hyperparameters.first_inv_lr)
# Init noise.
for buf in noise_bufs.values():
buf[:] = torch.randn_like(buf)
buf.requires_grad = True
for step in tqdm(range(num_steps)):
# Learning rate schedule.
t = step / num_steps
w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
lr = initial_learning_rate * lr_ramp
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# Synth images from opt_w.
w_noise = torch.randn_like(w_opt) * w_noise_scale
ws = (w_opt + w_noise)
synth_images = G.synthesis(ws, noise_mode='const', force_fp32=True)
# Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
synth_images = (synth_images + 1) * (255 / 2)
if synth_images.shape[2] > 256:
synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
# Features for synth images.
synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
dist = (target_features - synth_features).square().sum()
# Noise regularization.
reg_loss = 0.0
for v in noise_bufs.values():
noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d()
while True:
reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2
reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2
if noise.shape[2] <= 8:
break
noise = F.avg_pool2d(noise, kernel_size=2)
loss = dist + reg_loss * regularize_noise_weight
if step % image_log_step == 0:
with torch.no_grad():
if use_wandb:
global_config.training_step += 1
wandb.log({f'first projection _{w_name}': loss.detach().cpu()}, step=global_config.training_step)
log_image_from_w(w_opt, G, w_name)
# Step
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
logprint(f'step {step + 1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}')
# Normalize noise.
with torch.no_grad():
for buf in noise_bufs.values():
buf -= buf.mean()
buf *= buf.square().mean().rsqrt()
del G
return w_opt
================================================
FILE: stylegan_human/pti/training/projectors/w_projector.py
================================================
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Project given image to the latent space of pretrained network pickle."""
import copy
import wandb
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
from pti.pti_configs import global_config, hyperparameters
from utils import log_utils
import dnnlib
def project(
G,
target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
*,
num_steps=1000,
w_avg_samples=10000,
initial_learning_rate=0.01,
initial_noise_factor=0.05,
lr_rampdown_length=0.25,
lr_rampup_length=0.05,
noise_ramp_length=0.75,
regularize_noise_weight=1e5,
verbose=False,
device: torch.device,
use_wandb=False,
initial_w=None,
image_log_step=global_config.image_rec_result_log_snapshot,
w_name: str
):
print(target.shape,G.img_channels, G.img_resolution, G.img_resolution//2)
assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution // 2)
def logprint(*args):
if verbose:
print(*args)
G = copy.deepcopy(G).eval().requires_grad_(False).to(device).float() # type: ignore
# Compute w stats.
logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C]
w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C]
w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C]
w_avg_tensor = torch.from_numpy(w_avg).to(global_config.device)
w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
start_w = initial_w if initial_w is not None else w_avg
# Setup noise inputs.
noise_bufs = {name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name}
# Load VGG16 feature detector.
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
with dnnlib.util.open_url(url) as f:
vgg16 = torch.jit.load(f).eval().to(device)
# Features for target image.
target_images = target.unsqueeze(0).to(device).to(torch.float32)
if target_images.shape[2] > 256:
target_images = F.interpolate(target_images, size=(256, 256), mode='area')
target_features = vgg16(target_images, resize_images=False, return_lpips=True)
w_opt = torch.tensor(start_w, dtype=torch.float32, device=device,
requires_grad=True) # pylint: disable=not-callable
optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999),
lr=hyperparameters.first_inv_lr)
# Init noise.
for buf in noise_bufs.values():
buf[:] = torch.randn_like(buf)
buf.requires_grad = True
for step in range(num_steps):
# Learning rate schedule.
t = step / num_steps
w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
lr = initial_learning_rate * lr_ramp
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# Synth images from opt_w.
w_noise = torch.randn_like(w_opt) * w_noise_scale
ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1])
synth_images = G.synthesis(ws, noise_mode='const', force_fp32=True)
# Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
synth_images = (synth_images + 1) * (255 / 2)
if synth_images.shape[2] > 256:
synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
# Features for synth images.
synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
dist = (target_features - synth_features).square().sum()
# Noise regularization.
reg_loss = 0.0
for v in noise_bufs.values():
noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d()
while True:
reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2
reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2
if noise.shape[2] <= 8:
break
noise = F.avg_pool2d(noise, kernel_size=2)
loss = dist + reg_loss * regularize_noise_weight
if step % 10 == 0:
print("project loss", step, loss.data)
if step % image_log_step == 0:
with torch.no_grad():
if use_wandb:
global_config.training_step += 1
wandb.log({f'first projection _{w_name}': loss.detach().cpu()}, step=global_config.training_step)
log_utils.log_image_from_w(w_opt.repeat([1, G.mapping.num_ws, 1]), G, w_name)
# Step
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
logprint(f'step {step + 1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}')
# Normalize noise.
with torch.no_grad():
for buf in noise_bufs.values():
buf -= buf.mean()
buf *= buf.square().mean().rsqrt()
del G
return w_opt.repeat([1, 18, 1])
================================================
FILE: stylegan_human/run_pti.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
from random import choice
from string import ascii_uppercase
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
import os
from pti.pti_configs import global_config, paths_config
import wandb
from pti.training.coaches.multi_id_coach import MultiIDCoach
from pti.training.coaches.single_id_coach import SingleIDCoach
from utils.ImagesDataset import ImagesDataset
def run_PTI(run_name='', use_wandb=False, use_multi_id_training=False):
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = global_config.cuda_visible_devices
if run_name == '':
global_config.run_name = ''.join(choice(ascii_uppercase) for i in range(12))
else:
global_config.run_name = run_name
if use_wandb:
run = wandb.init(project=paths_config.pti_results_keyword, reinit=True, name=global_config.run_name)
global_config.pivotal_training_steps = 1
global_config.training_step = 1
embedding_dir_path = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}/{paths_config.pti_results_keyword}'
# print('embedding_dir_path: ', embedding_dir_path) #./embeddings/barcelona/PTI
os.makedirs(embedding_dir_path, exist_ok=True)
dataset = ImagesDataset(paths_config.input_data_path, transforms.Compose([
transforms.Resize((1024, 512)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]))
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
if use_multi_id_training:
coach = MultiIDCoach(dataloader, use_wandb)
else:
coach = SingleIDCoach(dataloader, use_wandb)
coach.train()
return global_config.run_name
if __name__ == '__main__':
run_PTI(run_name='', use_wandb=False, use_multi_id_training=False)
================================================
FILE: stylegan_human/style_mixing.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
#
import os
import re
from typing import List
import legacy
import click
import dnnlib
import numpy as np
import PIL.Image
import torch
"""
Style mixing using pretrained network pickle.
Examples:
\b
python style_mixing.py --network=pretrained_models/stylegan_human_v2_1024.pkl --rows=85,100,75,458,1500 \\
--cols=55,821,1789,293 --styles=0-3 --outdir=outputs/stylemixing
"""
@click.command()
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
@click.option('--rows', 'row_seeds', type=legacy.num_range, help='Random seeds to use for image rows', required=True)
@click.option('--cols', 'col_seeds', type=legacy.num_range, help='Random seeds to use for image columns', required=True)
@click.option('--styles', 'col_styles', type=legacy.num_range, help='Style layer range', default='0-6', show_default=True)
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=0.8, show_default=True)
@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
@click.option('--outdir', type=str, required=True, default='outputs/stylemixing')
def generate_style_mix(
network_pkl: str,
row_seeds: List[int],
col_seeds: List[int],
col_styles: List[int],
truncation_psi: float,
noise_mode: str,
outdir: str
):
print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
dtype = torch.float32 if device.type == 'mps' else torch.float64
with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype)
os.makedirs(outdir, exist_ok=True)
print('Generating W vectors...')
all_seeds = list(set(row_seeds + col_seeds))
all_z = np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])
all_w = G.mapping(torch.from_numpy(all_z).to(device, dtype=dtype), None)
w_avg = G.mapping.w_avg
all_w = w_avg + (all_w - w_avg) * truncation_psi
w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))}
print('Generating images...')
all_images = G.synthesis(all_w, noise_mode=noise_mode)
all_images = (all_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
image_dict = {(seed, seed): image for seed, image in zip(all_seeds, list(all_images))}
print('Generating style-mixed images...')
for row_seed in row_seeds:
for col_seed in col_seeds:
w = w_dict[row_seed].clone()
w[col_styles] = w_dict[col_seed][col_styles]
image = G.synthesis(w[np.newaxis], noise_mode=noise_mode)
image = (image.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
image_dict[(row_seed, col_seed)] = image[0].cpu().numpy()
os.makedirs(outdir, exist_ok=True)
# print('Saving images...')
# for (row_seed, col_seed), image in image_dict.items():
# PIL.Image.fromarray(image, 'RGB').save(f'{outdir}/{row_seed}-{col_seed}.png')
print('Saving image grid...')
W = G.img_resolution // 2
H = G.img_resolution
canvas = PIL.Image.new('RGB', (W * (len(col_seeds) + 1), H * (len(row_seeds) + 1)), 'black')
for row_idx, row_seed in enumerate([0] + row_seeds):
for col_idx, col_seed in enumerate([0] + col_seeds):
if row_idx == 0 and col_idx == 0:
continue
key = (row_seed, col_seed)
if row_idx == 0:
key = (col_seed, col_seed)
if col_idx == 0:
key = (row_seed, row_seed)
canvas.paste(PIL.Image.fromarray(image_dict[key], 'RGB'), (W * col_idx, H * row_idx))
canvas.save(f'{outdir}/grid.png')
#----------------------------------------------------------------------------
if __name__ == "__main__":
generate_style_mix() # pylint: disable=no-value-for-parameter
#----------------------------------------------------------------------------
================================================
FILE: stylegan_human/stylemixing_video.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
"""Here we demo style-mixing results using StyleGAN2 pretrained model.
Script reference: https://github.com/PDillis/stylegan2-fun """
import argparse
import legacy
import scipy
import numpy as np
import PIL.Image
import dnnlib
import dnnlib.tflib as tflib
from typing import List
import re
import sys
import os
import click
import torch
os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide"
import moviepy.editor
"""
Generate style mixing video.
Examples:
\b
python stylemixing_video.py --network=pretrained_models/stylegan_human_v2_1024.pkl --row-seed=3859 \\
--col-seeds=3098,31759,3791 --col-styles=8-12 --trunc=0.8 --outdir=outputs/stylemixing_video
"""
@click.command()
@click.option('--network', 'network_pkl', help='Path to network pickle filename', required=True)
@click.option('--row-seed', 'src_seed', type=legacy.num_range, help='Random seed to use for image source row', required=True)
@click.option('--col-seeds', 'dst_seeds', type=legacy.num_range, help='Random seeds to use for image columns (style)', required=True)
@click.option('--col-styles', 'col_styles', type=legacy.num_range, help='Style layer range (default: %(default)s)', default='0-6')
@click.option('--only-stylemix', 'only_stylemix', help='Add flag to only show the style mxied images in the video',default=False)
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi (default: %(default)s)', default=1)
@click.option('--duration-sec', 'duration_sec', type=float, help='Duration of video (default: %(default)s)', default=10)
@click.option('--fps', 'mp4_fps', type=int, help='FPS of generated video (default: %(default)s)', default=10)
@click.option('--indent-range', 'indent_range', type=int, default=30)
@click.option('--outdir', help='Root directory for run results (default: %(default)s)', default='outputs/stylemixing_video', metavar='DIR')
def style_mixing_video(network_pkl: str,
src_seed: List[int], # Seed of the source image style (row)
dst_seeds: List[int], # Seeds of the destination image styles (columns)
col_styles: List[int], # Styles to transfer from first row to first column
truncation_psi=float,
only_stylemix=bool, # True if user wishes to show only thre style transferred result
duration_sec=float,
smoothing_sec=1.0,
mp4_fps=int,
mp4_codec="libx264",
mp4_bitrate="16M",
minibatch_size=8,
noise_mode='const',
indent_range=int,
outdir=str):
# Calculate the number of frames:
print('col_seeds: ', dst_seeds)
num_frames = int(np.rint(duration_sec * mp4_fps))
print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
dtype = torch.float32 if device.type == 'mps' else torch.float64
with dnnlib.util.open_url(network_pkl) as f:
Gs = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype)
print(Gs.num_ws, Gs.w_dim, Gs.img_resolution)
max_style = int(2 * np.log2(Gs.img_resolution)) - 3
assert max(col_styles) <= max_style, f"Maximum col-style allowed: {max_style}"
# Left col latents
print('Generating Source W vectors...')
src_shape = [num_frames] + [Gs.z_dim]
src_z = np.random.RandomState(*src_seed).randn(*src_shape).astype(np.float32) # [frames, src, component]
src_z = scipy.ndimage.gaussian_filter(src_z, [smoothing_sec * mp4_fps] + [0] * (2- 1), mode="wrap")
src_z /= np.sqrt(np.mean(np.square(src_z)))
# Map into the detangled latent space W and do truncation trick
src_w = Gs.mapping(torch.from_numpy(src_z).to(device, dtype=dtype), None)
w_avg = Gs.mapping.w_avg
src_w = w_avg + (src_w - w_avg) * truncation_psi
# Top row latents (fixed reference)
print('Generating Destination W vectors...')
dst_z = np.stack([np.random.RandomState(seed).randn(Gs.z_dim) for seed in dst_seeds])
dst_w = Gs.mapping(torch.from_numpy(dst_z).to(device, dtype=dtype), None)
dst_w = w_avg + (dst_w - w_avg) * truncation_psi
# Get the width and height of each image:
H = Gs.img_resolution # 1024
W = Gs.img_resolution//2 # 512
# Generate ALL the source images:
src_images = Gs.synthesis(src_w, noise_mode=noise_mode)
src_images = (src_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
# Generate the column images:
dst_images = Gs.synthesis(dst_w, noise_mode=noise_mode)
dst_images = (dst_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
print('Generating full video (including source and destination images)')
# Generate our canvas where we will paste all the generated images:
canvas = PIL.Image.new("RGB", ((W-indent_range) * (len(dst_seeds) + 1), H * (len(src_seed) + 1)), "white") # W, H
for col, dst_image in enumerate(list(dst_images)): #dst_image:[3,1024,512]
canvas.paste(PIL.Image.fromarray(dst_image.cpu().numpy(), "RGB"), ((col + 1) * (W-indent_range), 0)) #H
# Aux functions: Frame generation func for moviepy.
def make_frame(t):
# Get the frame number according to time t:
frame_idx = int(np.clip(np.round(t * mp4_fps), 0, num_frames - 1))
# We wish the image belonging to the frame at time t:
src_image = src_images[frame_idx] # always in the same place
canvas.paste(PIL.Image.fromarray(src_image.cpu().numpy(), "RGB"), (0-indent_range, H)) # Paste it to the lower left
# Now, for each of the column images:
for col, dst_image in enumerate(list(dst_images)):
# Select the pertinent latent w column:
w_col = np.stack([dst_w[col].cpu()]) # [18, 512] -> [1, 18, 512]
w_col = torch.from_numpy(w_col).to(device, dtype=dtype)
# Replace the values defined by col_styles:
w_col[:, col_styles] = src_w[frame_idx, col_styles]#.cpu()
# Generate these synthesized images:
col_images = Gs.synthesis(w_col, noise_mode=noise_mode)
col_images = (col_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
# Paste them in their respective spot:
for row, image in enumerate(list(col_images)):
canvas.paste(
PIL.Image.fromarray(image.cpu().numpy(), "RGB"),
((col + 1) * (W - indent_range), (row + 1) * H),
)
return np.array(canvas)
# Generate video using make_frame:
print('Generating style-mixed video...')
videoclip = moviepy.editor.VideoClip(make_frame, duration=duration_sec)
grid_size = [len(dst_seeds), len(src_seed)]
mp4 = "{}x{}-style-mixing_{}_{}.mp4".format(*grid_size,min(col_styles),max(col_styles))
if not os.path.exists(outdir): os.makedirs(outdir)
videoclip.write_videofile(os.path.join(outdir,mp4),
fps=mp4_fps,
codec=mp4_codec,
bitrate=mp4_bitrate)
if __name__ == "__main__":
style_mixing_video()
================================================
FILE: stylegan_human/torch_utils/__init__.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
# empty
================================================
FILE: stylegan_human/torch_utils/custom_ops.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import os
import glob
import torch
import torch.utils.cpp_extension
import importlib
import hashlib
import shutil
from pathlib import Path
import re
import uuid
from torch.utils.file_baton import FileBaton
#----------------------------------------------------------------------------
# Global options.
verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
#----------------------------------------------------------------------------
# Internal helper funcs.
def _find_compiler_bindir():
patterns = [
'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
]
for pattern in patterns:
matches = sorted(glob.glob(pattern))
if len(matches):
return matches[-1]
return None
def _get_mangled_gpu_name():
name = torch.cuda.get_device_name().lower()
out = []
for c in name:
if re.match('[a-z0-9_-]+', c):
out.append(c)
else:
out.append('-')
return ''.join(out)
#----------------------------------------------------------------------------
# Main entry point for compiling and loading C++/CUDA plugins.
_cached_plugins = dict()
def get_plugin(module_name, sources, **build_kwargs):
assert verbosity in ['none', 'brief', 'full']
# Already cached?
if module_name in _cached_plugins:
return _cached_plugins[module_name]
# Print status.
if verbosity == 'full':
print(f'Setting up PyTorch plugin "{module_name}"...')
elif verbosity == 'brief':
print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
try: # pylint: disable=too-many-nested-blocks
# Make sure we can find the necessary compiler binaries.
if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
compiler_bindir = _find_compiler_bindir()
if compiler_bindir is None:
raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
os.environ['PATH'] += ';' + compiler_bindir
# Compile and load.
verbose_build = (verbosity == 'full')
# Incremental build md5sum trickery. Copies all the input source files
# into a cached build directory under a combined md5 digest of the input
# source files. Copying is done only if the combined digest has changed.
# This keeps input file timestamps and filenames the same as in previous
# extension builds, allowing for fast incremental rebuilds.
#
# This optimization is done only in case all the source files reside in
# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
# environment variable is set (we take this as a signal that the user
# actually cares about this.)
source_dirs_set = set(os.path.dirname(source) for source in sources)
if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
# Compute a combined hash digest for all source files in the same
# custom op directory (usually .cu, .cpp, .py and .h files).
hash_md5 = hashlib.md5()
for src in all_source_files:
with open(src, 'rb') as f:
hash_md5.update(f.read())
build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
if not os.path.isdir(digest_build_dir):
os.makedirs(digest_build_dir, exist_ok=True)
baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
if baton.try_acquire():
try:
for src in all_source_files:
shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
finally:
baton.release()
else:
# Someone else is copying source files under the digest dir,
# wait until done and continue.
baton.wait()
digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
verbose=verbose_build, sources=digest_sources, **build_kwargs)
else:
torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
module = importlib.import_module(module_name)
except:
if verbosity == 'brief':
print('Failed!')
raise
# Print status and add to cache.
if verbosity == 'full':
print(f'Done setting up PyTorch plugin "{module_name}".')
elif verbosity == 'brief':
print('Done.')
_cached_plugins[module_name] = module
return module
#----------------------------------------------------------------------------
def get_plugin_v3(module_name, sources, headers=None, source_dir=None, **build_kwargs):
assert verbosity in ['none', 'brief', 'full']
if headers is None:
headers = []
if source_dir is not None:
sources = [os.path.join(source_dir, fname) for fname in sources]
headers = [os.path.join(source_dir, fname) for fname in headers]
# Already cached?
if module_name in _cached_plugins:
return _cached_plugins[module_name]
# Print status.
if verbosity == 'full':
print(f'Setting up PyTorch plugin "{module_name}"...')
elif verbosity == 'brief':
print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
verbose_build = (verbosity == 'full')
# Compile and load.
try: # pylint: disable=too-many-nested-blocks
# Make sure we can find the necessary compiler binaries.
if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
compiler_bindir = _find_compiler_bindir()
if compiler_bindir is None:
raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
os.environ['PATH'] += ';' + compiler_bindir
# Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
# break the build or unnecessarily restrict what's available to nvcc.
# Unset it to let nvcc decide based on what's available on the
# machine.
os.environ['TORCH_CUDA_ARCH_LIST'] = ''
# Incremental build md5sum trickery. Copies all the input source files
# into a cached build directory under a combined md5 digest of the input
# source files. Copying is done only if the combined digest has changed.
# This keeps input file timestamps and filenames the same as in previous
# extension builds, allowing for fast incremental rebuilds.
#
# This optimization is done only in case all the source files reside in
# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
# environment variable is set (we take this as a signal that the user
# actually cares about this.)
#
# EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
# around the *.cu dependency bug in ninja config.
#
all_source_files = sorted(sources + headers)
all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
# Compute combined hash digest for all source files.
hash_md5 = hashlib.md5()
for src in all_source_files:
with open(src, 'rb') as f:
hash_md5.update(f.read())
# Select cached build directory name.
source_digest = hash_md5.hexdigest()
build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
if not os.path.isdir(cached_build_dir):
tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
os.makedirs(tmpdir)
for src in all_source_files:
shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
try:
os.replace(tmpdir, cached_build_dir) # atomic
except OSError:
# source directory already exists, delete tmpdir and its contents.
shutil.rmtree(tmpdir)
if not os.path.isdir(cached_build_dir): raise
# Compile.
cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
verbose=verbose_build, sources=cached_sources, **build_kwargs)
else:
torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
# Load.
module = importlib.import_module(module_name)
except:
if verbosity == 'brief':
print('Failed!')
raise
# Print status and add to cache dict.
if verbosity == 'full':
print(f'Done setting up PyTorch plugin "{module_name}".')
elif verbosity == 'brief':
print('Done.')
_cached_plugins[module_name] = module
return module
================================================
FILE: stylegan_human/torch_utils/misc.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import re
import contextlib
import numpy as np
import torch
import warnings
import dnnlib
#----------------------------------------------------------------------------
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
# same constant is used multiple times.
_constant_cache = dict()
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
value = np.asarray(value)
if shape is not None:
shape = tuple(shape)
if dtype is None:
dtype = torch.get_default_dtype()
if device is None:
device = torch.device('cpu')
if memory_format is None:
memory_format = torch.contiguous_format
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
tensor = _constant_cache.get(key, None)
if tensor is None:
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
if shape is not None:
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
tensor = tensor.contiguous(memory_format=memory_format)
_constant_cache[key] = tensor
return tensor
#----------------------------------------------------------------------------
# Replace NaN/Inf with specified numerical values.
try:
nan_to_num = torch.nan_to_num # 1.8.0a0
except AttributeError:
def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
assert isinstance(input, torch.Tensor)
if posinf is None:
posinf = torch.finfo(input.dtype).max
if neginf is None:
neginf = torch.finfo(input.dtype).min
assert nan == 0
return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
#----------------------------------------------------------------------------
# Symbolic assert.
try:
symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
except AttributeError:
symbolic_assert = torch.Assert # 1.7.0
#----------------------------------------------------------------------------
# Context manager to suppress known warnings in torch.jit.trace().
class suppress_tracer_warnings(warnings.catch_warnings):
def __enter__(self):
super().__enter__()
warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
return self
#----------------------------------------------------------------------------
# Assert that the shape of a tensor matches the given list of integers.
# None indicates that the size of a dimension is allowed to vary.
# Performs symbolic assertion when used in torch.jit.trace().
def assert_shape(tensor, ref_shape):
if tensor.ndim != len(ref_shape):
raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
if ref_size is None:
pass
elif isinstance(ref_size, torch.Tensor):
with suppress_tracer_warnings(): # as_tensor results are registered as constants
symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
elif isinstance(size, torch.Tensor):
with suppress_tracer_warnings(): # as_tensor results are registered as constants
symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
elif size != ref_size:
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
#----------------------------------------------------------------------------
# Function decorator that calls torch.autograd.profiler.record_function().
def profiled_function(fn):
def decorator(*args, **kwargs):
with torch.autograd.profiler.record_function(fn.__name__):
return fn(*args, **kwargs)
decorator.__name__ = fn.__name__
return decorator
#----------------------------------------------------------------------------
# Sampler for torch.utils.data.DataLoader that loops over the dataset
# indefinitely, shuffling items as it goes.
class InfiniteSampler(torch.utils.data.Sampler):
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
assert len(dataset) > 0
assert num_replicas > 0
assert 0 <= rank < num_replicas
assert 0 <= window_size <= 1
super().__init__(dataset)
self.dataset = dataset
self.rank = rank
self.num_replicas = num_replicas
self.shuffle = shuffle
self.seed = seed
self.window_size = window_size
def __iter__(self):
order = np.arange(len(self.dataset))
rnd = None
window = 0
if self.shuffle:
rnd = np.random.RandomState(self.seed)
rnd.shuffle(order)
window = int(np.rint(order.size * self.window_size))
idx = 0
while True:
i = idx % order.size
if idx % self.num_replicas == self.rank:
yield order[i]
if window >= 2:
j = (i - rnd.randint(window)) % order.size
order[i], order[j] = order[j], order[i]
idx += 1
#----------------------------------------------------------------------------
# Utilities for operating with torch.nn.Module parameters and buffers.
def params_and_buffers(module):
assert isinstance(module, torch.nn.Module)
return list(module.parameters()) + list(module.buffers())
def named_params_and_buffers(module):
assert isinstance(module, torch.nn.Module)
return list(module.named_parameters()) + list(module.named_buffers())
def copy_params_and_buffers(src_module, dst_module, require_all=False):
assert isinstance(src_module, torch.nn.Module)
assert isinstance(dst_module, torch.nn.Module)
src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
for name, tensor in named_params_and_buffers(dst_module):
assert (name in src_tensors) or (not require_all)
if name in src_tensors:
tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
#----------------------------------------------------------------------------
# Context manager for easily enabling/disabling DistributedDataParallel
# synchronization.
@contextlib.contextmanager
def ddp_sync(module, sync):
assert isinstance(module, torch.nn.Module)
if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
yield
else:
with module.no_sync():
yield
#----------------------------------------------------------------------------
# Check DistributedDataParallel consistency across processes.
def check_ddp_consistency(module, ignore_regex=None):
assert isinstance(module, torch.nn.Module)
for name, tensor in named_params_and_buffers(module):
fullname = type(module).__name__ + '.' + name
if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
continue
tensor = tensor.detach()
other = tensor.clone()
torch.distributed.broadcast(tensor=other, src=0)
assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
#----------------------------------------------------------------------------
# Print summary table of module hierarchy.
def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
assert isinstance(module, torch.nn.Module)
assert not isinstance(module, torch.jit.ScriptModule)
assert isinstance(inputs, (tuple, list))
# Register hooks.
entries = []
nesting = [0]
def pre_hook(_mod, _inputs):
nesting[0] += 1
def post_hook(mod, _inputs, outputs):
nesting[0] -= 1
if nesting[0] <= max_nesting:
outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
# Run module.
outputs = module(*inputs)
for hook in hooks:
hook.remove()
# Identify unique outputs, parameters, and buffers.
tensors_seen = set()
for e in entries:
e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
# Filter out redundant entries.
if skip_redundant:
entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
# Construct table.
rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
rows += [['---'] * len(rows[0])]
param_total = 0
buffer_total = 0
submodule_names = {mod: name for name, mod in module.named_modules()}
for e in entries:
name = '' if e.mod is module else submodule_names[e.mod]
param_size = sum(t.numel() for t in e.unique_params)
buffer_size = sum(t.numel() for t in e.unique_buffers)
output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
rows += [[
name + (':0' if len(e.outputs) >= 2 else ''),
str(param_size) if param_size else '-',
str(buffer_size) if buffer_size else '-',
(output_shapes + ['-'])[0],
(output_dtypes + ['-'])[0],
]]
for idx in range(1, len(e.outputs)):
rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
param_total += param_size
buffer_total += buffer_size
rows += [['---'] * len(rows[0])]
rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
# Print table.
widths = [max(len(cell) for cell in column) for column in zip(*rows)]
print()
for row in rows:
print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
print()
return outputs
#----------------------------------------------------------------------------
================================================
FILE: stylegan_human/torch_utils/models.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
# https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py
import math
import random
import functools
import operator
import torch
from torch import nn
from torch.nn import functional as F
import torch.nn.init as init
from torch.autograd import Function
from .op_edit import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
class PixelNorm(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input):
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
def make_kernel(k):
k = torch.tensor(k, dtype=torch.float32)
if k.ndim == 1:
k = k[None, :] * k[:, None]
k /= k.sum()
return k
class Upsample(nn.Module):
def __init__(self, kernel, factor=2):
super().__init__()
self.factor = factor
kernel = make_kernel(kernel) * (factor ** 2)
self.register_buffer("kernel", kernel)
p = kernel.shape[0] - factor
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2
self.pad = (pad0, pad1)
def forward(self, input):
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
return out
class Downsample(nn.Module):
def __init__(self, kernel, factor=2):
super().__init__()
self.factor = factor
kernel = make_kernel(kernel)
self.register_buffer("kernel", kernel)
p = kernel.shape[0] - factor
pad0 = (p + 1) // 2
pad1 = p // 2
self.pad = (pad0, pad1)
def forward(self, input):
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
return out
class Blur(nn.Module):
def __init__(self, kernel, pad, upsample_factor=1):
super().__init__()
kernel = make_kernel(kernel)
if upsample_factor > 1:
kernel = kernel * (upsample_factor ** 2)
self.register_buffer("kernel", kernel)
self.pad = pad
def forward(self, input):
out = upfirdn2d(input, self.kernel, pad=self.pad)
return out
class EqualConv2d(nn.Module):
def __init__(
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
):
super().__init__()
self.weight = nn.Parameter(
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
)
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
self.stride = stride
self.padding = padding
if bias:
self.bias = nn.Parameter(torch.zeros(out_channel))
else:
self.bias = None
def forward(self, input):
out = F.conv2d(
input,
self.weight * self.scale,
bias=self.bias,
stride=self.stride,
padding=self.padding,
)
return out
def __repr__(self):
return (
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
)
class EqualLinear(nn.Module):
def __init__(
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
if bias:
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
else:
self.bias = None
self.activation = activation
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
self.lr_mul = lr_mul
def forward(self, input):
if self.activation:
out = F.linear(input, self.weight * self.scale)
out = fused_leaky_relu(out, self.bias * self.lr_mul)
else:
out = F.linear(
input, self.weight * self.scale, bias=self.bias * self.lr_mul
)
return out
def __repr__(self):
return (
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
)
class ScaledLeakyReLU(nn.Module):
def __init__(self, negative_slope=0.2):
super().__init__()
self.negative_slope = negative_slope
def forward(self, input):
out = F.leaky_relu(input, negative_slope=self.negative_slope)
return out * math.sqrt(2)
class ModulatedConv2d(nn.Module):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
style_dim,
demodulate=True,
upsample=False,
downsample=False,
blur_kernel=[1, 3, 3, 1],
):
super().__init__()
self.eps = 1e-8
self.kernel_size = kernel_size
self.in_channel = in_channel
self.out_channel = out_channel
self.upsample = upsample
self.downsample = downsample
if upsample:
factor = 2
p = (len(blur_kernel) - factor) - (kernel_size - 1)
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2 + 1
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
fan_in = in_channel * kernel_size ** 2
self.scale = 1 / math.sqrt(fan_in)
self.padding = kernel_size // 2
self.weight = nn.Parameter(
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
)
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
self.demodulate = demodulate
def __repr__(self):
return (
f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
f"upsample={self.upsample}, downsample={self.downsample})"
)
def forward(self, input, style):
batch, in_channel, height, width = input.shape
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
weight = self.scale * self.weight * style
if self.demodulate:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
weight = weight.view(
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
)
if self.upsample:
input = input.view(1, batch * in_channel, height, width)
weight = weight.view(
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
)
weight = weight.transpose(1, 2).reshape(
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
)
out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)
out = self.blur(out)
elif self.downsample:
input = self.blur(input)
_, _, height, width = input.shape
input = input.view(1, batch * in_channel, height, width)
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)
else:
input = input.view(1, batch * in_channel, height, width)
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)
return out
class NoiseInjection(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.zeros(1))
def forward(self, image, noise=None):
if noise is None:
batch, _, height, width = image.shape
noise = image.new_empty(batch, 1, height, width).normal_()
return image + self.weight * noise
class ConstantInput(nn.Module):
def __init__(self, channel, size=4):
super().__init__()
self.input = nn.Parameter(torch.randn(1, channel, size, size // 2))
def forward(self, input):
batch = input.shape[0]
out = self.input.repeat(batch, 1, 1, 1)
return out
class StyledConv(nn.Module):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
style_dim,
upsample=False,
blur_kernel=[1, 3, 3, 1],
demodulate=True,
):
super().__init__()
self.conv = ModulatedConv2d(
in_channel,
out_channel,
kernel_size,
style_dim,
upsample=upsample,
blur_kernel=blur_kernel,
demodulate=demodulate,
)
self.noise = NoiseInjection()
self.activate = FusedLeakyReLU(out_channel)
def forward(self, input, style, noise=None):
out = self.conv(input, style)
out = self.noise(out, noise=noise)
out = self.activate(out)
return out
class ToRGB(nn.Module):
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
super().__init__()
if upsample:
self.upsample = Upsample(blur_kernel)
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
def forward(self, input, style, skip=None):
out = self.conv(input, style)
out = out + self.bias
if skip is not None:
skip = self.upsample(skip)
out = out + skip
return out
class Generator(nn.Module):
def __init__(
self,
size,
style_dim,
n_mlp,
channel_multiplier=1,
blur_kernel=[1, 3, 3, 1],
lr_mlp=0.01,
small=False,
small_isaac=False,
):
super().__init__()
self.size = size
if small and size > 64:
raise ValueError("small only works for sizes <= 64")
self.style_dim = style_dim
layers = [PixelNorm()]
for i in range(n_mlp):
layers.append(
EqualLinear(
style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
)
)
self.style = nn.Sequential(*layers)
if small:
self.channels = {
4: 64 * channel_multiplier,
8: 64 * channel_multiplier,
16: 64 * channel_multiplier,
32: 64 * channel_multiplier,
64: 64 * channel_multiplier,
}
elif small_isaac:
self.channels = {4: 256, 8: 256, 16: 256, 32: 256, 64: 128, 128: 128}
else:
self.channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
self.input = ConstantInput(self.channels[4])
self.conv1 = StyledConv(
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
)
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
self.log_size = int(math.log(size, 2))
self.num_layers = (self.log_size - 2) * 2 + 1
self.convs = nn.ModuleList()
self.upsamples = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
self.noises = nn.Module()
in_channel = self.channels[4]
for layer_idx in range(self.num_layers):
res = (layer_idx + 5) // 2
shape = [1, 1, 2 ** res, 2 ** res // 2]
self.noises.register_buffer(
"noise_{}".format(layer_idx), torch.randn(*shape)
)
for i in range(3, self.log_size + 1):
out_channel = self.channels[2 ** i]
self.convs.append(
StyledConv(
in_channel,
out_channel,
3,
style_dim,
upsample=True,
blur_kernel=blur_kernel,
)
)
self.convs.append(
StyledConv(
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
)
)
self.to_rgbs.append(ToRGB(out_channel, style_dim))
in_channel = out_channel
self.n_latent = self.log_size * 2 - 2
def make_noise(self):
device = self.input.input.device
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2 // 2, device=device)]
for i in range(3, self.log_size + 1):
for _ in range(2):
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i // 2, device=device))
return noises
def mean_latent(self, n_latent):
latent_in = torch.randn(
n_latent, self.style_dim, device=self.input.input.device
)
latent = self.style(latent_in).mean(0, keepdim=True)
return latent
def get_latent(self, input):
return self.style(input)
def forward(
self,
styles,
return_latents=False,
return_features=False,
inject_index=None,
truncation=1,
truncation_latent=None,
input_is_latent=False,
noise=None,
randomize_noise=True,
real=False,
):
if not input_is_latent:
styles = [self.style(s) for s in styles]
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers
else:
noise = [
getattr(self.noises, "noise_{}".format(i))
for i in range(self.num_layers)
]
if truncation < 1:
# print('truncation_latent: ', truncation_latent.shape)
if not real: #if type(styles) == list:
style_t = []
for style in styles:
style_t.append(
truncation_latent + truncation * (style - truncation_latent)
) # (-1.1162e-03-(-1.0914e-01))*0.8+(-1.0914e-01)
styles = style_t
else: # styles are latent (tensor: 1,18,512), for real PTI output
truncation_latent = truncation_latent.repeat(18,1).unsqueeze(0) # (1,512) --> (1,18,512)
styles = torch.add(truncation_latent,torch.mul(torch.sub(styles,truncation_latent),truncation))
# print('now styles after truncation : ', styles)
#if type(styles) == list and len(styles) < 2: # this if for input as list of [(1,512)]
if not real:
if len(styles) < 2:
inject_index = self.n_latent
if styles[0].ndim < 3:
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else:
latent = styles[0]
elif type(styles) == list:
if inject_index is None:
inject_index = 4
latent = styles[0].unsqueeze(0)
if latent.shape[1] == 1:
latent = latent.repeat(1, inject_index, 1)
else:
latent = latent[:, :inject_index, :]
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
latent = torch.cat([latent, latent2], 1)
else: # input is tensor of size with torch.Size([1, 18, 512]), for real PTI output
latent = styles
# print(f'processed latent: {latent.shape}')
features = {}
out = self.input(latent)
features["out_0"] = out
out = self.conv1(out, latent[:, 0], noise=noise[0])
features["conv1_0"] = out
skip = self.to_rgb1(out, latent[:, 1])
features["skip_0"] = skip
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
):
out = conv1(out, latent[:, i], noise=noise1)
features["conv1_{}".format(i)] = out
out = conv2(out, latent[:, i + 1], noise=noise2)
features["conv2_{}".format(i)] = out
skip = to_rgb(out, latent[:, i + 2], skip)
features["skip_{}".format(i)] = skip
i += 2
image = skip
if return_latents:
return image, latent
elif return_features:
return image, features
else:
return image, None
class ConvLayer(nn.Sequential):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
downsample=False,
blur_kernel=[1, 3, 3, 1],
bias=True,
activate=True,
):
layers = []
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
stride = 2
self.padding = 0
else:
stride = 1
self.padding = kernel_size // 2
layers.append(
EqualConv2d(
in_channel,
out_channel,
kernel_size,
padding=self.padding,
stride=stride,
bias=bias and not activate,
)
)
if activate:
if bias:
layers.append(FusedLeakyReLU(out_channel))
else:
layers.append(ScaledLeakyReLU(0.2))
super().__init__(*layers)
class ResBlock(nn.Module):
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
super().__init__()
self.conv1 = ConvLayer(in_channel, in_channel, 3)
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
self.skip = ConvLayer(
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
)
def forward(self, input):
out = self.conv1(input)
out = self.conv2(out)
skip = self.skip(input)
out = (out + skip) / math.sqrt(2)
return out
class StyleDiscriminator(nn.Module):
def __init__(
self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], small=False
):
super().__init__()
if small:
channels = {4: 64, 8: 64, 16: 64, 32: 64, 64: 64}
else:
channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
convs = [ConvLayer(3, channels[size], 1)]
log_size = int(math.log(size, 2))
in_channel = channels[size]
for i in range(log_size, 2, -1):
out_channel = channels[2 ** (i - 1)]
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
in_channel = out_channel
self.convs = nn.Sequential(*convs)
self.stddev_group = 4
self.stddev_feat = 1
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
self.final_linear = nn.Sequential(
EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
EqualLinear(channels[4], 1),
)
def forward(self, input):
h = input
h_list = []
for index, blocklist in enumerate(self.convs):
h = blocklist(h)
h_list.append(h)
out = h
batch, channel, height, width = out.shape
group = min(batch, self.stddev_group)
stddev = out.view(
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
)
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
stddev = stddev.repeat(group, 1, height, width)
out = torch.cat([out, stddev], 1)
out = self.final_conv(out)
h_list.append(out)
out = out.view(batch, -1)
out = self.final_linear(out)
return out, h_list
class StyleEncoder(nn.Module):
def __init__(self, size, w_dim=512):
super().__init__()
channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256,
128: 128,
256: 64,
512: 32,
1024: 16
}
self.w_dim = w_dim
log_size = int(math.log(size, 2))
convs = [ConvLayer(3, channels[size], 1)]
in_channel = channels[size]
for i in range(log_size, 2, -1):
out_channel = channels[2 ** (i - 1)]
convs.append(ResBlock(in_channel, out_channel))
in_channel = out_channel
convs.append(EqualConv2d(in_channel,2*self.w_dim, 4, padding=0, bias=False))
self.convs = nn.Sequential(*convs)
def forward(self, input):
out = self.convs(input)
# return out.view(len(input), self.n_latents, self.w_dim)
reshaped = out.view(len(input), 2*self.w_dim)
return reshaped[:,:self.w_dim], reshaped[:,self.w_dim:]
def kaiming_init(m):
if isinstance(m, (nn.Linear, nn.Conv2d)):
init.kaiming_normal_(m.weight)
if m.bias is not None:
m.bias.data.fill_(0)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
m.weight.data.fill_(1)
if m.bias is not None:
m.bias.data.fill_(0)
def normal_init(m):
if isinstance(m, (nn.Linear, nn.Conv2d)):
init.normal_(m.weight, 0, 0.02)
if m.bias is not None:
m.bias.data.fill_(0)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
m.weight.data.fill_(1)
if m.bias is not None:
m.bias.data.fill_(0)
================================================
FILE: stylegan_human/torch_utils/models_face.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
import math
import random
import functools
import operator
import torch
from torch import nn
from torch.nn import functional as F
import torch.nn.init as init
from torch.autograd import Function
from .op_edit import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
class PixelNorm(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input):
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
def make_kernel(k):
k = torch.tensor(k, dtype=torch.float32)
if k.ndim == 1:
k = k[None, :] * k[:, None]
k /= k.sum()
return k
class Upsample(nn.Module):
def __init__(self, kernel, factor=2):
super().__init__()
self.factor = factor
kernel = make_kernel(kernel) * (factor ** 2)
self.register_buffer("kernel", kernel)
p = kernel.shape[0] - factor
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2
self.pad = (pad0, pad1)
def forward(self, input):
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
return out
class Downsample(nn.Module):
def __init__(self, kernel, factor=2):
super().__init__()
self.factor = factor
kernel = make_kernel(kernel)
self.register_buffer("kernel", kernel)
p = kernel.shape[0] - factor
pad0 = (p + 1) // 2
pad1 = p // 2
self.pad = (pad0, pad1)
def forward(self, input):
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
return out
class Blur(nn.Module):
def __init__(self, kernel, pad, upsample_factor=1):
super().__init__()
kernel = make_kernel(kernel)
if upsample_factor > 1:
kernel = kernel * (upsample_factor ** 2)
self.register_buffer("kernel", kernel)
self.pad = pad
def forward(self, input):
out = upfirdn2d(input, self.kernel, pad=self.pad)
return out
class EqualConv2d(nn.Module):
def __init__(
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
):
super().__init__()
self.weight = nn.Parameter(
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
)
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
self.stride = stride
self.padding = padding
if bias:
self.bias = nn.Parameter(torch.zeros(out_channel))
else:
self.bias = None
def forward(self, input):
out = F.conv2d(
input,
self.weight * self.scale,
bias=self.bias,
stride=self.stride,
padding=self.padding,
)
return out
def __repr__(self):
return (
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
)
class EqualLinear(nn.Module):
def __init__(
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
if bias:
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
else:
self.bias = None
self.activation = activation
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
self.lr_mul = lr_mul
def forward(self, input):
if self.activation:
out = F.linear(input, self.weight * self.scale)
out = fused_leaky_relu(out, self.bias * self.lr_mul)
else:
out = F.linear(
input, self.weight * self.scale, bias=self.bias * self.lr_mul
)
return out
def __repr__(self):
return (
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
)
class ScaledLeakyReLU(nn.Module):
def __init__(self, negative_slope=0.2):
super().__init__()
self.negative_slope = negative_slope
def forward(self, input):
out = F.leaky_relu(input, negative_slope=self.negative_slope)
return out * math.sqrt(2)
class ModulatedConv2d(nn.Module):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
style_dim,
demodulate=True,
upsample=False,
downsample=False,
blur_kernel=[1, 3, 3, 1],
):
super().__init__()
self.eps = 1e-8
self.kernel_size = kernel_size
self.in_channel = in_channel
self.out_channel = out_channel
self.upsample = upsample
self.downsample = downsample
if upsample:
factor = 2
p = (len(blur_kernel) - factor) - (kernel_size - 1)
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2 + 1
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
fan_in = in_channel * kernel_size ** 2
self.scale = 1 / math.sqrt(fan_in)
self.padding = kernel_size // 2
self.weight = nn.Parameter(
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
)
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
self.demodulate = demodulate
def __repr__(self):
return (
f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
f"upsample={self.upsample}, downsample={self.downsample})"
)
def forward(self, input, style):
batch, in_channel, height, width = input.shape
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
weight = self.scale * self.weight * style
if self.demodulate:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
weight = weight.view(
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
)
if self.upsample:
input = input.view(1, batch * in_channel, height, width)
weight = weight.view(
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
)
weight = weight.transpose(1, 2).reshape(
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
)
out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)
out = self.blur(out)
elif self.downsample:
input = self.blur(input)
_, _, height, width = input.shape
input = input.view(1, batch * in_channel, height, width)
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)
else:
input = input.view(1, batch * in_channel, height, width)
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)
return out
class NoiseInjection(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.zeros(1))
def forward(self, image, noise=None):
if noise is None:
batch, _, height, width = image.shape
noise = image.new_empty(batch, 1, height, width).normal_()
return image + self.weight * noise
class ConstantInput(nn.Module):
def __init__(self, channel, size=4):
super().__init__()
self.input = nn.Parameter(torch.randn(1, channel, size, size))
def forward(self, input):
batch = input.shape[0]
out = self.input.repeat(batch, 1, 1, 1)
return out
class StyledConv(nn.Module):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
style_dim,
upsample=False,
blur_kernel=[1, 3, 3, 1],
demodulate=True,
):
super().__init__()
self.conv = ModulatedConv2d(
in_channel,
out_channel,
kernel_size,
style_dim,
upsample=upsample,
blur_kernel=blur_kernel,
demodulate=demodulate,
)
self.noise = NoiseInjection()
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
# self.activate = ScaledLeakyReLU(0.2)
self.activate = FusedLeakyReLU(out_channel)
def forward(self, input, style, noise=None):
out = self.conv(input, style)
out = self.noise(out, noise=noise)
# out = out + self.bias
out = self.activate(out)
return out
class ToRGB(nn.Module):
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
super().__init__()
if upsample:
self.upsample = Upsample(blur_kernel)
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
def forward(self, input, style, skip=None):
out = self.conv(input, style)
out = out + self.bias
if skip is not None:
skip = self.upsample(skip)
out = out + skip
return out
class Generator(nn.Module):
def __init__(
self,
size,
style_dim,
n_mlp,
channel_multiplier=1,
blur_kernel=[1, 3, 3, 1],
lr_mlp=0.01,
small=False,
small_isaac=False,
):
super().__init__()
self.size = size
if small and size > 64:
raise ValueError("small only works for sizes <= 64")
self.style_dim = style_dim
layers = [PixelNorm()]
for i in range(n_mlp):
layers.append(
EqualLinear(
style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
)
)
self.style = nn.Sequential(*layers)
if small:
self.channels = {
4: 64 * channel_multiplier,
8: 64 * channel_multiplier,
16: 64 * channel_multiplier,
32: 64 * channel_multiplier,
64: 64 * channel_multiplier,
}
elif small_isaac:
self.channels = {4: 256, 8: 256, 16: 256, 32: 256, 64: 128, 128: 128}
else:
self.channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
self.input = ConstantInput(self.channels[4])
self.conv1 = StyledConv(
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
)
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
self.log_size = int(math.log(size, 2))
self.num_layers = (self.log_size - 2) * 2 + 1
self.convs = nn.ModuleList()
self.upsamples = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
self.noises = nn.Module()
in_channel = self.channels[4]
for layer_idx in range(self.num_layers):
res = (layer_idx + 5) // 2
shape = [1, 1, 2 ** res, 2 ** res]
self.noises.register_buffer(
"noise_{}".format(layer_idx), torch.randn(*shape)
)
for i in range(3, self.log_size + 1):
out_channel = self.channels[2 ** i]
self.convs.append(
StyledConv(
in_channel,
out_channel,
3,
style_dim,
upsample=True,
blur_kernel=blur_kernel,
)
)
self.convs.append(
StyledConv(
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
)
)
self.to_rgbs.append(ToRGB(out_channel, style_dim))
in_channel = out_channel
self.n_latent = self.log_size * 2 - 2
def make_noise(self):
device = self.input.input.device
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
for i in range(3, self.log_size + 1):
for _ in range(2):
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
return noises
def mean_latent(self, n_latent):
latent_in = torch.randn(
n_latent, self.style_dim, device=self.input.input.device
)
latent = self.style(latent_in).mean(0, keepdim=True)
return latent
def get_latent(self, input):
return self.style(input)
def forward(
self,
styles,
return_latents=False,
return_features=False,
inject_index=None,
truncation=1,
truncation_latent=None,
input_is_latent=False,
noise=None,
randomize_noise=True,
):
if not input_is_latent:
# print("haha")
styles = [self.style(s) for s in styles]
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers
else:
noise = [
getattr(self.noises, "noise_{}".format(i))
for i in range(self.num_layers)
]
if truncation < 1:
style_t = []
for style in styles:
style_t.append(
truncation_latent + truncation * (style - truncation_latent)
)
styles = style_t
# print(styles)
if len(styles) < 2:
inject_index = self.n_latent
if styles[0].ndim < 3:
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
# print("a")
else:
# print(len(styles))
latent = styles[0]
# print("b", latent.shape)
else:
# print("c")
if inject_index is None:
inject_index = 4
latent = styles[0].unsqueeze(0)
if latent.shape[1] == 1:
latent = latent.repeat(1, inject_index, 1)
else:
latent = latent[:, :inject_index, :]
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
latent = torch.cat([latent, latent2], 1)
features = {}
out = self.input(latent)
features["out_0"] = out
out = self.conv1(out, latent[:, 0], noise=noise[0])
features["conv1_0"] = out
skip = self.to_rgb1(out, latent[:, 1])
features["skip_0"] = skip
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
):
out = conv1(out, latent[:, i], noise=noise1)
features["conv1_{}".format(i)] = out
out = conv2(out, latent[:, i + 1], noise=noise2)
features["conv2_{}".format(i)] = out
skip = to_rgb(out, latent[:, i + 2], skip)
features["skip_{}".format(i)] = skip
i += 2
image = skip
if return_latents:
return image, latent
elif return_features:
return image, features
else:
return image, None
class ConvLayer(nn.Sequential):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
downsample=False,
blur_kernel=[1, 3, 3, 1],
bias=True,
activate=True,
):
layers = []
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
stride = 2
self.padding = 0
else:
stride = 1
self.padding = kernel_size // 2
layers.append(
EqualConv2d(
in_channel,
out_channel,
kernel_size,
padding=self.padding,
stride=stride,
bias=bias and not activate,
)
)
if activate:
if bias:
layers.append(FusedLeakyReLU(out_channel))
else:
layers.append(ScaledLeakyReLU(0.2))
super().__init__(*layers)
class ResBlock(nn.Module):
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
super().__init__()
self.conv1 = ConvLayer(in_channel, in_channel, 3)
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
self.skip = ConvLayer(
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
)
def forward(self, input):
out = self.conv1(input)
out = self.conv2(out)
skip = self.skip(input)
out = (out + skip) / math.sqrt(2)
return out
class StyleDiscriminator(nn.Module):
def __init__(
self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], small=False
):
super().__init__()
if small:
channels = {4: 64, 8: 64, 16: 64, 32: 64, 64: 64}
else:
channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
convs = [ConvLayer(3, channels[size], 1)]
log_size = int(math.log(size, 2))
in_channel = channels[size]
for i in range(log_size, 2, -1):
out_channel = channels[2 ** (i - 1)]
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
in_channel = out_channel
self.convs = nn.Sequential(*convs)
self.stddev_group = 4
self.stddev_feat = 1
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
self.final_linear = nn.Sequential(
EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
EqualLinear(channels[4], 1),
)
# def forward(self, input):
# out = self.convs(input)
# batch, channel, height, width = out.shape
# group = min(batch, self.stddev_group)
# stddev = out.view(
# group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
# )
# stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
# stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
# stddev = stddev.repeat(group, 1, height, width)
# out = torch.cat([out, stddev], 1)
# out = self.final_conv(out)
# out = out.view(batch, -1)
# out = self.final_linear(out)
# return out
def forward(self, input):
h = input
h_list = []
for index, blocklist in enumerate(self.convs):
h = blocklist(h)
h_list.append(h)
out = h
batch, channel, height, width = out.shape
group = min(batch, self.stddev_group)
stddev = out.view(
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
)
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
stddev = stddev.repeat(group, 1, height, width)
out = torch.cat([out, stddev], 1)
out = self.final_conv(out)
h_list.append(out)
out = out.view(batch, -1)
out = self.final_linear(out)
return out, h_list
class StyleEncoder(nn.Module):
def __init__(self, size, w_dim=512):
super().__init__()
channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256,
128: 128,
256: 64,
512: 32,
1024: 16
}
self.w_dim = w_dim
log_size = int(math.log(size, 2))
# self.n_latents = log_size*2 - 2
convs = [ConvLayer(3, channels[size], 1)]
in_channel = channels[size]
for i in range(log_size, 2, -1):
out_channel = channels[2 ** (i - 1)]
convs.append(ResBlock(in_channel, out_channel))
in_channel = out_channel
# convs.append(EqualConv2d(in_channel, self.n_latents*self.w_dim, 4, padding=0, bias=False))
convs.append(EqualConv2d(in_channel,2*self.w_dim, 4, padding=0, bias=False))
self.convs = nn.Sequential(*convs)
def forward(self, input):
out = self.convs(input)
# return out.view(len(input), self.n_latents, self.w_dim)
reshaped = out.view(len(input), 2*self.w_dim)
return reshaped[:,:self.w_dim], reshaped[:,self.w_dim:]
def kaiming_init(m):
if isinstance(m, (nn.Linear, nn.Conv2d)):
init.kaiming_normal_(m.weight)
if m.bias is not None:
m.bias.data.fill_(0)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
m.weight.data.fill_(1)
if m.bias is not None:
m.bias.data.fill_(0)
def normal_init(m):
if isinstance(m, (nn.Linear, nn.Conv2d)):
init.normal_(m.weight, 0, 0.02)
if m.bias is not None:
m.bias.data.fill_(0)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
m.weight.data.fill_(1)
if m.bias is not None:
m.bias.data.fill_(0)
================================================
FILE: stylegan_human/torch_utils/op_edit/__init__.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
from .fused_act import FusedLeakyReLU, fused_leaky_relu
from .upfirdn2d import upfirdn2d
================================================
FILE: stylegan_human/torch_utils/op_edit/fused_act.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
import os
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Function
from torch.utils.cpp_extension import load
module_path = os.path.dirname(__file__)
fused = load(
"fused",
sources=[
os.path.join(module_path, "fused_bias_act.cpp"),
os.path.join(module_path, "fused_bias_act_kernel.cu"),
],
)
class FusedLeakyReLUFunctionBackward(Function):
@staticmethod
def forward(ctx, grad_output, out, negative_slope, scale):
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale
empty = grad_output.new_empty(0)
grad_input = fused.fused_bias_act(
grad_output, empty, out, 3, 1, negative_slope, scale
)
dim = [0]
if grad_input.ndim > 2:
dim += list(range(2, grad_input.ndim))
grad_bias = grad_input.sum(dim).detach()
return grad_input, grad_bias
@staticmethod
def backward(ctx, gradgrad_input, gradgrad_bias):
(out,) = ctx.saved_tensors
gradgrad_out = fused.fused_bias_act(
gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
)
return gradgrad_out, None, None, None
class FusedLeakyReLUFunction(Function):
@staticmethod
def forward(ctx, input, bias, negative_slope, scale):
empty = input.new_empty(0)
out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale
return out
@staticmethod
def backward(ctx, grad_output):
(out,) = ctx.saved_tensors
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
grad_output, out, ctx.negative_slope, ctx.scale
)
return grad_input, grad_bias, None, None
class FusedLeakyReLU(nn.Module):
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
super().__init__()
self.bias = nn.Parameter(torch.zeros(channel))
self.negative_slope = negative_slope
self.scale = scale
def forward(self, input):
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
if input.device.type == "cpu":
rest_dim = [1] * (input.ndim - bias.ndim - 1)
return (
F.leaky_relu(
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
)
* scale
)
else:
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
================================================
FILE: stylegan_human/torch_utils/op_edit/fused_bias_act.cpp
================================================
// Copyright (c) SenseTime Research. All rights reserved.
#include
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
int act, int grad, float alpha, float scale);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
int act, int grad, float alpha, float scale) {
CHECK_CUDA(input);
CHECK_CUDA(bias);
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
}
================================================
FILE: stylegan_human/torch_utils/op_edit/fused_bias_act_kernel.cu
================================================
// Copyright (c) SenseTime Research. All rights reserved.
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include
#include
#include
#include
#include
#include
#include
template
static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
scalar_t zero = 0.0;
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
scalar_t x = p_x[xi];
if (use_bias) {
x += p_b[(xi / step_b) % size_b];
}
scalar_t ref = use_ref ? p_ref[xi] : zero;
scalar_t y;
switch (act * 10 + grad) {
default:
case 10: y = x; break;
case 11: y = x; break;
case 12: y = 0.0; break;
case 30: y = (x > 0.0) ? x : x * alpha; break;
case 31: y = (ref > 0.0) ? x : x * alpha; break;
case 32: y = 0.0; break;
}
out[xi] = y * scale;
}
}
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
int act, int grad, float alpha, float scale) {
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
auto x = input.contiguous();
auto b = bias.contiguous();
auto ref = refer.contiguous();
int use_bias = b.numel() ? 1 : 0;
int use_ref = ref.numel() ? 1 : 0;
int size_x = x.numel();
int size_b = b.numel();
int step_b = 1;
for (int i = 1 + 1; i < x.dim(); i++) {
step_b *= x.size(i);
}
int loop_x = 4;
int block_size = 4 * 32;
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
auto y = torch::empty_like(x);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
fused_bias_act_kernel<<>>(
y.data_ptr(),
x.data_ptr(),
b.data_ptr(),
ref.data_ptr(),
act,
grad,
alpha,
scale,
loop_x,
size_x,
step_b,
size_b,
use_bias,
use_ref
);
});
return y;
}
================================================
FILE: stylegan_human/torch_utils/op_edit/upfirdn2d.cpp
================================================
// Copyright (c) SenseTime Research. All rights reserved.
#include
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
int up_x, int up_y, int down_x, int down_y,
int pad_x0, int pad_x1, int pad_y0, int pad_y1);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
int up_x, int up_y, int down_x, int down_y,
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
CHECK_CUDA(input);
CHECK_CUDA(kernel);
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
}
================================================
FILE: stylegan_human/torch_utils/op_edit/upfirdn2d.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
import os
import torch
from torch.nn import functional as F
from torch.autograd import Function
from torch.utils.cpp_extension import load
module_path = os.path.dirname(__file__)
upfirdn2d_op = load(
"upfirdn2d",
sources=[
os.path.join(module_path, "upfirdn2d.cpp"),
os.path.join(module_path, "upfirdn2d_kernel.cu"),
],
)
class UpFirDn2dBackward(Function):
@staticmethod
def forward(
ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
):
up_x, up_y = up
down_x, down_y = down
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
grad_input = upfirdn2d_op.upfirdn2d(
grad_output,
grad_kernel,
down_x,
down_y,
up_x,
up_y,
g_pad_x0,
g_pad_x1,
g_pad_y0,
g_pad_y1,
)
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
ctx.save_for_backward(kernel)
pad_x0, pad_x1, pad_y0, pad_y1 = pad
ctx.up_x = up_x
ctx.up_y = up_y
ctx.down_x = down_x
ctx.down_y = down_y
ctx.pad_x0 = pad_x0
ctx.pad_x1 = pad_x1
ctx.pad_y0 = pad_y0
ctx.pad_y1 = pad_y1
ctx.in_size = in_size
ctx.out_size = out_size
return grad_input
@staticmethod
def backward(ctx, gradgrad_input):
(kernel,) = ctx.saved_tensors
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
gradgrad_out = upfirdn2d_op.upfirdn2d(
gradgrad_input,
kernel,
ctx.up_x,
ctx.up_y,
ctx.down_x,
ctx.down_y,
ctx.pad_x0,
ctx.pad_x1,
ctx.pad_y0,
ctx.pad_y1,
)
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
gradgrad_out = gradgrad_out.view(
ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
)
return gradgrad_out, None, None, None, None, None, None, None, None
class UpFirDn2d(Function):
@staticmethod
def forward(ctx, input, kernel, up, down, pad):
up_x, up_y = up
down_x, down_y = down
pad_x0, pad_x1, pad_y0, pad_y1 = pad
kernel_h, kernel_w = kernel.shape
batch, channel, in_h, in_w = input.shape
ctx.in_size = input.shape
input = input.reshape(-1, in_h, in_w, 1)
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
ctx.out_size = (out_h, out_w)
ctx.up = (up_x, up_y)
ctx.down = (down_x, down_y)
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
g_pad_x0 = kernel_w - pad_x0 - 1
g_pad_y0 = kernel_h - pad_y0 - 1
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
out = upfirdn2d_op.upfirdn2d(
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
)
# out = out.view(major, out_h, out_w, minor)
out = out.view(-1, channel, out_h, out_w)
return out
@staticmethod
def backward(ctx, grad_output):
kernel, grad_kernel = ctx.saved_tensors
grad_input = UpFirDn2dBackward.apply(
grad_output,
kernel,
grad_kernel,
ctx.up,
ctx.down,
ctx.pad,
ctx.g_pad,
ctx.in_size,
ctx.out_size,
)
return grad_input, None, None, None, None
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
if input.device.type == "cpu":
out = upfirdn2d_native(
input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
)
else:
out = UpFirDn2d.apply(
input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
)
return out
def upfirdn2d_native(
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
):
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
)
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
:,
]
out = out.permute(0, 3, 1, 2)
out = out.reshape(
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
)
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.view(-1, channel, out_h, out_w)
================================================
FILE: stylegan_human/torch_utils/op_edit/upfirdn2d_kernel.cu
================================================
// Copyright (c) SenseTime Research. All rights reserved.
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include
#include
#include
#include
#include
#include
#include
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
int c = a / b;
if (c * b > a) {
c--;
}
return c;
}
struct UpFirDn2DKernelParams {
int up_x;
int up_y;
int down_x;
int down_y;
int pad_x0;
int pad_x1;
int pad_y0;
int pad_y1;
int major_dim;
int in_h;
int in_w;
int minor_dim;
int kernel_h;
int kernel_w;
int out_h;
int out_w;
int loop_major;
int loop_x;
};
template
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
const scalar_t *kernel,
const UpFirDn2DKernelParams p) {
int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
int out_y = minor_idx / p.minor_dim;
minor_idx -= out_y * p.minor_dim;
int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
int major_idx_base = blockIdx.z * p.loop_major;
if (out_x_base >= p.out_w || out_y >= p.out_h ||
major_idx_base >= p.major_dim) {
return;
}
int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
for (int loop_major = 0, major_idx = major_idx_base;
loop_major < p.loop_major && major_idx < p.major_dim;
loop_major++, major_idx++) {
for (int loop_x = 0, out_x = out_x_base;
loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
const scalar_t *x_p =
&input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
minor_idx];
const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
int x_px = p.minor_dim;
int k_px = -p.up_x;
int x_py = p.in_w * p.minor_dim;
int k_py = -p.up_y * p.kernel_w;
scalar_t v = 0.0f;
for (int y = 0; y < h; y++) {
for (int x = 0; x < w; x++) {
v += static_cast(*x_p) * static_cast(*k_p);
x_p += x_px;
k_p += k_px;
}
x_p += x_py - w * x_px;
k_p += k_py - w * k_px;
}
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
minor_idx] = v;
}
}
}
template
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
const scalar_t *kernel,
const UpFirDn2DKernelParams p) {
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
__shared__ volatile float sk[kernel_h][kernel_w];
__shared__ volatile float sx[tile_in_h][tile_in_w];
int minor_idx = blockIdx.x;
int tile_out_y = minor_idx / p.minor_dim;
minor_idx -= tile_out_y * p.minor_dim;
tile_out_y *= tile_out_h;
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
int major_idx_base = blockIdx.z * p.loop_major;
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
major_idx_base >= p.major_dim) {
return;
}
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
tap_idx += blockDim.x) {
int ky = tap_idx / kernel_w;
int kx = tap_idx - ky * kernel_w;
scalar_t v = 0.0;
if (kx < p.kernel_w & ky < p.kernel_h) {
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
}
sk[ky][kx] = v;
}
for (int loop_major = 0, major_idx = major_idx_base;
loop_major < p.loop_major & major_idx < p.major_dim;
loop_major++, major_idx++) {
for (int loop_x = 0, tile_out_x = tile_out_x_base;
loop_x < p.loop_x & tile_out_x < p.out_w;
loop_x++, tile_out_x += tile_out_w) {
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
int tile_in_x = floor_div(tile_mid_x, up_x);
int tile_in_y = floor_div(tile_mid_y, up_y);
__syncthreads();
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
in_idx += blockDim.x) {
int rel_in_y = in_idx / tile_in_w;
int rel_in_x = in_idx - rel_in_y * tile_in_w;
int in_x = rel_in_x + tile_in_x;
int in_y = rel_in_y + tile_in_y;
scalar_t v = 0.0;
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
p.minor_dim +
minor_idx];
}
sx[rel_in_y][rel_in_x] = v;
}
__syncthreads();
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
out_idx += blockDim.x) {
int rel_out_y = out_idx / tile_out_w;
int rel_out_x = out_idx - rel_out_y * tile_out_w;
int out_x = rel_out_x + tile_out_x;
int out_y = rel_out_y + tile_out_y;
int mid_x = tile_mid_x + rel_out_x * down_x;
int mid_y = tile_mid_y + rel_out_y * down_y;
int in_x = floor_div(mid_x, up_x);
int in_y = floor_div(mid_y, up_y);
int rel_in_x = in_x - tile_in_x;
int rel_in_y = in_y - tile_in_y;
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
scalar_t v = 0.0;
#pragma unroll
for (int y = 0; y < kernel_h / up_y; y++)
#pragma unroll
for (int x = 0; x < kernel_w / up_x; x++)
v += sx[rel_in_y + y][rel_in_x + x] *
sk[kernel_y + y * up_y][kernel_x + x * up_x];
if (out_x < p.out_w & out_y < p.out_h) {
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
minor_idx] = v;
}
}
}
}
}
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
const torch::Tensor &kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1,
int pad_y0, int pad_y1) {
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
UpFirDn2DKernelParams p;
auto x = input.contiguous();
auto k = kernel.contiguous();
p.major_dim = x.size(0);
p.in_h = x.size(1);
p.in_w = x.size(2);
p.minor_dim = x.size(3);
p.kernel_h = k.size(0);
p.kernel_w = k.size(1);
p.up_x = up_x;
p.up_y = up_y;
p.down_x = down_x;
p.down_y = down_y;
p.pad_x0 = pad_x0;
p.pad_x1 = pad_x1;
p.pad_y0 = pad_y0;
p.pad_y1 = pad_y1;
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
p.down_y;
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
p.down_x;
auto out =
at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
int mode = -1;
int tile_out_h = -1;
int tile_out_w = -1;
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 1;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 3 && p.kernel_w <= 3) {
mode = 2;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 3;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 2 && p.kernel_w <= 2) {
mode = 4;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 5;
tile_out_h = 8;
tile_out_w = 32;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
p.kernel_h <= 2 && p.kernel_w <= 2) {
mode = 6;
tile_out_h = 8;
tile_out_w = 32;
}
dim3 block_size;
dim3 grid_size;
if (tile_out_h > 0 && tile_out_w > 0) {
p.loop_major = (p.major_dim - 1) / 16384 + 1;
p.loop_x = 1;
block_size = dim3(32 * 8, 1, 1);
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
(p.major_dim - 1) / p.loop_major + 1);
} else {
p.loop_major = (p.major_dim - 1) / 16384 + 1;
p.loop_x = 4;
block_size = dim3(4, 32, 1);
grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
(p.out_w - 1) / (p.loop_x * block_size.y) + 1,
(p.major_dim - 1) / p.loop_major + 1);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
switch (mode) {
case 1:
upfirdn2d_kernel
<<>>(out.data_ptr(),
x.data_ptr(),
k.data_ptr(), p);
break;
case 2:
upfirdn2d_kernel
<<>>(out.data_ptr(),
x.data_ptr(),
k.data_ptr(), p);
break;
case 3:
upfirdn2d_kernel
<<>>(out.data_ptr(),
x.data_ptr(),
k.data_ptr(), p);
break;
case 4:
upfirdn2d_kernel
<<>>(out.data_ptr(),
x.data_ptr(),
k.data_ptr(), p);
break;
case 5:
upfirdn2d_kernel
<<>>(out.data_ptr(),
x.data_ptr(),
k.data_ptr(), p);
break;
case 6:
upfirdn2d_kernel
<<>>(out.data_ptr(),
x.data_ptr(),
k.data_ptr(), p);
break;
default:
upfirdn2d_kernel_large<<>>(
out.data_ptr(), x.data_ptr(),
k.data_ptr(), p);
}
});
return out;
}
================================================
FILE: stylegan_human/torch_utils/ops/__init__.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
#empty
================================================
FILE: stylegan_human/torch_utils/ops/bias_act.cpp
================================================
// Copyright (c) SenseTime Research. All rights reserved.
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include
#include
#include
#include "bias_act.h"
//------------------------------------------------------------------------
static bool has_same_layout(torch::Tensor x, torch::Tensor y)
{
if (x.dim() != y.dim())
return false;
for (int64_t i = 0; i < x.dim(); i++)
{
if (x.size(i) != y.size(i))
return false;
if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
return false;
}
return true;
}
//------------------------------------------------------------------------
static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
{
// Validate arguments.
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
TORCH_CHECK(b.dim() == 1, "b must have rank 1");
TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
TORCH_CHECK(grad >= 0, "grad must be non-negative");
// Validate layout.
TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
// Create output tensor.
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
torch::Tensor y = torch::empty_like(x);
TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
// Initialize CUDA kernel parameters.
bias_act_kernel_params p;
p.x = x.data_ptr();
p.b = (b.numel()) ? b.data_ptr() : NULL;
p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
p.y = y.data_ptr();
p.grad = grad;
p.act = act;
p.alpha = alpha;
p.gain = gain;
p.clamp = clamp;
p.sizeX = (int)x.numel();
p.sizeB = (int)b.numel();
p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
// Choose CUDA kernel.
void* kernel;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
{
kernel = choose_bias_act_kernel(p);
});
TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
// Launch CUDA kernel.
p.loopX = 4;
int blockSize = 4 * 32;
int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
void* args[] = {&p};
AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
return y;
}
//------------------------------------------------------------------------
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("bias_act", &bias_act);
}
//------------------------------------------------------------------------
================================================
FILE: stylegan_human/torch_utils/ops/bias_act.cu
================================================
// Copyright (c) SenseTime Research. All rights reserved.
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include
#include "bias_act.h"
//------------------------------------------------------------------------
// Helpers.
template struct InternalType;
template <> struct InternalType { typedef double scalar_t; };
template <> struct InternalType { typedef float scalar_t; };
template <> struct InternalType { typedef float scalar_t; };
//------------------------------------------------------------------------
// CUDA kernel.
template
__global__ void bias_act_kernel(bias_act_kernel_params p)
{
typedef typename InternalType::scalar_t scalar_t;
int G = p.grad;
scalar_t alpha = (scalar_t)p.alpha;
scalar_t gain = (scalar_t)p.gain;
scalar_t clamp = (scalar_t)p.clamp;
scalar_t one = (scalar_t)1;
scalar_t two = (scalar_t)2;
scalar_t expRange = (scalar_t)80;
scalar_t halfExpRange = (scalar_t)40;
scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
// Loop over elements.
int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
{
// Load.
scalar_t x = (scalar_t)((const T*)p.x)[xi];
scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
scalar_t yy = (gain != 0) ? yref / gain : 0;
scalar_t y = 0;
// Apply bias.
((G == 0) ? x : xref) += b;
// linear
if (A == 1)
{
if (G == 0) y = x;
if (G == 1) y = x;
}
// relu
if (A == 2)
{
if (G == 0) y = (x > 0) ? x : 0;
if (G == 1) y = (yy > 0) ? x : 0;
}
// lrelu
if (A == 3)
{
if (G == 0) y = (x > 0) ? x : x * alpha;
if (G == 1) y = (yy > 0) ? x : x * alpha;
}
// tanh
if (A == 4)
{
if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
if (G == 1) y = x * (one - yy * yy);
if (G == 2) y = x * (one - yy * yy) * (-two * yy);
}
// sigmoid
if (A == 5)
{
if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
if (G == 1) y = x * yy * (one - yy);
if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
}
// elu
if (A == 6)
{
if (G == 0) y = (x >= 0) ? x : exp(x) - one;
if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
}
// selu
if (A == 7)
{
if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
}
// softplus
if (A == 8)
{
if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
if (G == 1) y = x * (one - exp(-yy));
if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
}
// swish
if (A == 9)
{
if (G == 0)
y = (x < -expRange) ? 0 : x / (exp(-x) + one);
else
{
scalar_t c = exp(xref);
scalar_t d = c + one;
if (G == 1)
y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
else
y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
}
}
// Apply gain.
y *= gain * dy;
// Clamp.
if (clamp >= 0)
{
if (G == 0)
y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
else
y = (yref > -clamp & yref < clamp) ? y : 0;
}
// Store.
((T*)p.y)[xi] = (T)y;
}
}
//------------------------------------------------------------------------
// CUDA kernel selection.
template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
{
if (p.act == 1) return (void*)bias_act_kernel;
if (p.act == 2) return (void*)bias_act_kernel;
if (p.act == 3) return (void*)bias_act_kernel;
if (p.act == 4) return (void*)bias_act_kernel;
if (p.act == 5) return (void*)bias_act_kernel;
if (p.act == 6) return (void*)bias_act_kernel;
if (p.act == 7) return (void*)bias_act_kernel;
if (p.act == 8) return (void*)bias_act_kernel;
if (p.act == 9) return (void*)bias_act_kernel;
return NULL;
}
//------------------------------------------------------------------------
// Template specializations.
template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
//------------------------------------------------------------------------
================================================
FILE: stylegan_human/torch_utils/ops/bias_act.h
================================================
// Copyright (c) SenseTime Research. All rights reserved.
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//------------------------------------------------------------------------
// CUDA kernel parameters.
struct bias_act_kernel_params
{
const void* x; // [sizeX]
const void* b; // [sizeB] or NULL
const void* xref; // [sizeX] or NULL
const void* yref; // [sizeX] or NULL
const void* dy; // [sizeX] or NULL
void* y; // [sizeX]
int grad;
int act;
float alpha;
float gain;
float clamp;
int sizeX;
int sizeB;
int stepB;
int loopX;
};
//------------------------------------------------------------------------
// CUDA kernel selection.
template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
//------------------------------------------------------------------------
================================================
FILE: stylegan_human/torch_utils/ops/bias_act.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Custom PyTorch ops for efficient bias and activation."""
import os
import warnings
import numpy as np
import torch
import dnnlib
import traceback
from .. import custom_ops
from .. import misc
#----------------------------------------------------------------------------
activation_funcs = {
'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
}
#----------------------------------------------------------------------------
_inited = False
_plugin = None
_null_tensor = torch.empty([0])
def _init():
global _inited, _plugin
if not _inited:
_inited = True
sources = ['bias_act.cpp', 'bias_act.cu']
sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
try:
_plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
except:
warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
return _plugin is not None
#----------------------------------------------------------------------------
def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
r"""Fused bias and activation function.
Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
and scales the result by `gain`. Each of the steps is optional. In most cases,
the fused op is considerably more efficient than performing the same calculation
using standard PyTorch ops. It supports first and second order gradients,
but not third order gradients.
Args:
x: Input activation tensor. Can be of any shape.
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
as `x`. The shape must be known, and it must match the dimension of `x`
corresponding to `dim`.
dim: The dimension in `x` corresponding to the elements of `b`.
The value of `dim` is ignored if `b` is not specified.
act: Name of the activation function to evaluate, or `"linear"` to disable.
Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
See `activation_funcs` for a full list. `None` is not allowed.
alpha: Shape parameter for the activation function, or `None` to use the default.
gain: Scaling factor for the output tensor, or `None` to use default.
See `activation_funcs` for the default scaling of each activation function.
If unsure, consider specifying 1.
clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
the clamping (default).
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
Returns:
Tensor of the same shape and datatype as `x`.
"""
assert isinstance(x, torch.Tensor)
assert impl in ['ref', 'cuda']
if impl == 'cuda' and x.device.type == 'cuda' and _init():
return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
#----------------------------------------------------------------------------
@misc.profiled_function
def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
"""Slow reference implementation of `bias_act()` using standard TensorFlow ops.
"""
assert isinstance(x, torch.Tensor)
assert clamp is None or clamp >= 0
spec = activation_funcs[act]
alpha = float(alpha if alpha is not None else spec.def_alpha)
gain = float(gain if gain is not None else spec.def_gain)
clamp = float(clamp if clamp is not None else -1)
# Add bias.
if b is not None:
assert isinstance(b, torch.Tensor) and b.ndim == 1
assert 0 <= dim < x.ndim
assert b.shape[0] == x.shape[dim]
x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
# Evaluate activation function.
alpha = float(alpha)
x = spec.func(x, alpha=alpha)
# Scale by gain.
gain = float(gain)
if gain != 1:
x = x * gain
# Clamp.
if clamp >= 0:
x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
return x
#----------------------------------------------------------------------------
_bias_act_cuda_cache = dict()
def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
"""Fast CUDA implementation of `bias_act()` using custom ops.
"""
# Parse arguments.
assert clamp is None or clamp >= 0
spec = activation_funcs[act]
alpha = float(alpha if alpha is not None else spec.def_alpha)
gain = float(gain if gain is not None else spec.def_gain)
clamp = float(clamp if clamp is not None else -1)
# Lookup from cache.
key = (dim, act, alpha, gain, clamp)
if key in _bias_act_cuda_cache:
return _bias_act_cuda_cache[key]
# Forward op.
class BiasActCuda(torch.autograd.Function):
@staticmethod
def forward(ctx, x, b): # pylint: disable=arguments-differ
ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
x = x.contiguous(memory_format=ctx.memory_format)
b = b.contiguous() if b is not None else _null_tensor
y = x
if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
ctx.save_for_backward(
x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
y if 'y' in spec.ref else _null_tensor)
return y
@staticmethod
def backward(ctx, dy): # pylint: disable=arguments-differ
dy = dy.contiguous(memory_format=ctx.memory_format)
x, b, y = ctx.saved_tensors
dx = None
db = None
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
dx = dy
if act != 'linear' or gain != 1 or clamp >= 0:
dx = BiasActCudaGrad.apply(dy, x, b, y)
if ctx.needs_input_grad[1]:
db = dx.sum([i for i in range(dx.ndim) if i != dim])
return dx, db
# Backward op.
class BiasActCudaGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
ctx.save_for_backward(
dy if spec.has_2nd_grad else _null_tensor,
x, b, y)
return dx
@staticmethod
def backward(ctx, d_dx): # pylint: disable=arguments-differ
d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
dy, x, b, y = ctx.saved_tensors
d_dy = None
d_x = None
d_b = None
d_y = None
if ctx.needs_input_grad[0]:
d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
if spec.has_2nd_grad and ctx.needs_input_grad[2]:
d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
return d_dy, d_x, d_b, d_y
# Add to cache.
_bias_act_cuda_cache[key] = BiasActCuda
return BiasActCuda
#----------------------------------------------------------------------------
================================================
FILE: stylegan_human/torch_utils/ops/conv2d_gradfix.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Custom replacement for `torch.nn.functional.conv2d` that supports
arbitrarily high order gradients with zero performance penalty."""
import warnings
import contextlib
import torch
# pylint: disable=redefined-builtin
# pylint: disable=arguments-differ
# pylint: disable=protected-access
#----------------------------------------------------------------------------
enabled = False # Enable the custom op by setting this to true.
weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
@contextlib.contextmanager
def no_weight_gradients():
global weight_gradients_disabled
old = weight_gradients_disabled
weight_gradients_disabled = True
yield
weight_gradients_disabled = old
#----------------------------------------------------------------------------
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
if _should_use_custom_op(input):
return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
if _should_use_custom_op(input):
return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
#----------------------------------------------------------------------------
def _should_use_custom_op(input):
assert isinstance(input, torch.Tensor)
if (not enabled) or (not torch.backends.cudnn.enabled):
return False
if input.device.type != 'cuda':
return False
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
return True
warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
return False
def _tuple_of_ints(xs, ndim):
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
assert len(xs) == ndim
assert all(isinstance(x, int) for x in xs)
return xs
#----------------------------------------------------------------------------
_conv2d_gradfix_cache = dict()
def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
# Parse arguments.
ndim = 2
weight_shape = tuple(weight_shape)
stride = _tuple_of_ints(stride, ndim)
padding = _tuple_of_ints(padding, ndim)
output_padding = _tuple_of_ints(output_padding, ndim)
dilation = _tuple_of_ints(dilation, ndim)
# Lookup from cache.
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
if key in _conv2d_gradfix_cache:
return _conv2d_gradfix_cache[key]
# Validate arguments.
assert groups >= 1
assert len(weight_shape) == ndim + 2
assert all(stride[i] >= 1 for i in range(ndim))
assert all(padding[i] >= 0 for i in range(ndim))
assert all(dilation[i] >= 0 for i in range(ndim))
if not transpose:
assert all(output_padding[i] == 0 for i in range(ndim))
else: # transpose
assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
# Helpers.
common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
def calc_output_padding(input_shape, output_shape):
if transpose:
return [0, 0]
return [
input_shape[i + 2]
- (output_shape[i + 2] - 1) * stride[i]
- (1 - 2 * padding[i])
- dilation[i] * (weight_shape[i + 2] - 1)
for i in range(ndim)
]
# Forward & backward.
class Conv2d(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias):
assert weight.shape == weight_shape
if not transpose:
output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
else: # transpose
output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
ctx.save_for_backward(input, weight)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_input = None
grad_weight = None
grad_bias = None
if ctx.needs_input_grad[0]:
p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
assert grad_input.shape == input.shape
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
grad_weight = Conv2dGradWeight.apply(grad_output, input)
assert grad_weight.shape == weight_shape
if ctx.needs_input_grad[2]:
grad_bias = grad_output.sum([0, 2, 3])
return grad_input, grad_weight, grad_bias
# Gradient with respect to the weights.
class Conv2dGradWeight(torch.autograd.Function):
@staticmethod
def forward(ctx, grad_output, input):
op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
assert grad_weight.shape == weight_shape
ctx.save_for_backward(grad_output, input)
return grad_weight
@staticmethod
def backward(ctx, grad2_grad_weight):
grad_output, input = ctx.saved_tensors
grad2_grad_output = None
grad2_input = None
if ctx.needs_input_grad[0]:
grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
assert grad2_grad_output.shape == grad_output.shape
if ctx.needs_input_grad[1]:
p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
assert grad2_input.shape == input.shape
return grad2_grad_output, grad2_input
_conv2d_gradfix_cache[key] = Conv2d
return Conv2d
#----------------------------------------------------------------------------
================================================
FILE: stylegan_human/torch_utils/ops/conv2d_resample.py
================================================
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""2D convolution with optional up/downsampling."""
import torch
from .. import misc
from . import conv2d_gradfix
from . import upfirdn2d
from .upfirdn2d import _parse_padding
from .upfirdn2d import _get_filter_size
#----------------------------------------------------------------------------
def _get_weight_shape(w):
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
shape = [int(sz) for sz in w.shape]
misc.assert_shape(w, shape)
return shape
#----------------------------------------------------------------------------
def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
"""Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
"""
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
# Flip weight if requested.
if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
w = w.flip([2, 3])
# Workaround performance pitfall in cuDNN 8.0.5, triggered when using
# 1x1 kernel + memory_format=channels_last + less than 64 channels.
if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
if out_channels <= 4 and groups == 1:
in_shape = x.shape
x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
else:
x = x.to(memory_format=torch.contiguous_format)
w = w.to(memory_format=torch.contiguous_format)
x = conv2d_gradfix.conv2d(x, w, groups=groups)
return x.to(memory_format=torch.channels_last)
# Otherwise => execute using conv2d_gradfix.
op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
return op(x, w, stride=stride, padding=padding, groups=groups)
#----------------------------------------------------------------------------
@misc.profiled_function
def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
r"""2D convolution with optional up/downsampling.
Padding is performed only once at the beginning, not between the operations.
Args:
x: Input tensor of shape
`[batch_size, in_channels, in_height, in_width]`.
w: Weight tensor of shape
`[out_channels, in_channels//groups, kernel_height, kernel_width]`.
f: Low-pass filter for up/downsampling. Must be prepared beforehand by
calling upfirdn2d.setup_filter(). None = identity (default).
up: Integer upsampling factor (default: 1).
down: Integer downsampling factor (default: 1).
padding: Padding with respect to the upsampled image. Can be a single number
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
(default: 0).
groups: Split input channels into N groups (default: 1).
flip_weight: False = convolution, True = correlation (default: True).
flip_filter: False = convolution, True = correlation (default: False).
Returns:
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
"""
# Validate arguments.
assert isinstance(x, torch.Tensor) and (x.ndim == 4)
assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
assert isinstance(up, int) and (up >= 1)
assert isinstance(down, int) and (down >= 1)
assert isinstance(groups, int) and (groups >= 1)
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
fw, fh = _get_filter_size(f)
px0, px1, py0, py1 = _parse_padding(padding)
# Adjust padding to account for up/downsampling.
if up > 1:
px0 += (fw + up - 1) // 2
px1 += (fw - up) // 2
py0 += (fh + up - 1) // 2
py1 += (fh - up) // 2
if down > 1:
px0 += (fw - down + 1) // 2
px1 += (fw - down) // 2
py0 += (fh - down + 1) // 2
py1 += (fh - down) // 2
# Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
if kw == 1 and kh == 1 and (down > 1 and up == 1):
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
return x
# Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
if kw == 1 and kh == 1 and (up > 1 and down == 1):
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
return x
# Fast path: downsampling only => use strided convolution.
if down > 1 and up == 1:
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
return x
# Fast path: upsampling with optional downsampling => use transpose strided convolution.
if up > 1:
if groups == 1:
w = w.transpose(0, 1)
else:
w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
w = w.transpose(1, 2)
w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
px0 -= kw - 1
px1 -= kw - up
py0 -= kh - 1
py1 -= kh - up
pxt = max(min(-px0, -px1), 0)
pyt = max(min(-py0, -py1), 0)
x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
if down > 1:
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
return x
# Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
if up == 1 and down == 1:
if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
# Fallback: Generic reference implementation.
x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
if down > 1:
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
return x
#----------------------------------------------------------------------------
================================================
FILE: stylegan_human/torch_utils/ops/filtered_lrelu.cpp
================================================
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include
#include
#include
#include "filtered_lrelu.h"
//------------------------------------------------------------------------
static std::tuple filtered_lrelu(
torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si,
int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns)
{
// Set CUDA device.
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
// Validate arguments.
TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device");
TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32");
TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype");
TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32");
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
TORCH_CHECK(x.numel() > 0, "x is empty");
TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2");
TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large");
TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large");
TORCH_CHECK(fu.numel() > 0, "fu is empty");
TORCH_CHECK(fd.numel() > 0, "fd is empty");
TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x");
TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1");
// Figure out how much shared memory is available on the device.
int maxSharedBytes = 0;
AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index()));
int sharedKB = maxSharedBytes >> 10;
// Populate enough launch parameters to check if a CUDA kernel exists.
filtered_lrelu_kernel_params p;
p.up = up;
p.down = down;
p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter.
p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0);
filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel(p, sharedKB);
if (!test_spec.exec)
{
// No kernel found - return empty tensors and indicate missing kernel with return code of -1.
return std::make_tuple(torch::Tensor(), torch::Tensor(), -1);
}
// Input/output element size.
int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4;
// Input sizes.
int64_t xw = (int)x.size(3);
int64_t xh = (int)x.size(2);
int64_t fut_w = (int)fu.size(-1) - 1;
int64_t fut_h = (int)fu.size(0) - 1;
int64_t fdt_w = (int)fd.size(-1) - 1;
int64_t fdt_h = (int)fd.size(0) - 1;
// Logical size of upsampled buffer.
int64_t cw = xw * up + (px0 + px1) - fut_w;
int64_t ch = xh * up + (py0 + py1) - fut_h;
TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter");
TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large");
// Compute output size and allocate.
int64_t yw = (cw - fdt_w + (down - 1)) / down;
int64_t yh = (ch - fdt_h + (down - 1)) / down;
TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1");
TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large");
torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format());
// Allocate sign tensor.
torch::Tensor so;
torch::Tensor s = si;
bool readSigns = !!s.numel();
int64_t sw_active = 0; // Active width of sign tensor.
if (writeSigns)
{
sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements.
int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height.
int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16.
TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large");
s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
}
else if (readSigns)
sw_active = s.size(3) << 2;
// Validate sign tensor if in use.
if (readSigns || writeSigns)
{
TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large");
}
// Populate rest of CUDA kernel parameters.
p.x = x.data_ptr();
p.y = y.data_ptr();
p.b = b.data_ptr();
p.s = (readSigns || writeSigns) ? s.data_ptr() : 0;
p.fu = fu.data_ptr();
p.fd = fd.data_ptr();
p.pad0 = make_int2(px0, py0);
p.gain = gain;
p.slope = slope;
p.clamp = clamp;
p.flip = (flip_filters) ? 1 : 0;
p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous.
p.sOfs = make_int2(sx, sy);
p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes.
// x, y, b strides are in bytes.
p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0));
p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0));
p.bStride = sz * b.stride(0);
// fu, fd strides are in elements.
p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0);
p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0);
// Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those.
bool index64b = false;
if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true;
if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true;
if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true;
if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true;
if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true;
if (s.numel() > INT_MAX) index64b = true;
// Choose CUDA kernel.
filtered_lrelu_kernel_spec spec = { 0 };
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&]
{
if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation.
{
// Choose kernel based on index type, datatype and sign read/write modes.
if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
}
});
TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists.
// Launch CUDA kernel.
void* args[] = {&p};
int bx = spec.numWarps * 32;
int gx = (p.yShape.x - 1) / spec.tileOut.x + 1;
int gy = (p.yShape.y - 1) / spec.tileOut.y + 1;
int gz = p.yShape.z * p.yShape.w;
// Repeat multiple horizontal tiles in a CTA?
if (spec.xrep)
{
p.tilesXrep = spec.xrep;
p.tilesXdim = gx;
gx = (gx + p.tilesXrep - 1) / p.tilesXrep;
std::swap(gx, gy);
}
else
{
p.tilesXrep = 0;
p.tilesXdim = 0;
}
// Launch filter setup kernel.
AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream()));
// Copy kernels to constant memory.
if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream())));
else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream())));
else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream())));
// Set cache and shared memory configurations for main kernel.
AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared));
if (spec.dynamicSharedKB) // Need dynamically allocated shared memory?
AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10));
AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte));
// Launch main kernel.
const int maxSubGz = 65535; // CUDA maximum for block z dimension.
for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big.
{
p.blockZofs = zofs;
int subGz = std::min(maxSubGz, gz - zofs);
AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream()));
}
// Done.
return std::make_tuple(y, so, 0);
}
//------------------------------------------------------------------------
static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns)
{
// Set CUDA device.
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
// Validate arguments.
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
TORCH_CHECK(x.numel() > 0, "x is empty");
TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64");
// Output signs if we don't have sign input.
torch::Tensor so;
torch::Tensor s = si;
bool readSigns = !!s.numel();
if (writeSigns)
{
int64_t sw = x.size(3);
sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing.
s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
}
// Validate sign tensor if in use.
if (readSigns || writeSigns)
{
TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large");
}
// Initialize CUDA kernel parameters.
filtered_lrelu_act_kernel_params p;
p.x = x.data_ptr();
p.s = (readSigns || writeSigns) ? s.data_ptr() : 0;
p.gain = gain;
p.slope = slope;
p.clamp = clamp;
p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0));
p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous.
p.sOfs = make_int2(sx, sy);
// Choose CUDA kernel.
void* func = 0;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&]
{
if (writeSigns)
func = choose_filtered_lrelu_act_kernel();
else if (readSigns)
func = choose_filtered_lrelu_act_kernel();
else
func = choose_filtered_lrelu_act_kernel();
});
TORCH_CHECK(func, "internal error - CUDA kernel not found");
// Launch CUDA kernel.
void* args[] = {&p};
int bx = 128; // 4 warps per block.
// Logical size of launch = writeSigns ? p.s : p.x
uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x;
uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y;
uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use.
gx = (gx - 1) / bx + 1;
// Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest.
const uint32_t gmax = 65535;
gy = std::min(gy, gmax);
gz = std::min(gz, gmax);
// Launch.
AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream()));
return so;
}
//------------------------------------------------------------------------
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("filtered_lrelu", &filtered_lrelu); // The whole thing.
m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place.
}
//------------------------------------------------------------------------
================================================
FILE: stylegan_human/torch_utils/ops/filtered_lrelu.cu
================================================
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include
#include "filtered_lrelu.h"
#include
//------------------------------------------------------------------------
// Helpers.
enum // Filter modes.
{
MODE_SUSD = 0, // Separable upsampling, separable downsampling.
MODE_FUSD = 1, // Full upsampling, separable downsampling.
MODE_SUFD = 2, // Separable upsampling, full downsampling.
MODE_FUFD = 3, // Full upsampling, full downsampling.
};
template struct InternalType;
template <> struct InternalType
{
typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t;
__device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); }
__device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); }
__device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); }
};
template <> struct InternalType
{
typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
__device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
__device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
__device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
};
template <> struct InternalType
{
typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
__device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
__device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
__device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
};
#define MIN(A, B) ((A) < (B) ? (A) : (B))
#define MAX(A, B) ((A) > (B) ? (A) : (B))
#define CEIL_DIV(A, B) (((B)==1) ? (A) : \
((B)==2) ? ((int)((A)+1) >> 1) : \
((B)==4) ? ((int)((A)+3) >> 2) : \
(((A) + ((A) > 0 ? (B) - 1 : 0)) / (B)))
// This works only up to blocks of size 256 x 256 and for all N that are powers of two.
template __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i)
{
if ((N & (N-1)) && N <= 256)
y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256.
else
y = i/N;
x = i - y*N;
}
// Type cast stride before reading it.
template __device__ __forceinline__ T get_stride(const int64_t& x)
{
return *reinterpret_cast(&x);
}
//------------------------------------------------------------------------
// Filters, setup kernel, copying function.
#define MAX_FILTER_SIZE 32
// Combined up/down filter buffers so that transfer can be done with one copy.
__device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel.
__device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel.
// Accessors to combined buffers to index up/down filters individually.
#define c_fu (c_fbuf)
#define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
#define g_fu (g_fbuf)
#define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
// Set up filters into global memory buffer.
static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p)
{
for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x)
{
int x, y;
fast_div_mod(x, y, idx);
int fu_x = p.flip ? x : (p.fuShape.x - 1 - x);
int fu_y = p.flip ? y : (p.fuShape.y - 1 - y);
if (p.fuShape.y > 0)
g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y];
else
g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x];
int fd_x = p.flip ? x : (p.fdShape.x - 1 - x);
int fd_y = p.flip ? y : (p.fdShape.y - 1 - y);
if (p.fdShape.y > 0)
g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y];
else
g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x];
}
}
// Host function to copy filters written by setup kernel into constant buffer for main kernel.
template static cudaError_t copy_filters(cudaStream_t stream)
{
void* src = 0;
cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf);
if (err) return err;
return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream);
}
//------------------------------------------------------------------------
// Coordinate spaces:
// - Relative to input tensor: inX, inY, tileInX, tileInY
// - Relative to input tile: relInX, relInY, tileInW, tileInH
// - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH
// - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH
// - Relative to output tensor: outX, outY, tileOutX, tileOutY
//
// Relationships between coordinate spaces:
// - inX = tileInX + relInX
// - inY = tileInY + relInY
// - relUpX = relInX * up + phaseInX
// - relUpY = relInY * up + phaseInY
// - relUpX = relOutX * down
// - relUpY = relOutY * down
// - outX = tileOutX + relOutX
// - outY = tileOutY + relOutY
extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer.
template
static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p)
{
// Check that we don't try to support non-existing filter modes.
static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported");
static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported");
static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor");
static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor");
static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor");
static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor");
static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE");
static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters");
static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters");
static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4");
static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4");
// Static definitions.
typedef typename InternalType::scalar_t scalar_t;
typedef typename InternalType::vec2_t vec2_t;
typedef typename InternalType::vec4_t vec4_t;
const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4.
const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height.
const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width.
const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height.
const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up.
const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4.
// Merge 1x1 downsampling into last upsampling step for upf1 and ups2.
const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD));
// Sizes of logical buffers.
const int szIn = tileInH_up * tileInW;
const int szUpX = tileInH_up * tileUpW;
const int szUpXY = downInline ? 0 : (tileUpH * tileUpW);
const int szDownX = tileUpH * tileOutW;
// Sizes for shared memory arrays.
const int s_buf0_size_base =
(filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) :
(filterMode == MODE_FUSD) ? MAX(szIn, szDownX) :
(filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) :
(filterMode == MODE_FUFD) ? szIn :
-1;
const int s_buf1_size_base =
(filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) :
(filterMode == MODE_FUSD) ? szUpXY :
(filterMode == MODE_SUFD) ? szUpX :
(filterMode == MODE_FUFD) ? szUpXY :
-1;
// Ensure U128 alignment.
const int s_buf0_size = (s_buf0_size_base + 3) & ~3;
const int s_buf1_size = (s_buf1_size_base + 3) & ~3;
// Check at compile time that we don't use too much shared memory.
static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow");
// Declare shared memory arrays.
scalar_t* s_buf0;
scalar_t* s_buf1;
if (sharedKB <= 48)
{
// Allocate shared memory arrays here.
__shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused.
s_buf0 = s_buf0_st;
s_buf1 = s_buf0 + s_buf0_size;
}
else
{
// Use the dynamically allocated shared memory array.
s_buf0 = (scalar_t*)s_buf_raw;
s_buf1 = s_buf0 + s_buf0_size;
}
// Pointers to the buffers.
scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY]
scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX]
scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX]
scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX]
if (filterMode == MODE_SUSD)
{
s_tileIn = s_buf0;
s_tileUpX = s_buf1;
s_tileUpXY = s_buf0;
s_tileDownX = s_buf1;
}
else if (filterMode == MODE_FUSD)
{
s_tileIn = s_buf0;
s_tileUpXY = s_buf1;
s_tileDownX = s_buf0;
}
else if (filterMode == MODE_SUFD)
{
s_tileIn = s_buf0;
s_tileUpX = s_buf1;
s_tileUpXY = s_buf0;
}
else if (filterMode == MODE_FUFD)
{
s_tileIn = s_buf0;
s_tileUpXY = s_buf1;
}
// Allow large grids in z direction via per-launch offset.
int channelIdx = blockIdx.z + p.blockZofs;
int batchIdx = channelIdx / p.yShape.z;
channelIdx -= batchIdx * p.yShape.z;
// Offset to output feature map. In bytes.
index_t mapOfsOut = channelIdx * get_stride(p.yStride.z) + batchIdx * get_stride(p.yStride.w);
// Sign shift amount.
uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6;
// Inner tile loop.
#pragma unroll 1
for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++)
{
// Locate output tile.
int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x;
int tileOutX = tileX * tileOutW;
int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH;
// Locate input tile.
int tmpX = tileOutX * down - p.pad0.x;
int tmpY = tileOutY * down - p.pad0.y;
int tileInX = CEIL_DIV(tmpX, up);
int tileInY = CEIL_DIV(tmpY, up);
const int phaseInX = tileInX * up - tmpX;
const int phaseInY = tileInY * up - tmpY;
// Extra sync if input and output buffers are the same and we are not on first tile.
if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline)))
__syncthreads();
// Load input tile & apply bias. Unrolled.
scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride(p.bStride)));
index_t mapOfsIn = channelIdx * get_stride(p.xStride.z) + batchIdx * get_stride(p.xStride.w);
int idx = threadIdx.x;
const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock);
#pragma unroll
for (int loop = 0; loop < loopCountIN; loop++)
{
int relInX, relInY;
fast_div_mod(relInX, relInY, idx);
int inX = tileInX + relInX;
int inY = tileInY + relInY;
scalar_t v = 0;
if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y)
v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride(p.xStride.x) + inY * get_stride(p.xStride.y) + mapOfsIn))) + b;
bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH);
if (!skip)
s_tileIn[idx] = v;
idx += threadsPerBlock;
}
if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter.
{
// Horizontal upsampling.
__syncthreads();
if (up == 4)
{
for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
{
int relUpX0, relInY;
fast_div_mod(relUpX0, relInY, idx);
int relInX0 = relUpX0 / up;
int src0 = relInX0 + tileInW * relInY;
int dst = relInY * tileUpW + relUpX0;
vec4_t v = InternalType::zero_vec4();
scalar_t a = s_tileIn[src0];
if (phaseInX == 0)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 0];
a = s_tileIn[src0 + step + 1];
v.y += a * (scalar_t)c_fu[step * up + 3];
v.z += a * (scalar_t)c_fu[step * up + 2];
v.w += a * (scalar_t)c_fu[step * up + 1];
}
}
else if (phaseInX == 1)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 1];
v.y += a * (scalar_t)c_fu[step * up + 0];
a = s_tileIn[src0 + step + 1];
v.z += a * (scalar_t)c_fu[step * up + 3];
v.w += a * (scalar_t)c_fu[step * up + 2];
}
}
else if (phaseInX == 2)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 2];
v.y += a * (scalar_t)c_fu[step * up + 1];
v.z += a * (scalar_t)c_fu[step * up + 0];
a = s_tileIn[src0 + step + 1];
v.w += a * (scalar_t)c_fu[step * up + 3];
}
}
else // (phaseInX == 3)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 3];
v.y += a * (scalar_t)c_fu[step * up + 2];
v.z += a * (scalar_t)c_fu[step * up + 1];
v.w += a * (scalar_t)c_fu[step * up + 0];
a = s_tileIn[src0 + step + 1];
}
}
s_tileUpX[dst+0] = v.x;
s_tileUpX[dst+1] = v.y;
s_tileUpX[dst+2] = v.z;
s_tileUpX[dst+3] = v.w;
}
}
else if (up == 2)
{
bool p0 = (phaseInX == 0);
for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
{
int relUpX0, relInY;
fast_div_mod(relUpX0, relInY, idx);
int relInX0 = relUpX0 / up;
int src0 = relInX0 + tileInW * relInY;
int dst = relInY * tileUpW + relUpX0;
vec2_t v = InternalType::zero_vec2();
scalar_t a = s_tileIn[src0];
if (p0) // (phaseInX == 0)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 0];
a = s_tileIn[src0 + step + 1];
v.y += a * (scalar_t)c_fu[step * up + 1];
}
}
else // (phaseInX == 1)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 1];
v.y += a * (scalar_t)c_fu[step * up + 0];
a = s_tileIn[src0 + step + 1];
}
}
s_tileUpX[dst+0] = v.x;
s_tileUpX[dst+1] = v.y;
}
}
// Vertical upsampling & nonlinearity.
__syncthreads();
int groupMask = 15 << ((threadIdx.x & 31) & ~3);
int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes.
if (up == 4)
{
minY -= 3; // Adjust according to block height.
for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
{
int relUpX, relInY0;
fast_div_mod(relUpX, relInY0, idx);
int relUpY0 = relInY0 * up;
int src0 = relInY0 * tileUpW + relUpX;
int dst = relUpY0 * tileUpW + relUpX;
vec4_t v = InternalType::zero_vec4();
scalar_t a = s_tileUpX[src0];
if (phaseInY == 0)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 0];
a = s_tileUpX[src0 + (step + 1) * tileUpW];
v.y += a * (scalar_t)c_fu[step * up + 3];
v.z += a * (scalar_t)c_fu[step * up + 2];
v.w += a * (scalar_t)c_fu[step * up + 1];
}
}
else if (phaseInY == 1)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 1];
v.y += a * (scalar_t)c_fu[step * up + 0];
a = s_tileUpX[src0 + (step + 1) * tileUpW];
v.z += a * (scalar_t)c_fu[step * up + 3];
v.w += a * (scalar_t)c_fu[step * up + 2];
}
}
else if (phaseInY == 2)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 2];
v.y += a * (scalar_t)c_fu[step * up + 1];
v.z += a * (scalar_t)c_fu[step * up + 0];
a = s_tileUpX[src0 + (step + 1) * tileUpW];
v.w += a * (scalar_t)c_fu[step * up + 3];
}
}
else // (phaseInY == 3)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 3];
v.y += a * (scalar_t)c_fu[step * up + 2];
v.z += a * (scalar_t)c_fu[step * up + 1];
v.w += a * (scalar_t)c_fu[step * up + 0];
a = s_tileUpX[src0 + (step + 1) * tileUpW];
}
}
int x = tileOutX * down + relUpX;
int y = tileOutY * down + relUpY0;
int signX = x + p.sOfs.x;
int signY = y + p.sOfs.y;
int signZ = blockIdx.z + p.blockZofs;
int signXb = signX >> 2;
index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
index_t si1 = si0 + p.sShape.x;
index_t si2 = si0 + p.sShape.x * 2;
index_t si3 = si0 + p.sShape.x * 3;
v.x *= (scalar_t)((float)up * (float)up * p.gain);
v.y *= (scalar_t)((float)up * (float)up * p.gain);
v.z *= (scalar_t)((float)up * (float)up * p.gain);
v.w *= (scalar_t)((float)up * (float)up * p.gain);
if (signWrite)
{
if (!enableWriteSkip)
{
// Determine and write signs.
int sx = __float_as_uint(v.x) >> 31 << 0;
int sy = __float_as_uint(v.y) >> 31 << 8;
int sz = __float_as_uint(v.z) >> 31 << 16;
int sw = __float_as_uint(v.w) >> 31 << 24;
if (sx) v.x *= p.slope;
if (sy) v.y *= p.slope;
if (sz) v.z *= p.slope;
if (sw) v.w *= p.slope;
if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); }
if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); }
if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); }
if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); }
if ((uint32_t)signXb < p.swLimit && signY >= minY)
{
// Combine signs.
uint32_t s = sx + sy + sw + sz;
s <<= (signX & 3) << 1;
s |= __shfl_xor_sync(groupMask, s, 1);
s |= __shfl_xor_sync(groupMask, s, 2);
// Write signs.
if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
}
}
else
{
// Determine and write signs.
if ((uint32_t)signXb < p.swLimit && signY >= minY)
{
int sx = __float_as_uint(v.x) >> 31 << 0;
int sy = __float_as_uint(v.y) >> 31 << 8;
int sz = __float_as_uint(v.z) >> 31 << 16;
int sw = __float_as_uint(v.w) >> 31 << 24;
if (sx) v.x *= p.slope;
if (sy) v.y *= p.slope;
if (sz) v.z *= p.slope;
if (sw) v.w *= p.slope;
if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); }
if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); }
if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); }
if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); }
// Combine signs.
uint32_t s = sx + sy + sw + sz;
s <<= (signX & 3) << 1;
s |= __shfl_xor_sync(groupMask, s, 1);
s |= __shfl_xor_sync(groupMask, s, 2);
// Write signs.
if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
}
else
{
// Just compute the values.
if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp);
if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp);
}
}
}
else if (signRead) // Read signs and apply.
{
if ((uint32_t)signXb < p.swLimit)
{
int ss = (signX & 3) << 1;
if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; }
if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; }
}
}
else // Forward pass with no sign write.
{
if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
if (v.y < 0.f) v.y *= p.slope; v.y = InternalType