Repository: XingangPan/DragGAN Branch: main Commit: 336f120ce126 Files: 186 Total size: 1.5 MB Directory structure: gitextract_5nncgy8o/ ├── .gitignore ├── Dockerfile ├── LICENSE.txt ├── README.md ├── dnnlib/ │ ├── __init__.py │ └── util.py ├── environment.yml ├── gen_images.py ├── gradio_utils/ │ ├── __init__.py │ └── utils.py ├── gui_utils/ │ ├── __init__.py │ ├── gl_utils.py │ ├── glfw_window.py │ ├── imgui_utils.py │ ├── imgui_window.py │ └── text_utils.py ├── legacy.py ├── requirements.txt ├── stylegan_human/ │ ├── .gitignore │ ├── PP_HumanSeg/ │ │ ├── deploy/ │ │ │ └── infer.py │ │ ├── export_model/ │ │ │ └── download_export_model.py │ │ └── pretrained_model/ │ │ └── download_pretrained_model.py │ ├── README.md │ ├── __init__.py │ ├── alignment.py │ ├── bg_white.py │ ├── dnnlib/ │ │ ├── __init__.py │ │ ├── tflib/ │ │ │ ├── __init__.py │ │ │ ├── autosummary.py │ │ │ ├── custom_ops.py │ │ │ ├── network.py │ │ │ ├── ops/ │ │ │ │ ├── __init__.py │ │ │ │ ├── fused_bias_act.cu │ │ │ │ ├── fused_bias_act.py │ │ │ │ ├── upfirdn_2d.cu │ │ │ │ └── upfirdn_2d.py │ │ │ ├── optimizer.py │ │ │ └── tfutil.py │ │ └── util.py │ ├── docs/ │ │ └── Dataset.md │ ├── edit/ │ │ ├── __init__.py │ │ ├── edit_config.py │ │ └── edit_helper.py │ ├── edit.py │ ├── environment.yml │ ├── generate.py │ ├── insetgan.py │ ├── interpolation.py │ ├── latent_direction/ │ │ └── ss_statics/ │ │ ├── bottom_length_statis/ │ │ │ ├── 3/ │ │ │ │ └── statis.csv │ │ │ ├── 4/ │ │ │ │ └── statis.csv │ │ │ └── 5/ │ │ │ └── statis.csv │ │ └── upper_length_statis/ │ │ └── 5/ │ │ └── statis.csv │ ├── legacy.py │ ├── openpose/ │ │ ├── model/ │ │ │ └── .gitkeep │ │ └── src/ │ │ ├── __init__.py │ │ ├── body.py │ │ ├── model.py │ │ └── util.py │ ├── pti/ │ │ ├── pti_configs/ │ │ │ ├── __init__.py │ │ │ ├── global_config.py │ │ │ ├── hyperparameters.py │ │ │ └── paths_config.py │ │ ├── pti_models/ │ │ │ ├── __init__.py │ │ │ └── e4e/ │ │ │ ├── __init__.py │ │ │ ├── encoders/ │ │ │ │ ├── __init__.py │ │ │ │ ├── helpers.py │ │ │ │ ├── model_irse.py │ │ │ │ └── psp_encoders.py │ │ │ ├── latent_codes_pool.py │ │ │ ├── psp.py │ │ │ └── stylegan2/ │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── op/ │ │ │ ├── __init__.py │ │ │ ├── fused_act.py │ │ │ ├── fused_bias_act.cpp │ │ │ ├── fused_bias_act_kernel.cu │ │ │ ├── upfirdn2d.cpp │ │ │ ├── upfirdn2d.py │ │ │ └── upfirdn2d_kernel.cu │ │ └── training/ │ │ ├── __init__.py │ │ ├── coaches/ │ │ │ ├── __init__.py │ │ │ ├── base_coach.py │ │ │ ├── localitly_regulizer.py │ │ │ ├── multi_id_coach.py │ │ │ └── single_id_coach.py │ │ └── projectors/ │ │ ├── __init__.py │ │ ├── w_plus_projector.py │ │ └── w_projector.py │ ├── run_pti.py │ ├── style_mixing.py │ ├── stylemixing_video.py │ ├── torch_utils/ │ │ ├── __init__.py │ │ ├── custom_ops.py │ │ ├── misc.py │ │ ├── models.py │ │ ├── models_face.py │ │ ├── op_edit/ │ │ │ ├── __init__.py │ │ │ ├── fused_act.py │ │ │ ├── fused_bias_act.cpp │ │ │ ├── fused_bias_act_kernel.cu │ │ │ ├── upfirdn2d.cpp │ │ │ ├── upfirdn2d.py │ │ │ └── upfirdn2d_kernel.cu │ │ ├── ops/ │ │ │ ├── __init__.py │ │ │ ├── bias_act.cpp │ │ │ ├── bias_act.cu │ │ │ ├── bias_act.h │ │ │ ├── bias_act.py │ │ │ ├── conv2d_gradfix.py │ │ │ ├── conv2d_resample.py │ │ │ ├── filtered_lrelu.cpp │ │ │ ├── filtered_lrelu.cu │ │ │ ├── filtered_lrelu.h │ │ │ ├── filtered_lrelu.py │ │ │ ├── filtered_lrelu_ns.cu │ │ │ ├── filtered_lrelu_rd.cu │ │ │ ├── filtered_lrelu_wr.cu │ │ │ ├── fma.py │ │ │ ├── grid_sample_gradfix.py │ │ │ ├── upfirdn2d.cpp │ │ │ ├── upfirdn2d.cu │ │ │ ├── upfirdn2d.h │ │ │ └── upfirdn2d.py │ │ ├── persistence.py │ │ └── training_stats.py │ ├── training/ │ │ ├── __init__.py │ │ ├── augment.py │ │ ├── dataset.py │ │ ├── loss.py │ │ ├── networks_stylegan2.py │ │ ├── networks_stylegan3.py │ │ └── training_loop.py │ ├── training_scripts/ │ │ ├── sg2/ │ │ │ ├── train.py │ │ │ └── training/ │ │ │ ├── dataset.py │ │ │ └── networks.py │ │ └── sg3/ │ │ ├── train.py │ │ └── training/ │ │ ├── dataset.py │ │ ├── networks_stylegan2.py │ │ └── networks_stylegan3.py │ └── utils/ │ ├── ImagesDataset.py │ ├── __init__.py │ ├── data_utils.py │ ├── face_alignment.py │ ├── log_utils.py │ ├── models_utils.py │ └── util.py ├── torch_utils/ │ ├── __init__.py │ ├── custom_ops.py │ ├── misc.py │ ├── ops/ │ │ ├── __init__.py │ │ ├── bias_act.cpp │ │ ├── bias_act.cu │ │ ├── bias_act.h │ │ ├── bias_act.py │ │ ├── conv2d_gradfix.py │ │ ├── conv2d_resample.py │ │ ├── filtered_lrelu.cpp │ │ ├── filtered_lrelu.cu │ │ ├── filtered_lrelu.h │ │ ├── filtered_lrelu.py │ │ ├── filtered_lrelu_ns.cu │ │ ├── filtered_lrelu_rd.cu │ │ ├── filtered_lrelu_wr.cu │ │ ├── fma.py │ │ ├── grid_sample_gradfix.py │ │ ├── upfirdn2d.cpp │ │ ├── upfirdn2d.cu │ │ ├── upfirdn2d.h │ │ └── upfirdn2d.py │ ├── persistence.py │ └── training_stats.py ├── training/ │ ├── __init__.py │ ├── augment.py │ ├── dataset.py │ ├── loss.py │ ├── networks_stylegan2.py │ ├── networks_stylegan3.py │ └── training_loop.py ├── visualizer_drag.py ├── visualizer_drag_gradio.py └── viz/ ├── __init__.py ├── capture_widget.py ├── drag_widget.py ├── latent_widget.py ├── pickle_widget.py └── renderer.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Created by .ignore support plugin (hsz.mobi) ### Python template # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python env/ build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ *.egg-info/ .installed.cfg *.egg # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *,cover .hypothesis/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # IPython Notebook .ipynb_checkpoints # pyenv .python-version # celery beat schedule file celerybeat-schedule # dotenv .env # virtualenv venv/ ENV/ # Spyder project settings .spyderproject # Rope project settings .ropeproject ### VirtualEnv template # Virtualenv # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ .Python [Bb]in [Ii]nclude [Ll]ib [Ll]ib64 [Ll]ocal [Ss]cripts !scripts/ pyvenv.cfg .venv pip-selfcheck.json ### JetBrains template # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 # User-specific stuff: .idea/workspace.xml .idea/tasks.xml .idea/dictionaries .idea/vcs.xml .idea/jsLibraryMappings.xml # Sensitive or high-churn files: .idea/dataSources.ids .idea/dataSources.xml .idea/dataSources.local.xml .idea/sqlDataSources.xml .idea/dynamic.xml .idea/uiDesigner.xml # Gradle: .idea/gradle.xml .idea/libraries # Mongo Explorer plugin: .idea/mongoSettings.xml .idea/ ## File-based project format: *.iws ## Plugin-specific files: # IntelliJ /out/ # mpeltonen/sbt-idea plugin .idea_modules/ # JIRA plugin atlassian-ide-plugin.xml # Crashlytics plugin (for Android Studio and IntelliJ) com_crashlytics_export_strings.xml crashlytics.properties crashlytics-build.properties fabric.properties # Mac related .DS_Store checkpoints ================================================ FILE: Dockerfile ================================================ FROM nvcr.io/nvidia/pytorch:23.05-py3 ENV PYTHONDONTWRITEBYTECODE 1 ENV PYTHONUNBUFFERED 1 RUN apt-get update && apt-get install -y --no-install-recommends \ make \ pkgconf \ xz-utils \ xorg-dev \ libgl1-mesa-dev \ libglu1-mesa-dev \ libxrandr-dev \ libxinerama-dev \ libxcursor-dev \ libxi-dev \ libxxf86vm-dev \ && rm -rf /var/lib/apt/lists/* RUN pip install --no-cache-dir --upgrade pip COPY requirements.txt . RUN pip install --no-cache-dir -r requirements.txt WORKDIR /workspace RUN (printf '#!/bin/bash\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh ENTRYPOINT ["/entry.sh"] ================================================ FILE: LICENSE.txt ================================================ Copyright (c) 2021, NVIDIA Corporation & affiliates. All rights reserved. NVIDIA Source Code License for StyleGAN3 ======================================================================= 1. Definitions "Licensor" means any person or entity that distributes its Work. "Software" means the original work of authorship made available under this License. "Work" means the Software and any additions to or derivative works of the Software that are made available under this License. The terms "reproduce," "reproduction," "derivative works," and "distribution" have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. Works, including the Software, are "made available" under this License by including in or with the Work either (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License. 2. License Grants 2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. 3. Limitations 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you include a complete copy of this License with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work. 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work ("Your Terms") only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself. 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA and its affiliates may use the Work and any derivative works commercially. As used herein, "non-commercially" means for research or evaluation purposes only. 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this License from such Licensor (including the grant in Section 2.1) will terminate immediately. 3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this License. 3.6 Termination. If you violate any term of this License, then your rights under this License (including the grant in Section 2.1) will terminate immediately. 4. Disclaimer of Warranty. THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 5. Limitation of Liability. EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. ======================================================================= ================================================ FILE: README.md ================================================

Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold

Xingang Pan · Ayush Tewari · Thomas Leimkühler · Lingjie Liu · Abhimitra Meka · Christian Theobalt

SIGGRAPH 2023 Conference Proceedings


PyTorch Twitter Paper PDF Project Page Open In Colab

## Web Demos [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/XingangPan/DragGAN)

Huggingface

## Requirements If you have CUDA graphic card, please follow the requirements of [NVlabs/stylegan3](https://github.com/NVlabs/stylegan3#requirements). The usual installation steps involve the following commands, they should set up the correct CUDA version and all the python packages ``` conda env create -f environment.yml conda activate stylegan3 ``` Then install the additional requirements ``` pip install -r requirements.txt ``` Otherwise (for GPU acceleration on MacOS with Silicon Mac M1/M2, or just CPU) try the following: ```sh cat environment.yml | \ grep -v -E 'nvidia|cuda' > environment-no-nvidia.yml && \ conda env create -f environment-no-nvidia.yml conda activate stylegan3 # On MacOS export PYTORCH_ENABLE_MPS_FALLBACK=1 ``` ## Run Gradio visualizer in Docker Provided docker image is based on NGC PyTorch repository. To quickly try out visualizer in Docker, run the following: ```sh # before you build the docker container, make sure you have cloned this repo, and downloaded the pretrained model by `python scripts/download_model.py`. docker build . -t draggan:latest docker run -p 7860:7860 -v "$PWD":/workspace/src -it draggan:latest bash # (Use GPU)if you want to utilize your Nvidia gpu to accelerate in docker, please add command tag `--gpus all`, like: # docker run --gpus all -p 7860:7860 -v "$PWD":/workspace/src -it draggan:latest bash cd src && python visualizer_drag_gradio.py --listen ``` Now you can open a shared link from Gradio (printed in the terminal console). Beware the Docker image takes about 25GB of disk space! ## Download pre-trained StyleGAN2 weights To download pre-trained weights, simply run: ``` python scripts/download_model.py ``` If you want to try StyleGAN-Human and the Landscapes HQ (LHQ) dataset, please download weights from these links: [StyleGAN-Human](https://drive.google.com/file/d/1dlFEHbu-WzQWJl7nBBZYcTyo000H9hVm/view?usp=sharing), [LHQ](https://drive.google.com/file/d/16twEf0T9QINAEoMsWefoWiyhcTd-aiWc/view?usp=sharing), and put them under `./checkpoints`. Feel free to try other pretrained StyleGAN. ## Run DragGAN GUI To start the DragGAN GUI, simply run: ```sh sh scripts/gui.sh ``` If you are using windows, you can run: ``` .\scripts\gui.bat ``` This GUI supports editing GAN-generated images. To edit a real image, you need to first perform GAN inversion using tools like [PTI](https://github.com/danielroich/PTI). Then load the new latent code and model weights to the GUI. You can run DragGAN Gradio demo as well, this is universal for both windows and linux: ```sh python visualizer_drag_gradio.py ``` ## Acknowledgement This code is developed based on [StyleGAN3](https://github.com/NVlabs/stylegan3). Part of the code is borrowed from [StyleGAN-Human](https://github.com/stylegan-human/StyleGAN-Human). (cheers to the community as well) ## License The code related to the DragGAN algorithm is licensed under [CC-BY-NC](https://creativecommons.org/licenses/by-nc/4.0/). However, most of this project are available under a separate license terms: all codes used or modified from [StyleGAN3](https://github.com/NVlabs/stylegan3) is under the [Nvidia Source Code License](https://github.com/NVlabs/stylegan3/blob/main/LICENSE.txt). Any form of use and derivative of this code must preserve the watermarking functionality showing "AI Generated". ## BibTeX ```bibtex @inproceedings{pan2023draggan, title={Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold}, author={Pan, Xingang and Tewari, Ayush, and Leimk{\"u}hler, Thomas and Liu, Lingjie and Meka, Abhimitra and Theobalt, Christian}, booktitle = {ACM SIGGRAPH 2023 Conference Proceedings}, year={2023} } ``` ================================================ FILE: dnnlib/__init__.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. from .util import EasyDict, make_cache_dir_path ================================================ FILE: dnnlib/util.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Miscellaneous utility classes and functions.""" import ctypes import fnmatch import importlib import inspect import numpy as np import os import shutil import sys import types import io import pickle import re import requests import html import hashlib import glob import tempfile import urllib import urllib.request import uuid from distutils.util import strtobool from typing import Any, List, Tuple, Union # Util classes # ------------------------------------------------------------------------------------------ class EasyDict(dict): """Convenience class that behaves like a dict but allows access with the attribute syntax.""" def __getattr__(self, name: str) -> Any: try: return self[name] except KeyError: raise AttributeError(name) def __setattr__(self, name: str, value: Any) -> None: self[name] = value def __delattr__(self, name: str) -> None: del self[name] class Logger(object): """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): self.file = None if file_name is not None: self.file = open(file_name, file_mode) self.should_flush = should_flush self.stdout = sys.stdout self.stderr = sys.stderr sys.stdout = self sys.stderr = self def __enter__(self) -> "Logger": return self def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: self.close() def write(self, text: Union[str, bytes]) -> None: """Write text to stdout (and a file) and optionally flush.""" if isinstance(text, bytes): text = text.decode() if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash return if self.file is not None: self.file.write(text) self.stdout.write(text) if self.should_flush: self.flush() def flush(self) -> None: """Flush written text to both stdout and a file, if open.""" if self.file is not None: self.file.flush() self.stdout.flush() def close(self) -> None: """Flush, close possible files, and remove stdout/stderr mirroring.""" self.flush() # if using multiple loggers, prevent closing in wrong order if sys.stdout is self: sys.stdout = self.stdout if sys.stderr is self: sys.stderr = self.stderr if self.file is not None: self.file.close() self.file = None # Cache directories # ------------------------------------------------------------------------------------------ _dnnlib_cache_dir = None def set_cache_dir(path: str) -> None: global _dnnlib_cache_dir _dnnlib_cache_dir = path def make_cache_dir_path(*paths: str) -> str: if _dnnlib_cache_dir is not None: return os.path.join(_dnnlib_cache_dir, *paths) if 'DNNLIB_CACHE_DIR' in os.environ: return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) if 'HOME' in os.environ: return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) if 'USERPROFILE' in os.environ: return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) # Small util functions # ------------------------------------------------------------------------------------------ def format_time(seconds: Union[int, float]) -> str: """Convert the seconds to human readable string with days, hours, minutes and seconds.""" s = int(np.rint(seconds)) if s < 60: return "{0}s".format(s) elif s < 60 * 60: return "{0}m {1:02}s".format(s // 60, s % 60) elif s < 24 * 60 * 60: return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) else: return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) def format_time_brief(seconds: Union[int, float]) -> str: """Convert the seconds to human readable string with days, hours, minutes and seconds.""" s = int(np.rint(seconds)) if s < 60: return "{0}s".format(s) elif s < 60 * 60: return "{0}m {1:02}s".format(s // 60, s % 60) elif s < 24 * 60 * 60: return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) else: return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24) def ask_yes_no(question: str) -> bool: """Ask the user the question until the user inputs a valid answer.""" while True: try: print("{0} [y/n]".format(question)) return strtobool(input().lower()) except ValueError: pass def tuple_product(t: Tuple) -> Any: """Calculate the product of the tuple elements.""" result = 1 for v in t: result *= v return result _str_to_ctype = { "uint8": ctypes.c_ubyte, "uint16": ctypes.c_uint16, "uint32": ctypes.c_uint32, "uint64": ctypes.c_uint64, "int8": ctypes.c_byte, "int16": ctypes.c_int16, "int32": ctypes.c_int32, "int64": ctypes.c_int64, "float32": ctypes.c_float, "float64": ctypes.c_double } def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" type_str = None if isinstance(type_obj, str): type_str = type_obj elif hasattr(type_obj, "__name__"): type_str = type_obj.__name__ elif hasattr(type_obj, "name"): type_str = type_obj.name else: raise RuntimeError("Cannot infer type name from input") assert type_str in _str_to_ctype.keys() my_dtype = np.dtype(type_str) my_ctype = _str_to_ctype[type_str] assert my_dtype.itemsize == ctypes.sizeof(my_ctype) return my_dtype, my_ctype def is_pickleable(obj: Any) -> bool: try: with io.BytesIO() as stream: pickle.dump(obj, stream) return True except: return False # Functionality to import modules/objects by name, and call functions by name # ------------------------------------------------------------------------------------------ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: """Searches for the underlying module behind the name to some python object. Returns the module and the object name (original name with module part removed).""" # allow convenience shorthands, substitute them by full names obj_name = re.sub("^np.", "numpy.", obj_name) obj_name = re.sub("^tf.", "tensorflow.", obj_name) # list alternatives for (module_name, local_obj_name) parts = obj_name.split(".") name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] # try each alternative in turn for module_name, local_obj_name in name_pairs: try: module = importlib.import_module(module_name) # may raise ImportError get_obj_from_module(module, local_obj_name) # may raise AttributeError return module, local_obj_name except: pass # maybe some of the modules themselves contain errors? for module_name, _local_obj_name in name_pairs: try: importlib.import_module(module_name) # may raise ImportError except ImportError: if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): raise # maybe the requested attribute is missing? for module_name, local_obj_name in name_pairs: try: module = importlib.import_module(module_name) # may raise ImportError get_obj_from_module(module, local_obj_name) # may raise AttributeError except ImportError: pass # we are out of luck, but we have no idea why raise ImportError(obj_name) def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: """Traverses the object name and returns the last (rightmost) python object.""" if obj_name == '': return module obj = module for part in obj_name.split("."): obj = getattr(obj, part) return obj def get_obj_by_name(name: str) -> Any: """Finds the python object with the given name.""" module, obj_name = get_module_from_obj_name(name) return get_obj_from_module(module, obj_name) def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: """Finds the python object with the given name and calls it as a function.""" assert func_name is not None func_obj = get_obj_by_name(func_name) assert callable(func_obj) return func_obj(*args, **kwargs) def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: """Finds the python class with the given name and constructs it with the given arguments.""" return call_func_by_name(*args, func_name=class_name, **kwargs) def get_module_dir_by_obj_name(obj_name: str) -> str: """Get the directory path of the module containing the given object name.""" module, _ = get_module_from_obj_name(obj_name) return os.path.dirname(inspect.getfile(module)) def is_top_level_function(obj: Any) -> bool: """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ def get_top_level_function_name(obj: Any) -> str: """Return the fully-qualified name of a top-level function.""" assert is_top_level_function(obj) module = obj.__module__ if module == '__main__': module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] return module + "." + obj.__name__ # File system helpers # ------------------------------------------------------------------------------------------ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: """List all files recursively in a given directory while ignoring given file and directory names. Returns list of tuples containing both absolute and relative paths.""" assert os.path.isdir(dir_path) base_name = os.path.basename(os.path.normpath(dir_path)) if ignores is None: ignores = [] result = [] for root, dirs, files in os.walk(dir_path, topdown=True): for ignore_ in ignores: dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] # dirs need to be edited in-place for d in dirs_to_remove: dirs.remove(d) files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] absolute_paths = [os.path.join(root, f) for f in files] relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] if add_base_to_relative: relative_paths = [os.path.join(base_name, p) for p in relative_paths] assert len(absolute_paths) == len(relative_paths) result += zip(absolute_paths, relative_paths) return result def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: """Takes in a list of tuples of (src, dst) paths and copies files. Will create all necessary directories.""" for file in files: target_dir_name = os.path.dirname(file[1]) # will create all intermediate-level directories if not os.path.exists(target_dir_name): os.makedirs(target_dir_name) shutil.copyfile(file[0], file[1]) # URL helpers # ------------------------------------------------------------------------------------------ def is_url(obj: Any, allow_file_urls: bool = False) -> bool: """Determine whether the given object is a valid URL string.""" if not isinstance(obj, str) or not "://" in obj: return False if allow_file_urls and obj.startswith('file://'): return True try: res = requests.compat.urlparse(obj) if not res.scheme or not res.netloc or not "." in res.netloc: return False res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) if not res.scheme or not res.netloc or not "." in res.netloc: return False except: return False return True def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any: """Download the given URL and return a binary-mode file object to access the data.""" assert num_attempts >= 1 assert not (return_filename and (not cache)) # Doesn't look like an URL scheme so interpret it as a local filename. if not re.match('^[a-z]+://', url): return url if return_filename else open(url, "rb") # Handle file URLs. This code handles unusual file:// patterns that # arise on Windows: # # file:///c:/foo.txt # # which would translate to a local '/c:/foo.txt' filename that's # invalid. Drop the forward slash for such pathnames. # # If you touch this code path, you should test it on both Linux and # Windows. # # Some internet resources suggest using urllib.request.url2pathname() but # but that converts forward slashes to backslashes and this causes # its own set of problems. if url.startswith('file://'): filename = urllib.parse.urlparse(url).path if re.match(r'^/[a-zA-Z]:', filename): filename = filename[1:] return filename if return_filename else open(filename, "rb") assert is_url(url) # Lookup from cache. if cache_dir is None: cache_dir = make_cache_dir_path('downloads') url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() if cache: cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) if len(cache_files) == 1: filename = cache_files[0] return filename if return_filename else open(filename, "rb") # Download. url_name = None url_data = None with requests.Session() as session: if verbose: print("Downloading %s ..." % url, end="", flush=True) for attempts_left in reversed(range(num_attempts)): try: with session.get(url) as res: res.raise_for_status() if len(res.content) == 0: raise IOError("No data received") if len(res.content) < 8192: content_str = res.content.decode("utf-8") if "download_warning" in res.headers.get("Set-Cookie", ""): links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] if len(links) == 1: url = requests.compat.urljoin(url, links[0]) raise IOError("Google Drive virus checker nag") if "Google Drive - Quota exceeded" in content_str: raise IOError("Google Drive download quota exceeded -- please try again later") match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) url_name = match[1] if match else url url_data = res.content if verbose: print(" done") break except KeyboardInterrupt: raise except: if not attempts_left: if verbose: print(" failed") raise if verbose: print(".", end="", flush=True) # Save to cache. if cache: safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) os.makedirs(cache_dir, exist_ok=True) with open(temp_file, "wb") as f: f.write(url_data) os.replace(temp_file, cache_file) # atomic if return_filename: return cache_file # Return data as file object. assert not return_filename return io.BytesIO(url_data) ================================================ FILE: environment.yml ================================================ name: stylegan3 channels: - pytorch - nvidia dependencies: - python >= 3.8 - pip - numpy>=1.25 - click>=8.0 - pillow=9.4.0 - scipy=1.11.1 - pytorch>=2.0.1 - torchvision>=0.15.2 - cudatoolkit=11.1 - requests=2.26.0 - tqdm=4.62.2 - ninja=1.10.2 - matplotlib=3.4.2 - imageio=2.9.0 - pip: - imgui==2.0.0 - glfw==2.6.1 - gradio==3.35.2 - pyopengl==3.1.5 - imageio-ffmpeg==0.4.3 # pyspng is currently broken on MacOS (see https://github.com/nurpax/pyspng/pull/6 for instance) - pyspng-seunglab ================================================ FILE: gen_images.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Generate images using pretrained network pickle.""" import os import re from typing import List, Optional, Tuple, Union import click import dnnlib import numpy as np import PIL.Image import torch import legacy #---------------------------------------------------------------------------- def parse_range(s: Union[str, List]) -> List[int]: '''Parse a comma separated list of numbers or ranges and return a list of ints. Example: '1,2,5-10' returns [1, 2, 5, 6, 7] ''' if isinstance(s, list): return s ranges = [] range_re = re.compile(r'^(\d+)-(\d+)$') for p in s.split(','): m = range_re.match(p) if m: ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) else: ranges.append(int(p)) return ranges #---------------------------------------------------------------------------- def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]: '''Parse a floating point 2-vector of syntax 'a,b'. Example: '0,1' returns (0,1) ''' if isinstance(s, tuple): return s parts = s.split(',') if len(parts) == 2: return (float(parts[0]), float(parts[1])) raise ValueError(f'cannot parse 2-vector {s}') #---------------------------------------------------------------------------- def make_transform(translate: Tuple[float,float], angle: float): m = np.eye(3) s = np.sin(angle/360.0*np.pi*2) c = np.cos(angle/360.0*np.pi*2) m[0][0] = c m[0][1] = s m[0][2] = translate[0] m[1][0] = -s m[1][1] = c m[1][2] = translate[1] return m #---------------------------------------------------------------------------- @click.command() @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) @click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', required=True) @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)') @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) @click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2') @click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE') @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR') def generate_images( network_pkl: str, seeds: List[int], truncation_psi: float, noise_mode: str, outdir: str, translate: Tuple[float,float], rotate: float, class_idx: Optional[int] ): """Generate images using pretrained network pickle. Examples: \b # Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left). python gen_images.py --outdir=out --trunc=1 --seeds=2 \\ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl \b # Generate uncurated images with truncation using the MetFaces-U dataset python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl """ print('Loading networks from "%s"...' % network_pkl) device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') dtype = torch.float32 if device.type == 'mps' else torch.float64 with dnnlib.util.open_url(network_pkl) as f: G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype) # type: ignore # import pickle # G = legacy.load_network_pkl(f) # output = open('checkpoints/stylegan2-car-config-f-pt.pkl', 'wb') # pickle.dump(G, output) os.makedirs(outdir, exist_ok=True) # Labels. label = torch.zeros([1, G.c_dim], device=device) if G.c_dim != 0: if class_idx is None: raise click.ClickException('Must specify class label with --class when using a conditional network') label[:, class_idx] = 1 else: if class_idx is not None: print ('warn: --class=lbl ignored when running on an unconditional network') # Generate images. for seed_idx, seed in enumerate(seeds): print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device, dtype=dtype) # Construct an inverse rotation/translation matrix and pass to the generator. The # generator expects this matrix as an inverse to avoid potentially failing numerical # operations in the network. if hasattr(G.synthesis, 'input'): m = make_transform(translate, rotate) m = np.linalg.inv(m) G.synthesis.input.transform.copy_(torch.from_numpy(m)) img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode) img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png') #---------------------------------------------------------------------------- if __name__ == "__main__": generate_images() # pylint: disable=no-value-for-parameter #---------------------------------------------------------------------------- ================================================ FILE: gradio_utils/__init__.py ================================================ from .utils import (ImageMask, draw_mask_on_image, draw_points_on_image, get_latest_points_pair, get_valid_mask, on_change_single_global_state) __all__ = [ 'draw_mask_on_image', 'draw_points_on_image', 'on_change_single_global_state', 'get_latest_points_pair', 'get_valid_mask', 'ImageMask' ] ================================================ FILE: gradio_utils/utils.py ================================================ import gradio as gr import numpy as np from PIL import Image, ImageDraw class ImageMask(gr.components.Image): """ Sets: source="canvas", tool="sketch" """ is_template = True def __init__(self, **kwargs): super().__init__(source="upload", tool="sketch", interactive=False, **kwargs) def preprocess(self, x): if x is None: return x if self.tool == "sketch" and self.source in ["upload", "webcam" ] and type(x) != dict: decode_image = gr.processing_utils.decode_base64_to_image(x) width, height = decode_image.size mask = np.ones((height, width, 4), dtype=np.uint8) mask[..., -1] = 255 mask = self.postprocess(mask) x = {'image': x, 'mask': mask} return super().preprocess(x) def get_valid_mask(mask: np.ndarray): """Convert mask from gr.Image(0 to 255, RGBA) to binary mask. """ if mask.ndim == 3: mask_pil = Image.fromarray(mask).convert('L') mask = np.array(mask_pil) if mask.max() == 255: mask = mask / 255 return mask def draw_points_on_image(image, points, curr_point=None, highlight_all=True, radius_scale=0.01): overlay_rgba = Image.new("RGBA", image.size, 0) overlay_draw = ImageDraw.Draw(overlay_rgba) for point_key, point in points.items(): if ((curr_point is not None and curr_point == point_key) or highlight_all): p_color = (255, 0, 0) t_color = (0, 0, 255) else: p_color = (255, 0, 0, 35) t_color = (0, 0, 255, 35) rad_draw = int(image.size[0] * radius_scale) p_start = point.get("start_temp", point["start"]) p_target = point["target"] if p_start is not None and p_target is not None: p_draw = int(p_start[0]), int(p_start[1]) t_draw = int(p_target[0]), int(p_target[1]) overlay_draw.line( (p_draw[0], p_draw[1], t_draw[0], t_draw[1]), fill=(255, 255, 0), width=2, ) if p_start is not None: p_draw = int(p_start[0]), int(p_start[1]) overlay_draw.ellipse( ( p_draw[0] - rad_draw, p_draw[1] - rad_draw, p_draw[0] + rad_draw, p_draw[1] + rad_draw, ), fill=p_color, ) if curr_point is not None and curr_point == point_key: # overlay_draw.text(p_draw, "p", font=font, align="center", fill=(0, 0, 0)) overlay_draw.text(p_draw, "p", align="center", fill=(0, 0, 0)) if p_target is not None: t_draw = int(p_target[0]), int(p_target[1]) overlay_draw.ellipse( ( t_draw[0] - rad_draw, t_draw[1] - rad_draw, t_draw[0] + rad_draw, t_draw[1] + rad_draw, ), fill=t_color, ) if curr_point is not None and curr_point == point_key: # overlay_draw.text(t_draw, "t", font=font, align="center", fill=(0, 0, 0)) overlay_draw.text(t_draw, "t", align="center", fill=(0, 0, 0)) return Image.alpha_composite(image.convert("RGBA"), overlay_rgba).convert("RGB") def draw_mask_on_image(image, mask): im_mask = np.uint8(mask * 255) im_mask_rgba = np.concatenate( ( np.tile(im_mask[..., None], [1, 1, 3]), 45 * np.ones( (im_mask.shape[0], im_mask.shape[1], 1), dtype=np.uint8), ), axis=-1, ) im_mask_rgba = Image.fromarray(im_mask_rgba).convert("RGBA") return Image.alpha_composite(image.convert("RGBA"), im_mask_rgba).convert("RGB") def on_change_single_global_state(keys, value, global_state, map_transform=None): if map_transform is not None: value = map_transform(value) curr_state = global_state if isinstance(keys, str): last_key = keys else: for k in keys[:-1]: curr_state = curr_state[k] last_key = keys[-1] curr_state[last_key] = value return global_state def get_latest_points_pair(points_dict): if not points_dict: return None point_idx = list(points_dict.keys()) latest_point_idx = max(point_idx) return latest_point_idx ================================================ FILE: gui_utils/__init__.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. # empty ================================================ FILE: gui_utils/gl_utils.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import math import os import functools import contextlib import numpy as np import OpenGL.GL as gl import OpenGL.GL.ARB.texture_float import dnnlib #---------------------------------------------------------------------------- def init_egl(): assert os.environ['PYOPENGL_PLATFORM'] == 'egl' # Must be set before importing OpenGL. import OpenGL.EGL as egl import ctypes # Initialize EGL. display = egl.eglGetDisplay(egl.EGL_DEFAULT_DISPLAY) assert display != egl.EGL_NO_DISPLAY major = ctypes.c_int32() minor = ctypes.c_int32() ok = egl.eglInitialize(display, major, minor) assert ok assert major.value * 10 + minor.value >= 14 # Choose config. config_attribs = [ egl.EGL_RENDERABLE_TYPE, egl.EGL_OPENGL_BIT, egl.EGL_SURFACE_TYPE, egl.EGL_PBUFFER_BIT, egl.EGL_NONE ] configs = (ctypes.c_int32 * 1)() num_configs = ctypes.c_int32() ok = egl.eglChooseConfig(display, config_attribs, configs, 1, num_configs) assert ok assert num_configs.value == 1 config = configs[0] # Create dummy pbuffer surface. surface_attribs = [ egl.EGL_WIDTH, 1, egl.EGL_HEIGHT, 1, egl.EGL_NONE ] surface = egl.eglCreatePbufferSurface(display, config, surface_attribs) assert surface != egl.EGL_NO_SURFACE # Setup GL context. ok = egl.eglBindAPI(egl.EGL_OPENGL_API) assert ok context = egl.eglCreateContext(display, config, egl.EGL_NO_CONTEXT, None) assert context != egl.EGL_NO_CONTEXT ok = egl.eglMakeCurrent(display, surface, surface, context) assert ok #---------------------------------------------------------------------------- _texture_formats = { ('uint8', 1): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE, internalformat=gl.GL_LUMINANCE8), ('uint8', 2): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE_ALPHA, internalformat=gl.GL_LUMINANCE8_ALPHA8), ('uint8', 3): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGB, internalformat=gl.GL_RGB8), ('uint8', 4): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGBA, internalformat=gl.GL_RGBA8), ('float32', 1): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE32F_ARB), ('float32', 2): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE_ALPHA, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE_ALPHA32F_ARB), ('float32', 3): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGB, internalformat=gl.GL_RGB32F), ('float32', 4): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGBA, internalformat=gl.GL_RGBA32F), } def get_texture_format(dtype, channels): return _texture_formats[(np.dtype(dtype).name, int(channels))] #---------------------------------------------------------------------------- def prepare_texture_data(image): image = np.asarray(image) if image.ndim == 2: image = image[:, :, np.newaxis] if image.dtype.name == 'float64': image = image.astype('float32') return image #---------------------------------------------------------------------------- def draw_pixels(image, *, pos=0, zoom=1, align=0, rint=True): pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2]) zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2]) align = np.broadcast_to(np.asarray(align, dtype='float32'), [2]) image = prepare_texture_data(image) height, width, channels = image.shape size = zoom * [width, height] pos = pos - size * align if rint: pos = np.rint(pos) fmt = get_texture_format(image.dtype, channels) gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_PIXEL_MODE_BIT) gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT) gl.glRasterPos2f(pos[0], pos[1]) gl.glPixelZoom(zoom[0], -zoom[1]) gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1) gl.glDrawPixels(width, height, fmt.format, fmt.type, image) gl.glPopClientAttrib() gl.glPopAttrib() #---------------------------------------------------------------------------- def read_pixels(width, height, *, pos=0, dtype='uint8', channels=3): pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2]) dtype = np.dtype(dtype) fmt = get_texture_format(dtype, channels) image = np.empty([height, width, channels], dtype=dtype) gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT) gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1) gl.glReadPixels(int(np.round(pos[0])), int(np.round(pos[1])), width, height, fmt.format, fmt.type, image) gl.glPopClientAttrib() return np.flipud(image) #---------------------------------------------------------------------------- class Texture: def __init__(self, *, image=None, width=None, height=None, channels=None, dtype=None, bilinear=True, mipmap=True): self.gl_id = None self.bilinear = bilinear self.mipmap = mipmap # Determine size and dtype. if image is not None: image = prepare_texture_data(image) self.height, self.width, self.channels = image.shape self.dtype = image.dtype else: assert width is not None and height is not None self.width = width self.height = height self.channels = channels if channels is not None else 3 self.dtype = np.dtype(dtype) if dtype is not None else np.uint8 # Validate size and dtype. assert isinstance(self.width, int) and self.width >= 0 assert isinstance(self.height, int) and self.height >= 0 assert isinstance(self.channels, int) and self.channels >= 1 assert self.is_compatible(width=width, height=height, channels=channels, dtype=dtype) # Create texture object. self.gl_id = gl.glGenTextures(1) with self.bind(): gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE) gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE) gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR if self.bilinear else gl.GL_NEAREST) gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR_MIPMAP_LINEAR if self.mipmap else gl.GL_NEAREST) self.update(image) def delete(self): if self.gl_id is not None: gl.glDeleteTextures([self.gl_id]) self.gl_id = None def __del__(self): try: self.delete() except: pass @contextlib.contextmanager def bind(self): prev_id = gl.glGetInteger(gl.GL_TEXTURE_BINDING_2D) gl.glBindTexture(gl.GL_TEXTURE_2D, self.gl_id) yield gl.glBindTexture(gl.GL_TEXTURE_2D, prev_id) def update(self, image): if image is not None: image = prepare_texture_data(image) assert self.is_compatible(image=image) with self.bind(): fmt = get_texture_format(self.dtype, self.channels) gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT) gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1) gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, fmt.internalformat, self.width, self.height, 0, fmt.format, fmt.type, image) if self.mipmap: gl.glGenerateMipmap(gl.GL_TEXTURE_2D) gl.glPopClientAttrib() def draw(self, *, pos=0, zoom=1, align=0, rint=False, color=1, alpha=1, rounding=0): zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2]) size = zoom * [self.width, self.height] with self.bind(): gl.glPushAttrib(gl.GL_ENABLE_BIT) gl.glEnable(gl.GL_TEXTURE_2D) draw_rect(pos=pos, size=size, align=align, rint=rint, color=color, alpha=alpha, rounding=rounding) gl.glPopAttrib() def is_compatible(self, *, image=None, width=None, height=None, channels=None, dtype=None): # pylint: disable=too-many-return-statements if image is not None: if image.ndim != 3: return False ih, iw, ic = image.shape if not self.is_compatible(width=iw, height=ih, channels=ic, dtype=image.dtype): return False if width is not None and self.width != width: return False if height is not None and self.height != height: return False if channels is not None and self.channels != channels: return False if dtype is not None and self.dtype != dtype: return False return True #---------------------------------------------------------------------------- class Framebuffer: def __init__(self, *, texture=None, width=None, height=None, channels=None, dtype=None, msaa=0): self.texture = texture self.gl_id = None self.gl_color = None self.gl_depth_stencil = None self.msaa = msaa # Determine size and dtype. if texture is not None: assert isinstance(self.texture, Texture) self.width = texture.width self.height = texture.height self.channels = texture.channels self.dtype = texture.dtype else: assert width is not None and height is not None self.width = width self.height = height self.channels = channels if channels is not None else 4 self.dtype = np.dtype(dtype) if dtype is not None else np.float32 # Validate size and dtype. assert isinstance(self.width, int) and self.width >= 0 assert isinstance(self.height, int) and self.height >= 0 assert isinstance(self.channels, int) and self.channels >= 1 assert width is None or width == self.width assert height is None or height == self.height assert channels is None or channels == self.channels assert dtype is None or dtype == self.dtype # Create framebuffer object. self.gl_id = gl.glGenFramebuffers(1) with self.bind(): # Setup color buffer. if self.texture is not None: assert self.msaa == 0 gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_TEXTURE_2D, self.texture.gl_id, 0) else: fmt = get_texture_format(self.dtype, self.channels) self.gl_color = gl.glGenRenderbuffers(1) gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_color) gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, fmt.internalformat, self.width, self.height) gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_RENDERBUFFER, self.gl_color) # Setup depth/stencil buffer. self.gl_depth_stencil = gl.glGenRenderbuffers(1) gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_depth_stencil) gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, gl.GL_DEPTH24_STENCIL8, self.width, self.height) gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_DEPTH_STENCIL_ATTACHMENT, gl.GL_RENDERBUFFER, self.gl_depth_stencil) def delete(self): if self.gl_id is not None: gl.glDeleteFramebuffers([self.gl_id]) self.gl_id = None if self.gl_color is not None: gl.glDeleteRenderbuffers(1, [self.gl_color]) self.gl_color = None if self.gl_depth_stencil is not None: gl.glDeleteRenderbuffers(1, [self.gl_depth_stencil]) self.gl_depth_stencil = None def __del__(self): try: self.delete() except: pass @contextlib.contextmanager def bind(self): prev_fbo = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING) prev_rbo = gl.glGetInteger(gl.GL_RENDERBUFFER_BINDING) gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.gl_id) if self.width is not None and self.height is not None: gl.glViewport(0, 0, self.width, self.height) yield gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, prev_fbo) gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, prev_rbo) def blit(self, dst=None): assert dst is None or isinstance(dst, Framebuffer) with self.bind(): gl.glBindFramebuffer(gl.GL_DRAW_FRAMEBUFFER, 0 if dst is None else dst.fbo) gl.glBlitFramebuffer(0, 0, self.width, self.height, 0, 0, self.width, self.height, gl.GL_COLOR_BUFFER_BIT, gl.GL_NEAREST) #---------------------------------------------------------------------------- def draw_shape(vertices, *, mode=gl.GL_TRIANGLE_FAN, pos=0, size=1, color=1, alpha=1): assert vertices.ndim == 2 and vertices.shape[1] == 2 pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2]) size = np.broadcast_to(np.asarray(size, dtype='float32'), [2]) color = np.broadcast_to(np.asarray(color, dtype='float32'), [3]) alpha = np.clip(np.broadcast_to(np.asarray(alpha, dtype='float32'), []), 0, 1) gl.glPushClientAttrib(gl.GL_CLIENT_VERTEX_ARRAY_BIT) gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_TRANSFORM_BIT) gl.glMatrixMode(gl.GL_MODELVIEW) gl.glPushMatrix() gl.glEnableClientState(gl.GL_VERTEX_ARRAY) gl.glEnableClientState(gl.GL_TEXTURE_COORD_ARRAY) gl.glVertexPointer(2, gl.GL_FLOAT, 0, vertices) gl.glTexCoordPointer(2, gl.GL_FLOAT, 0, vertices) gl.glTranslate(pos[0], pos[1], 0) gl.glScale(size[0], size[1], 1) gl.glColor4f(color[0] * alpha, color[1] * alpha, color[2] * alpha, alpha) gl.glDrawArrays(mode, 0, vertices.shape[0]) gl.glPopMatrix() gl.glPopAttrib() gl.glPopClientAttrib() #---------------------------------------------------------------------------- def draw_arrow(x1, y1, x2, y2, l=10, width=1.0): # Compute the length and angle of the arrow dx = x2 - x1 dy = y2 - y1 length = math.sqrt(dx**2 + dy**2) if length < l: return angle = math.atan2(dy, dx) # Save the current modelview matrix gl.glPushMatrix() # Translate and rotate the coordinate system gl.glTranslatef(x1, y1, 0.0) gl.glRotatef(angle * 180.0 / math.pi, 0.0, 0.0, 1.0) # Set the line width gl.glLineWidth(width) # gl.glColor3f(0.75, 0.75, 0.75) # Begin drawing lines gl.glBegin(gl.GL_LINES) # Draw the shaft of the arrow gl.glVertex2f(0.0, 0.0) gl.glVertex2f(length, 0.0) # Draw the head of the arrow gl.glVertex2f(length, 0.0) gl.glVertex2f(length - 2 * l, l) gl.glVertex2f(length, 0.0) gl.glVertex2f(length - 2 * l, -l) # End drawing lines gl.glEnd() # Restore the modelview matrix gl.glPopMatrix() #---------------------------------------------------------------------------- def draw_rect(*, pos=0, pos2=None, size=None, align=0, rint=False, color=1, alpha=1, rounding=0): assert pos2 is None or size is None pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2]) pos2 = np.broadcast_to(np.asarray(pos2, dtype='float32'), [2]) if pos2 is not None else None size = np.broadcast_to(np.asarray(size, dtype='float32'), [2]) if size is not None else None size = size if size is not None else pos2 - pos if pos2 is not None else np.array([1, 1], dtype='float32') pos = pos - size * align if rint: pos = np.rint(pos) rounding = np.broadcast_to(np.asarray(rounding, dtype='float32'), [2]) rounding = np.minimum(np.abs(rounding) / np.maximum(np.abs(size), 1e-8), 0.5) if np.min(rounding) == 0: rounding *= 0 vertices = _setup_rect(float(rounding[0]), float(rounding[1])) draw_shape(vertices, mode=gl.GL_TRIANGLE_FAN, pos=pos, size=size, color=color, alpha=alpha) @functools.lru_cache(maxsize=10000) def _setup_rect(rx, ry): t = np.linspace(0, np.pi / 2, 1 if max(rx, ry) == 0 else 64) s = 1 - np.sin(t); c = 1 - np.cos(t) x = [c * rx, 1 - s * rx, 1 - c * rx, s * rx] y = [s * ry, c * ry, 1 - s * ry, 1 - c * ry] v = np.stack([x, y], axis=-1).reshape(-1, 2) return v.astype('float32') #---------------------------------------------------------------------------- def draw_circle(*, center=0, radius=100, hole=0, color=1, alpha=1): hole = np.broadcast_to(np.asarray(hole, dtype='float32'), []) vertices = _setup_circle(float(hole)) draw_shape(vertices, mode=gl.GL_TRIANGLE_STRIP, pos=center, size=radius, color=color, alpha=alpha) @functools.lru_cache(maxsize=10000) def _setup_circle(hole): t = np.linspace(0, np.pi * 2, 128) s = np.sin(t); c = np.cos(t) v = np.stack([c, s, c * hole, s * hole], axis=-1).reshape(-1, 2) return v.astype('float32') #---------------------------------------------------------------------------- ================================================ FILE: gui_utils/glfw_window.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import time import glfw import OpenGL.GL as gl from . import gl_utils #---------------------------------------------------------------------------- class GlfwWindow: # pylint: disable=too-many-public-methods def __init__(self, *, title='GlfwWindow', window_width=1920, window_height=1080, deferred_show=True, close_on_esc=True): self._glfw_window = None self._drawing_frame = False self._frame_start_time = None self._frame_delta = 0 self._fps_limit = None self._vsync = None self._skip_frames = 0 self._deferred_show = deferred_show self._close_on_esc = close_on_esc self._esc_pressed = False self._drag_and_drop_paths = None self._capture_next_frame = False self._captured_frame = None # Create window. glfw.init() glfw.window_hint(glfw.VISIBLE, False) self._glfw_window = glfw.create_window(width=window_width, height=window_height, title=title, monitor=None, share=None) self._attach_glfw_callbacks() self.make_context_current() # Adjust window. self.set_vsync(False) self.set_window_size(window_width, window_height) if not self._deferred_show: glfw.show_window(self._glfw_window) def close(self): if self._drawing_frame: self.end_frame() if self._glfw_window is not None: glfw.destroy_window(self._glfw_window) self._glfw_window = None #glfw.terminate() # Commented out to play it nice with other glfw clients. def __del__(self): try: self.close() except: pass @property def window_width(self): return self.content_width @property def window_height(self): return self.content_height + self.title_bar_height @property def content_width(self): width, _height = glfw.get_window_size(self._glfw_window) return width @property def content_height(self): _width, height = glfw.get_window_size(self._glfw_window) return height @property def title_bar_height(self): _left, top, _right, _bottom = glfw.get_window_frame_size(self._glfw_window) return top @property def monitor_width(self): _, _, width, _height = glfw.get_monitor_workarea(glfw.get_primary_monitor()) return width @property def monitor_height(self): _, _, _width, height = glfw.get_monitor_workarea(glfw.get_primary_monitor()) return height @property def frame_delta(self): return self._frame_delta def set_title(self, title): glfw.set_window_title(self._glfw_window, title) def set_window_size(self, width, height): width = min(width, self.monitor_width) height = min(height, self.monitor_height) glfw.set_window_size(self._glfw_window, width, max(height - self.title_bar_height, 0)) if width == self.monitor_width and height == self.monitor_height: self.maximize() def set_content_size(self, width, height): self.set_window_size(width, height + self.title_bar_height) def maximize(self): glfw.maximize_window(self._glfw_window) def set_position(self, x, y): glfw.set_window_pos(self._glfw_window, x, y + self.title_bar_height) def center(self): self.set_position((self.monitor_width - self.window_width) // 2, (self.monitor_height - self.window_height) // 2) def set_vsync(self, vsync): vsync = bool(vsync) if vsync != self._vsync: glfw.swap_interval(1 if vsync else 0) self._vsync = vsync def set_fps_limit(self, fps_limit): self._fps_limit = int(fps_limit) def should_close(self): return glfw.window_should_close(self._glfw_window) or (self._close_on_esc and self._esc_pressed) def skip_frame(self): self.skip_frames(1) def skip_frames(self, num): # Do not update window for the next N frames. self._skip_frames = max(self._skip_frames, int(num)) def is_skipping_frames(self): return self._skip_frames > 0 def capture_next_frame(self): self._capture_next_frame = True def pop_captured_frame(self): frame = self._captured_frame self._captured_frame = None return frame def pop_drag_and_drop_paths(self): paths = self._drag_and_drop_paths self._drag_and_drop_paths = None return paths def draw_frame(self): # To be overridden by subclass. self.begin_frame() # Rendering code goes here. self.end_frame() def make_context_current(self): if self._glfw_window is not None: glfw.make_context_current(self._glfw_window) def begin_frame(self): # End previous frame. if self._drawing_frame: self.end_frame() # Apply FPS limit. if self._frame_start_time is not None and self._fps_limit is not None: delay = self._frame_start_time - time.perf_counter() + 1 / self._fps_limit if delay > 0: time.sleep(delay) cur_time = time.perf_counter() if self._frame_start_time is not None: self._frame_delta = cur_time - self._frame_start_time self._frame_start_time = cur_time # Process events. glfw.poll_events() # Begin frame. self._drawing_frame = True self.make_context_current() # Initialize GL state. gl.glViewport(0, 0, self.content_width, self.content_height) gl.glMatrixMode(gl.GL_PROJECTION) gl.glLoadIdentity() gl.glTranslate(-1, 1, 0) gl.glScale(2 / max(self.content_width, 1), -2 / max(self.content_height, 1), 1) gl.glMatrixMode(gl.GL_MODELVIEW) gl.glLoadIdentity() gl.glEnable(gl.GL_BLEND) gl.glBlendFunc(gl.GL_ONE, gl.GL_ONE_MINUS_SRC_ALPHA) # Pre-multiplied alpha. # Clear. gl.glClearColor(0, 0, 0, 1) gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) def end_frame(self): assert self._drawing_frame self._drawing_frame = False # Skip frames if requested. if self._skip_frames > 0: self._skip_frames -= 1 return # Capture frame if requested. if self._capture_next_frame: self._captured_frame = gl_utils.read_pixels(self.content_width, self.content_height) self._capture_next_frame = False # Update window. if self._deferred_show: glfw.show_window(self._glfw_window) self._deferred_show = False glfw.swap_buffers(self._glfw_window) def _attach_glfw_callbacks(self): glfw.set_key_callback(self._glfw_window, self._glfw_key_callback) glfw.set_drop_callback(self._glfw_window, self._glfw_drop_callback) def _glfw_key_callback(self, _window, key, _scancode, action, _mods): if action == glfw.PRESS and key == glfw.KEY_ESCAPE: self._esc_pressed = True def _glfw_drop_callback(self, _window, paths): self._drag_and_drop_paths = paths #---------------------------------------------------------------------------- ================================================ FILE: gui_utils/imgui_utils.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import contextlib import imgui #---------------------------------------------------------------------------- def set_default_style(color_scheme='dark', spacing=9, indent=23, scrollbar=27): s = imgui.get_style() s.window_padding = [spacing, spacing] s.item_spacing = [spacing, spacing] s.item_inner_spacing = [spacing, spacing] s.columns_min_spacing = spacing s.indent_spacing = indent s.scrollbar_size = scrollbar s.frame_padding = [4, 3] s.window_border_size = 1 s.child_border_size = 1 s.popup_border_size = 1 s.frame_border_size = 1 s.window_rounding = 0 s.child_rounding = 0 s.popup_rounding = 3 s.frame_rounding = 3 s.scrollbar_rounding = 3 s.grab_rounding = 3 getattr(imgui, f'style_colors_{color_scheme}')(s) c0 = s.colors[imgui.COLOR_MENUBAR_BACKGROUND] c1 = s.colors[imgui.COLOR_FRAME_BACKGROUND] s.colors[imgui.COLOR_POPUP_BACKGROUND] = [x * 0.7 + y * 0.3 for x, y in zip(c0, c1)][:3] + [1] #---------------------------------------------------------------------------- @contextlib.contextmanager def grayed_out(cond=True): if cond: s = imgui.get_style() text = s.colors[imgui.COLOR_TEXT_DISABLED] grab = s.colors[imgui.COLOR_SCROLLBAR_GRAB] back = s.colors[imgui.COLOR_MENUBAR_BACKGROUND] imgui.push_style_color(imgui.COLOR_TEXT, *text) imgui.push_style_color(imgui.COLOR_CHECK_MARK, *grab) imgui.push_style_color(imgui.COLOR_SLIDER_GRAB, *grab) imgui.push_style_color(imgui.COLOR_SLIDER_GRAB_ACTIVE, *grab) imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND, *back) imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_HOVERED, *back) imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_ACTIVE, *back) imgui.push_style_color(imgui.COLOR_BUTTON, *back) imgui.push_style_color(imgui.COLOR_BUTTON_HOVERED, *back) imgui.push_style_color(imgui.COLOR_BUTTON_ACTIVE, *back) imgui.push_style_color(imgui.COLOR_HEADER, *back) imgui.push_style_color(imgui.COLOR_HEADER_HOVERED, *back) imgui.push_style_color(imgui.COLOR_HEADER_ACTIVE, *back) imgui.push_style_color(imgui.COLOR_POPUP_BACKGROUND, *back) yield imgui.pop_style_color(14) else: yield #---------------------------------------------------------------------------- @contextlib.contextmanager def item_width(width=None): if width is not None: imgui.push_item_width(width) yield imgui.pop_item_width() else: yield #---------------------------------------------------------------------------- def scoped_by_object_id(method): def decorator(self, *args, **kwargs): imgui.push_id(str(id(self))) res = method(self, *args, **kwargs) imgui.pop_id() return res return decorator #---------------------------------------------------------------------------- def button(label, width=0, enabled=True): with grayed_out(not enabled): clicked = imgui.button(label, width=width) clicked = clicked and enabled return clicked #---------------------------------------------------------------------------- def collapsing_header(text, visible=None, flags=0, default=False, enabled=True, show=True): expanded = False if show: if default: flags |= imgui.TREE_NODE_DEFAULT_OPEN if not enabled: flags |= imgui.TREE_NODE_LEAF with grayed_out(not enabled): expanded, visible = imgui.collapsing_header(text, visible=visible, flags=flags) expanded = expanded and enabled return expanded, visible #---------------------------------------------------------------------------- def popup_button(label, width=0, enabled=True): if button(label, width, enabled): imgui.open_popup(label) opened = imgui.begin_popup(label) return opened #---------------------------------------------------------------------------- def input_text(label, value, buffer_length, flags, width=None, help_text=''): old_value = value color = list(imgui.get_style().colors[imgui.COLOR_TEXT]) if value == '': color[-1] *= 0.5 with item_width(width): imgui.push_style_color(imgui.COLOR_TEXT, *color) value = value if value != '' else help_text changed, value = imgui.input_text(label, value, buffer_length, flags) value = value if value != help_text else '' imgui.pop_style_color(1) if not flags & imgui.INPUT_TEXT_ENTER_RETURNS_TRUE: changed = (value != old_value) return changed, value #---------------------------------------------------------------------------- def drag_previous_control(enabled=True): dragging = False dx = 0 dy = 0 if imgui.begin_drag_drop_source(imgui.DRAG_DROP_SOURCE_NO_PREVIEW_TOOLTIP): if enabled: dragging = True dx, dy = imgui.get_mouse_drag_delta() imgui.reset_mouse_drag_delta() imgui.end_drag_drop_source() return dragging, dx, dy #---------------------------------------------------------------------------- def drag_button(label, width=0, enabled=True): clicked = button(label, width=width, enabled=enabled) dragging, dx, dy = drag_previous_control(enabled=enabled) return clicked, dragging, dx, dy #---------------------------------------------------------------------------- def drag_hidden_window(label, x, y, width, height, enabled=True): imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0) imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0) imgui.set_next_window_position(x, y) imgui.set_next_window_size(width, height) imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE)) dragging, dx, dy = drag_previous_control(enabled=enabled) imgui.end() imgui.pop_style_color(2) return dragging, dx, dy #---------------------------------------------------------------------------- def click_hidden_window(label, x, y, width, height, img_w, img_h, enabled=True): imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0) imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0) imgui.set_next_window_position(x, y) imgui.set_next_window_size(width, height) imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE)) clicked, down = False, False img_x, img_y = 0, 0 if imgui.is_mouse_down(): posx, posy = imgui.get_mouse_pos() if posx >= x and posx < x + width and posy >= y and posy < y + height: if imgui.is_mouse_clicked(): clicked = True down = True img_x = round((posx - x) / (width - 1) * (img_w - 1)) img_y = round((posy - y) / (height - 1) * (img_h - 1)) imgui.end() imgui.pop_style_color(2) return clicked, down, img_x, img_y #---------------------------------------------------------------------------- ================================================ FILE: gui_utils/imgui_window.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import os import imgui import imgui.integrations.glfw from . import glfw_window from . import imgui_utils from . import text_utils #---------------------------------------------------------------------------- class ImguiWindow(glfw_window.GlfwWindow): def __init__(self, *, title='ImguiWindow', font=None, font_sizes=range(14,24), **glfw_kwargs): if font is None: font = text_utils.get_default_font() font_sizes = {int(size) for size in font_sizes} super().__init__(title=title, **glfw_kwargs) # Init fields. self._imgui_context = None self._imgui_renderer = None self._imgui_fonts = None self._cur_font_size = max(font_sizes) # Delete leftover imgui.ini to avoid unexpected behavior. if os.path.isfile('imgui.ini'): os.remove('imgui.ini') # Init ImGui. self._imgui_context = imgui.create_context() self._imgui_renderer = _GlfwRenderer(self._glfw_window) self._attach_glfw_callbacks() imgui.get_io().ini_saving_rate = 0 # Disable creating imgui.ini at runtime. imgui.get_io().mouse_drag_threshold = 0 # Improve behavior with imgui_utils.drag_custom(). self._imgui_fonts = {size: imgui.get_io().fonts.add_font_from_file_ttf(font, size) for size in font_sizes} self._imgui_renderer.refresh_font_texture() def close(self): self.make_context_current() self._imgui_fonts = None if self._imgui_renderer is not None: self._imgui_renderer.shutdown() self._imgui_renderer = None if self._imgui_context is not None: #imgui.destroy_context(self._imgui_context) # Commented out to avoid creating imgui.ini at the end. self._imgui_context = None super().close() def _glfw_key_callback(self, *args): super()._glfw_key_callback(*args) self._imgui_renderer.keyboard_callback(*args) @property def font_size(self): return self._cur_font_size @property def spacing(self): return round(self._cur_font_size * 0.4) def set_font_size(self, target): # Applied on next frame. self._cur_font_size = min((abs(key - target), key) for key in self._imgui_fonts.keys())[1] def begin_frame(self): # Begin glfw frame. super().begin_frame() # Process imgui events. self._imgui_renderer.mouse_wheel_multiplier = self._cur_font_size / 10 if self.content_width > 0 and self.content_height > 0: self._imgui_renderer.process_inputs() # Begin imgui frame. imgui.new_frame() imgui.push_font(self._imgui_fonts[self._cur_font_size]) imgui_utils.set_default_style(spacing=self.spacing, indent=self.font_size, scrollbar=self.font_size+4) def end_frame(self): imgui.pop_font() imgui.render() imgui.end_frame() self._imgui_renderer.render(imgui.get_draw_data()) super().end_frame() #---------------------------------------------------------------------------- # Wrapper class for GlfwRenderer to fix a mouse wheel bug on Linux. class _GlfwRenderer(imgui.integrations.glfw.GlfwRenderer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.mouse_wheel_multiplier = 1 def scroll_callback(self, window, x_offset, y_offset): self.io.mouse_wheel += y_offset * self.mouse_wheel_multiplier #---------------------------------------------------------------------------- ================================================ FILE: gui_utils/text_utils.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import functools from typing import Optional import dnnlib import numpy as np import PIL.Image import PIL.ImageFont import scipy.ndimage from . import gl_utils #---------------------------------------------------------------------------- def get_default_font(): url = 'http://fonts.gstatic.com/s/opensans/v17/mem8YaGs126MiZpBA-U1UpcaXcl0Aw.ttf' # Open Sans regular return dnnlib.util.open_url(url, return_filename=True) #---------------------------------------------------------------------------- @functools.lru_cache(maxsize=None) def get_pil_font(font=None, size=32): if font is None: font = get_default_font() return PIL.ImageFont.truetype(font=font, size=size) #---------------------------------------------------------------------------- def get_array(string, *, dropshadow_radius: int=None, **kwargs): if dropshadow_radius is not None: offset_x = int(np.ceil(dropshadow_radius*2/3)) offset_y = int(np.ceil(dropshadow_radius*2/3)) return _get_array_priv(string, dropshadow_radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs) else: return _get_array_priv(string, **kwargs) @functools.lru_cache(maxsize=10000) def _get_array_priv( string: str, *, size: int = 32, max_width: Optional[int]=None, max_height: Optional[int]=None, min_size=10, shrink_coef=0.8, dropshadow_radius: int=None, offset_x: int=None, offset_y: int=None, **kwargs ): cur_size = size array = None while True: if dropshadow_radius is not None: # separate implementation for dropshadow text rendering array = _get_array_impl_dropshadow(string, size=cur_size, radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs) else: array = _get_array_impl(string, size=cur_size, **kwargs) height, width, _ = array.shape if (max_width is None or width <= max_width) and (max_height is None or height <= max_height) or (cur_size <= min_size): break cur_size = max(int(cur_size * shrink_coef), min_size) return array #---------------------------------------------------------------------------- @functools.lru_cache(maxsize=10000) def _get_array_impl(string, *, font=None, size=32, outline=0, outline_pad=3, outline_coef=3, outline_exp=2, line_pad: int=None): pil_font = get_pil_font(font=font, size=size) lines = [pil_font.getmask(line, 'L') for line in string.split('\n')] lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines] width = max(line.shape[1] for line in lines) lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines] line_spacing = line_pad if line_pad is not None else size // 2 lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:] mask = np.concatenate(lines, axis=0) alpha = mask if outline > 0: mask = np.pad(mask, int(np.ceil(outline * outline_pad)), mode='constant', constant_values=0) alpha = mask.astype(np.float32) / 255 alpha = scipy.ndimage.gaussian_filter(alpha, outline) alpha = 1 - np.maximum(1 - alpha * outline_coef, 0) ** outline_exp alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8) alpha = np.maximum(alpha, mask) return np.stack([mask, alpha], axis=-1) #---------------------------------------------------------------------------- @functools.lru_cache(maxsize=10000) def _get_array_impl_dropshadow(string, *, font=None, size=32, radius: int, offset_x: int, offset_y: int, line_pad: int=None, **kwargs): assert (offset_x > 0) and (offset_y > 0) pil_font = get_pil_font(font=font, size=size) lines = [pil_font.getmask(line, 'L') for line in string.split('\n')] lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines] width = max(line.shape[1] for line in lines) lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines] line_spacing = line_pad if line_pad is not None else size // 2 lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:] mask = np.concatenate(lines, axis=0) alpha = mask mask = np.pad(mask, 2*radius + max(abs(offset_x), abs(offset_y)), mode='constant', constant_values=0) alpha = mask.astype(np.float32) / 255 alpha = scipy.ndimage.gaussian_filter(alpha, radius) alpha = 1 - np.maximum(1 - alpha * 1.5, 0) ** 1.4 alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8) alpha = np.pad(alpha, [(offset_y, 0), (offset_x, 0)], mode='constant')[:-offset_y, :-offset_x] alpha = np.maximum(alpha, mask) return np.stack([mask, alpha], axis=-1) #---------------------------------------------------------------------------- @functools.lru_cache(maxsize=10000) def get_texture(string, bilinear=True, mipmap=True, **kwargs): return gl_utils.Texture(image=get_array(string, **kwargs), bilinear=bilinear, mipmap=mipmap) #---------------------------------------------------------------------------- ================================================ FILE: legacy.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Converting legacy network pickle into the new format.""" import click import pickle import re import copy import numpy as np import torch import dnnlib from torch_utils import misc #---------------------------------------------------------------------------- def load_network_pkl(f, force_fp16=False): data = _LegacyUnpickler(f).load() # Legacy TensorFlow pickle => convert. if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data): tf_G, tf_D, tf_Gs = data G = convert_tf_generator(tf_G) D = convert_tf_discriminator(tf_D) G_ema = convert_tf_generator(tf_Gs) data = dict(G=G, D=D, G_ema=G_ema) # Add missing fields. if 'training_set_kwargs' not in data: data['training_set_kwargs'] = None if 'augment_pipe' not in data: data['augment_pipe'] = None # Validate contents. assert isinstance(data['G'], torch.nn.Module) assert isinstance(data['D'], torch.nn.Module) assert isinstance(data['G_ema'], torch.nn.Module) assert isinstance(data['training_set_kwargs'], (dict, type(None))) assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None))) # Force FP16. if force_fp16: for key in ['G', 'D', 'G_ema']: old = data[key] kwargs = copy.deepcopy(old.init_kwargs) fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs) fp16_kwargs.num_fp16_res = 4 fp16_kwargs.conv_clamp = 256 if kwargs != old.init_kwargs: new = type(old)(**kwargs).eval().requires_grad_(False) misc.copy_params_and_buffers(old, new, require_all=True) data[key] = new return data #---------------------------------------------------------------------------- class _TFNetworkStub(dnnlib.EasyDict): pass class _LegacyUnpickler(pickle.Unpickler): def find_class(self, module, name): if module == 'dnnlib.tflib.network' and name == 'Network': return _TFNetworkStub return super().find_class(module, name) #---------------------------------------------------------------------------- def _collect_tf_params(tf_net): # pylint: disable=protected-access tf_params = dict() def recurse(prefix, tf_net): for name, value in tf_net.variables: tf_params[prefix + name] = value for name, comp in tf_net.components.items(): recurse(prefix + name + '/', comp) recurse('', tf_net) return tf_params #---------------------------------------------------------------------------- def _populate_module_params(module, *patterns): for name, tensor in misc.named_params_and_buffers(module): found = False value = None for pattern, value_fn in zip(patterns[0::2], patterns[1::2]): match = re.fullmatch(pattern, name) if match: found = True if value_fn is not None: value = value_fn(*match.groups()) break try: assert found if value is not None: tensor.copy_(torch.from_numpy(np.array(value))) except: print(name, list(tensor.shape)) raise #---------------------------------------------------------------------------- def convert_tf_generator(tf_G): if tf_G.version < 4: raise ValueError('TensorFlow pickle version too low') # Collect kwargs. tf_kwargs = tf_G.static_kwargs known_kwargs = set() def kwarg(tf_name, default=None, none=None): known_kwargs.add(tf_name) val = tf_kwargs.get(tf_name, default) return val if val is not None else none # Convert kwargs. from training import networks_stylegan2 network_class = networks_stylegan2.Generator kwargs = dnnlib.EasyDict( z_dim = kwarg('latent_size', 512), c_dim = kwarg('label_size', 0), w_dim = kwarg('dlatent_size', 512), img_resolution = kwarg('resolution', 1024), img_channels = kwarg('num_channels', 3), channel_base = kwarg('fmap_base', 16384) * 2, channel_max = kwarg('fmap_max', 512), num_fp16_res = kwarg('num_fp16_res', 0), conv_clamp = kwarg('conv_clamp', None), architecture = kwarg('architecture', 'skip'), resample_filter = kwarg('resample_kernel', [1,3,3,1]), use_noise = kwarg('use_noise', True), activation = kwarg('nonlinearity', 'lrelu'), mapping_kwargs = dnnlib.EasyDict( num_layers = kwarg('mapping_layers', 8), embed_features = kwarg('label_fmaps', None), layer_features = kwarg('mapping_fmaps', None), activation = kwarg('mapping_nonlinearity', 'lrelu'), lr_multiplier = kwarg('mapping_lrmul', 0.01), w_avg_beta = kwarg('w_avg_beta', 0.995, none=1), ), ) # Check for unknown kwargs. kwarg('truncation_psi') kwarg('truncation_cutoff') kwarg('style_mixing_prob') kwarg('structure') kwarg('conditioning') kwarg('fused_modconv') unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) if len(unknown_kwargs) > 0: raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0]) # Collect params. tf_params = _collect_tf_params(tf_G) for name, value in list(tf_params.items()): match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name) if match: r = kwargs.img_resolution // (2 ** int(match.group(1))) tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value kwargs.synthesis.kwargs.architecture = 'orig' #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}') # Convert params. G = network_class(**kwargs).eval().requires_grad_(False) # pylint: disable=unnecessary-lambda # pylint: disable=f-string-without-interpolation _populate_module_params(G, r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'], r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(), r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'], r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(), r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'], r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0], r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1), r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'], r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0], r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'], r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(), r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1, r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1), r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'], r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0], r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'], r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(), r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1, r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1), r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'], r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0], r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'], r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(), r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1, r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1), r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'], r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(), r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1, r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1), r'.*\.resample_filter', None, r'.*\.act_filter', None, ) return G #---------------------------------------------------------------------------- def convert_tf_discriminator(tf_D): if tf_D.version < 4: raise ValueError('TensorFlow pickle version too low') # Collect kwargs. tf_kwargs = tf_D.static_kwargs known_kwargs = set() def kwarg(tf_name, default=None): known_kwargs.add(tf_name) return tf_kwargs.get(tf_name, default) # Convert kwargs. kwargs = dnnlib.EasyDict( c_dim = kwarg('label_size', 0), img_resolution = kwarg('resolution', 1024), img_channels = kwarg('num_channels', 3), architecture = kwarg('architecture', 'resnet'), channel_base = kwarg('fmap_base', 16384) * 2, channel_max = kwarg('fmap_max', 512), num_fp16_res = kwarg('num_fp16_res', 0), conv_clamp = kwarg('conv_clamp', None), cmap_dim = kwarg('mapping_fmaps', None), block_kwargs = dnnlib.EasyDict( activation = kwarg('nonlinearity', 'lrelu'), resample_filter = kwarg('resample_kernel', [1,3,3,1]), freeze_layers = kwarg('freeze_layers', 0), ), mapping_kwargs = dnnlib.EasyDict( num_layers = kwarg('mapping_layers', 0), embed_features = kwarg('mapping_fmaps', None), layer_features = kwarg('mapping_fmaps', None), activation = kwarg('nonlinearity', 'lrelu'), lr_multiplier = kwarg('mapping_lrmul', 0.1), ), epilogue_kwargs = dnnlib.EasyDict( mbstd_group_size = kwarg('mbstd_group_size', None), mbstd_num_channels = kwarg('mbstd_num_features', 1), activation = kwarg('nonlinearity', 'lrelu'), ), ) # Check for unknown kwargs. kwarg('structure') kwarg('conditioning') unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) if len(unknown_kwargs) > 0: raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0]) # Collect params. tf_params = _collect_tf_params(tf_D) for name, value in list(tf_params.items()): match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name) if match: r = kwargs.img_resolution // (2 ** int(match.group(1))) tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value kwargs.architecture = 'orig' #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}') # Convert params. from training import networks_stylegan2 D = networks_stylegan2.Discriminator(**kwargs).eval().requires_grad_(False) # pylint: disable=unnecessary-lambda # pylint: disable=f-string-without-interpolation _populate_module_params(D, r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1), r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'], r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1), r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'], r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1), r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(), r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'], r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(), r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'], r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1), r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'], r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(), r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'], r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(), r'b4\.out\.bias', lambda: tf_params[f'Output/bias'], r'.*\.resample_filter', None, ) return D #---------------------------------------------------------------------------- @click.command() @click.option('--source', help='Input pickle', required=True, metavar='PATH') @click.option('--dest', help='Output pickle', required=True, metavar='PATH') @click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True) def convert_network_pickle(source, dest, force_fp16): """Convert legacy network pickle into the native PyTorch format. The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA. It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks. Example: \b python legacy.py \\ --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\ --dest=stylegan2-cat-config-f.pkl """ print(f'Loading "{source}"...') with dnnlib.util.open_url(source) as f: data = load_network_pkl(f, force_fp16=force_fp16) print(f'Saving "{dest}"...') with open(dest, 'wb') as f: pickle.dump(data, f) print('Done.') #---------------------------------------------------------------------------- if __name__ == "__main__": convert_network_pickle() # pylint: disable=no-value-for-parameter #---------------------------------------------------------------------------- ================================================ FILE: requirements.txt ================================================ torch>=2.0.0 scipy>=1.11.1 Ninja==1.10.2 gradio>=3.35.2 imageio-ffmpeg>=0.4.3 huggingface_hub hf_transfer pyopengl imgui glfw==2.6.1 pillow>=9.4.0 torchvision>=0.15.2 imageio>=2.9.0 ================================================ FILE: stylegan_human/.gitignore ================================================ .DS_Store __pycache__ *.pt *.pth *.pdparams *.pdiparams *.pdmodel *.pkl *.info *.yaml ================================================ FILE: stylegan_human/PP_HumanSeg/deploy/infer.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import codecs import os import time import yaml import numpy as np import cv2 import paddle import paddleseg.transforms as T from paddle.inference import create_predictor, PrecisionType from paddle.inference import Config as PredictConfig from paddleseg.core.infer import reverse_transform from paddleseg.cvlibs import manager from paddleseg.utils import TimeAverager from ..scripts.optic_flow_process import optic_flow_process class DeployConfig: def __init__(self, path): with codecs.open(path, 'r', 'utf-8') as file: self.dic = yaml.load(file, Loader=yaml.FullLoader) self._transforms = self._load_transforms(self.dic['Deploy'][ 'transforms']) self._dir = os.path.dirname(path) @property def transforms(self): return self._transforms @property def model(self): return os.path.join(self._dir, self.dic['Deploy']['model']) @property def params(self): return os.path.join(self._dir, self.dic['Deploy']['params']) def _load_transforms(self, t_list): com = manager.TRANSFORMS transforms = [] for t in t_list: ctype = t.pop('type') transforms.append(com[ctype](**t)) return transforms class Predictor: def __init__(self, args): self.cfg = DeployConfig(args.cfg) self.args = args self.compose = T.Compose(self.cfg.transforms) resize_h, resize_w = args.input_shape self.disflow = cv2.DISOpticalFlow_create( cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST) self.prev_gray = np.zeros((resize_h, resize_w), np.uint8) self.prev_cfd = np.zeros((resize_h, resize_w), np.float32) self.is_init = True pred_cfg = PredictConfig(self.cfg.model, self.cfg.params) pred_cfg.disable_glog_info() if self.args.use_gpu: pred_cfg.enable_use_gpu(100, 0) self.predictor = create_predictor(pred_cfg) if self.args.test_speed: self.cost_averager = TimeAverager() def preprocess(self, img): ori_shapes = [] processed_imgs = [] processed_img = self.compose(img)[0] processed_imgs.append(processed_img) ori_shapes.append(img.shape) return processed_imgs, ori_shapes def run(self, img, bg): input_names = self.predictor.get_input_names() input_handle = self.predictor.get_input_handle(input_names[0]) processed_imgs, ori_shapes = self.preprocess(img) data = np.array(processed_imgs) input_handle.reshape(data.shape) input_handle.copy_from_cpu(data) if self.args.test_speed: start = time.time() self.predictor.run() if self.args.test_speed: self.cost_averager.record(time.time() - start) output_names = self.predictor.get_output_names() output_handle = self.predictor.get_output_handle(output_names[0]) output = output_handle.copy_to_cpu() return self.postprocess(output, img, ori_shapes[0], bg) def postprocess(self, pred, img, ori_shape, bg): if not os.path.exists(self.args.save_dir): os.makedirs(self.args.save_dir) resize_w = pred.shape[-1] resize_h = pred.shape[-2] if self.args.soft_predict: if self.args.use_optic_flow: score_map = pred[:, 1, :, :].squeeze(0) score_map = 255 * score_map cur_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) cur_gray = cv2.resize(cur_gray, (resize_w, resize_h)) optflow_map = optic_flow_process(cur_gray, score_map, self.prev_gray, self.prev_cfd, \ self.disflow, self.is_init) self.prev_gray = cur_gray.copy() self.prev_cfd = optflow_map.copy() self.is_init = False score_map = np.repeat(optflow_map[:, :, np.newaxis], 3, axis=2) score_map = np.transpose(score_map, [2, 0, 1])[np.newaxis, ...] score_map = reverse_transform( paddle.to_tensor(score_map), ori_shape, self.cfg.transforms, mode='bilinear') alpha = np.transpose(score_map.numpy().squeeze(0), [1, 2, 0]) / 255 else: score_map = pred[:, 1, :, :] score_map = score_map[np.newaxis, ...] score_map = reverse_transform( paddle.to_tensor(score_map), ori_shape, self.cfg.transforms, mode='bilinear') alpha = np.transpose(score_map.numpy().squeeze(0), [1, 2, 0]) else: if pred.ndim == 3: pred = pred[:, np.newaxis, ...] result = reverse_transform( paddle.to_tensor( pred, dtype='float32'), ori_shape, self.cfg.transforms, mode='bilinear') result = np.array(result) if self.args.add_argmax: result = np.argmax(result, axis=1) else: result = result.squeeze(1) alpha = np.transpose(result, [1, 2, 0]) # background replace h, w, _ = img.shape if bg is None: bg = np.ones_like(img)*255 else: bg = cv2.resize(bg, (w, h)) if bg.ndim == 2: bg = bg[..., np.newaxis] comb = (alpha * img + (1 - alpha) * bg).astype(np.uint8) return comb, alpha, bg, img ================================================ FILE: stylegan_human/PP_HumanSeg/export_model/download_export_model.py ================================================ # coding: utf8 # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import os LOCAL_PATH = os.path.dirname(os.path.abspath(__file__)) TEST_PATH = os.path.join(LOCAL_PATH, "../../../", "test") sys.path.append(TEST_PATH) from paddleseg.utils.download import download_file_and_uncompress model_urls = { "pphumanseg_lite_portrait_398x224_with_softmax": "https://paddleseg.bj.bcebos.com/dygraph/ppseg/ppseg_lite_portrait_398x224_with_softmax.tar.gz", "deeplabv3p_resnet50_os8_humanseg_512x512_100k_with_softmax": "https://paddleseg.bj.bcebos.com/dygraph/humanseg/export/deeplabv3p_resnet50_os8_humanseg_512x512_100k_with_softmax.zip", "fcn_hrnetw18_small_v1_humanseg_192x192_with_softmax": "https://paddleseg.bj.bcebos.com/dygraph/humanseg/export/fcn_hrnetw18_small_v1_humanseg_192x192_with_softmax.zip", "pphumanseg_lite_generic_humanseg_192x192_with_softmax": "https://paddleseg.bj.bcebos.com/dygraph/humanseg/export/pphumanseg_lite_generic_192x192_with_softmax.zip", } if __name__ == "__main__": for model_name, url in model_urls.items(): download_file_and_uncompress( url=url, savepath=LOCAL_PATH, extrapath=LOCAL_PATH, extraname=model_name) print("Export model download success!") ================================================ FILE: stylegan_human/PP_HumanSeg/pretrained_model/download_pretrained_model.py ================================================ # coding: utf8 # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import os LOCAL_PATH = os.path.dirname(os.path.abspath(__file__)) TEST_PATH = os.path.join(LOCAL_PATH, "../../../", "test") sys.path.append(TEST_PATH) from paddleseg.utils.download import download_file_and_uncompress model_urls = { "pphumanseg_lite_portrait_398x224": "https://paddleseg.bj.bcebos.com/dygraph/ppseg/ppseg_lite_portrait_398x224.tar.gz", "deeplabv3p_resnet50_os8_humanseg_512x512_100k": "https://paddleseg.bj.bcebos.com/dygraph/humanseg/train/deeplabv3p_resnet50_os8_humanseg_512x512_100k.zip", "fcn_hrnetw18_small_v1_humanseg_192x192": "https://paddleseg.bj.bcebos.com/dygraph/humanseg/train/fcn_hrnetw18_small_v1_humanseg_192x192.zip", "pphumanseg_lite_generic_human_192x192": "https://paddleseg.bj.bcebos.com/dygraph/humanseg/train/pphumanseg_lite_generic_192x192.zip", } if __name__ == "__main__": for model_name, url in model_urls.items(): download_file_and_uncompress( url=url, savepath=LOCAL_PATH, extrapath=LOCAL_PATH, extraname=model_name) print("Pretrained model download success!") ================================================ FILE: stylegan_human/README.md ================================================ # StyleGAN-Human: A Data-Centric Odyssey of Human Generation > > > **Abstract:** *Unconditional human image generation is an important task in vision and graphics, which enables various applications in the creative industry. Existing studies in this field mainly focus on "network engineering" such as designing new components and objective functions. This work takes a data-centric perspective and investigates multiple critical aspects in "data engineering", which we believe would complement the current practice. To facilitate a comprehensive study, we collect and annotate a large-scale human image dataset with over 230K samples capturing diverse poses and textures. Equipped with this large dataset, we rigorously investigate three essential factors in data engineering for StyleGAN-based human generation, namely data size, data distribution, and data alignment. Extensive experiments reveal several valuable observations w.r.t. these aspects: 1) Large-scale data, more than 40K images, are needed to train a high-fidelity unconditional human generation model with vanilla StyleGAN. 2) A balanced training set helps improve the generation quality with rare face poses compared to the long-tailed counterpart, whereas simply balancing the clothing texture distribution does not effectively bring an improvement. 3) Human GAN models with body centers for alignment outperform models trained using face centers or pelvis points as alignment anchors. In addition, a model zoo and human editing applications are demonstrated to facilitate future research in the community.*
**Keyword:** Human Image Generation, Data-Centric, StyleGAN [Jianglin Fu](mailto:fujianglin@sensetime.com), [Shikai Li](mailto:lishikai@sensetime.com), [Yuming Jiang](https://yumingj.github.io/), [Kwan-Yee Lin](https://kwanyeelin.github.io/), [Chen Qian](https://scholar.google.com/citations?user=AerkT0YAAAAJ&hl=zh-CN), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/), [Wayne Wu](https://wywu.github.io/), and [Ziwei Liu](https://liuziwei7.github.io/)
**[[Demo Video]](https://youtu.be/nIrb9hwsdcI)** | **[[Project Page]](https://stylegan-human.github.io/)** | **[[Paper]](https://arxiv.org/pdf/2204.11823.pdf)** ## Updates - [20/07/2022] [SHHQ-1.0](./docs/Dataset.md) dataset with 40K images is released! :sparkles: - [15/06/2022] Data alignment and real-image inversion scripts are released. - [26/04/2022] Technical report released! - [22/04/2022] Technical report will be released before May. - [21/04/2022] The codebase and project page are created. ## Data Download The first version SHHQ-1.0, with 40K images is released. To download and use the dataset set, please read the instructions in [Dataset.md](./docs/Dataset.md) (We are currently facing large incoming applications, and we need to carefully verify all the applicants, please be patient, and we will reply to you as soon as possible.) ## Model Zoo | Structure | 1024x512 | Metric | Scores | 512x256 | Metric | Scores | | --------- |:----------:| :----------:| :----------:| :-----: | :-----: | :-----: | | StyleGAN1 |[stylegan_human_v1_1024.pkl](https://drive.google.com/file/d/1h-R-IV-INGdPEzj4P9ml6JTEvihuNgLX/view?usp=sharing)| fid50k | 3.79 | to be released | - | - | | StyleGAN2 |[stylegan_human_v2_1024.pkl](https://drive.google.com/file/d/1FlAb1rYa0r_--Zj_ML8e6shmaF28hQb5/view?usp=sharing)| fid50k_full | 1.57 |[stylegan_human_v2_512.pkl](https://drive.google.com/file/d/1dlFEHbu-WzQWJl7nBBZYcTyo000H9hVm/view?usp=sharing) | fid50k_full | 1.97 | | StyleGAN3 |to be released | - | - | [stylegan_human_v3_512.pkl](https://drive.google.com/file/d/1_274jk_N6WSCkKWeu7hjHycqGvbuOFf5/view?usp=sharing) | fid50k_full | 2.54 | ## Web Demo Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo for generation: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/hysts/StyleGAN-Human) and interpolation [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/hysts/StyleGAN-Human-Interpolation) We prepare a Colab demo to allow you to synthesize images with the provided models, as well as visualize the performance of style-mixing, interpolation, and attributes editing. The notebook will guide you to install the necessary environment and download pretrained models. The output images can be found in `./StyleGAN-Human/outputs/`. Hope you enjoy! ## Usage ### System requirements * The original code bases are [stylegan (tensorflow)](https://github.com/NVlabs/stylegan), [stylegan2-ada (pytorch)](https://github.com/NVlabs/stylegan2-ada-pytorch), [stylegan3 (pytorch)](https://github.com/NVlabs/stylegan3), released by NVidia * We tested in Python 3.8.5 and PyTorch 1.9.1 with CUDA 11.1. (See https://pytorch.org for PyTorch install instructions.) ### Installation To work with this project on your own machine, you need to install the environmnet as follows: ``` conda env create -f environment.yml conda activate stylehuman # [Optional: tensorflow 1.x is required for StyleGAN1. ] pip install nvidia-pyindex pip install nvidia-tensorflow[horovod] pip install nvidia-tensorboard==1.15 ``` Extra notes: 1. In case having some conflicts when calling CUDA version, please try to empty the LD_LIBRARY_PATH. For example: ``` LD_LIBRARY_PATH=; python generate.py --outdir=out/stylegan_human_v2_1024 --trunc=1 --seeds=1,3,5,7 --network=pretrained_models/stylegan_human_v2_1024.pkl --version 2 ``` 2. We found the following troubleshooting links might be helpful: [1.](https://github.com/NVlabs/stylegan3), [2.](https://github.com/NVlabs/stylegan3/blob/main/docs/troubleshooting.md) ### Train The training scripts are based on the original [stylegan1](https://github.com/NVlabs/stylegan), [stylegan2-ada](https://github.com/NVlabs/stylegan2-ada-pytorch), and [stylegan3](https://github.com/NVlabs/stylegan3) with minor changes. Here we only provide the scripts with modifications for SG2 and SG3. You can replace the old files with the provided scripts to train. (assume SHHQ-1.0 is placed under data/) #### Train Stylegan2-ada-pytorch with SHHQ-1.0 ``` python train.py --outdir=training_results/sg2/ --data=data/SHHQ-1.0/ \ --gpus=8 --aug=noaug --mirror=1 --snap=250 --cfg=shhq --square=False ``` #### Train Stylegan3 with SHHQ-1.0 ``` python train.py --outdir=training_results/sg3/ --cfg=stylegan3-r --gpus=8 --batch=32 --gamma=12.4 \ --mirror=1 --aug=noaug --data=data/SHHQ-1.0/ --square=False --snap=250 ``` ### Pretrained models Please put the downloaded pretrained models [from above link](#Model-Zoo) under the folder 'pretrained_models'. ### Generate full-body human images using our pretrained model ``` # Generate human full-body images without truncation python generate.py --outdir=outputs/generate/stylegan_human_v2_1024 --trunc=1 --seeds=1,3,5,7 --network=pretrained_models/stylegan_human_v2_1024.pkl --version 2 # Generate human full-body images with truncation python generate.py --outdir=outputs/generate/stylegan_human_v2_1024 --trunc=0.8 --seeds=0-10 --network=pretrained_models/stylegan_human_v2_1024.pkl --version 2 # Generate human full-body images using stylegan V1 python generate.py --outdir=outputs/generate/stylegan_human_v1_1024 --network=pretrained_models/stylegan_human_v1_1024.pkl --version 1 --seeds=1,3,5 # Generate human full-body images using stylegan V3 python generate.py --outdir=outputs/generate/stylegan_human_v3_512 --network=pretrained_models/stylegan_human_v3_512.pkl --version 3 --seeds=1,3,5 ``` #### Note: The following demos are generated based on models related to StyleGAN V2 (stylegan_human_v2_512.pkl and stylegan_human_v2_1024.pkl). If you want to see results for V1 or V3, you need to change the loading method of the corresponding models. ### Interpolation ``` python interpolation.py --network=pretrained_models/stylegan_human_v2_1024.pkl --seeds=85,100 --outdir=outputs/inter_gifs ``` ### Style-mixing **image** using stylegan2 ``` python style_mixing.py --network=pretrained_models/stylegan_human_v2_1024.pkl --rows=85,100,75,458,1500 \\ --cols=55,821,1789,293 --styles=0-3 --outdir=outputs/stylemixing ``` ### Style-mixing **video** using stylegan2 ``` python stylemixing_video.py --network=pretrained_models/stylegan_human_v2_1024.pkl --row-seed=3859 \\ --col-seeds=3098,31759,3791 --col-styles=8-12 --trunc=0.8 --outdir=outputs/stylemixing_video ``` ### Aligned raw images For alignment, we use [openpose-pytorch](https://github.com/Hzzone/pytorch-openpose) for body-keypoints detection and [PaddlePaddle](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.5/contrib/PP-HumanSeg) for human segmentation. Before running the alignment script, few models need to be installed: 1. download [body_pose_model.pth](https://drive.google.com/drive/folders/1JsvI4M4ZTg98fmnCZLFM-3TeovnCRElG?usp=sharing) and place it into openpose/model/. 2. download and extract [deeplabv3p_resnet50_os8_humanseg_512x512_100k_with_softmax](https://paddleseg.bj.bcebos.com/dygraph/humanseg/export/deeplabv3p_resnet50_os8_humanseg_512x512_100k_with_softmax.zip) into PP_HumanSeg/export_model/deeplabv3p_resnet50_os8_humanseg_512x512_100k_with_softmax. 3. download and extract [deeplabv3p_resnet50_os8_humanseg_512x512_100k](https://paddleseg.bj.bcebos.com/dygraph/humanseg/train/deeplabv3p_resnet50_os8_humanseg_512x512_100k.zip) into PP_HumanSeg/pretrained_model/deeplabv3p_resnet50_os8_humanseg_512x512_100k. 4. install paddlepaddel: ``` pip install paddleseg ``` Then you can start alignment: ``` python alignment.py --image-folder img/test/ --output-folder aligned_image/ ``` ### Invert real image with [PTI](https://github.com/danielroich/PTI) Before inversion, please download our PTI weights: [e4e_w+.pt](https://drive.google.com/file/d/1NUfSJqLhsrU7c9PwAtlZ9xtrxhzS_6tu/view?usp=sharing) into /pti/. Few parameters you can change: - /pti/pti_configs/hyperparameters.py: - first_inv_type = 'w+' -> Use pretrained e4e encoder - first_inv_type = 'w' -> Use projection and optimization - /pti/pti_configs/paths_config.py: - input_data_path: path of real images - e4e: path of e4e_w+.pt - stylegan2_ada_shhq: pretrained stylegan2-ada model for SHHQ ``` python run_pti.py ``` Note: we used the test image under 'aligned_image/' (the output of alignment.py), the inverted latent code and fine-tuned generator will be saved in 'outputs/pti/' ### Editing with InterfaceGAN, StyleSpace, and Sefa ``` python edit.py --network pretrained_models/stylegan_human_v2_1024.pkl --attr_name upper_length \\ --seeds 61531,61570,61571,61610 --outdir outputs/edit_results ``` ### Editing using inverted latent code ``` python edit.py ---network outputs/pti/checkpoints/model_test.pkl --attr_name upper_length \\ --outdir outputs/edit_results --real True --real_w_path outputs/pti/embeddings/test/PTI/test/0.pt --real_img_path aligned_image/test.png ``` Note: 1. ''upper_length'' and ''bottom_length'' of ''attr_name'' are available for demo. 2. Layers to control and editing strength are set in edit/edit_config.py. ### Demo for [InsetGAN](https://arxiv.org/abs/2203.07293) We implement a quick demo using the key idea from InsetGAN: combining the face generated by FFHQ with the human-body generated by our pretrained model, optimizing both face and body latent codes to get a coherent full-body image. Before running the script, you need to download the [FFHQ face model]( https://docs.google.com/uc?export=download&confirm=t&id=125OG7SMkXI-Kf2aqiwLLHyCvSW-gZk3M), or you can use your own face model, as well as [pretrained face landmark](https://docs.google.com/uc?export=download&confirm=&id=1A82DnJBJzt8wI2J8ZrCK5fgHcQ2-tcWM) and [pretrained CNN face detection model for dlib](https://docs.google.com/uc?export=download&confirm=&id=1MduBgju5KFNrQfDLoQXJ_1_h5MnctCIG) ``` python insetgan.py --body_network=pretrained_models/stylegan_human_v2_1024.pkl --face_network=pretrained_models/ffhq.pkl \\ --body_seed=82 --face_seed=43 --trunc=0.6 --outdir=outputs/insetgan/ --video 1 ``` ## Results ### Editing with inverted real image (from left to right: real image | inverted image | InterFaceGAN result | StyleSpace result | SeFa result) https://user-images.githubusercontent.com/98547009/173773800-bb7fe54a-84d3-4b30-9864-a6b7b311f8ff.mp4 ### For more demo, please visit our [**web page**](https://stylegan-human.github.io/) . ## TODO List - [ ] Release 1024x512 version of StyleGAN-Human based on StyleGAN3 - [ ] Release 512x256 version of StyleGAN-Human based on StyleGAN1 - [ ] Extension of downstream application (InsetGAN): Add face inversion interface to support fusing user face image and stylegen-human body image - [x] Add Inversion Script into the provided editing pipeline - [ ] Release Dataset ## Related Works * (SIGGRAPH 2022) **Text2Human: Text-Driven Controllable Human Image Generation**, Yuming Jiang et al. [[Paper](https://arxiv.org/pdf/2205.15996.pdf)], [[Code](https://github.com/yumingj/Text2Human)], [[Project Page](https://yumingj.github.io/projects/Text2Human.html)], [[Dataset](https://github.com/yumingj/DeepFashion-MultiModal)] * (ICCV 2021) **Talk-to-Edit: Fine-Grained Facial Editing via Dialog**, Yuming Jiang et al. [[Paper](https://arxiv.org/abs/2109.04425)], [[Code](https://github.com/yumingj/Talk-to-Edit)], [[Project Page](https://www.mmlab-ntu.com/project/talkedit/)], [[Dataset](https://mmlab.ie.cuhk.edu.hk/projects/CelebA/CelebA_Dialog.html)] * (Technical Report 2022) **Generalizable Neural Performer: Learning Robust Radiance Fields for Human Novel View Synthesis**, Wei Cheng et al. [[Paper](https://arxiv.org/pdf/2204.11798.pdf)], [[Code](https://github.com/generalizable-neural-performer/gnr)], [[Project Page](https://generalizable-neural-performer.github.io/)], [[Dataset](https://generalizable-neural-performer.github.io/genebody.html)] ## Citation If you find this work useful for your research, please consider citing our paper: ```bibtex @article{fu2022styleganhuman, title={StyleGAN-Human: A Data-Centric Odyssey of Human Generation}, author={Fu, Jianglin and Li, Shikai and Jiang, Yuming and Lin, Kwan-Yee and Qian, Chen and Loy, Chen-Change and Wu, Wayne and Liu, Ziwei}, journal = {arXiv preprint}, volume = {arXiv:2204.11823}, year = {2022} ``` ## Acknowlegement Part of the code is borrowed from [stylegan (tensorflow)](https://github.com/NVlabs/stylegan), [stylegan2-ada (pytorch)](https://github.com/NVlabs/stylegan2-ada-pytorch), [stylegan3 (pytorch)](https://github.com/NVlabs/stylegan3). ================================================ FILE: stylegan_human/__init__.py ================================================ ================================================ FILE: stylegan_human/alignment.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. import os import argparse import numpy as np import torch from torch.utils.data import DataLoader from torchvision.transforms import transforms from utils.ImagesDataset import ImagesDataset import cv2 import time import copy import imutils # for openpose body keypoint detector : # (src:https://github.com/Hzzone/pytorch-openpose) from openpose.src import util from openpose.src.body import Body # for paddlepaddle human segmentation : #(src: https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/contrib/PP-HumanSeg/) from PP_HumanSeg.deploy.infer import Predictor as PP_HumenSeg_Predictor import math def angle_between_points(p0,p1,p2): if p0[1]==-1 or p1[1]==-1 or p2[1]==-1: return -1 a = (p1[0]-p0[0])**2 + (p1[1]-p0[1])**2 b = (p1[0]-p2[0])**2 + (p1[1]-p2[1])**2 c = (p2[0]-p0[0])**2 + (p2[1]-p0[1])**2 if a * b == 0: return -1 return math.acos((a+b-c) / math.sqrt(4*a*b)) * 180 / math.pi def crop_img_with_padding(img, keypoints, rect): person_xmin,person_xmax, ymin, ymax= rect img_h,img_w,_ = img.shape ## find body center using keypoints middle_shoulder_x = keypoints[1][0] middle_hip_x = (keypoints[8][0] + keypoints[11][0]) // 2 mid_x = (middle_hip_x + middle_shoulder_x) // 2 mid_y = (ymin + ymax) // 2 ## find which side (l or r) is further than center x, use the further side if abs(mid_x-person_xmin) > abs(person_xmax-mid_x): #left further xmin = person_xmin xmax = mid_x + (mid_x-person_xmin) else: ############### may be negtive ### in this case, the script won't output any image, leave the case like this ### since we don't want to pad human body xmin = mid_x - (person_xmax-mid_x) xmax = person_xmax w = xmax - xmin h = ymax - ymin ## pad rectangle to w:h = 1:2 ## calculate desired border length if h / w >= 2: #pad horizontally target_w = h // 2 xmin_prime = int(mid_x - target_w / 2) xmax_prime = int(mid_x + target_w / 2) if xmin_prime < 0: pad_left = abs(xmin_prime)# - xmin xmin = 0 else: pad_left = 0 xmin = xmin_prime if xmax_prime > img_w: pad_right = xmax_prime - img_w xmax = img_w else: pad_right = 0 xmax = xmax_prime cropped_img = img[int(ymin):int(ymax), int(xmin):int(xmax)] im_pad = cv2.copyMakeBorder(cropped_img, 0, 0, int(pad_left), int(pad_right), cv2.BORDER_REPLICATE) else: #pad vertically target_h = w * 2 ymin_prime = mid_y - (target_h / 2) ymax_prime = mid_y + (target_h / 2) if ymin_prime < 0: pad_up = abs(ymin_prime)# - ymin ymin = 0 else: pad_up = 0 ymin = ymin_prime if ymax_prime > img_h: pad_down = ymax_prime - img_h ymax = img_h else: pad_down = 0 ymax = ymax_prime print(ymin,ymax, xmin,xmax, img.shape) cropped_img = img[int(ymin):int(ymax), int(xmin):int(xmax)] im_pad = cv2.copyMakeBorder(cropped_img, int(pad_up), int(pad_down), 0, 0, cv2.BORDER_REPLICATE) result = cv2.resize(im_pad,(512,1024),interpolation = cv2.INTER_AREA) return result def run(args): os.makedirs(args.output_folder, exist_ok=True) dataset = ImagesDataset(args.image_folder, transforms.Compose([transforms.ToTensor()])) dataloader = DataLoader(dataset, batch_size=1, shuffle=False) body_estimation = Body('openpose/model/body_pose_model.pth') total = len(dataloader) print('Num of dataloader : ', total) os.makedirs(f'{args.output_folder}', exist_ok=True) # os.makedirs(f'{args.output_folder}/middle_result', exist_ok=True) ## initialzide HumenSeg human_seg_args = {} human_seg_args['cfg'] = 'PP_HumanSeg/export_model/deeplabv3p_resnet50_os8_humanseg_512x512_100k_with_softmax/deploy.yaml' human_seg_args['input_shape'] = [1024,512] human_seg_args['save_dir'] = args.output_folder human_seg_args['soft_predict'] = False human_seg_args['use_gpu'] = True human_seg_args['test_speed'] = False human_seg_args['use_optic_flow'] = False human_seg_args['add_argmax'] = True human_seg_args= argparse.Namespace(**human_seg_args) human_seg = PP_HumenSeg_Predictor(human_seg_args) from tqdm import tqdm for fname, image in tqdm(dataloader): # try: ## tensor to numpy image fname = fname[0] print(f'Processing \'{fname}\'.') image = (image.permute(0, 2, 3, 1) * 255).clamp(0, 255) image = image.squeeze(0).numpy() # --> tensor to numpy, (H,W,C) # avoid super high res img if image.shape[0] >= 2000: # height ### for shein image ratio = image.shape[0]/1200 #height dim = (int(image.shape[1]/ratio),1200)#(width, height) image = cv2.resize(image, dim, interpolation = cv2.INTER_AREA) image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) ## create segmentation # mybg = cv2.imread('mybg.png') comb, segmentation, bg, ori_img = human_seg.run(image,None) #mybg) # cv2.imwrite('comb.png',comb) # [0,255] # cv2.imwrite('alpha.png',segmentation*255) # segmentation [0,1] --> [0.255] # cv2.imwrite('bg.png',bg) #[0,255] # cv2.imwrite('ori_img.png',ori_img) # [0,255] masks_np = (segmentation* 255)# .byte().cpu().numpy() #1024,512,1 mask0_np = masks_np[:,:,0].astype(np.uint8)#[0, :, :] contours = cv2.findContours(mask0_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cnts = imutils.grab_contours(contours) c = max(cnts, key=cv2.contourArea) extTop = tuple(c[c[:, :, 1].argmin()][0]) extBot = tuple(c[c[:, :, 1].argmax()][0]) extBot = list(extBot) extTop = list(extTop) pad_range = int((extBot[1]-extTop[1])*0.05) if (int(extTop[1])<=5 and int(extTop[1])>0) and (comb.shape[0]>int(extBot[1]) and int(extBot[1])>=comb.shape[0]-5): #seg mask already reaches to the edge #pad with pure white, top 100 px, bottom 100 px comb= cv2.copyMakeBorder(comb,pad_range+5,pad_range+5,0,0,cv2.BORDER_CONSTANT,value=[255,255,255]) elif int(extTop[1])<=0 or int(extBot[1])>=comb.shape[0]: print('PAD: body out of boundary', fname) #should not happened return {} else: comb = cv2.copyMakeBorder(comb, pad_range+5, pad_range+5, 0, 0, cv2.BORDER_REPLICATE) #105 instead of 100: give some extra space extBot[1] = extBot[1] + pad_range+5 extTop[1] = extTop[1] + pad_range+5 extLeft = tuple(c[c[:, :, 0].argmin()][0]) extRight = tuple(c[c[:, :, 0].argmax()][0]) extLeft = list(extLeft) extRight = list(extRight) person_ymin = int(extTop[1])-pad_range # 100 person_ymax = int(extBot[1])+pad_range # 100 #height if person_ymin<0 or person_ymax>comb.shape[0]: # out of range return {} person_xmin = int(extLeft[0]) person_xmax = int(extRight[0]) rect = [person_xmin,person_xmax,person_ymin, person_ymax] # recimg = copy.deepcopy(comb) # cv2.rectangle(recimg,(person_xmin,person_ymin),(person_xmax,person_ymax),(0,255,0),2) # cv2.imwrite(f'{args.output_folder}/middle_result/{fname}_rec.png',recimg) ## detect keypoints keypoints, subset = body_estimation(comb) # print(keypoints, subset, len(subset)) if len(subset) != 1 or (len(subset)==1 and subset[0][-1]<15): print(f'Processing \'{fname}\'. Please import image contains one person only. Also can check segmentation mask. ') continue # canvas = copy.deepcopy(comb) # canvas = util.draw_bodypose(canvas, keypoints, subset, show_number=True) # cv2.imwrite(f'{args.output_folder}/middle_result/{fname}_keypoints.png',canvas) comb = crop_img_with_padding(comb, keypoints, rect) cv2.imwrite(f'{args.output_folder}/{fname}.png', comb) print(f' -- Finished processing \'{fname}\'. --') # except: # print(f'Processing \'{fname}\'. Not satisfied the alignment strategy.') if __name__ == '__main__': torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = False t1 = time.time() arg_formatter = argparse.ArgumentDefaultsHelpFormatter description = 'StyleGAN-Human data process' parser = argparse.ArgumentParser(formatter_class=arg_formatter, description=description) parser.add_argument('--image-folder', type=str, dest='image_folder') parser.add_argument('--output-folder', dest='output_folder', default='results', type=str) # parser.add_argument('--cfg', dest='cfg for segmentation', default='PP_HumanSeg/export_model/ppseg_lite_portrait_398x224_with_softmax/deploy.yaml', type=str) print('parsing arguments') cmd_args = parser.parse_args() run(cmd_args) print('total time elapsed: ', str(time.time() - t1)) ================================================ FILE: stylegan_human/bg_white.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. import os import click import cv2 import numpy as np def bg_white(seg, raw, blur_level=3, gaussian=81): seg = cv2.blur(seg, (blur_level, blur_level)) empty = np.ones_like(seg) seg_bg = (empty - seg) * 255 seg_bg = cv2.GaussianBlur(seg_bg,(gaussian,gaussian),0) background_mask = cv2.cvtColor(255 - cv2.cvtColor(seg, cv2.COLOR_BGR2GRAY), cv2.COLOR_GRAY2BGR) masked_fg = (raw * (1 / 255)) * (seg * (1 / 255)) masked_bg = (seg_bg * (1 / 255)) * (background_mask * (1 / 255)) frame = np.uint8(cv2.add(masked_bg,masked_fg)*255) return frame """ To turn background into white. Examples: \b python bg_white.py --raw_img_dir=./SHHQ-1.0/no_segment/ --raw_seg_dir=./SHHQ-1.0/segments/ \\ --outdir=./SHHQ-1.0/bg_white/ """ @click.command() @click.pass_context @click.option('--raw_img_dir', default="./SHHQ-1.0/no_segment/", help='folder of raw image', required=True) @click.option('--raw_seg_dir', default='./SHHQ-1.0/segments/', help='folder of segmentation masks', required=True) @click.option('--outdir', help='Where to save the output images', default= "./SHHQ-1.0/bg_white/" , type=str, required=True, metavar='DIR') def main( ctx: click.Context, raw_img_dir: str, raw_seg_dir: str, outdir: str): os.makedirs(outdir, exist_ok=True) files = os.listdir(raw_img_dir) for file in files: print(file) raw = cv2.imread(os.path.join(raw_img_dir, file)) seg = cv2.imread(os.path.join(raw_seg_dir, file)) assert raw is not None assert seg is not None white_frame = bg_white(seg, raw) cv2.imwrite(os.path.join(outdir,file), white_frame) if __name__ == "__main__": main() ================================================ FILE: stylegan_human/dnnlib/__init__.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. from .util import EasyDict, make_cache_dir_path ================================================ FILE: stylegan_human/dnnlib/tflib/__init__.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. # # This work is made available under the Nvidia Source Code License-NC. # To view a copy of this license, visit # https://nvlabs.github.io/stylegan2/license.html from . import autosummary from . import network from . import optimizer from . import tfutil from . import custom_ops from .tfutil import * from .network import Network from .optimizer import Optimizer from .custom_ops import get_plugin ================================================ FILE: stylegan_human/dnnlib/tflib/autosummary.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. # # This work is made available under the Nvidia Source Code License-NC. # To view a copy of this license, visit # https://nvlabs.github.io/stylegan2/license.html """Helper for adding automatically tracked values to Tensorboard. Autosummary creates an identity op that internally keeps track of the input values and automatically shows up in TensorBoard. The reported value represents an average over input components. The average is accumulated constantly over time and flushed when save_summaries() is called. Notes: - The output tensor must be used as an input for something else in the graph. Otherwise, the autosummary op will not get executed, and the average value will not get accumulated. - It is perfectly fine to include autosummaries with the same name in several places throughout the graph, even if they are executed concurrently. - It is ok to also pass in a python scalar or numpy array. In this case, it is added to the average immediately. """ from collections import OrderedDict import numpy as np import tensorflow as tf from tensorboard import summary as summary_lib from tensorboard.plugins.custom_scalar import layout_pb2 from . import tfutil from .tfutil import TfExpression from .tfutil import TfExpressionEx # Enable "Custom scalars" tab in TensorBoard for advanced formatting. # Disabled by default to reduce tfevents file size. enable_custom_scalars = False _dtype = tf.float64 _vars = OrderedDict() # name => [var, ...] _immediate = OrderedDict() # name => update_op, update_value _finalized = False _merge_op = None def _create_var(name: str, value_expr: TfExpression) -> TfExpression: """Internal helper for creating autosummary accumulators.""" assert not _finalized name_id = name.replace("/", "_") v = tf.cast(value_expr, _dtype) if v.shape.is_fully_defined(): size = np.prod(v.shape.as_list()) size_expr = tf.constant(size, dtype=_dtype) else: size = None size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype)) if size == 1: if v.shape.ndims != 0: v = tf.reshape(v, []) v = [size_expr, v, tf.square(v)] else: v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))] v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype)) with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None): var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)] update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v)) if name in _vars: _vars[name].append(var) else: _vars[name] = [var] return update_op def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None, condition: TfExpressionEx = True) -> TfExpressionEx: """Create a new autosummary. Args: name: Name to use in TensorBoard value: TensorFlow expression or python value to track passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node. Example use of the passthru mechanism: n = autosummary('l2loss', loss, passthru=n) This is a shorthand for the following code: with tf.control_dependencies([autosummary('l2loss', loss)]): n = tf.identity(n) """ tfutil.assert_tf_initialized() name_id = name.replace("/", "_") if tfutil.is_tf_expression(value): with tf.name_scope("summary_" + name_id), tf.device(value.device): condition = tf.convert_to_tensor(condition, name='condition') update_op = tf.cond(condition, lambda: tf.group(_create_var(name, value)), tf.no_op) with tf.control_dependencies([update_op]): return tf.identity(value if passthru is None else passthru) else: # python scalar or numpy array assert not tfutil.is_tf_expression(passthru) assert not tfutil.is_tf_expression(condition) if condition: if name not in _immediate: with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None): update_value = tf.placeholder(_dtype) update_op = _create_var(name, update_value) _immediate[name] = update_op, update_value update_op, update_value = _immediate[name] tfutil.run(update_op, {update_value: value}) return value if passthru is None else passthru def finalize_autosummaries() -> None: """Create the necessary ops to include autosummaries in TensorBoard report. Note: This should be done only once per graph. """ global _finalized tfutil.assert_tf_initialized() if _finalized: return None _finalized = True tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list]) # Create summary ops. with tf.device(None), tf.control_dependencies(None): for name, vars_list in _vars.items(): name_id = name.replace("/", "_") with tfutil.absolute_name_scope("Autosummary/" + name_id): moments = tf.add_n(vars_list) moments /= moments[0] with tf.control_dependencies([moments]): # read before resetting reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list] with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting mean = moments[1] std = tf.sqrt(moments[2] - tf.square(moments[1])) tf.summary.scalar(name, mean) if enable_custom_scalars: tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std) tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std) # Setup layout for custom scalars. layout = None if enable_custom_scalars: cat_dict = OrderedDict() for series_name in sorted(_vars.keys()): p = series_name.split("/") cat = p[0] if len(p) >= 2 else "" chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1] if cat not in cat_dict: cat_dict[cat] = OrderedDict() if chart not in cat_dict[cat]: cat_dict[cat][chart] = [] cat_dict[cat][chart].append(series_name) categories = [] for cat_name, chart_dict in cat_dict.items(): charts = [] for chart_name, series_names in chart_dict.items(): series = [] for series_name in series_names: series.append(layout_pb2.MarginChartContent.Series( value=series_name, lower="xCustomScalars/" + series_name + "/margin_lo", upper="xCustomScalars/" + series_name + "/margin_hi")) margin = layout_pb2.MarginChartContent(series=series) charts.append(layout_pb2.Chart(title=chart_name, margin=margin)) categories.append(layout_pb2.Category(title=cat_name, chart=charts)) layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories)) return layout def save_summaries(file_writer, global_step=None): """Call FileWriter.add_summary() with all summaries in the default graph, automatically finalizing and merging them on the first call. """ global _merge_op tfutil.assert_tf_initialized() if _merge_op is None: layout = finalize_autosummaries() if layout is not None: file_writer.add_summary(layout) with tf.device(None), tf.control_dependencies(None): _merge_op = tf.summary.merge_all() file_writer.add_summary(_merge_op.eval(), global_step) ================================================ FILE: stylegan_human/dnnlib/tflib/custom_ops.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. # # This work is made available under the Nvidia Source Code License-NC. # To view a copy of this license, visit # https://nvlabs.github.io/stylegan2/license.html """TensorFlow custom ops builder. """ import os import re import uuid import hashlib import tempfile import shutil import tensorflow as tf from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module #---------------------------------------------------------------------------- # Global options. cuda_cache_path = os.path.join(os.path.dirname(__file__), '_cudacache') cuda_cache_version_tag = 'v1' do_not_hash_included_headers = False # Speed up compilation by assuming that headers included by the CUDA code never change. Unsafe! verbose = True # Print status messages to stdout. compiler_bindir_search_path = [ 'C:/Program Files (x86)/Microsoft Visual Studio/2017/Community/VC/Tools/MSVC/14.14.26428/bin/Hostx64/x64', 'C:/Program Files (x86)/Microsoft Visual Studio/2019/Community/VC/Tools/MSVC/14.23.28105/bin/Hostx64/x64', 'C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin', ] #---------------------------------------------------------------------------- # Internal helper funcs. def _find_compiler_bindir(): for compiler_path in compiler_bindir_search_path: if os.path.isdir(compiler_path): return compiler_path return None def _get_compute_cap(device): caps_str = device.physical_device_desc m = re.search('compute capability: (\\d+).(\\d+)', caps_str) major = m.group(1) minor = m.group(2) return (major, minor) def _get_cuda_gpu_arch_string(): gpus = [x for x in device_lib.list_local_devices() if x.device_type == 'GPU'] if len(gpus) == 0: raise RuntimeError('No GPU devices found') (major, minor) = _get_compute_cap(gpus[0]) return 'sm_%s%s' % (major, minor) def _run_cmd(cmd): with os.popen(cmd) as pipe: output = pipe.read() status = pipe.close() if status is not None: raise RuntimeError('NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s' % (cmd, output)) def _prepare_nvcc_cli(opts): cmd = 'nvcc ' + opts.strip() cmd += ' --disable-warnings' cmd += ' --include-path "%s"' % tf.sysconfig.get_include() cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'protobuf_archive', 'src') cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'com_google_absl') cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'eigen_archive') compiler_bindir = _find_compiler_bindir() if compiler_bindir is None: # Require that _find_compiler_bindir succeeds on Windows. Allow # nvcc to use whatever is the default on Linux. if os.name == 'nt': raise RuntimeError('Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".' % __file__) else: cmd += ' --compiler-bindir "%s"' % compiler_bindir cmd += ' 2>&1' return cmd #---------------------------------------------------------------------------- # Main entry point. _plugin_cache = dict() def get_plugin(cuda_file): cuda_file_base = os.path.basename(cuda_file) cuda_file_name, cuda_file_ext = os.path.splitext(cuda_file_base) # Already in cache? if cuda_file in _plugin_cache: return _plugin_cache[cuda_file] # Setup plugin. if verbose: print('Setting up TensorFlow plugin "%s": ' % cuda_file_base, end='', flush=True) try: # Hash CUDA source. md5 = hashlib.md5() with open(cuda_file, 'rb') as f: md5.update(f.read()) md5.update(b'\n') # Hash headers included by the CUDA code by running it through the preprocessor. if not do_not_hash_included_headers: if verbose: print('Preprocessing... ', end='', flush=True) with tempfile.TemporaryDirectory() as tmp_dir: tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + cuda_file_ext) _run_cmd(_prepare_nvcc_cli('"%s" --preprocess -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir))) with open(tmp_file, 'rb') as f: bad_file_str = ('"' + cuda_file.replace('\\', '/') + '"').encode('utf-8') # __FILE__ in error check macros good_file_str = ('"' + cuda_file_base + '"').encode('utf-8') for ln in f: if not ln.startswith(b'# ') and not ln.startswith(b'#line '): # ignore line number pragmas ln = ln.replace(bad_file_str, good_file_str) md5.update(ln) md5.update(b'\n') # Select compiler options. compile_opts = '' if os.name == 'nt': compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.lib') elif os.name == 'posix': compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.so') compile_opts += ' --compiler-options \'-fPIC -D_GLIBCXX_USE_CXX11_ABI=0\'' else: assert False # not Windows or Linux, w00t? compile_opts += ' --gpu-architecture=%s' % _get_cuda_gpu_arch_string() compile_opts += ' --use_fast_math' nvcc_cmd = _prepare_nvcc_cli(compile_opts) # Hash build configuration. md5.update(('nvcc_cmd: ' + nvcc_cmd).encode('utf-8') + b'\n') md5.update(('tf.VERSION: ' + tf.VERSION).encode('utf-8') + b'\n') md5.update(('cuda_cache_version_tag: ' + cuda_cache_version_tag).encode('utf-8') + b'\n') # Compile if not already compiled. bin_file_ext = '.dll' if os.name == 'nt' else '.so' bin_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + md5.hexdigest() + bin_file_ext) if not os.path.isfile(bin_file): if verbose: print('Compiling... ', end='', flush=True) with tempfile.TemporaryDirectory() as tmp_dir: tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + bin_file_ext) _run_cmd(nvcc_cmd + ' "%s" --shared -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir)) os.makedirs(cuda_cache_path, exist_ok=True) intermediate_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + uuid.uuid4().hex + '_tmp' + bin_file_ext) shutil.copyfile(tmp_file, intermediate_file) os.rename(intermediate_file, bin_file) # atomic # Load. if verbose: print('Loading... ', end='', flush=True) plugin = tf.load_op_library(bin_file) # Add to cache. _plugin_cache[cuda_file] = plugin if verbose: print('Done.', flush=True) return plugin except: if verbose: print('Failed!', flush=True) raise #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/dnnlib/tflib/network.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. # # This work is made available under the Nvidia Source Code License-NC. # To view a copy of this license, visit # https://nvlabs.github.io/stylegan2/license.html """Helper for managing networks.""" import types import inspect import re import uuid import sys import numpy as np import tensorflow as tf from collections import OrderedDict from typing import Any, List, Tuple, Union from . import tfutil from .. import util from .tfutil import TfExpression, TfExpressionEx _import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import. _import_module_src = dict() # Source code for temporary modules created during pickle import. def import_handler(handler_func): """Function decorator for declaring custom import handlers.""" _import_handlers.append(handler_func) return handler_func class Network: """Generic network abstraction. Acts as a convenience wrapper for a parameterized network construction function, providing several utility methods and convenient access to the inputs/outputs/weights. Network objects can be safely pickled and unpickled for long-term archival purposes. The pickling works reliably as long as the underlying network construction function is defined in a standalone Python module that has no side effects or application-specific imports. Args: name: Network name. Used to select TensorFlow name and variable scopes. func_name: Fully qualified name of the underlying network construction function, or a top-level function object. static_kwargs: Keyword arguments to be passed in to the network construction function. Attributes: name: User-specified name, defaults to build func name if None. scope: Unique TensorFlow scope containing template graph and variables, derived from the user-specified name. static_kwargs: Arguments passed to the user-supplied build func. components: Container for sub-networks. Passed to the build func, and retained between calls. num_inputs: Number of input tensors. num_outputs: Number of output tensors. input_shapes: Input tensor shapes (NC or NCHW), including minibatch dimension. output_shapes: Output tensor shapes (NC or NCHW), including minibatch dimension. input_shape: Short-hand for input_shapes[0]. output_shape: Short-hand for output_shapes[0]. input_templates: Input placeholders in the template graph. output_templates: Output tensors in the template graph. input_names: Name string for each input. output_names: Name string for each output. own_vars: Variables defined by this network (local_name => var), excluding sub-networks. vars: All variables (local_name => var). trainables: All trainable variables (local_name => var). var_global_to_local: Mapping from variable global names to local names. """ def __init__(self, name: str = None, func_name: Any = None, **static_kwargs): tfutil.assert_tf_initialized() assert isinstance(name, str) or name is None assert func_name is not None assert isinstance(func_name, str) or util.is_top_level_function(func_name) assert util.is_pickleable(static_kwargs) self._init_fields() self.name = name self.static_kwargs = util.EasyDict(static_kwargs) # Locate the user-specified network build function. if util.is_top_level_function(func_name): func_name = util.get_top_level_function_name(func_name) module, self._build_func_name = util.get_module_from_obj_name(func_name) self._build_func = util.get_obj_from_module(module, self._build_func_name) assert callable(self._build_func) # Dig up source code for the module containing the build function. self._build_module_src = _import_module_src.get(module, None) if self._build_module_src is None: self._build_module_src = inspect.getsource(module) # Init TensorFlow graph. self._init_graph() self.reset_own_vars() def _init_fields(self) -> None: self.name = None self.scope = None self.static_kwargs = util.EasyDict() self.components = util.EasyDict() self.num_inputs = 0 self.num_outputs = 0 self.input_shapes = [[]] self.output_shapes = [[]] self.input_shape = [] self.output_shape = [] self.input_templates = [] self.output_templates = [] self.input_names = [] self.output_names = [] self.own_vars = OrderedDict() self.vars = OrderedDict() self.trainables = OrderedDict() self.var_global_to_local = OrderedDict() self._build_func = None # User-supplied build function that constructs the network. self._build_func_name = None # Name of the build function. self._build_module_src = None # Full source code of the module containing the build function. self._run_cache = dict() # Cached graph data for Network.run(). def _init_graph(self) -> None: # Collect inputs. self.input_names = [] for param in inspect.signature(self._build_func).parameters.values(): if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty: self.input_names.append(param.name) self.num_inputs = len(self.input_names) assert self.num_inputs >= 1 # Choose name and scope. if self.name is None: self.name = self._build_func_name assert re.match("^[A-Za-z0-9_.\\-]*$", self.name) with tf.name_scope(None): self.scope = tf.get_default_graph().unique_name(self.name, mark_as_used=True) # Finalize build func kwargs. build_kwargs = dict(self.static_kwargs) build_kwargs["is_template_graph"] = True build_kwargs["components"] = self.components # Build template graph. with tfutil.absolute_variable_scope(self.scope, reuse=False), tfutil.absolute_name_scope(self.scope): # ignore surrounding scopes assert tf.get_variable_scope().name == self.scope assert tf.get_default_graph().get_name_scope() == self.scope with tf.control_dependencies(None): # ignore surrounding control dependencies self.input_templates = [tf.placeholder(tf.float32, name=name) for name in self.input_names] out_expr = self._build_func(*self.input_templates, **build_kwargs) # Collect outputs. assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple) self.output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr) self.num_outputs = len(self.output_templates) assert self.num_outputs >= 1 assert all(tfutil.is_tf_expression(t) for t in self.output_templates) # Perform sanity checks. if any(t.shape.ndims is None for t in self.input_templates): raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.") if any(t.shape.ndims is None for t in self.output_templates): raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.") if any(not isinstance(comp, Network) for comp in self.components.values()): raise ValueError("Components of a Network must be Networks themselves.") if len(self.components) != len(set(comp.name for comp in self.components.values())): raise ValueError("Components of a Network must have unique names.") # List inputs and outputs. self.input_shapes = [t.shape.as_list() for t in self.input_templates] self.output_shapes = [t.shape.as_list() for t in self.output_templates] self.input_shape = self.input_shapes[0] self.output_shape = self.output_shapes[0] self.output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates] # List variables. self.own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/")) self.vars = OrderedDict(self.own_vars) self.vars.update((comp.name + "/" + name, var) for comp in self.components.values() for name, var in comp.vars.items()) self.trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable) self.var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items()) def reset_own_vars(self) -> None: """Re-initialize all variables of this network, excluding sub-networks.""" tfutil.run([var.initializer for var in self.own_vars.values()]) def reset_vars(self) -> None: """Re-initialize all variables of this network, including sub-networks.""" tfutil.run([var.initializer for var in self.vars.values()]) def reset_trainables(self) -> None: """Re-initialize all trainable variables of this network, including sub-networks.""" tfutil.run([var.initializer for var in self.trainables.values()]) def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]: """Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s).""" assert len(in_expr) == self.num_inputs assert not all(expr is None for expr in in_expr) # Finalize build func kwargs. build_kwargs = dict(self.static_kwargs) build_kwargs.update(dynamic_kwargs) build_kwargs["is_template_graph"] = False build_kwargs["components"] = self.components # Build TensorFlow graph to evaluate the network. with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name): assert tf.get_variable_scope().name == self.scope valid_inputs = [expr for expr in in_expr if expr is not None] final_inputs = [] for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes): if expr is not None: expr = tf.identity(expr, name=name) else: expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name) final_inputs.append(expr) out_expr = self._build_func(*final_inputs, **build_kwargs) # Propagate input shapes back to the user-specified expressions. for expr, final in zip(in_expr, final_inputs): if isinstance(expr, tf.Tensor): expr.set_shape(final.shape) # Express outputs in the desired format. assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple) if return_as_list: out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr) return out_expr def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str: """Get the local name of a given variable, without any surrounding name scopes.""" assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str) global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name return self.var_global_to_local[global_name] def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression: """Find variable by local or global name.""" assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str) return self.vars[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray: """Get the value of a given variable as NumPy array. Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible.""" return self.find_var(var_or_local_name).eval() def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None: """Set the value of a given variable based on the given NumPy array. Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible.""" tfutil.set_vars({self.find_var(var_or_local_name): new_value}) def __getstate__(self) -> dict: """Pickle export.""" state = dict() state["version"] = 4 state["name"] = self.name state["static_kwargs"] = dict(self.static_kwargs) state["components"] = dict(self.components) state["build_module_src"] = self._build_module_src state["build_func_name"] = self._build_func_name state["variables"] = list(zip(self.own_vars.keys(), tfutil.run(list(self.own_vars.values())))) return state def __setstate__(self, state: dict) -> None: """Pickle import.""" # pylint: disable=attribute-defined-outside-init tfutil.assert_tf_initialized() self._init_fields() # Execute custom import handlers. for handler in _import_handlers: state = handler(state) # Set basic fields. assert state["version"] in [2, 3, 4] self.name = state["name"] self.static_kwargs = util.EasyDict(state["static_kwargs"]) self.components = util.EasyDict(state.get("components", {})) self._build_module_src = state["build_module_src"] self._build_func_name = state["build_func_name"] # Create temporary module from the imported source code. module_name = "_tflib_network_import_" + uuid.uuid4().hex module = types.ModuleType(module_name) sys.modules[module_name] = module _import_module_src[module] = self._build_module_src exec(self._build_module_src, module.__dict__) # pylint: disable=exec-used # Locate network build function in the temporary module. self._build_func = util.get_obj_from_module(module, self._build_func_name) assert callable(self._build_func) # Init TensorFlow graph. self._init_graph() self.reset_own_vars() tfutil.set_vars({self.find_var(name): value for name, value in state["variables"]}) def clone(self, name: str = None, **new_static_kwargs) -> "Network": """Create a clone of this network with its own copy of the variables.""" # pylint: disable=protected-access net = object.__new__(Network) net._init_fields() net.name = name if name is not None else self.name net.static_kwargs = util.EasyDict(self.static_kwargs) net.static_kwargs.update(new_static_kwargs) net._build_module_src = self._build_module_src net._build_func_name = self._build_func_name net._build_func = self._build_func net._init_graph() net.copy_vars_from(self) return net def copy_own_vars_from(self, src_net: "Network") -> None: """Copy the values of all variables from the given network, excluding sub-networks.""" names = [name for name in self.own_vars.keys() if name in src_net.own_vars] tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names})) def copy_vars_from(self, src_net: "Network") -> None: """Copy the values of all variables from the given network, including sub-networks.""" names = [name for name in self.vars.keys() if name in src_net.vars] tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names})) def copy_trainables_from(self, src_net: "Network") -> None: """Copy the values of all trainable variables from the given network, including sub-networks.""" names = [name for name in self.trainables.keys() if name in src_net.trainables] tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names})) def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network": """Create new network with the given parameters, and copy all variables from this network.""" if new_name is None: new_name = self.name static_kwargs = dict(self.static_kwargs) static_kwargs.update(new_static_kwargs) net = Network(name=new_name, func_name=new_func_name, **static_kwargs) net.copy_vars_from(self) return net def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation: """Construct a TensorFlow op that updates the variables of this network to be slightly closer to those of the given network.""" with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"): ops = [] for name, var in self.vars.items(): if name in src_net.vars: cur_beta = beta if name in self.trainables else beta_nontrainable new_value = tfutil.lerp(src_net.vars[name], var, cur_beta) ops.append(var.assign(new_value)) return tf.group(*ops) def run(self, *in_arrays: Tuple[Union[np.ndarray, None], ...], input_transform: dict = None, output_transform: dict = None, return_as_list: bool = False, print_progress: bool = False, minibatch_size: int = None, num_gpus: int = 1, assume_frozen: bool = False, **dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]: """Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s). Args: input_transform: A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network. The dict must contain a 'func' field that points to a top-level function. The function is called with the input TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs. output_transform: A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network. The dict must contain a 'func' field that points to a top-level function. The function is called with the output TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs. return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs. print_progress: Print progress to the console? Useful for very large input arrays. minibatch_size: Maximum minibatch size to use, None = disable batching. num_gpus: Number of GPUs to use. assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls. dynamic_kwargs: Additional keyword arguments to be passed into the network build function. """ assert len(in_arrays) == self.num_inputs assert not all(arr is None for arr in in_arrays) assert input_transform is None or util.is_top_level_function(input_transform["func"]) assert output_transform is None or util.is_top_level_function(output_transform["func"]) output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs) num_items = in_arrays[0].shape[0] if minibatch_size is None: minibatch_size = num_items # Construct unique hash key from all arguments that affect the TensorFlow graph. key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs) def unwind_key(obj): if isinstance(obj, dict): return [(key, unwind_key(value)) for key, value in sorted(obj.items())] if callable(obj): return util.get_top_level_function_name(obj) return obj key = repr(unwind_key(key)) # Build graph. if key not in self._run_cache: with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None): with tf.device("/cpu:0"): in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names] in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr])) out_split = [] for gpu in range(num_gpus): with tf.device("/gpu:%d" % gpu): net_gpu = self.clone() if assume_frozen else self in_gpu = in_split[gpu] if input_transform is not None: in_kwargs = dict(input_transform) in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs) in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu) assert len(in_gpu) == self.num_inputs out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs) if output_transform is not None: out_kwargs = dict(output_transform) out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs) out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu) assert len(out_gpu) == self.num_outputs out_split.append(out_gpu) with tf.device("/cpu:0"): out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)] self._run_cache[key] = in_expr, out_expr # Run minibatches. in_expr, out_expr = self._run_cache[key] out_arrays = [np.empty([num_items] + expr.shape.as_list()[1:], expr.dtype.name) for expr in out_expr] for mb_begin in range(0, num_items, minibatch_size): if print_progress: print("\r%d / %d" % (mb_begin, num_items), end="") mb_end = min(mb_begin + minibatch_size, num_items) mb_num = mb_end - mb_begin mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)] mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in))) for dst, src in zip(out_arrays, mb_out): dst[mb_begin: mb_end] = src # Done. if print_progress: print("\r%d / %d" % (num_items, num_items)) if not return_as_list: out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays) return out_arrays def list_ops(self) -> List[TfExpression]: include_prefix = self.scope + "/" exclude_prefix = include_prefix + "_" ops = tf.get_default_graph().get_operations() ops = [op for op in ops if op.name.startswith(include_prefix)] ops = [op for op in ops if not op.name.startswith(exclude_prefix)] return ops def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]: """Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to individual layers of the network. Mainly intended to be used for reporting.""" layers = [] def recurse(scope, parent_ops, parent_vars, level): # Ignore specific patterns. if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]): return # Filter ops and vars by scope. global_prefix = scope + "/" local_prefix = global_prefix[len(self.scope) + 1:] cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]] cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]] if not cur_ops and not cur_vars: return # Filter out all ops related to variables. for var in [op for op in cur_ops if op.type.startswith("Variable")]: var_prefix = var.name + "/" cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)] # Scope does not contain ops as immediate children => recurse deeper. contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type not in ["Identity", "Cast", "Transpose"] for op in cur_ops) if (level == 0 or not contains_direct_ops) and (len(cur_ops) + len(cur_vars)) > 1: visited = set() for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]: token = rel_name.split("/")[0] if token not in visited: recurse(global_prefix + token, cur_ops, cur_vars, level + 1) visited.add(token) return # Report layer. layer_name = scope[len(self.scope) + 1:] layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1] layer_trainables = [var for _name, var in cur_vars if var.trainable] layers.append((layer_name, layer_output, layer_trainables)) recurse(self.scope, self.list_ops(), list(self.vars.items()), 0) return layers def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None: """Print a summary table of the network structure.""" rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]] rows += [["---"] * 4] total_params = 0 for layer_name, layer_output, layer_trainables in self.list_layers(): num_params = sum(int(np.prod(var.shape.as_list())) for var in layer_trainables) weights = [var for var in layer_trainables if var.name.endswith("/weight:0")] weights.sort(key=lambda x: len(x.name)) if len(weights) == 0 and len(layer_trainables) == 1: weights = layer_trainables total_params += num_params if not hide_layers_with_no_params or num_params != 0: num_params_str = str(num_params) if num_params > 0 else "-" output_shape_str = str(layer_output.shape) weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-" rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]] rows += [["---"] * 4] rows += [["Total", str(total_params), "", ""]] widths = [max(len(cell) for cell in column) for column in zip(*rows)] print() for row in rows: print(" ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths))) print() def setup_weight_histograms(self, title: str = None) -> None: """Construct summary ops to include histograms of all trainable parameters in TensorBoard.""" if title is None: title = self.name with tf.name_scope(None), tf.device(None), tf.control_dependencies(None): for local_name, var in self.trainables.items(): if "/" in local_name: p = local_name.split("/") name = title + "_" + p[-1] + "/" + "_".join(p[:-1]) else: name = title + "_toplevel/" + local_name tf.summary.histogram(name, var) #---------------------------------------------------------------------------- # Backwards-compatible emulation of legacy output transformation in Network.run(). _print_legacy_warning = True def _handle_legacy_output_transforms(output_transform, dynamic_kwargs): global _print_legacy_warning legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"] if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs): return output_transform, dynamic_kwargs if _print_legacy_warning: _print_legacy_warning = False print() print("WARNING: Old-style output transformations in Network.run() are deprecated.") print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'") print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.") print() assert output_transform is None new_kwargs = dict(dynamic_kwargs) new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs} new_transform["func"] = _legacy_output_transform_func return new_transform, new_kwargs def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None): if out_mul != 1.0: expr = [x * out_mul for x in expr] if out_add != 0.0: expr = [x + out_add for x in expr] if out_shrink > 1: ksize = [1, 1, out_shrink, out_shrink] expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr] if out_dtype is not None: if tf.as_dtype(out_dtype).is_integer: expr = [tf.round(x) for x in expr] expr = [tf.saturate_cast(x, out_dtype) for x in expr] return expr ================================================ FILE: stylegan_human/dnnlib/tflib/ops/__init__.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. # # This work is made available under the Nvidia Source Code License-NC. # To view a copy of this license, visit # https://nvlabs.github.io/stylegan2/license.html # empty ================================================ FILE: stylegan_human/dnnlib/tflib/ops/fused_bias_act.cu ================================================ // Copyright (c) SenseTime Research. All rights reserved. // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. // // This work is made available under the Nvidia Source Code License-NC. // To view a copy of this license, visit // https://nvlabs.github.io/stylegan2/license.html #define EIGEN_USE_GPU #define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__ #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" #include using namespace tensorflow; using namespace tensorflow::shape_inference; #define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false) //------------------------------------------------------------------------ // CUDA kernel. template struct FusedBiasActKernelParams { const T* x; // [sizeX] const T* b; // [sizeB] or NULL const T* ref; // [sizeX] or NULL T* y; // [sizeX] int grad; int axis; int act; float alpha; float gain; int sizeX; int sizeB; int stepB; int loopX; }; template static __global__ void FusedBiasActKernel(const FusedBiasActKernelParams p) { const float expRange = 80.0f; const float halfExpRange = 40.0f; const float seluScale = 1.0507009873554804934193349852946f; const float seluAlpha = 1.6732632423543772848170429916717f; // Loop over elements. int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) { // Load and apply bias. float x = (float)p.x[xi]; if (p.b) x += (float)p.b[(xi / p.stepB) % p.sizeB]; float ref = (p.ref) ? (float)p.ref[xi] : 0.0f; if (p.gain != 0.0f & p.act != 9) ref /= p.gain; // Evaluate activation func. float y; switch (p.act * 10 + p.grad) { // linear default: case 10: y = x; break; case 11: y = x; break; case 12: y = 0.0f; break; // relu case 20: y = (x > 0.0f) ? x : 0.0f; break; case 21: y = (ref > 0.0f) ? x : 0.0f; break; case 22: y = 0.0f; break; // lrelu case 30: y = (x > 0.0f) ? x : x * p.alpha; break; case 31: y = (ref > 0.0f) ? x : x * p.alpha; break; case 32: y = 0.0f; break; // tanh case 40: { float c = expf(x); float d = 1.0f / c; y = (x < -expRange) ? -1.0f : (x > expRange) ? 1.0f : (c - d) / (c + d); } break; case 41: y = x * (1.0f - ref * ref); break; case 42: y = x * (1.0f - ref * ref) * (-2.0f * ref); break; // sigmoid case 50: y = (x < -expRange) ? 0.0f : 1.0f / (expf(-x) + 1.0f); break; case 51: y = x * ref * (1.0f - ref); break; case 52: y = x * ref * (1.0f - ref) * (1.0f - 2.0f * ref); break; // elu case 60: y = (x >= 0.0f) ? x : expf(x) - 1.0f; break; case 61: y = (ref >= 0.0f) ? x : x * (ref + 1.0f); break; case 62: y = (ref >= 0.0f) ? 0.0f : x * (ref + 1.0f); break; // selu case 70: y = (x >= 0.0f) ? seluScale * x : (seluScale * seluAlpha) * (expf(x) - 1.0f); break; case 71: y = (ref >= 0.0f) ? x * seluScale : x * (ref + seluScale * seluAlpha); break; case 72: y = (ref >= 0.0f) ? 0.0f : x * (ref + seluScale * seluAlpha); break; // softplus case 80: y = (x > expRange) ? x : logf(expf(x) + 1.0f); break; case 81: y = x * (1.0f - expf(-ref)); break; case 82: { float c = expf(-ref); y = x * c * (1.0f - c); } break; // swish case 90: y = (x < -expRange) ? 0.0f : x / (expf(-x) + 1.0f); break; case 91: { float c = expf(ref); float d = c + 1.0f; y = (ref > halfExpRange) ? x : x * c * (ref + d) / (d * d); } break; case 92: { float c = expf(ref); float d = c + 1.0f; y = (ref > halfExpRange) ? 0.0f : x * c * (ref * (2.0f - d) + 2.0f * d) / (d * d * d); } break; } // Apply gain and store. p.y[xi] = (T)(y * p.gain); } } //------------------------------------------------------------------------ // TensorFlow op. template struct FusedBiasActOp : public OpKernel { FusedBiasActKernelParams m_attribs; FusedBiasActOp(OpKernelConstruction* ctx) : OpKernel(ctx) { memset(&m_attribs, 0, sizeof(m_attribs)); OP_REQUIRES_OK(ctx, ctx->GetAttr("grad", &m_attribs.grad)); OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &m_attribs.axis)); OP_REQUIRES_OK(ctx, ctx->GetAttr("act", &m_attribs.act)); OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &m_attribs.alpha)); OP_REQUIRES_OK(ctx, ctx->GetAttr("gain", &m_attribs.gain)); OP_REQUIRES(ctx, m_attribs.grad >= 0, errors::InvalidArgument("grad must be non-negative")); OP_REQUIRES(ctx, m_attribs.axis >= 0, errors::InvalidArgument("axis must be non-negative")); OP_REQUIRES(ctx, m_attribs.act >= 0, errors::InvalidArgument("act must be non-negative")); } void Compute(OpKernelContext* ctx) { FusedBiasActKernelParams p = m_attribs; cudaStream_t stream = ctx->eigen_device().stream(); const Tensor& x = ctx->input(0); // [...] const Tensor& b = ctx->input(1); // [sizeB] or [0] const Tensor& ref = ctx->input(2); // x.shape or [0] p.x = x.flat().data(); p.b = (b.NumElements()) ? b.flat().data() : NULL; p.ref = (ref.NumElements()) ? ref.flat().data() : NULL; OP_REQUIRES(ctx, b.NumElements() == 0 || m_attribs.axis < x.dims(), errors::InvalidArgument("axis out of bounds")); OP_REQUIRES(ctx, b.dims() == 1, errors::InvalidArgument("b must have rank 1")); OP_REQUIRES(ctx, b.NumElements() == 0 || b.NumElements() == x.dim_size(m_attribs.axis), errors::InvalidArgument("b has wrong number of elements")); OP_REQUIRES(ctx, ref.NumElements() == ((p.grad == 0) ? 0 : x.NumElements()), errors::InvalidArgument("ref has wrong number of elements")); OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("x is too large")); p.sizeX = (int)x.NumElements(); p.sizeB = (int)b.NumElements(); p.stepB = 1; for (int i = m_attribs.axis + 1; i < x.dims(); i++) p.stepB *= (int)x.dim_size(i); Tensor* y = NULL; // x.shape OP_REQUIRES_OK(ctx, ctx->allocate_output(0, x.shape(), &y)); p.y = y->flat().data(); p.loopX = 4; int blockSize = 4 * 32; int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; void* args[] = {&p}; OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel((void*)FusedBiasActKernel, gridSize, blockSize, args, 0, stream)); } }; REGISTER_OP("FusedBiasAct") .Input ("x: T") .Input ("b: T") .Input ("ref: T") .Output ("y: T") .Attr ("T: {float, half}") .Attr ("grad: int = 0") .Attr ("axis: int = 1") .Attr ("act: int = 0") .Attr ("alpha: float = 0.0") .Attr ("gain: float = 1.0"); REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint("T"), FusedBiasActOp); REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint("T"), FusedBiasActOp); //------------------------------------------------------------------------ ================================================ FILE: stylegan_human/dnnlib/tflib/ops/fused_bias_act.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. # # This work is made available under the Nvidia Source Code License-NC. # To view a copy of this license, visit # https://nvlabs.github.io/stylegan2/license.html """Custom TensorFlow ops for efficient bias and activation.""" import os import numpy as np import tensorflow as tf from .. import custom_ops from ...util import EasyDict def _get_plugin(): return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu') #---------------------------------------------------------------------------- activation_funcs = { 'linear': EasyDict(func=lambda x, **_: x, def_alpha=None, def_gain=1.0, cuda_idx=1, ref='y', zero_2nd_grad=True), 'relu': EasyDict(func=lambda x, **_: tf.nn.relu(x), def_alpha=None, def_gain=np.sqrt(2), cuda_idx=2, ref='y', zero_2nd_grad=True), 'lrelu': EasyDict(func=lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', zero_2nd_grad=True), 'tanh': EasyDict(func=lambda x, **_: tf.nn.tanh(x), def_alpha=None, def_gain=1.0, cuda_idx=4, ref='y', zero_2nd_grad=False), 'sigmoid': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x), def_alpha=None, def_gain=1.0, cuda_idx=5, ref='y', zero_2nd_grad=False), 'elu': EasyDict(func=lambda x, **_: tf.nn.elu(x), def_alpha=None, def_gain=1.0, cuda_idx=6, ref='y', zero_2nd_grad=False), 'selu': EasyDict(func=lambda x, **_: tf.nn.selu(x), def_alpha=None, def_gain=1.0, cuda_idx=7, ref='y', zero_2nd_grad=False), 'softplus': EasyDict(func=lambda x, **_: tf.nn.softplus(x), def_alpha=None, def_gain=1.0, cuda_idx=8, ref='y', zero_2nd_grad=False), 'swish': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x) * x, def_alpha=None, def_gain=np.sqrt(2), cuda_idx=9, ref='x', zero_2nd_grad=False), } #---------------------------------------------------------------------------- def fused_bias_act(x, b=None, axis=1, act='linear', alpha=None, gain=None, impl='cuda'): r"""Fused bias and activation function. Adds bias `b` to activation tensor `x`, evaluates activation function `act`, and scales the result by `gain`. Each of the steps is optional. In most cases, the fused op is considerably more efficient than performing the same calculation using standard TensorFlow ops. It supports first and second order gradients, but not third order gradients. Args: x: Input activation tensor. Can have any shape, but if `b` is defined, the dimension corresponding to `axis`, as well as the rank, must be known. b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type as `x`. The shape must be known, and it must match the dimension of `x` corresponding to `axis`. axis: The dimension in `x` corresponding to the elements of `b`. The value of `axis` is ignored if `b` is not specified. act: Name of the activation function to evaluate, or `"linear"` to disable. Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. See `activation_funcs` for a full list. `None` is not allowed. alpha: Shape parameter for the activation function, or `None` to use the default. gain: Scaling factor for the output tensor, or `None` to use default. See `activation_funcs` for the default scaling of each activation function. If unsure, consider specifying `1.0`. impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). Returns: Tensor of the same shape and datatype as `x`. """ impl_dict = { 'ref': _fused_bias_act_ref, 'cuda': _fused_bias_act_cuda, } return impl_dict[impl](x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain) #---------------------------------------------------------------------------- def _fused_bias_act_ref(x, b, axis, act, alpha, gain): """Slow reference implementation of `fused_bias_act()` using standard TensorFlow ops.""" # Validate arguments. x = tf.convert_to_tensor(x) b = tf.convert_to_tensor(b) if b is not None else tf.constant([], dtype=x.dtype) act_spec = activation_funcs[act] assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis]) assert b.shape[0] == 0 or 0 <= axis < x.shape.rank if alpha is None: alpha = act_spec.def_alpha if gain is None: gain = act_spec.def_gain # Add bias. if b.shape[0] != 0: x += tf.reshape(b, [-1 if i == axis else 1 for i in range(x.shape.rank)]) # Evaluate activation function. x = act_spec.func(x, alpha=alpha) # Scale by gain. if gain != 1: x *= gain return x #---------------------------------------------------------------------------- def _fused_bias_act_cuda(x, b, axis, act, alpha, gain): """Fast CUDA implementation of `fused_bias_act()` using custom ops.""" # Validate arguments. x = tf.convert_to_tensor(x) empty_tensor = tf.constant([], dtype=x.dtype) b = tf.convert_to_tensor(b) if b is not None else empty_tensor act_spec = activation_funcs[act] assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis]) assert b.shape[0] == 0 or 0 <= axis < x.shape.rank if alpha is None: alpha = act_spec.def_alpha if gain is None: gain = act_spec.def_gain # Special cases. if act == 'linear' and b is None and gain == 1.0: return x if act_spec.cuda_idx is None: return _fused_bias_act_ref(x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain) # CUDA kernel. cuda_kernel = _get_plugin().fused_bias_act cuda_kwargs = dict(axis=axis, act=act_spec.cuda_idx, alpha=alpha, gain=gain) # Forward pass: y = func(x, b). def func_y(x, b): y = cuda_kernel(x=x, b=b, ref=empty_tensor, grad=0, **cuda_kwargs) y.set_shape(x.shape) return y # Backward pass: dx, db = grad(dy, x, y) def grad_dx(dy, x, y): ref = {'x': x, 'y': y}[act_spec.ref] dx = cuda_kernel(x=dy, b=empty_tensor, ref=ref, grad=1, **cuda_kwargs) dx.set_shape(x.shape) return dx def grad_db(dx): if b.shape[0] == 0: return empty_tensor db = dx if axis < x.shape.rank - 1: db = tf.reduce_sum(db, list(range(axis + 1, x.shape.rank))) if axis > 0: db = tf.reduce_sum(db, list(range(axis))) db.set_shape(b.shape) return db # Second order gradients: d_dy, d_x = grad2(d_dx, d_db, x, y) def grad2_d_dy(d_dx, d_db, x, y): ref = {'x': x, 'y': y}[act_spec.ref] d_dy = cuda_kernel(x=d_dx, b=d_db, ref=ref, grad=1, **cuda_kwargs) d_dy.set_shape(x.shape) return d_dy def grad2_d_x(d_dx, d_db, x, y): ref = {'x': x, 'y': y}[act_spec.ref] d_x = cuda_kernel(x=d_dx, b=d_db, ref=ref, grad=2, **cuda_kwargs) d_x.set_shape(x.shape) return d_x # Fast version for piecewise-linear activation funcs. @tf.custom_gradient def func_zero_2nd_grad(x, b): y = func_y(x, b) @tf.custom_gradient def grad(dy): dx = grad_dx(dy, x, y) db = grad_db(dx) def grad2(d_dx, d_db): d_dy = grad2_d_dy(d_dx, d_db, x, y) return d_dy return (dx, db), grad2 return y, grad # Slow version for general activation funcs. @tf.custom_gradient def func_nonzero_2nd_grad(x, b): y = func_y(x, b) def grad_wrap(dy): @tf.custom_gradient def grad_impl(dy, x): dx = grad_dx(dy, x, y) db = grad_db(dx) def grad2(d_dx, d_db): d_dy = grad2_d_dy(d_dx, d_db, x, y) d_x = grad2_d_x(d_dx, d_db, x, y) return d_dy, d_x return (dx, db), grad2 return grad_impl(dy, x) return y, grad_wrap # Which version to use? if act_spec.zero_2nd_grad: return func_zero_2nd_grad(x, b) return func_nonzero_2nd_grad(x, b) #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/dnnlib/tflib/ops/upfirdn_2d.cu ================================================ // Copyright (c) SenseTime Research. All rights reserved. // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. // // This work is made available under the Nvidia Source Code License-NC. // To view a copy of this license, visit // https://nvlabs.github.io/stylegan2/license.html #define EIGEN_USE_GPU #define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__ #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" #include using namespace tensorflow; using namespace tensorflow::shape_inference; //------------------------------------------------------------------------ // Helpers. #define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false) static __host__ __device__ __forceinline__ int floorDiv(int a, int b) { int c = a / b; if (c * b > a) c--; return c; } //------------------------------------------------------------------------ // CUDA kernel params. template struct UpFirDn2DKernelParams { const T* x; // [majorDim, inH, inW, minorDim] const T* k; // [kernelH, kernelW] T* y; // [majorDim, outH, outW, minorDim] int upx; int upy; int downx; int downy; int padx0; int padx1; int pady0; int pady1; int majorDim; int inH; int inW; int minorDim; int kernelH; int kernelW; int outH; int outW; int loopMajor; int loopX; }; //------------------------------------------------------------------------ // General CUDA implementation for large filter kernels. template static __global__ void UpFirDn2DKernel_large(const UpFirDn2DKernelParams p) { // Calculate thread index. int minorIdx = blockIdx.x * blockDim.x + threadIdx.x; int outY = minorIdx / p.minorDim; minorIdx -= outY * p.minorDim; int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; int majorIdxBase = blockIdx.z * p.loopMajor; if (outXBase >= p.outW || outY >= p.outH || majorIdxBase >= p.majorDim) return; // Setup Y receptive field. int midY = outY * p.downy + p.upy - 1 - p.pady0; int inY = min(max(floorDiv(midY, p.upy), 0), p.inH); int h = min(max(floorDiv(midY + p.kernelH, p.upy), 0), p.inH) - inY; int kernelY = midY + p.kernelH - (inY + 1) * p.upy; // Loop over majorDim and outX. for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor && majorIdx < p.majorDim; loopMajor++, majorIdx++) for (int loopX = 0, outX = outXBase; loopX < p.loopX && outX < p.outW; loopX++, outX += blockDim.y) { // Setup X receptive field. int midX = outX * p.downx + p.upx - 1 - p.padx0; int inX = min(max(floorDiv(midX, p.upx), 0), p.inW); int w = min(max(floorDiv(midX + p.kernelW, p.upx), 0), p.inW) - inX; int kernelX = midX + p.kernelW - (inX + 1) * p.upx; // Initialize pointers. const T* xp = &p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx]; const T* kp = &p.k[kernelY * p.kernelW + kernelX]; int xpx = p.minorDim; int kpx = -p.upx; int xpy = p.inW * p.minorDim; int kpy = -p.upy * p.kernelW; // Inner loop. float v = 0.0f; for (int y = 0; y < h; y++) { for (int x = 0; x < w; x++) { v += (float)(*xp) * (float)(*kp); xp += xpx; kp += kpx; } xp += xpy - w * xpx; kp += kpy - w * kpx; } // Store result. p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v; } } //------------------------------------------------------------------------ // Specialized CUDA implementation for small filter kernels. template static __global__ void UpFirDn2DKernel_small(const UpFirDn2DKernelParams p) { //assert(kernelW % upx == 0); //assert(kernelH % upy == 0); const int tileInW = ((tileOutW - 1) * downx + kernelW - 1) / upx + 1; const int tileInH = ((tileOutH - 1) * downy + kernelH - 1) / upy + 1; __shared__ volatile float sk[kernelH][kernelW]; __shared__ volatile float sx[tileInH][tileInW]; // Calculate tile index. int minorIdx = blockIdx.x; int tileOutY = minorIdx / p.minorDim; minorIdx -= tileOutY * p.minorDim; tileOutY *= tileOutH; int tileOutXBase = blockIdx.y * p.loopX * tileOutW; int majorIdxBase = blockIdx.z * p.loopMajor; if (tileOutXBase >= p.outW | tileOutY >= p.outH | majorIdxBase >= p.majorDim) return; // Load filter kernel (flipped). for (int tapIdx = threadIdx.x; tapIdx < kernelH * kernelW; tapIdx += blockDim.x) { int ky = tapIdx / kernelW; int kx = tapIdx - ky * kernelW; float v = 0.0f; if (kx < p.kernelW & ky < p.kernelH) v = (float)p.k[(p.kernelH - 1 - ky) * p.kernelW + (p.kernelW - 1 - kx)]; sk[ky][kx] = v; } // Loop over majorDim and outX. for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor & majorIdx < p.majorDim; loopMajor++, majorIdx++) for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outW; loopX++, tileOutX += tileOutW) { // Load input pixels. int tileMidX = tileOutX * downx + upx - 1 - p.padx0; int tileMidY = tileOutY * downy + upy - 1 - p.pady0; int tileInX = floorDiv(tileMidX, upx); int tileInY = floorDiv(tileMidY, upy); __syncthreads(); for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW; inIdx += blockDim.x) { int relInY = inIdx / tileInW; int relInX = inIdx - relInY * tileInW; int inX = relInX + tileInX; int inY = relInY + tileInY; float v = 0.0f; if (inX >= 0 & inY >= 0 & inX < p.inW & inY < p.inH) v = (float)p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx]; sx[relInY][relInX] = v; } // Loop over output pixels. __syncthreads(); for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW; outIdx += blockDim.x) { int relOutY = outIdx / tileOutW; int relOutX = outIdx - relOutY * tileOutW; int outX = relOutX + tileOutX; int outY = relOutY + tileOutY; // Setup receptive field. int midX = tileMidX + relOutX * downx; int midY = tileMidY + relOutY * downy; int inX = floorDiv(midX, upx); int inY = floorDiv(midY, upy); int relInX = inX - tileInX; int relInY = inY - tileInY; int kernelX = (inX + 1) * upx - midX - 1; // flipped int kernelY = (inY + 1) * upy - midY - 1; // flipped // Inner loop. float v = 0.0f; #pragma unroll for (int y = 0; y < kernelH / upy; y++) #pragma unroll for (int x = 0; x < kernelW / upx; x++) v += sx[relInY + y][relInX + x] * sk[kernelY + y * upy][kernelX + x * upx]; // Store result. if (outX < p.outW & outY < p.outH) p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v; } } } //------------------------------------------------------------------------ // TensorFlow op. template struct UpFirDn2DOp : public OpKernel { UpFirDn2DKernelParams m_attribs; UpFirDn2DOp(OpKernelConstruction* ctx) : OpKernel(ctx) { memset(&m_attribs, 0, sizeof(m_attribs)); OP_REQUIRES_OK(ctx, ctx->GetAttr("upx", &m_attribs.upx)); OP_REQUIRES_OK(ctx, ctx->GetAttr("upy", &m_attribs.upy)); OP_REQUIRES_OK(ctx, ctx->GetAttr("downx", &m_attribs.downx)); OP_REQUIRES_OK(ctx, ctx->GetAttr("downy", &m_attribs.downy)); OP_REQUIRES_OK(ctx, ctx->GetAttr("padx0", &m_attribs.padx0)); OP_REQUIRES_OK(ctx, ctx->GetAttr("padx1", &m_attribs.padx1)); OP_REQUIRES_OK(ctx, ctx->GetAttr("pady0", &m_attribs.pady0)); OP_REQUIRES_OK(ctx, ctx->GetAttr("pady1", &m_attribs.pady1)); OP_REQUIRES(ctx, m_attribs.upx >= 1 && m_attribs.upy >= 1, errors::InvalidArgument("upx and upy must be at least 1x1")); OP_REQUIRES(ctx, m_attribs.downx >= 1 && m_attribs.downy >= 1, errors::InvalidArgument("downx and downy must be at least 1x1")); } void Compute(OpKernelContext* ctx) { UpFirDn2DKernelParams p = m_attribs; cudaStream_t stream = ctx->eigen_device().stream(); const Tensor& x = ctx->input(0); // [majorDim, inH, inW, minorDim] const Tensor& k = ctx->input(1); // [kernelH, kernelW] p.x = x.flat().data(); p.k = k.flat().data(); OP_REQUIRES(ctx, x.dims() == 4, errors::InvalidArgument("input must have rank 4")); OP_REQUIRES(ctx, k.dims() == 2, errors::InvalidArgument("kernel must have rank 2")); OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("input too large")); OP_REQUIRES(ctx, k.NumElements() <= kint32max, errors::InvalidArgument("kernel too large")); p.majorDim = (int)x.dim_size(0); p.inH = (int)x.dim_size(1); p.inW = (int)x.dim_size(2); p.minorDim = (int)x.dim_size(3); p.kernelH = (int)k.dim_size(0); p.kernelW = (int)k.dim_size(1); OP_REQUIRES(ctx, p.kernelW >= 1 && p.kernelH >= 1, errors::InvalidArgument("kernel must be at least 1x1")); p.outW = (p.inW * p.upx + p.padx0 + p.padx1 - p.kernelW + p.downx) / p.downx; p.outH = (p.inH * p.upy + p.pady0 + p.pady1 - p.kernelH + p.downy) / p.downy; OP_REQUIRES(ctx, p.outW >= 1 && p.outH >= 1, errors::InvalidArgument("output must be at least 1x1")); Tensor* y = NULL; // [majorDim, outH, outW, minorDim] TensorShape ys; ys.AddDim(p.majorDim); ys.AddDim(p.outH); ys.AddDim(p.outW); ys.AddDim(p.minorDim); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, ys, &y)); p.y = y->flat().data(); OP_REQUIRES(ctx, y->NumElements() <= kint32max, errors::InvalidArgument("output too large")); // Choose CUDA kernel to use. void* cudaKernel = (void*)UpFirDn2DKernel_large; int tileOutW = -1; int tileOutH = -1; if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 7 && p.kernelH <= 7) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 5 && p.kernelH <= 5) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 3 && p.kernelH <= 3) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 8) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 2 && p.kernelH <= 2) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 8 && p.kernelH <= 8) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 2 && p.kernelH <= 2) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } // Choose launch params. dim3 blockSize; dim3 gridSize; if (tileOutW > 0 && tileOutH > 0) // small { p.loopMajor = (p.majorDim - 1) / 16384 + 1; p.loopX = 1; blockSize = dim3(32 * 8, 1, 1); gridSize = dim3(((p.outH - 1) / tileOutH + 1) * p.minorDim, (p.outW - 1) / (p.loopX * tileOutW) + 1, (p.majorDim - 1) / p.loopMajor + 1); } else // large { p.loopMajor = (p.majorDim - 1) / 16384 + 1; p.loopX = 4; blockSize = dim3(4, 32, 1); gridSize = dim3((p.outH * p.minorDim - 1) / blockSize.x + 1, (p.outW - 1) / (p.loopX * blockSize.y) + 1, (p.majorDim - 1) / p.loopMajor + 1); } // Launch CUDA kernel. void* args[] = {&p}; OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(cudaKernel, gridSize, blockSize, args, 0, stream)); } }; REGISTER_OP("UpFirDn2D") .Input ("x: T") .Input ("k: T") .Output ("y: T") .Attr ("T: {float, half}") .Attr ("upx: int = 1") .Attr ("upy: int = 1") .Attr ("downx: int = 1") .Attr ("downy: int = 1") .Attr ("padx0: int = 0") .Attr ("padx1: int = 0") .Attr ("pady0: int = 0") .Attr ("pady1: int = 0"); REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint("T"), UpFirDn2DOp); REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint("T"), UpFirDn2DOp); //------------------------------------------------------------------------ ================================================ FILE: stylegan_human/dnnlib/tflib/ops/upfirdn_2d.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. # # This work is made available under the Nvidia Source Code License-NC. # To view a copy of this license, visit # https://nvlabs.github.io/stylegan2/license.html """Custom TensorFlow ops for efficient resampling of 2D images.""" import os import numpy as np import tensorflow as tf from .. import custom_ops def _get_plugin(): return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu') #---------------------------------------------------------------------------- def upfirdn_2d(x, k, upx=1, upy=1, downx=1, downy=1, padx0=0, padx1=0, pady0=0, pady1=0, impl='cuda'): r"""Pad, upsample, FIR filter, and downsample a batch of 2D images. Accepts a batch of 2D images of the shape `[majorDim, inH, inW, minorDim]` and performs the following operations for each image, batched across `majorDim` and `minorDim`: 1. Pad the image with zeros by the specified number of pixels on each side (`padx0`, `padx1`, `pady0`, `pady1`). Specifying a negative value corresponds to cropping the image. 2. Upsample the image by inserting the zeros after each pixel (`upx`, `upy`). 3. Convolve the image with the specified 2D FIR filter (`k`), shrinking the image so that the footprint of all output pixels lies within the input image. 4. Downsample the image by throwing away pixels (`downx`, `downy`). This sequence of operations bears close resemblance to scipy.signal.upfirdn(). The fused op is considerably more efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary order. Args: x: Input tensor of the shape `[majorDim, inH, inW, minorDim]`. k: 2D FIR filter of the shape `[firH, firW]`. upx: Integer upsampling factor along the X-axis (default: 1). upy: Integer upsampling factor along the Y-axis (default: 1). downx: Integer downsampling factor along the X-axis (default: 1). downy: Integer downsampling factor along the Y-axis (default: 1). padx0: Number of pixels to pad on the left side (default: 0). padx1: Number of pixels to pad on the right side (default: 0). pady0: Number of pixels to pad on the top side (default: 0). pady1: Number of pixels to pad on the bottom side (default: 0). impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). Returns: Tensor of the shape `[majorDim, outH, outW, minorDim]`, and same datatype as `x`. """ impl_dict = { 'ref': _upfirdn_2d_ref, 'cuda': _upfirdn_2d_cuda, } return impl_dict[impl](x=x, k=k, upx=upx, upy=upy, downx=downx, downy=downy, padx0=padx0, padx1=padx1, pady0=pady0, pady1=pady1) #---------------------------------------------------------------------------- def _upfirdn_2d_ref(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1): """Slow reference implementation of `upfirdn_2d()` using standard TensorFlow ops.""" x = tf.convert_to_tensor(x) k = np.asarray(k, dtype=np.float32) assert x.shape.rank == 4 inH = x.shape[1].value inW = x.shape[2].value minorDim = _shape(x, 3) kernelH, kernelW = k.shape assert inW >= 1 and inH >= 1 assert kernelW >= 1 and kernelH >= 1 assert isinstance(upx, int) and isinstance(upy, int) assert isinstance(downx, int) and isinstance(downy, int) assert isinstance(padx0, int) and isinstance(padx1, int) assert isinstance(pady0, int) and isinstance(pady1, int) # Upsample (insert zeros). x = tf.reshape(x, [-1, inH, 1, inW, 1, minorDim]) x = tf.pad(x, [[0, 0], [0, 0], [0, upy - 1], [0, 0], [0, upx - 1], [0, 0]]) x = tf.reshape(x, [-1, inH * upy, inW * upx, minorDim]) # Pad (crop if negative). x = tf.pad(x, [[0, 0], [max(pady0, 0), max(pady1, 0)], [max(padx0, 0), max(padx1, 0)], [0, 0]]) x = x[:, max(-pady0, 0) : x.shape[1].value - max(-pady1, 0), max(-padx0, 0) : x.shape[2].value - max(-padx1, 0), :] # Convolve with filter. x = tf.transpose(x, [0, 3, 1, 2]) x = tf.reshape(x, [-1, 1, inH * upy + pady0 + pady1, inW * upx + padx0 + padx1]) w = tf.constant(k[::-1, ::-1, np.newaxis, np.newaxis], dtype=x.dtype) x = tf.nn.conv2d(x, w, strides=[1,1,1,1], padding='VALID', data_format='NCHW') x = tf.reshape(x, [-1, minorDim, inH * upy + pady0 + pady1 - kernelH + 1, inW * upx + padx0 + padx1 - kernelW + 1]) x = tf.transpose(x, [0, 2, 3, 1]) # Downsample (throw away pixels). return x[:, ::downy, ::downx, :] #---------------------------------------------------------------------------- def _upfirdn_2d_cuda(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1): """Fast CUDA implementation of `upfirdn_2d()` using custom ops.""" x = tf.convert_to_tensor(x) k = np.asarray(k, dtype=np.float32) majorDim, inH, inW, minorDim = x.shape.as_list() kernelH, kernelW = k.shape assert inW >= 1 and inH >= 1 assert kernelW >= 1 and kernelH >= 1 assert isinstance(upx, int) and isinstance(upy, int) assert isinstance(downx, int) and isinstance(downy, int) assert isinstance(padx0, int) and isinstance(padx1, int) assert isinstance(pady0, int) and isinstance(pady1, int) outW = (inW * upx + padx0 + padx1 - kernelW) // downx + 1 outH = (inH * upy + pady0 + pady1 - kernelH) // downy + 1 assert outW >= 1 and outH >= 1 kc = tf.constant(k, dtype=x.dtype) gkc = tf.constant(k[::-1, ::-1], dtype=x.dtype) gpadx0 = kernelW - padx0 - 1 gpady0 = kernelH - pady0 - 1 gpadx1 = inW * upx - outW * downx + padx0 - upx + 1 gpady1 = inH * upy - outH * downy + pady0 - upy + 1 @tf.custom_gradient def func(x): y = _get_plugin().up_fir_dn2d(x=x, k=kc, upx=upx, upy=upy, downx=downx, downy=downy, padx0=padx0, padx1=padx1, pady0=pady0, pady1=pady1) y.set_shape([majorDim, outH, outW, minorDim]) @tf.custom_gradient def grad(dy): dx = _get_plugin().up_fir_dn2d(x=dy, k=gkc, upx=downx, upy=downy, downx=upx, downy=upy, padx0=gpadx0, padx1=gpadx1, pady0=gpady0, pady1=gpady1) dx.set_shape([majorDim, inH, inW, minorDim]) return dx, func return y, grad return func(x) #---------------------------------------------------------------------------- def filter_2d(x, k, gain=1, data_format='NCHW', impl='cuda'): r"""Filter a batch of 2D images with the given FIR filter. Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and filters each image with the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified `gain`. Pixels outside the image are assumed to be zero. Args: x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). gain: Scaling factor for signal magnitude (default: 1.0). data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`). impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). Returns: Tensor of the same shape and datatype as `x`. """ k = _setup_kernel(k) * gain p = k.shape[0] - 1 return _simple_upfirdn_2d(x, k, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl) #---------------------------------------------------------------------------- def upsample_2d(x, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'): r"""Upsample a batch of 2D images with the given filter. Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a multiple of the upsampling factor. Args: x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`). impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). Returns: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as `x`. """ assert isinstance(factor, int) and factor >= 1 if k is None: k = [1] * factor k = _setup_kernel(k) * (gain * (factor ** 2)) p = k.shape[0] - factor return _simple_upfirdn_2d(x, k, up=factor, pad0=(p+1)//2+factor-1, pad1=p//2, data_format=data_format, impl=impl) #---------------------------------------------------------------------------- def downsample_2d(x, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'): r"""Downsample a batch of 2D images with the given filter. Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a multiple of the downsampling factor. Args: x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`). impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). Returns: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype as `x`. """ assert isinstance(factor, int) and factor >= 1 if k is None: k = [1] * factor k = _setup_kernel(k) * gain p = k.shape[0] - factor return _simple_upfirdn_2d(x, k, down=factor, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl) #---------------------------------------------------------------------------- def upsample_conv_2d(x, w, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'): r"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`. Padding is performed only once at the beginning, not between the operations. The fused op is considerably more efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary order. Args: x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`). impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). Returns: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as `x`. """ assert isinstance(factor, int) and factor >= 1 # Check weight shape. w = tf.convert_to_tensor(w) assert w.shape.rank == 4 convH = w.shape[0].value convW = w.shape[1].value inC = _shape(w, 2) outC = _shape(w, 3) assert convW == convH # Setup filter kernel. if k is None: k = [1] * factor k = _setup_kernel(k) * (gain * (factor ** 2)) p = (k.shape[0] - factor) - (convW - 1) # Determine data dimensions. if data_format == 'NCHW': stride = [1, 1, factor, factor] output_shape = [_shape(x, 0), outC, (_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW] num_groups = _shape(x, 1) // inC else: stride = [1, factor, factor, 1] output_shape = [_shape(x, 0), (_shape(x, 1) - 1) * factor + convH, (_shape(x, 2) - 1) * factor + convW, outC] num_groups = _shape(x, 3) // inC # Transpose weights. w = tf.reshape(w, [convH, convW, inC, num_groups, -1]) w = tf.transpose(w[::-1, ::-1], [0, 1, 4, 3, 2]) w = tf.reshape(w, [convH, convW, -1, num_groups * inC]) # Execute. x = tf.nn.conv2d_transpose(x, w, output_shape=output_shape, strides=stride, padding='VALID', data_format=data_format) return _simple_upfirdn_2d(x, k, pad0=(p+1)//2+factor-1, pad1=p//2+1, data_format=data_format, impl=impl) #---------------------------------------------------------------------------- def conv_downsample_2d(x, w, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'): r"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`. Padding is performed only once at the beginning, not between the operations. The fused op is considerably more efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary order. Args: x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`). impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). Returns: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype as `x`. """ assert isinstance(factor, int) and factor >= 1 w = tf.convert_to_tensor(w) convH, convW, _inC, _outC = w.shape.as_list() assert convW == convH if k is None: k = [1] * factor k = _setup_kernel(k) * gain p = (k.shape[0] - factor) + (convW - 1) if data_format == 'NCHW': s = [1, 1, factor, factor] else: s = [1, factor, factor, 1] x = _simple_upfirdn_2d(x, k, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl) return tf.nn.conv2d(x, w, strides=s, padding='VALID', data_format=data_format) #---------------------------------------------------------------------------- # Internal helper funcs. def _shape(tf_expr, dim_idx): if tf_expr.shape.rank is not None: dim = tf_expr.shape[dim_idx].value if dim is not None: return dim return tf.shape(tf_expr)[dim_idx] def _setup_kernel(k): k = np.asarray(k, dtype=np.float32) if k.ndim == 1: k = np.outer(k, k) k /= np.sum(k) assert k.ndim == 2 assert k.shape[0] == k.shape[1] return k def _simple_upfirdn_2d(x, k, up=1, down=1, pad0=0, pad1=0, data_format='NCHW', impl='cuda'): assert data_format in ['NCHW', 'NHWC'] assert x.shape.rank == 4 y = x if data_format == 'NCHW': y = tf.reshape(y, [-1, _shape(y, 2), _shape(y, 3), 1]) y = upfirdn_2d(y, k, upx=up, upy=up, downx=down, downy=down, padx0=pad0, padx1=pad1, pady0=pad0, pady1=pad1, impl=impl) if data_format == 'NCHW': y = tf.reshape(y, [-1, _shape(x, 1), _shape(y, 1), _shape(y, 2)]) return y #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/dnnlib/tflib/optimizer.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. # # This work is made available under the Nvidia Source Code License-NC. # To view a copy of this license, visit # https://nvlabs.github.io/stylegan2/license.html """Helper wrapper for a Tensorflow optimizer.""" import numpy as np import tensorflow as tf from collections import OrderedDict from typing import List, Union from . import autosummary from . import tfutil from .. import util from .tfutil import TfExpression, TfExpressionEx try: # TensorFlow 1.13 from tensorflow.python.ops import nccl_ops except: # Older TensorFlow versions import tensorflow.contrib.nccl as nccl_ops class Optimizer: """A Wrapper for tf.train.Optimizer. Automatically takes care of: - Gradient averaging for multi-GPU training. - Gradient accumulation for arbitrarily large minibatches. - Dynamic loss scaling and typecasts for FP16 training. - Ignoring corrupted gradients that contain NaNs/Infs. - Reporting statistics. - Well-chosen default settings. """ def __init__(self, name: str = "Train", # Name string that will appear in TensorFlow graph. tf_optimizer: str = "tf.train.AdamOptimizer", # Underlying optimizer class. learning_rate: TfExpressionEx = 0.001, # Learning rate. Can vary over time. minibatch_multiplier: TfExpressionEx = None, # Treat N consecutive minibatches as one by accumulating gradients. share: "Optimizer" = None, # Share internal state with a previously created optimizer? use_loss_scaling: bool = False, # Enable dynamic loss scaling for robust mixed-precision training? loss_scaling_init: float = 64.0, # Log2 of initial loss scaling factor. loss_scaling_inc: float = 0.0005, # Log2 of per-minibatch loss scaling increment when there is no overflow. loss_scaling_dec: float = 1.0, # Log2 of per-minibatch loss scaling decrement when there is an overflow. report_mem_usage: bool = False, # Report fine-grained memory usage statistics in TensorBoard? **kwargs): # Public fields. self.name = name self.learning_rate = learning_rate self.minibatch_multiplier = minibatch_multiplier self.id = self.name.replace("/", ".") self.scope = tf.get_default_graph().unique_name(self.id) self.optimizer_class = util.get_obj_by_name(tf_optimizer) self.optimizer_kwargs = dict(kwargs) self.use_loss_scaling = use_loss_scaling self.loss_scaling_init = loss_scaling_init self.loss_scaling_inc = loss_scaling_inc self.loss_scaling_dec = loss_scaling_dec # Private fields. self._updates_applied = False self._devices = OrderedDict() # device_name => EasyDict() self._shared_optimizers = OrderedDict() # device_name => optimizer_class self._gradient_shapes = None # [shape, ...] self._report_mem_usage = report_mem_usage # Validate arguments. assert callable(self.optimizer_class) # Share internal state if requested. if share is not None: assert isinstance(share, Optimizer) assert self.optimizer_class is share.optimizer_class assert self.learning_rate is share.learning_rate assert self.optimizer_kwargs == share.optimizer_kwargs self._shared_optimizers = share._shared_optimizers # pylint: disable=protected-access def _get_device(self, device_name: str): """Get internal state for the given TensorFlow device.""" tfutil.assert_tf_initialized() if device_name in self._devices: return self._devices[device_name] # Initialize fields. device = util.EasyDict() device.name = device_name device.optimizer = None # Underlying optimizer: optimizer_class device.loss_scaling_var = None # Log2 of loss scaling: tf.Variable device.grad_raw = OrderedDict() # Raw gradients: var => [grad, ...] device.grad_clean = OrderedDict() # Clean gradients: var => grad device.grad_acc_vars = OrderedDict() # Accumulation sums: var => tf.Variable device.grad_acc_count = None # Accumulation counter: tf.Variable device.grad_acc = OrderedDict() # Accumulated gradients: var => grad # Setup TensorFlow objects. with tfutil.absolute_name_scope(self.scope + "/Devices"), tf.device(device_name), tf.control_dependencies(None): if device_name not in self._shared_optimizers: optimizer_name = self.scope.replace("/", "_") + "_opt%d" % len(self._shared_optimizers) self._shared_optimizers[device_name] = self.optimizer_class(name=optimizer_name, learning_rate=self.learning_rate, **self.optimizer_kwargs) device.optimizer = self._shared_optimizers[device_name] if self.use_loss_scaling: device.loss_scaling_var = tf.Variable(np.float32(self.loss_scaling_init), trainable=False, name="loss_scaling_var") # Register device. self._devices[device_name] = device return device def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None: """Register the gradients of the given loss function with respect to the given variables. Intended to be called once per GPU.""" tfutil.assert_tf_initialized() assert not self._updates_applied device = self._get_device(loss.device) # Validate trainables. if isinstance(trainable_vars, dict): trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1 assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss]) assert all(var.device == device.name for var in trainable_vars) # Validate shapes. if self._gradient_shapes is None: self._gradient_shapes = [var.shape.as_list() for var in trainable_vars] assert len(trainable_vars) == len(self._gradient_shapes) assert all(var.shape.as_list() == var_shape for var, var_shape in zip(trainable_vars, self._gradient_shapes)) # Report memory usage if requested. deps = [] if self._report_mem_usage: self._report_mem_usage = False try: with tf.name_scope(self.id + '_mem'), tf.device(device.name), tf.control_dependencies([loss]): deps.append(autosummary.autosummary(self.id + "/mem_usage_gb", tf.contrib.memory_stats.BytesInUse() / 2**30)) except tf.errors.NotFoundError: pass # Compute gradients. with tf.name_scope(self.id + "_grad"), tf.device(device.name), tf.control_dependencies(deps): loss = self.apply_loss_scaling(tf.cast(loss, tf.float32)) gate = tf.train.Optimizer.GATE_NONE # disable gating to reduce memory usage grad_list = device.optimizer.compute_gradients(loss=loss, var_list=trainable_vars, gate_gradients=gate) # Register gradients. for grad, var in grad_list: if var not in device.grad_raw: device.grad_raw[var] = [] device.grad_raw[var].append(grad) def apply_updates(self, allow_no_op: bool = False) -> tf.Operation: """Construct training op to update the registered variables based on their gradients.""" tfutil.assert_tf_initialized() assert not self._updates_applied self._updates_applied = True all_ops = [] # Check for no-op. if allow_no_op and len(self._devices) == 0: with tfutil.absolute_name_scope(self.scope): return tf.no_op(name='TrainingOp') # Clean up gradients. for device_idx, device in enumerate(self._devices.values()): with tfutil.absolute_name_scope(self.scope + "/Clean%d" % device_idx), tf.device(device.name): for var, grad in device.grad_raw.items(): # Filter out disconnected gradients and convert to float32. grad = [g for g in grad if g is not None] grad = [tf.cast(g, tf.float32) for g in grad] # Sum within the device. if len(grad) == 0: grad = tf.zeros(var.shape) # No gradients => zero. elif len(grad) == 1: grad = grad[0] # Single gradient => use as is. else: grad = tf.add_n(grad) # Multiple gradients => sum. # Scale as needed. scale = 1.0 / len(device.grad_raw[var]) / len(self._devices) scale = tf.constant(scale, dtype=tf.float32, name="scale") if self.minibatch_multiplier is not None: scale /= tf.cast(self.minibatch_multiplier, tf.float32) scale = self.undo_loss_scaling(scale) device.grad_clean[var] = grad * scale # Sum gradients across devices. if len(self._devices) > 1: with tfutil.absolute_name_scope(self.scope + "/Broadcast"), tf.device(None): for all_vars in zip(*[device.grad_clean.keys() for device in self._devices.values()]): if len(all_vars) > 0 and all(dim > 0 for dim in all_vars[0].shape.as_list()): # NCCL does not support zero-sized tensors. all_grads = [device.grad_clean[var] for device, var in zip(self._devices.values(), all_vars)] all_grads = nccl_ops.all_sum(all_grads) for device, var, grad in zip(self._devices.values(), all_vars, all_grads): device.grad_clean[var] = grad # Apply updates separately on each device. for device_idx, device in enumerate(self._devices.values()): with tfutil.absolute_name_scope(self.scope + "/Apply%d" % device_idx), tf.device(device.name): # pylint: disable=cell-var-from-loop # Accumulate gradients over time. if self.minibatch_multiplier is None: acc_ok = tf.constant(True, name='acc_ok') device.grad_acc = OrderedDict(device.grad_clean) else: # Create variables. with tf.control_dependencies(None): for var in device.grad_clean.keys(): device.grad_acc_vars[var] = tf.Variable(tf.zeros(var.shape), trainable=False, name="grad_acc_var") device.grad_acc_count = tf.Variable(tf.zeros([]), trainable=False, name="grad_acc_count") # Track counter. count_cur = device.grad_acc_count + 1.0 count_inc_op = lambda: tf.assign(device.grad_acc_count, count_cur) count_reset_op = lambda: tf.assign(device.grad_acc_count, tf.zeros([])) acc_ok = (count_cur >= tf.cast(self.minibatch_multiplier, tf.float32)) all_ops.append(tf.cond(acc_ok, count_reset_op, count_inc_op)) # Track gradients. for var, grad in device.grad_clean.items(): acc_var = device.grad_acc_vars[var] acc_cur = acc_var + grad device.grad_acc[var] = acc_cur with tf.control_dependencies([acc_cur]): acc_inc_op = lambda: tf.assign(acc_var, acc_cur) acc_reset_op = lambda: tf.assign(acc_var, tf.zeros(var.shape)) all_ops.append(tf.cond(acc_ok, acc_reset_op, acc_inc_op)) # No overflow => apply gradients. all_ok = tf.reduce_all(tf.stack([acc_ok] + [tf.reduce_all(tf.is_finite(g)) for g in device.grad_acc.values()])) apply_op = lambda: device.optimizer.apply_gradients([(tf.cast(grad, var.dtype), var) for var, grad in device.grad_acc.items()]) all_ops.append(tf.cond(all_ok, apply_op, tf.no_op)) # Adjust loss scaling. if self.use_loss_scaling: ls_inc_op = lambda: tf.assign_add(device.loss_scaling_var, self.loss_scaling_inc) ls_dec_op = lambda: tf.assign_sub(device.loss_scaling_var, self.loss_scaling_dec) ls_update_op = lambda: tf.group(tf.cond(all_ok, ls_inc_op, ls_dec_op)) all_ops.append(tf.cond(acc_ok, ls_update_op, tf.no_op)) # Last device => report statistics. if device_idx == len(self._devices) - 1: all_ops.append(autosummary.autosummary(self.id + "/learning_rate", self.learning_rate)) all_ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(all_ok, 0, 1), condition=acc_ok)) if self.use_loss_scaling: all_ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", device.loss_scaling_var)) # Initialize variables. self.reset_optimizer_state() if self.use_loss_scaling: tfutil.init_uninitialized_vars([device.loss_scaling_var for device in self._devices.values()]) if self.minibatch_multiplier is not None: tfutil.run([var.initializer for device in self._devices.values() for var in list(device.grad_acc_vars.values()) + [device.grad_acc_count]]) # Group everything into a single op. with tfutil.absolute_name_scope(self.scope): return tf.group(*all_ops, name="TrainingOp") def reset_optimizer_state(self) -> None: """Reset internal state of the underlying optimizer.""" tfutil.assert_tf_initialized() tfutil.run([var.initializer for device in self._devices.values() for var in device.optimizer.variables()]) def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]: """Get or create variable representing log2 of the current dynamic loss scaling factor.""" return self._get_device(device).loss_scaling_var def apply_loss_scaling(self, value: TfExpression) -> TfExpression: """Apply dynamic loss scaling for the given expression.""" assert tfutil.is_tf_expression(value) if not self.use_loss_scaling: return value return value * tfutil.exp2(self.get_loss_scaling_var(value.device)) def undo_loss_scaling(self, value: TfExpression) -> TfExpression: """Undo the effect of dynamic loss scaling for the given expression.""" assert tfutil.is_tf_expression(value) if not self.use_loss_scaling: return value return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type class SimpleAdam: """Simplified version of tf.train.AdamOptimizer that behaves identically when used with dnnlib.tflib.Optimizer.""" def __init__(self, name="Adam", learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8): self.name = name self.learning_rate = learning_rate self.beta1 = beta1 self.beta2 = beta2 self.epsilon = epsilon self.all_state_vars = [] def variables(self): return self.all_state_vars def compute_gradients(self, loss, var_list, gate_gradients=tf.train.Optimizer.GATE_NONE): assert gate_gradients == tf.train.Optimizer.GATE_NONE return list(zip(tf.gradients(loss, var_list), var_list)) def apply_gradients(self, grads_and_vars): with tf.name_scope(self.name): state_vars = [] update_ops = [] # Adjust learning rate to deal with startup bias. with tf.control_dependencies(None): b1pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False) b2pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False) state_vars += [b1pow_var, b2pow_var] b1pow_new = b1pow_var * self.beta1 b2pow_new = b2pow_var * self.beta2 update_ops += [tf.assign(b1pow_var, b1pow_new), tf.assign(b2pow_var, b2pow_new)] lr_new = self.learning_rate * tf.sqrt(1 - b2pow_new) / (1 - b1pow_new) # Construct ops to update each variable. for grad, var in grads_and_vars: with tf.control_dependencies(None): m_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False) v_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False) state_vars += [m_var, v_var] m_new = self.beta1 * m_var + (1 - self.beta1) * grad v_new = self.beta2 * v_var + (1 - self.beta2) * tf.square(grad) var_delta = lr_new * m_new / (tf.sqrt(v_new) + self.epsilon) update_ops += [tf.assign(m_var, m_new), tf.assign(v_var, v_new), tf.assign_sub(var, var_delta)] # Group everything together. self.all_state_vars += state_vars return tf.group(*update_ops) ================================================ FILE: stylegan_human/dnnlib/tflib/tfutil.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. # # This work is made available under the Nvidia Source Code License-NC. # To view a copy of this license, visit # https://nvlabs.github.io/stylegan2/license.html """Miscellaneous helper utils for Tensorflow.""" import os import numpy as np import tensorflow as tf # Silence deprecation warnings from TensorFlow 1.13 onwards import logging logging.getLogger('tensorflow').setLevel(logging.ERROR) import tensorflow.contrib # requires TensorFlow 1.x! tf.contrib = tensorflow.contrib from typing import Any, Iterable, List, Union TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation] """A type that represents a valid Tensorflow expression.""" TfExpressionEx = Union[TfExpression, int, float, np.ndarray] """A type that can be converted to a valid Tensorflow expression.""" def run(*args, **kwargs) -> Any: """Run the specified ops in the default session.""" assert_tf_initialized() return tf.get_default_session().run(*args, **kwargs) def is_tf_expression(x: Any) -> bool: """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation.""" return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation)) def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]: """Convert a Tensorflow shape to a list of ints. Retained for backwards compatibility -- use TensorShape.as_list() in new code.""" return [dim.value for dim in shape] def flatten(x: TfExpressionEx) -> TfExpression: """Shortcut function for flattening a tensor.""" with tf.name_scope("Flatten"): return tf.reshape(x, [-1]) def log2(x: TfExpressionEx) -> TfExpression: """Logarithm in base 2.""" with tf.name_scope("Log2"): return tf.log(x) * np.float32(1.0 / np.log(2.0)) def exp2(x: TfExpressionEx) -> TfExpression: """Exponent in base 2.""" with tf.name_scope("Exp2"): return tf.exp(x * np.float32(np.log(2.0))) def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx: """Linear interpolation.""" with tf.name_scope("Lerp"): return a + (b - a) * t def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression: """Linear interpolation with clip.""" with tf.name_scope("LerpClip"): return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0) def absolute_name_scope(scope: str) -> tf.name_scope: """Forcefully enter the specified name scope, ignoring any surrounding scopes.""" return tf.name_scope(scope + "/") def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope: """Forcefully enter the specified variable scope, ignoring any surrounding scopes.""" return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False) def _sanitize_tf_config(config_dict: dict = None) -> dict: # Defaults. cfg = dict() cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is. cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is. cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info. cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used. cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed. # Remove defaults for environment variables that are already set. for key in list(cfg): fields = key.split(".") if fields[0] == "env": assert len(fields) == 2 if fields[1] in os.environ: del cfg[key] # User overrides. if config_dict is not None: cfg.update(config_dict) return cfg def init_tf(config_dict: dict = None) -> None: """Initialize TensorFlow session using good default settings.""" # Skip if already initialized. if tf.get_default_session() is not None: return # Setup config dict and random seeds. cfg = _sanitize_tf_config(config_dict) np_random_seed = cfg["rnd.np_random_seed"] if np_random_seed is not None: np.random.seed(np_random_seed) tf_random_seed = cfg["rnd.tf_random_seed"] if tf_random_seed == "auto": tf_random_seed = np.random.randint(1 << 31) if tf_random_seed is not None: tf.set_random_seed(tf_random_seed) # Setup environment variables. for key, value in cfg.items(): fields = key.split(".") if fields[0] == "env": assert len(fields) == 2 os.environ[fields[1]] = str(value) # Create default TensorFlow session. create_session(cfg, force_as_default=True) def assert_tf_initialized(): """Check that TensorFlow session has been initialized.""" if tf.get_default_session() is None: raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().") def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session: """Create tf.Session based on config dict.""" # Setup TensorFlow config proto. cfg = _sanitize_tf_config(config_dict) config_proto = tf.ConfigProto() for key, value in cfg.items(): fields = key.split(".") if fields[0] not in ["rnd", "env"]: obj = config_proto for field in fields[:-1]: obj = getattr(obj, field) setattr(obj, fields[-1], value) # Create session. session = tf.Session(config=config_proto) if force_as_default: # pylint: disable=protected-access session._default_session = session.as_default() session._default_session.enforce_nesting = False session._default_session.__enter__() return session def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None: """Initialize all tf.Variables that have not already been initialized. Equivalent to the following, but more efficient and does not bloat the tf graph: tf.variables_initializer(tf.report_uninitialized_variables()).run() """ assert_tf_initialized() if target_vars is None: target_vars = tf.global_variables() test_vars = [] test_ops = [] with tf.control_dependencies(None): # ignore surrounding control_dependencies for var in target_vars: assert is_tf_expression(var) try: tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0")) except KeyError: # Op does not exist => variable may be uninitialized. test_vars.append(var) with absolute_name_scope(var.name.split(":")[0]): test_ops.append(tf.is_variable_initialized(var)) init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited] run([var.initializer for var in init_vars]) def set_vars(var_to_value_dict: dict) -> None: """Set the values of given tf.Variables. Equivalent to the following, but more efficient and does not bloat the tf graph: tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()] """ assert_tf_initialized() ops = [] feed_dict = {} for var, value in var_to_value_dict.items(): assert is_tf_expression(var) try: setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op except KeyError: with absolute_name_scope(var.name.split(":")[0]): with tf.control_dependencies(None): # ignore surrounding control_dependencies setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter ops.append(setter) feed_dict[setter.op.inputs[1]] = value run(ops, feed_dict) def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs): """Create tf.Variable with large initial value without bloating the tf graph.""" assert_tf_initialized() assert isinstance(initial_value, np.ndarray) zeros = tf.zeros(initial_value.shape, initial_value.dtype) var = tf.Variable(zeros, *args, **kwargs) set_vars({var: initial_value}) return var def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False): """Convert a minibatch of images from uint8 to float32 with configurable dynamic range. Can be used as an input transformation for Network.run(). """ images = tf.cast(images, tf.float32) if nhwc_to_nchw: images = tf.transpose(images, [0, 3, 1, 2]) return images * ((drange[1] - drange[0]) / 255) + drange[0] def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1): """Convert a minibatch of images from float32 to uint8 with configurable dynamic range. Can be used as an output transformation for Network.run(). """ images = tf.cast(images, tf.float32) if shrink > 1: ksize = [1, 1, shrink, shrink] images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") if nchw_to_nhwc: images = tf.transpose(images, [0, 2, 3, 1]) scale = 255 / (drange[1] - drange[0]) images = images * scale + (0.5 - drange[0] * scale) return tf.saturate_cast(images, tf.uint8) ================================================ FILE: stylegan_human/dnnlib/util.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Miscellaneous utility classes and functions.""" import ctypes import fnmatch import importlib import inspect import numpy as np import os import shutil import sys import types import io import pickle import re import requests import html import hashlib import glob import tempfile import urllib import urllib.request import uuid from distutils.util import strtobool from typing import Any, List, Tuple, Union # Util classes # ------------------------------------------------------------------------------------------ class EasyDict(dict): """Convenience class that behaves like a dict but allows access with the attribute syntax.""" def __getattr__(self, name: str) -> Any: try: return self[name] except KeyError: raise AttributeError(name) def __setattr__(self, name: str, value: Any) -> None: self[name] = value def __delattr__(self, name: str) -> None: del self[name] class Logger(object): """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): self.file = None if file_name is not None: self.file = open(file_name, file_mode) self.should_flush = should_flush self.stdout = sys.stdout self.stderr = sys.stderr sys.stdout = self sys.stderr = self def __enter__(self) -> "Logger": return self def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: self.close() def write(self, text: Union[str, bytes]) -> None: """Write text to stdout (and a file) and optionally flush.""" if isinstance(text, bytes): text = text.decode() if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash return if self.file is not None: self.file.write(text) self.stdout.write(text) if self.should_flush: self.flush() def flush(self) -> None: """Flush written text to both stdout and a file, if open.""" if self.file is not None: self.file.flush() self.stdout.flush() def close(self) -> None: """Flush, close possible files, and remove stdout/stderr mirroring.""" self.flush() # if using multiple loggers, prevent closing in wrong order if sys.stdout is self: sys.stdout = self.stdout if sys.stderr is self: sys.stderr = self.stderr if self.file is not None: self.file.close() self.file = None # Cache directories # ------------------------------------------------------------------------------------------ _dnnlib_cache_dir = None def set_cache_dir(path: str) -> None: global _dnnlib_cache_dir _dnnlib_cache_dir = path def make_cache_dir_path(*paths: str) -> str: if _dnnlib_cache_dir is not None: return os.path.join(_dnnlib_cache_dir, *paths) if 'DNNLIB_CACHE_DIR' in os.environ: return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) if 'HOME' in os.environ: return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) if 'USERPROFILE' in os.environ: return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) # Small util functions # ------------------------------------------------------------------------------------------ def format_time(seconds: Union[int, float]) -> str: """Convert the seconds to human readable string with days, hours, minutes and seconds.""" s = int(np.rint(seconds)) if s < 60: return "{0}s".format(s) elif s < 60 * 60: return "{0}m {1:02}s".format(s // 60, s % 60) elif s < 24 * 60 * 60: return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) else: return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) def ask_yes_no(question: str) -> bool: """Ask the user the question until the user inputs a valid answer.""" while True: try: print("{0} [y/n]".format(question)) return strtobool(input().lower()) except ValueError: pass def tuple_product(t: Tuple) -> Any: """Calculate the product of the tuple elements.""" result = 1 for v in t: result *= v return result _str_to_ctype = { "uint8": ctypes.c_ubyte, "uint16": ctypes.c_uint16, "uint32": ctypes.c_uint32, "uint64": ctypes.c_uint64, "int8": ctypes.c_byte, "int16": ctypes.c_int16, "int32": ctypes.c_int32, "int64": ctypes.c_int64, "float32": ctypes.c_float, "float64": ctypes.c_double } def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" type_str = None if isinstance(type_obj, str): type_str = type_obj elif hasattr(type_obj, "__name__"): type_str = type_obj.__name__ elif hasattr(type_obj, "name"): type_str = type_obj.name else: raise RuntimeError("Cannot infer type name from input") assert type_str in _str_to_ctype.keys() my_dtype = np.dtype(type_str) my_ctype = _str_to_ctype[type_str] assert my_dtype.itemsize == ctypes.sizeof(my_ctype) return my_dtype, my_ctype def is_pickleable(obj: Any) -> bool: try: with io.BytesIO() as stream: pickle.dump(obj, stream) return True except: return False # Functionality to import modules/objects by name, and call functions by name # ------------------------------------------------------------------------------------------ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: """Searches for the underlying module behind the name to some python object. Returns the module and the object name (original name with module part removed).""" # allow convenience shorthands, substitute them by full names obj_name = re.sub("^np.", "numpy.", obj_name) obj_name = re.sub("^tf.", "tensorflow.", obj_name) # list alternatives for (module_name, local_obj_name) parts = obj_name.split(".") name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] # try each alternative in turn for module_name, local_obj_name in name_pairs: try: module = importlib.import_module(module_name) # may raise ImportError get_obj_from_module(module, local_obj_name) # may raise AttributeError return module, local_obj_name except: pass # maybe some of the modules themselves contain errors? for module_name, _local_obj_name in name_pairs: try: importlib.import_module(module_name) # may raise ImportError except ImportError: if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): raise # maybe the requested attribute is missing? for module_name, local_obj_name in name_pairs: try: module = importlib.import_module(module_name) # may raise ImportError get_obj_from_module(module, local_obj_name) # may raise AttributeError except ImportError: pass # we are out of luck, but we have no idea why raise ImportError(obj_name) def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: """Traverses the object name and returns the last (rightmost) python object.""" if obj_name == '': return module obj = module for part in obj_name.split("."): obj = getattr(obj, part) return obj def get_obj_by_name(name: str) -> Any: """Finds the python object with the given name.""" module, obj_name = get_module_from_obj_name(name) return get_obj_from_module(module, obj_name) def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: """Finds the python object with the given name and calls it as a function.""" assert func_name is not None # print('func_name: ', func_name) #'training.dataset.ImageFolderDataset' func_obj = get_obj_by_name(func_name) assert callable(func_obj) return func_obj(*args, **kwargs) def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: """Finds the python class with the given name and constructs it with the given arguments.""" return call_func_by_name(*args, func_name=class_name, **kwargs) def get_module_dir_by_obj_name(obj_name: str) -> str: """Get the directory path of the module containing the given object name.""" module, _ = get_module_from_obj_name(obj_name) return os.path.dirname(inspect.getfile(module)) def is_top_level_function(obj: Any) -> bool: """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ def get_top_level_function_name(obj: Any) -> str: """Return the fully-qualified name of a top-level function.""" assert is_top_level_function(obj) module = obj.__module__ if module == '__main__': module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] return module + "." + obj.__name__ # File system helpers # ------------------------------------------------------------------------------------------ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: """List all files recursively in a given directory while ignoring given file and directory names. Returns list of tuples containing both absolute and relative paths.""" assert os.path.isdir(dir_path) base_name = os.path.basename(os.path.normpath(dir_path)) if ignores is None: ignores = [] result = [] for root, dirs, files in os.walk(dir_path, topdown=True): for ignore_ in ignores: dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] # dirs need to be edited in-place for d in dirs_to_remove: dirs.remove(d) files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] absolute_paths = [os.path.join(root, f) for f in files] relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] if add_base_to_relative: relative_paths = [os.path.join(base_name, p) for p in relative_paths] assert len(absolute_paths) == len(relative_paths) result += zip(absolute_paths, relative_paths) return result def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: """Takes in a list of tuples of (src, dst) paths and copies files. Will create all necessary directories.""" for file in files: target_dir_name = os.path.dirname(file[1]) # will create all intermediate-level directories if not os.path.exists(target_dir_name): os.makedirs(target_dir_name) shutil.copyfile(file[0], file[1]) # URL helpers # ------------------------------------------------------------------------------------------ def is_url(obj: Any, allow_file_urls: bool = False) -> bool: """Determine whether the given object is a valid URL string.""" if not isinstance(obj, str) or not "://" in obj: return False if allow_file_urls and obj.startswith('file://'): return True try: res = requests.compat.urlparse(obj) if not res.scheme or not res.netloc or not "." in res.netloc: return False res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) if not res.scheme or not res.netloc or not "." in res.netloc: return False except: return False return True def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any: """Download the given URL and return a binary-mode file object to access the data.""" assert num_attempts >= 1 assert not (return_filename and (not cache)) # Doesn't look like an URL scheme so interpret it as a local filename. if not re.match('^[a-z]+://', url): return url if return_filename else open(url, "rb") # Handle file URLs. This code handles unusual file:// patterns that # arise on Windows: # # file:///c:/foo.txt # # which would translate to a local '/c:/foo.txt' filename that's # invalid. Drop the forward slash for such pathnames. # # If you touch this code path, you should test it on both Linux and # Windows. # # Some internet resources suggest using urllib.request.url2pathname() but # but that converts forward slashes to backslashes and this causes # its own set of problems. if url.startswith('file://'): filename = urllib.parse.urlparse(url).path if re.match(r'^/[a-zA-Z]:', filename): filename = filename[1:] return filename if return_filename else open(filename, "rb") assert is_url(url) # Lookup from cache. if cache_dir is None: cache_dir = make_cache_dir_path('downloads') url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() if cache: cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) if len(cache_files) == 1: filename = cache_files[0] return filename if return_filename else open(filename, "rb") # Download. url_name = None url_data = None with requests.Session() as session: if verbose: print("Downloading %s ..." % url, end="", flush=True) for attempts_left in reversed(range(num_attempts)): try: with session.get(url) as res: res.raise_for_status() if len(res.content) == 0: raise IOError("No data received") if len(res.content) < 8192: content_str = res.content.decode("utf-8") if "download_warning" in res.headers.get("Set-Cookie", ""): links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] if len(links) == 1: url = requests.compat.urljoin(url, links[0]) raise IOError("Google Drive virus checker nag") if "Google Drive - Quota exceeded" in content_str: raise IOError("Google Drive download quota exceeded -- please try again later") match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) url_name = match[1] if match else url url_data = res.content if verbose: print(" done") break except KeyboardInterrupt: raise except: if not attempts_left: if verbose: print(" failed") raise if verbose: print(".", end="", flush=True) # Save to cache. if cache: safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) os.makedirs(cache_dir, exist_ok=True) with open(temp_file, "wb") as f: f.write(url_data) os.replace(temp_file, cache_file) # atomic if return_filename: return cache_file # Return data as file object. assert not return_filename return io.BytesIO(url_data) ================================================ FILE: stylegan_human/docs/Dataset.md ================================================ # SHHQ Dataset ## Overview SHHQ is a dataset with high-quality full-body human images in a resolution of 1024 × 512. Since we need to follow a rigorous legal review in our institute, we can not release all of the data at once. For now, SHHQ-1.0 with 40K images is released! More data will be released in the later versions. ## Data Sources Images are collected in two main ways: 1) From the Internet. We developed a crawler tool with an official API, mainly downloading images from Flickr, Pixabay and Pexels. So you need to meet all the following licenses when using the dataset: CC0, [Pixabay License](https://pixabay.com/service/license/), and [Pexels Licenses](https://www.pexels.com/license/). 2) From the data providers. We purchased images from databases of individual photographers, modeling agencies and other suppliers. Images were reviewed by our legal team prior to purchase to ensure permission for use in research. ### Note: The composition of SHHQ-1.0: 1) Images obtained from the above sources. 2) Processed 9991 DeepFashion [[1]](#1) images (retain only full body images). 3) 1940 African images from the InFashAI [[2]](#2) dataset to increase data diversity. ## Data License We are aware of privacy concerns and seriously treat the license and privacy issues. All released data will be ensured under the license of CC0 and free for research use. Also, persons in the dataset are anonymised without additional private or sensitive metadata. ## Agreement The SHHQ is available for non-commercial research purposes only. You agree not to reproduce, duplicate, copy, sell, trade, resell or exploit any portion of the images and any portion of the derived data for commercial purposes. You agree NOT to further copy, publish or distribute any portion of SHHQ to any third party for any purpose. Except, for internal use at a single site within the same organization it is allowed to make copies of the dataset. Shanghai AI Lab reserves the right to terminate your access to the SHHQ at any time. ## Dataset Preview For those interested in our dataset, we provide a preview version with 100 images randomly sampled from SHHQ-1.0: [SHHQ-1.0_samples](https://drive.google.com/file/d/1tnNFfmFtzRbYL3qEnNXQ_ShaN9YV5tI5/view?usp=sharing). In SHHQ-1.0, we provide aligned raw images along with machine-calculated segmentation masks. Later we are planning to release manually annotated human-parsing version of these 40,000 images. Please stay tuned. > We also provide script [bg_white.py](../bg_white.py) to whiten the background of the raw image using its segmentation mask. If you want to access the full SHHQ-1.0, please read the following instructions. ## Model trained using SHHQ-1.0 | Structure | 1024x512 | Metric | Scores | 512x256 | Metric | Scores | | --------- |:----------:| :----------:| :----------:| :-----: | :-----: | :-----: | | StyleGAN1 | to be released | - | - | to be released | - | - | | StyleGAN2 | [SHHQ-1.0_sg2_1024.pkl](https://drive.google.com/file/d/1PuvE72xpc69Zq4y58dohuKbG9dFnnjEX/view?usp=sharing) | fid50k_full | 3.56 | [SHHQ-1.0_sg2_512.pkl](https://drive.google.com/file/d/170t2FRWxR8_TG3_y0nVtDBogLPOClnyf/view?usp=sharing) | fid50k_full | 3.68 | | StyleGAN3 | to be released | - | - |to be released | - | - | ## Download Instructions Please download the SHHQ Dataset Release Agreement from [link](./SHHQ_Dataset_Release_Agreement.pdf). Read it carefully, complete and sign it appropriately. Please send the completed form to Jianglin Fu (arlenefu@outlook.com) and Shikai Li (lishikai@pjlab.org.cn), and cc to Wayne Wu (wuwenyan0503@gmail.com) using institutional email address. The email Subject Title is "SHHQ Dataset Release Agreement". We will verify your request and contact you with the dataset link and password to unzip the image data. Note: 1. We are currently facing large incoming applications, and we need to carefully verify all the applicants, please be patient, and we will reply to you as soon as possible. 2. The signature in the agreement should be hand-written. ## References [1] Liu, Ziwei and Luo, Ping and Qiu, Shi and Wang, Xiaogang and Tang, Xiaoou. DeepFashion: Powering Robust Clothes Recognition and Retrieval with Rich Annotations. CVPR (2016) [2] Hacheme, Gilles and Sayouti, Noureini. Neural fashion image captioning: Accounting for data diversity. arXiv preprint arXiv:2106.12154 (2021) ================================================ FILE: stylegan_human/edit/__init__.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # empty ================================================ FILE: stylegan_human/edit/edit_config.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. attr_dict = dict( interface_gan={ # strength 'upper_length': [-1], # strength: negative for shorter, positive for longer 'bottom_length': [1] }, stylespace={ # layer, strength, threshold 'upper_length': [5, -5, 0.0028], # strength: negative for shorter, positive for longer 'bottom_length': [3, 5, 0.003] }, sefa={ # layer, strength 'upper_length': [[4, 5, 6, 7], 5], #-5 # strength: negative for longer, positive for shorter 'bottom_length': [[4, 5, 6, 7], 5] } ) ================================================ FILE: stylegan_human/edit/edit_helper.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. from legacy import save_obj, load_pkl import torch from torch.nn import functional as F import pandas as pd from .edit_config import attr_dict import os def conv_warper(layer, input, style, noise): # the conv should change conv = layer.conv batch, in_channel, height, width = input.shape style = style.view(batch, 1, in_channel, 1, 1) weight = conv.scale * conv.weight * style if conv.demodulate: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) weight = weight * demod.view(batch, conv.out_channel, 1, 1, 1) weight = weight.view( batch * conv.out_channel, in_channel, conv.kernel_size, conv.kernel_size ) if conv.upsample: input = input.view(1, batch * in_channel, height, width) weight = weight.view( batch, conv.out_channel, in_channel, conv.kernel_size, conv.kernel_size ) weight = weight.transpose(1, 2).reshape( batch * in_channel, conv.out_channel, conv.kernel_size, conv.kernel_size ) out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) _, _, height, width = out.shape out = out.view(batch, conv.out_channel, height, width) out = conv.blur(out) elif conv.downsample: input = conv.blur(input) _, _, height, width = input.shape input = input.view(1, batch * in_channel, height, width) out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) _, _, height, width = out.shape out = out.view(batch, conv.out_channel, height, width) else: input = input.view(1, batch * in_channel, height, width) out = F.conv2d(input, weight, padding=conv.padding, groups=batch) _, _, height, width = out.shape out = out.view(batch, conv.out_channel, height, width) out = layer.noise(out, noise=noise) out = layer.activate(out) return out def decoder(G, style_space, latent, noise): # an decoder warper for G out = G.input(latent) out = conv_warper(G.conv1, out, style_space[0], noise[0]) skip = G.to_rgb1(out, latent[:, 1]) i = 1 for conv1, conv2, noise1, noise2, to_rgb in zip( G.convs[::2], G.convs[1::2], noise[1::2], noise[2::2], G.to_rgbs ): out = conv_warper(conv1, out, style_space[i], noise=noise1) out = conv_warper(conv2, out, style_space[i+1], noise=noise2) skip = to_rgb(out, latent[:, i + 2], skip) i += 2 image = skip return image def encoder_ifg(G, noise, attr_name, truncation=1, truncation_latent=None, latent_dir='latent_direction/ss/', step=0, total=0, real=False): if not real: styles = [noise] styles = [G.style(s) for s in styles] style_space = [] if truncation<1: if not real: style_t = [] for style in styles: style_t.append(truncation_latent + truncation * (style - truncation_latent)) styles = style_t else: # styles are latent (tensor: 1,18,512), for real PTI output truncation_latent = truncation_latent.repeat(18,1).unsqueeze(0) # (1,512) --> (1,18,512) styles = torch.add(truncation_latent,torch.mul(torch.sub(noise,truncation_latent),truncation)) noise = [getattr(G.noises, 'noise_{}'.format(i)) for i in range(G.num_layers)] if not real: inject_index = G.n_latent latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) else: latent=styles style_space.append(G.conv1.conv.modulation(latent[:, 0])) i = 1 for conv1, conv2, noise1, noise2, to_rgb in zip( G.convs[::2], G.convs[1::2], noise[1::2], noise[2::2], G.to_rgbs ): style_space.append(conv1.conv.modulation(latent[:, i])) style_space.append(conv2.conv.modulation(latent[:, i+1])) i += 2 # get layer, strength by dict strength = attr_dict['interface_gan'][attr_name][0] if step != 0 and total != 0: strength = step / total * strength for i in range(15): style_vect = load_pkl(os.path.join(latent_dir, '{}/style_vect_mean_{}.pkl'.format(attr_name, i))) style_vect = torch.from_numpy(style_vect).to(latent.device).float() style_space[i] += style_vect * strength return style_space, latent, noise def encoder_ss(G, noise, attr_name, truncation=1, truncation_latent=None, statics_dir="latent_direction/ss_statics", latent_dir="latent_direction/ss/", step=0, total=0,real=False): if not real: styles = [noise] styles = [G.style(s) for s in styles] style_space = [] if truncation<1: if not real: style_t = [] for style in styles: style_t.append( truncation_latent + truncation * (style - truncation_latent) ) styles = style_t else: # styles are latent (tensor: 1,18,512), for real PTI output truncation_latent = truncation_latent.repeat(18,1).unsqueeze(0) # (1,512) --> (1,18,512) styles = torch.add(truncation_latent,torch.mul(torch.sub(noise,truncation_latent),truncation)) noise = [getattr(G.noises, 'noise_{}'.format(i)) for i in range(G.num_layers)] if not real: inject_index = G.n_latent latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) else: latent = styles style_space.append(G.conv1.conv.modulation(latent[:, 0])) i = 1 for conv1, conv2, noise1, noise2, to_rgb in zip( G.convs[::2], G.convs[1::2], noise[1::2], noise[2::2], G.to_rgbs ): style_space.append(conv1.conv.modulation(latent[:, i])) style_space.append(conv2.conv.modulation(latent[:, i+1])) i += 2 # get threshold, layer, strength by dict layer, strength, threshold = attr_dict['stylespace'][attr_name] statis_dir = os.path.join(statics_dir, "{}_statis/{}".format(attr_name, layer)) statis_csv_path = os.path.join(statis_dir, "statis.csv") statis_df = pd.read_csv(statis_csv_path) statis_df = statis_df.sort_values(by='channel', ascending=True) ch_mask = statis_df['strength'].values ch_mask = torch.from_numpy(ch_mask).to(latent.device).float() ch_mask = (ch_mask.abs()>threshold).float() style_vect = load_pkl(os.path.join(latent_dir, '{}/style_vect_mean_{}.pkl'.format(attr_name, layer))) style_vect = torch.from_numpy(style_vect).to(latent.device).float() style_vect = style_vect * ch_mask if step != 0 and total != 0: strength = step / total * strength style_space[layer] += style_vect * strength return style_space, latent, noise def encoder_sefa(G, noise, attr_name, truncation=1, truncation_latent=None, latent_dir='latent_direction/sefa/', step=0, total=0, real=False): if not real: styles = [noise] styles = [G.style(s) for s in styles] if truncation<1: if not real: style_t = [] for style in styles: style_t.append( truncation_latent + truncation * (style - truncation_latent) ) styles = style_t else: truncation_latent = truncation_latent.repeat(18,1).unsqueeze(0) # (1,512) --> (1,18,512) styles = torch.add(truncation_latent,torch.mul(torch.sub(noise,truncation_latent),truncation)) noise = [getattr(G.noises, 'noise_{}'.format(i)) for i in range(G.num_layers)] if not real: inject_index = G.n_latent latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) else: latent = styles layer, strength = attr_dict['sefa'][attr_name] sefa_vect = torch.load(os.path.join(latent_dir, '{}.pt'.format(attr_name))).to(latent.device).float() if step != 0 and total != 0: strength = step / total * strength for l in layer: latent[:, l, :] += (sefa_vect * strength * 2) return latent, noise ================================================ FILE: stylegan_human/edit.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. import os import sys import torch import numpy as np sys.path.append(".") from torch_utils.models import Generator import click import cv2 from typing import List, Optional import subprocess import legacy from edit.edit_helper import conv_warper, decoder, encoder_ifg, encoder_ss, encoder_sefa """ Edit generated images with different SOTA methods. Notes: 1. We provide some latent directions in the folder, you can play around with them. 2. ''upper_length'' and ''bottom_length'' of ''attr_name'' are available for demo. 3. Layers to control and editing strength are set in edit/edit_config.py. Examples: \b # Editing with InterfaceGAN, StyleSpace, and Sefa python edit.py --network pretrained_models/stylegan_human_v2_1024.pkl --attr_name upper_length \\ --seeds 61531,61570,61571,61610 --outdir outputs/edit_results # Editing using inverted latent code python edit.py ---network outputs/pti/checkpoints/model_test.pkl --attr_name upper_length \\ --outdir outputs/edit_results --real True --real_w_path outputs/pti/embeddings/test/PTI/test/0.pt --real_img_path aligned_image/test.png """ @click.command() @click.pass_context @click.option('--network', 'ckpt_path', help='Network pickle filename', required=True) @click.option('--attr_name', help='choose one of the attr: upper_length or bottom_length', type=str, required=True) @click.option('--trunc', 'truncation', type=float, help='Truncation psi', default=0.8, show_default=True) @click.option('--gen_video', type=bool, default=True, help='If want to generate video') @click.option('--combine', type=bool, default=True, help='If want to combine different editing results in the same frame') @click.option('--seeds', type=legacy.num_range, help='List of random seeds') @click.option('--outdir', help='Where to save the output images', type=str, required=True, default='outputs/editing', metavar='DIR') @click.option('--real', type=bool, help='True for editing real image', default=False) @click.option('--real_w_path', help='Path of latent code for real image') @click.option('--real_img_path', help='Path of real image, this just concat real image with inverted and edited results together') def main( ctx: click.Context, ckpt_path: str, attr_name: str, truncation: float, gen_video: bool, combine: bool, seeds: Optional[List[int]], outdir: str, real: str, real_w_path: str, real_img_path: str ): ## convert pkl to pth # if not os.path.exists(ckpt_path.replace('.pkl','.pth')): legacy.convert(ckpt_path, ckpt_path.replace('.pkl','.pth'), G_only=real) ckpt_path = ckpt_path.replace('.pkl','.pth') print("start...", flush=True) config = {"latent" : 512, "n_mlp" : 8, "channel_multiplier": 2} generator = Generator( size = 1024, style_dim=config["latent"], n_mlp=config["n_mlp"], channel_multiplier=config["channel_multiplier"] ) generator.load_state_dict(torch.load(ckpt_path)['g_ema']) generator.eval().cuda() with torch.no_grad(): mean_path = os.path.join('edit','mean_latent.pkl') if not os.path.exists(mean_path): mean_n = 3000 mean_latent = generator.mean_latent(mean_n).detach() legacy.save_obj(mean_latent, mean_path) else: mean_latent = legacy.load_pkl(mean_path).cuda() finals = [] ## -- selected sample seeds -- ## # seeds = [60948,60965,61174,61210,61511,61598,61610] #bottom -> long # [60941,61064,61103,61313,61531,61570,61571] # bottom -> short # [60941,60965,61064,61103,6117461210,61531,61570,61571,61610] # upper --> long # [60948,61313,61511,61598] # upper --> short if real: seeds = [0] for t in seeds: if real: # now assume process single real image only if real_img_path: real_image = cv2.imread(real_img_path) real_image = cv2.cvtColor(real_image, cv2.COLOR_BGR2RGB) import torchvision.transforms as transforms transform = transforms.Compose( # normalize to (-1, 1) [transforms.ToTensor(), transforms.Normalize(mean=(.5,.5,.5), std=(.5,.5,.5))] ) real_image = transform(real_image).unsqueeze(0).cuda() test_input = torch.load(real_w_path) output, _ = generator(test_input, False, truncation=1,input_is_latent=True, real=True) else: # generate image from random seeds test_input = torch.from_numpy(np.random.RandomState(t).randn(1, 512)).float().cuda() # torch.Size([1, 512]) output, _ = generator([test_input], False, truncation=truncation, truncation_latent=mean_latent, real=real) # interfacegan style_space, latent, noise = encoder_ifg(generator, test_input, attr_name, truncation, mean_latent,real=real) image1 = decoder(generator, style_space, latent, noise) # stylespace style_space, latent, noise = encoder_ss(generator, test_input, attr_name, truncation, mean_latent,real=real) image2 = decoder(generator, style_space, latent, noise) # sefa latent, noise = encoder_sefa(generator, test_input, attr_name, truncation, mean_latent,real=real) image3, _ = generator([latent], noise=noise, input_is_latent=True) if real_img_path: final = torch.cat((real_image, output, image1, image2, image3), 3) else: final = torch.cat((output, image1, image2, image3), 3) # legacy.visual(output, f'{outdir}/{attr_name}_{t:05d}_raw.jpg') # legacy.visual(image1, f'{outdir}/{attr_name}_{t:05d}_ifg.jpg') # legacy.visual(image2, f'{outdir}/{attr_name}_{t:05d}_ss.jpg') # legacy.visual(image3, f'{outdir}/{attr_name}_{t:05d}_sefa.jpg') if gen_video: total_step = 90 if real: video_ifg_path = f"{outdir}/video/ifg_{attr_name}_{real_w_path.split('/')[-2]}/" video_ss_path = f"{outdir}/video/ss_{attr_name}_{real_w_path.split('/')[-2]}/" video_sefa_path = f"{outdir}/video/ss_{attr_name}_{real_w_path.split('/')[-2]}/" else: video_ifg_path = f"{outdir}/video/ifg_{attr_name}_{t:05d}/" video_ss_path = f"{outdir}/video/ss_{attr_name}_{t:05d}/" video_sefa_path = f"{outdir}/video/ss_{attr_name}_{t:05d}/" video_comb_path = f"{outdir}/video/tmp" if combine: if not os.path.exists(video_comb_path): os.makedirs(video_comb_path) else: if not os.path.exists(video_ifg_path): os.makedirs(video_ifg_path) if not os.path.exists(video_ss_path): os.makedirs(video_ss_path) if not os.path.exists(video_sefa_path): os.makedirs(video_sefa_path) for i in range(total_step): style_space, latent, noise = encoder_ifg(generator, test_input, attr_name, truncation, mean_latent, step=i, total=total_step,real=real) image1 = decoder(generator, style_space, latent, noise) style_space, latent, noise = encoder_ss(generator, test_input, attr_name, truncation, mean_latent, step=i, total=total_step,real=real) image2 = decoder(generator, style_space, latent, noise) latent, noise = encoder_sefa(generator, test_input, attr_name, truncation, mean_latent, step=i, total=total_step,real=real) image3, _ = generator([latent], noise=noise, input_is_latent=True) if combine: if real_img_path: comb_img = torch.cat((real_image, output, image1, image2, image3), 3) else: comb_img = torch.cat((output, image1, image2, image3), 3) legacy.visual(comb_img, os.path.join(video_comb_path, f'{i:05d}.jpg')) else: legacy.visual(image1, os.path.join(video_ifg_path, f'{i:05d}.jpg')) legacy.visual(image2, os.path.join(video_ss_path, f'{i:05d}.jpg')) if combine: cmd=f"ffmpeg -hide_banner -loglevel error -y -r 30 -i {video_comb_path}/%05d.jpg -vcodec libx264 -pix_fmt yuv420p {video_ifg_path.replace('ifg_', '')[:-1] + '.mp4'}" subprocess.call(cmd, shell=True) else: cmd=f"ffmpeg -hide_banner -loglevel error -y -r 30 -i {video_ifg_path}/%05d.jpg -vcodec libx264 -pix_fmt yuv420p {video_ifg_path[:-1] + '.mp4'}" subprocess.call(cmd, shell=True) cmd=f"ffmpeg -hide_banner -loglevel error -y -r 30 -i {video_ss_path}/%05d.jpg -vcodec libx264 -pix_fmt yuv420p {video_ss_path[:-1] + '.mp4'}" subprocess.call(cmd, shell=True) # interfacegan, stylespace, sefa finals.append(final) final = torch.cat(finals, 2) legacy.visual(final, os.path.join(outdir,'final.jpg')) if __name__ == "__main__": main() ================================================ FILE: stylegan_human/environment.yml ================================================ name: stylehuman channels: - pytorch - nvidia dependencies: - python == 3.8 - pip - numpy>=1.20 - click>=8.0 - pillow=8.3.1 - scipy=1.7.1 - pytorch=1.9.1 - cudatoolkit=11.1 - requests=2.26.0 - tqdm=4.62.2 - ninja=1.10.2 - matplotlib=3.4.2 - imageio=2.9.0 - pip: - imgui==1.3.0 - glfw==2.2.0 - pyopengl==3.1.5 - imageio-ffmpeg==0.4.3 - lpips==0.1.4 - pyspng - dlib - opencv-python - pandas - moviepy - imutils ================================================ FILE: stylegan_human/generate.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. # This work is made available under the Nvidia Source Code License-NC. # To view a copy of this license, visit # https://nvlabs.github.io/stylegan2/license.html ## this script is for generating images from pre-trained network based on StyleGAN1 (TensorFlow) and StyleGAN2-ada (PyTorch) ## import os import click import dnnlib import numpy as np import PIL.Image import legacy from typing import List, Optional """ Generate images using pretrained network pickle. Examples: \b # Generate human full-body images without truncation python generate.py --outdir=outputs/generate/stylegan_human_v2_1024 --trunc=1 --seeds=1,3,5,7 \\ --network=pretrained_models/stylegan_human_v2_1024.pkl --version 2 \b # Generate human full-body images with truncation python generate.py --outdir=outputs/generate/stylegan_human_v2_1024 --trunc=0.8 --seeds=0-100\\ --network=pretrained_models/stylegan_human_v2_1024.pkl --version 2 # \b # Generate human full-body images using stylegan V1 # python generate.py --outdir=outputs/generate/stylegan_human_v1_1024 \\ # --network=pretrained_models/stylegan_human_v1_1024.pkl --version 1 """ @click.command() @click.pass_context @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) @click.option('--seeds', type=legacy.num_range, help='List of random seeds') @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) @click.option('--outdir', help='Where to save the output images', default= 'outputs/generate/' , type=str, required=True, metavar='DIR') @click.option('--version', help="stylegan version, 1, 2 or 3", type=int, default=2) def generate_images( ctx: click.Context, network_pkl: str, seeds: Optional[List[int]], truncation_psi: float, noise_mode: str, outdir: str, version: int ): print('Loading networks from "%s"...' % network_pkl) if version == 1: import dnnlib.tflib as tflib tflib.init_tf() G, D, Gs = legacy.load_pkl(network_pkl) else: import torch device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') dtype = torch.float32 if device.type == 'mps' else torch.float64 with dnnlib.util.open_url(network_pkl) as f: G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype) # type: ignore os.makedirs(outdir, exist_ok=True) if seeds is None: ctx.fail('--seeds option is required.') # Generate images. target_z = np.array([]) target_w = np.array([]) latent_out = outdir.replace('/images/','') for seed_idx, seed in enumerate(seeds): if seed % 5000 == 0: print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) if version == 1: ## stylegan v1 z = np.random.RandomState(seed).randn(1, Gs.input_shape[1]) # Generate image. fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) if noise_mode == 'const': randomize_noise=False else: randomize_noise = True images = Gs.run(z, None, truncation_psi=truncation_psi, randomize_noise=randomize_noise, output_transform=fmt) PIL.Image.fromarray(images[0], 'RGB').save(f'{outdir}/seed{seed:04d}.png') else: ## stylegan v2/v3 label = torch.zeros([1, G.c_dim], device=device) z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device, dtype=dtype) if target_z.size==0: target_z= z.cpu() else: target_z=np.append(target_z, z.cpu(), axis=0) w = G.mapping(z, label,truncation_psi=truncation_psi) img = G.synthesis(w, noise_mode=noise_mode,force_fp32 = True) if target_w.size==0: target_w= w.cpu() else: target_w=np.append(target_w, w.cpu(), axis=0) img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png') # print(target_z) # print(target_z.shape,target_w.shape) #---------------------------------------------------------------------------- if __name__ == "__main__": generate_images() #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/insetgan.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. import torch import torch.nn.functional as F from tqdm import tqdm from lpips import LPIPS import numpy as np from torch_utils.models import Generator as bodyGAN from torch_utils.models_face import Generator as FaceGAN import dlib from utils.face_alignment import align_face_for_insetgan from utils.util import visual,tensor_to_numpy, numpy_to_tensor import legacy import os import click class InsetGAN(torch.nn.Module): def __init__(self, stylebody_ckpt, styleface_ckpt): super().__init__() ## convert pkl to pth if not os.path.exists(stylebody_ckpt.replace('.pkl','.pth')): legacy.convert(stylebody_ckpt, stylebody_ckpt.replace('.pkl','.pth')) stylebody_ckpt = stylebody_ckpt.replace('.pkl','.pth') if not os.path.exists(styleface_ckpt.replace('.pkl','.pth')): legacy.convert(styleface_ckpt, styleface_ckpt.replace('.pkl','.pth')) styleface_ckpt = styleface_ckpt.replace('.pkl','.pth') # dual generator config = {"latent" : 512, "n_mlp" : 8, "channel_multiplier": 2} self.body_generator = bodyGAN( size = 1024, style_dim=config["latent"], n_mlp=config["n_mlp"], channel_multiplier=config["channel_multiplier"] ) self.body_generator.load_state_dict(torch.load(stylebody_ckpt)['g_ema']) self.body_generator.eval().requires_grad_(False).cuda() self.face_generator = FaceGAN( size = 1024, style_dim=config["latent"], n_mlp=config["n_mlp"], channel_multiplier=config["channel_multiplier"] ) self.face_generator.load_state_dict(torch.load(styleface_ckpt)['g_ema']) self.face_generator.eval().requires_grad_(False).cuda() # crop function self.dlib_predictor = dlib.shape_predictor('./pretrained_models/shape_predictor_68_face_landmarks.dat') self.dlib_cnn_face_detector = dlib.cnn_face_detection_model_v1("pretrained_models/mmod_human_face_detector.dat") # criterion self.lpips_loss = LPIPS(net='alex').cuda().eval() self.l1_loss = torch.nn.L1Loss(reduction='mean') def loss_coarse(self, A_face, B, p1=500, p2=0.05): A_face = F.interpolate(A_face, size=(64, 64), mode='area') B = F.interpolate(B, size=(64, 64), mode='area') loss_l1 = p1 * self.l1_loss(A_face, B) loss_lpips = p2 * self.lpips_loss(A_face, B) return loss_l1 + loss_lpips @staticmethod def get_border_mask(A, x, spec): mask = torch.zeros_like(A) mask[:, :, :x, ] = 1 mask[:, :, -x:, ] = 1 mask[:, :, :, :x ] = 1 mask[:, :, :, -x:] = 1 return mask @staticmethod def get_body_mask(A, crop, padding=4): mask = torch.ones_like(A) mask[:, :, crop[1]-padding:crop[3]+padding, crop[0]-padding:crop[2]+padding] = 0 return mask def loss_border(self, A_face, B, p1=10000, p2=2, spec=None): mask = self.get_border_mask(A_face, 8, spec) loss_l1 = p1 * self.l1_loss(A_face*mask, B*mask) loss_lpips = p2 * self.lpips_loss(A_face*mask, B*mask) return loss_l1 + loss_lpips def loss_body(self, A, B, crop, p1=9000, p2=0.1): padding = int((crop[3] - crop[1]) / 20) mask = self.get_body_mask(A, crop, padding) loss_l1 = p1 * self.l1_loss(A*mask, B*mask) loss_lpips = p2 * self.lpips_loss(A*mask, B*mask) return loss_l1+loss_lpips def loss_face(self, A, B, crop, p1=5000, p2=1.75): mask = 1 - self.get_body_mask(A, crop) loss_l1 = p1 * self.l1_loss(A*mask, B*mask) loss_lpips = p2 * self.lpips_loss(A*mask, B*mask) return loss_l1+loss_lpips def loss_reg(self, w, w_mean, p1, w_plus_delta=None, p2=None): return p1 * torch.mean(((w - w_mean) ** 2)) + p2 * torch.mean(w_plus_delta ** 2) # FFHQ type def detect_face_dlib(self, img): # tensor to numpy array rgb uint8 img = tensor_to_numpy(img) aligned_image, crop, rect = align_face_for_insetgan(img=img, detector=self.dlib_cnn_face_detector, predictor=self.dlib_predictor, output_size=256) aligned_image = np.array(aligned_image) aligned_image = numpy_to_tensor(aligned_image) return aligned_image, crop, rect # joint optimization def dual_optimizer(self, face_w, body_w, joint_steps=500, face_initial_learning_rate=0.02, body_initial_learning_rate=0.05, lr_rampdown_length=0.25, lr_rampup_length=0.05, seed=None, output_path=None, video=0): ''' Given a face_w, optimize a body_w with suitable body pose & shape for face_w ''' def visual_(path, synth_body, synth_face, body_crop, step, both=False, init_body_with_face=None): tmp = synth_body.clone().detach() tmp[:, :, body_crop[1]:body_crop[3], body_crop[0]:body_crop[2]] = synth_face if both: tmp = torch.cat([synth_body, tmp], dim=3) save_path = os.path.join(path, f"{step:04d}.jpg") visual(tmp, save_path) def forward(face_w_opt, body_w_opt, face_w_delta, body_w_delta, body_crop, update_crop=False ): if face_w_opt.shape[1] != 18: face_ws = (face_w_opt).repeat([1, 18, 1]) else: face_ws = face_w_opt.clone() face_ws = face_ws + face_w_delta synth_face, _ = self.face_generator([face_ws], input_is_latent=True, randomize_noise=False) body_ws = (body_w_opt).repeat([1, 18, 1]) body_ws = body_ws + body_w_delta synth_body, _ = self.body_generator([body_ws], input_is_latent=True, randomize_noise=False) if update_crop: old_r = (body_crop[3]-body_crop[1]) // 2, (body_crop[2]-body_crop[0]) // 2 _, body_crop, _ = self.detect_face_dlib(synth_body) center = (body_crop[1] + body_crop[3]) // 2, (body_crop[0] + body_crop[2]) // 2 body_crop = (center[1] - old_r[1], center[0] - old_r[0], center[1] + old_r[1], center[0] + old_r[0]) synth_body_face = synth_body[:, :, body_crop[1]:body_crop[3], body_crop[0]:body_crop[2]] if synth_face.shape[2] > body_crop[3]-body_crop[1]: synth_face_resize = F.interpolate(synth_face, size=(body_crop[3]-body_crop[1], body_crop[2]-body_crop[0]), mode='area') return synth_body, synth_body_face, synth_face, synth_face_resize, body_crop def update_lr(init_lr, step, num_steps, lr_rampdown_length, lr_rampup_length): t = step / num_steps lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length) lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi) lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length) lr = init_lr * lr_ramp return lr # update output_path output_path = os.path.join(output_path, seed) os.makedirs(output_path, exist_ok=True) # define optimized params body_w_mean = self.body_generator.mean_latent(10000).detach() face_w_opt = face_w.clone().detach().requires_grad_(True) body_w_opt = body_w.clone().detach().requires_grad_(True) face_w_delta = torch.zeros_like(face_w.repeat([1, 18, 1])).requires_grad_(True) body_w_delta = torch.zeros_like(body_w.repeat([1, 18, 1])).requires_grad_(True) # generate ref face & body ref_body, _ = self.body_generator([body_w.repeat([1, 18, 1])], input_is_latent=True, randomize_noise=False) # for inversion ref_face, _ = self.face_generator([face_w.repeat([1, 18, 1])], input_is_latent=True, randomize_noise=False) # get initilized crop _, body_crop, _ = self.detect_face_dlib(ref_body) _, _, face_crop = self.detect_face_dlib(ref_face) # NOTE: this is face rect only. no FFHQ type. # create optimizer face_optimizer = torch.optim.Adam([face_w_opt, face_w_delta], betas=(0.9, 0.999), lr=face_initial_learning_rate) body_optimizer = torch.optim.Adam([body_w_opt, body_w_delta], betas=(0.9, 0.999), lr=body_initial_learning_rate) global_step = 0 # Stage1: remove background of face image face_steps = 25 pbar = tqdm(range(face_steps)) for step in pbar: face_lr = update_lr(face_initial_learning_rate / 2, step, face_steps, lr_rampdown_length, lr_rampup_length) for param_group in face_optimizer.param_groups: param_group['lr'] =face_lr synth_body, synth_body_face, synth_face_raw, synth_face, body_crop = forward(face_w_opt, body_w_opt, face_w_delta, body_w_delta, body_crop) loss_face = self.loss_face(synth_face_raw, ref_face, face_crop, 5000, 1.75) loss_coarse = self.loss_coarse(synth_face, synth_body_face, 50, 0.05) loss_border = self.loss_border(synth_face, synth_body_face, 1000, 0.1) loss = loss_coarse + loss_border + loss_face face_optimizer.zero_grad() loss.backward() face_optimizer.step() # visualization if video: visual_(output_path, synth_body, synth_face, body_crop, global_step) pbar.set_description( ( f"face: {step:.4f}, lr: {face_lr}, loss: {loss.item():.2f}, loss_coarse: {loss_coarse.item():.2f};" f"loss_border: {loss_border.item():.2f}, loss_face: {loss_face.item():.2f};" ) ) global_step += 1 # Stage2: find a suitable body body_steps = 150 pbar = tqdm(range(body_steps)) for step in pbar: body_lr = update_lr(body_initial_learning_rate, step, body_steps, lr_rampdown_length, lr_rampup_length) update_crop = True if (step % 50 == 0) else False # update_crop = False for param_group in body_optimizer.param_groups: param_group['lr'] =body_lr synth_body, synth_body_face, synth_face_raw, synth_face, body_crop = forward(face_w_opt, body_w_opt, face_w_delta, body_w_delta, body_crop, update_crop=update_crop) loss_coarse = self.loss_coarse(synth_face, synth_body_face, 500, 0.05) loss_border = self.loss_border(synth_face, synth_body_face, 2500, 0) loss_body = self.loss_body(synth_body, ref_body, body_crop, 9000, 0.1) loss_reg = self.loss_reg(body_w_opt, body_w_mean, 15000, body_w_delta, 0) loss = loss_coarse + loss_border + loss_body + loss_reg body_optimizer.zero_grad() loss.backward() body_optimizer.step() # visualization if video: visual_(output_path, synth_body, synth_face, body_crop, global_step) pbar.set_description( ( f"body: {step:.4f}, lr: {body_lr}, loss: {loss.item():.2f}, loss_coarse: {loss_coarse.item():.2f};" f"loss_border: {loss_border.item():.2f}, loss_body: {loss_body.item():.2f}, loss_reg: {loss_reg:.2f}" ) ) global_step += 1 # Stage3: joint optimization interval = 50 joint_face_steps = joint_steps // 2 joint_body_steps = joint_steps // 2 face_step = 0 body_step = 0 pbar = tqdm(range(joint_steps)) flag = -1 for step in pbar: if step % interval == 0: flag += 1 text_flag = 'optimize_face' if flag % 2 == 0 else 'optimize_body' synth_body, synth_body_face, synth_face_raw, synth_face, body_crop = forward(face_w_opt, body_w_opt, face_w_delta, body_w_delta, body_crop) if text_flag == 'optimize_face': face_lr = update_lr(face_initial_learning_rate, face_step, joint_face_steps, lr_rampdown_length, lr_rampup_length) for param_group in face_optimizer.param_groups: param_group['lr'] =face_lr loss_face = self.loss_face(synth_face_raw, ref_face, face_crop, 5000, 1.75) loss_coarse = self.loss_coarse(synth_face, synth_body_face, 500, 0.05) loss_border = self.loss_border(synth_face, synth_body_face, 25000, 0) loss = loss_coarse + loss_border + loss_face face_optimizer.zero_grad() loss.backward() face_optimizer.step() pbar.set_description( ( f"face: {step}, lr: {face_lr:.4f}, loss: {loss.item():.2f}, loss_coarse: {loss_coarse.item():.2f};" f"loss_border: {loss_border.item():.2f}, loss_face: {loss_face.item():.2f};" ) ) face_step += 1 else: body_lr = update_lr(body_initial_learning_rate, body_step, joint_body_steps, lr_rampdown_length, lr_rampup_length) for param_group in body_optimizer.param_groups: param_group['lr'] =body_lr loss_coarse = self.loss_coarse(synth_face, synth_body_face, 500, 0.05) loss_border = self.loss_border(synth_face, synth_body_face, 2500, 0) loss_body = self.loss_body(synth_body, ref_body, body_crop, 9000, 0.1) loss_reg = self.loss_reg(body_w_opt, body_w_mean, 25000, body_w_delta, 0) loss = loss_coarse + loss_border + loss_body + loss_reg body_optimizer.zero_grad() loss.backward() body_optimizer.step() pbar.set_description( ( f"body: {step}, lr: {body_lr:.4f}, loss: {loss.item():.2f}, loss_coarse: {loss_coarse.item():.2f};" f"loss_border: {loss_border.item():.2f}, loss_body: {loss_body.item():.2f}, loss_reg: {loss_reg:.2f}" ) ) body_step += 1 if video: visual_(output_path, synth_body, synth_face, body_crop, global_step) global_step += 1 return face_w_opt.repeat([1, 18, 1])+face_w_delta, body_w_opt.repeat([1, 18, 1])+body_w_delta, body_crop """ Jointly combine and optimize generated faces and bodies . Examples: \b # Combine the generate human full-body image from the provided StyleGAN-Human pre-trained model # and the generated face image from FFHQ model, optimize both latent codes to produce the coherent face-body image python insetgan.py --body_network=pretrained_models/stylegan_human_v2_1024.pkl --face_network=pretrained_models/ffhq.pkl \\ --body_seed=82 --face_seed=43 --trunc=0.6 --outdir=outputs/insetgan/ --video 1 """ @click.command() @click.pass_context @click.option('--face_network', default="./pretrained_models/ffhq.pkl", help='Network pickle filename', required=True) @click.option('--body_network', default='./pretrained_models/stylegan2_1024.pkl', help='Network pickle filename', required=True) @click.option('--face_seed', type=int, default=82, help='selected random seed') @click.option('--body_seed', type=int, default=43, help='selected random seed') @click.option('--joint_steps', type=int, default=500, help='num steps for joint optimization') @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=0.6, show_default=True) @click.option('--outdir', help='Where to save the output images', default= "outputs/insetgan/" , type=str, required=True, metavar='DIR') @click.option('--video', help="set to 1 if want to save video", type=int, default=0) def main( ctx: click.Context, face_network: str, body_network: str, face_seed: int, body_seed: int, joint_steps: int, truncation_psi: float, outdir: str, video: int): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") insgan = InsetGAN(body_network, face_network) os.makedirs(outdir, exist_ok=True) face_z = np.random.RandomState(face_seed).randn(1, 512).astype(np.float32) face_mean = insgan.face_generator.mean_latent(3000) face_w = insgan.face_generator.get_latent(torch.from_numpy(face_z).to(device)) # [N, L, C] face_w = truncation_psi * face_w + (1-truncation_psi) * face_mean face_img, _ = insgan.face_generator([face_w], input_is_latent=True) body_z = np.random.RandomState(body_seed).randn(1, 512).astype(np.float32) body_mean = insgan.body_generator.mean_latent(3000) body_w = insgan.body_generator.get_latent(torch.from_numpy(body_z).to(device)) # [N, L, C] body_w = truncation_psi * body_w + (1-truncation_psi) * body_mean body_img, _ = insgan.body_generator([body_w], input_is_latent=True) _, body_crop, _ = insgan.detect_face_dlib(body_img) face_img = F.interpolate(face_img, size=(body_crop[3]-body_crop[1], body_crop[2]-body_crop[0]), mode='area') cp_body = body_img.clone() cp_body[:, :, body_crop[1]:body_crop[3], body_crop[0]:body_crop[2]] = face_img optim_face_w, optim_body_w, crop = insgan.dual_optimizer( face_w, body_w, joint_steps=joint_steps, seed=f'{face_seed:04d}_{body_seed:04d}', output_path=outdir, video=video ) if video: ffmpeg_cmd = f"ffmpeg -hide_banner -loglevel error -i ./{outdir}/{face_seed:04d}_{body_seed:04d}/%04d.jpg -c:v libx264 -vf fps=30 -pix_fmt yuv420p ./{outdir}/{face_seed:04d}_{body_seed:04d}.mp4" os.system(ffmpeg_cmd) new_face_img, _ = insgan.face_generator([optim_face_w], input_is_latent=True) new_shape = crop[3] - crop[1], crop[2] - crop[0] new_face_img_crop = F.interpolate(new_face_img, size=new_shape, mode='area') seamless_body, _ = insgan.body_generator([optim_body_w], input_is_latent=True) seamless_body[:, :, crop[1]:crop[3], crop[0]:crop[2]] = new_face_img_crop temp = torch.cat([cp_body, seamless_body], dim=3) visual(temp, f"{outdir}/{face_seed:04d}_{body_seed:04d}.png") if __name__ == "__main__": main() ================================================ FILE: stylegan_human/interpolation.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. ## interpolate between two z code ## score all middle latent code # https://www.aiuai.cn/aifarm1929.html import os import re from typing import List from tqdm import tqdm import click import dnnlib import numpy as np import PIL.Image import torch import click import legacy import random from typing import List, Optional def lerp(code1, code2, alpha): return code1 * alpha + code2 * (1 - alpha) # Taken and adapted from wikipedia's slerp article # https://en.wikipedia.org/wiki/Slerp def slerp(code1, code2, alpha, DOT_THRESHOLD=0.9995): # Spherical linear interpolation code1_copy = np.copy(code1) code2_copy = np.copy(code2) code1 = code1 / np.linalg.norm(code1) code2 = code2 / np.linalg.norm(code2) dot = np.sum(code1 * code2) if np.abs(dot) > DOT_THRESHOLD: return lerp(code1_copy, code2_copy, alpha) # Calculate initial angle between v0 and v1 theta_0 = np.arccos(dot) sin_theta_0 = np.sin(theta_0) # Angle at timestep t theta_t = theta_0 * alpha sin_theta_t = np.sin(theta_t) s0 = np.sin(theta_0 - theta_t) / sin_theta_0 s1 = sin_theta_t / sin_theta_0 code3 = s0 * code1_copy + s1 * code2_copy return code3 def generate_image_from_z(G, z, noise_mode, truncation_psi, device): label = torch.zeros([1, G.c_dim], device=device) w = G.mapping(z, label,truncation_psi=truncation_psi) img = G.synthesis(w, noise_mode=noise_mode,force_fp32 = True) img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB') return img def get_concat_h(im1, im2): dst = PIL.Image.new('RGB', (im1.width + im2.width, im1.height)) dst.paste(im1, (0, 0)) dst.paste(im2, (im1.width, 0)) return dst def make_latent_interp_animation(G, code1, code2, img1, img2, num_interps, noise_mode, save_mid_image, truncation_psi,device, outdir,fps): step_size = 1.0/num_interps all_imgs = [] amounts = np.arange(0, 1, step_size) for seed_idx, alpha in enumerate(tqdm(amounts)): interpolated_latent_code = lerp(code1, code2, alpha) image = generate_image_from_z(G,interpolated_latent_code, noise_mode, truncation_psi, device) interp_latent_image = image.resize((512, 1024)) if not os.path.exists(os.path.join(outdir,'img')): os.makedirs(os.path.join(outdir,'img'), exist_ok=True) if save_mid_image: interp_latent_image.save(f'{outdir}/img/seed{seed_idx:04d}.png') frame = get_concat_h(img2, interp_latent_image) frame = get_concat_h(frame, img1) all_imgs.append(frame) save_name = os.path.join(outdir,'latent_space_traversal.gif') all_imgs[0].save(save_name, save_all=True, append_images=all_imgs[1:], duration=1000/fps, loop=0) """ Create interpolated images between two given seeds using pretrained network pickle. Examples: \b python interpolation.py --network=pretrained_models/stylegan_human_v2_1024.pkl --seeds=85,100 --outdir=outputs/inter_gifs """ @click.command() @click.pass_context @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) @click.option('--seeds', type=legacy.num_range, help='List of 2 random seeds, e.g. 1,2') @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=0.8, show_default=True) @click.option('--noise-mode', 'noise_mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) @click.option('--outdir', default= 'outputs/inter_gifs', help='Where to save the output images', type=str, required=True, metavar='DIR') @click.option('--save_mid_image', default=True, type=bool, help='select True if you want to save all interpolated images') @click.option('--fps', default= 15, help='FPS for GIF', type=int) @click.option('--num_interps', default= 100, help='Number of interpolation images', type=int) def main( ctx: click.Context, network_pkl: str, seeds: Optional[List[int]], truncation_psi: float, noise_mode: str, outdir: str, save_mid_image: bool, fps:int, num_interps:int ): device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') dtype = torch.float32 if device.type == 'mps' else torch.float64 with dnnlib.util.open_url(network_pkl) as f: G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype) # type: ignore outdir = os.path.join(outdir) if not os.path.exists(outdir): os.makedirs(outdir,exist_ok=True) os.makedirs(os.path.join(outdir,'img'),exist_ok=True) if len(seeds) > 2: print("Receiving more than two seeds, only use the first two.") seeds = seeds[0:2] elif len(seeds) == 1: print('Require two seeds, randomly generate two now.') seeds = [seeds[0],random.randint(0,10000)] z1 = torch.from_numpy(np.random.RandomState(seeds[0]).randn(1, G.z_dim)).to(device, dtype=dtype) z2 = torch.from_numpy(np.random.RandomState(seeds[1]).randn(1, G.z_dim)).to(device, dtype=dtype) img1 = generate_image_from_z(G, z1, noise_mode, truncation_psi, device) img2 = generate_image_from_z(G, z2, noise_mode, truncation_psi, device) img1.save(f'{outdir}/seed{seeds[0]:04d}.png') img2.save(f'{outdir}/seed{seeds[1]:04d}.png') make_latent_interp_animation(G, z1, z2, img1, img2, num_interps, noise_mode, save_mid_image, truncation_psi, device, outdir, fps) if __name__ == "__main__": main() ================================================ FILE: stylegan_human/latent_direction/ss_statics/bottom_length_statis/3/statis.csv ================================================ ,channel,strength 401,401,0.0051189866 79,79,0.004417926 499,499,0.0042351373 272,272,0.0033855115 2,2,0.003143758 267,267,0.0025972966 510,510,0.0025229468 130,130,0.0022487796 228,228,0.0021741684 101,101,0.001948409 418,418,0.0018696061 481,481,0.0017976156 88,88,0.0017784507 58,58,0.0017771542 116,116,0.0017733901 282,282,0.0017370607 207,207,0.0017357969 429,429,0.0016839057 284,284,0.0016397897 139,139,0.0016203154 242,242,0.0016087457 319,319,0.0015855831 237,237,0.0015757639 475,475,0.001572773 427,427,0.001526569 276,276,0.0015258243 163,163,0.0014731362 460,460,0.0014698224 268,268,0.0014479166 87,87,0.0013900393 486,486,0.0013423131 367,367,0.0013412643 129,129,0.0013402662 448,448,0.0013169902 438,438,0.0012944449 463,463,0.001292959 109,109,0.0012898487 197,197,0.0012845552 215,215,0.0012809597 419,419,0.0012448858 170,170,0.0012249282 46,46,0.0012235934 191,191,0.0012160796 6,6,0.0012132789 292,292,0.0012097002 174,174,0.0011935516 198,198,0.0011886621 450,450,0.0011825112 334,334,0.0011808838 134,134,0.0011740165 297,297,0.0011682409 388,388,0.0011680947 4,4,0.0011665267 96,96,0.001155285 144,144,0.0011436475 383,383,0.0011424082 472,472,0.0011330546 200,200,0.0011163615 126,126,0.0011128024 7,7,0.0011117855 149,149,0.0011032915 142,142,0.0010992303 108,108,0.0010948912 55,55,0.0010793228 35,35,0.0010671945 156,156,0.0010658612 75,75,0.0010604868 497,497,0.0010573129 333,333,0.0010556887 346,346,0.0010499252 259,259,0.0010397168 33,33,0.0010339752 196,196,0.0010080026 321,321,0.0010073334 169,169,0.001006006 187,187,0.0010003602 421,421,0.0009871207 347,347,0.0009822787 495,495,0.0009788647 235,235,0.00097607024 313,313,0.000972718 316,316,0.00096160895 32,32,0.0009501351 365,365,0.0009465426 50,50,0.0009324631 309,309,0.0009274587 461,461,0.0009274281 439,439,0.0009251979 140,140,0.00091394293 220,220,0.0009082752 482,482,0.0009080755 430,430,0.0009043302 218,218,0.0009004896 143,143,0.00089990336 99,99,0.0008916955 70,70,0.00089066115 168,168,0.0008892674 209,209,0.00088266545 391,391,0.0008787587 137,137,0.00087609346 369,369,0.00087306765 355,355,0.0008672569 354,354,0.0008614661 352,352,0.0008591456 359,359,0.00085570634 258,258,0.00084690325 385,385,0.0008433214 296,296,0.0008431721 153,153,0.0008382941 17,17,0.0008377292 186,186,0.00083413016 162,162,0.00083256833 473,473,0.00082412886 47,47,0.0008205253 89,89,0.0008157718 283,283,0.00080831826 351,351,0.0008064204 124,124,0.00080594653 457,457,0.0008049854 188,188,0.00078665616 154,154,0.00078267895 120,120,0.0007801328 190,190,0.0007771554 85,85,0.0007743246 22,22,0.00076702645 266,266,0.00076506665 227,227,0.0007623983 21,21,0.0007611779 295,295,0.00075573527 476,476,0.00074219145 115,115,0.0007414153 363,363,0.000735065 44,44,0.00073410827 83,83,0.0007331246 118,118,0.0007294682 511,511,0.00072840624 322,322,0.00072799233 483,483,0.0007264888 219,219,0.0007200153 274,274,0.0007194695 455,455,0.0007167979 509,509,0.0007164892 412,412,0.00071247795 239,239,0.00071194395 3,3,0.00070765684 420,420,0.0007068611 53,53,0.0007004219 173,173,0.00069374195 480,480,0.0006935094 189,189,0.0006923614 80,80,0.0006903429 141,141,0.00068950764 208,208,0.00068684825 474,474,0.00068534614 386,386,0.00068098377 107,107,0.00068088406 504,504,0.0006799584 328,328,0.00067885965 307,307,0.00067760365 64,64,0.0006768076 362,362,0.00067168014 86,86,0.000671517 279,279,0.0006705279 361,361,0.00066772755 175,175,0.0006668292 16,16,0.0006653719 345,345,0.00066359126 372,372,0.00066246424 380,380,0.0006578692 330,330,0.00065479317 470,470,0.00065403903 43,43,0.00065164984 205,205,0.0006473879 294,294,0.00063388806 357,357,0.00063241523 36,36,0.00063216814 68,68,0.00063114305 57,57,0.0006299799 213,213,0.00062879996 210,210,0.0006244437 49,49,0.00062352256 241,241,0.00062221487 487,487,0.0006212713 82,82,0.00062058726 466,466,0.0006202627 395,395,0.0006191936 72,72,0.0006174072 158,158,0.0006166083 437,437,0.00061549625 113,113,0.0006142644 277,277,0.0006109689 157,157,0.00060956663 291,291,0.0006088704 370,370,0.0006068144 104,104,0.00060625107 41,41,0.0006062162 94,94,0.00060481543 493,493,0.00060343795 247,247,0.00060251716 338,338,0.00060242455 204,204,0.00060240383 424,424,0.0005989864 344,344,0.0005951116 360,360,0.0005884679 151,151,0.0005868215 264,264,0.0005865409 293,293,0.00058519834 62,62,0.0005819824 300,300,0.00058131246 238,238,0.00057876547 407,407,0.0005770471 342,342,0.0005726651 5,5,0.00057154294 114,114,0.00057109277 240,240,0.00057103945 452,452,0.0005675462 91,91,0.0005672489 413,413,0.0005642817 119,119,0.00056401774 458,458,0.0005635342 180,180,0.0005617487 10,10,0.0005590609 181,181,0.00055804825 479,479,0.0005575437 29,29,0.00055724994 9,9,0.00055689306 102,102,0.0005544259 399,399,0.00055424945 97,97,0.0005539326 172,172,0.00054995815 31,31,0.0005499472 364,364,0.0005491284 492,492,0.00054688356 164,164,0.00054382335 371,371,0.0005417347 275,275,0.00053873524 308,308,0.00053855526 501,501,0.00053753663 92,92,0.00053742283 506,506,0.0005367725 167,167,0.00053528306 305,305,0.00053263427 485,485,0.00053174875 318,318,0.0005315904 177,177,0.00053113786 166,166,0.0005307643 193,193,0.0005297839 469,469,0.0005261937 25,25,0.000521339 48,48,0.0005209389 128,128,0.00052093086 498,498,0.00052030955 405,405,0.0005189927 201,201,0.00051636016 229,229,0.00051383715 24,24,0.00051297864 123,123,0.0005124072 477,477,0.0005120602 402,402,0.0005115426 377,377,0.00051048637 348,348,0.0005102354 23,23,0.00050849793 451,451,0.00050814415 406,406,0.0005045002 27,27,0.0004999815 350,350,0.0004998393 185,185,0.0004971214 390,390,0.00049634936 375,375,0.0004955703 431,431,0.00049411313 105,105,0.00049172394 411,411,0.0004917152 148,148,0.00049001497 250,250,0.0004884555 392,392,0.00048794085 374,374,0.00048640848 252,252,0.00048480998 269,269,0.00048419714 192,192,0.00048391512 217,217,0.00048157398 263,263,0.00048102875 415,415,0.00047999277 212,212,0.0004762149 417,417,0.00047523607 467,467,0.0004741602 340,340,0.00047381772 397,397,0.00047334703 433,433,0.00047333006 378,378,0.00047203893 206,206,0.0004719441 443,443,0.00047179937 484,484,0.00047088487 434,434,0.0004697333 396,396,0.00046903393 13,13,0.00046736852 379,379,0.00046703266 178,178,0.0004656918 202,202,0.0004656007 341,341,0.00046170983 456,456,0.0004603486 462,462,0.0004575785 67,67,0.00045551857 138,138,0.00045521912 459,459,0.00045479517 358,358,0.000450105 77,77,0.00044913465 146,146,0.00044637956 66,66,0.0004448067 98,98,0.00044425039 442,442,0.00044048973 0,0,0.00044048866 216,216,0.0004404604 18,18,0.0004400146 54,54,0.00043942602 20,20,0.00043839475 508,508,0.0004379366 285,285,0.0004373548 195,195,0.00043511056 155,155,0.00043351707 444,444,0.0004311831 257,257,0.00043021288 287,287,0.00042994966 449,449,0.0004278185 280,280,0.00042747098 255,255,0.000425165 56,56,0.00042424107 404,404,0.0004226035 488,488,0.00042232242 356,356,0.00042173947 244,244,0.0004209792 432,432,0.0004125784 214,214,0.0004114269 393,393,0.0004107277 270,270,0.00041058182 111,111,0.0004104286 324,324,0.00040866464 61,61,0.00040655467 366,366,0.00040608697 147,147,0.00040604445 311,311,0.00040550664 500,500,0.00040497814 211,211,0.00040463882 112,112,0.00040117715 100,100,0.00040099313 234,234,0.00040040378 132,132,0.00039979443 478,478,0.0003981258 221,221,0.0003965061 368,368,0.0003960585 336,336,0.00039551125 339,339,0.00039536165 19,19,0.0003951851 71,71,0.00039469448 490,490,0.00039201186 253,253,0.0003899332 332,332,0.00038865488 447,447,0.00038850866 223,223,0.0003865917 12,12,0.00038631624 256,256,0.00038542095 303,303,0.00038529842 335,335,0.0003850341 125,125,0.00038496146 52,52,0.00038445802 465,465,0.00037988674 14,14,0.00037270828 445,445,0.0003714122 51,51,0.0003710708 183,183,0.00036938934 435,435,0.00036892455 76,76,0.0003672379 203,203,0.0003666505 74,74,0.00036636737 464,464,0.00036338985 28,28,0.00036282517 376,376,0.00036232817 389,389,0.00036217785 394,394,0.00036191306 30,30,0.00036106672 327,327,0.000358803 73,73,0.0003566918 343,343,0.00035621642 384,384,0.0003560467 440,440,0.00035523146 251,251,0.00035423675 260,260,0.00035293537 265,265,0.00035225553 387,387,0.0003514995 298,298,0.00034635572 306,306,0.00034440306 110,110,0.00034401406 254,254,0.0003433519 505,505,0.00034252735 60,60,0.00034171442 302,302,0.000340328 171,171,0.00033906853 38,38,0.0003379222 59,59,0.00033599927 353,353,0.00033367483 317,317,0.0003322806 337,337,0.0003305284 135,135,0.0003302332 423,423,0.0003287331 310,310,0.00032717254 503,503,0.00032671163 69,69,0.00032358448 145,145,0.00032273383 160,160,0.00032157102 40,40,0.00032081635 400,400,0.00031983448 278,278,0.00031925194 489,489,0.0003178289 199,199,0.00031677147 133,133,0.00031545162 373,373,0.00031506256 331,331,0.0003125472 382,382,0.00031192778 11,11,0.00031145735 494,494,0.00031131672 426,426,0.00031126558 233,233,0.00030971633 290,290,0.000309349 232,232,0.00030930655 262,262,0.00030752877 231,231,0.00030624977 314,314,0.0003054031 502,502,0.00030359824 323,323,0.0003030356 222,222,0.0002989251 428,428,0.0002974012 496,496,0.00029618212 230,230,0.0002957415 121,121,0.00029490213 304,304,0.00029465224 179,179,0.00029258238 248,248,0.00029258014 436,436,0.0002923748 425,425,0.0002921299 236,236,0.0002915477 150,150,0.00029135606 414,414,0.00029009863 286,286,0.00028853436 320,320,0.00028726715 37,37,0.00028702882 131,131,0.0002847645 225,225,0.0002837026 441,441,0.00028333988 326,326,0.0002814045 422,422,0.00028055938 165,165,0.00027847502 471,471,0.00027833483 349,349,0.0002757503 409,409,0.00027481435 103,103,0.00027326684 95,95,0.00027135346 249,249,0.00027121097 90,90,0.0002710454 224,224,0.00027073256 34,34,0.0002699063 65,65,0.00026914835 184,184,0.00026874637 398,398,0.00026665113 301,301,0.00026612534 325,325,0.00026601987 261,261,0.00026447696 246,246,0.00026436363 122,122,0.00026221105 84,84,0.00026163872 78,78,0.00025123646 299,299,0.00024949652 408,408,0.00024725433 161,161,0.00024636975 288,288,0.00024635496 42,42,0.0002452217 106,106,0.00024499706 182,182,0.00024430823 507,507,0.00024347423 271,271,0.00024137288 136,136,0.00023734072 403,403,0.00023719446 453,453,0.0002368316 26,26,0.00023657693 468,468,0.00023344165 127,127,0.00023242869 117,117,0.0002311785 45,45,0.00022815086 1,1,0.0002279681 194,194,0.00022725234 312,312,0.00022653052 410,410,0.00022528754 491,491,0.00022254683 93,93,0.00022081727 63,63,0.00022056459 226,226,0.00021268189 381,381,0.00020935576 329,329,0.00020906379 446,446,0.00020802105 245,245,0.00020745523 15,15,0.0002072561 281,281,0.00020714967 315,315,0.00020467484 152,152,0.00020205516 8,8,0.00019883284 81,81,0.00019411556 273,273,0.00019290911 39,39,0.00019221198 243,243,0.0001919428 416,416,0.00018266037 289,289,0.00016792006 454,454,0.0001655061 176,176,0.00015807906 159,159,0.0001545051 ================================================ FILE: stylegan_human/latent_direction/ss_statics/bottom_length_statis/4/statis.csv ================================================ ,channel,strength 371,371,0.010705759 130,130,0.007931795 66,66,0.0069621187 411,411,0.0065370337 241,241,0.0061685536 178,178,0.0057360367 422,422,0.0055051707 59,59,0.0054199533 193,193,0.0052992324 405,405,0.00511423 202,202,0.00487821 414,414,0.004596879 347,347,0.0045533227 325,325,0.0042810338 479,479,0.0041018343 234,234,0.003791195 104,104,0.0037741603 437,437,0.0036558367 186,186,0.0036010114 214,214,0.0035913743 472,472,0.0035745492 99,99,0.003559262 13,13,0.003553579 302,302,0.0034689216 428,428,0.0034320198 43,43,0.0032388393 215,215,0.0031643964 346,346,0.0031622355 392,392,0.0031421043 469,469,0.0031391508 185,185,0.0031346607 110,110,0.0031338846 152,152,0.0031238336 255,255,0.0031061403 27,27,0.003093111 494,494,0.0030917446 238,238,0.0030462171 111,111,0.003043536 162,162,0.0030410155 125,125,0.0030364853 51,51,0.0030085724 231,231,0.0029884125 335,335,0.002956904 184,184,0.002923058 80,80,0.0029210225 253,253,0.0029075942 357,357,0.0028416363 180,180,0.0028330602 360,360,0.0027900473 105,105,0.0027881858 33,33,0.0027586774 475,475,0.0027457555 332,332,0.0027370767 220,220,0.0026984583 31,31,0.0026166395 53,53,0.0026120786 106,106,0.0025991872 412,412,0.00259617 382,382,0.0025869396 38,38,0.0025757526 316,316,0.0025433942 389,389,0.0025421656 435,435,0.002534331 225,225,0.002509905 354,354,0.00245047 243,243,0.0024502743 221,221,0.0024180906 218,218,0.0023629142 56,56,0.0023580198 230,230,0.0023460235 8,8,0.0022959905 344,344,0.0022738976 102,102,0.002244589 279,279,0.0022293774 77,77,0.0022287841 404,404,0.0022230886 200,200,0.0022142548 450,450,0.002204902 319,319,0.0021294882 117,117,0.0021246204 6,6,0.0021206948 247,247,0.002117081 297,297,0.0020922043 40,40,0.0020740507 352,352,0.0020036534 239,239,0.0019678425 402,402,0.0019613549 315,315,0.0019569998 195,195,0.0019549306 128,128,0.0019411439 207,207,0.0019369258 432,432,0.0019148714 365,365,0.0019143238 322,322,0.0018905443 24,24,0.0018891945 265,265,0.0018829522 417,417,0.0018819926 334,334,0.0018748969 0,0,0.0018602412 85,85,0.0018551659 126,126,0.0018536468 4,4,0.001846846 232,232,0.0018420395 376,376,0.0018346108 333,333,0.0018218933 250,250,0.0018113112 169,169,0.0018011285 361,361,0.0017686478 25,25,0.0017465113 427,427,0.0017453731 461,461,0.0017188549 182,182,0.0017154247 11,11,0.0016988176 197,197,0.0016953972 20,20,0.0016696001 246,246,0.0016686992 339,339,0.0016661945 205,205,0.0016548207 177,177,0.0016535291 153,153,0.0016440097 356,356,0.0016337134 456,456,0.0016318822 42,42,0.0016255479 378,378,0.0016247402 155,155,0.0016181484 401,401,0.0016147293 16,16,0.001612585 30,30,0.0016062072 377,377,0.0015920534 385,385,0.0015913579 79,79,0.0015898152 135,135,0.0015849494 384,384,0.0015830033 338,338,0.0015758197 98,98,0.0015708887 21,21,0.0015523953 35,35,0.0015421407 364,364,0.0015421154 270,270,0.0015403415 447,447,0.0015402873 485,485,0.0015397645 121,121,0.0015270623 408,408,0.0015214243 32,32,0.0015187038 336,336,0.0014896884 413,413,0.0014843026 499,499,0.0014718628 487,487,0.0014659785 488,488,0.0014608143 122,122,0.0014533442 491,491,0.0014338568 54,54,0.001432836 363,363,0.0014307784 151,151,0.001427959 91,91,0.0014276047 314,314,0.0014251476 161,161,0.0014244247 211,211,0.0014210346 362,362,0.0014053485 216,216,0.0013955731 159,159,0.0013930433 233,233,0.0013912213 449,449,0.0013888216 48,48,0.0013732343 248,248,0.0013702087 299,299,0.0013695534 503,503,0.0013617459 1,1,0.0013607475 237,237,0.0013588071 57,57,0.0013579503 409,409,0.001355633 483,483,0.0013551097 229,229,0.001353437 19,19,0.0013416156 293,293,0.0013304886 390,390,0.0013271623 168,168,0.0013254886 381,381,0.0013072129 366,366,0.0013022809 288,288,0.0012966645 451,451,0.0012919087 244,244,0.0012890152 292,292,0.0012853605 463,463,0.0012843801 470,470,0.0012821403 416,416,0.0012795452 157,157,0.0012787202 464,464,0.0012758262 329,329,0.0012730482 490,490,0.0012670642 74,74,0.0012638838 170,170,0.001260051 278,278,0.0012583392 22,22,0.0012556936 399,399,0.0012530655 100,100,0.0012508945 355,355,0.001250555 486,486,0.0012472017 506,506,0.0012441961 459,459,0.0012374625 309,309,0.0012292771 113,113,0.0012270019 138,138,0.0012170793 47,47,0.0012084226 65,65,0.0012071957 198,198,0.001206154 196,196,0.001205836 285,285,0.0012046142 49,49,0.0012002704 457,457,0.0011999249 425,425,0.0011967781 175,175,0.0011944658 148,148,0.0011917639 86,86,0.0011908459 94,94,0.0011882967 501,501,0.0011801487 476,476,0.0011663077 156,156,0.0011636167 387,387,0.0011613253 266,266,0.0011546519 496,496,0.0011480699 340,340,0.0011474552 343,343,0.0011469066 52,52,0.0011451634 369,369,0.0011436169 90,90,0.0011377904 386,386,0.0011357777 96,96,0.0011343259 124,124,0.0011335325 460,460,0.0011268322 321,321,0.0011234696 264,264,0.0011168371 287,287,0.0011126697 353,353,0.0011116265 462,462,0.0011107579 154,154,0.0011100196 388,388,0.0011071602 448,448,0.0011041508 187,187,0.0010949599 328,328,0.0010922498 454,454,0.0010910842 306,306,0.0010906332 320,320,0.0010779172 391,391,0.0010766807 318,318,0.0010763333 107,107,0.0010709754 505,505,0.0010685796 206,206,0.0010629565 34,34,0.0010543949 473,473,0.0010541352 173,173,0.0010531493 109,109,0.0010509333 424,424,0.0010466026 430,430,0.0010448382 150,150,0.0010423292 268,268,0.001038994 2,2,0.0010353344 29,29,0.0010323756 504,504,0.0010309685 119,119,0.0010272656 174,174,0.0010258228 260,260,0.0010184883 249,249,0.0010175364 36,36,0.0010174246 137,137,0.001017379 303,303,0.0010154104 163,163,0.0010028904 455,455,0.0010002597 510,510,0.0009994362 245,245,0.00099766 262,262,0.0009885678 281,281,0.0009873835 28,28,0.0009856436 380,380,0.0009821856 228,228,0.0009717986 367,367,0.000969772 286,286,0.00096976873 7,7,0.00096907315 146,146,0.00096689834 139,139,0.000962454 204,204,0.0009589661 3,3,0.000956985 280,280,0.0009566337 509,509,0.00094728253 181,181,0.0009440433 68,68,0.0009432923 269,269,0.0009399444 179,179,0.00093980297 274,274,0.0009315241 95,95,0.0009284373 263,263,0.00092792546 72,72,0.00092775293 277,277,0.0009271326 436,436,0.0009269056 500,500,0.0009265228 87,87,0.00092501455 310,310,0.00092400954 300,300,0.00092247425 144,144,0.0009218417 426,426,0.00091638684 81,81,0.00091634423 188,188,0.00090678554 289,289,0.0009011138 418,418,0.00089773914 397,397,0.00089350634 304,304,0.0008868619 482,482,0.0008865463 495,495,0.0008795051 312,312,0.00087571866 166,166,0.00087193854 183,183,0.00087024615 5,5,0.00086923986 446,446,0.0008691286 212,212,0.00086676405 46,46,0.00086484046 118,118,0.00086480496 254,254,0.00086135446 88,88,0.000861164 219,219,0.0008608658 467,467,0.0008600293 76,76,0.00085848826 331,331,0.00085327355 498,498,0.0008505213 39,39,0.000848981 93,93,0.00084537006 433,433,0.0008452787 410,410,0.00084377866 194,194,0.00083563104 478,478,0.00083321065 272,272,0.00083227345 223,223,0.00083198043 311,311,0.00083184306 431,431,0.0008309719 337,337,0.0008306422 189,189,0.0008290602 341,341,0.0008284594 394,394,0.00082656107 396,396,0.0008221022 92,92,0.0008204752 50,50,0.0008168993 84,84,0.0008081732 64,64,0.0008072594 123,123,0.00080666045 71,71,0.0008044883 140,140,0.00080113776 120,120,0.0007968432 14,14,0.00079405046 324,324,0.00079051324 115,115,0.0007894897 191,191,0.0007880761 439,439,0.0007866818 393,393,0.000783365 131,131,0.0007829777 327,327,0.0007810872 17,17,0.0007793945 97,97,0.00077890133 295,295,0.0007760637 423,423,0.00077443745 403,403,0.0007743149 497,497,0.0007740729 171,171,0.00077238533 276,276,0.0007655106 452,452,0.00076521165 136,136,0.0007644099 82,82,0.0007621963 142,142,0.000760782 374,374,0.0007555941 444,444,0.0007507936 282,282,0.0007499849 421,421,0.0007498477 375,375,0.00074311404 358,358,0.0007403847 217,217,0.00073654624 165,165,0.00073613157 420,420,0.0007335366 227,227,0.0007332281 330,330,0.00073107216 368,368,0.00072970067 37,37,0.00072445447 149,149,0.00072384684 477,477,0.000723286 407,407,0.00072004553 242,242,0.00071955 134,134,0.0007190485 172,172,0.0007161819 69,69,0.00071376155 372,372,0.00071212306 236,236,0.00071169576 349,349,0.000709231 484,484,0.00070886995 222,222,0.0007068514 9,9,0.00070441724 481,481,0.0007041735 373,373,0.0007033714 323,323,0.00070036115 434,434,0.0006995436 438,438,0.00069369253 359,359,0.0006884251 370,370,0.00068327464 308,308,0.000678813 445,445,0.0006780726 10,10,0.0006776827 127,127,0.0006771573 224,224,0.0006768041 296,296,0.00067331357 256,256,0.00066997507 493,493,0.0006680274 67,67,0.0006640576 116,116,0.0006630596 132,132,0.000661265 62,62,0.00066118705 508,508,0.00065791415 468,468,0.0006579114 440,440,0.00065191026 317,317,0.00065190904 160,160,0.00065158366 492,492,0.00065123267 458,458,0.0006491314 114,114,0.0006431901 58,58,0.00063985295 313,313,0.0006342638 18,18,0.0006290864 261,261,0.0006265827 383,383,0.00062478095 294,294,0.00062198006 143,143,0.0006198618 61,61,0.0006060728 103,103,0.0006045529 419,419,0.00060209975 466,466,0.0006016268 507,507,0.0005986943 273,273,0.0005929426 240,240,0.0005922067 350,350,0.00058880384 429,429,0.000588744 129,129,0.00058819435 267,267,0.00058599794 252,252,0.0005856361 23,23,0.00058518484 400,400,0.0005806418 283,283,0.0005790221 89,89,0.00057516847 12,12,0.0005741658 176,176,0.0005739618 192,192,0.00057157595 101,101,0.00056115567 442,442,0.0005588267 41,41,0.0005569018 63,63,0.0005514146 441,441,0.0005468517 398,398,0.00054553134 307,307,0.0005450811 298,298,0.00054445845 342,342,0.00054111326 502,502,0.0005392242 78,78,0.0005375705 44,44,0.0005357114 73,73,0.00053344626 209,209,0.00053105725 158,158,0.0005261206 70,70,0.0005217334 199,199,0.0005199517 471,471,0.0005012095 511,511,0.0004969943 259,259,0.000493192 235,235,0.00047622804 301,301,0.0004594244 275,275,0.0004556442 167,167,0.00045370075 133,133,0.00044329825 147,147,0.00043812627 348,348,0.00042484334 190,190,0.0004229568 406,406,0.00042175007 480,480,0.000417746 108,108,0.00041762542 395,395,0.000411962 305,305,0.0004101298 290,290,0.00040680933 489,489,0.0004044473 251,251,0.0004041571 164,164,0.00040222614 257,257,0.00039560342 379,379,0.000392651 326,326,0.00038673886 112,112,0.00036090027 83,83,0.000344462 351,351,0.00034133552 210,210,0.00033938422 141,141,0.0003355473 60,60,0.00033554618 226,226,0.00033128157 203,203,0.00032476717 15,15,0.00027671162 208,208,0.0002761039 291,291,0.000266872 213,213,0.0002634147 415,415,0.00026088266 474,474,0.000256801 271,271,0.00024554084 201,201,0.00023913333 443,443,0.00023458686 145,145,0.00022581422 284,284,0.00022399156 258,258,0.00021423028 465,465,0.00021232266 453,453,0.00021226592 75,75,0.00020900984 55,55,0.00020078795 26,26,0.00018182797 45,45,0.00016305555 345,345,0.00015414672 ================================================ FILE: stylegan_human/latent_direction/ss_statics/bottom_length_statis/5/statis.csv ================================================ ,channel,strength 242,242,0.01746412 134,134,0.011444086 71,71,0.01060778 395,395,0.0062382165 363,363,0.0058679837 175,175,0.004722381 53,53,0.0044826367 112,112,0.0042659384 457,457,0.003703029 288,288,0.003450465 328,328,0.00344062 414,414,0.0032427178 205,205,0.0032254586 321,321,0.003165059 32,32,0.003014796 9,9,0.0025584965 180,180,0.0025317036 452,452,0.0024022916 69,69,0.0023928392 210,210,0.0023827436 385,385,0.0023732155 98,98,0.0023490055 307,307,0.0023184065 418,418,0.002295881 470,470,0.0022638584 341,341,0.0021729823 308,308,0.0021633923 37,37,0.0021602109 440,440,0.0021246527 16,16,0.0020481 10,10,0.0020317773 486,486,0.002021162 150,150,0.0020176854 89,89,0.0019858822 278,278,0.001971183 430,430,0.0019591004 463,463,0.0019105887 434,434,0.0018847745 437,437,0.0017904 127,127,0.0017076981 310,310,0.0016927454 151,151,0.0016283162 224,224,0.0015725751 268,268,0.0015449473 402,402,0.00153654 190,190,0.0014757048 92,92,0.0014610167 117,117,0.0014568182 110,110,0.0014490561 423,423,0.0014475571 161,161,0.0014227595 291,291,0.001410095 225,225,0.0013883268 189,189,0.001364547 157,157,0.0013630674 499,499,0.00135625 274,274,0.0013522932 166,166,0.0013465862 475,475,0.0013449993 300,300,0.0013322539 368,368,0.0012959774 267,267,0.0012800089 36,36,0.0012684392 11,11,0.0012514348 184,184,0.0012406422 453,453,0.0012389477 173,173,0.0012284226 429,429,0.001227794 229,229,0.0011854741 212,212,0.0011837278 295,295,0.0011766667 318,318,0.0011758378 390,390,0.0011577767 67,67,0.0011576377 26,26,0.001157294 256,256,0.0011399041 287,287,0.0011214579 245,245,0.0011204005 118,118,0.001119347 379,379,0.0011182849 412,412,0.0011027583 169,169,0.0010857102 488,488,0.001084579 108,108,0.0010608159 155,155,0.0010607184 465,465,0.0010555563 80,80,0.0010430918 285,285,0.0010419621 191,191,0.001039581 320,320,0.0010336601 489,489,0.0009907437 46,46,0.00098818 359,359,0.0009863363 415,415,0.000984566 438,438,0.0009843353 2,2,0.0009840623 483,483,0.000978259 116,116,0.00096878014 279,279,0.00096704334 391,391,0.0009629472 75,75,0.0009625787 386,386,0.0009506957 213,213,0.00094954396 81,81,0.0009334278 170,170,0.0009309288 459,459,0.0009304561 25,25,0.0009263908 422,422,0.0009251744 316,316,0.0009241467 254,254,0.00092409964 294,294,0.00092359504 322,322,0.00092271896 493,493,0.00092130696 168,168,0.00091688987 361,361,0.0009016815 302,302,0.0008988095 199,199,0.0008969074 42,42,0.00089361327 275,275,0.00089227123 20,20,0.00089052174 197,197,0.00088976097 43,43,0.0008803114 370,370,0.00087934994 436,436,0.00087875564 28,28,0.00087209826 290,290,0.0008675793 330,330,0.00085718994 94,94,0.0008566909 511,511,0.0008561607 77,77,0.0008551629 484,484,0.0008420306 202,202,0.0008376041 78,78,0.00083523884 487,487,0.00083071465 44,44,0.0008302506 456,456,0.00082660967 343,343,0.00082623283 186,186,0.0008204403 428,428,0.0008122731 63,63,0.000809409 371,371,0.0007866993 367,367,0.0007859241 410,410,0.00078440236 129,129,0.00078421313 492,492,0.0007841307 219,219,0.00078324176 181,181,0.0007805426 192,192,0.000759081 348,348,0.00075573416 156,156,0.00075448444 149,149,0.0007483769 497,497,0.0007449612 97,97,0.0007447865 238,238,0.0007345617 427,427,0.0007344842 48,48,0.0007277395 496,496,0.0007251045 468,468,0.0007233464 351,351,0.00071131974 396,396,0.0007099477 240,240,0.00070780434 277,277,0.00070575473 397,397,0.00070399133 362,362,0.0007016971 122,122,0.0006994947 425,425,0.0006994096 347,347,0.0006979086 502,502,0.00069097895 377,377,0.00068525685 70,70,0.00068207615 481,481,0.0006812176 185,185,0.00067786046 90,90,0.00067573227 472,472,0.0006711317 339,339,0.0006699505 405,405,0.000669038 426,426,0.0006675291 204,204,0.0006652177 296,296,0.0006634654 235,235,0.000660739 141,141,0.0006510568 203,203,0.00064648956 293,293,0.0006462373 508,508,0.00064582145 121,121,0.00064454816 93,93,0.00064454804 406,406,0.00064240245 12,12,0.0006397705 344,344,0.00063504284 19,19,0.0006331822 332,332,0.0006323309 194,194,0.00062925677 313,313,0.0006285695 507,507,0.0006174056 82,82,0.0006161944 239,239,0.00060919294 490,490,0.0006089468 266,266,0.0006075872 128,128,0.0006045006 407,407,0.00060399977 451,451,0.00060152815 137,137,0.0005997921 454,454,0.00059671386 270,270,0.00059550925 404,404,0.00059543905 439,439,0.000591806 460,460,0.0005913006 23,23,0.0005911026 373,373,0.000590888 96,96,0.0005903696 55,55,0.00059018075 365,365,0.000585234 411,411,0.00058369443 374,374,0.00058170315 119,119,0.000581078 458,458,0.00058017287 420,420,0.00057635305 393,393,0.00057374657 261,261,0.0005727315 319,319,0.00057239935 283,283,0.00057194225 233,233,0.0005646504 403,403,0.00056403375 357,357,0.00056290894 241,241,0.00056050875 182,182,0.00056047173 257,257,0.000556821 364,364,0.00055118365 174,174,0.0005500873 111,111,0.00054910127 309,309,0.0005458186 193,193,0.00054326915 5,5,0.0005425164 24,24,0.0005411514 162,162,0.0005377205 432,432,0.0005299737 284,284,0.000529899 491,491,0.0005281161 352,352,0.000527957 39,39,0.0005278845 297,297,0.00052527175 312,312,0.0005249035 443,443,0.0005247692 143,143,0.0005237088 389,389,0.00052194 132,132,0.0005208663 57,57,0.00051607064 85,85,0.00051471085 482,482,0.00051036675 247,247,0.0005102306 477,477,0.00050744385 358,358,0.0005060787 125,125,0.000505312 466,466,0.00050286297 479,479,0.00050196645 323,323,0.0005004967 232,232,0.0004986481 114,114,0.0004957778 87,87,0.0004945035 158,158,0.0004930025 376,376,0.00049192616 384,384,0.00049149833 65,65,0.00049068604 144,144,0.0004868068 41,41,0.00048459516 33,33,0.00048199162 135,135,0.00047957737 72,72,0.00047893118 449,449,0.00047819986 442,442,0.0004762921 130,130,0.00047556465 244,244,0.00047063318 163,163,0.00047034535 292,292,0.00047004907 350,350,0.00046825878 304,304,0.0004671009 474,474,0.0004649635 208,208,0.00046426945 286,286,0.00046057467 178,178,0.00045859 450,450,0.00045825946 226,226,0.00045819176 8,8,0.00045815166 252,252,0.0004559861 154,154,0.00045457872 264,264,0.00045156496 146,146,0.00044669665 220,220,0.00044567848 501,501,0.00044542857 273,273,0.00044488933 171,171,0.0004427107 401,401,0.00044082777 381,381,0.00043972745 140,140,0.00043886708 356,356,0.00043867563 464,464,0.0004386433 353,353,0.00043817703 230,230,0.0004378558 249,249,0.00043769262 15,15,0.0004355795 345,345,0.0004354763 133,133,0.0004346199 120,120,0.00043198353 505,505,0.00043042895 95,95,0.00042860254 378,378,0.0004282037 455,455,0.00042507777 298,298,0.00042234876 145,145,0.0004205409 104,104,0.0004192717 394,394,0.00041797172 45,45,0.00041774593 366,366,0.00041443488 325,325,0.00041310437 62,62,0.00041287072 179,179,0.00041070042 317,317,0.0004105377 338,338,0.00040974608 79,79,0.00040827427 66,66,0.00040683098 139,139,0.0004056471 29,29,0.00040440765 152,152,0.00040363474 214,214,0.00040329635 435,435,0.0004028462 22,22,0.00040165885 346,346,0.0004001219 209,209,0.00039810632 392,392,0.00039726324 315,315,0.00039714077 433,433,0.00039478665 506,506,0.0003915952 105,105,0.0003900892 446,446,0.00038801826 0,0,0.00038777932 18,18,0.0003875171 424,424,0.00038726558 331,331,0.00038722638 51,51,0.00038692914 417,417,0.00038655673 14,14,0.0003852234 6,6,0.00038509126 281,281,0.0003836144 383,383,0.00038292323 216,216,0.00038181938 262,262,0.00037940088 74,74,0.00037926337 3,3,0.00037809636 387,387,0.00037650907 342,342,0.00037650426 398,398,0.0003733892 136,136,0.00037266538 243,243,0.00037207166 471,471,0.00037091138 73,73,0.00036918517 86,86,0.00036871075 68,68,0.00036807635 372,372,0.00036491535 200,200,0.0003633216 107,107,0.00036330617 480,480,0.00036279066 124,124,0.00036251516 131,131,0.00036231225 269,269,0.00036161783 447,447,0.00036158395 167,167,0.0003561097 38,38,0.00035443963 211,211,0.00035319992 369,369,0.00035081702 248,248,0.00035030014 101,101,0.00034903514 115,115,0.00034696967 413,413,0.0003440623 289,289,0.00034358836 56,56,0.00034278553 195,195,0.0003423341 388,388,0.00034183572 416,416,0.00034126177 206,206,0.0003408197 375,375,0.00033973216 4,4,0.00033931396 494,494,0.00033901844 99,99,0.00033845875 282,282,0.00033784402 172,172,0.00033696217 218,218,0.00033630853 165,165,0.0003358777 196,196,0.0003349965 419,419,0.0003341522 469,469,0.00033384332 327,327,0.0003337259 324,324,0.00033304066 260,260,0.00033025324 478,478,0.00032999786 148,148,0.0003298886 500,500,0.00032941083 83,83,0.00032738544 231,231,0.00032720037 280,280,0.00032718244 498,498,0.00032587873 467,467,0.00032520766 102,102,0.00032370255 84,84,0.00032345828 21,21,0.00032132064 334,334,0.00032061047 237,237,0.00032028827 441,441,0.0003187737 485,485,0.0003181529 159,159,0.00031660515 50,50,0.00031633597 258,258,0.00031522466 164,164,0.00031313743 59,59,0.00031286437 160,160,0.00031265983 355,355,0.00031167237 7,7,0.00031130642 76,76,0.00030938906 476,476,0.0003087692 263,263,0.0003072378 495,495,0.0003052068 27,27,0.00030507136 234,234,0.00030466946 299,299,0.00030342882 113,113,0.00030309713 276,276,0.00030135227 217,217,0.00030106696 61,61,0.00030104464 54,54,0.00030067936 349,349,0.0003001497 380,380,0.00029829858 503,503,0.00029654359 303,303,0.0002964005 88,88,0.00029631294 188,188,0.00029344021 444,444,0.00029314525 329,329,0.0002908486 400,400,0.0002895397 223,223,0.00028937528 123,123,0.0002876151 272,272,0.0002861553 91,91,0.0002857061 354,354,0.00028297797 409,409,0.0002819136 448,448,0.0002804029 509,509,0.00027843454 215,215,0.00027820378 183,183,0.0002780649 253,253,0.0002754507 461,461,0.00027407336 255,255,0.00027380435 251,251,0.00027299367 109,109,0.00027289733 246,246,0.00027286436 259,259,0.00027249 201,201,0.00026961832 207,207,0.00026913724 198,198,0.00026806045 236,236,0.00026771048 326,326,0.00026338472 49,49,0.00026074707 138,138,0.0002606831 147,147,0.00025864976 30,30,0.00025828253 408,408,0.0002574985 177,177,0.00025741325 153,153,0.00025581883 187,187,0.0002544332 34,34,0.0002536471 58,58,0.00025209208 473,473,0.00024838003 221,221,0.00024675552 126,126,0.00024573723 228,228,0.00024463152 306,306,0.00024407305 250,250,0.00024358112 421,421,0.00024149592 176,176,0.00024142586 64,64,0.00023718696 336,336,0.00023363969 40,40,0.00023206352 504,504,0.00023138674 399,399,0.0002291037 305,305,0.00022865152 60,60,0.0002266992 301,301,0.00022585278 222,222,0.00022387604 311,311,0.00022371336 17,17,0.0002226341 52,52,0.00022181227 142,142,0.00021962133 103,103,0.00021821138 382,382,0.00021803642 1,1,0.0002171288 337,337,0.00021468119 445,445,0.00021307352 314,314,0.00021049718 100,100,0.00020906334 360,360,0.00020833187 340,340,0.00020668289 106,106,0.0002023169 462,462,0.00019855957 271,271,0.00019850428 227,227,0.00019575101 35,35,0.00019263716 510,510,0.00018776124 265,265,0.0001874168 31,31,0.00018595619 333,333,0.00018071612 13,13,0.00017645532 431,431,0.00017538466 47,47,0.00017165719 335,335,0.00016702584 ================================================ FILE: stylegan_human/latent_direction/ss_statics/upper_length_statis/5/statis.csv ================================================ ,channel,strength 423,423,0.004408924 341,341,0.0032079767 379,379,0.0028457695 184,184,0.0028266357 368,368,0.0027224978 486,486,0.0025201144 313,313,0.0025118296 426,426,0.002334192 367,367,0.0022732853 213,213,0.0022067663 436,436,0.0021884514 308,308,0.0021619347 496,496,0.0020120202 393,393,0.001984407 422,422,0.0019403459 425,425,0.0018690284 511,511,0.0017944475 181,181,0.00178793 497,497,0.0016764585 361,361,0.0016183082 2,2,0.0015440682 267,267,0.0015391556 114,114,0.0015294778 170,170,0.001501581 432,432,0.001494758 385,385,0.0014706858 339,339,0.001420463 415,415,0.0013572118 116,116,0.0013073295 373,373,0.0013072236 453,453,0.0013039358 320,320,0.0012937128 256,256,0.0012766926 127,127,0.0012432956 75,75,0.0012031216 287,287,0.0011476731 309,309,0.001147382 456,456,0.0011466638 343,343,0.0011463353 457,457,0.0011426552 98,98,0.0011147967 437,437,0.0011112978 435,435,0.0010973605 182,182,0.0010894896 150,150,0.0010746964 279,279,0.0010726912 189,189,0.0010475953 9,9,0.001045001 371,371,0.001035062 53,53,0.0010211505 168,168,0.0010098461 146,146,0.0010040879 470,470,0.00097405957 69,69,0.00097038125 375,375,0.0009642161 134,134,0.00094965403 475,475,0.0009452465 125,125,0.0009439787 434,434,0.0009384095 288,288,0.0009309878 205,205,0.0009186115 128,128,0.0009111188 328,328,0.0008899651 319,319,0.0008791126 42,42,0.00087770226 458,458,0.0008629476 272,272,0.0008628456 414,414,0.00086089334 261,261,0.00085644424 304,304,0.00083918776 24,24,0.0008195494 161,161,0.0008154048 225,225,0.0008051095 67,67,0.0008038561 482,482,0.0007958309 430,430,0.0007924481 499,499,0.0007831592 390,390,0.00078276 11,11,0.00076811534 332,332,0.0007633267 197,197,0.0007563872 325,325,0.00074650464 322,322,0.00073741924 157,157,0.00071299827 93,93,0.00069571583 108,108,0.0006955461 185,185,0.00068194594 115,115,0.0006798201 80,80,0.0006779811 405,405,0.00067563757 450,450,0.00067155104 105,105,0.00065722043 57,57,0.00065079477 36,36,0.0006450097 454,454,0.0006430749 247,247,0.0006401138 8,8,0.00063832046 102,102,0.0006370239 316,316,0.0006331501 488,488,0.00062760746 364,364,0.0006171263 389,389,0.00060887367 172,172,0.0006087077 352,352,0.00060442963 463,463,0.0006023239 19,19,0.00060143415 210,210,0.00059817376 186,186,0.0005907412 471,471,0.0005869326 140,140,0.00058600784 94,94,0.000580931 81,81,0.000580287 270,270,0.0005789991 417,417,0.0005783043 196,196,0.0005754427 46,46,0.00057365885 464,464,0.00056969275 104,104,0.00056765747 171,171,0.00056728435 487,487,0.0005661972 220,220,0.00056612946 68,68,0.00056457426 33,33,0.0005640128 72,72,0.0005622605 229,229,0.00055803073 41,41,0.0005567271 148,148,0.0005566318 455,455,0.00055565406 349,349,0.00055537344 442,442,0.00055352686 162,162,0.0005533156 191,191,0.00054769375 202,202,0.0005464838 226,226,0.0005416205 466,466,0.00054103765 192,192,0.0005383168 479,479,0.00052559533 143,143,0.00052148907 199,199,0.0005204556 281,281,0.00051968545 410,410,0.0005187148 403,403,0.0005185161 38,38,0.00051575975 429,429,0.00051309256 381,381,0.0005028584 363,363,0.00050276244 26,26,0.0004997491 502,502,0.0004992033 208,208,0.000498523 109,109,0.0004985198 179,179,0.000497237 433,433,0.00049502135 190,190,0.0004945647 223,223,0.00049233536 50,50,0.00048974034 97,97,0.00048946124 37,37,0.0004884755 85,85,0.00048751346 6,6,0.00048617192 351,351,0.0004857896 212,212,0.00048565448 259,259,0.00048205393 294,294,0.00048184753 118,118,0.00048155375 117,117,0.00047967743 428,428,0.0004783297 110,110,0.0004755637 90,90,0.00047256536 55,55,0.00047060288 233,233,0.0004681879 264,264,0.00046594813 240,240,0.00046441547 310,310,0.00046434483 238,238,0.00046259057 145,145,0.00046210623 22,22,0.00046033954 395,395,0.00045831574 397,397,0.00045133353 18,18,0.00044636318 129,129,0.0004431245 241,241,0.0004430129 465,465,0.00044261583 63,63,0.00044257144 418,418,0.00044039512 507,507,0.00043912584 358,358,0.0004376243 365,365,0.00043710147 242,242,0.00043396762 411,411,0.00043118192 338,338,0.0004302841 122,122,0.00042847547 235,235,0.00042753594 476,476,0.0004272296 3,3,0.00042361638 176,176,0.00041432443 407,407,0.00041412332 391,391,0.00041366852 1,1,0.00041138142 73,73,0.0004112309 493,493,0.00041061096 421,421,0.000409816 459,459,0.00040876624 14,14,0.00040792229 284,284,0.00040753878 61,61,0.0004073927 293,293,0.00040734603 275,275,0.00040710357 494,494,0.0004049803 180,180,0.00040350194 301,301,0.00040267198 198,198,0.00040212894 193,193,0.00040090212 280,280,0.0004008506 396,396,0.00040081202 5,5,0.0004002457 76,76,0.00040004638 350,350,0.00039750861 283,283,0.0003969664 344,344,0.0003949259 230,230,0.00039482582 149,149,0.0003945113 244,244,0.00039450615 357,357,0.00039388 491,491,0.00039187729 45,45,0.00038878893 298,298,0.00038755446 56,56,0.00038721022 503,503,0.00038716424 217,217,0.0003867056 159,159,0.00038530145 468,468,0.00038489333 427,427,0.0003842091 291,291,0.00038256386 500,500,0.0003809668 290,290,0.00038067088 460,460,0.00038007705 331,331,0.00037731865 211,211,0.00037723288 440,440,0.00037687554 74,74,0.00037646503 119,119,0.0003755539 133,133,0.00037539483 353,353,0.00037147343 321,321,0.00037131374 32,32,0.0003699305 16,16,0.0003695323 71,71,0.00036733068 483,483,0.0003673293 131,131,0.00036619935 404,404,0.00036550473 218,218,0.00036407504 424,424,0.00036398123 23,23,0.00036384634 65,65,0.0003628074 111,111,0.00036226492 260,260,0.00036025967 258,258,0.0003602185 112,112,0.00035953286 492,492,0.00035905885 167,167,0.00035840494 399,399,0.00035528265 376,376,0.00035397 326,326,0.00035366518 481,481,0.00035335837 79,79,0.00035233115 101,101,0.00035053334 276,276,0.00034943986 296,296,0.00034848423 152,152,0.00034692476 10,10,0.00034476537 489,489,0.00034423766 438,438,0.0003438288 120,120,0.0003430017 239,239,0.00034234222 317,317,0.00034196727 347,347,0.00034175767 206,206,0.0003414347 84,84,0.00034051857 490,490,0.0003404467 107,107,0.0003396762 495,495,0.00033961574 333,333,0.00033911737 446,446,0.000338417 254,254,0.00033728743 386,386,0.00033677832 250,250,0.00033664028 306,306,0.0003353818 246,246,0.0003352706 372,372,0.00033452015 169,169,0.00033249875 451,451,0.0003304912 173,173,0.0003292484 302,302,0.00032853018 151,151,0.0003258597 263,263,0.0003249236 274,274,0.00032480215 156,156,0.00032411155 307,307,0.00032296553 88,88,0.00032126167 39,39,0.0003210367 91,91,0.00032037924 413,413,0.00032021984 232,232,0.0003186513 366,366,0.00031673085 480,480,0.00031612715 44,44,0.00031571824 462,462,0.00031546177 380,380,0.0003152556 83,83,0.00031272677 132,132,0.0003114493 209,209,0.00030984185 48,48,0.00030914877 382,382,0.00030887048 195,195,0.00030860244 154,154,0.00030410106 166,166,0.0003024194 245,245,0.00030229168 262,262,0.00030191787 237,237,0.0002995339 443,443,0.0002943032 467,467,0.00029363343 121,121,0.00029333794 416,416,0.00029272414 160,160,0.00029269484 4,4,0.00029229806 92,92,0.00029168854 77,77,0.00028903817 400,400,0.0002876192 278,278,0.00028760926 474,474,0.00028757288 402,402,0.00028493986 506,506,0.0002847649 234,234,0.00028450877 277,277,0.00028409314 447,447,0.0002835903 342,342,0.00028351165 285,285,0.00028341086 345,345,0.00028339337 348,348,0.0002823747 300,300,0.00028156798 383,383,0.00028049652 231,231,0.0002790845 203,203,0.00027895247 355,355,0.00027876275 204,204,0.00027841472 216,216,0.0002779351 508,508,0.00027720784 282,282,0.00027655836 297,297,0.00027502645 292,292,0.00027430354 327,327,0.0002727945 100,100,0.000269865 95,95,0.0002694548 187,187,0.0002689126 408,408,0.0002658863 477,477,0.00026576317 384,384,0.0002645117 54,54,0.00026404977 374,374,0.00026287523 420,420,0.00026245107 509,509,0.00026231605 28,28,0.00026166256 449,449,0.0002611203 336,336,0.0002604421 178,178,0.00026030626 299,299,0.00025961167 103,103,0.00025886018 388,388,0.00025811547 271,271,0.00025790904 207,207,0.00025755883 248,248,0.00025613784 249,249,0.00025567278 138,138,0.00025559164 78,78,0.00025549062 269,269,0.00025442135 273,273,0.00025399768 286,286,0.00025389006 12,12,0.0002534591 478,478,0.00025331602 452,452,0.00025299162 27,27,0.0002521485 82,82,0.00025113745 295,295,0.0002508786 201,201,0.0002506164 409,409,0.00025036978 359,359,0.00024992102 394,394,0.00024895562 330,330,0.00024881997 501,501,0.0002484243 64,64,0.00024702528 52,52,0.00024523196 106,106,0.00024233114 175,175,0.00024187689 135,135,0.00024020251 419,419,0.00023802932 139,139,0.00023566972 25,25,0.00023424833 312,312,0.00023372385 469,469,0.00023334593 7,7,0.00023190145 158,158,0.00023087433 60,60,0.00023052357 441,441,0.00022939079 165,165,0.00022774191 59,59,0.00022733606 147,147,0.00022624452 62,62,0.00022606985 144,144,0.00022451063 370,370,0.00022431195 21,21,0.0002239656 369,369,0.00022302035 314,314,0.00022240904 377,377,0.00022210156 406,406,0.00022187224 255,255,0.00022102552 356,356,0.00022071223 472,472,0.00021941346 484,484,0.00021900832 289,289,0.0002186439 137,137,0.00021708063 17,17,0.00021706238 51,51,0.00021587432 174,174,0.0002145808 124,124,0.00021406145 253,253,0.00021273737 251,251,0.00021224513 485,485,0.00021198599 214,214,0.00021052467 227,227,0.00020948028 126,126,0.00020917962 362,362,0.00020837369 473,473,0.00020753275 311,311,0.00020720006 346,346,0.00020569382 243,243,0.00020554132 439,439,0.00020543092 177,177,0.00020462318 86,86,0.00020450725 43,43,0.00020431104 354,354,0.0002021195 323,323,0.00020008704 378,378,0.00019674652 153,153,0.00019593372 324,324,0.00019575354 194,194,0.0001956694 360,360,0.00019427242 188,188,0.0001923151 265,265,0.00019114434 431,431,0.00019109932 219,219,0.0001904252 315,315,0.00019023697 224,224,0.00018825711 412,412,0.0001871387 89,89,0.00018677485 268,268,0.0001852995 257,257,0.00018430859 392,392,0.0001833855 35,35,0.00018329639 20,20,0.00018285586 222,222,0.00018160306 141,141,0.00018157126 398,398,0.00018143529 461,461,0.0001811381 29,29,0.00018102965 318,318,0.00017966877 448,448,0.00017874948 58,58,0.00017864846 329,329,0.00017672384 401,401,0.0001746514 183,183,0.00017442314 142,142,0.00017366462 498,498,0.00017296121 0,0,0.00017286446 504,504,0.00017194722 444,444,0.00017125413 155,155,0.00016847586 87,87,0.00016776803 96,96,0.00016757102 99,99,0.00016714378 340,340,0.00016512058 505,505,0.00016355969 266,266,0.00016327428 49,49,0.00016297943 221,221,0.00016184167 15,15,0.00016183533 303,303,0.00016075459 113,113,0.00016014857 130,130,0.00015612668 215,215,0.00015609166 228,228,0.0001551064 305,305,0.00015308998 335,335,0.00015264843 40,40,0.00015258256 66,66,0.00015006233 200,200,0.00015001767 387,387,0.00014855604 252,252,0.00014628381 164,164,0.00014496325 136,136,0.00014322113 445,445,0.00014278234 123,123,0.00014132661 236,236,0.00013684056 31,31,0.00013596549 34,34,0.00013567152 334,334,0.00013561502 337,337,0.00013048301 70,70,0.00012578799 510,510,0.00012369145 47,47,0.00012088309 13,13,0.00011961328 163,163,0.00010987895 30,30,9.594474e-05 ================================================ FILE: stylegan_human/legacy.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. # import pickle import dnnlib import re from typing import List, Optional import torch import copy import numpy as np from torch_utils import misc #---------------------------------------------------------------------------- ## loading torch pkl def load_network_pkl(f, force_fp16=False, G_only=False): data = _LegacyUnpickler(f).load() if G_only: f = open('ori_model_Gonly.txt','a+') else: f = open('ori_model.txt','a+') for key in data.keys(): f.write(str(data[key])) f.close() ## We comment out this part, if you want to convert TF pickle, you can use the original script from StyleGAN2-ada-pytorch # # Legacy TensorFlow pickle => convert. # if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data): # tf_G, tf_D, tf_Gs = data # G = convert_tf_generator(tf_G) # D = convert_tf_discriminator(tf_D) # G_ema = convert_tf_generator(tf_Gs) # data = dict(G=G, D=D, G_ema=G_ema) # Add missing fields. if 'training_set_kwargs' not in data: data['training_set_kwargs'] = None if 'augment_pipe' not in data: data['augment_pipe'] = None # Validate contents. assert isinstance(data['G_ema'], torch.nn.Module) if not G_only: assert isinstance(data['D'], torch.nn.Module) assert isinstance(data['G'], torch.nn.Module) assert isinstance(data['training_set_kwargs'], (dict, type(None))) assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None))) # Force FP16. if force_fp16: if G_only: convert_list = ['G_ema'] #'G' else: convert_list = ['G', 'D', 'G_ema'] for key in convert_list: old = data[key] kwargs = copy.deepcopy(old.init_kwargs) if key.startswith('G'): kwargs.synthesis_kwargs = dnnlib.EasyDict(kwargs.get('synthesis_kwargs', {})) kwargs.synthesis_kwargs.num_fp16_res = 4 kwargs.synthesis_kwargs.conv_clamp = 256 if key.startswith('D'): kwargs.num_fp16_res = 4 kwargs.conv_clamp = 256 if kwargs != old.init_kwargs: new = type(old)(**kwargs).eval().requires_grad_(False) misc.copy_params_and_buffers(old, new, require_all=True) data[key] = new return data class _TFNetworkStub(dnnlib.EasyDict): pass class _LegacyUnpickler(pickle.Unpickler): def find_class(self, module, name): if module == 'dnnlib.tflib.network' and name == 'Network': return _TFNetworkStub return super().find_class(module, name) #---------------------------------------------------------------------------- def num_range(s: str) -> List[int]: '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.''' range_re = re.compile(r'^(\d+)-(\d+)$') m = range_re.match(s) if m: return list(range(int(m.group(1)), int(m.group(2))+1)) vals = s.split(',') return [int(x) for x in vals] #---------------------------------------------------------------------------- #### loading tf pkl def load_pkl(file_or_url): with open(file_or_url, 'rb') as file: return pickle.load(file, encoding='latin1') #---------------------------------------------------------------------------- ### For editing def visual(output, out_path): import torch import cv2 import numpy as np output = (output + 1)/2 output = torch.clamp(output, 0, 1) if output.shape[1] == 1: output = torch.cat([output, output, output], 1) output = output[0].detach().cpu().permute(1,2,0).numpy() output = (output*255).astype(np.uint8) output = output[:,:,::-1] cv2.imwrite(out_path, output) def save_obj(obj, path): with open(path, 'wb+') as f: pickle.dump(obj, f, protocol=4) #---------------------------------------------------------------------------- ## Converting pkl to pth, change dict info inside pickle def convert_to_rgb(state_ros, state_nv, ros_name, nv_name): state_ros[f"{ros_name}.conv.weight"] = state_nv[f"{nv_name}.torgb.weight"].unsqueeze(0) state_ros[f"{ros_name}.bias"] = state_nv[f"{nv_name}.torgb.bias"].unsqueeze(0).unsqueeze(-1).unsqueeze(-1) state_ros[f"{ros_name}.conv.modulation.weight"] = state_nv[f"{nv_name}.torgb.affine.weight"] state_ros[f"{ros_name}.conv.modulation.bias"] = state_nv[f"{nv_name}.torgb.affine.bias"] def convert_conv(state_ros, state_nv, ros_name, nv_name): state_ros[f"{ros_name}.conv.weight"] = state_nv[f"{nv_name}.weight"].unsqueeze(0) state_ros[f"{ros_name}.activate.bias"] = state_nv[f"{nv_name}.bias"] state_ros[f"{ros_name}.conv.modulation.weight"] = state_nv[f"{nv_name}.affine.weight"] state_ros[f"{ros_name}.conv.modulation.bias"] = state_nv[f"{nv_name}.affine.bias"] state_ros[f"{ros_name}.noise.weight"] = state_nv[f"{nv_name}.noise_strength"].unsqueeze(0) def convert_blur_kernel(state_ros, state_nv, level): """Not quite sure why there is a factor of 4 here""" # They are all the same state_ros[f"convs.{2*level}.conv.blur.kernel"] = 4*state_nv["synthesis.b4.resample_filter"] state_ros[f"to_rgbs.{level}.upsample.kernel"] = 4*state_nv["synthesis.b4.resample_filter"] def determine_config(state_nv): mapping_names = [name for name in state_nv.keys() if "mapping.fc" in name] sythesis_names = [name for name in state_nv.keys() if "synthesis.b" in name] n_mapping = max([int(re.findall("(\d+)", n)[0]) for n in mapping_names]) + 1 resolution = max([int(re.findall("(\d+)", n)[0]) for n in sythesis_names]) n_layers = np.log(resolution/2)/np.log(2) return n_mapping, n_layers def convert(network_pkl, output_file, G_only=False): with dnnlib.util.open_url(network_pkl) as f: G_nvidia = load_network_pkl(f,G_only=G_only)['G_ema'] state_nv = G_nvidia.state_dict() n_mapping, n_layers = determine_config(state_nv) state_ros = {} for i in range(n_mapping): state_ros[f"style.{i+1}.weight"] = state_nv[f"mapping.fc{i}.weight"] state_ros[f"style.{i+1}.bias"] = state_nv[f"mapping.fc{i}.bias"] for i in range(int(n_layers)): if i > 0: for conv_level in range(2): convert_conv(state_ros, state_nv, f"convs.{2*i-2+conv_level}", f"synthesis.b{4*(2**i)}.conv{conv_level}") state_ros[f"noises.noise_{2*i-1+conv_level}"] = state_nv[f"synthesis.b{4*(2**i)}.conv{conv_level}.noise_const"].unsqueeze(0).unsqueeze(0) convert_to_rgb(state_ros, state_nv, f"to_rgbs.{i-1}", f"synthesis.b{4*(2**i)}") convert_blur_kernel(state_ros, state_nv, i-1) else: state_ros[f"input.input"] = state_nv[f"synthesis.b{4*(2**i)}.const"].unsqueeze(0) convert_conv(state_ros, state_nv, "conv1", f"synthesis.b{4*(2**i)}.conv1") state_ros[f"noises.noise_{2*i}"] = state_nv[f"synthesis.b{4*(2**i)}.conv1.noise_const"].unsqueeze(0).unsqueeze(0) convert_to_rgb(state_ros, state_nv, "to_rgb1", f"synthesis.b{4*(2**i)}") # https://github.com/yuval-alaluf/restyle-encoder/issues/1#issuecomment-828354736 latent_avg = state_nv['mapping.w_avg'] state_dict = {"g_ema": state_ros, "latent_avg": latent_avg} # if G_only: # f = open('converted_model_Gonly.txt','a+') # else: # f = open('converted_model.txt','a+') # for key in state_dict['g_ema'].keys(): # f.write(str(key)+': '+str(state_dict['g_ema'][key].shape)+'\n') # f.close() torch.save(state_dict, output_file) ================================================ FILE: stylegan_human/openpose/model/.gitkeep ================================================ ================================================ FILE: stylegan_human/openpose/src/__init__.py ================================================ ================================================ FILE: stylegan_human/openpose/src/body.py ================================================ import cv2 import numpy as np import math import time from scipy.ndimage.filters import gaussian_filter import matplotlib.pyplot as plt import matplotlib import torch from torchvision import transforms from openpose.src import util from openpose.src.model import bodypose_model class Body(object): def __init__(self, model_path): self.model = bodypose_model() if torch.cuda.is_available(): self.model = self.model.cuda() model_dict = util.transfer(self.model, torch.load(model_path)) self.model.load_state_dict(model_dict) self.model.eval() def __call__(self, oriImg): # scale_search = [0.5, 1.0, 1.5, 2.0] scale_search = [0.5] boxsize = 368 stride = 8 padValue = 128 thre1 = 0.1 thre2 = 0.05 multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search] heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19)) paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38)) for m in range(len(multiplier)): scale = multiplier[m] imageToTest = cv2.resize(oriImg, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC) imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue) im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5 im = np.ascontiguousarray(im) data = torch.from_numpy(im).float() if torch.cuda.is_available(): data = data.cuda() # data = data.permute([2, 0, 1]).unsqueeze(0).float() with torch.no_grad(): Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data) Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy() Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy() # extract outputs, resize, and remove padding # heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps heatmap = cv2.resize(heatmap, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC) heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC) # paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs paf = cv2.resize(paf, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC) paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] paf = cv2.resize(paf, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC) heatmap_avg += heatmap_avg + heatmap / len(multiplier) paf_avg += + paf / len(multiplier) all_peaks = [] peak_counter = 0 for part in range(18): map_ori = heatmap_avg[:, :, part] one_heatmap = gaussian_filter(map_ori, sigma=3) map_left = np.zeros(one_heatmap.shape) map_left[1:, :] = one_heatmap[:-1, :] map_right = np.zeros(one_heatmap.shape) map_right[:-1, :] = one_heatmap[1:, :] map_up = np.zeros(one_heatmap.shape) map_up[:, 1:] = one_heatmap[:, :-1] map_down = np.zeros(one_heatmap.shape) map_down[:, :-1] = one_heatmap[:, 1:] peaks_binary = np.logical_and.reduce( (one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1)) peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks] peak_id = range(peak_counter, peak_counter + len(peaks)) peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))] all_peaks.append(peaks_with_score_and_id) peak_counter += len(peaks) # find connection in the specified sequence, center 29 is in the position 15 limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ [1, 16], [16, 18], [3, 17], [6, 18]] # the middle joints heatmap correpondence mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \ [23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \ [55, 56], [37, 38], [45, 46]] connection_all = [] special_k = [] mid_num = 10 for k in range(len(mapIdx)): score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]] candA = all_peaks[limbSeq[k][0] - 1] candB = all_peaks[limbSeq[k][1] - 1] nA = len(candA) nB = len(candB) indexA, indexB = limbSeq[k] if (nA != 0 and nB != 0): connection_candidate = [] for i in range(nA): for j in range(nB): vec = np.subtract(candB[j][:2], candA[i][:2]) norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1]) norm = max(0.001, norm) vec = np.divide(vec, norm) startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \ np.linspace(candA[i][1], candB[j][1], num=mid_num))) vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \ for I in range(len(startend))]) vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \ for I in range(len(startend))]) score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1]) score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min( 0.5 * oriImg.shape[0] / norm - 1, 0) criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts) criterion2 = score_with_dist_prior > 0 if criterion1 and criterion2: connection_candidate.append( [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]]) connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True) connection = np.zeros((0, 5)) for c in range(len(connection_candidate)): i, j, s = connection_candidate[c][0:3] if (i not in connection[:, 3] and j not in connection[:, 4]): connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]]) if (len(connection) >= min(nA, nB)): break connection_all.append(connection) else: special_k.append(k) connection_all.append([]) # last number in each row is the total parts number of that person # the second last number in each row is the score of the overall configuration subset = -1 * np.ones((0, 20)) candidate = np.array([item for sublist in all_peaks for item in sublist]) for k in range(len(mapIdx)): if k not in special_k: partAs = connection_all[k][:, 0] partBs = connection_all[k][:, 1] indexA, indexB = np.array(limbSeq[k]) - 1 for i in range(len(connection_all[k])): # = 1:size(temp,1) found = 0 subset_idx = [-1, -1] for j in range(len(subset)): # 1:size(subset,1): if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]: subset_idx[found] = j found += 1 if found == 1: j = subset_idx[0] if subset[j][indexB] != partBs[i]: subset[j][indexB] = partBs[i] subset[j][-1] += 1 subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] elif found == 2: # if found 2 and disjoint, merge them j1, j2 = subset_idx membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2] if len(np.nonzero(membership == 2)[0]) == 0: # merge subset[j1][:-2] += (subset[j2][:-2] + 1) subset[j1][-2:] += subset[j2][-2:] subset[j1][-2] += connection_all[k][i][2] subset = np.delete(subset, j2, 0) else: # as like found == 1 subset[j1][indexB] = partBs[i] subset[j1][-1] += 1 subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] # if find no partA in the subset, create a new subset elif not found and k < 17: row = -1 * np.ones(20) row[indexA] = partAs[i] row[indexB] = partBs[i] row[-1] = 2 row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2] subset = np.vstack([subset, row]) # delete some rows of subset which has few parts occur deleteIdx = [] for i in range(len(subset)): if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4: deleteIdx.append(i) subset = np.delete(subset, deleteIdx, axis=0) # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts # candidate: x, y, score, id return candidate, subset if __name__ == "__main__": body_estimation = Body('../model/body_pose_model.pth') test_image = '../images/ski.jpg' oriImg = cv2.imread(test_image) # B,G,R order candidate, subset = body_estimation(oriImg) canvas = util.draw_bodypose(oriImg, candidate, subset) plt.imshow(canvas[:, :, [2, 1, 0]]) plt.show() ================================================ FILE: stylegan_human/openpose/src/model.py ================================================ import torch from collections import OrderedDict import torch import torch.nn as nn def make_layers(block, no_relu_layers): layers = [] for layer_name, v in block.items(): if 'pool' in layer_name: layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2]) layers.append((layer_name, layer)) else: conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1], kernel_size=v[2], stride=v[3], padding=v[4]) layers.append((layer_name, conv2d)) if layer_name not in no_relu_layers: layers.append(('relu_'+layer_name, nn.ReLU(inplace=True))) return nn.Sequential(OrderedDict(layers)) class bodypose_model(nn.Module): def __init__(self): super(bodypose_model, self).__init__() # these layers have no relu layer no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\ 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\ 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\ 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1'] blocks = {} block0 = OrderedDict([ ('conv1_1', [3, 64, 3, 1, 1]), ('conv1_2', [64, 64, 3, 1, 1]), ('pool1_stage1', [2, 2, 0]), ('conv2_1', [64, 128, 3, 1, 1]), ('conv2_2', [128, 128, 3, 1, 1]), ('pool2_stage1', [2, 2, 0]), ('conv3_1', [128, 256, 3, 1, 1]), ('conv3_2', [256, 256, 3, 1, 1]), ('conv3_3', [256, 256, 3, 1, 1]), ('conv3_4', [256, 256, 3, 1, 1]), ('pool3_stage1', [2, 2, 0]), ('conv4_1', [256, 512, 3, 1, 1]), ('conv4_2', [512, 512, 3, 1, 1]), ('conv4_3_CPM', [512, 256, 3, 1, 1]), ('conv4_4_CPM', [256, 128, 3, 1, 1]) ]) # Stage 1 block1_1 = OrderedDict([ ('conv5_1_CPM_L1', [128, 128, 3, 1, 1]), ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]), ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]), ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]), ('conv5_5_CPM_L1', [512, 38, 1, 1, 0]) ]) block1_2 = OrderedDict([ ('conv5_1_CPM_L2', [128, 128, 3, 1, 1]), ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]), ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]), ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]), ('conv5_5_CPM_L2', [512, 19, 1, 1, 0]) ]) blocks['block1_1'] = block1_1 blocks['block1_2'] = block1_2 self.model0 = make_layers(block0, no_relu_layers) # Stages 2 - 6 for i in range(2, 7): blocks['block%d_1' % i] = OrderedDict([ ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]), ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]), ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]), ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]), ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]), ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]), ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0]) ]) blocks['block%d_2' % i] = OrderedDict([ ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]), ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]), ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]), ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]), ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]), ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]), ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0]) ]) for k in blocks.keys(): blocks[k] = make_layers(blocks[k], no_relu_layers) self.model1_1 = blocks['block1_1'] self.model2_1 = blocks['block2_1'] self.model3_1 = blocks['block3_1'] self.model4_1 = blocks['block4_1'] self.model5_1 = blocks['block5_1'] self.model6_1 = blocks['block6_1'] self.model1_2 = blocks['block1_2'] self.model2_2 = blocks['block2_2'] self.model3_2 = blocks['block3_2'] self.model4_2 = blocks['block4_2'] self.model5_2 = blocks['block5_2'] self.model6_2 = blocks['block6_2'] def forward(self, x): out1 = self.model0(x) out1_1 = self.model1_1(out1) out1_2 = self.model1_2(out1) out2 = torch.cat([out1_1, out1_2, out1], 1) out2_1 = self.model2_1(out2) out2_2 = self.model2_2(out2) out3 = torch.cat([out2_1, out2_2, out1], 1) out3_1 = self.model3_1(out3) out3_2 = self.model3_2(out3) out4 = torch.cat([out3_1, out3_2, out1], 1) out4_1 = self.model4_1(out4) out4_2 = self.model4_2(out4) out5 = torch.cat([out4_1, out4_2, out1], 1) out5_1 = self.model5_1(out5) out5_2 = self.model5_2(out5) out6 = torch.cat([out5_1, out5_2, out1], 1) out6_1 = self.model6_1(out6) out6_2 = self.model6_2(out6) return out6_1, out6_2 class handpose_model(nn.Module): def __init__(self): super(handpose_model, self).__init__() # these layers have no relu layer no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\ 'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6'] # stage 1 block1_0 = OrderedDict([ ('conv1_1', [3, 64, 3, 1, 1]), ('conv1_2', [64, 64, 3, 1, 1]), ('pool1_stage1', [2, 2, 0]), ('conv2_1', [64, 128, 3, 1, 1]), ('conv2_2', [128, 128, 3, 1, 1]), ('pool2_stage1', [2, 2, 0]), ('conv3_1', [128, 256, 3, 1, 1]), ('conv3_2', [256, 256, 3, 1, 1]), ('conv3_3', [256, 256, 3, 1, 1]), ('conv3_4', [256, 256, 3, 1, 1]), ('pool3_stage1', [2, 2, 0]), ('conv4_1', [256, 512, 3, 1, 1]), ('conv4_2', [512, 512, 3, 1, 1]), ('conv4_3', [512, 512, 3, 1, 1]), ('conv4_4', [512, 512, 3, 1, 1]), ('conv5_1', [512, 512, 3, 1, 1]), ('conv5_2', [512, 512, 3, 1, 1]), ('conv5_3_CPM', [512, 128, 3, 1, 1]) ]) block1_1 = OrderedDict([ ('conv6_1_CPM', [128, 512, 1, 1, 0]), ('conv6_2_CPM', [512, 22, 1, 1, 0]) ]) blocks = {} blocks['block1_0'] = block1_0 blocks['block1_1'] = block1_1 # stage 2-6 for i in range(2, 7): blocks['block%d' % i] = OrderedDict([ ('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]), ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]), ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]), ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]), ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]), ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]), ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0]) ]) for k in blocks.keys(): blocks[k] = make_layers(blocks[k], no_relu_layers) self.model1_0 = blocks['block1_0'] self.model1_1 = blocks['block1_1'] self.model2 = blocks['block2'] self.model3 = blocks['block3'] self.model4 = blocks['block4'] self.model5 = blocks['block5'] self.model6 = blocks['block6'] def forward(self, x): out1_0 = self.model1_0(x) out1_1 = self.model1_1(out1_0) concat_stage2 = torch.cat([out1_1, out1_0], 1) out_stage2 = self.model2(concat_stage2) concat_stage3 = torch.cat([out_stage2, out1_0], 1) out_stage3 = self.model3(concat_stage3) concat_stage4 = torch.cat([out_stage3, out1_0], 1) out_stage4 = self.model4(concat_stage4) concat_stage5 = torch.cat([out_stage4, out1_0], 1) out_stage5 = self.model5(concat_stage5) concat_stage6 = torch.cat([out_stage5, out1_0], 1) out_stage6 = self.model6(concat_stage6) return out_stage6 ================================================ FILE: stylegan_human/openpose/src/util.py ================================================ import numpy as np import math import cv2 import matplotlib from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas from matplotlib.figure import Figure import numpy as np import matplotlib.pyplot as plt import cv2 def padRightDownCorner(img, stride, padValue): h = img.shape[0] w = img.shape[1] pad = 4 * [None] pad[0] = 0 # up pad[1] = 0 # left pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right img_padded = img pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1)) img_padded = np.concatenate((pad_up, img_padded), axis=0) pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1)) img_padded = np.concatenate((pad_left, img_padded), axis=1) pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1)) img_padded = np.concatenate((img_padded, pad_down), axis=0) pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1)) img_padded = np.concatenate((img_padded, pad_right), axis=1) return img_padded, pad # transfer caffe model to pytorch which will match the layer name def transfer(model, model_weights): transfered_model_weights = {} for weights_name in model.state_dict().keys(): transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])] return transfered_model_weights # draw the body keypoint and lims def draw_bodypose(canvas, candidate, subset,show_number=False): stickwidth = 4 limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ [1, 16], [16, 18], [3, 17], [6, 18]] colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] for i in range(18): for n in range(len(subset)): index = int(subset[n][i]) if index == -1: continue x, y = candidate[index][0:2] cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) if show_number: cv2.putText(canvas, f'{index}', (int(x), int(y)),cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,0), 1, cv2.LINE_AA) ## calc and print average for i in range(17): for n in range(len(subset)): index = subset[n][np.array(limbSeq[i]) - 1] if -1 in index: continue cur_canvas = canvas.copy() Y = candidate[index.astype(int), 0] X = candidate[index.astype(int), 1] mX = np.mean(X) mY = np.mean(Y) length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) cv2.fillConvexPoly(cur_canvas, polygon, colors[i]) canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0) return canvas # get max index of 2d array def npmax(array): arrayindex = array.argmax(1) arrayvalue = array.max(1) i = arrayvalue.argmax() j = arrayindex[i] return i, j # get max index of 2d array def npmax_with_score(array): arrayindex = array.argmax(1) arrayvalue = array.max(1) i = arrayvalue.argmax() j = arrayindex[i] score =array[i][j] return i, j,score ================================================ FILE: stylegan_human/pti/pti_configs/__init__.py ================================================ ================================================ FILE: stylegan_human/pti/pti_configs/global_config.py ================================================ ## Device cuda_visible_devices = '0' device = 'cuda:0' ## Logs training_step = 1 image_rec_result_log_snapshot = 100 pivotal_training_steps = 0 model_snapshot_interval = 400 ## Run name to be updated during PTI run_name = 'exp' ================================================ FILE: stylegan_human/pti/pti_configs/hyperparameters.py ================================================ ## Architechture lpips_type = 'alex' first_inv_type = 'w+'#'w+' optim_type = 'adam' ## Locality regularization latent_ball_num_of_samples = 1 locality_regularization_interval = 1 use_locality_regularization = False regulizer_l2_lambda = 0.1 regulizer_lpips_lambda = 0.1 regulizer_alpha = 30 ## Loss pt_l2_lambda = 1 pt_lpips_lambda = 1 ## Steps LPIPS_value_threshold = 0.04 max_pti_steps = 350 first_inv_steps = 450 max_images_to_invert = 30 ## Optimization pti_learning_rate = 5e-4 first_inv_lr = 8e-3 train_batch_size = 1 use_last_w_pivots = False ================================================ FILE: stylegan_human/pti/pti_configs/paths_config.py ================================================ import os ## Pretrained models paths e4e = './pti/e4e_w+.pt' stylegan2_ada_shhq = './pretrained_models/stylegan_human_v2_1024.pkl' ir_se50 = '' #'./model_ir_se50.pth' ## Dirs for output files checkpoints_dir = './outputs/pti/checkpoints/' embedding_base_dir = './outputs/pti/embeddings' experiments_output_dir = './outputs/pti/' ## Input info ### Input dir, where the images reside input_data_path = 'aligned_image/' ### Inversion identifier, used to keeping track of the inversion results. Both the latent code and the generator input_data_id = 'test' ## Keywords pti_results_keyword = 'PTI' e4e_results_keyword = 'e4e' sg2_results_keyword = 'SG2' sg2_plus_results_keyword = 'SG2_Plus' multi_id_model_type = 'multi_id' ================================================ FILE: stylegan_human/pti/pti_models/__init__.py ================================================ ================================================ FILE: stylegan_human/pti/pti_models/e4e/__init__.py ================================================ ================================================ FILE: stylegan_human/pti/pti_models/e4e/encoders/__init__.py ================================================ ================================================ FILE: stylegan_human/pti/pti_models/e4e/encoders/helpers.py ================================================ from collections import namedtuple import torch import torch.nn.functional as F from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module """ ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) """ class Flatten(Module): def forward(self, input): return input.view(input.size(0), -1) def l2_norm(input, axis=1): norm = torch.norm(input, 2, axis, True) output = torch.div(input, norm) return output class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): """ A named tuple describing a ResNet block. """ def get_block(in_channel, depth, num_units, stride=2): return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] def get_blocks(num_layers): if num_layers == 50: blocks = [ get_block(in_channel=64, depth=64, num_units=3), get_block(in_channel=64, depth=128, num_units=4), get_block(in_channel=128, depth=256, num_units=14), get_block(in_channel=256, depth=512, num_units=3) ] elif num_layers == 100: blocks = [ get_block(in_channel=64, depth=64, num_units=3), get_block(in_channel=64, depth=128, num_units=13), get_block(in_channel=128, depth=256, num_units=30), get_block(in_channel=256, depth=512, num_units=3) ] elif num_layers == 152: blocks = [ get_block(in_channel=64, depth=64, num_units=3), get_block(in_channel=64, depth=128, num_units=8), get_block(in_channel=128, depth=256, num_units=36), get_block(in_channel=256, depth=512, num_units=3) ] else: raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) return blocks class SEModule(Module): def __init__(self, channels, reduction): super(SEModule, self).__init__() self.avg_pool = AdaptiveAvgPool2d(1) self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) self.relu = ReLU(inplace=True) self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) self.sigmoid = Sigmoid() def forward(self, x): module_input = x x = self.avg_pool(x) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) x = self.sigmoid(x) return module_input * x class bottleneck_IR(Module): def __init__(self, in_channel, depth, stride): super(bottleneck_IR, self).__init__() if in_channel == depth: self.shortcut_layer = MaxPool2d(1, stride) else: self.shortcut_layer = Sequential( Conv2d(in_channel, depth, (1, 1), stride, bias=False), BatchNorm2d(depth) ) self.res_layer = Sequential( BatchNorm2d(in_channel), Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) ) def forward(self, x): shortcut = self.shortcut_layer(x) res = self.res_layer(x) return res + shortcut class bottleneck_IR_SE(Module): def __init__(self, in_channel, depth, stride): super(bottleneck_IR_SE, self).__init__() if in_channel == depth: self.shortcut_layer = MaxPool2d(1, stride) else: self.shortcut_layer = Sequential( Conv2d(in_channel, depth, (1, 1), stride, bias=False), BatchNorm2d(depth) ) self.res_layer = Sequential( BatchNorm2d(in_channel), Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth), SEModule(depth, 16) ) def forward(self, x): shortcut = self.shortcut_layer(x) res = self.res_layer(x) return res + shortcut def _upsample_add(x, y): """Upsample and add two feature maps. Args: x: (Variable) top feature map to be upsampled. y: (Variable) lateral feature map. Returns: (Variable) added feature map. Note in PyTorch, when input size is odd, the upsampled feature map with `F.upsample(..., scale_factor=2, mode='nearest')` maybe not equal to the lateral feature map size. e.g. original input size: [N,_,15,15] -> conv2d feature map size: [N,_,8,8] -> upsampled feature map size: [N,_,16,16] So we choose bilinear upsample which supports arbitrary output sizes. """ _, _, H, W = y.size() return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y ================================================ FILE: stylegan_human/pti/pti_models/e4e/encoders/model_irse.py ================================================ from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module from encoder4editing.models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm """ Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) """ class Backbone(Module): def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): super(Backbone, self).__init__() assert input_size in [112, 224], "input_size should be 112 or 224" assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" blocks = get_blocks(num_layers) if mode == 'ir': unit_module = bottleneck_IR elif mode == 'ir_se': unit_module = bottleneck_IR_SE self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), PReLU(64)) if input_size == 112: self.output_layer = Sequential(BatchNorm2d(512), Dropout(drop_ratio), Flatten(), Linear(512 * 7 * 7, 512), BatchNorm1d(512, affine=affine)) else: self.output_layer = Sequential(BatchNorm2d(512), Dropout(drop_ratio), Flatten(), Linear(512 * 14 * 14, 512), BatchNorm1d(512, affine=affine)) modules = [] for block in blocks: for bottleneck in block: modules.append(unit_module(bottleneck.in_channel, bottleneck.depth, bottleneck.stride)) self.body = Sequential(*modules) def forward(self, x): x = self.input_layer(x) x = self.body(x) x = self.output_layer(x) return l2_norm(x) def IR_50(input_size): """Constructs a ir-50 model.""" model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) return model def IR_101(input_size): """Constructs a ir-101 model.""" model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) return model def IR_152(input_size): """Constructs a ir-152 model.""" model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) return model def IR_SE_50(input_size): """Constructs a ir_se-50 model.""" model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) return model def IR_SE_101(input_size): """Constructs a ir_se-101 model.""" model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) return model def IR_SE_152(input_size): """Constructs a ir_se-152 model.""" model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) return model ================================================ FILE: stylegan_human/pti/pti_models/e4e/encoders/psp_encoders.py ================================================ from enum import Enum import math import numpy as np import torch from torch import nn from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module from pti.pti_models.e4e.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE, _upsample_add from pti.pti_models.e4e.stylegan2.model import EqualLinear class ProgressiveStage(Enum): WTraining = 0 Delta1Training = 1 Delta2Training = 2 Delta3Training = 3 Delta4Training = 4 Delta5Training = 5 Delta6Training = 6 Delta7Training = 7 Delta8Training = 8 Delta9Training = 9 Delta10Training = 10 Delta11Training = 11 Delta12Training = 12 Delta13Training = 13 Delta14Training = 14 Delta15Training = 15 Delta16Training = 16 Delta17Training = 17 Inference = 18 class GradualStyleBlock(Module): def __init__(self, in_c, out_c, spatial): super(GradualStyleBlock, self).__init__() self.out_c = out_c self.spatial = spatial num_pools = int(np.log2(spatial)) modules = [] modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] for i in range(num_pools - 1): modules += [ Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU() ] self.convs = nn.Sequential(*modules) self.linear = EqualLinear(out_c, out_c, lr_mul=1) def forward(self, x): x = self.convs(x) x = x.view(-1, self.out_c) x = self.linear(x) return x class GradualStyleEncoder(Module): def __init__(self, num_layers, mode='ir', opts=None): super(GradualStyleEncoder, self).__init__() assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' blocks = get_blocks(num_layers) if mode == 'ir': unit_module = bottleneck_IR elif mode == 'ir_se': unit_module = bottleneck_IR_SE self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), PReLU(64)) modules = [] for block in blocks: for bottleneck in block: modules.append(unit_module(bottleneck.in_channel, bottleneck.depth, bottleneck.stride)) self.body = Sequential(*modules) self.styles = nn.ModuleList() log_size = int(math.log(opts.stylegan_size, 2)) self.style_count = 2 * log_size - 2 self.coarse_ind = 3 self.middle_ind = 7 for i in range(self.style_count): if i < self.coarse_ind: style = GradualStyleBlock(512, 512, 16) elif i < self.middle_ind: style = GradualStyleBlock(512, 512, 32) else: style = GradualStyleBlock(512, 512, 64) self.styles.append(style) self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) def forward(self, x): x = self.input_layer(x) latents = [] modulelist = list(self.body._modules.values()) for i, l in enumerate(modulelist): x = l(x) if i == 6: c1 = x elif i == 20: c2 = x elif i == 23: c3 = x for j in range(self.coarse_ind): latents.append(self.styles[j](c3)) p2 = _upsample_add(c3, self.latlayer1(c2)) for j in range(self.coarse_ind, self.middle_ind): latents.append(self.styles[j](p2)) p1 = _upsample_add(p2, self.latlayer2(c1)) for j in range(self.middle_ind, self.style_count): latents.append(self.styles[j](p1)) out = torch.stack(latents, dim=1) return out class Encoder4Editing(Module): def __init__(self, num_layers, mode='ir', opts=None): super(Encoder4Editing, self).__init__() assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' blocks = get_blocks(num_layers) if mode == 'ir': unit_module = bottleneck_IR elif mode == 'ir_se': unit_module = bottleneck_IR_SE self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), PReLU(64)) modules = [] for block in blocks: for bottleneck in block: modules.append(unit_module(bottleneck.in_channel, bottleneck.depth, bottleneck.stride)) self.body = Sequential(*modules) self.styles = nn.ModuleList() log_size = int(math.log(opts.stylegan_size, 2)) self.style_count = 2 * log_size - 2 self.coarse_ind = 3 self.middle_ind = 7 for i in range(self.style_count): if i < self.coarse_ind: style = GradualStyleBlock(512, 512, 16) elif i < self.middle_ind: style = GradualStyleBlock(512, 512, 32) else: style = GradualStyleBlock(512, 512, 64) self.styles.append(style) self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) self.progressive_stage = ProgressiveStage.Inference def get_deltas_starting_dimensions(self): ''' Get a list of the initial dimension of every delta from which it is applied ''' return list(range(self.style_count)) # Each dimension has a delta applied to it def set_progressive_stage(self, new_stage: ProgressiveStage): self.progressive_stage = new_stage print('Changed progressive stage to: ', new_stage) def forward(self, x): x = self.input_layer(x) modulelist = list(self.body._modules.values()) for i, l in enumerate(modulelist): x = l(x) if i == 6: c1 = x elif i == 20: c2 = x elif i == 23: c3 = x # Infer main W and duplicate it w0 = self.styles[0](c3) w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2) stage = self.progressive_stage.value features = c3 for i in range(1, min(stage + 1, self.style_count)): # Infer additional deltas if i == self.coarse_ind: p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features features = p2 elif i == self.middle_ind: p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features features = p1 delta_i = self.styles[i](features) w[:, i] += delta_i return w ================================================ FILE: stylegan_human/pti/pti_models/e4e/latent_codes_pool.py ================================================ import random import torch class LatentCodesPool: """This class implements latent codes buffer that stores previously generated w latent codes. This buffer enables us to update discriminators using a history of generated w's rather than the ones produced by the latest encoder. """ def __init__(self, pool_size): """Initialize the ImagePool class Parameters: pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created """ self.pool_size = pool_size if self.pool_size > 0: # create an empty pool self.num_ws = 0 self.ws = [] def query(self, ws): """Return w's from the pool. Parameters: ws: the latest generated w's from the generator Returns w's from the buffer. By 50/100, the buffer will return input w's. By 50/100, the buffer will return w's previously stored in the buffer, and insert the current w's to the buffer. """ if self.pool_size == 0: # if the buffer size is 0, do nothing return ws return_ws = [] for w in ws: # ws.shape: (batch, 512) or (batch, n_latent, 512) # w = torch.unsqueeze(image.data, 0) if w.ndim == 2: i = random.randint(0, len(w) - 1) # apply a random latent index as a candidate w = w[i] self.handle_w(w, return_ws) return_ws = torch.stack(return_ws, 0) # collect all the images and return return return_ws def handle_w(self, w, return_ws): if self.num_ws < self.pool_size: # if the buffer is not full; keep inserting current codes to the buffer self.num_ws = self.num_ws + 1 self.ws.append(w) return_ws.append(w) else: p = random.uniform(0, 1) if p > 0.5: # by 50% chance, the buffer will return a previously stored latent code, and insert the current code into the buffer random_id = random.randint(0, self.pool_size - 1) # randint is inclusive tmp = self.ws[random_id].clone() self.ws[random_id] = w return_ws.append(tmp) else: # by another 50% chance, the buffer will return the current image return_ws.append(w) ================================================ FILE: stylegan_human/pti/pti_models/e4e/psp.py ================================================ import matplotlib from pti.pti_configs import paths_config matplotlib.use('Agg') import torch from torch import nn from pti.pti_models.e4e.encoders import psp_encoders from pti.pti_models.e4e.stylegan2.model import Generator def get_keys(d, name): if 'state_dict' in d: d = d['state_dict'] d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} return d_filt class pSp(nn.Module): def __init__(self, opts): super(pSp, self).__init__() self.opts = opts # Define architecture self.encoder = self.set_encoder() self.decoder = Generator(opts.stylegan_size, 512, 8, channel_multiplier=2) self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256 // 2)) # Load weights if needed self.load_weights() def set_encoder(self): if self.opts.encoder_type == 'GradualStyleEncoder': encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts) elif self.opts.encoder_type == 'Encoder4Editing': encoder = psp_encoders.Encoder4Editing(50, 'ir_se', self.opts) elif self.opts.encoder_type == 'SingleStyleCodeEncoder': encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts) else: raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type)) return encoder def load_weights(self): if self.opts.checkpoint_path is not None: print('Loading e4e over the pSp framework from checkpoint: {}'.format(self.opts.checkpoint_path)) ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True) self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True) self.__load_latent_avg(ckpt) else: print('Loading encoders weights from irse50!') encoder_ckpt = torch.load(model_paths['ir_se50']) self.encoder.load_state_dict(encoder_ckpt, strict=False) print('Loading decoder weights from pretrained!') ckpt = torch.load(self.opts.stylegan_weights) self.decoder.load_state_dict(ckpt['g_ema'], strict=False) self.__load_latent_avg(ckpt, repeat=self.encoder.style_count) def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True, inject_latent=None, return_latents=False, alpha=None): if input_code: codes = x else: codes = self.encoder(x) # normalize with respect to the center of an average face if self.opts.start_from_latent_avg: if codes.ndim == 2: codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :] else: codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1) if latent_mask is not None: for i in latent_mask: if inject_latent is not None: if alpha is not None: codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i] else: codes[:, i] = inject_latent[:, i] else: codes[:, i] = 0 input_is_latent = not input_code images, result_latent = self.decoder([codes], input_is_latent=input_is_latent, randomize_noise=randomize_noise, return_latents=return_latents) if resize: images = self.face_pool(images) if return_latents: return images, result_latent else: return images def __load_latent_avg(self, ckpt, repeat=None): if 'latent_avg' in ckpt: self.latent_avg = ckpt['latent_avg'].to(self.opts.device) if repeat is not None: self.latent_avg = self.latent_avg.repeat(repeat, 1) else: self.latent_avg = None ================================================ FILE: stylegan_human/pti/pti_models/e4e/stylegan2/__init__.py ================================================ ================================================ FILE: stylegan_human/pti/pti_models/e4e/stylegan2/model.py ================================================ import math import random import torch from torch import nn from torch.nn import functional as F from .op.fused_act import FusedLeakyReLU, fused_leaky_relu from .op.upfirdn2d import upfirdn2d class PixelNorm(nn.Module): def __init__(self): super().__init__() def forward(self, input): return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) def make_kernel(k): k = torch.tensor(k, dtype=torch.float32) if k.ndim == 1: k = k[None, :] * k[:, None] k /= k.sum() return k class Upsample(nn.Module): def __init__(self, kernel, factor=2): super().__init__() self.factor = factor kernel = make_kernel(kernel) * (factor ** 2) self.register_buffer('kernel', kernel) p = kernel.shape[0] - factor pad0 = (p + 1) // 2 + factor - 1 pad1 = p // 2 self.pad = (pad0, pad1) def forward(self, input): out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) return out class Downsample(nn.Module): def __init__(self, kernel, factor=2): super().__init__() self.factor = factor kernel = make_kernel(kernel) self.register_buffer('kernel', kernel) p = kernel.shape[0] - factor pad0 = (p + 1) // 2 pad1 = p // 2 self.pad = (pad0, pad1) def forward(self, input): out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) return out class Blur(nn.Module): def __init__(self, kernel, pad, upsample_factor=1): super().__init__() kernel = make_kernel(kernel) if upsample_factor > 1: kernel = kernel * (upsample_factor ** 2) self.register_buffer('kernel', kernel) self.pad = pad def forward(self, input): out = upfirdn2d(input, self.kernel, pad=self.pad) return out class EqualConv2d(nn.Module): def __init__( self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True ): super().__init__() self.weight = nn.Parameter( torch.randn(out_channel, in_channel, kernel_size, kernel_size) ) self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) self.stride = stride self.padding = padding if bias: self.bias = nn.Parameter(torch.zeros(out_channel)) else: self.bias = None def forward(self, input): out = F.conv2d( input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding, ) return out def __repr__(self): return ( f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' ) class EqualLinear(nn.Module): def __init__( self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None ): super().__init__() self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) if bias: self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) else: self.bias = None self.activation = activation self.scale = (1 / math.sqrt(in_dim)) * lr_mul self.lr_mul = lr_mul def forward(self, input): if self.activation: out = F.linear(input, self.weight * self.scale) out = fused_leaky_relu(out, self.bias * self.lr_mul) else: out = F.linear( input, self.weight * self.scale, bias=self.bias * self.lr_mul ) return out def __repr__(self): return ( f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' ) class ScaledLeakyReLU(nn.Module): def __init__(self, negative_slope=0.2): super().__init__() self.negative_slope = negative_slope def forward(self, input): out = F.leaky_relu(input, negative_slope=self.negative_slope) return out * math.sqrt(2) class ModulatedConv2d(nn.Module): def __init__( self, in_channel, out_channel, kernel_size, style_dim, demodulate=True, upsample=False, downsample=False, blur_kernel=[1, 3, 3, 1], ): super().__init__() self.eps = 1e-8 self.kernel_size = kernel_size self.in_channel = in_channel self.out_channel = out_channel self.upsample = upsample self.downsample = downsample if upsample: factor = 2 p = (len(blur_kernel) - factor) - (kernel_size - 1) pad0 = (p + 1) // 2 + factor - 1 pad1 = p // 2 + 1 self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) if downsample: factor = 2 p = (len(blur_kernel) - factor) + (kernel_size - 1) pad0 = (p + 1) // 2 pad1 = p // 2 self.blur = Blur(blur_kernel, pad=(pad0, pad1)) fan_in = in_channel * kernel_size ** 2 self.scale = 1 / math.sqrt(fan_in) self.padding = kernel_size // 2 self.weight = nn.Parameter( torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) ) self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) self.demodulate = demodulate def __repr__(self): return ( f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' f'upsample={self.upsample}, downsample={self.downsample})' ) def forward(self, input, style): batch, in_channel, height, width = input.shape style = self.modulation(style).view(batch, 1, in_channel, 1, 1) weight = self.scale * self.weight * style if self.demodulate: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) weight = weight.view( batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size ) if self.upsample: input = input.view(1, batch * in_channel, height, width) weight = weight.view( batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size ) weight = weight.transpose(1, 2).reshape( batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size ) out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) out = self.blur(out) elif self.downsample: input = self.blur(input) _, _, height, width = input.shape input = input.view(1, batch * in_channel, height, width) out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) else: input = input.view(1, batch * in_channel, height, width) out = F.conv2d(input, weight, padding=self.padding, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) return out class NoiseInjection(nn.Module): def __init__(self): super().__init__() self.weight = nn.Parameter(torch.zeros(1)) def forward(self, image, noise=None): if noise is None: batch, _, height, width = image.shape noise = image.new_empty(batch, 1, height, width).normal_() return image + self.weight * noise class ConstantInput(nn.Module): def __init__(self, channel, size=4): super().__init__() self.input = nn.Parameter(torch.randn(1, channel, size, size // 2)) def forward(self, input): batch = input.shape[0] out = self.input.repeat(batch, 1, 1, 1) return out class StyledConv(nn.Module): def __init__( self, in_channel, out_channel, kernel_size, style_dim, upsample=False, blur_kernel=[1, 3, 3, 1], demodulate=True, ): super().__init__() self.conv = ModulatedConv2d( in_channel, out_channel, kernel_size, style_dim, upsample=upsample, blur_kernel=blur_kernel, demodulate=demodulate, ) self.noise = NoiseInjection() # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) # self.activate = ScaledLeakyReLU(0.2) self.activate = FusedLeakyReLU(out_channel) def forward(self, input, style, noise=None): out = self.conv(input, style) out = self.noise(out, noise=noise) # out = out + self.bias out = self.activate(out) return out class ToRGB(nn.Module): def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): super().__init__() if upsample: self.upsample = Upsample(blur_kernel) self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) def forward(self, input, style, skip=None): out = self.conv(input, style) out = out + self.bias if skip is not None: skip = self.upsample(skip) out = out + skip return out class Generator(nn.Module): def __init__( self, size, style_dim, n_mlp, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], lr_mlp=0.01, ): super().__init__() self.size = size self.style_dim = style_dim layers = [PixelNorm()] for i in range(n_mlp): layers.append( EqualLinear( style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu' ) ) self.style = nn.Sequential(*layers) self.channels = { 4: 512, 8: 512, 16: 512, 32: 512, 64: 256 * channel_multiplier, 128: 128 * channel_multiplier, 256: 64 * channel_multiplier, 512: 32 * channel_multiplier, 1024: 16 * channel_multiplier, } self.input = ConstantInput(self.channels[4]) self.conv1 = StyledConv( self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel ) self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) self.log_size = int(math.log(size, 2)) self.num_layers = (self.log_size - 2) * 2 + 1 self.convs = nn.ModuleList() self.upsamples = nn.ModuleList() self.to_rgbs = nn.ModuleList() self.noises = nn.Module() in_channel = self.channels[4] for layer_idx in range(self.num_layers): res = (layer_idx + 5) // 2 shape = [1, 1, 2 ** res, 2 ** res // 2] self.noises.register_buffer( "noise_{}".format(layer_idx), torch.randn(*shape) ) for i in range(3, self.log_size + 1): out_channel = self.channels[2 ** i] self.convs.append( StyledConv( in_channel, out_channel, 3, style_dim, upsample=True, blur_kernel=blur_kernel, ) ) self.convs.append( StyledConv( out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel ) ) self.to_rgbs.append(ToRGB(out_channel, style_dim)) in_channel = out_channel self.n_latent = self.log_size * 2 - 2 def make_noise(self): device = self.input.input.device noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2 // 2, device=device)] for i in range(3, self.log_size + 1): for _ in range(2): noises.append(torch.randn(1, 1, 2 ** i, 2 ** i // 2, device=device)) return noises def mean_latent(self, n_latent): latent_in = torch.randn( n_latent, self.style_dim, device=self.input.input.device ) latent = self.style(latent_in).mean(0, keepdim=True) return latent def get_latent(self, input): return self.style(input) def forward( self, styles, return_latents=False, return_features=False, inject_index=None, truncation=1, truncation_latent=None, input_is_latent=False, noise=None, randomize_noise=True, ): if not input_is_latent: styles = [self.style(s) for s in styles] if noise is None: if randomize_noise: noise = [None] * self.num_layers else: noise = [ getattr(self.noises, f'noise_{i}') for i in range(self.num_layers) ] if truncation < 1: style_t = [] for style in styles: style_t.append( truncation_latent + truncation * (style - truncation_latent) ) styles = style_t if len(styles) < 2: inject_index = self.n_latent if styles[0].ndim < 3: latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) else: latent = styles[0] else: if inject_index is None: inject_index = random.randint(1, self.n_latent - 1) # latent = styles[0].unsqueeze(0) # if latent.shape[1] == 1: # latent = latent.repeat(1, inject_index, 1) # else: # latent = latent[:, :inject_index, :] latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) # latent = styles[0][:, :inject_index, :] # latent2 = styles[1][:, inject_index:, :] latent = torch.cat([latent, latent2], 1) out = self.input(latent) out = self.conv1(out, latent[:, 0], noise=noise[0]) skip = self.to_rgb1(out, latent[:, 1]) i = 1 for conv1, conv2, noise1, noise2, to_rgb in zip( self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs ): out = conv1(out, latent[:, i], noise=noise1) out = conv2(out, latent[:, i + 1], noise=noise2) skip = to_rgb(out, latent[:, i + 2], skip) i += 2 image = skip if return_latents: return image, latent elif return_features: return image, out else: return image, None class ConvLayer(nn.Sequential): def __init__( self, in_channel, out_channel, kernel_size, downsample=False, blur_kernel=[1, 3, 3, 1], bias=True, activate=True, ): layers = [] if downsample: factor = 2 p = (len(blur_kernel) - factor) + (kernel_size - 1) pad0 = (p + 1) // 2 pad1 = p // 2 layers.append(Blur(blur_kernel, pad=(pad0, pad1))) stride = 2 self.padding = 0 else: stride = 1 self.padding = kernel_size // 2 layers.append( EqualConv2d( in_channel, out_channel, kernel_size, padding=self.padding, stride=stride, bias=bias and not activate, ) ) if activate: if bias: layers.append(FusedLeakyReLU(out_channel)) else: layers.append(ScaledLeakyReLU(0.2)) super().__init__(*layers) class ResBlock(nn.Module): def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): super().__init__() self.conv1 = ConvLayer(in_channel, in_channel, 3) self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) self.skip = ConvLayer( in_channel, out_channel, 1, downsample=True, activate=False, bias=False ) def forward(self, input): out = self.conv1(input) out = self.conv2(out) skip = self.skip(input) out = (out + skip) / math.sqrt(2) return out class Discriminator(nn.Module): def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): super().__init__() channels = { 4: 512, 8: 512, 16: 512, 32: 512, 64: 256 * channel_multiplier, 128: 128 * channel_multiplier, 256: 64 * channel_multiplier, 512: 32 * channel_multiplier, 1024: 16 * channel_multiplier, } convs = [ConvLayer(3, channels[size], 1)] log_size = int(math.log(size, 2)) in_channel = channels[size] for i in range(log_size, 2, -1): out_channel = channels[2 ** (i - 1)] convs.append(ResBlock(in_channel, out_channel, blur_kernel)) in_channel = out_channel self.convs = nn.Sequential(*convs) self.stddev_group = 4 self.stddev_feat = 1 self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) self.final_linear = nn.Sequential( EqualLinear(channels[4] * 4 * 4 // 2, channels[4], activation='fused_lrelu'), EqualLinear(channels[4], 1), ) def forward(self, input): out = self.convs(input) batch, channel, height, width = out.shape group = min(batch, self.stddev_group) stddev = out.view( group, -1, self.stddev_feat, channel // self.stddev_feat, height, width ) stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) stddev = stddev.repeat(group, 1, height, width) out = torch.cat([out, stddev], 1) out = self.final_conv(out) out = out.view(batch, -1) out = self.final_linear(out) return out ================================================ FILE: stylegan_human/pti/pti_models/e4e/stylegan2/op/__init__.py ================================================ from .fused_act import FusedLeakyReLU, fused_leaky_relu from .upfirdn2d import upfirdn2d ================================================ FILE: stylegan_human/pti/pti_models/e4e/stylegan2/op/fused_act.py ================================================ import os import torch from torch import nn from torch.nn import functional as F from torch.autograd import Function module_path = os.path.dirname(__file__) class FusedLeakyReLU(nn.Module): def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): super().__init__() self.bias = nn.Parameter(torch.zeros(channel)) self.negative_slope = negative_slope self.scale = scale def forward(self, input): return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): rest_dim = [1] * (input.ndim - bias.ndim - 1) input = input.cuda() return ( F.leaky_relu( input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope ) * scale ) ================================================ FILE: stylegan_human/pti/pti_models/e4e/stylegan2/op/fused_bias_act.cpp ================================================ #include torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, int act, int grad, float alpha, float scale); #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, int act, int grad, float alpha, float scale) { CHECK_CUDA(input); CHECK_CUDA(bias); return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); } ================================================ FILE: stylegan_human/pti/pti_models/e4e/stylegan2/op/fused_bias_act_kernel.cu ================================================ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. // // This work is made available under the Nvidia Source Code License-NC. // To view a copy of this license, visit // https://nvlabs.github.io/stylegan2/license.html #include #include #include #include #include #include #include template static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; scalar_t zero = 0.0; for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { scalar_t x = p_x[xi]; if (use_bias) { x += p_b[(xi / step_b) % size_b]; } scalar_t ref = use_ref ? p_ref[xi] : zero; scalar_t y; switch (act * 10 + grad) { default: case 10: y = x; break; case 11: y = x; break; case 12: y = 0.0; break; case 30: y = (x > 0.0) ? x : x * alpha; break; case 31: y = (ref > 0.0) ? x : x * alpha; break; case 32: y = 0.0; break; } out[xi] = y * scale; } } torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, int act, int grad, float alpha, float scale) { int curDevice = -1; cudaGetDevice(&curDevice); cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); auto x = input.contiguous(); auto b = bias.contiguous(); auto ref = refer.contiguous(); int use_bias = b.numel() ? 1 : 0; int use_ref = ref.numel() ? 1 : 0; int size_x = x.numel(); int size_b = b.numel(); int step_b = 1; for (int i = 1 + 1; i < x.dim(); i++) { step_b *= x.size(i); } int loop_x = 4; int block_size = 4 * 32; int grid_size = (size_x - 1) / (loop_x * block_size) + 1; auto y = torch::empty_like(x); AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { fused_bias_act_kernel<<>>( y.data_ptr(), x.data_ptr(), b.data_ptr(), ref.data_ptr(), act, grad, alpha, scale, loop_x, size_x, step_b, size_b, use_bias, use_ref ); }); return y; } ================================================ FILE: stylegan_human/pti/pti_models/e4e/stylegan2/op/upfirdn2d.cpp ================================================ #include torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, int up_x, int up_y, int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0, int pad_y1); #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, int up_x, int up_y, int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0, int pad_y1) { CHECK_CUDA(input); CHECK_CUDA(kernel); return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); } ================================================ FILE: stylegan_human/pti/pti_models/e4e/stylegan2/op/upfirdn2d.py ================================================ import os import torch from torch.nn import functional as F module_path = os.path.dirname(__file__) def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): out = upfirdn2d_native( input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] ) return out def upfirdn2d_native( input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 ): _, channel, in_h, in_w = input.shape input = input.reshape(-1, in_h, in_w, 1) _, in_h, in_w, minor = input.shape kernel_h, kernel_w = kernel.shape out = input.view(-1, in_h, 1, in_w, 1, minor) out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) out = out.view(-1, in_h * up_y, in_w * up_x, minor) out = F.pad( out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] ) out = out[ :, max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), :, ] out = out.permute(0, 3, 1, 2) out = out.reshape( [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] ) w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) out = F.conv2d(out, w) out = out.reshape( -1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, ) out = out.permute(0, 2, 3, 1) out = out[:, ::down_y, ::down_x, :] out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 return out.view(-1, channel, out_h, out_w) ================================================ FILE: stylegan_human/pti/pti_models/e4e/stylegan2/op/upfirdn2d_kernel.cu ================================================ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. // // This work is made available under the Nvidia Source Code License-NC. // To view a copy of this license, visit // https://nvlabs.github.io/stylegan2/license.html #include #include #include #include #include #include #include static __host__ __device__ __forceinline__ int floor_div(int a, int b) { int c = a / b; if (c * b > a) { c--; } return c; } struct UpFirDn2DKernelParams { int up_x; int up_y; int down_x; int down_y; int pad_x0; int pad_x1; int pad_y0; int pad_y1; int major_dim; int in_h; int in_w; int minor_dim; int kernel_h; int kernel_w; int out_h; int out_w; int loop_major; int loop_x; }; template __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) { const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; __shared__ volatile float sk[kernel_h][kernel_w]; __shared__ volatile float sx[tile_in_h][tile_in_w]; int minor_idx = blockIdx.x; int tile_out_y = minor_idx / p.minor_dim; minor_idx -= tile_out_y * p.minor_dim; tile_out_y *= tile_out_h; int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; int major_idx_base = blockIdx.z * p.loop_major; if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) { return; } for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) { int ky = tap_idx / kernel_w; int kx = tap_idx - ky * kernel_w; scalar_t v = 0.0; if (kx < p.kernel_w & ky < p.kernel_h) { v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; } sk[ky][kx] = v; } for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) { for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) { int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; int tile_in_x = floor_div(tile_mid_x, up_x); int tile_in_y = floor_div(tile_mid_y, up_y); __syncthreads(); for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) { int rel_in_y = in_idx / tile_in_w; int rel_in_x = in_idx - rel_in_y * tile_in_w; int in_x = rel_in_x + tile_in_x; int in_y = rel_in_y + tile_in_y; scalar_t v = 0.0; if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; } sx[rel_in_y][rel_in_x] = v; } __syncthreads(); for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) { int rel_out_y = out_idx / tile_out_w; int rel_out_x = out_idx - rel_out_y * tile_out_w; int out_x = rel_out_x + tile_out_x; int out_y = rel_out_y + tile_out_y; int mid_x = tile_mid_x + rel_out_x * down_x; int mid_y = tile_mid_y + rel_out_y * down_y; int in_x = floor_div(mid_x, up_x); int in_y = floor_div(mid_y, up_y); int rel_in_x = in_x - tile_in_x; int rel_in_y = in_y - tile_in_y; int kernel_x = (in_x + 1) * up_x - mid_x - 1; int kernel_y = (in_y + 1) * up_y - mid_y - 1; scalar_t v = 0.0; #pragma unroll for (int y = 0; y < kernel_h / up_y; y++) #pragma unroll for (int x = 0; x < kernel_w / up_x; x++) v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x]; if (out_x < p.out_w & out_y < p.out_h) { out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; } } } } } torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, int up_x, int up_y, int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0, int pad_y1) { int curDevice = -1; cudaGetDevice(&curDevice); cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); UpFirDn2DKernelParams p; auto x = input.contiguous(); auto k = kernel.contiguous(); p.major_dim = x.size(0); p.in_h = x.size(1); p.in_w = x.size(2); p.minor_dim = x.size(3); p.kernel_h = k.size(0); p.kernel_w = k.size(1); p.up_x = up_x; p.up_y = up_y; p.down_x = down_x; p.down_y = down_y; p.pad_x0 = pad_x0; p.pad_x1 = pad_x1; p.pad_y0 = pad_y0; p.pad_y1 = pad_y1; p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y; p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x; auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); int mode = -1; int tile_out_h; int tile_out_w; if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { mode = 1; tile_out_h = 16; tile_out_w = 64; } if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) { mode = 2; tile_out_h = 16; tile_out_w = 64; } if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { mode = 3; tile_out_h = 16; tile_out_w = 64; } if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) { mode = 4; tile_out_h = 16; tile_out_w = 64; } if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) { mode = 5; tile_out_h = 8; tile_out_w = 32; } if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) { mode = 6; tile_out_h = 8; tile_out_w = 32; } dim3 block_size; dim3 grid_size; if (tile_out_h > 0 && tile_out_w) { p.loop_major = (p.major_dim - 1) / 16384 + 1; p.loop_x = 1; block_size = dim3(32 * 8, 1, 1); grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, (p.major_dim - 1) / p.loop_major + 1); } AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { switch (mode) { case 1: upfirdn2d_kernel<<>>( out.data_ptr(), x.data_ptr(), k.data_ptr(), p ); break; case 2: upfirdn2d_kernel<<>>( out.data_ptr(), x.data_ptr(), k.data_ptr(), p ); break; case 3: upfirdn2d_kernel<<>>( out.data_ptr(), x.data_ptr(), k.data_ptr(), p ); break; case 4: upfirdn2d_kernel<<>>( out.data_ptr(), x.data_ptr(), k.data_ptr(), p ); break; case 5: upfirdn2d_kernel<<>>( out.data_ptr(), x.data_ptr(), k.data_ptr(), p ); break; case 6: upfirdn2d_kernel<<>>( out.data_ptr(), x.data_ptr(), k.data_ptr(), p ); break; } }); return out; } ================================================ FILE: stylegan_human/pti/training/__init__.py ================================================ ================================================ FILE: stylegan_human/pti/training/coaches/__init__.py ================================================ ================================================ FILE: stylegan_human/pti/training/coaches/base_coach.py ================================================ import abc import os import pickle from argparse import Namespace import wandb import os.path from .localitly_regulizer import Space_Regulizer, l2_loss import torch from torchvision import transforms from lpips import LPIPS from pti.training.projectors import w_projector from pti.pti_configs import global_config, paths_config, hyperparameters from pti.pti_models.e4e.psp import pSp from utils.log_utils import log_image_from_w from utils.models_utils import toogle_grad, load_old_G class BaseCoach: def __init__(self, data_loader, use_wandb): self.use_wandb = use_wandb self.data_loader = data_loader self.w_pivots = {} self.image_counter = 0 if hyperparameters.first_inv_type == 'w+': self.initilize_e4e() self.e4e_image_transform = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((256, 128)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) # Initialize loss self.lpips_loss = LPIPS(net=hyperparameters.lpips_type).to(global_config.device).eval() self.restart_training() # Initialize checkpoint dir self.checkpoint_dir = paths_config.checkpoints_dir os.makedirs(self.checkpoint_dir, exist_ok=True) def restart_training(self): # Initialize networks self.G = load_old_G() toogle_grad(self.G, True) self.original_G = load_old_G() self.space_regulizer = Space_Regulizer(self.original_G, self.lpips_loss) self.optimizer = self.configure_optimizers() def get_inversion(self, w_path_dir, image_name, image): embedding_dir = f'{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}' os.makedirs(embedding_dir, exist_ok=True) w_pivot = None if hyperparameters.use_last_w_pivots: w_pivot = self.load_inversions(w_path_dir, image_name) if not hyperparameters.use_last_w_pivots or w_pivot is None: w_pivot = self.calc_inversions(image, image_name) torch.save(w_pivot, f'{embedding_dir}/0.pt') w_pivot = w_pivot.to(global_config.device) return w_pivot def load_inversions(self, w_path_dir, image_name): if image_name in self.w_pivots: return self.w_pivots[image_name] if hyperparameters.first_inv_type == 'w+': w_potential_path = f'{w_path_dir}/{paths_config.e4e_results_keyword}/{image_name}/0.pt' else: w_potential_path = f'{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}/0.pt' if not os.path.isfile(w_potential_path): return None w = torch.load(w_potential_path).to(global_config.device) self.w_pivots[image_name] = w return w def calc_inversions(self, image, image_name): if hyperparameters.first_inv_type == 'w+': w = self.get_e4e_inversion(image) else: id_image = torch.squeeze((image.to(global_config.device) + 1) / 2) * 255 w = w_projector.project(self.G, id_image, device=torch.device(global_config.device), w_avg_samples=600, num_steps=hyperparameters.first_inv_steps, w_name=image_name, use_wandb=self.use_wandb) return w @abc.abstractmethod def train(self): pass def configure_optimizers(self): optimizer = torch.optim.Adam(self.G.parameters(), lr=hyperparameters.pti_learning_rate) return optimizer def calc_loss(self, generated_images, real_images, log_name, new_G, use_ball_holder, w_batch): loss = 0.0 if hyperparameters.pt_l2_lambda > 0: l2_loss_val = l2_loss(generated_images, real_images) if self.use_wandb: wandb.log({f'MSE_loss_val_{log_name}': l2_loss_val.detach().cpu()}, step=global_config.training_step) loss += l2_loss_val * hyperparameters.pt_l2_lambda if hyperparameters.pt_lpips_lambda > 0: loss_lpips = self.lpips_loss(generated_images, real_images) loss_lpips = torch.squeeze(loss_lpips) if self.use_wandb: wandb.log({f'LPIPS_loss_val_{log_name}': loss_lpips.detach().cpu()}, step=global_config.training_step) loss += loss_lpips * hyperparameters.pt_lpips_lambda if use_ball_holder and hyperparameters.use_locality_regularization: ball_holder_loss_val = self.space_regulizer.space_regulizer_loss(new_G, w_batch, use_wandb=self.use_wandb) loss += ball_holder_loss_val return loss, l2_loss_val, loss_lpips def forward(self, w): generated_images = self.G.synthesis(w, noise_mode='const', force_fp32=True) return generated_images def initilize_e4e(self): ckpt = torch.load(paths_config.e4e, map_location='cpu') opts = ckpt['opts'] opts['batch_size'] = hyperparameters.train_batch_size opts['checkpoint_path'] = paths_config.e4e opts = Namespace(**opts) self.e4e_inversion_net = pSp(opts) self.e4e_inversion_net.eval() self.e4e_inversion_net = self.e4e_inversion_net.to(global_config.device) toogle_grad(self.e4e_inversion_net, False) def get_e4e_inversion(self, image): image = (image + 1) / 2 new_image = self.e4e_image_transform(image[0]).to(global_config.device) _, w = self.e4e_inversion_net(new_image.unsqueeze(0), randomize_noise=False, return_latents=True, resize=False, input_code=False) if self.use_wandb: log_image_from_w(w, self.G, 'First e4e inversion') return w ================================================ FILE: stylegan_human/pti/training/coaches/localitly_regulizer.py ================================================ import torch import numpy as np import wandb from pti.pti_configs import hyperparameters, global_config l2_criterion = torch.nn.MSELoss(reduction='mean') def l2_loss(real_images, generated_images): loss = l2_criterion(real_images, generated_images) return loss class Space_Regulizer: def __init__(self, original_G, lpips_net): self.original_G = original_G self.morphing_regulizer_alpha = hyperparameters.regulizer_alpha self.lpips_loss = lpips_net def get_morphed_w_code(self, new_w_code, fixed_w): interpolation_direction = new_w_code - fixed_w interpolation_direction_norm = torch.norm(interpolation_direction, p=2) direction_to_move = hyperparameters.regulizer_alpha * interpolation_direction / interpolation_direction_norm result_w = fixed_w + direction_to_move self.morphing_regulizer_alpha * fixed_w + (1 - self.morphing_regulizer_alpha) * new_w_code return result_w def get_image_from_ws(self, w_codes, G): return torch.cat([G.synthesis(w_code, noise_mode='none', force_fp32=True) for w_code in w_codes]) def ball_holder_loss_lazy(self, new_G, num_of_sampled_latents, w_batch, use_wandb=False): loss = 0.0 z_samples = np.random.randn(num_of_sampled_latents, self.original_G.z_dim) w_samples = self.original_G.mapping(torch.from_numpy(z_samples).to(global_config.device), None, truncation_psi=0.5) territory_indicator_ws = [self.get_morphed_w_code(w_code.unsqueeze(0), w_batch) for w_code in w_samples] for w_code in territory_indicator_ws: new_img = new_G.synthesis(w_code, noise_mode='none', force_fp32=True) with torch.no_grad(): old_img = self.original_G.synthesis(w_code, noise_mode='none', force_fp32=True) if hyperparameters.regulizer_l2_lambda > 0: l2_loss_val = l2_loss.l2_loss(old_img, new_img) if use_wandb: wandb.log({f'space_regulizer_l2_loss_val': l2_loss_val.detach().cpu()}, step=global_config.training_step) loss += l2_loss_val * hyperparameters.regulizer_l2_lambda if hyperparameters.regulizer_lpips_lambda > 0: loss_lpips = self.lpips_loss(old_img, new_img) loss_lpips = torch.mean(torch.squeeze(loss_lpips)) if use_wandb: wandb.log({f'space_regulizer_lpips_loss_val': loss_lpips.detach().cpu()}, step=global_config.training_step) loss += loss_lpips * hyperparameters.regulizer_lpips_lambda return loss / len(territory_indicator_ws) def space_regulizer_loss(self, new_G, w_batch, use_wandb): ret_val = self.ball_holder_loss_lazy(new_G, hyperparameters.latent_ball_num_of_samples, w_batch, use_wandb) return ret_val ================================================ FILE: stylegan_human/pti/training/coaches/multi_id_coach.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. import os import torch from tqdm import tqdm from pti.pti_configs import paths_config, hyperparameters, global_config from pti.training.coaches.base_coach import BaseCoach from utils.log_utils import log_images_from_w class MultiIDCoach(BaseCoach): def __init__(self, data_loader, use_wandb): super().__init__(data_loader, use_wandb) def train(self): self.G.synthesis.train() self.G.mapping.train() w_path_dir = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}' os.makedirs(w_path_dir, exist_ok=True) os.makedirs(f'{w_path_dir}/{paths_config.pti_results_keyword}', exist_ok=True) use_ball_holder = True w_pivots = [] images = [] for fname, image in self.data_loader: if self.image_counter >= hyperparameters.max_images_to_invert: break image_name = fname[0] if hyperparameters.first_inv_type == 'w+': embedding_dir = f'{w_path_dir}/{paths_config.e4e_results_keyword}/{image_name}' else: embedding_dir = f'{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}' os.makedirs(embedding_dir, exist_ok=True) w_pivot = self.get_inversion(w_path_dir, image_name, image) w_pivots.append(w_pivot) images.append((image_name, image)) self.image_counter += 1 for i in tqdm(range(hyperparameters.max_pti_steps)): self.image_counter = 0 for data, w_pivot in zip(images, w_pivots): image_name, image = data if self.image_counter >= hyperparameters.max_images_to_invert: break real_images_batch = image.to(global_config.device) generated_images = self.forward(w_pivot) loss, l2_loss_val, loss_lpips = self.calc_loss(generated_images, real_images_batch, image_name, self.G, use_ball_holder, w_pivot) self.optimizer.zero_grad() loss.backward() self.optimizer.step() use_ball_holder = global_config.training_step % hyperparameters.locality_regularization_interval == 0 global_config.training_step += 1 self.image_counter += 1 if self.use_wandb: log_images_from_w(w_pivots, self.G, [image[0] for image in images]) # torch.save(self.G, # f'{paths_config.checkpoints_dir}/model_{global_config.run_name}_multi_id.pt') snapshot_data = dict() snapshot_data['G_ema'] = self.G import pickle with open(f'{paths_config.checkpoints_dir}/model_{global_config.run_name}_multi_id.pkl', 'wb') as f: pickle.dump(snapshot_data, f) ================================================ FILE: stylegan_human/pti/training/coaches/single_id_coach.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. import os import torch from tqdm import tqdm from pti.pti_configs import paths_config, hyperparameters, global_config from pti.training.coaches.base_coach import BaseCoach from utils.log_utils import log_images_from_w from torchvision.utils import save_image class SingleIDCoach(BaseCoach): def __init__(self, data_loader, use_wandb): super().__init__(data_loader, use_wandb) def train(self): w_path_dir = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}' os.makedirs(w_path_dir, exist_ok=True) os.makedirs(f'{w_path_dir}/{paths_config.pti_results_keyword}', exist_ok=True) use_ball_holder = True for fname, image in tqdm(self.data_loader): image_name = fname[0] self.restart_training() if self.image_counter >= hyperparameters.max_images_to_invert: break embedding_dir = f'{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}' os.makedirs(embedding_dir, exist_ok=True) w_pivot = None if hyperparameters.use_last_w_pivots: w_pivot = self.load_inversions(w_path_dir, image_name) # Copyright (c) SenseTime Research. All rights reserved. elif not hyperparameters.use_last_w_pivots or w_pivot is None: w_pivot = self.calc_inversions(image, image_name) # w_pivot = w_pivot.detach().clone().to(global_config.device) w_pivot = w_pivot.to(global_config.device) torch.save(w_pivot, f'{embedding_dir}/0.pt') log_images_counter = 0 real_images_batch = image.to(global_config.device) for i in range(hyperparameters.max_pti_steps): generated_images = self.forward(w_pivot) loss, l2_loss_val, loss_lpips = self.calc_loss(generated_images, real_images_batch, image_name, self.G, use_ball_holder, w_pivot) if i == 0: tmp1 = torch.clone(generated_images) if i % 10 == 0: print("pti loss: ", i, loss.data, loss_lpips.data) self.optimizer.zero_grad() if loss_lpips <= hyperparameters.LPIPS_value_threshold: break loss.backward() self.optimizer.step() use_ball_holder = global_config.training_step % hyperparameters.locality_regularization_interval == 0 if self.use_wandb and log_images_counter % global_config.image_rec_result_log_snapshot == 0: log_images_from_w([w_pivot], self.G, [image_name]) global_config.training_step += 1 log_images_counter += 1 # save output image tmp = torch.cat([real_images_batch, tmp1, generated_images], axis= 3) save_image(tmp, f"{paths_config.experiments_output_dir}/{image_name}.png", normalize=True) self.image_counter += 1 # torch.save(self.G, # f'{paths_config.checkpoints_dir}/model_{image_name}.pt') #'.pt' snapshot_data = dict() snapshot_data['G_ema'] = self.G import pickle with open(f'{paths_config.checkpoints_dir}/model_{image_name}.pkl', 'wb') as f: pickle.dump(snapshot_data, f) ================================================ FILE: stylegan_human/pti/training/projectors/__init__.py ================================================ ================================================ FILE: stylegan_human/pti/training/projectors/w_plus_projector.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Project given image to the latent space of pretrained network pickle.""" import copy import wandb import numpy as np import torch import torch.nn.functional as F from tqdm import tqdm from configs import global_config, hyperparameters import dnnlib from utils.log_utils import log_image_from_w def project( G, target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution *, num_steps=1000, w_avg_samples=10000, initial_learning_rate=0.01, initial_noise_factor=0.05, lr_rampdown_length=0.25, lr_rampup_length=0.05, noise_ramp_length=0.75, regularize_noise_weight=1e5, verbose=False, device: torch.device, use_wandb=False, initial_w=None, image_log_step=global_config.image_rec_result_log_snapshot, w_name: str ): print('inside training/projectors/w_plus_projector') print(target.shape, G.img_channels, G.img_resolution * 2 , G.img_resolution) assert target.shape == (G.img_channels, G.img_resolution * 2, G.img_resolution) def logprint(*args): if verbose: print(*args) G = copy.deepcopy(G).eval().requires_grad_(False).to(device).float() # type: ignore # Compute w stats. logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...') z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C] w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C] w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C] w_avg_tensor = torch.from_numpy(w_avg).to(global_config.device) w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5 start_w = initial_w if initial_w is not None else w_avg # Setup noise inputs. noise_bufs = {name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name} # Load VGG16 feature detector. url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' with dnnlib.util.open_url(url) as f: vgg16 = torch.jit.load(f).eval().to(device) # Features for target image. target_images = target.unsqueeze(0).to(device).to(torch.float32) if target_images.shape[2] > 256: target_images = F.interpolate(target_images, size=(256, 256), mode='area') target_features = vgg16(target_images, resize_images=False, return_lpips=True) start_w = np.repeat(start_w, G.mapping.num_ws, axis=1) w_opt = torch.tensor(start_w, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=hyperparameters.first_inv_lr) # Init noise. for buf in noise_bufs.values(): buf[:] = torch.randn_like(buf) buf.requires_grad = True for step in tqdm(range(num_steps)): # Learning rate schedule. t = step / num_steps w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2 lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length) lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi) lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length) lr = initial_learning_rate * lr_ramp for param_group in optimizer.param_groups: param_group['lr'] = lr # Synth images from opt_w. w_noise = torch.randn_like(w_opt) * w_noise_scale ws = (w_opt + w_noise) synth_images = G.synthesis(ws, noise_mode='const', force_fp32=True) # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. synth_images = (synth_images + 1) * (255 / 2) if synth_images.shape[2] > 256: synth_images = F.interpolate(synth_images, size=(256, 256), mode='area') # Features for synth images. synth_features = vgg16(synth_images, resize_images=False, return_lpips=True) dist = (target_features - synth_features).square().sum() # Noise regularization. reg_loss = 0.0 for v in noise_bufs.values(): noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d() while True: reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2 reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2 if noise.shape[2] <= 8: break noise = F.avg_pool2d(noise, kernel_size=2) loss = dist + reg_loss * regularize_noise_weight if step % image_log_step == 0: with torch.no_grad(): if use_wandb: global_config.training_step += 1 wandb.log({f'first projection _{w_name}': loss.detach().cpu()}, step=global_config.training_step) log_image_from_w(w_opt, G, w_name) # Step optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() logprint(f'step {step + 1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}') # Normalize noise. with torch.no_grad(): for buf in noise_bufs.values(): buf -= buf.mean() buf *= buf.square().mean().rsqrt() del G return w_opt ================================================ FILE: stylegan_human/pti/training/projectors/w_projector.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Project given image to the latent space of pretrained network pickle.""" import copy import wandb import numpy as np import torch import torch.nn.functional as F from tqdm import tqdm from pti.pti_configs import global_config, hyperparameters from utils import log_utils import dnnlib def project( G, target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution *, num_steps=1000, w_avg_samples=10000, initial_learning_rate=0.01, initial_noise_factor=0.05, lr_rampdown_length=0.25, lr_rampup_length=0.05, noise_ramp_length=0.75, regularize_noise_weight=1e5, verbose=False, device: torch.device, use_wandb=False, initial_w=None, image_log_step=global_config.image_rec_result_log_snapshot, w_name: str ): print(target.shape,G.img_channels, G.img_resolution, G.img_resolution//2) assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution // 2) def logprint(*args): if verbose: print(*args) G = copy.deepcopy(G).eval().requires_grad_(False).to(device).float() # type: ignore # Compute w stats. logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...') z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C] w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C] w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C] w_avg_tensor = torch.from_numpy(w_avg).to(global_config.device) w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5 start_w = initial_w if initial_w is not None else w_avg # Setup noise inputs. noise_bufs = {name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name} # Load VGG16 feature detector. url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' with dnnlib.util.open_url(url) as f: vgg16 = torch.jit.load(f).eval().to(device) # Features for target image. target_images = target.unsqueeze(0).to(device).to(torch.float32) if target_images.shape[2] > 256: target_images = F.interpolate(target_images, size=(256, 256), mode='area') target_features = vgg16(target_images, resize_images=False, return_lpips=True) w_opt = torch.tensor(start_w, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=hyperparameters.first_inv_lr) # Init noise. for buf in noise_bufs.values(): buf[:] = torch.randn_like(buf) buf.requires_grad = True for step in range(num_steps): # Learning rate schedule. t = step / num_steps w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2 lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length) lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi) lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length) lr = initial_learning_rate * lr_ramp for param_group in optimizer.param_groups: param_group['lr'] = lr # Synth images from opt_w. w_noise = torch.randn_like(w_opt) * w_noise_scale ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1]) synth_images = G.synthesis(ws, noise_mode='const', force_fp32=True) # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. synth_images = (synth_images + 1) * (255 / 2) if synth_images.shape[2] > 256: synth_images = F.interpolate(synth_images, size=(256, 256), mode='area') # Features for synth images. synth_features = vgg16(synth_images, resize_images=False, return_lpips=True) dist = (target_features - synth_features).square().sum() # Noise regularization. reg_loss = 0.0 for v in noise_bufs.values(): noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d() while True: reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2 reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2 if noise.shape[2] <= 8: break noise = F.avg_pool2d(noise, kernel_size=2) loss = dist + reg_loss * regularize_noise_weight if step % 10 == 0: print("project loss", step, loss.data) if step % image_log_step == 0: with torch.no_grad(): if use_wandb: global_config.training_step += 1 wandb.log({f'first projection _{w_name}': loss.detach().cpu()}, step=global_config.training_step) log_utils.log_image_from_w(w_opt.repeat([1, G.mapping.num_ws, 1]), G, w_name) # Step optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() logprint(f'step {step + 1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}') # Normalize noise. with torch.no_grad(): for buf in noise_bufs.values(): buf -= buf.mean() buf *= buf.square().mean().rsqrt() del G return w_opt.repeat([1, 18, 1]) ================================================ FILE: stylegan_human/run_pti.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. from random import choice from string import ascii_uppercase from torch.utils.data import DataLoader from torchvision.transforms import transforms import os from pti.pti_configs import global_config, paths_config import wandb from pti.training.coaches.multi_id_coach import MultiIDCoach from pti.training.coaches.single_id_coach import SingleIDCoach from utils.ImagesDataset import ImagesDataset def run_PTI(run_name='', use_wandb=False, use_multi_id_training=False): os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' os.environ['CUDA_VISIBLE_DEVICES'] = global_config.cuda_visible_devices if run_name == '': global_config.run_name = ''.join(choice(ascii_uppercase) for i in range(12)) else: global_config.run_name = run_name if use_wandb: run = wandb.init(project=paths_config.pti_results_keyword, reinit=True, name=global_config.run_name) global_config.pivotal_training_steps = 1 global_config.training_step = 1 embedding_dir_path = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}/{paths_config.pti_results_keyword}' # print('embedding_dir_path: ', embedding_dir_path) #./embeddings/barcelona/PTI os.makedirs(embedding_dir_path, exist_ok=True) dataset = ImagesDataset(paths_config.input_data_path, transforms.Compose([ transforms.Resize((1024, 512)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])) dataloader = DataLoader(dataset, batch_size=1, shuffle=False) if use_multi_id_training: coach = MultiIDCoach(dataloader, use_wandb) else: coach = SingleIDCoach(dataloader, use_wandb) coach.train() return global_config.run_name if __name__ == '__main__': run_PTI(run_name='', use_wandb=False, use_multi_id_training=False) ================================================ FILE: stylegan_human/style_mixing.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. # import os import re from typing import List import legacy import click import dnnlib import numpy as np import PIL.Image import torch """ Style mixing using pretrained network pickle. Examples: \b python style_mixing.py --network=pretrained_models/stylegan_human_v2_1024.pkl --rows=85,100,75,458,1500 \\ --cols=55,821,1789,293 --styles=0-3 --outdir=outputs/stylemixing """ @click.command() @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) @click.option('--rows', 'row_seeds', type=legacy.num_range, help='Random seeds to use for image rows', required=True) @click.option('--cols', 'col_seeds', type=legacy.num_range, help='Random seeds to use for image columns', required=True) @click.option('--styles', 'col_styles', type=legacy.num_range, help='Style layer range', default='0-6', show_default=True) @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=0.8, show_default=True) @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) @click.option('--outdir', type=str, required=True, default='outputs/stylemixing') def generate_style_mix( network_pkl: str, row_seeds: List[int], col_seeds: List[int], col_styles: List[int], truncation_psi: float, noise_mode: str, outdir: str ): print('Loading networks from "%s"...' % network_pkl) device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') dtype = torch.float32 if device.type == 'mps' else torch.float64 with dnnlib.util.open_url(network_pkl) as f: G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype) os.makedirs(outdir, exist_ok=True) print('Generating W vectors...') all_seeds = list(set(row_seeds + col_seeds)) all_z = np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds]) all_w = G.mapping(torch.from_numpy(all_z).to(device, dtype=dtype), None) w_avg = G.mapping.w_avg all_w = w_avg + (all_w - w_avg) * truncation_psi w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))} print('Generating images...') all_images = G.synthesis(all_w, noise_mode=noise_mode) all_images = (all_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy() image_dict = {(seed, seed): image for seed, image in zip(all_seeds, list(all_images))} print('Generating style-mixed images...') for row_seed in row_seeds: for col_seed in col_seeds: w = w_dict[row_seed].clone() w[col_styles] = w_dict[col_seed][col_styles] image = G.synthesis(w[np.newaxis], noise_mode=noise_mode) image = (image.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) image_dict[(row_seed, col_seed)] = image[0].cpu().numpy() os.makedirs(outdir, exist_ok=True) # print('Saving images...') # for (row_seed, col_seed), image in image_dict.items(): # PIL.Image.fromarray(image, 'RGB').save(f'{outdir}/{row_seed}-{col_seed}.png') print('Saving image grid...') W = G.img_resolution // 2 H = G.img_resolution canvas = PIL.Image.new('RGB', (W * (len(col_seeds) + 1), H * (len(row_seeds) + 1)), 'black') for row_idx, row_seed in enumerate([0] + row_seeds): for col_idx, col_seed in enumerate([0] + col_seeds): if row_idx == 0 and col_idx == 0: continue key = (row_seed, col_seed) if row_idx == 0: key = (col_seed, col_seed) if col_idx == 0: key = (row_seed, row_seed) canvas.paste(PIL.Image.fromarray(image_dict[key], 'RGB'), (W * col_idx, H * row_idx)) canvas.save(f'{outdir}/grid.png') #---------------------------------------------------------------------------- if __name__ == "__main__": generate_style_mix() # pylint: disable=no-value-for-parameter #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/stylemixing_video.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. """Here we demo style-mixing results using StyleGAN2 pretrained model. Script reference: https://github.com/PDillis/stylegan2-fun """ import argparse import legacy import scipy import numpy as np import PIL.Image import dnnlib import dnnlib.tflib as tflib from typing import List import re import sys import os import click import torch os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide" import moviepy.editor """ Generate style mixing video. Examples: \b python stylemixing_video.py --network=pretrained_models/stylegan_human_v2_1024.pkl --row-seed=3859 \\ --col-seeds=3098,31759,3791 --col-styles=8-12 --trunc=0.8 --outdir=outputs/stylemixing_video """ @click.command() @click.option('--network', 'network_pkl', help='Path to network pickle filename', required=True) @click.option('--row-seed', 'src_seed', type=legacy.num_range, help='Random seed to use for image source row', required=True) @click.option('--col-seeds', 'dst_seeds', type=legacy.num_range, help='Random seeds to use for image columns (style)', required=True) @click.option('--col-styles', 'col_styles', type=legacy.num_range, help='Style layer range (default: %(default)s)', default='0-6') @click.option('--only-stylemix', 'only_stylemix', help='Add flag to only show the style mxied images in the video',default=False) @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi (default: %(default)s)', default=1) @click.option('--duration-sec', 'duration_sec', type=float, help='Duration of video (default: %(default)s)', default=10) @click.option('--fps', 'mp4_fps', type=int, help='FPS of generated video (default: %(default)s)', default=10) @click.option('--indent-range', 'indent_range', type=int, default=30) @click.option('--outdir', help='Root directory for run results (default: %(default)s)', default='outputs/stylemixing_video', metavar='DIR') def style_mixing_video(network_pkl: str, src_seed: List[int], # Seed of the source image style (row) dst_seeds: List[int], # Seeds of the destination image styles (columns) col_styles: List[int], # Styles to transfer from first row to first column truncation_psi=float, only_stylemix=bool, # True if user wishes to show only thre style transferred result duration_sec=float, smoothing_sec=1.0, mp4_fps=int, mp4_codec="libx264", mp4_bitrate="16M", minibatch_size=8, noise_mode='const', indent_range=int, outdir=str): # Calculate the number of frames: print('col_seeds: ', dst_seeds) num_frames = int(np.rint(duration_sec * mp4_fps)) print('Loading networks from "%s"...' % network_pkl) device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') dtype = torch.float32 if device.type == 'mps' else torch.float64 with dnnlib.util.open_url(network_pkl) as f: Gs = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype) print(Gs.num_ws, Gs.w_dim, Gs.img_resolution) max_style = int(2 * np.log2(Gs.img_resolution)) - 3 assert max(col_styles) <= max_style, f"Maximum col-style allowed: {max_style}" # Left col latents print('Generating Source W vectors...') src_shape = [num_frames] + [Gs.z_dim] src_z = np.random.RandomState(*src_seed).randn(*src_shape).astype(np.float32) # [frames, src, component] src_z = scipy.ndimage.gaussian_filter(src_z, [smoothing_sec * mp4_fps] + [0] * (2- 1), mode="wrap") src_z /= np.sqrt(np.mean(np.square(src_z))) # Map into the detangled latent space W and do truncation trick src_w = Gs.mapping(torch.from_numpy(src_z).to(device, dtype=dtype), None) w_avg = Gs.mapping.w_avg src_w = w_avg + (src_w - w_avg) * truncation_psi # Top row latents (fixed reference) print('Generating Destination W vectors...') dst_z = np.stack([np.random.RandomState(seed).randn(Gs.z_dim) for seed in dst_seeds]) dst_w = Gs.mapping(torch.from_numpy(dst_z).to(device, dtype=dtype), None) dst_w = w_avg + (dst_w - w_avg) * truncation_psi # Get the width and height of each image: H = Gs.img_resolution # 1024 W = Gs.img_resolution//2 # 512 # Generate ALL the source images: src_images = Gs.synthesis(src_w, noise_mode=noise_mode) src_images = (src_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) # Generate the column images: dst_images = Gs.synthesis(dst_w, noise_mode=noise_mode) dst_images = (dst_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) print('Generating full video (including source and destination images)') # Generate our canvas where we will paste all the generated images: canvas = PIL.Image.new("RGB", ((W-indent_range) * (len(dst_seeds) + 1), H * (len(src_seed) + 1)), "white") # W, H for col, dst_image in enumerate(list(dst_images)): #dst_image:[3,1024,512] canvas.paste(PIL.Image.fromarray(dst_image.cpu().numpy(), "RGB"), ((col + 1) * (W-indent_range), 0)) #H # Aux functions: Frame generation func for moviepy. def make_frame(t): # Get the frame number according to time t: frame_idx = int(np.clip(np.round(t * mp4_fps), 0, num_frames - 1)) # We wish the image belonging to the frame at time t: src_image = src_images[frame_idx] # always in the same place canvas.paste(PIL.Image.fromarray(src_image.cpu().numpy(), "RGB"), (0-indent_range, H)) # Paste it to the lower left # Now, for each of the column images: for col, dst_image in enumerate(list(dst_images)): # Select the pertinent latent w column: w_col = np.stack([dst_w[col].cpu()]) # [18, 512] -> [1, 18, 512] w_col = torch.from_numpy(w_col).to(device, dtype=dtype) # Replace the values defined by col_styles: w_col[:, col_styles] = src_w[frame_idx, col_styles]#.cpu() # Generate these synthesized images: col_images = Gs.synthesis(w_col, noise_mode=noise_mode) col_images = (col_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) # Paste them in their respective spot: for row, image in enumerate(list(col_images)): canvas.paste( PIL.Image.fromarray(image.cpu().numpy(), "RGB"), ((col + 1) * (W - indent_range), (row + 1) * H), ) return np.array(canvas) # Generate video using make_frame: print('Generating style-mixed video...') videoclip = moviepy.editor.VideoClip(make_frame, duration=duration_sec) grid_size = [len(dst_seeds), len(src_seed)] mp4 = "{}x{}-style-mixing_{}_{}.mp4".format(*grid_size,min(col_styles),max(col_styles)) if not os.path.exists(outdir): os.makedirs(outdir) videoclip.write_videofile(os.path.join(outdir,mp4), fps=mp4_fps, codec=mp4_codec, bitrate=mp4_bitrate) if __name__ == "__main__": style_mixing_video() ================================================ FILE: stylegan_human/torch_utils/__init__.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. # empty ================================================ FILE: stylegan_human/torch_utils/custom_ops.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import os import glob import torch import torch.utils.cpp_extension import importlib import hashlib import shutil from pathlib import Path import re import uuid from torch.utils.file_baton import FileBaton #---------------------------------------------------------------------------- # Global options. verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' #---------------------------------------------------------------------------- # Internal helper funcs. def _find_compiler_bindir(): patterns = [ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', ] for pattern in patterns: matches = sorted(glob.glob(pattern)) if len(matches): return matches[-1] return None def _get_mangled_gpu_name(): name = torch.cuda.get_device_name().lower() out = [] for c in name: if re.match('[a-z0-9_-]+', c): out.append(c) else: out.append('-') return ''.join(out) #---------------------------------------------------------------------------- # Main entry point for compiling and loading C++/CUDA plugins. _cached_plugins = dict() def get_plugin(module_name, sources, **build_kwargs): assert verbosity in ['none', 'brief', 'full'] # Already cached? if module_name in _cached_plugins: return _cached_plugins[module_name] # Print status. if verbosity == 'full': print(f'Setting up PyTorch plugin "{module_name}"...') elif verbosity == 'brief': print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) try: # pylint: disable=too-many-nested-blocks # Make sure we can find the necessary compiler binaries. if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: compiler_bindir = _find_compiler_bindir() if compiler_bindir is None: raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') os.environ['PATH'] += ';' + compiler_bindir # Compile and load. verbose_build = (verbosity == 'full') # Incremental build md5sum trickery. Copies all the input source files # into a cached build directory under a combined md5 digest of the input # source files. Copying is done only if the combined digest has changed. # This keeps input file timestamps and filenames the same as in previous # extension builds, allowing for fast incremental rebuilds. # # This optimization is done only in case all the source files reside in # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR # environment variable is set (we take this as a signal that the user # actually cares about this.) source_dirs_set = set(os.path.dirname(source) for source in sources) if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) # Compute a combined hash digest for all source files in the same # custom op directory (usually .cu, .cpp, .py and .h files). hash_md5 = hashlib.md5() for src in all_source_files: with open(src, 'rb') as f: hash_md5.update(f.read()) build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) if not os.path.isdir(digest_build_dir): os.makedirs(digest_build_dir, exist_ok=True) baton = FileBaton(os.path.join(digest_build_dir, 'lock')) if baton.try_acquire(): try: for src in all_source_files: shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) finally: baton.release() else: # Someone else is copying source files under the digest dir, # wait until done and continue. baton.wait() digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, verbose=verbose_build, sources=digest_sources, **build_kwargs) else: torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) module = importlib.import_module(module_name) except: if verbosity == 'brief': print('Failed!') raise # Print status and add to cache. if verbosity == 'full': print(f'Done setting up PyTorch plugin "{module_name}".') elif verbosity == 'brief': print('Done.') _cached_plugins[module_name] = module return module #---------------------------------------------------------------------------- def get_plugin_v3(module_name, sources, headers=None, source_dir=None, **build_kwargs): assert verbosity in ['none', 'brief', 'full'] if headers is None: headers = [] if source_dir is not None: sources = [os.path.join(source_dir, fname) for fname in sources] headers = [os.path.join(source_dir, fname) for fname in headers] # Already cached? if module_name in _cached_plugins: return _cached_plugins[module_name] # Print status. if verbosity == 'full': print(f'Setting up PyTorch plugin "{module_name}"...') elif verbosity == 'brief': print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) verbose_build = (verbosity == 'full') # Compile and load. try: # pylint: disable=too-many-nested-blocks # Make sure we can find the necessary compiler binaries. if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: compiler_bindir = _find_compiler_bindir() if compiler_bindir is None: raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') os.environ['PATH'] += ';' + compiler_bindir # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either # break the build or unnecessarily restrict what's available to nvcc. # Unset it to let nvcc decide based on what's available on the # machine. os.environ['TORCH_CUDA_ARCH_LIST'] = '' # Incremental build md5sum trickery. Copies all the input source files # into a cached build directory under a combined md5 digest of the input # source files. Copying is done only if the combined digest has changed. # This keeps input file timestamps and filenames the same as in previous # extension builds, allowing for fast incremental rebuilds. # # This optimization is done only in case all the source files reside in # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR # environment variable is set (we take this as a signal that the user # actually cares about this.) # # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work # around the *.cu dependency bug in ninja config. # all_source_files = sorted(sources + headers) all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): # Compute combined hash digest for all source files. hash_md5 = hashlib.md5() for src in all_source_files: with open(src, 'rb') as f: hash_md5.update(f.read()) # Select cached build directory name. source_digest = hash_md5.hexdigest() build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') if not os.path.isdir(cached_build_dir): tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' os.makedirs(tmpdir) for src in all_source_files: shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) try: os.replace(tmpdir, cached_build_dir) # atomic except OSError: # source directory already exists, delete tmpdir and its contents. shutil.rmtree(tmpdir) if not os.path.isdir(cached_build_dir): raise # Compile. cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, verbose=verbose_build, sources=cached_sources, **build_kwargs) else: torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) # Load. module = importlib.import_module(module_name) except: if verbosity == 'brief': print('Failed!') raise # Print status and add to cache dict. if verbosity == 'full': print(f'Done setting up PyTorch plugin "{module_name}".') elif verbosity == 'brief': print('Done.') _cached_plugins[module_name] = module return module ================================================ FILE: stylegan_human/torch_utils/misc.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import re import contextlib import numpy as np import torch import warnings import dnnlib #---------------------------------------------------------------------------- # Cached construction of constant tensors. Avoids CPU=>GPU copy when the # same constant is used multiple times. _constant_cache = dict() def constant(value, shape=None, dtype=None, device=None, memory_format=None): value = np.asarray(value) if shape is not None: shape = tuple(shape) if dtype is None: dtype = torch.get_default_dtype() if device is None: device = torch.device('cpu') if memory_format is None: memory_format = torch.contiguous_format key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) tensor = _constant_cache.get(key, None) if tensor is None: tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) if shape is not None: tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) tensor = tensor.contiguous(memory_format=memory_format) _constant_cache[key] = tensor return tensor #---------------------------------------------------------------------------- # Replace NaN/Inf with specified numerical values. try: nan_to_num = torch.nan_to_num # 1.8.0a0 except AttributeError: def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin assert isinstance(input, torch.Tensor) if posinf is None: posinf = torch.finfo(input.dtype).max if neginf is None: neginf = torch.finfo(input.dtype).min assert nan == 0 return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) #---------------------------------------------------------------------------- # Symbolic assert. try: symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access except AttributeError: symbolic_assert = torch.Assert # 1.7.0 #---------------------------------------------------------------------------- # Context manager to suppress known warnings in torch.jit.trace(). class suppress_tracer_warnings(warnings.catch_warnings): def __enter__(self): super().__enter__() warnings.simplefilter('ignore', category=torch.jit.TracerWarning) return self #---------------------------------------------------------------------------- # Assert that the shape of a tensor matches the given list of integers. # None indicates that the size of a dimension is allowed to vary. # Performs symbolic assertion when used in torch.jit.trace(). def assert_shape(tensor, ref_shape): if tensor.ndim != len(ref_shape): raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): if ref_size is None: pass elif isinstance(ref_size, torch.Tensor): with suppress_tracer_warnings(): # as_tensor results are registered as constants symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') elif isinstance(size, torch.Tensor): with suppress_tracer_warnings(): # as_tensor results are registered as constants symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') elif size != ref_size: raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') #---------------------------------------------------------------------------- # Function decorator that calls torch.autograd.profiler.record_function(). def profiled_function(fn): def decorator(*args, **kwargs): with torch.autograd.profiler.record_function(fn.__name__): return fn(*args, **kwargs) decorator.__name__ = fn.__name__ return decorator #---------------------------------------------------------------------------- # Sampler for torch.utils.data.DataLoader that loops over the dataset # indefinitely, shuffling items as it goes. class InfiniteSampler(torch.utils.data.Sampler): def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): assert len(dataset) > 0 assert num_replicas > 0 assert 0 <= rank < num_replicas assert 0 <= window_size <= 1 super().__init__(dataset) self.dataset = dataset self.rank = rank self.num_replicas = num_replicas self.shuffle = shuffle self.seed = seed self.window_size = window_size def __iter__(self): order = np.arange(len(self.dataset)) rnd = None window = 0 if self.shuffle: rnd = np.random.RandomState(self.seed) rnd.shuffle(order) window = int(np.rint(order.size * self.window_size)) idx = 0 while True: i = idx % order.size if idx % self.num_replicas == self.rank: yield order[i] if window >= 2: j = (i - rnd.randint(window)) % order.size order[i], order[j] = order[j], order[i] idx += 1 #---------------------------------------------------------------------------- # Utilities for operating with torch.nn.Module parameters and buffers. def params_and_buffers(module): assert isinstance(module, torch.nn.Module) return list(module.parameters()) + list(module.buffers()) def named_params_and_buffers(module): assert isinstance(module, torch.nn.Module) return list(module.named_parameters()) + list(module.named_buffers()) def copy_params_and_buffers(src_module, dst_module, require_all=False): assert isinstance(src_module, torch.nn.Module) assert isinstance(dst_module, torch.nn.Module) src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)} for name, tensor in named_params_and_buffers(dst_module): assert (name in src_tensors) or (not require_all) if name in src_tensors: tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) #---------------------------------------------------------------------------- # Context manager for easily enabling/disabling DistributedDataParallel # synchronization. @contextlib.contextmanager def ddp_sync(module, sync): assert isinstance(module, torch.nn.Module) if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): yield else: with module.no_sync(): yield #---------------------------------------------------------------------------- # Check DistributedDataParallel consistency across processes. def check_ddp_consistency(module, ignore_regex=None): assert isinstance(module, torch.nn.Module) for name, tensor in named_params_and_buffers(module): fullname = type(module).__name__ + '.' + name if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): continue tensor = tensor.detach() other = tensor.clone() torch.distributed.broadcast(tensor=other, src=0) assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname #---------------------------------------------------------------------------- # Print summary table of module hierarchy. def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): assert isinstance(module, torch.nn.Module) assert not isinstance(module, torch.jit.ScriptModule) assert isinstance(inputs, (tuple, list)) # Register hooks. entries = [] nesting = [0] def pre_hook(_mod, _inputs): nesting[0] += 1 def post_hook(mod, _inputs, outputs): nesting[0] -= 1 if nesting[0] <= max_nesting: outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] outputs = [t for t in outputs if isinstance(t, torch.Tensor)] entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] # Run module. outputs = module(*inputs) for hook in hooks: hook.remove() # Identify unique outputs, parameters, and buffers. tensors_seen = set() for e in entries: e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} # Filter out redundant entries. if skip_redundant: entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] # Construct table. rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] rows += [['---'] * len(rows[0])] param_total = 0 buffer_total = 0 submodule_names = {mod: name for name, mod in module.named_modules()} for e in entries: name = '' if e.mod is module else submodule_names[e.mod] param_size = sum(t.numel() for t in e.unique_params) buffer_size = sum(t.numel() for t in e.unique_buffers) output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs] output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] rows += [[ name + (':0' if len(e.outputs) >= 2 else ''), str(param_size) if param_size else '-', str(buffer_size) if buffer_size else '-', (output_shapes + ['-'])[0], (output_dtypes + ['-'])[0], ]] for idx in range(1, len(e.outputs)): rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] param_total += param_size buffer_total += buffer_size rows += [['---'] * len(rows[0])] rows += [['Total', str(param_total), str(buffer_total), '-', '-']] # Print table. widths = [max(len(cell) for cell in column) for column in zip(*rows)] print() for row in rows: print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) print() return outputs #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/torch_utils/models.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py import math import random import functools import operator import torch from torch import nn from torch.nn import functional as F import torch.nn.init as init from torch.autograd import Function from .op_edit import FusedLeakyReLU, fused_leaky_relu, upfirdn2d class PixelNorm(nn.Module): def __init__(self): super().__init__() def forward(self, input): return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) def make_kernel(k): k = torch.tensor(k, dtype=torch.float32) if k.ndim == 1: k = k[None, :] * k[:, None] k /= k.sum() return k class Upsample(nn.Module): def __init__(self, kernel, factor=2): super().__init__() self.factor = factor kernel = make_kernel(kernel) * (factor ** 2) self.register_buffer("kernel", kernel) p = kernel.shape[0] - factor pad0 = (p + 1) // 2 + factor - 1 pad1 = p // 2 self.pad = (pad0, pad1) def forward(self, input): out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) return out class Downsample(nn.Module): def __init__(self, kernel, factor=2): super().__init__() self.factor = factor kernel = make_kernel(kernel) self.register_buffer("kernel", kernel) p = kernel.shape[0] - factor pad0 = (p + 1) // 2 pad1 = p // 2 self.pad = (pad0, pad1) def forward(self, input): out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) return out class Blur(nn.Module): def __init__(self, kernel, pad, upsample_factor=1): super().__init__() kernel = make_kernel(kernel) if upsample_factor > 1: kernel = kernel * (upsample_factor ** 2) self.register_buffer("kernel", kernel) self.pad = pad def forward(self, input): out = upfirdn2d(input, self.kernel, pad=self.pad) return out class EqualConv2d(nn.Module): def __init__( self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True ): super().__init__() self.weight = nn.Parameter( torch.randn(out_channel, in_channel, kernel_size, kernel_size) ) self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) self.stride = stride self.padding = padding if bias: self.bias = nn.Parameter(torch.zeros(out_channel)) else: self.bias = None def forward(self, input): out = F.conv2d( input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding, ) return out def __repr__(self): return ( f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" ) class EqualLinear(nn.Module): def __init__( self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None ): super().__init__() self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) if bias: self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) else: self.bias = None self.activation = activation self.scale = (1 / math.sqrt(in_dim)) * lr_mul self.lr_mul = lr_mul def forward(self, input): if self.activation: out = F.linear(input, self.weight * self.scale) out = fused_leaky_relu(out, self.bias * self.lr_mul) else: out = F.linear( input, self.weight * self.scale, bias=self.bias * self.lr_mul ) return out def __repr__(self): return ( f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" ) class ScaledLeakyReLU(nn.Module): def __init__(self, negative_slope=0.2): super().__init__() self.negative_slope = negative_slope def forward(self, input): out = F.leaky_relu(input, negative_slope=self.negative_slope) return out * math.sqrt(2) class ModulatedConv2d(nn.Module): def __init__( self, in_channel, out_channel, kernel_size, style_dim, demodulate=True, upsample=False, downsample=False, blur_kernel=[1, 3, 3, 1], ): super().__init__() self.eps = 1e-8 self.kernel_size = kernel_size self.in_channel = in_channel self.out_channel = out_channel self.upsample = upsample self.downsample = downsample if upsample: factor = 2 p = (len(blur_kernel) - factor) - (kernel_size - 1) pad0 = (p + 1) // 2 + factor - 1 pad1 = p // 2 + 1 self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) if downsample: factor = 2 p = (len(blur_kernel) - factor) + (kernel_size - 1) pad0 = (p + 1) // 2 pad1 = p // 2 self.blur = Blur(blur_kernel, pad=(pad0, pad1)) fan_in = in_channel * kernel_size ** 2 self.scale = 1 / math.sqrt(fan_in) self.padding = kernel_size // 2 self.weight = nn.Parameter( torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) ) self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) self.demodulate = demodulate def __repr__(self): return ( f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " f"upsample={self.upsample}, downsample={self.downsample})" ) def forward(self, input, style): batch, in_channel, height, width = input.shape style = self.modulation(style).view(batch, 1, in_channel, 1, 1) weight = self.scale * self.weight * style if self.demodulate: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) weight = weight.view( batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size ) if self.upsample: input = input.view(1, batch * in_channel, height, width) weight = weight.view( batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size ) weight = weight.transpose(1, 2).reshape( batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size ) out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) out = self.blur(out) elif self.downsample: input = self.blur(input) _, _, height, width = input.shape input = input.view(1, batch * in_channel, height, width) out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) else: input = input.view(1, batch * in_channel, height, width) out = F.conv2d(input, weight, padding=self.padding, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) return out class NoiseInjection(nn.Module): def __init__(self): super().__init__() self.weight = nn.Parameter(torch.zeros(1)) def forward(self, image, noise=None): if noise is None: batch, _, height, width = image.shape noise = image.new_empty(batch, 1, height, width).normal_() return image + self.weight * noise class ConstantInput(nn.Module): def __init__(self, channel, size=4): super().__init__() self.input = nn.Parameter(torch.randn(1, channel, size, size // 2)) def forward(self, input): batch = input.shape[0] out = self.input.repeat(batch, 1, 1, 1) return out class StyledConv(nn.Module): def __init__( self, in_channel, out_channel, kernel_size, style_dim, upsample=False, blur_kernel=[1, 3, 3, 1], demodulate=True, ): super().__init__() self.conv = ModulatedConv2d( in_channel, out_channel, kernel_size, style_dim, upsample=upsample, blur_kernel=blur_kernel, demodulate=demodulate, ) self.noise = NoiseInjection() self.activate = FusedLeakyReLU(out_channel) def forward(self, input, style, noise=None): out = self.conv(input, style) out = self.noise(out, noise=noise) out = self.activate(out) return out class ToRGB(nn.Module): def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): super().__init__() if upsample: self.upsample = Upsample(blur_kernel) self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) def forward(self, input, style, skip=None): out = self.conv(input, style) out = out + self.bias if skip is not None: skip = self.upsample(skip) out = out + skip return out class Generator(nn.Module): def __init__( self, size, style_dim, n_mlp, channel_multiplier=1, blur_kernel=[1, 3, 3, 1], lr_mlp=0.01, small=False, small_isaac=False, ): super().__init__() self.size = size if small and size > 64: raise ValueError("small only works for sizes <= 64") self.style_dim = style_dim layers = [PixelNorm()] for i in range(n_mlp): layers.append( EqualLinear( style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu" ) ) self.style = nn.Sequential(*layers) if small: self.channels = { 4: 64 * channel_multiplier, 8: 64 * channel_multiplier, 16: 64 * channel_multiplier, 32: 64 * channel_multiplier, 64: 64 * channel_multiplier, } elif small_isaac: self.channels = {4: 256, 8: 256, 16: 256, 32: 256, 64: 128, 128: 128} else: self.channels = { 4: 512, 8: 512, 16: 512, 32: 512, 64: 256 * channel_multiplier, 128: 128 * channel_multiplier, 256: 64 * channel_multiplier, 512: 32 * channel_multiplier, 1024: 16 * channel_multiplier, } self.input = ConstantInput(self.channels[4]) self.conv1 = StyledConv( self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel ) self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) self.log_size = int(math.log(size, 2)) self.num_layers = (self.log_size - 2) * 2 + 1 self.convs = nn.ModuleList() self.upsamples = nn.ModuleList() self.to_rgbs = nn.ModuleList() self.noises = nn.Module() in_channel = self.channels[4] for layer_idx in range(self.num_layers): res = (layer_idx + 5) // 2 shape = [1, 1, 2 ** res, 2 ** res // 2] self.noises.register_buffer( "noise_{}".format(layer_idx), torch.randn(*shape) ) for i in range(3, self.log_size + 1): out_channel = self.channels[2 ** i] self.convs.append( StyledConv( in_channel, out_channel, 3, style_dim, upsample=True, blur_kernel=blur_kernel, ) ) self.convs.append( StyledConv( out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel ) ) self.to_rgbs.append(ToRGB(out_channel, style_dim)) in_channel = out_channel self.n_latent = self.log_size * 2 - 2 def make_noise(self): device = self.input.input.device noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2 // 2, device=device)] for i in range(3, self.log_size + 1): for _ in range(2): noises.append(torch.randn(1, 1, 2 ** i, 2 ** i // 2, device=device)) return noises def mean_latent(self, n_latent): latent_in = torch.randn( n_latent, self.style_dim, device=self.input.input.device ) latent = self.style(latent_in).mean(0, keepdim=True) return latent def get_latent(self, input): return self.style(input) def forward( self, styles, return_latents=False, return_features=False, inject_index=None, truncation=1, truncation_latent=None, input_is_latent=False, noise=None, randomize_noise=True, real=False, ): if not input_is_latent: styles = [self.style(s) for s in styles] if noise is None: if randomize_noise: noise = [None] * self.num_layers else: noise = [ getattr(self.noises, "noise_{}".format(i)) for i in range(self.num_layers) ] if truncation < 1: # print('truncation_latent: ', truncation_latent.shape) if not real: #if type(styles) == list: style_t = [] for style in styles: style_t.append( truncation_latent + truncation * (style - truncation_latent) ) # (-1.1162e-03-(-1.0914e-01))*0.8+(-1.0914e-01) styles = style_t else: # styles are latent (tensor: 1,18,512), for real PTI output truncation_latent = truncation_latent.repeat(18,1).unsqueeze(0) # (1,512) --> (1,18,512) styles = torch.add(truncation_latent,torch.mul(torch.sub(styles,truncation_latent),truncation)) # print('now styles after truncation : ', styles) #if type(styles) == list and len(styles) < 2: # this if for input as list of [(1,512)] if not real: if len(styles) < 2: inject_index = self.n_latent if styles[0].ndim < 3: latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) else: latent = styles[0] elif type(styles) == list: if inject_index is None: inject_index = 4 latent = styles[0].unsqueeze(0) if latent.shape[1] == 1: latent = latent.repeat(1, inject_index, 1) else: latent = latent[:, :inject_index, :] latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) latent = torch.cat([latent, latent2], 1) else: # input is tensor of size with torch.Size([1, 18, 512]), for real PTI output latent = styles # print(f'processed latent: {latent.shape}') features = {} out = self.input(latent) features["out_0"] = out out = self.conv1(out, latent[:, 0], noise=noise[0]) features["conv1_0"] = out skip = self.to_rgb1(out, latent[:, 1]) features["skip_0"] = skip i = 1 for conv1, conv2, noise1, noise2, to_rgb in zip( self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs ): out = conv1(out, latent[:, i], noise=noise1) features["conv1_{}".format(i)] = out out = conv2(out, latent[:, i + 1], noise=noise2) features["conv2_{}".format(i)] = out skip = to_rgb(out, latent[:, i + 2], skip) features["skip_{}".format(i)] = skip i += 2 image = skip if return_latents: return image, latent elif return_features: return image, features else: return image, None class ConvLayer(nn.Sequential): def __init__( self, in_channel, out_channel, kernel_size, downsample=False, blur_kernel=[1, 3, 3, 1], bias=True, activate=True, ): layers = [] if downsample: factor = 2 p = (len(blur_kernel) - factor) + (kernel_size - 1) pad0 = (p + 1) // 2 pad1 = p // 2 layers.append(Blur(blur_kernel, pad=(pad0, pad1))) stride = 2 self.padding = 0 else: stride = 1 self.padding = kernel_size // 2 layers.append( EqualConv2d( in_channel, out_channel, kernel_size, padding=self.padding, stride=stride, bias=bias and not activate, ) ) if activate: if bias: layers.append(FusedLeakyReLU(out_channel)) else: layers.append(ScaledLeakyReLU(0.2)) super().__init__(*layers) class ResBlock(nn.Module): def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): super().__init__() self.conv1 = ConvLayer(in_channel, in_channel, 3) self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) self.skip = ConvLayer( in_channel, out_channel, 1, downsample=True, activate=False, bias=False ) def forward(self, input): out = self.conv1(input) out = self.conv2(out) skip = self.skip(input) out = (out + skip) / math.sqrt(2) return out class StyleDiscriminator(nn.Module): def __init__( self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], small=False ): super().__init__() if small: channels = {4: 64, 8: 64, 16: 64, 32: 64, 64: 64} else: channels = { 4: 512, 8: 512, 16: 512, 32: 512, 64: 256 * channel_multiplier, 128: 128 * channel_multiplier, 256: 64 * channel_multiplier, 512: 32 * channel_multiplier, 1024: 16 * channel_multiplier, } convs = [ConvLayer(3, channels[size], 1)] log_size = int(math.log(size, 2)) in_channel = channels[size] for i in range(log_size, 2, -1): out_channel = channels[2 ** (i - 1)] convs.append(ResBlock(in_channel, out_channel, blur_kernel)) in_channel = out_channel self.convs = nn.Sequential(*convs) self.stddev_group = 4 self.stddev_feat = 1 self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) self.final_linear = nn.Sequential( EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"), EqualLinear(channels[4], 1), ) def forward(self, input): h = input h_list = [] for index, blocklist in enumerate(self.convs): h = blocklist(h) h_list.append(h) out = h batch, channel, height, width = out.shape group = min(batch, self.stddev_group) stddev = out.view( group, -1, self.stddev_feat, channel // self.stddev_feat, height, width ) stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) stddev = stddev.repeat(group, 1, height, width) out = torch.cat([out, stddev], 1) out = self.final_conv(out) h_list.append(out) out = out.view(batch, -1) out = self.final_linear(out) return out, h_list class StyleEncoder(nn.Module): def __init__(self, size, w_dim=512): super().__init__() channels = { 4: 512, 8: 512, 16: 512, 32: 512, 64: 256, 128: 128, 256: 64, 512: 32, 1024: 16 } self.w_dim = w_dim log_size = int(math.log(size, 2)) convs = [ConvLayer(3, channels[size], 1)] in_channel = channels[size] for i in range(log_size, 2, -1): out_channel = channels[2 ** (i - 1)] convs.append(ResBlock(in_channel, out_channel)) in_channel = out_channel convs.append(EqualConv2d(in_channel,2*self.w_dim, 4, padding=0, bias=False)) self.convs = nn.Sequential(*convs) def forward(self, input): out = self.convs(input) # return out.view(len(input), self.n_latents, self.w_dim) reshaped = out.view(len(input), 2*self.w_dim) return reshaped[:,:self.w_dim], reshaped[:,self.w_dim:] def kaiming_init(m): if isinstance(m, (nn.Linear, nn.Conv2d)): init.kaiming_normal_(m.weight) if m.bias is not None: m.bias.data.fill_(0) elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): m.weight.data.fill_(1) if m.bias is not None: m.bias.data.fill_(0) def normal_init(m): if isinstance(m, (nn.Linear, nn.Conv2d)): init.normal_(m.weight, 0, 0.02) if m.bias is not None: m.bias.data.fill_(0) elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): m.weight.data.fill_(1) if m.bias is not None: m.bias.data.fill_(0) ================================================ FILE: stylegan_human/torch_utils/models_face.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. import math import random import functools import operator import torch from torch import nn from torch.nn import functional as F import torch.nn.init as init from torch.autograd import Function from .op_edit import FusedLeakyReLU, fused_leaky_relu, upfirdn2d class PixelNorm(nn.Module): def __init__(self): super().__init__() def forward(self, input): return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) def make_kernel(k): k = torch.tensor(k, dtype=torch.float32) if k.ndim == 1: k = k[None, :] * k[:, None] k /= k.sum() return k class Upsample(nn.Module): def __init__(self, kernel, factor=2): super().__init__() self.factor = factor kernel = make_kernel(kernel) * (factor ** 2) self.register_buffer("kernel", kernel) p = kernel.shape[0] - factor pad0 = (p + 1) // 2 + factor - 1 pad1 = p // 2 self.pad = (pad0, pad1) def forward(self, input): out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) return out class Downsample(nn.Module): def __init__(self, kernel, factor=2): super().__init__() self.factor = factor kernel = make_kernel(kernel) self.register_buffer("kernel", kernel) p = kernel.shape[0] - factor pad0 = (p + 1) // 2 pad1 = p // 2 self.pad = (pad0, pad1) def forward(self, input): out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) return out class Blur(nn.Module): def __init__(self, kernel, pad, upsample_factor=1): super().__init__() kernel = make_kernel(kernel) if upsample_factor > 1: kernel = kernel * (upsample_factor ** 2) self.register_buffer("kernel", kernel) self.pad = pad def forward(self, input): out = upfirdn2d(input, self.kernel, pad=self.pad) return out class EqualConv2d(nn.Module): def __init__( self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True ): super().__init__() self.weight = nn.Parameter( torch.randn(out_channel, in_channel, kernel_size, kernel_size) ) self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) self.stride = stride self.padding = padding if bias: self.bias = nn.Parameter(torch.zeros(out_channel)) else: self.bias = None def forward(self, input): out = F.conv2d( input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding, ) return out def __repr__(self): return ( f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" ) class EqualLinear(nn.Module): def __init__( self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None ): super().__init__() self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) if bias: self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) else: self.bias = None self.activation = activation self.scale = (1 / math.sqrt(in_dim)) * lr_mul self.lr_mul = lr_mul def forward(self, input): if self.activation: out = F.linear(input, self.weight * self.scale) out = fused_leaky_relu(out, self.bias * self.lr_mul) else: out = F.linear( input, self.weight * self.scale, bias=self.bias * self.lr_mul ) return out def __repr__(self): return ( f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" ) class ScaledLeakyReLU(nn.Module): def __init__(self, negative_slope=0.2): super().__init__() self.negative_slope = negative_slope def forward(self, input): out = F.leaky_relu(input, negative_slope=self.negative_slope) return out * math.sqrt(2) class ModulatedConv2d(nn.Module): def __init__( self, in_channel, out_channel, kernel_size, style_dim, demodulate=True, upsample=False, downsample=False, blur_kernel=[1, 3, 3, 1], ): super().__init__() self.eps = 1e-8 self.kernel_size = kernel_size self.in_channel = in_channel self.out_channel = out_channel self.upsample = upsample self.downsample = downsample if upsample: factor = 2 p = (len(blur_kernel) - factor) - (kernel_size - 1) pad0 = (p + 1) // 2 + factor - 1 pad1 = p // 2 + 1 self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) if downsample: factor = 2 p = (len(blur_kernel) - factor) + (kernel_size - 1) pad0 = (p + 1) // 2 pad1 = p // 2 self.blur = Blur(blur_kernel, pad=(pad0, pad1)) fan_in = in_channel * kernel_size ** 2 self.scale = 1 / math.sqrt(fan_in) self.padding = kernel_size // 2 self.weight = nn.Parameter( torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) ) self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) self.demodulate = demodulate def __repr__(self): return ( f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " f"upsample={self.upsample}, downsample={self.downsample})" ) def forward(self, input, style): batch, in_channel, height, width = input.shape style = self.modulation(style).view(batch, 1, in_channel, 1, 1) weight = self.scale * self.weight * style if self.demodulate: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) weight = weight.view( batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size ) if self.upsample: input = input.view(1, batch * in_channel, height, width) weight = weight.view( batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size ) weight = weight.transpose(1, 2).reshape( batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size ) out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) out = self.blur(out) elif self.downsample: input = self.blur(input) _, _, height, width = input.shape input = input.view(1, batch * in_channel, height, width) out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) else: input = input.view(1, batch * in_channel, height, width) out = F.conv2d(input, weight, padding=self.padding, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) return out class NoiseInjection(nn.Module): def __init__(self): super().__init__() self.weight = nn.Parameter(torch.zeros(1)) def forward(self, image, noise=None): if noise is None: batch, _, height, width = image.shape noise = image.new_empty(batch, 1, height, width).normal_() return image + self.weight * noise class ConstantInput(nn.Module): def __init__(self, channel, size=4): super().__init__() self.input = nn.Parameter(torch.randn(1, channel, size, size)) def forward(self, input): batch = input.shape[0] out = self.input.repeat(batch, 1, 1, 1) return out class StyledConv(nn.Module): def __init__( self, in_channel, out_channel, kernel_size, style_dim, upsample=False, blur_kernel=[1, 3, 3, 1], demodulate=True, ): super().__init__() self.conv = ModulatedConv2d( in_channel, out_channel, kernel_size, style_dim, upsample=upsample, blur_kernel=blur_kernel, demodulate=demodulate, ) self.noise = NoiseInjection() # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) # self.activate = ScaledLeakyReLU(0.2) self.activate = FusedLeakyReLU(out_channel) def forward(self, input, style, noise=None): out = self.conv(input, style) out = self.noise(out, noise=noise) # out = out + self.bias out = self.activate(out) return out class ToRGB(nn.Module): def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): super().__init__() if upsample: self.upsample = Upsample(blur_kernel) self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) def forward(self, input, style, skip=None): out = self.conv(input, style) out = out + self.bias if skip is not None: skip = self.upsample(skip) out = out + skip return out class Generator(nn.Module): def __init__( self, size, style_dim, n_mlp, channel_multiplier=1, blur_kernel=[1, 3, 3, 1], lr_mlp=0.01, small=False, small_isaac=False, ): super().__init__() self.size = size if small and size > 64: raise ValueError("small only works for sizes <= 64") self.style_dim = style_dim layers = [PixelNorm()] for i in range(n_mlp): layers.append( EqualLinear( style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu" ) ) self.style = nn.Sequential(*layers) if small: self.channels = { 4: 64 * channel_multiplier, 8: 64 * channel_multiplier, 16: 64 * channel_multiplier, 32: 64 * channel_multiplier, 64: 64 * channel_multiplier, } elif small_isaac: self.channels = {4: 256, 8: 256, 16: 256, 32: 256, 64: 128, 128: 128} else: self.channels = { 4: 512, 8: 512, 16: 512, 32: 512, 64: 256 * channel_multiplier, 128: 128 * channel_multiplier, 256: 64 * channel_multiplier, 512: 32 * channel_multiplier, 1024: 16 * channel_multiplier, } self.input = ConstantInput(self.channels[4]) self.conv1 = StyledConv( self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel ) self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) self.log_size = int(math.log(size, 2)) self.num_layers = (self.log_size - 2) * 2 + 1 self.convs = nn.ModuleList() self.upsamples = nn.ModuleList() self.to_rgbs = nn.ModuleList() self.noises = nn.Module() in_channel = self.channels[4] for layer_idx in range(self.num_layers): res = (layer_idx + 5) // 2 shape = [1, 1, 2 ** res, 2 ** res] self.noises.register_buffer( "noise_{}".format(layer_idx), torch.randn(*shape) ) for i in range(3, self.log_size + 1): out_channel = self.channels[2 ** i] self.convs.append( StyledConv( in_channel, out_channel, 3, style_dim, upsample=True, blur_kernel=blur_kernel, ) ) self.convs.append( StyledConv( out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel ) ) self.to_rgbs.append(ToRGB(out_channel, style_dim)) in_channel = out_channel self.n_latent = self.log_size * 2 - 2 def make_noise(self): device = self.input.input.device noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] for i in range(3, self.log_size + 1): for _ in range(2): noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) return noises def mean_latent(self, n_latent): latent_in = torch.randn( n_latent, self.style_dim, device=self.input.input.device ) latent = self.style(latent_in).mean(0, keepdim=True) return latent def get_latent(self, input): return self.style(input) def forward( self, styles, return_latents=False, return_features=False, inject_index=None, truncation=1, truncation_latent=None, input_is_latent=False, noise=None, randomize_noise=True, ): if not input_is_latent: # print("haha") styles = [self.style(s) for s in styles] if noise is None: if randomize_noise: noise = [None] * self.num_layers else: noise = [ getattr(self.noises, "noise_{}".format(i)) for i in range(self.num_layers) ] if truncation < 1: style_t = [] for style in styles: style_t.append( truncation_latent + truncation * (style - truncation_latent) ) styles = style_t # print(styles) if len(styles) < 2: inject_index = self.n_latent if styles[0].ndim < 3: latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) # print("a") else: # print(len(styles)) latent = styles[0] # print("b", latent.shape) else: # print("c") if inject_index is None: inject_index = 4 latent = styles[0].unsqueeze(0) if latent.shape[1] == 1: latent = latent.repeat(1, inject_index, 1) else: latent = latent[:, :inject_index, :] latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) latent = torch.cat([latent, latent2], 1) features = {} out = self.input(latent) features["out_0"] = out out = self.conv1(out, latent[:, 0], noise=noise[0]) features["conv1_0"] = out skip = self.to_rgb1(out, latent[:, 1]) features["skip_0"] = skip i = 1 for conv1, conv2, noise1, noise2, to_rgb in zip( self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs ): out = conv1(out, latent[:, i], noise=noise1) features["conv1_{}".format(i)] = out out = conv2(out, latent[:, i + 1], noise=noise2) features["conv2_{}".format(i)] = out skip = to_rgb(out, latent[:, i + 2], skip) features["skip_{}".format(i)] = skip i += 2 image = skip if return_latents: return image, latent elif return_features: return image, features else: return image, None class ConvLayer(nn.Sequential): def __init__( self, in_channel, out_channel, kernel_size, downsample=False, blur_kernel=[1, 3, 3, 1], bias=True, activate=True, ): layers = [] if downsample: factor = 2 p = (len(blur_kernel) - factor) + (kernel_size - 1) pad0 = (p + 1) // 2 pad1 = p // 2 layers.append(Blur(blur_kernel, pad=(pad0, pad1))) stride = 2 self.padding = 0 else: stride = 1 self.padding = kernel_size // 2 layers.append( EqualConv2d( in_channel, out_channel, kernel_size, padding=self.padding, stride=stride, bias=bias and not activate, ) ) if activate: if bias: layers.append(FusedLeakyReLU(out_channel)) else: layers.append(ScaledLeakyReLU(0.2)) super().__init__(*layers) class ResBlock(nn.Module): def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): super().__init__() self.conv1 = ConvLayer(in_channel, in_channel, 3) self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) self.skip = ConvLayer( in_channel, out_channel, 1, downsample=True, activate=False, bias=False ) def forward(self, input): out = self.conv1(input) out = self.conv2(out) skip = self.skip(input) out = (out + skip) / math.sqrt(2) return out class StyleDiscriminator(nn.Module): def __init__( self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], small=False ): super().__init__() if small: channels = {4: 64, 8: 64, 16: 64, 32: 64, 64: 64} else: channels = { 4: 512, 8: 512, 16: 512, 32: 512, 64: 256 * channel_multiplier, 128: 128 * channel_multiplier, 256: 64 * channel_multiplier, 512: 32 * channel_multiplier, 1024: 16 * channel_multiplier, } convs = [ConvLayer(3, channels[size], 1)] log_size = int(math.log(size, 2)) in_channel = channels[size] for i in range(log_size, 2, -1): out_channel = channels[2 ** (i - 1)] convs.append(ResBlock(in_channel, out_channel, blur_kernel)) in_channel = out_channel self.convs = nn.Sequential(*convs) self.stddev_group = 4 self.stddev_feat = 1 self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) self.final_linear = nn.Sequential( EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"), EqualLinear(channels[4], 1), ) # def forward(self, input): # out = self.convs(input) # batch, channel, height, width = out.shape # group = min(batch, self.stddev_group) # stddev = out.view( # group, -1, self.stddev_feat, channel // self.stddev_feat, height, width # ) # stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) # stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) # stddev = stddev.repeat(group, 1, height, width) # out = torch.cat([out, stddev], 1) # out = self.final_conv(out) # out = out.view(batch, -1) # out = self.final_linear(out) # return out def forward(self, input): h = input h_list = [] for index, blocklist in enumerate(self.convs): h = blocklist(h) h_list.append(h) out = h batch, channel, height, width = out.shape group = min(batch, self.stddev_group) stddev = out.view( group, -1, self.stddev_feat, channel // self.stddev_feat, height, width ) stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) stddev = stddev.repeat(group, 1, height, width) out = torch.cat([out, stddev], 1) out = self.final_conv(out) h_list.append(out) out = out.view(batch, -1) out = self.final_linear(out) return out, h_list class StyleEncoder(nn.Module): def __init__(self, size, w_dim=512): super().__init__() channels = { 4: 512, 8: 512, 16: 512, 32: 512, 64: 256, 128: 128, 256: 64, 512: 32, 1024: 16 } self.w_dim = w_dim log_size = int(math.log(size, 2)) # self.n_latents = log_size*2 - 2 convs = [ConvLayer(3, channels[size], 1)] in_channel = channels[size] for i in range(log_size, 2, -1): out_channel = channels[2 ** (i - 1)] convs.append(ResBlock(in_channel, out_channel)) in_channel = out_channel # convs.append(EqualConv2d(in_channel, self.n_latents*self.w_dim, 4, padding=0, bias=False)) convs.append(EqualConv2d(in_channel,2*self.w_dim, 4, padding=0, bias=False)) self.convs = nn.Sequential(*convs) def forward(self, input): out = self.convs(input) # return out.view(len(input), self.n_latents, self.w_dim) reshaped = out.view(len(input), 2*self.w_dim) return reshaped[:,:self.w_dim], reshaped[:,self.w_dim:] def kaiming_init(m): if isinstance(m, (nn.Linear, nn.Conv2d)): init.kaiming_normal_(m.weight) if m.bias is not None: m.bias.data.fill_(0) elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): m.weight.data.fill_(1) if m.bias is not None: m.bias.data.fill_(0) def normal_init(m): if isinstance(m, (nn.Linear, nn.Conv2d)): init.normal_(m.weight, 0, 0.02) if m.bias is not None: m.bias.data.fill_(0) elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): m.weight.data.fill_(1) if m.bias is not None: m.bias.data.fill_(0) ================================================ FILE: stylegan_human/torch_utils/op_edit/__init__.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. from .fused_act import FusedLeakyReLU, fused_leaky_relu from .upfirdn2d import upfirdn2d ================================================ FILE: stylegan_human/torch_utils/op_edit/fused_act.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. import os import torch from torch import nn from torch.nn import functional as F from torch.autograd import Function from torch.utils.cpp_extension import load module_path = os.path.dirname(__file__) fused = load( "fused", sources=[ os.path.join(module_path, "fused_bias_act.cpp"), os.path.join(module_path, "fused_bias_act_kernel.cu"), ], ) class FusedLeakyReLUFunctionBackward(Function): @staticmethod def forward(ctx, grad_output, out, negative_slope, scale): ctx.save_for_backward(out) ctx.negative_slope = negative_slope ctx.scale = scale empty = grad_output.new_empty(0) grad_input = fused.fused_bias_act( grad_output, empty, out, 3, 1, negative_slope, scale ) dim = [0] if grad_input.ndim > 2: dim += list(range(2, grad_input.ndim)) grad_bias = grad_input.sum(dim).detach() return grad_input, grad_bias @staticmethod def backward(ctx, gradgrad_input, gradgrad_bias): (out,) = ctx.saved_tensors gradgrad_out = fused.fused_bias_act( gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale ) return gradgrad_out, None, None, None class FusedLeakyReLUFunction(Function): @staticmethod def forward(ctx, input, bias, negative_slope, scale): empty = input.new_empty(0) out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) ctx.save_for_backward(out) ctx.negative_slope = negative_slope ctx.scale = scale return out @staticmethod def backward(ctx, grad_output): (out,) = ctx.saved_tensors grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( grad_output, out, ctx.negative_slope, ctx.scale ) return grad_input, grad_bias, None, None class FusedLeakyReLU(nn.Module): def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): super().__init__() self.bias = nn.Parameter(torch.zeros(channel)) self.negative_slope = negative_slope self.scale = scale def forward(self, input): return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): if input.device.type == "cpu": rest_dim = [1] * (input.ndim - bias.ndim - 1) return ( F.leaky_relu( input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 ) * scale ) else: return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) ================================================ FILE: stylegan_human/torch_utils/op_edit/fused_bias_act.cpp ================================================ // Copyright (c) SenseTime Research. All rights reserved. #include torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, int act, int grad, float alpha, float scale); #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, int act, int grad, float alpha, float scale) { CHECK_CUDA(input); CHECK_CUDA(bias); return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); } ================================================ FILE: stylegan_human/torch_utils/op_edit/fused_bias_act_kernel.cu ================================================ // Copyright (c) SenseTime Research. All rights reserved. // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. // // This work is made available under the Nvidia Source Code License-NC. // To view a copy of this license, visit // https://nvlabs.github.io/stylegan2/license.html #include #include #include #include #include #include #include template static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; scalar_t zero = 0.0; for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { scalar_t x = p_x[xi]; if (use_bias) { x += p_b[(xi / step_b) % size_b]; } scalar_t ref = use_ref ? p_ref[xi] : zero; scalar_t y; switch (act * 10 + grad) { default: case 10: y = x; break; case 11: y = x; break; case 12: y = 0.0; break; case 30: y = (x > 0.0) ? x : x * alpha; break; case 31: y = (ref > 0.0) ? x : x * alpha; break; case 32: y = 0.0; break; } out[xi] = y * scale; } } torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, int act, int grad, float alpha, float scale) { int curDevice = -1; cudaGetDevice(&curDevice); cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); auto x = input.contiguous(); auto b = bias.contiguous(); auto ref = refer.contiguous(); int use_bias = b.numel() ? 1 : 0; int use_ref = ref.numel() ? 1 : 0; int size_x = x.numel(); int size_b = b.numel(); int step_b = 1; for (int i = 1 + 1; i < x.dim(); i++) { step_b *= x.size(i); } int loop_x = 4; int block_size = 4 * 32; int grid_size = (size_x - 1) / (loop_x * block_size) + 1; auto y = torch::empty_like(x); AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { fused_bias_act_kernel<<>>( y.data_ptr(), x.data_ptr(), b.data_ptr(), ref.data_ptr(), act, grad, alpha, scale, loop_x, size_x, step_b, size_b, use_bias, use_ref ); }); return y; } ================================================ FILE: stylegan_human/torch_utils/op_edit/upfirdn2d.cpp ================================================ // Copyright (c) SenseTime Research. All rights reserved. #include torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, int up_x, int up_y, int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0, int pad_y1); #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, int up_x, int up_y, int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0, int pad_y1) { CHECK_CUDA(input); CHECK_CUDA(kernel); return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); } ================================================ FILE: stylegan_human/torch_utils/op_edit/upfirdn2d.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. import os import torch from torch.nn import functional as F from torch.autograd import Function from torch.utils.cpp_extension import load module_path = os.path.dirname(__file__) upfirdn2d_op = load( "upfirdn2d", sources=[ os.path.join(module_path, "upfirdn2d.cpp"), os.path.join(module_path, "upfirdn2d_kernel.cu"), ], ) class UpFirDn2dBackward(Function): @staticmethod def forward( ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size ): up_x, up_y = up down_x, down_y = down g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) grad_input = upfirdn2d_op.upfirdn2d( grad_output, grad_kernel, down_x, down_y, up_x, up_y, g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1, ) grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) ctx.save_for_backward(kernel) pad_x0, pad_x1, pad_y0, pad_y1 = pad ctx.up_x = up_x ctx.up_y = up_y ctx.down_x = down_x ctx.down_y = down_y ctx.pad_x0 = pad_x0 ctx.pad_x1 = pad_x1 ctx.pad_y0 = pad_y0 ctx.pad_y1 = pad_y1 ctx.in_size = in_size ctx.out_size = out_size return grad_input @staticmethod def backward(ctx, gradgrad_input): (kernel,) = ctx.saved_tensors gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) gradgrad_out = upfirdn2d_op.upfirdn2d( gradgrad_input, kernel, ctx.up_x, ctx.up_y, ctx.down_x, ctx.down_y, ctx.pad_x0, ctx.pad_x1, ctx.pad_y0, ctx.pad_y1, ) # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) gradgrad_out = gradgrad_out.view( ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] ) return gradgrad_out, None, None, None, None, None, None, None, None class UpFirDn2d(Function): @staticmethod def forward(ctx, input, kernel, up, down, pad): up_x, up_y = up down_x, down_y = down pad_x0, pad_x1, pad_y0, pad_y1 = pad kernel_h, kernel_w = kernel.shape batch, channel, in_h, in_w = input.shape ctx.in_size = input.shape input = input.reshape(-1, in_h, in_w, 1) ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 ctx.out_size = (out_h, out_w) ctx.up = (up_x, up_y) ctx.down = (down_x, down_y) ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) g_pad_x0 = kernel_w - pad_x0 - 1 g_pad_y0 = kernel_h - pad_y0 - 1 g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) out = upfirdn2d_op.upfirdn2d( input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 ) # out = out.view(major, out_h, out_w, minor) out = out.view(-1, channel, out_h, out_w) return out @staticmethod def backward(ctx, grad_output): kernel, grad_kernel = ctx.saved_tensors grad_input = UpFirDn2dBackward.apply( grad_output, kernel, grad_kernel, ctx.up, ctx.down, ctx.pad, ctx.g_pad, ctx.in_size, ctx.out_size, ) return grad_input, None, None, None, None def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): if input.device.type == "cpu": out = upfirdn2d_native( input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] ) else: out = UpFirDn2d.apply( input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) ) return out def upfirdn2d_native( input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 ): _, channel, in_h, in_w = input.shape input = input.reshape(-1, in_h, in_w, 1) _, in_h, in_w, minor = input.shape kernel_h, kernel_w = kernel.shape out = input.view(-1, in_h, 1, in_w, 1, minor) out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) out = out.view(-1, in_h * up_y, in_w * up_x, minor) out = F.pad( out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] ) out = out[ :, max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), :, ] out = out.permute(0, 3, 1, 2) out = out.reshape( [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] ) w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) out = F.conv2d(out, w) out = out.reshape( -1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, ) out = out.permute(0, 2, 3, 1) out = out[:, ::down_y, ::down_x, :] out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 return out.view(-1, channel, out_h, out_w) ================================================ FILE: stylegan_human/torch_utils/op_edit/upfirdn2d_kernel.cu ================================================ // Copyright (c) SenseTime Research. All rights reserved. // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. // // This work is made available under the Nvidia Source Code License-NC. // To view a copy of this license, visit // https://nvlabs.github.io/stylegan2/license.html #include #include #include #include #include #include #include static __host__ __device__ __forceinline__ int floor_div(int a, int b) { int c = a / b; if (c * b > a) { c--; } return c; } struct UpFirDn2DKernelParams { int up_x; int up_y; int down_x; int down_y; int pad_x0; int pad_x1; int pad_y0; int pad_y1; int major_dim; int in_h; int in_w; int minor_dim; int kernel_h; int kernel_w; int out_h; int out_w; int loop_major; int loop_x; }; template __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, const scalar_t *kernel, const UpFirDn2DKernelParams p) { int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; int out_y = minor_idx / p.minor_dim; minor_idx -= out_y * p.minor_dim; int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; int major_idx_base = blockIdx.z * p.loop_major; if (out_x_base >= p.out_w || out_y >= p.out_h || major_idx_base >= p.major_dim) { return; } int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major && major_idx < p.major_dim; loop_major++, major_idx++) { for (int loop_x = 0, out_x = out_x_base; loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; const scalar_t *x_p = &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; int x_px = p.minor_dim; int k_px = -p.up_x; int x_py = p.in_w * p.minor_dim; int k_py = -p.up_y * p.kernel_w; scalar_t v = 0.0f; for (int y = 0; y < h; y++) { for (int x = 0; x < w; x++) { v += static_cast(*x_p) * static_cast(*k_p); x_p += x_px; k_p += k_px; } x_p += x_py - w * x_px; k_p += k_py - w * k_px; } out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; } } } template __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, const scalar_t *kernel, const UpFirDn2DKernelParams p) { const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; __shared__ volatile float sk[kernel_h][kernel_w]; __shared__ volatile float sx[tile_in_h][tile_in_w]; int minor_idx = blockIdx.x; int tile_out_y = minor_idx / p.minor_dim; minor_idx -= tile_out_y * p.minor_dim; tile_out_y *= tile_out_h; int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; int major_idx_base = blockIdx.z * p.loop_major; if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) { return; } for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) { int ky = tap_idx / kernel_w; int kx = tap_idx - ky * kernel_w; scalar_t v = 0.0; if (kx < p.kernel_w & ky < p.kernel_h) { v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; } sk[ky][kx] = v; } for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) { for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) { int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; int tile_in_x = floor_div(tile_mid_x, up_x); int tile_in_y = floor_div(tile_mid_y, up_y); __syncthreads(); for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) { int rel_in_y = in_idx / tile_in_w; int rel_in_x = in_idx - rel_in_y * tile_in_w; int in_x = rel_in_x + tile_in_x; int in_y = rel_in_y + tile_in_y; scalar_t v = 0.0; if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; } sx[rel_in_y][rel_in_x] = v; } __syncthreads(); for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) { int rel_out_y = out_idx / tile_out_w; int rel_out_x = out_idx - rel_out_y * tile_out_w; int out_x = rel_out_x + tile_out_x; int out_y = rel_out_y + tile_out_y; int mid_x = tile_mid_x + rel_out_x * down_x; int mid_y = tile_mid_y + rel_out_y * down_y; int in_x = floor_div(mid_x, up_x); int in_y = floor_div(mid_y, up_y); int rel_in_x = in_x - tile_in_x; int rel_in_y = in_y - tile_in_y; int kernel_x = (in_x + 1) * up_x - mid_x - 1; int kernel_y = (in_y + 1) * up_y - mid_y - 1; scalar_t v = 0.0; #pragma unroll for (int y = 0; y < kernel_h / up_y; y++) #pragma unroll for (int x = 0; x < kernel_w / up_x; x++) v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x]; if (out_x < p.out_w & out_y < p.out_h) { out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; } } } } } torch::Tensor upfirdn2d_op(const torch::Tensor &input, const torch::Tensor &kernel, int up_x, int up_y, int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0, int pad_y1) { int curDevice = -1; cudaGetDevice(&curDevice); cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); UpFirDn2DKernelParams p; auto x = input.contiguous(); auto k = kernel.contiguous(); p.major_dim = x.size(0); p.in_h = x.size(1); p.in_w = x.size(2); p.minor_dim = x.size(3); p.kernel_h = k.size(0); p.kernel_w = k.size(1); p.up_x = up_x; p.up_y = up_y; p.down_x = down_x; p.down_y = down_y; p.pad_x0 = pad_x0; p.pad_x1 = pad_x1; p.pad_y0 = pad_y0; p.pad_y1 = pad_y1; p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y; p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x; auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); int mode = -1; int tile_out_h = -1; int tile_out_w = -1; if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { mode = 1; tile_out_h = 16; tile_out_w = 64; } if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) { mode = 2; tile_out_h = 16; tile_out_w = 64; } if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { mode = 3; tile_out_h = 16; tile_out_w = 64; } if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) { mode = 4; tile_out_h = 16; tile_out_w = 64; } if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) { mode = 5; tile_out_h = 8; tile_out_w = 32; } if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) { mode = 6; tile_out_h = 8; tile_out_w = 32; } dim3 block_size; dim3 grid_size; if (tile_out_h > 0 && tile_out_w > 0) { p.loop_major = (p.major_dim - 1) / 16384 + 1; p.loop_x = 1; block_size = dim3(32 * 8, 1, 1); grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, (p.major_dim - 1) / p.loop_major + 1); } else { p.loop_major = (p.major_dim - 1) / 16384 + 1; p.loop_x = 4; block_size = dim3(4, 32, 1); grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, (p.out_w - 1) / (p.loop_x * block_size.y) + 1, (p.major_dim - 1) / p.loop_major + 1); } AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { switch (mode) { case 1: upfirdn2d_kernel <<>>(out.data_ptr(), x.data_ptr(), k.data_ptr(), p); break; case 2: upfirdn2d_kernel <<>>(out.data_ptr(), x.data_ptr(), k.data_ptr(), p); break; case 3: upfirdn2d_kernel <<>>(out.data_ptr(), x.data_ptr(), k.data_ptr(), p); break; case 4: upfirdn2d_kernel <<>>(out.data_ptr(), x.data_ptr(), k.data_ptr(), p); break; case 5: upfirdn2d_kernel <<>>(out.data_ptr(), x.data_ptr(), k.data_ptr(), p); break; case 6: upfirdn2d_kernel <<>>(out.data_ptr(), x.data_ptr(), k.data_ptr(), p); break; default: upfirdn2d_kernel_large<<>>( out.data_ptr(), x.data_ptr(), k.data_ptr(), p); } }); return out; } ================================================ FILE: stylegan_human/torch_utils/ops/__init__.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. #empty ================================================ FILE: stylegan_human/torch_utils/ops/bias_act.cpp ================================================ // Copyright (c) SenseTime Research. All rights reserved. // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include #include #include #include "bias_act.h" //------------------------------------------------------------------------ static bool has_same_layout(torch::Tensor x, torch::Tensor y) { if (x.dim() != y.dim()) return false; for (int64_t i = 0; i < x.dim(); i++) { if (x.size(i) != y.size(i)) return false; if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) return false; } return true; } //------------------------------------------------------------------------ static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) { // Validate arguments. TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); TORCH_CHECK(b.dim() == 1, "b must have rank 1"); TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); TORCH_CHECK(grad >= 0, "grad must be non-negative"); // Validate layout. TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); // Create output tensor. const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); torch::Tensor y = torch::empty_like(x); TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); // Initialize CUDA kernel parameters. bias_act_kernel_params p; p.x = x.data_ptr(); p.b = (b.numel()) ? b.data_ptr() : NULL; p.xref = (xref.numel()) ? xref.data_ptr() : NULL; p.yref = (yref.numel()) ? yref.data_ptr() : NULL; p.dy = (dy.numel()) ? dy.data_ptr() : NULL; p.y = y.data_ptr(); p.grad = grad; p.act = act; p.alpha = alpha; p.gain = gain; p.clamp = clamp; p.sizeX = (int)x.numel(); p.sizeB = (int)b.numel(); p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; // Choose CUDA kernel. void* kernel; AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { kernel = choose_bias_act_kernel(p); }); TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); // Launch CUDA kernel. p.loopX = 4; int blockSize = 4 * 32; int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; void* args[] = {&p}; AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); return y; } //------------------------------------------------------------------------ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("bias_act", &bias_act); } //------------------------------------------------------------------------ ================================================ FILE: stylegan_human/torch_utils/ops/bias_act.cu ================================================ // Copyright (c) SenseTime Research. All rights reserved. // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include #include "bias_act.h" //------------------------------------------------------------------------ // Helpers. template struct InternalType; template <> struct InternalType { typedef double scalar_t; }; template <> struct InternalType { typedef float scalar_t; }; template <> struct InternalType { typedef float scalar_t; }; //------------------------------------------------------------------------ // CUDA kernel. template __global__ void bias_act_kernel(bias_act_kernel_params p) { typedef typename InternalType::scalar_t scalar_t; int G = p.grad; scalar_t alpha = (scalar_t)p.alpha; scalar_t gain = (scalar_t)p.gain; scalar_t clamp = (scalar_t)p.clamp; scalar_t one = (scalar_t)1; scalar_t two = (scalar_t)2; scalar_t expRange = (scalar_t)80; scalar_t halfExpRange = (scalar_t)40; scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; // Loop over elements. int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) { // Load. scalar_t x = (scalar_t)((const T*)p.x)[xi]; scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; scalar_t yy = (gain != 0) ? yref / gain : 0; scalar_t y = 0; // Apply bias. ((G == 0) ? x : xref) += b; // linear if (A == 1) { if (G == 0) y = x; if (G == 1) y = x; } // relu if (A == 2) { if (G == 0) y = (x > 0) ? x : 0; if (G == 1) y = (yy > 0) ? x : 0; } // lrelu if (A == 3) { if (G == 0) y = (x > 0) ? x : x * alpha; if (G == 1) y = (yy > 0) ? x : x * alpha; } // tanh if (A == 4) { if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } if (G == 1) y = x * (one - yy * yy); if (G == 2) y = x * (one - yy * yy) * (-two * yy); } // sigmoid if (A == 5) { if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); if (G == 1) y = x * yy * (one - yy); if (G == 2) y = x * yy * (one - yy) * (one - two * yy); } // elu if (A == 6) { if (G == 0) y = (x >= 0) ? x : exp(x) - one; if (G == 1) y = (yy >= 0) ? x : x * (yy + one); if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); } // selu if (A == 7) { if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); } // softplus if (A == 8) { if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); if (G == 1) y = x * (one - exp(-yy)); if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } } // swish if (A == 9) { if (G == 0) y = (x < -expRange) ? 0 : x / (exp(-x) + one); else { scalar_t c = exp(xref); scalar_t d = c + one; if (G == 1) y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); else y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; } } // Apply gain. y *= gain * dy; // Clamp. if (clamp >= 0) { if (G == 0) y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; else y = (yref > -clamp & yref < clamp) ? y : 0; } // Store. ((T*)p.y)[xi] = (T)y; } } //------------------------------------------------------------------------ // CUDA kernel selection. template void* choose_bias_act_kernel(const bias_act_kernel_params& p) { if (p.act == 1) return (void*)bias_act_kernel; if (p.act == 2) return (void*)bias_act_kernel; if (p.act == 3) return (void*)bias_act_kernel; if (p.act == 4) return (void*)bias_act_kernel; if (p.act == 5) return (void*)bias_act_kernel; if (p.act == 6) return (void*)bias_act_kernel; if (p.act == 7) return (void*)bias_act_kernel; if (p.act == 8) return (void*)bias_act_kernel; if (p.act == 9) return (void*)bias_act_kernel; return NULL; } //------------------------------------------------------------------------ // Template specializations. template void* choose_bias_act_kernel (const bias_act_kernel_params& p); template void* choose_bias_act_kernel (const bias_act_kernel_params& p); template void* choose_bias_act_kernel (const bias_act_kernel_params& p); //------------------------------------------------------------------------ ================================================ FILE: stylegan_human/torch_utils/ops/bias_act.h ================================================ // Copyright (c) SenseTime Research. All rights reserved. // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. //------------------------------------------------------------------------ // CUDA kernel parameters. struct bias_act_kernel_params { const void* x; // [sizeX] const void* b; // [sizeB] or NULL const void* xref; // [sizeX] or NULL const void* yref; // [sizeX] or NULL const void* dy; // [sizeX] or NULL void* y; // [sizeX] int grad; int act; float alpha; float gain; float clamp; int sizeX; int sizeB; int stepB; int loopX; }; //------------------------------------------------------------------------ // CUDA kernel selection. template void* choose_bias_act_kernel(const bias_act_kernel_params& p); //------------------------------------------------------------------------ ================================================ FILE: stylegan_human/torch_utils/ops/bias_act.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Custom PyTorch ops for efficient bias and activation.""" import os import warnings import numpy as np import torch import dnnlib import traceback from .. import custom_ops from .. import misc #---------------------------------------------------------------------------- activation_funcs = { 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), } #---------------------------------------------------------------------------- _inited = False _plugin = None _null_tensor = torch.empty([0]) def _init(): global _inited, _plugin if not _inited: _inited = True sources = ['bias_act.cpp', 'bias_act.cu'] sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] try: _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) except: warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) return _plugin is not None #---------------------------------------------------------------------------- def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): r"""Fused bias and activation function. Adds bias `b` to activation tensor `x`, evaluates activation function `act`, and scales the result by `gain`. Each of the steps is optional. In most cases, the fused op is considerably more efficient than performing the same calculation using standard PyTorch ops. It supports first and second order gradients, but not third order gradients. Args: x: Input activation tensor. Can be of any shape. b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type as `x`. The shape must be known, and it must match the dimension of `x` corresponding to `dim`. dim: The dimension in `x` corresponding to the elements of `b`. The value of `dim` is ignored if `b` is not specified. act: Name of the activation function to evaluate, or `"linear"` to disable. Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. See `activation_funcs` for a full list. `None` is not allowed. alpha: Shape parameter for the activation function, or `None` to use the default. gain: Scaling factor for the output tensor, or `None` to use default. See `activation_funcs` for the default scaling of each activation function. If unsure, consider specifying 1. clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable the clamping (default). impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). Returns: Tensor of the same shape and datatype as `x`. """ assert isinstance(x, torch.Tensor) assert impl in ['ref', 'cuda'] if impl == 'cuda' and x.device.type == 'cuda' and _init(): return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) #---------------------------------------------------------------------------- @misc.profiled_function def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): """Slow reference implementation of `bias_act()` using standard TensorFlow ops. """ assert isinstance(x, torch.Tensor) assert clamp is None or clamp >= 0 spec = activation_funcs[act] alpha = float(alpha if alpha is not None else spec.def_alpha) gain = float(gain if gain is not None else spec.def_gain) clamp = float(clamp if clamp is not None else -1) # Add bias. if b is not None: assert isinstance(b, torch.Tensor) and b.ndim == 1 assert 0 <= dim < x.ndim assert b.shape[0] == x.shape[dim] x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) # Evaluate activation function. alpha = float(alpha) x = spec.func(x, alpha=alpha) # Scale by gain. gain = float(gain) if gain != 1: x = x * gain # Clamp. if clamp >= 0: x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type return x #---------------------------------------------------------------------------- _bias_act_cuda_cache = dict() def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): """Fast CUDA implementation of `bias_act()` using custom ops. """ # Parse arguments. assert clamp is None or clamp >= 0 spec = activation_funcs[act] alpha = float(alpha if alpha is not None else spec.def_alpha) gain = float(gain if gain is not None else spec.def_gain) clamp = float(clamp if clamp is not None else -1) # Lookup from cache. key = (dim, act, alpha, gain, clamp) if key in _bias_act_cuda_cache: return _bias_act_cuda_cache[key] # Forward op. class BiasActCuda(torch.autograd.Function): @staticmethod def forward(ctx, x, b): # pylint: disable=arguments-differ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format x = x.contiguous(memory_format=ctx.memory_format) b = b.contiguous() if b is not None else _null_tensor y = x if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) ctx.save_for_backward( x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, y if 'y' in spec.ref else _null_tensor) return y @staticmethod def backward(ctx, dy): # pylint: disable=arguments-differ dy = dy.contiguous(memory_format=ctx.memory_format) x, b, y = ctx.saved_tensors dx = None db = None if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: dx = dy if act != 'linear' or gain != 1 or clamp >= 0: dx = BiasActCudaGrad.apply(dy, x, b, y) if ctx.needs_input_grad[1]: db = dx.sum([i for i in range(dx.ndim) if i != dim]) return dx, db # Backward op. class BiasActCudaGrad(torch.autograd.Function): @staticmethod def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) ctx.save_for_backward( dy if spec.has_2nd_grad else _null_tensor, x, b, y) return dx @staticmethod def backward(ctx, d_dx): # pylint: disable=arguments-differ d_dx = d_dx.contiguous(memory_format=ctx.memory_format) dy, x, b, y = ctx.saved_tensors d_dy = None d_x = None d_b = None d_y = None if ctx.needs_input_grad[0]: d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) if spec.has_2nd_grad and ctx.needs_input_grad[2]: d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) return d_dy, d_x, d_b, d_y # Add to cache. _bias_act_cuda_cache[key] = BiasActCuda return BiasActCuda #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/torch_utils/ops/conv2d_gradfix.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Custom replacement for `torch.nn.functional.conv2d` that supports arbitrarily high order gradients with zero performance penalty.""" import warnings import contextlib import torch # pylint: disable=redefined-builtin # pylint: disable=arguments-differ # pylint: disable=protected-access #---------------------------------------------------------------------------- enabled = False # Enable the custom op by setting this to true. weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. @contextlib.contextmanager def no_weight_gradients(): global weight_gradients_disabled old = weight_gradients_disabled weight_gradients_disabled = True yield weight_gradients_disabled = old #---------------------------------------------------------------------------- def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): if _should_use_custom_op(input): return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): if _should_use_custom_op(input): return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) #---------------------------------------------------------------------------- def _should_use_custom_op(input): assert isinstance(input, torch.Tensor) if (not enabled) or (not torch.backends.cudnn.enabled): return False if input.device.type != 'cuda': return False if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): return True warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') return False def _tuple_of_ints(xs, ndim): xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim assert len(xs) == ndim assert all(isinstance(x, int) for x in xs) return xs #---------------------------------------------------------------------------- _conv2d_gradfix_cache = dict() def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): # Parse arguments. ndim = 2 weight_shape = tuple(weight_shape) stride = _tuple_of_ints(stride, ndim) padding = _tuple_of_ints(padding, ndim) output_padding = _tuple_of_ints(output_padding, ndim) dilation = _tuple_of_ints(dilation, ndim) # Lookup from cache. key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) if key in _conv2d_gradfix_cache: return _conv2d_gradfix_cache[key] # Validate arguments. assert groups >= 1 assert len(weight_shape) == ndim + 2 assert all(stride[i] >= 1 for i in range(ndim)) assert all(padding[i] >= 0 for i in range(ndim)) assert all(dilation[i] >= 0 for i in range(ndim)) if not transpose: assert all(output_padding[i] == 0 for i in range(ndim)) else: # transpose assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) # Helpers. common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) def calc_output_padding(input_shape, output_shape): if transpose: return [0, 0] return [ input_shape[i + 2] - (output_shape[i + 2] - 1) * stride[i] - (1 - 2 * padding[i]) - dilation[i] * (weight_shape[i + 2] - 1) for i in range(ndim) ] # Forward & backward. class Conv2d(torch.autograd.Function): @staticmethod def forward(ctx, input, weight, bias): assert weight.shape == weight_shape if not transpose: output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) else: # transpose output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) ctx.save_for_backward(input, weight) return output @staticmethod def backward(ctx, grad_output): input, weight = ctx.saved_tensors grad_input = None grad_weight = None grad_bias = None if ctx.needs_input_grad[0]: p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None) assert grad_input.shape == input.shape if ctx.needs_input_grad[1] and not weight_gradients_disabled: grad_weight = Conv2dGradWeight.apply(grad_output, input) assert grad_weight.shape == weight_shape if ctx.needs_input_grad[2]: grad_bias = grad_output.sum([0, 2, 3]) return grad_input, grad_weight, grad_bias # Gradient with respect to the weights. class Conv2dGradWeight(torch.autograd.Function): @staticmethod def forward(ctx, grad_output, input): op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) assert grad_weight.shape == weight_shape ctx.save_for_backward(grad_output, input) return grad_weight @staticmethod def backward(ctx, grad2_grad_weight): grad_output, input = ctx.saved_tensors grad2_grad_output = None grad2_input = None if ctx.needs_input_grad[0]: grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) assert grad2_grad_output.shape == grad_output.shape if ctx.needs_input_grad[1]: p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) assert grad2_input.shape == input.shape return grad2_grad_output, grad2_input _conv2d_gradfix_cache[key] = Conv2d return Conv2d #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/torch_utils/ops/conv2d_resample.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """2D convolution with optional up/downsampling.""" import torch from .. import misc from . import conv2d_gradfix from . import upfirdn2d from .upfirdn2d import _parse_padding from .upfirdn2d import _get_filter_size #---------------------------------------------------------------------------- def _get_weight_shape(w): with misc.suppress_tracer_warnings(): # this value will be treated as a constant shape = [int(sz) for sz in w.shape] misc.assert_shape(w, shape) return shape #---------------------------------------------------------------------------- def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. """ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) # Flip weight if requested. if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). w = w.flip([2, 3]) # Workaround performance pitfall in cuDNN 8.0.5, triggered when using # 1x1 kernel + memory_format=channels_last + less than 64 channels. if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: if out_channels <= 4 and groups == 1: in_shape = x.shape x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) else: x = x.to(memory_format=torch.contiguous_format) w = w.to(memory_format=torch.contiguous_format) x = conv2d_gradfix.conv2d(x, w, groups=groups) return x.to(memory_format=torch.channels_last) # Otherwise => execute using conv2d_gradfix. op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d return op(x, w, stride=stride, padding=padding, groups=groups) #---------------------------------------------------------------------------- @misc.profiled_function def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): r"""2D convolution with optional up/downsampling. Padding is performed only once at the beginning, not between the operations. Args: x: Input tensor of shape `[batch_size, in_channels, in_height, in_width]`. w: Weight tensor of shape `[out_channels, in_channels//groups, kernel_height, kernel_width]`. f: Low-pass filter for up/downsampling. Must be prepared beforehand by calling upfirdn2d.setup_filter(). None = identity (default). up: Integer upsampling factor (default: 1). down: Integer downsampling factor (default: 1). padding: Padding with respect to the upsampled image. Can be a single number or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` (default: 0). groups: Split input channels into N groups (default: 1). flip_weight: False = convolution, True = correlation (default: True). flip_filter: False = convolution, True = correlation (default: False). Returns: Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. """ # Validate arguments. assert isinstance(x, torch.Tensor) and (x.ndim == 4) assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) assert isinstance(up, int) and (up >= 1) assert isinstance(down, int) and (down >= 1) assert isinstance(groups, int) and (groups >= 1) out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) fw, fh = _get_filter_size(f) px0, px1, py0, py1 = _parse_padding(padding) # Adjust padding to account for up/downsampling. if up > 1: px0 += (fw + up - 1) // 2 px1 += (fw - up) // 2 py0 += (fh + up - 1) // 2 py1 += (fh - up) // 2 if down > 1: px0 += (fw - down + 1) // 2 px1 += (fw - down) // 2 py0 += (fh - down + 1) // 2 py1 += (fh - down) // 2 # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. if kw == 1 and kh == 1 and (down > 1 and up == 1): x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) return x # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. if kw == 1 and kh == 1 and (up > 1 and down == 1): x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) return x # Fast path: downsampling only => use strided convolution. if down > 1 and up == 1: x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) return x # Fast path: upsampling with optional downsampling => use transpose strided convolution. if up > 1: if groups == 1: w = w.transpose(0, 1) else: w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) w = w.transpose(1, 2) w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) px0 -= kw - 1 px1 -= kw - up py0 -= kh - 1 py1 -= kh - up pxt = max(min(-px0, -px1), 0) pyt = max(min(-py0, -py1), 0) x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) if down > 1: x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) return x # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. if up == 1 and down == 1: if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) # Fallback: Generic reference implementation. x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) if down > 1: x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) return x #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/torch_utils/ops/filtered_lrelu.cpp ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include #include #include #include "filtered_lrelu.h" //------------------------------------------------------------------------ static std::tuple filtered_lrelu( torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si, int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns) { // Set CUDA device. TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); // Validate arguments. TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device"); TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32"); TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype"); TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32"); TORCH_CHECK(x.dim() == 4, "x must be rank 4"); TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large"); TORCH_CHECK(x.numel() > 0, "x is empty"); TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2"); TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large"); TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large"); TORCH_CHECK(fu.numel() > 0, "fu is empty"); TORCH_CHECK(fd.numel() > 0, "fd is empty"); TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x"); TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1"); // Figure out how much shared memory is available on the device. int maxSharedBytes = 0; AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index())); int sharedKB = maxSharedBytes >> 10; // Populate enough launch parameters to check if a CUDA kernel exists. filtered_lrelu_kernel_params p; p.up = up; p.down = down; p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter. p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0); filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel(p, sharedKB); if (!test_spec.exec) { // No kernel found - return empty tensors and indicate missing kernel with return code of -1. return std::make_tuple(torch::Tensor(), torch::Tensor(), -1); } // Input/output element size. int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4; // Input sizes. int64_t xw = (int)x.size(3); int64_t xh = (int)x.size(2); int64_t fut_w = (int)fu.size(-1) - 1; int64_t fut_h = (int)fu.size(0) - 1; int64_t fdt_w = (int)fd.size(-1) - 1; int64_t fdt_h = (int)fd.size(0) - 1; // Logical size of upsampled buffer. int64_t cw = xw * up + (px0 + px1) - fut_w; int64_t ch = xh * up + (py0 + py1) - fut_h; TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter"); TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large"); // Compute output size and allocate. int64_t yw = (cw - fdt_w + (down - 1)) / down; int64_t yh = (ch - fdt_h + (down - 1)) / down; TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1"); TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large"); torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format()); // Allocate sign tensor. torch::Tensor so; torch::Tensor s = si; bool readSigns = !!s.numel(); int64_t sw_active = 0; // Active width of sign tensor. if (writeSigns) { sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements. int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height. int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16. TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large"); s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous); } else if (readSigns) sw_active = s.size(3) << 2; // Validate sign tensor if in use. if (readSigns || writeSigns) { TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x"); TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x"); TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large"); } // Populate rest of CUDA kernel parameters. p.x = x.data_ptr(); p.y = y.data_ptr(); p.b = b.data_ptr(); p.s = (readSigns || writeSigns) ? s.data_ptr() : 0; p.fu = fu.data_ptr(); p.fd = fd.data_ptr(); p.pad0 = make_int2(px0, py0); p.gain = gain; p.slope = slope; p.clamp = clamp; p.flip = (flip_filters) ? 1 : 0; p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous. p.sOfs = make_int2(sx, sy); p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes. // x, y, b strides are in bytes. p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0)); p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0)); p.bStride = sz * b.stride(0); // fu, fd strides are in elements. p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0); p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0); // Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those. bool index64b = false; if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true; if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true; if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true; if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true; if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true; if (s.numel() > INT_MAX) index64b = true; // Choose CUDA kernel. filtered_lrelu_kernel_spec spec = { 0 }; AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&] { if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation. { // Choose kernel based on index type, datatype and sign read/write modes. if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); } }); TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists. // Launch CUDA kernel. void* args[] = {&p}; int bx = spec.numWarps * 32; int gx = (p.yShape.x - 1) / spec.tileOut.x + 1; int gy = (p.yShape.y - 1) / spec.tileOut.y + 1; int gz = p.yShape.z * p.yShape.w; // Repeat multiple horizontal tiles in a CTA? if (spec.xrep) { p.tilesXrep = spec.xrep; p.tilesXdim = gx; gx = (gx + p.tilesXrep - 1) / p.tilesXrep; std::swap(gx, gy); } else { p.tilesXrep = 0; p.tilesXdim = 0; } // Launch filter setup kernel. AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream())); // Copy kernels to constant memory. if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); // Set cache and shared memory configurations for main kernel. AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared)); if (spec.dynamicSharedKB) // Need dynamically allocated shared memory? AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10)); AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte)); // Launch main kernel. const int maxSubGz = 65535; // CUDA maximum for block z dimension. for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big. { p.blockZofs = zofs; int subGz = std::min(maxSubGz, gz - zofs); AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream())); } // Done. return std::make_tuple(y, so, 0); } //------------------------------------------------------------------------ static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns) { // Set CUDA device. TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); // Validate arguments. TORCH_CHECK(x.dim() == 4, "x must be rank 4"); TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large"); TORCH_CHECK(x.numel() > 0, "x is empty"); TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64"); // Output signs if we don't have sign input. torch::Tensor so; torch::Tensor s = si; bool readSigns = !!s.numel(); if (writeSigns) { int64_t sw = x.size(3); sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing. s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous); } // Validate sign tensor if in use. if (readSigns || writeSigns) { TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x"); TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x"); TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large"); } // Initialize CUDA kernel parameters. filtered_lrelu_act_kernel_params p; p.x = x.data_ptr(); p.s = (readSigns || writeSigns) ? s.data_ptr() : 0; p.gain = gain; p.slope = slope; p.clamp = clamp; p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0)); p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous. p.sOfs = make_int2(sx, sy); // Choose CUDA kernel. void* func = 0; AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&] { if (writeSigns) func = choose_filtered_lrelu_act_kernel(); else if (readSigns) func = choose_filtered_lrelu_act_kernel(); else func = choose_filtered_lrelu_act_kernel(); }); TORCH_CHECK(func, "internal error - CUDA kernel not found"); // Launch CUDA kernel. void* args[] = {&p}; int bx = 128; // 4 warps per block. // Logical size of launch = writeSigns ? p.s : p.x uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x; uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y; uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use. gx = (gx - 1) / bx + 1; // Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest. const uint32_t gmax = 65535; gy = std::min(gy, gmax); gz = std::min(gz, gmax); // Launch. AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream())); return so; } //------------------------------------------------------------------------ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("filtered_lrelu", &filtered_lrelu); // The whole thing. m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place. } //------------------------------------------------------------------------ ================================================ FILE: stylegan_human/torch_utils/ops/filtered_lrelu.cu ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include #include "filtered_lrelu.h" #include //------------------------------------------------------------------------ // Helpers. enum // Filter modes. { MODE_SUSD = 0, // Separable upsampling, separable downsampling. MODE_FUSD = 1, // Full upsampling, separable downsampling. MODE_SUFD = 2, // Separable upsampling, full downsampling. MODE_FUFD = 3, // Full upsampling, full downsampling. }; template struct InternalType; template <> struct InternalType { typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t; __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); } __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); } __device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); } }; template <> struct InternalType { typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t; __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); } __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); } __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); } }; template <> struct InternalType { typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t; __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); } __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); } __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); } }; #define MIN(A, B) ((A) < (B) ? (A) : (B)) #define MAX(A, B) ((A) > (B) ? (A) : (B)) #define CEIL_DIV(A, B) (((B)==1) ? (A) : \ ((B)==2) ? ((int)((A)+1) >> 1) : \ ((B)==4) ? ((int)((A)+3) >> 2) : \ (((A) + ((A) > 0 ? (B) - 1 : 0)) / (B))) // This works only up to blocks of size 256 x 256 and for all N that are powers of two. template __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i) { if ((N & (N-1)) && N <= 256) y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256. else y = i/N; x = i - y*N; } // Type cast stride before reading it. template __device__ __forceinline__ T get_stride(const int64_t& x) { return *reinterpret_cast(&x); } //------------------------------------------------------------------------ // Filters, setup kernel, copying function. #define MAX_FILTER_SIZE 32 // Combined up/down filter buffers so that transfer can be done with one copy. __device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel. __device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel. // Accessors to combined buffers to index up/down filters individually. #define c_fu (c_fbuf) #define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE) #define g_fu (g_fbuf) #define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE) // Set up filters into global memory buffer. static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p) { for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x) { int x, y; fast_div_mod(x, y, idx); int fu_x = p.flip ? x : (p.fuShape.x - 1 - x); int fu_y = p.flip ? y : (p.fuShape.y - 1 - y); if (p.fuShape.y > 0) g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y]; else g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x]; int fd_x = p.flip ? x : (p.fdShape.x - 1 - x); int fd_y = p.flip ? y : (p.fdShape.y - 1 - y); if (p.fdShape.y > 0) g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y]; else g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x]; } } // Host function to copy filters written by setup kernel into constant buffer for main kernel. template static cudaError_t copy_filters(cudaStream_t stream) { void* src = 0; cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf); if (err) return err; return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream); } //------------------------------------------------------------------------ // Coordinate spaces: // - Relative to input tensor: inX, inY, tileInX, tileInY // - Relative to input tile: relInX, relInY, tileInW, tileInH // - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH // - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH // - Relative to output tensor: outX, outY, tileOutX, tileOutY // // Relationships between coordinate spaces: // - inX = tileInX + relInX // - inY = tileInY + relInY // - relUpX = relInX * up + phaseInX // - relUpY = relInY * up + phaseInY // - relUpX = relOutX * down // - relUpY = relOutY * down // - outX = tileOutX + relOutX // - outY = tileOutY + relOutY extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer. template static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) { // Check that we don't try to support non-existing filter modes. static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported"); static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported"); static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor"); static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor"); static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor"); static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor"); static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE"); static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters"); static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters"); static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4"); static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4"); // Static definitions. typedef typename InternalType::scalar_t scalar_t; typedef typename InternalType::vec2_t vec2_t; typedef typename InternalType::vec4_t vec4_t; const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4. const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height. const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width. const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height. const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up. const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4. // Merge 1x1 downsampling into last upsampling step for upf1 and ups2. const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD)); // Sizes of logical buffers. const int szIn = tileInH_up * tileInW; const int szUpX = tileInH_up * tileUpW; const int szUpXY = downInline ? 0 : (tileUpH * tileUpW); const int szDownX = tileUpH * tileOutW; // Sizes for shared memory arrays. const int s_buf0_size_base = (filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) : (filterMode == MODE_FUSD) ? MAX(szIn, szDownX) : (filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) : (filterMode == MODE_FUFD) ? szIn : -1; const int s_buf1_size_base = (filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) : (filterMode == MODE_FUSD) ? szUpXY : (filterMode == MODE_SUFD) ? szUpX : (filterMode == MODE_FUFD) ? szUpXY : -1; // Ensure U128 alignment. const int s_buf0_size = (s_buf0_size_base + 3) & ~3; const int s_buf1_size = (s_buf1_size_base + 3) & ~3; // Check at compile time that we don't use too much shared memory. static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow"); // Declare shared memory arrays. scalar_t* s_buf0; scalar_t* s_buf1; if (sharedKB <= 48) { // Allocate shared memory arrays here. __shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused. s_buf0 = s_buf0_st; s_buf1 = s_buf0 + s_buf0_size; } else { // Use the dynamically allocated shared memory array. s_buf0 = (scalar_t*)s_buf_raw; s_buf1 = s_buf0 + s_buf0_size; } // Pointers to the buffers. scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY] scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX] scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX] scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX] if (filterMode == MODE_SUSD) { s_tileIn = s_buf0; s_tileUpX = s_buf1; s_tileUpXY = s_buf0; s_tileDownX = s_buf1; } else if (filterMode == MODE_FUSD) { s_tileIn = s_buf0; s_tileUpXY = s_buf1; s_tileDownX = s_buf0; } else if (filterMode == MODE_SUFD) { s_tileIn = s_buf0; s_tileUpX = s_buf1; s_tileUpXY = s_buf0; } else if (filterMode == MODE_FUFD) { s_tileIn = s_buf0; s_tileUpXY = s_buf1; } // Allow large grids in z direction via per-launch offset. int channelIdx = blockIdx.z + p.blockZofs; int batchIdx = channelIdx / p.yShape.z; channelIdx -= batchIdx * p.yShape.z; // Offset to output feature map. In bytes. index_t mapOfsOut = channelIdx * get_stride(p.yStride.z) + batchIdx * get_stride(p.yStride.w); // Sign shift amount. uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6; // Inner tile loop. #pragma unroll 1 for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++) { // Locate output tile. int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x; int tileOutX = tileX * tileOutW; int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH; // Locate input tile. int tmpX = tileOutX * down - p.pad0.x; int tmpY = tileOutY * down - p.pad0.y; int tileInX = CEIL_DIV(tmpX, up); int tileInY = CEIL_DIV(tmpY, up); const int phaseInX = tileInX * up - tmpX; const int phaseInY = tileInY * up - tmpY; // Extra sync if input and output buffers are the same and we are not on first tile. if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline))) __syncthreads(); // Load input tile & apply bias. Unrolled. scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride(p.bStride))); index_t mapOfsIn = channelIdx * get_stride(p.xStride.z) + batchIdx * get_stride(p.xStride.w); int idx = threadIdx.x; const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock); #pragma unroll for (int loop = 0; loop < loopCountIN; loop++) { int relInX, relInY; fast_div_mod(relInX, relInY, idx); int inX = tileInX + relInX; int inY = tileInY + relInY; scalar_t v = 0; if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y) v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride(p.xStride.x) + inY * get_stride(p.xStride.y) + mapOfsIn))) + b; bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH); if (!skip) s_tileIn[idx] = v; idx += threadsPerBlock; } if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter. { // Horizontal upsampling. __syncthreads(); if (up == 4) { for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up) { int relUpX0, relInY; fast_div_mod(relUpX0, relInY, idx); int relInX0 = relUpX0 / up; int src0 = relInX0 + tileInW * relInY; int dst = relInY * tileUpW + relUpX0; vec4_t v = InternalType::zero_vec4(); scalar_t a = s_tileIn[src0]; if (phaseInX == 0) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 0]; a = s_tileIn[src0 + step + 1]; v.y += a * (scalar_t)c_fu[step * up + 3]; v.z += a * (scalar_t)c_fu[step * up + 2]; v.w += a * (scalar_t)c_fu[step * up + 1]; } } else if (phaseInX == 1) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 1]; v.y += a * (scalar_t)c_fu[step * up + 0]; a = s_tileIn[src0 + step + 1]; v.z += a * (scalar_t)c_fu[step * up + 3]; v.w += a * (scalar_t)c_fu[step * up + 2]; } } else if (phaseInX == 2) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 2]; v.y += a * (scalar_t)c_fu[step * up + 1]; v.z += a * (scalar_t)c_fu[step * up + 0]; a = s_tileIn[src0 + step + 1]; v.w += a * (scalar_t)c_fu[step * up + 3]; } } else // (phaseInX == 3) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 3]; v.y += a * (scalar_t)c_fu[step * up + 2]; v.z += a * (scalar_t)c_fu[step * up + 1]; v.w += a * (scalar_t)c_fu[step * up + 0]; a = s_tileIn[src0 + step + 1]; } } s_tileUpX[dst+0] = v.x; s_tileUpX[dst+1] = v.y; s_tileUpX[dst+2] = v.z; s_tileUpX[dst+3] = v.w; } } else if (up == 2) { bool p0 = (phaseInX == 0); for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up) { int relUpX0, relInY; fast_div_mod(relUpX0, relInY, idx); int relInX0 = relUpX0 / up; int src0 = relInX0 + tileInW * relInY; int dst = relInY * tileUpW + relUpX0; vec2_t v = InternalType::zero_vec2(); scalar_t a = s_tileIn[src0]; if (p0) // (phaseInX == 0) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 0]; a = s_tileIn[src0 + step + 1]; v.y += a * (scalar_t)c_fu[step * up + 1]; } } else // (phaseInX == 1) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 1]; v.y += a * (scalar_t)c_fu[step * up + 0]; a = s_tileIn[src0 + step + 1]; } } s_tileUpX[dst+0] = v.x; s_tileUpX[dst+1] = v.y; } } // Vertical upsampling & nonlinearity. __syncthreads(); int groupMask = 15 << ((threadIdx.x & 31) & ~3); int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs. int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes. if (up == 4) { minY -= 3; // Adjust according to block height. for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x) { int relUpX, relInY0; fast_div_mod(relUpX, relInY0, idx); int relUpY0 = relInY0 * up; int src0 = relInY0 * tileUpW + relUpX; int dst = relUpY0 * tileUpW + relUpX; vec4_t v = InternalType::zero_vec4(); scalar_t a = s_tileUpX[src0]; if (phaseInY == 0) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 0]; a = s_tileUpX[src0 + (step + 1) * tileUpW]; v.y += a * (scalar_t)c_fu[step * up + 3]; v.z += a * (scalar_t)c_fu[step * up + 2]; v.w += a * (scalar_t)c_fu[step * up + 1]; } } else if (phaseInY == 1) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 1]; v.y += a * (scalar_t)c_fu[step * up + 0]; a = s_tileUpX[src0 + (step + 1) * tileUpW]; v.z += a * (scalar_t)c_fu[step * up + 3]; v.w += a * (scalar_t)c_fu[step * up + 2]; } } else if (phaseInY == 2) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 2]; v.y += a * (scalar_t)c_fu[step * up + 1]; v.z += a * (scalar_t)c_fu[step * up + 0]; a = s_tileUpX[src0 + (step + 1) * tileUpW]; v.w += a * (scalar_t)c_fu[step * up + 3]; } } else // (phaseInY == 3) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 3]; v.y += a * (scalar_t)c_fu[step * up + 2]; v.z += a * (scalar_t)c_fu[step * up + 1]; v.w += a * (scalar_t)c_fu[step * up + 0]; a = s_tileUpX[src0 + (step + 1) * tileUpW]; } } int x = tileOutX * down + relUpX; int y = tileOutY * down + relUpY0; int signX = x + p.sOfs.x; int signY = y + p.sOfs.y; int signZ = blockIdx.z + p.blockZofs; int signXb = signX >> 2; index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); index_t si1 = si0 + p.sShape.x; index_t si2 = si0 + p.sShape.x * 2; index_t si3 = si0 + p.sShape.x * 3; v.x *= (scalar_t)((float)up * (float)up * p.gain); v.y *= (scalar_t)((float)up * (float)up * p.gain); v.z *= (scalar_t)((float)up * (float)up * p.gain); v.w *= (scalar_t)((float)up * (float)up * p.gain); if (signWrite) { if (!enableWriteSkip) { // Determine and write signs. int sx = __float_as_uint(v.x) >> 31 << 0; int sy = __float_as_uint(v.y) >> 31 << 8; int sz = __float_as_uint(v.z) >> 31 << 16; int sw = __float_as_uint(v.w) >> 31 << 24; if (sx) v.x *= p.slope; if (sy) v.y *= p.slope; if (sz) v.z *= p.slope; if (sw) v.w *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); } if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); } if ((uint32_t)signXb < p.swLimit && signY >= minY) { // Combine signs. uint32_t s = sx + sy + sw + sz; s <<= (signX & 3) << 1; s |= __shfl_xor_sync(groupMask, s, 1); s |= __shfl_xor_sync(groupMask, s, 2); // Write signs. if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); } if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); } } } else { // Determine and write signs. if ((uint32_t)signXb < p.swLimit && signY >= minY) { int sx = __float_as_uint(v.x) >> 31 << 0; int sy = __float_as_uint(v.y) >> 31 << 8; int sz = __float_as_uint(v.z) >> 31 << 16; int sw = __float_as_uint(v.w) >> 31 << 24; if (sx) v.x *= p.slope; if (sy) v.y *= p.slope; if (sz) v.z *= p.slope; if (sw) v.w *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); } if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); } // Combine signs. uint32_t s = sx + sy + sw + sz; s <<= (signX & 3) << 1; s |= __shfl_xor_sync(groupMask, s, 1); s |= __shfl_xor_sync(groupMask, s, 2); // Write signs. if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); } if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); } } else { // Just compute the values. if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); } } } else if (signRead) // Read signs and apply. { if ((uint32_t)signXb < p.swLimit) { int ss = (signX & 3) << 1; if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; } if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; } if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; } if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; } } } else // Forward pass with no sign write. { if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); } s_tileUpXY[dst + 0 * tileUpW] = v.x; if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y; if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z; if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w; } } else if (up == 2) { minY -= 1; // Adjust according to block height. for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x) { int relUpX, relInY0; fast_div_mod(relUpX, relInY0, idx); int relUpY0 = relInY0 * up; int src0 = relInY0 * tileUpW + relUpX; int dst = relUpY0 * tileUpW + relUpX; vec2_t v = InternalType::zero_vec2(); scalar_t a = s_tileUpX[src0]; if (phaseInY == 0) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 0]; a = s_tileUpX[src0 + (step + 1) * tileUpW]; v.y += a * (scalar_t)c_fu[step * up + 1]; } } else // (phaseInY == 1) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 1]; v.y += a * (scalar_t)c_fu[step * up + 0]; a = s_tileUpX[src0 + (step + 1) * tileUpW]; } } int x = tileOutX * down + relUpX; int y = tileOutY * down + relUpY0; int signX = x + p.sOfs.x; int signY = y + p.sOfs.y; int signZ = blockIdx.z + p.blockZofs; int signXb = signX >> 2; index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); index_t si1 = si0 + p.sShape.x; v.x *= (scalar_t)((float)up * (float)up * p.gain); v.y *= (scalar_t)((float)up * (float)up * p.gain); if (signWrite) { if (!enableWriteSkip) { // Determine and write signs. int sx = __float_as_uint(v.x) >> 31 << 0; int sy = __float_as_uint(v.y) >> 31 << 8; if (sx) v.x *= p.slope; if (sy) v.y *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } if ((uint32_t)signXb < p.swLimit && signY >= minY) { // Combine signs. int s = sx + sy; s <<= signXo; s |= __shfl_xor_sync(groupMask, s, 1); s |= __shfl_xor_sync(groupMask, s, 2); // Write signs. if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } } } else { // Determine and write signs. if ((uint32_t)signXb < p.swLimit && signY >= minY) { int sx = __float_as_uint(v.x) >> 31 << 0; int sy = __float_as_uint(v.y) >> 31 << 8; if (sx) v.x *= p.slope; if (sy) v.y *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } // Combine signs. int s = sx + sy; s <<= signXo; s |= __shfl_xor_sync(groupMask, s, 1); s |= __shfl_xor_sync(groupMask, s, 2); // Write signs. if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } } else { // Just compute the values. if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); } } } else if (signRead) // Read signs and apply. { if ((uint32_t)signXb < p.swLimit) { if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> signXo; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; } if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> signXo; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; } } } else // Forward pass with no sign write. { if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); } if (!downInline) { // Write into temporary buffer. s_tileUpXY[dst] = v.x; if (relUpY0 < tileUpH - 1) s_tileUpXY[dst + tileUpW] = v.y; } else { // Write directly into output buffer. if ((uint32_t)x < p.yShape.x) { int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down); index_t ofs = x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut; if ((uint32_t)y + 0 < p.yShape.y) *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]); if ((uint32_t)y + 1 < ymax) *((T*)((char*)p.y + ofs + get_stride(p.yStride.y))) = (T)(v.y * (scalar_t)c_fd[0]); } } } } } else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD) { // Full upsampling filter. if (up == 2) { // 2 x 2-wide. __syncthreads(); int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y : 0; // Skip already written signs. for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; idx += blockDim.x * 4) { int relUpX0, relUpY0; fast_div_mod(relUpX0, relUpY0, idx); int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up); int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up); int src0 = relInX0 + tileInW * relInY0; int tap0y = (relInY0 * up + phaseInY - relUpY0); #define X_LOOP(TAPY, PX) \ for (int sx = 0; sx < fuSize / up; sx++) \ { \ v.x += a * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \ v.z += b * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 0) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \ v.y += a * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \ v.w += b * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 1) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \ } vec4_t v = InternalType::zero_vec4(); if (tap0y == 0 && phaseInX == 0) #pragma unroll for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; #pragma unroll X_LOOP(0, 0) } if (tap0y == 0 && phaseInX == 1) #pragma unroll for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; #pragma unroll X_LOOP(0, 1) } if (tap0y == 1 && phaseInX == 0) #pragma unroll for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; #pragma unroll X_LOOP(1, 0) } if (tap0y == 1 && phaseInX == 1) #pragma unroll for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; #pragma unroll X_LOOP(1, 1) } #undef X_LOOP int x = tileOutX * down + relUpX0; int y = tileOutY * down + relUpY0; int signX = x + p.sOfs.x; int signY = y + p.sOfs.y; int signZ = blockIdx.z + p.blockZofs; int signXb = signX >> 2; index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); v.x *= (scalar_t)((float)up * (float)up * p.gain); v.y *= (scalar_t)((float)up * (float)up * p.gain); v.z *= (scalar_t)((float)up * (float)up * p.gain); v.w *= (scalar_t)((float)up * (float)up * p.gain); if (signWrite) { if (!enableWriteSkip) { // Determine and write signs. int sx = __float_as_uint(v.x) >> 31; int sy = __float_as_uint(v.y) >> 31; int sz = __float_as_uint(v.z) >> 31; int sw = __float_as_uint(v.w) >> 31; if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); } if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); } if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); } if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); } if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) { p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6); } } else { // Determine and write signs. if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) { int sx = __float_as_uint(v.x) >> 31; int sy = __float_as_uint(v.y) >> 31; int sz = __float_as_uint(v.z) >> 31; int sw = __float_as_uint(v.w) >> 31; if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); } if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); } if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); } if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); } p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6); } else { // Just compute the values. if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); } } } else if (signRead) // Read sign and apply. { if ((uint32_t)signY < p.sShape.y) { int s = 0; if ((uint32_t)signXb < p.swLimit) s = p.s[si]; if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8; s >>= (signX & 3) << 1; if (s & 0x01) v.x *= p.slope; if (s & 0x02) v.x = 0.f; if (s & 0x04) v.y *= p.slope; if (s & 0x08) v.y = 0.f; if (s & 0x10) v.z *= p.slope; if (s & 0x20) v.z = 0.f; if (s & 0x40) v.w *= p.slope; if (s & 0x80) v.w = 0.f; } } else // Forward pass with no sign write. { if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); } s_tileUpXY[idx + 0] = v.x; s_tileUpXY[idx + 1] = v.y; s_tileUpXY[idx + 2] = v.z; s_tileUpXY[idx + 3] = v.w; } } else if (up == 1) { __syncthreads(); uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3); int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs. for (int idx = threadIdx.x; idx < tileUpW * tileUpH; idx += blockDim.x) { int relUpX0, relUpY0; fast_div_mod(relUpX0, relUpY0, idx); scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter. int x = tileOutX * down + relUpX0; int y = tileOutY * down + relUpY0; int signX = x + p.sOfs.x; int signY = y + p.sOfs.y; int signZ = blockIdx.z + p.blockZofs; int signXb = signX >> 2; index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); v *= (scalar_t)((float)up * (float)up * p.gain); if (signWrite) { if (!enableWriteSkip) { // Determine and write sign. uint32_t s = 0; uint32_t signXbit = (1u << signXo); if (v < 0.f) { s = signXbit; v *= p.slope; } if (fabsf(v) > p.clamp) { s = signXbit * 2; v = InternalType::clamp(v, p.clamp); } if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) { s += __shfl_xor_sync(groupMask, s, 1); // Coalesce. s += __shfl_xor_sync(groupMask, s, 2); // Coalesce. p.s[si] = s; // Write. } } else { // Determine and write sign. if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) { uint32_t s = 0; uint32_t signXbit = (1u << signXo); if (v < 0.f) { s = signXbit; v *= p.slope; } if (fabsf(v) > p.clamp) { s = signXbit * 2; v = InternalType::clamp(v, p.clamp); } s += __shfl_xor_sync(groupMask, s, 1); // Coalesce. s += __shfl_xor_sync(groupMask, s, 2); // Coalesce. p.s[si] = s; // Write. } else { // Just compute the value. if (v < 0.f) v *= p.slope; v = InternalType::clamp(v, p.clamp); } } } else if (signRead) { // Read sign and apply if within sign tensor bounds. if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y) { int s = p.s[si]; s >>= signXo; if (s & 1) v *= p.slope; if (s & 2) v = 0.f; } } else // Forward pass with no sign write. { if (v < 0.f) v *= p.slope; v = InternalType::clamp(v, p.clamp); } if (!downInline) // Write into temporary buffer. s_tileUpXY[idx] = v; else if ((uint32_t)x < p.yShape.x && (uint32_t)y < p.yShape.y) // Write directly into output buffer *((T*)((char*)p.y + (x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]); } } } // Downsampling. if (filterMode == MODE_SUSD || filterMode == MODE_FUSD) { // Horizontal downsampling. __syncthreads(); if (down == 4 && tileOutW % 4 == 0) { // Calculate 4 pixels at a time. for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; idx += blockDim.x * 4) { int relOutX0, relUpY; fast_div_mod(relOutX0, relUpY, idx); int relUpX0 = relOutX0 * down; int src0 = relUpY * tileUpW + relUpX0; vec4_t v = InternalType::zero_vec4(); #pragma unroll for (int step = 0; step < fdSize; step++) { v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step]; v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step]; v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step]; v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step]; } s_tileDownX[idx+0] = v.x; s_tileDownX[idx+1] = v.y; s_tileDownX[idx+2] = v.z; s_tileDownX[idx+3] = v.w; } } else if ((down == 2 || down == 4) && (tileOutW % 2 == 0)) { // Calculate 2 pixels at a time. for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; idx += blockDim.x * 2) { int relOutX0, relUpY; fast_div_mod(relOutX0, relUpY, idx); int relUpX0 = relOutX0 * down; int src0 = relUpY * tileUpW + relUpX0; vec2_t v = InternalType::zero_vec2(); #pragma unroll for (int step = 0; step < fdSize; step++) { v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step]; v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step]; } s_tileDownX[idx+0] = v.x; s_tileDownX[idx+1] = v.y; } } else { // Calculate 1 pixel at a time. for (int idx = threadIdx.x; idx < tileOutW * tileUpH; idx += blockDim.x) { int relOutX0, relUpY; fast_div_mod(relOutX0, relUpY, idx); int relUpX0 = relOutX0 * down; int src = relUpY * tileUpW + relUpX0; scalar_t v = 0.f; #pragma unroll for (int step = 0; step < fdSize; step++) v += s_tileUpXY[src + step] * (scalar_t)c_fd[step]; s_tileDownX[idx] = v; } } // Vertical downsampling & store output tile. __syncthreads(); for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x) { int relOutX, relOutY0; fast_div_mod(relOutX, relOutY0, idx); int relUpY0 = relOutY0 * down; int src0 = relUpY0 * tileOutW + relOutX; scalar_t v = 0; #pragma unroll for (int step = 0; step < fdSize; step++) v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step]; int outX = tileOutX + relOutX; int outY = tileOutY + relOutY0; if (outX < p.yShape.x & outY < p.yShape.y) *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v; } } else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD) { // Full downsampling filter. if (down == 2) { // 2-wide. __syncthreads(); for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; idx += blockDim.x * 2) { int relOutX0, relOutY0; fast_div_mod(relOutX0, relOutY0, idx); int relUpX0 = relOutX0 * down; int relUpY0 = relOutY0 * down; int src0 = relUpY0 * tileUpW + relUpX0; vec2_t v = InternalType::zero_vec2(); #pragma unroll for (int sy = 0; sy < fdSize; sy++) #pragma unroll for (int sx = 0; sx < fdSize; sx++) { v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE]; v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE]; } int outX = tileOutX + relOutX0; int outY = tileOutY + relOutY0; if ((uint32_t)outY < p.yShape.y) { index_t ofs = outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut; if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x; if (outX + 1 < p.yShape.x) *((T*)((char*)p.y + ofs + get_stride(p.yStride.x))) = (T)v.y; } } } else if (down == 1 && !downInline) { // Thread per pixel. __syncthreads(); for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x) { int relOutX0, relOutY0; fast_div_mod(relOutX0, relOutY0, idx); scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter. int outX = tileOutX + relOutX0; int outY = tileOutY + relOutY0; if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y) *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v; } } } if (!enableXrep) break; } } //------------------------------------------------------------------------ // Compute activation function and signs for upsampled data tensor, modifying data tensor in-place. Used for accelerating the generic variant. // Sign tensor is known to be contiguous, and p.x and p.s have the same z, w dimensions. 64-bit indexing is always used. template static __global__ void filtered_lrelu_act_kernel(filtered_lrelu_act_kernel_params p) { typedef typename InternalType::scalar_t scalar_t; // Indexing. int32_t x = threadIdx.x + blockIdx.x * blockDim.x; int32_t ymax = signWrite ? p.sShape.y : p.xShape.y; int32_t qmax = p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index. // Loop to accommodate oversized tensors. for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z) for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y) { // Extract z and w (channel, minibatch index). int32_t w = q / p.xShape.z; int32_t z = q - w * p.xShape.z; // Choose behavior based on sign read/write mode. if (signWrite) { // Process value if in p.x. uint32_t s = 0; if (x < p.xShape.x && y < p.xShape.y) { int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; T* pv = ((T*)p.x) + ix; scalar_t v = (scalar_t)(*pv); // Gain, LReLU, clamp. v *= p.gain; if (v < 0.f) { v *= p.slope; s = 1; // Sign. } if (fabsf(v) > p.clamp) { v = InternalType::clamp(v, p.clamp); s = 2; // Clamp. } *pv = (T)v; // Write value. } // Coalesce into threads 0 and 16 of warp. uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu; s <<= ((threadIdx.x & 15) << 1); // Shift into place. s |= __shfl_xor_sync(m, s, 1); // Distribute. s |= __shfl_xor_sync(m, s, 2); s |= __shfl_xor_sync(m, s, 4); s |= __shfl_xor_sync(m, s, 8); // Write signs if leader and in p.s. if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in. { uint64_t is = x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous. ((uint32_t*)p.s)[is >> 4] = s; } } else if (signRead) { // Process value if in p.x. if (x < p.xShape.x) // y is always in. { int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; T* pv = ((T*)p.x) + ix; scalar_t v = (scalar_t)(*pv); v *= p.gain; // Apply sign buffer offset. uint32_t sx = x + p.sOfs.x; uint32_t sy = y + p.sOfs.y; // Read and apply signs if we land inside valid region of sign buffer. if (sx < p.sShape.x && sy < p.sShape.y) { uint64_t is = (sx >> 2) + (p.sShape.x >> 2) * (sy + (uint64_t)p.sShape.y * q); // Contiguous. unsigned char s = p.s[is]; s >>= (sx & 3) << 1; // Shift into place. if (s & 1) // Sign? v *= p.slope; if (s & 2) // Clamp? v = 0.f; } *pv = (T)v; // Write value. } } else { // Forward pass with no sign write. Process value if in p.x. if (x < p.xShape.x) // y is always in. { int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; T* pv = ((T*)p.x) + ix; scalar_t v = (scalar_t)(*pv); v *= p.gain; if (v < 0.f) v *= p.slope; if (fabsf(v) > p.clamp) v = InternalType::clamp(v, p.clamp); *pv = (T)v; // Write value. } } } } template void* choose_filtered_lrelu_act_kernel(void) { return (void*)filtered_lrelu_act_kernel; } //------------------------------------------------------------------------ // CUDA kernel selection. template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB) { filtered_lrelu_kernel_spec s = { 0 }; // Return the first matching kernel. #define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \ if (sharedKB >= SH) \ if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \ if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \ if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) \ { \ static_assert((D*TW % 4) == 0, "down * tileWidth must be divisible by 4"); \ static_assert(FU % U == 0, "upscaling filter size must be multiple of upscaling factor"); \ static_assert(FD % D == 0, "downscaling filter size must be multiple of downscaling factor"); \ s.setup = (void*)setup_filters_kernel; \ s.exec = (void*)filtered_lrelu_kernel; \ s.tileOut = make_int2(TW, TH); \ s.numWarps = W; \ s.xrep = XR; \ s.dynamicSharedKB = (SH == 48) ? 0 : SH; \ return s; \ } // Launch parameters for various kernel specializations. // Small filters must be listed before large filters, otherwise the kernel for larger filter will always match first. // Kernels that use more shared memory must be listed before those that use less, for the same reason. CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/1,1, /*mode*/MODE_FUFD, /*tw,th,warps,xrep,wskip*/64, 178, 32, 0, 0) // 1t-upf1-downf1 CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/152, 95, 16, 0, 0) // 4t-ups2-downf1 CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 22, 16, 0, 0) // 4t-upf1-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 29, 16, 11, 0) // 4t-ups2-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/60, 28, 16, 0, 0) // 4t-upf2-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 28, 16, 0, 0) // 4t-ups2-downf2 CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 31, 16, 11, 0) // 4t-ups4-downs2 CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 36, 16, 0, 0) // 4t-ups4-downf2 CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 22, 16, 12, 0) // 4t-ups2-downs4 CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/29, 15, 16, 0, 0) // 4t-upf2-downs4 CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/96, 150, 28, 0, 0) // 6t-ups2-downf1 CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 35, 24, 0, 0) // 6t-upf1-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 16, 10, 0) // 6t-ups2-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/58, 28, 24, 8, 0) // 6t-upf2-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/52, 28, 16, 0, 0) // 6t-ups2-downf2 CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 51, 16, 5, 0) // 6t-ups4-downs2 CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 56, 16, 6, 0) // 6t-ups4-downf2 CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 18, 16, 12, 0) // 6t-ups2-downs4 CASE(/*sharedKB*/96, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 13, 24, 0, 0) // 6t-upf2-downs4 CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/148, 89, 24, 0, 0) // 8t-ups2-downf1 CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 31, 16, 5, 0) // 8t-upf1-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 41, 16, 9, 0) // 8t-ups2-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 26, 24, 0, 0) // 8t-upf2-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 40, 16, 0, 0) // 8t-ups2-downf2 CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 24, 5, 0) // 8t-ups4-downs2 CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 50, 16, 0, 0) // 8t-ups4-downf2 CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 13, 16, 10, 1) // 8t-ups2-downs4 CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 10, 24, 0, 0) // 8t-upf2-downs4 #undef CASE return s; // No kernel found. } //------------------------------------------------------------------------ ================================================ FILE: stylegan_human/torch_utils/ops/filtered_lrelu.h ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include //------------------------------------------------------------------------ // CUDA kernel parameters. struct filtered_lrelu_kernel_params { // These parameters decide which kernel to use. int up; // upsampling ratio (1, 2, 4) int down; // downsampling ratio (1, 2, 4) int2 fuShape; // [size, 1] | [size, size] int2 fdShape; // [size, 1] | [size, size] int _dummy; // Alignment. // Rest of the parameters. const void* x; // Input tensor. void* y; // Output tensor. const void* b; // Bias tensor. unsigned char* s; // Sign tensor in/out. NULL if unused. const float* fu; // Upsampling filter. const float* fd; // Downsampling filter. int2 pad0; // Left/top padding. float gain; // Additional gain factor. float slope; // Leaky ReLU slope on negative side. float clamp; // Clamp after nonlinearity. int flip; // Filter kernel flip for gradient computation. int tilesXdim; // Original number of horizontal output tiles. int tilesXrep; // Number of horizontal tiles per CTA. int blockZofs; // Block z offset to support large minibatch, channel dimensions. int4 xShape; // [width, height, channel, batch] int4 yShape; // [width, height, channel, batch] int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. int swLimit; // Active width of sign tensor in bytes. longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. longlong4 yStride; // int64_t bStride; // longlong3 fuStride; // longlong3 fdStride; // }; struct filtered_lrelu_act_kernel_params { void* x; // Input/output, modified in-place. unsigned char* s; // Sign tensor in/out. NULL if unused. float gain; // Additional gain factor. float slope; // Leaky ReLU slope on negative side. float clamp; // Clamp after nonlinearity. int4 xShape; // [width, height, channel, batch] longlong4 xStride; // Input/output tensor strides, same order as in shape. int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. }; //------------------------------------------------------------------------ // CUDA kernel specialization. struct filtered_lrelu_kernel_spec { void* setup; // Function for filter kernel setup. void* exec; // Function for main operation. int2 tileOut; // Width/height of launch tile. int numWarps; // Number of warps per thread block, determines launch block size. int xrep; // For processing multiple horizontal tiles per thread block. int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. }; //------------------------------------------------------------------------ // CUDA kernel selection. template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); template void* choose_filtered_lrelu_act_kernel(void); template cudaError_t copy_filters(cudaStream_t stream); //------------------------------------------------------------------------ ================================================ FILE: stylegan_human/torch_utils/ops/filtered_lrelu.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import os import numpy as np import torch import warnings from .. import custom_ops from .. import misc from . import upfirdn2d from . import bias_act #---------------------------------------------------------------------------- _plugin = None def _init(): global _plugin if _plugin is None: # sources=['filtered_lrelu.h', 'filtered_lrelu.cu', 'filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'] # sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] # try: # _plugin = custom_ops.get_plugin('filtered_lrelu_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler']) # except: # warnings.warn('Failed to build CUDA kernels for filtered_lrelu_plugin. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) _plugin = custom_ops.get_plugin_v3( module_name='filtered_lrelu_plugin', sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'], headers=['filtered_lrelu.h', 'filtered_lrelu.cu'], source_dir=os.path.dirname(__file__), extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'], ) return True def _get_filter_size(f): if f is None: return 1, 1 assert isinstance(f, torch.Tensor) assert 1 <= f.ndim <= 2 return f.shape[-1], f.shape[0] # width, height def _parse_padding(padding): if isinstance(padding, int): padding = [padding, padding] assert isinstance(padding, (list, tuple)) assert all(isinstance(x, (int, np.integer)) for x in padding) padding = [int(x) for x in padding] if len(padding) == 2: px, py = padding padding = [px, px, py, py] px0, px1, py0, py1 = padding return px0, px1, py0, py1 #---------------------------------------------------------------------------- def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'): r"""Filtered leaky ReLU for a batch of 2D images. Performs the following sequence of operations for each channel: 1. Add channel-specific bias if provided (`b`). 2. Upsample the image by inserting N-1 zeros after each pixel (`up`). 3. Pad the image with the specified number of zeros on each side (`padding`). Negative padding corresponds to cropping the image. 4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it so that the footprint of all output pixels lies within the input image. 5. Multiply each value by the provided gain factor (`gain`). 6. Apply leaky ReLU activation function to each value. 7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided. 8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking it so that the footprint of all output pixels lies within the input image. 9. Downsample the image by keeping every Nth pixel (`down`). The fused op is considerably more efficient than performing the same calculation using standard PyTorch ops. It supports gradients of arbitrary order. Args: x: Float32/float16/float64 input tensor of the shape `[batch_size, num_channels, in_height, in_width]`. fu: Float32 upsampling FIR filter of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), or `None` (identity). fd: Float32 downsampling FIR filter of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), or `None` (identity). b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type as `x`. The length of vector must must match the channel dimension of `x`. up: Integer upsampling factor (default: 1). down: Integer downsampling factor. (default: 1). padding: Padding with respect to the upsampled image. Can be a single number or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` (default: 0). gain: Overall scaling factor for signal magnitude (default: sqrt(2)). slope: Slope on the negative side of leaky ReLU (default: 0.2). clamp: Maximum magnitude for leaky ReLU output (default: None). flip_filter: False = convolution, True = correlation (default: False). impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). Returns: Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. """ assert isinstance(x, torch.Tensor) assert impl in ['ref', 'cuda'] if impl == 'cuda' and x.device.type == 'cuda' and _init(): return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0) return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter) #---------------------------------------------------------------------------- @misc.profiled_function def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): """Slow and memory-inefficient reference implementation of `filtered_lrelu()` using existing `upfirdn2n()` and `bias_act()` ops. """ assert isinstance(x, torch.Tensor) and x.ndim == 4 fu_w, fu_h = _get_filter_size(fu) fd_w, fd_h = _get_filter_size(fd) if b is not None: assert isinstance(b, torch.Tensor) and b.dtype == x.dtype misc.assert_shape(b, [x.shape[1]]) assert isinstance(up, int) and up >= 1 assert isinstance(down, int) and down >= 1 px0, px1, py0, py1 = _parse_padding(padding) assert gain == float(gain) and gain > 0 assert slope == float(slope) and slope >= 0 assert clamp is None or (clamp == float(clamp) and clamp >= 0) # Calculate output size. batch_size, channels, in_h, in_w = x.shape in_dtype = x.dtype out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down # Compute using existing ops. x = bias_act.bias_act(x=x, b=b) # Apply bias. x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample. x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp. x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample. # Check output shape & dtype. misc.assert_shape(x, [batch_size, channels, out_h, out_w]) assert x.dtype == in_dtype return x #---------------------------------------------------------------------------- _filtered_lrelu_cuda_cache = dict() def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): """Fast CUDA implementation of `filtered_lrelu()` using custom ops. """ assert isinstance(up, int) and up >= 1 assert isinstance(down, int) and down >= 1 px0, px1, py0, py1 = _parse_padding(padding) assert gain == float(gain) and gain > 0 gain = float(gain) assert slope == float(slope) and slope >= 0 slope = float(slope) assert clamp is None or (clamp == float(clamp) and clamp >= 0) clamp = float(clamp if clamp is not None else 'inf') # Lookup from cache. key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter) if key in _filtered_lrelu_cuda_cache: return _filtered_lrelu_cuda_cache[key] # Forward op. class FilteredLReluCuda(torch.autograd.Function): @staticmethod def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ assert isinstance(x, torch.Tensor) and x.ndim == 4 # Replace empty up/downsample kernels with full 1x1 kernels (faster than separable). if fu is None: fu = torch.ones([1, 1], dtype=torch.float32, device=x.device) if fd is None: fd = torch.ones([1, 1], dtype=torch.float32, device=x.device) assert 1 <= fu.ndim <= 2 assert 1 <= fd.ndim <= 2 # Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1. if up == 1 and fu.ndim == 1 and fu.shape[0] == 1: fu = fu.square()[None] if down == 1 and fd.ndim == 1 and fd.shape[0] == 1: fd = fd.square()[None] # Missing sign input tensor. if si is None: si = torch.empty([0]) # Missing bias tensor. if b is None: b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device) # Construct internal sign tensor only if gradients are needed. write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad) # Warn if input storage strides are not in decreasing order due to e.g. channels-last layout. strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1] if any(a < b for a, b in zip(strides[:-1], strides[1:])): warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning) # Call C++/Cuda plugin if datatype is supported. if x.dtype in [torch.float16, torch.float32]: if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device): warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning) y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs) else: return_code = -1 # No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because # only the bit-packed sign tensor is retained for gradient computation. if return_code < 0: warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning) y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias. y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample. so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place. y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample. # Prepare for gradient computation. ctx.save_for_backward(fu, fd, (si if si.numel() else so)) ctx.x_shape = x.shape ctx.y_shape = y.shape ctx.s_ofs = sx, sy return y @staticmethod def backward(ctx, dy): # pylint: disable=arguments-differ fu, fd, si = ctx.saved_tensors _, _, xh, xw = ctx.x_shape _, _, yh, yw = ctx.y_shape sx, sy = ctx.s_ofs dx = None # 0 dfu = None; assert not ctx.needs_input_grad[1] dfd = None; assert not ctx.needs_input_grad[2] db = None # 3 dsi = None; assert not ctx.needs_input_grad[4] dsx = None; assert not ctx.needs_input_grad[5] dsy = None; assert not ctx.needs_input_grad[6] if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]: pp = [ (fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0, xw * up - yw * down + px0 - (up - 1), (fu.shape[0] - 1) + (fd.shape[0] - 1) - py0, xh * up - yh * down + py0 - (up - 1), ] gg = gain * (up ** 2) / (down ** 2) ff = (not flip_filter) sx = sx - (fu.shape[-1] - 1) + px0 sy = sy - (fu.shape[0] - 1) + py0 dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy) if ctx.needs_input_grad[3]: db = dx.sum([0, 2, 3]) return dx, dfu, dfd, db, dsi, dsx, dsy # Add to cache. _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda return FilteredLReluCuda #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/torch_utils/ops/filtered_lrelu_ns.cu ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include "filtered_lrelu.cu" // Template/kernel specializations for no signs mode (no gradients required). // Full op, 32-bit indexing. template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); // Full op, 64-bit indexing. template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); // Activation/signs only for generic variant. 64-bit indexing. template void* choose_filtered_lrelu_act_kernel(void); template void* choose_filtered_lrelu_act_kernel(void); template void* choose_filtered_lrelu_act_kernel(void); // Copy filters to constant memory. template cudaError_t copy_filters(cudaStream_t stream); ================================================ FILE: stylegan_human/torch_utils/ops/filtered_lrelu_rd.cu ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include "filtered_lrelu.cu" // Template/kernel specializations for sign read mode. // Full op, 32-bit indexing. template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); // Full op, 64-bit indexing. template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); // Activation/signs only for generic variant. 64-bit indexing. template void* choose_filtered_lrelu_act_kernel(void); template void* choose_filtered_lrelu_act_kernel(void); template void* choose_filtered_lrelu_act_kernel(void); // Copy filters to constant memory. template cudaError_t copy_filters(cudaStream_t stream); ================================================ FILE: stylegan_human/torch_utils/ops/filtered_lrelu_wr.cu ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include "filtered_lrelu.cu" // Template/kernel specializations for sign write mode. // Full op, 32-bit indexing. template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); // Full op, 64-bit indexing. template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); // Activation/signs only for generic variant. 64-bit indexing. template void* choose_filtered_lrelu_act_kernel(void); template void* choose_filtered_lrelu_act_kernel(void); template void* choose_filtered_lrelu_act_kernel(void); // Copy filters to constant memory. template cudaError_t copy_filters(cudaStream_t stream); ================================================ FILE: stylegan_human/torch_utils/ops/fma.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" import torch #---------------------------------------------------------------------------- def fma(a, b, c): # => a * b + c return _FusedMultiplyAdd.apply(a, b, c) #---------------------------------------------------------------------------- class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c @staticmethod def forward(ctx, a, b, c): # pylint: disable=arguments-differ out = torch.addcmul(c, a, b) ctx.save_for_backward(a, b) ctx.c_shape = c.shape return out @staticmethod def backward(ctx, dout): # pylint: disable=arguments-differ a, b = ctx.saved_tensors c_shape = ctx.c_shape da = None db = None dc = None if ctx.needs_input_grad[0]: da = _unbroadcast(dout * b, a.shape) if ctx.needs_input_grad[1]: db = _unbroadcast(dout * a, b.shape) if ctx.needs_input_grad[2]: dc = _unbroadcast(dout, c_shape) return da, db, dc #---------------------------------------------------------------------------- def _unbroadcast(x, shape): extra_dims = x.ndim - len(shape) assert extra_dims >= 0 dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] if len(dim): x = x.sum(dim=dim, keepdim=True) if extra_dims: x = x.reshape(-1, *x.shape[extra_dims+1:]) assert x.shape == shape return x #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/torch_utils/ops/grid_sample_gradfix.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Custom replacement for `torch.nn.functional.grid_sample` that supports arbitrarily high order gradients between the input and output. Only works on 2D images and assumes `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" import warnings import torch # pylint: disable=redefined-builtin # pylint: disable=arguments-differ # pylint: disable=protected-access #---------------------------------------------------------------------------- enabled = False # Enable the custom op by setting this to true. #---------------------------------------------------------------------------- def grid_sample(input, grid): if _should_use_custom_op(): return _GridSample2dForward.apply(input, grid) return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) #---------------------------------------------------------------------------- def _should_use_custom_op(): if not enabled: return False if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): return True warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') return False #---------------------------------------------------------------------------- class _GridSample2dForward(torch.autograd.Function): @staticmethod def forward(ctx, input, grid): assert input.ndim == 4 assert grid.ndim == 4 output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) ctx.save_for_backward(input, grid) return output @staticmethod def backward(ctx, grad_output): input, grid = ctx.saved_tensors grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) return grad_input, grad_grid #---------------------------------------------------------------------------- class _GridSample2dBackward(torch.autograd.Function): @staticmethod def forward(ctx, grad_output, input, grid): op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) ctx.save_for_backward(grid) return grad_input, grad_grid @staticmethod def backward(ctx, grad2_grad_input, grad2_grad_grid): _ = grad2_grad_grid # unused grid, = ctx.saved_tensors grad2_grad_output = None grad2_input = None grad2_grid = None if ctx.needs_input_grad[0]: grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) assert not ctx.needs_input_grad[2] return grad2_grad_output, grad2_input, grad2_grid #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/torch_utils/ops/upfirdn2d.cpp ================================================ // Copyright (c) SenseTime Research. All rights reserved. // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include #include #include #include "upfirdn2d.h" //------------------------------------------------------------------------ static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) { // Validate arguments. TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); TORCH_CHECK(x.dim() == 4, "x must be rank 4"); TORCH_CHECK(f.dim() == 2, "f must be rank 2"); TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); // Create output tensor. const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); // Initialize CUDA kernel parameters. upfirdn2d_kernel_params p; p.x = x.data_ptr(); p.f = f.data_ptr(); p.y = y.data_ptr(); p.up = make_int2(upx, upy); p.down = make_int2(downx, downy); p.pad0 = make_int2(padx0, pady0); p.flip = (flip) ? 1 : 0; p.gain = gain; p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; // Choose CUDA kernel. upfirdn2d_kernel_spec spec; AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { spec = choose_upfirdn2d_kernel(p); }); // Set looping options. p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; p.loopMinor = spec.loopMinor; p.loopX = spec.loopX; p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; // Compute grid size. dim3 blockSize, gridSize; if (spec.tileOutW < 0) // large { blockSize = dim3(4, 32, 1); gridSize = dim3( ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, p.launchMajor); } else // small { blockSize = dim3(256, 1, 1); gridSize = dim3( ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, p.launchMajor); } // Launch CUDA kernel. void* args[] = {&p}; AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); return y; } //------------------------------------------------------------------------ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("upfirdn2d", &upfirdn2d); } //------------------------------------------------------------------------ ================================================ FILE: stylegan_human/torch_utils/ops/upfirdn2d.cu ================================================ // Copyright (c) SenseTime Research. All rights reserved. // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include #include "upfirdn2d.h" //------------------------------------------------------------------------ // Helpers. template struct InternalType; template <> struct InternalType { typedef double scalar_t; }; template <> struct InternalType { typedef float scalar_t; }; template <> struct InternalType { typedef float scalar_t; }; static __device__ __forceinline__ int floor_div(int a, int b) { int t = 1 - a / b; return (a + t * b) / b - t; } //------------------------------------------------------------------------ // Generic CUDA implementation for large filters. template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) { typedef typename InternalType::scalar_t scalar_t; // Calculate thread index. int minorBase = blockIdx.x * blockDim.x + threadIdx.x; int outY = minorBase / p.launchMinor; minorBase -= outY * p.launchMinor; int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; int majorBase = blockIdx.z * p.loopMajor; if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor) return; // Setup Y receptive field. int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y; int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y); int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY; int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y; if (p.flip) filterY = p.filterSize.y - 1 - filterY; // Loop over major, minor, and X. for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor) { int nc = major * p.sizeMinor + minor; int n = nc / p.inSize.z; int c = nc - n * p.inSize.z; for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y) { // Setup X receptive field. int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x; int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x); int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX; int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x; if (p.flip) filterX = p.filterSize.x - 1 - filterX; // Initialize pointers. const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y]; int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x; int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y; // Inner loop. scalar_t v = 0; for (int y = 0; y < h; y++) { for (int x = 0; x < w; x++) { v += (scalar_t)(*xp) * (scalar_t)(*fp); xp += p.inStride.x; fp += filterStepX; } xp += p.inStride.y - w * p.inStride.x; fp += filterStepY - w * filterStepX; } // Store result. v *= p.gain; ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; } } } //------------------------------------------------------------------------ // Specialized CUDA implementation for small filters. template static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) { typedef typename InternalType::scalar_t scalar_t; const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1; const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1; __shared__ volatile scalar_t sf[filterH][filterW]; __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor]; // Calculate tile index. int minorBase = blockIdx.x; int tileOutY = minorBase / p.launchMinor; minorBase -= tileOutY * p.launchMinor; minorBase *= loopMinor; tileOutY *= tileOutH; int tileOutXBase = blockIdx.y * p.loopX * tileOutW; int majorBase = blockIdx.z * p.loopMajor; if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor) return; // Load filter (flipped). for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x) { int fy = tapIdx / filterW; int fx = tapIdx - fy * filterW; scalar_t v = 0; if (fx < p.filterSize.x & fy < p.filterSize.y) { int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx; int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy; v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y]; } sf[fy][fx] = v; } // Loop over major and X. for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) { int baseNC = major * p.sizeMinor + minorBase; int n = baseNC / p.inSize.z; int baseC = baseNC - n * p.inSize.z; for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW) { // Load input pixels. int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x; int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y; int tileInX = floor_div(tileMidX, upx); int tileInY = floor_div(tileMidY, upy); __syncthreads(); for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x) { int relC = inIdx; int relInX = relC / loopMinor; int relInY = relInX / tileInW; relC -= relInX * loopMinor; relInX -= relInY * tileInW; int c = baseC + relC; int inX = tileInX + relInX; int inY = tileInY + relInY; scalar_t v = 0; if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z) v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; sx[relInY][relInX][relC] = v; } // Loop over output pixels. __syncthreads(); for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x) { int relC = outIdx; int relOutX = relC / loopMinor; int relOutY = relOutX / tileOutW; relC -= relOutX * loopMinor; relOutX -= relOutY * tileOutW; int c = baseC + relC; int outX = tileOutX + relOutX; int outY = tileOutY + relOutY; // Setup receptive field. int midX = tileMidX + relOutX * downx; int midY = tileMidY + relOutY * downy; int inX = floor_div(midX, upx); int inY = floor_div(midY, upy); int relInX = inX - tileInX; int relInY = inY - tileInY; int filterX = (inX + 1) * upx - midX - 1; // flipped int filterY = (inY + 1) * upy - midY - 1; // flipped // Inner loop. if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) { scalar_t v = 0; #pragma unroll for (int y = 0; y < filterH / upy; y++) #pragma unroll for (int x = 0; x < filterW / upx; x++) v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx]; v *= p.gain; ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; } } } } } //------------------------------------------------------------------------ // CUDA kernel selection. template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p) { int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y; upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous { if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; } if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last { if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; } if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous { if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; } if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last { if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; } if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous { if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; } if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last { if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; } if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous { if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; } if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last { if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; } if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous { if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; } if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last { if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; } if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous { if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; } if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last { if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; } if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous { if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; } if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last { if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; } return spec; } //------------------------------------------------------------------------ // Template specializations. template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); //------------------------------------------------------------------------ ================================================ FILE: stylegan_human/torch_utils/ops/upfirdn2d.h ================================================ // Copyright (c) SenseTime Research. All rights reserved. // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include //------------------------------------------------------------------------ // CUDA kernel parameters. struct upfirdn2d_kernel_params { const void* x; const float* f; void* y; int2 up; int2 down; int2 pad0; int flip; float gain; int4 inSize; // [width, height, channel, batch] int4 inStride; int2 filterSize; // [width, height] int2 filterStride; int4 outSize; // [width, height, channel, batch] int4 outStride; int sizeMinor; int sizeMajor; int loopMinor; int loopMajor; int loopX; int launchMinor; int launchMajor; }; //------------------------------------------------------------------------ // CUDA kernel specialization. struct upfirdn2d_kernel_spec { void* kernel; int tileOutW; int tileOutH; int loopMinor; int loopX; }; //------------------------------------------------------------------------ // CUDA kernel selection. template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); //------------------------------------------------------------------------ ================================================ FILE: stylegan_human/torch_utils/ops/upfirdn2d.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Custom PyTorch ops for efficient resampling of 2D images.""" import os import warnings import numpy as np import torch import traceback from .. import custom_ops from .. import misc from . import conv2d_gradfix #---------------------------------------------------------------------------- _inited = False _plugin = None def _init(): global _inited, _plugin if not _inited: sources = ['upfirdn2d.cpp', 'upfirdn2d.cu'] sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] try: _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) except: warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) return _plugin is not None def _parse_scaling(scaling): if isinstance(scaling, int): scaling = [scaling, scaling] assert isinstance(scaling, (list, tuple)) assert all(isinstance(x, int) for x in scaling) sx, sy = scaling assert sx >= 1 and sy >= 1 return sx, sy def _parse_padding(padding): if isinstance(padding, int): padding = [padding, padding] assert isinstance(padding, (list, tuple)) assert all(isinstance(x, int) for x in padding) if len(padding) == 2: padx, pady = padding padding = [padx, padx, pady, pady] padx0, padx1, pady0, pady1 = padding return padx0, padx1, pady0, pady1 def _get_filter_size(f): if f is None: return 1, 1 assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] fw = f.shape[-1] fh = f.shape[0] with misc.suppress_tracer_warnings(): fw = int(fw) fh = int(fh) misc.assert_shape(f, [fh, fw][:f.ndim]) assert fw >= 1 and fh >= 1 return fw, fh #---------------------------------------------------------------------------- def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. Args: f: Torch tensor, numpy array, or python list of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), `[]` (impulse), or `None` (identity). device: Result device (default: cpu). normalize: Normalize the filter so that it retains the magnitude for constant input signal (DC)? (default: True). flip_filter: Flip the filter? (default: False). gain: Overall scaling factor for signal magnitude (default: 1). separable: Return a separable filter? (default: select automatically). Returns: Float32 tensor of the shape `[filter_height, filter_width]` (non-separable) or `[filter_taps]` (separable). """ # Validate. if f is None: f = 1 f = torch.as_tensor(f, dtype=torch.float32) assert f.ndim in [0, 1, 2] assert f.numel() > 0 if f.ndim == 0: f = f[np.newaxis] # Separable? if separable is None: separable = (f.ndim == 1 and f.numel() >= 8) if f.ndim == 1 and not separable: f = f.ger(f) assert f.ndim == (1 if separable else 2) # Apply normalize, flip, gain, and device. if normalize: f /= f.sum() if flip_filter: f = f.flip(list(range(f.ndim))) f = f * (gain ** (f.ndim / 2)) f = f.to(device=device) return f #---------------------------------------------------------------------------- def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): r"""Pad, upsample, filter, and downsample a batch of 2D images. Performs the following sequence of operations for each channel: 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). 2. Pad the image with the specified number of zeros on each side (`padding`). Negative padding corresponds to cropping the image. 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it so that the footprint of all output pixels lies within the input image. 4. Downsample the image by keeping every Nth pixel (`down`). This sequence of operations bears close resemblance to scipy.signal.upfirdn(). The fused op is considerably more efficient than performing the same calculation using standard PyTorch ops. It supports gradients of arbitrary order. Args: x: Float32/float64/float16 input tensor of the shape `[batch_size, num_channels, in_height, in_width]`. f: Float32 FIR filter of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), or `None` (identity). up: Integer upsampling factor. Can be a single int or a list/tuple `[x, y]` (default: 1). down: Integer downsampling factor. Can be a single int or a list/tuple `[x, y]` (default: 1). padding: Padding with respect to the upsampled image. Can be a single number or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` (default: 0). flip_filter: False = convolution, True = correlation (default: False). gain: Overall scaling factor for signal magnitude (default: 1). impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). Returns: Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. """ assert isinstance(x, torch.Tensor) assert impl in ['ref', 'cuda'] if impl == 'cuda' and x.device.type == 'cuda' and _init(): return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) #---------------------------------------------------------------------------- @misc.profiled_function def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. """ # Validate arguments. assert isinstance(x, torch.Tensor) and x.ndim == 4 if f is None: f = torch.ones([1, 1], dtype=torch.float32, device=x.device) assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] assert f.dtype == torch.float32 and not f.requires_grad batch_size, num_channels, in_height, in_width = x.shape upx, upy = _parse_scaling(up) downx, downy = _parse_scaling(down) padx0, padx1, pady0, pady1 = _parse_padding(padding) # Upsample by inserting zeros. x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) # Pad or crop. x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)] # Setup filter. f = f * (gain ** (f.ndim / 2)) f = f.to(x.dtype) if not flip_filter: f = f.flip(list(range(f.ndim))) # Convolve with the filter. f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) if f.ndim == 4: x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels) else: x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) # Downsample by throwing away pixels. x = x[:, :, ::downy, ::downx] return x #---------------------------------------------------------------------------- _upfirdn2d_cuda_cache = dict() def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): """Fast CUDA implementation of `upfirdn2d()` using custom ops. """ # Parse arguments. upx, upy = _parse_scaling(up) downx, downy = _parse_scaling(down) padx0, padx1, pady0, pady1 = _parse_padding(padding) # Lookup from cache. key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) if key in _upfirdn2d_cuda_cache: return _upfirdn2d_cuda_cache[key] # Forward op. class Upfirdn2dCuda(torch.autograd.Function): @staticmethod def forward(ctx, x, f): # pylint: disable=arguments-differ assert isinstance(x, torch.Tensor) and x.ndim == 4 if f is None: f = torch.ones([1, 1], dtype=torch.float32, device=x.device) assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] y = x if f.ndim == 2: y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) else: y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain)) y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain)) ctx.save_for_backward(f) ctx.x_shape = x.shape return y @staticmethod def backward(ctx, dy): # pylint: disable=arguments-differ f, = ctx.saved_tensors _, _, ih, iw = ctx.x_shape _, _, oh, ow = dy.shape fw, fh = _get_filter_size(f) p = [ fw - padx0 - 1, iw * upx - ow * downx + padx0 - upx + 1, fh - pady0 - 1, ih * upy - oh * downy + pady0 - upy + 1, ] dx = None df = None if ctx.needs_input_grad[0]: dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) assert not ctx.needs_input_grad[1] return dx, df # Add to cache. _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda return Upfirdn2dCuda #---------------------------------------------------------------------------- def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): r"""Filter a batch of 2D images using the given 2D FIR filter. By default, the result is padded so that its shape matches the input. User-specified padding is applied on top of that, with negative values indicating cropping. Pixels outside the image are assumed to be zero. Args: x: Float32/float64/float16 input tensor of the shape `[batch_size, num_channels, in_height, in_width]`. f: Float32 FIR filter of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), or `None` (identity). padding: Padding with respect to the output. Can be a single number or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` (default: 0). flip_filter: False = convolution, True = correlation (default: False). gain: Overall scaling factor for signal magnitude (default: 1). impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). Returns: Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. """ padx0, padx1, pady0, pady1 = _parse_padding(padding) fw, fh = _get_filter_size(f) p = [ padx0 + fw // 2, padx1 + (fw - 1) // 2, pady0 + fh // 2, pady1 + (fh - 1) // 2, ] return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) #---------------------------------------------------------------------------- def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): r"""Upsample a batch of 2D images using the given 2D FIR filter. By default, the result is padded so that its shape is a multiple of the input. User-specified padding is applied on top of that, with negative values indicating cropping. Pixels outside the image are assumed to be zero. Args: x: Float32/float64/float16 input tensor of the shape `[batch_size, num_channels, in_height, in_width]`. f: Float32 FIR filter of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), or `None` (identity). up: Integer upsampling factor. Can be a single int or a list/tuple `[x, y]` (default: 1). padding: Padding with respect to the output. Can be a single number or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` (default: 0). flip_filter: False = convolution, True = correlation (default: False). gain: Overall scaling factor for signal magnitude (default: 1). impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). Returns: Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. """ upx, upy = _parse_scaling(up) padx0, padx1, pady0, pady1 = _parse_padding(padding) fw, fh = _get_filter_size(f) p = [ padx0 + (fw + upx - 1) // 2, padx1 + (fw - upx) // 2, pady0 + (fh + upy - 1) // 2, pady1 + (fh - upy) // 2, ] return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl) #---------------------------------------------------------------------------- def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): r"""Downsample a batch of 2D images using the given 2D FIR filter. By default, the result is padded so that its shape is a fraction of the input. User-specified padding is applied on top of that, with negative values indicating cropping. Pixels outside the image are assumed to be zero. Args: x: Float32/float64/float16 input tensor of the shape `[batch_size, num_channels, in_height, in_width]`. f: Float32 FIR filter of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), or `None` (identity). down: Integer downsampling factor. Can be a single int or a list/tuple `[x, y]` (default: 1). padding: Padding with respect to the input. Can be a single number or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` (default: 0). flip_filter: False = convolution, True = correlation (default: False). gain: Overall scaling factor for signal magnitude (default: 1). impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). Returns: Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. """ downx, downy = _parse_scaling(down) padx0, padx1, pady0, pady1 = _parse_padding(padding) fw, fh = _get_filter_size(f) p = [ padx0 + (fw - downx + 1) // 2, padx1 + (fw - downx) // 2, pady0 + (fh - downy + 1) // 2, pady1 + (fh - downy) // 2, ] return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/torch_utils/persistence.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Facilities for pickling Python code alongside other data. The pickled code is automatically imported into a separate Python module during unpickling. This way, any previously exported pickles will remain usable even if the original code is no longer available, or if the current version of the code is not consistent with what was originally pickled.""" import sys import pickle import io import inspect import copy import uuid import types import dnnlib #---------------------------------------------------------------------------- _version = 6 # internal version number _decorators = set() # {decorator_class, ...} _import_hooks = [] # [hook_function, ...] _module_to_src_dict = dict() # {module: src, ...} _src_to_module_dict = dict() # {src: module, ...} #---------------------------------------------------------------------------- def persistent_class(orig_class): r"""Class decorator that extends a given class to save its source code when pickled. Example: from torch_utils import persistence @persistence.persistent_class class MyNetwork(torch.nn.Module): def __init__(self, num_inputs, num_outputs): super().__init__() self.fc = MyLayer(num_inputs, num_outputs) ... @persistence.persistent_class class MyLayer(torch.nn.Module): ... When pickled, any instance of `MyNetwork` and `MyLayer` will save its source code alongside other internal state (e.g., parameters, buffers, and submodules). This way, any previously exported pickle will remain usable even if the class definitions have been modified or are no longer available. The decorator saves the source code of the entire Python module containing the decorated class. It does *not* save the source code of any imported modules. Thus, the imported modules must be available during unpickling, also including `torch_utils.persistence` itself. It is ok to call functions defined in the same module from the decorated class. However, if the decorated class depends on other classes defined in the same module, they must be decorated as well. This is illustrated in the above example in the case of `MyLayer`. It is also possible to employ the decorator just-in-time before calling the constructor. For example: cls = MyLayer if want_to_make_it_persistent: cls = persistence.persistent_class(cls) layer = cls(num_inputs, num_outputs) As an additional feature, the decorator also keeps track of the arguments that were used to construct each instance of the decorated class. The arguments can be queried via `obj.init_args` and `obj.init_kwargs`, and they are automatically pickled alongside other object state. A typical use case is to first unpickle a previous instance of a persistent class, and then upgrade it to use the latest version of the source code: with open('old_pickle.pkl', 'rb') as f: old_net = pickle.load(f) new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) misc.copy_params_and_buffers(old_net, new_net, require_all=True) """ assert isinstance(orig_class, type) if is_persistent(orig_class): return orig_class assert orig_class.__module__ in sys.modules orig_module = sys.modules[orig_class.__module__] orig_module_src = _module_to_src(orig_module) class Decorator(orig_class): _orig_module_src = orig_module_src _orig_class_name = orig_class.__name__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._init_args = copy.deepcopy(args) self._init_kwargs = copy.deepcopy(kwargs) assert orig_class.__name__ in orig_module.__dict__ _check_pickleable(self.__reduce__()) @property def init_args(self): return copy.deepcopy(self._init_args) @property def init_kwargs(self): return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) def __reduce__(self): fields = list(super().__reduce__()) fields += [None] * max(3 - len(fields), 0) if fields[0] is not _reconstruct_persistent_obj: meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) fields[0] = _reconstruct_persistent_obj # reconstruct func fields[1] = (meta,) # reconstruct args fields[2] = None # state dict return tuple(fields) Decorator.__name__ = orig_class.__name__ _decorators.add(Decorator) return Decorator #---------------------------------------------------------------------------- def is_persistent(obj): r"""Test whether the given object or class is persistent, i.e., whether it will save its source code when pickled. """ try: if obj in _decorators: return True except TypeError: pass return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck #---------------------------------------------------------------------------- def import_hook(hook): r"""Register an import hook that is called whenever a persistent object is being unpickled. A typical use case is to patch the pickled source code to avoid errors and inconsistencies when the API of some imported module has changed. The hook should have the following signature: hook(meta) -> modified meta `meta` is an instance of `dnnlib.EasyDict` with the following fields: type: Type of the persistent object, e.g. `'class'`. version: Internal version number of `torch_utils.persistence`. module_src Original source code of the Python module. class_name: Class name in the original Python module. state: Internal state of the object. Example: @persistence.import_hook def wreck_my_network(meta): if meta.class_name == 'MyNetwork': print('MyNetwork is being imported. I will wreck it!') meta.module_src = meta.module_src.replace("True", "False") return meta """ assert callable(hook) _import_hooks.append(hook) #---------------------------------------------------------------------------- def _reconstruct_persistent_obj(meta): r"""Hook that is called internally by the `pickle` module to unpickle a persistent object. """ meta = dnnlib.EasyDict(meta) meta.state = dnnlib.EasyDict(meta.state) for hook in _import_hooks: meta = hook(meta) assert meta is not None assert meta.version == _version module = _src_to_module(meta.module_src) assert meta.type == 'class' orig_class = module.__dict__[meta.class_name] decorator_class = persistent_class(orig_class) obj = decorator_class.__new__(decorator_class) setstate = getattr(obj, '__setstate__', None) if callable(setstate): setstate(meta.state) # pylint: disable=not-callable else: obj.__dict__.update(meta.state) return obj #---------------------------------------------------------------------------- def _module_to_src(module): r"""Query the source code of a given Python module. """ src = _module_to_src_dict.get(module, None) if src is None: src = inspect.getsource(module) _module_to_src_dict[module] = src _src_to_module_dict[src] = module return src def _src_to_module(src): r"""Get or create a Python module for the given source code. """ module = _src_to_module_dict.get(src, None) if module is None: module_name = "_imported_module_" + uuid.uuid4().hex module = types.ModuleType(module_name) sys.modules[module_name] = module _module_to_src_dict[module] = src _src_to_module_dict[src] = module exec(src, module.__dict__) # pylint: disable=exec-used return module #---------------------------------------------------------------------------- def _check_pickleable(obj): r"""Check that the given object is pickleable, raising an exception if it is not. This function is expected to be considerably more efficient than actually pickling the object. """ def recurse(obj): if isinstance(obj, (list, tuple, set)): return [recurse(x) for x in obj] if isinstance(obj, dict): return [[recurse(x), recurse(y)] for x, y in obj.items()] if isinstance(obj, (str, int, float, bool, bytes, bytearray)): return None # Python primitive types are pickleable. if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']: return None # NumPy arrays and PyTorch tensors are pickleable. if is_persistent(obj): return None # Persistent objects are pickleable, by virtue of the constructor check. return obj with io.BytesIO() as f: pickle.dump(recurse(obj), f) #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/torch_utils/training_stats.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Facilities for reporting and collecting training statistics across multiple processes and devices. The interface is designed to minimize synchronization overhead as well as the amount of boilerplate in user code.""" import re import numpy as np import torch import dnnlib from . import misc #---------------------------------------------------------------------------- _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. _counter_dtype = torch.float64 # Data type to use for the internal counters. _rank = 0 # Rank of the current process. _sync_device = None # Device to use for multiprocess communication. None = single-process. _sync_called = False # Has _sync() been called yet? _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor #---------------------------------------------------------------------------- def init_multiprocessing(rank, sync_device): r"""Initializes `torch_utils.training_stats` for collecting statistics across multiple processes. This function must be called after `torch.distributed.init_process_group()` and before `Collector.update()`. The call is not necessary if multi-process collection is not needed. Args: rank: Rank of the current process. sync_device: PyTorch device to use for inter-process communication, or None to disable multi-process collection. Typically `torch.device('cuda', rank)`. """ global _rank, _sync_device assert not _sync_called _rank = rank _sync_device = sync_device #---------------------------------------------------------------------------- @misc.profiled_function def report(name, value): r"""Broadcasts the given set of scalars to all interested instances of `Collector`, across device and process boundaries. This function is expected to be extremely cheap and can be safely called from anywhere in the training loop, loss function, or inside a `torch.nn.Module`. Warning: The current implementation expects the set of unique names to be consistent across processes. Please make sure that `report()` is called at least once for each unique name by each process, and in the same order. If a given process has no scalars to broadcast, it can do `report(name, [])` (empty list). Args: name: Arbitrary string specifying the name of the statistic. Averages are accumulated separately for each unique name. value: Arbitrary set of scalars. Can be a list, tuple, NumPy array, PyTorch tensor, or Python scalar. Returns: The same `value` that was passed in. """ if name not in _counters: _counters[name] = dict() elems = torch.as_tensor(value) if elems.numel() == 0: return value elems = elems.detach().flatten().to(_reduce_dtype) moments = torch.stack([ torch.ones_like(elems).sum(), elems.sum(), elems.square().sum(), ]) assert moments.ndim == 1 and moments.shape[0] == _num_moments moments = moments.to(_counter_dtype) device = moments.device if device not in _counters[name]: _counters[name][device] = torch.zeros_like(moments) _counters[name][device].add_(moments) return value #---------------------------------------------------------------------------- def report0(name, value): r"""Broadcasts the given set of scalars by the first process (`rank = 0`), but ignores any scalars provided by the other processes. See `report()` for further details. """ report(name, value if _rank == 0 else []) return value #---------------------------------------------------------------------------- class Collector: r"""Collects the scalars broadcasted by `report()` and `report0()` and computes their long-term averages (mean and standard deviation) over user-defined periods of time. The averages are first collected into internal counters that are not directly visible to the user. They are then copied to the user-visible state as a result of calling `update()` and can then be queried using `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the internal counters for the next round, so that the user-visible state effectively reflects averages collected between the last two calls to `update()`. Args: regex: Regular expression defining which statistics to collect. The default is to collect everything. keep_previous: Whether to retain the previous averages if no scalars were collected on a given round (default: True). """ def __init__(self, regex='.*', keep_previous=True): self._regex = re.compile(regex) self._keep_previous = keep_previous self._cumulative = dict() self._moments = dict() self.update() self._moments.clear() def names(self): r"""Returns the names of all statistics broadcasted so far that match the regular expression specified at construction time. """ return [name for name in _counters if self._regex.fullmatch(name)] def update(self): r"""Copies current values of the internal counters to the user-visible state and resets them for the next round. If `keep_previous=True` was specified at construction time, the operation is skipped for statistics that have received no scalars since the last update, retaining their previous averages. This method performs a number of GPU-to-CPU transfers and one `torch.distributed.all_reduce()`. It is intended to be called periodically in the main training loop, typically once every N training steps. """ if not self._keep_previous: self._moments.clear() for name, cumulative in _sync(self.names()): if name not in self._cumulative: self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) delta = cumulative - self._cumulative[name] self._cumulative[name].copy_(cumulative) if float(delta[0]) != 0: self._moments[name] = delta def _get_delta(self, name): r"""Returns the raw moments that were accumulated for the given statistic between the last two calls to `update()`, or zero if no scalars were collected. """ assert self._regex.fullmatch(name) if name not in self._moments: self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) return self._moments[name] def num(self, name): r"""Returns the number of scalars that were accumulated for the given statistic between the last two calls to `update()`, or zero if no scalars were collected. """ delta = self._get_delta(name) return int(delta[0]) def mean(self, name): r"""Returns the mean of the scalars that were accumulated for the given statistic between the last two calls to `update()`, or NaN if no scalars were collected. """ delta = self._get_delta(name) if int(delta[0]) == 0: return float('nan') return float(delta[1] / delta[0]) def std(self, name): r"""Returns the standard deviation of the scalars that were accumulated for the given statistic between the last two calls to `update()`, or NaN if no scalars were collected. """ delta = self._get_delta(name) if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): return float('nan') if int(delta[0]) == 1: return float(0) mean = float(delta[1] / delta[0]) raw_var = float(delta[2] / delta[0]) return np.sqrt(max(raw_var - np.square(mean), 0)) def as_dict(self): r"""Returns the averages accumulated between the last two calls to `update()` as an `dnnlib.EasyDict`. The contents are as follows: dnnlib.EasyDict( NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), ... ) """ stats = dnnlib.EasyDict() for name in self.names(): stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) return stats def __getitem__(self, name): r"""Convenience getter. `collector[name]` is a synonym for `collector.mean(name)`. """ return self.mean(name) #---------------------------------------------------------------------------- def _sync(names): r"""Synchronize the global cumulative counters across devices and processes. Called internally by `Collector.update()`. """ if len(names) == 0: return [] global _sync_called _sync_called = True # Collect deltas within current rank. deltas = [] device = _sync_device if _sync_device is not None else torch.device('cpu') for name in names: delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) for counter in _counters[name].values(): delta.add_(counter.to(device)) counter.copy_(torch.zeros_like(counter)) deltas.append(delta) deltas = torch.stack(deltas) # Sum deltas across ranks. if _sync_device is not None: torch.distributed.all_reduce(deltas) # Update cumulative values. deltas = deltas.cpu() for idx, name in enumerate(names): if name not in _cumulative: _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) _cumulative[name].add_(deltas[idx]) # Return name-value pairs. return [(name, _cumulative[name]) for name in names] #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/training/__init__.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. # empty ================================================ FILE: stylegan_human/training/augment.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Augmentation pipeline from the paper "Training Generative Adversarial Networks with Limited Data". Matches the original implementation by Karras et al. at https://github.com/NVlabs/stylegan2-ada/blob/main/training/augment.py""" import numpy as np import scipy.signal import torch from torch_utils import persistence from torch_utils import misc from torch_utils.ops import upfirdn2d from torch_utils.ops import grid_sample_gradfix from torch_utils.ops import conv2d_gradfix #---------------------------------------------------------------------------- # Coefficients of various wavelet decomposition low-pass filters. wavelets = { 'haar': [0.7071067811865476, 0.7071067811865476], 'db1': [0.7071067811865476, 0.7071067811865476], 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025], 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569], 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523], 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125], 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017], 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236], 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161], 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025], 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569], 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427], 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728], 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148], 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255], 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609], } #---------------------------------------------------------------------------- # Helpers for constructing transformation matrices. def matrix(*rows, device=None): assert all(len(row) == len(rows[0]) for row in rows) elems = [x for row in rows for x in row] ref = [x for x in elems if isinstance(x, torch.Tensor)] if len(ref) == 0: return misc.constant(np.asarray(rows), device=device) assert device is None or device == ref[0].device elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems] return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1)) def translate2d(tx, ty, **kwargs): return matrix( [1, 0, tx], [0, 1, ty], [0, 0, 1], **kwargs) def translate3d(tx, ty, tz, **kwargs): return matrix( [1, 0, 0, tx], [0, 1, 0, ty], [0, 0, 1, tz], [0, 0, 0, 1], **kwargs) def scale2d(sx, sy, **kwargs): return matrix( [sx, 0, 0], [0, sy, 0], [0, 0, 1], **kwargs) def scale3d(sx, sy, sz, **kwargs): return matrix( [sx, 0, 0, 0], [0, sy, 0, 0], [0, 0, sz, 0], [0, 0, 0, 1], **kwargs) def rotate2d(theta, **kwargs): return matrix( [torch.cos(theta), torch.sin(-theta), 0], [torch.sin(theta), torch.cos(theta), 0], [0, 0, 1], **kwargs) def rotate3d(v, theta, **kwargs): vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2] s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c return matrix( [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0], [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0], [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0], [0, 0, 0, 1], **kwargs) def translate2d_inv(tx, ty, **kwargs): return translate2d(-tx, -ty, **kwargs) def scale2d_inv(sx, sy, **kwargs): return scale2d(1 / sx, 1 / sy, **kwargs) def rotate2d_inv(theta, **kwargs): return rotate2d(-theta, **kwargs) #---------------------------------------------------------------------------- # Versatile image augmentation pipeline from the paper # "Training Generative Adversarial Networks with Limited Data". # # All augmentations are disabled by default; individual augmentations can # be enabled by setting their probability multipliers to 1. @persistence.persistent_class class AugmentPipe(torch.nn.Module): def __init__(self, xflip=0, rotate90=0, xint=0, xint_max=0.125, scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125, brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5, hue_max=1, saturation_std=1, imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1, noise=0, cutout=0, noise_std=0.1, cutout_size=0.5, ): super().__init__() self.register_buffer('p', torch.ones([])) # Overall multiplier for augmentation probability. # Pixel blitting. self.xflip = float(xflip) # Probability multiplier for x-flip. self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations. self.xint = float(xint) # Probability multiplier for integer translation. self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions. # General geometric transformations. self.scale = float(scale) # Probability multiplier for isotropic scaling. self.rotate = float(rotate) # Probability multiplier for arbitrary rotation. self.aniso = float(aniso) # Probability multiplier for anisotropic scaling. self.xfrac = float(xfrac) # Probability multiplier for fractional translation. self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling. self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle. self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling. self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions. # Color transformations. self.brightness = float(brightness) # Probability multiplier for brightness. self.contrast = float(contrast) # Probability multiplier for contrast. self.lumaflip = float(lumaflip) # Probability multiplier for luma flip. self.hue = float(hue) # Probability multiplier for hue rotation. self.saturation = float(saturation) # Probability multiplier for saturation. self.brightness_std = float(brightness_std) # Standard deviation of brightness. self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast. self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle. self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation. # Image-space filtering. self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering. self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands. self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification. # Image-space corruptions. self.noise = float(noise) # Probability multiplier for additive RGB noise. self.cutout = float(cutout) # Probability multiplier for cutout. self.noise_std = float(noise_std) # Standard deviation of additive RGB noise. self.cutout_size = float(cutout_size) # Size of the cutout rectangle, relative to image dimensions. # Setup orthogonal lowpass filter for geometric augmentations. self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6'])) # Construct filter bank for image-space filtering. Hz_lo = np.asarray(wavelets['sym2']) # H(z) Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z) Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2 Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2 Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i) for i in range(1, Hz_fbank.shape[0]): Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1] Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2]) Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2 self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32)) def forward(self, images, debug_percentile=None): assert isinstance(images, torch.Tensor) and images.ndim == 4 batch_size, num_channels, height, width = images.shape device = images.device if debug_percentile is not None: debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device) # ------------------------------------- # Select parameters for pixel blitting. # ------------------------------------- # Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in I_3 = torch.eye(3, device=device) G_inv = I_3 # Apply x-flip with probability (xflip * strength). if self.xflip > 0: i = torch.floor(torch.rand([batch_size], device=device) * 2) i = torch.where(torch.rand([batch_size], device=device) < self.xflip * self.p, i, torch.zeros_like(i)) if debug_percentile is not None: i = torch.full_like(i, torch.floor(debug_percentile * 2)) G_inv = G_inv @ scale2d_inv(1 - 2 * i, 1) # Apply 90 degree rotations with probability (rotate90 * strength). if self.rotate90 > 0: i = torch.floor(torch.rand([batch_size], device=device) * 4) i = torch.where(torch.rand([batch_size], device=device) < self.rotate90 * self.p, i, torch.zeros_like(i)) if debug_percentile is not None: i = torch.full_like(i, torch.floor(debug_percentile * 4)) G_inv = G_inv @ rotate2d_inv(-np.pi / 2 * i) # Apply integer translation with probability (xint * strength). if self.xint > 0: t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t)) if debug_percentile is not None: t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max) G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height)) # -------------------------------------------------------- # Select parameters for general geometric transformations. # -------------------------------------------------------- # Apply isotropic scaling with probability (scale * strength). if self.scale > 0: s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std) s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s)) if debug_percentile is not None: s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std)) G_inv = G_inv @ scale2d_inv(s, s) # Apply pre-rotation with probability p_rot. p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p if self.rotate > 0: theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta)) if debug_percentile is not None: theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max) G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling. # Apply anisotropic scaling with probability (aniso * strength). if self.aniso > 0: s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std) s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s)) if debug_percentile is not None: s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std)) G_inv = G_inv @ scale2d_inv(s, 1 / s) # Apply post-rotation with probability p_rot. if self.rotate > 0: theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta)) if debug_percentile is not None: theta = torch.zeros_like(theta) G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling. # Apply fractional translation with probability (xfrac * strength). if self.xfrac > 0: t = torch.randn([batch_size, 2], device=device) * self.xfrac_std t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t)) if debug_percentile is not None: t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std) G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height) # ---------------------------------- # Execute geometric transformations. # ---------------------------------- # Execute if the transform is not identity. if G_inv is not I_3: # Calculate padding. cx = (width - 1) / 2 cy = (height - 1) / 2 cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz] cp = G_inv @ cp.t() # [batch, xyz, idx] Hz_pad = self.Hz_geom.shape[0] // 4 margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx] margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1] margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device) margin = margin.max(misc.constant([0, 0] * 2, device=device)) margin = margin.min(misc.constant([width-1, height-1] * 2, device=device)) mx0, my0, mx1, my1 = margin.ceil().to(torch.int32) # Pad image and adjust origin. images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect') G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv # Upsample. images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2) G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device) G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device) # Execute transformation. shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2] G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device) grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False) images = grid_sample_gradfix.grid_sample(images, grid) # Downsample and crop. images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True) # -------------------------------------------- # Select parameters for color transformations. # -------------------------------------------- # Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out I_4 = torch.eye(4, device=device) C = I_4 # Apply brightness with probability (brightness * strength). if self.brightness > 0: b = torch.randn([batch_size], device=device) * self.brightness_std b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b)) if debug_percentile is not None: b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std) C = translate3d(b, b, b) @ C # Apply contrast with probability (contrast * strength). if self.contrast > 0: c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std) c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c)) if debug_percentile is not None: c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std)) C = scale3d(c, c, c) @ C # Apply luma flip with probability (lumaflip * strength). v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis. if self.lumaflip > 0: i = torch.floor(torch.rand([batch_size, 1, 1], device=device) * 2) i = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.lumaflip * self.p, i, torch.zeros_like(i)) if debug_percentile is not None: i = torch.full_like(i, torch.floor(debug_percentile * 2)) C = (I_4 - 2 * v.ger(v) * i) @ C # Householder reflection. # Apply hue rotation with probability (hue * strength). if self.hue > 0 and num_channels > 1: theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta)) if debug_percentile is not None: theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max) C = rotate3d(v, theta) @ C # Rotate around v. # Apply saturation with probability (saturation * strength). if self.saturation > 0 and num_channels > 1: s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std) s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s)) if debug_percentile is not None: s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std)) C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C # ------------------------------ # Execute color transformations. # ------------------------------ # Execute if the transform is not identity. if C is not I_4: images = images.reshape([batch_size, num_channels, height * width]) if num_channels == 3: images = C[:, :3, :3] @ images + C[:, :3, 3:] elif num_channels == 1: C = C[:, :3, :].mean(dim=1, keepdims=True) images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:] else: raise ValueError('Image must be RGB (3 channels) or L (1 channel)') images = images.reshape([batch_size, num_channels, height, width]) # ---------------------- # Image-space filtering. # ---------------------- if self.imgfilter > 0: num_bands = self.Hz_fbank.shape[0] assert len(self.imgfilter_bands) == num_bands expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f). # Apply amplification for each band with probability (imgfilter * strength * band_strength). g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity). for i, band_strength in enumerate(self.imgfilter_bands): t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std) t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i)) if debug_percentile is not None: t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i) t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector. t[:, i] = t_i # Replace i'th element. t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power. g = g * t # Accumulate into global gain. # Construct combined amplification filter. Hz_prime = g @ self.Hz_fbank # [batch, tap] Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap] Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap] # Apply filter. p = self.Hz_fbank.shape[1] // 2 images = images.reshape([1, batch_size * num_channels, height, width]) images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect') images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels) images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels) images = images.reshape([batch_size, num_channels, height, width]) # ------------------------ # Image-space corruptions. # ------------------------ # Apply additive RGB noise with probability (noise * strength). if self.noise > 0: sigma = torch.randn([batch_size, 1, 1, 1], device=device).abs() * self.noise_std sigma = torch.where(torch.rand([batch_size, 1, 1, 1], device=device) < self.noise * self.p, sigma, torch.zeros_like(sigma)) if debug_percentile is not None: sigma = torch.full_like(sigma, torch.erfinv(debug_percentile) * self.noise_std) images = images + torch.randn([batch_size, num_channels, height, width], device=device) * sigma # Apply cutout with probability (cutout * strength). if self.cutout > 0: size = torch.full([batch_size, 2, 1, 1, 1], self.cutout_size, device=device) size = torch.where(torch.rand([batch_size, 1, 1, 1, 1], device=device) < self.cutout * self.p, size, torch.zeros_like(size)) center = torch.rand([batch_size, 2, 1, 1, 1], device=device) if debug_percentile is not None: size = torch.full_like(size, self.cutout_size) center = torch.full_like(center, debug_percentile) coord_x = torch.arange(width, device=device).reshape([1, 1, 1, -1]) coord_y = torch.arange(height, device=device).reshape([1, 1, -1, 1]) mask_x = (((coord_x + 0.5) / width - center[:, 0]).abs() >= size[:, 0] / 2) mask_y = (((coord_y + 0.5) / height - center[:, 1]).abs() >= size[:, 1] / 2) mask = torch.logical_or(mask_x, mask_y).to(torch.float32) images = images * mask return images #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/training/dataset.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Streaming images and labels from datasets created with dataset_tool.py.""" import os import numpy as np import zipfile import PIL.Image import json import torch import dnnlib try: import pyspng except ImportError: pyspng = None #---------------------------------------------------------------------------- class Dataset(torch.utils.data.Dataset): def __init__(self, name, # Name of the dataset. raw_shape, # Shape of the raw image data (NCHW). max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. use_labels = False, # Enable conditioning labels? False = label dimension is zero. xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. random_seed = 0, # Random seed to use when applying max_size. ): self._name = name self._raw_shape = list(raw_shape) self._use_labels = use_labels self._raw_labels = None self._label_shape = None # Apply max_size. self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) if (max_size is not None) and (self._raw_idx.size > max_size): np.random.RandomState(random_seed).shuffle(self._raw_idx) self._raw_idx = np.sort(self._raw_idx[:max_size]) # Apply xflip. self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) if xflip: self._raw_idx = np.tile(self._raw_idx, 2) self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) def _get_raw_labels(self): if self._raw_labels is None: self._raw_labels = self._load_raw_labels() if self._use_labels else None if self._raw_labels is None: self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) assert isinstance(self._raw_labels, np.ndarray) assert self._raw_labels.shape[0] == self._raw_shape[0] assert self._raw_labels.dtype in [np.float32, np.int64] if self._raw_labels.dtype == np.int64: assert self._raw_labels.ndim == 1 assert np.all(self._raw_labels >= 0) return self._raw_labels def close(self): # to be overridden by subclass pass def _load_raw_image(self, raw_idx): # to be overridden by subclass raise NotImplementedError def _load_raw_labels(self): # to be overridden by subclass raise NotImplementedError def __getstate__(self): return dict(self.__dict__, _raw_labels=None) def __del__(self): try: self.close() except: pass def __len__(self): return self._raw_idx.size def __getitem__(self, idx): image = self._load_raw_image(self._raw_idx[idx]) assert isinstance(image, np.ndarray) assert list(image.shape) == self.image_shape assert image.dtype == np.uint8 if self._xflip[idx]: assert image.ndim == 3 # CHW image = image[:, :, ::-1] return image.copy(), self.get_label(idx) def get_label(self, idx): label = self._get_raw_labels()[self._raw_idx[idx]] if label.dtype == np.int64: onehot = np.zeros(self.label_shape, dtype=np.float32) onehot[label] = 1 label = onehot return label.copy() def get_details(self, idx): d = dnnlib.EasyDict() d.raw_idx = int(self._raw_idx[idx]) d.xflip = (int(self._xflip[idx]) != 0) d.raw_label = self._get_raw_labels()[d.raw_idx].copy() return d @property def name(self): return self._name @property def image_shape(self): return list(self._raw_shape[1:]) @property def num_channels(self): assert len(self.image_shape) == 3 # CHW return self.image_shape[0] @property def resolution(self): assert len(self.image_shape) == 3 # CHW assert self.image_shape[1] == self.image_shape[2] return self.image_shape[1] @property def label_shape(self): if self._label_shape is None: raw_labels = self._get_raw_labels() if raw_labels.dtype == np.int64: self._label_shape = [int(np.max(raw_labels)) + 1] else: self._label_shape = raw_labels.shape[1:] return list(self._label_shape) @property def label_dim(self): assert len(self.label_shape) == 1 return self.label_shape[0] @property def has_labels(self): return any(x != 0 for x in self.label_shape) @property def has_onehot_labels(self): return self._get_raw_labels().dtype == np.int64 #---------------------------------------------------------------------------- class ImageFolderDataset(Dataset): def __init__(self, path, # Path to directory or zip. resolution = None, # Ensure specific resolution, None = highest available. **super_kwargs, # Additional arguments for the Dataset base class. ): self._path = path self._zipfile = None if os.path.isdir(self._path): self._type = 'dir' self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} elif self._file_ext(self._path) == '.zip': self._type = 'zip' self._all_fnames = set(self._get_zipfile().namelist()) else: raise IOError('Path must point to a directory or zip') PIL.Image.init() self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) if len(self._image_fnames) == 0: raise IOError('No image files found in the specified path') name = os.path.splitext(os.path.basename(self._path))[0] raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): raise IOError('Image files do not match the specified resolution') super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) @staticmethod def _file_ext(fname): return os.path.splitext(fname)[1].lower() def _get_zipfile(self): assert self._type == 'zip' if self._zipfile is None: self._zipfile = zipfile.ZipFile(self._path) return self._zipfile def _open_file(self, fname): if self._type == 'dir': return open(os.path.join(self._path, fname), 'rb') if self._type == 'zip': return self._get_zipfile().open(fname, 'r') return None def close(self): try: if self._zipfile is not None: self._zipfile.close() finally: self._zipfile = None def __getstate__(self): return dict(super().__getstate__(), _zipfile=None) def _load_raw_image(self, raw_idx): fname = self._image_fnames[raw_idx] with self._open_file(fname) as f: if pyspng is not None and self._file_ext(fname) == '.png': image = pyspng.load(f.read()) else: image = np.array(PIL.Image.open(f)) if image.ndim == 2: image = image[:, :, np.newaxis] # HW => HWC image = image.transpose(2, 0, 1) # HWC => CHW return image def _load_raw_labels(self): fname = 'dataset.json' if fname not in self._all_fnames: return None with self._open_file(fname) as f: labels = json.load(f)['labels'] if labels is None: return None labels = dict(labels) labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] labels = np.array(labels) labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) return labels #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/training/loss.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Loss functions.""" import numpy as np import torch from torch_utils import training_stats from torch_utils.ops import conv2d_gradfix from torch_utils.ops import upfirdn2d #---------------------------------------------------------------------------- class Loss: def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg): # to be overridden by subclass raise NotImplementedError() #---------------------------------------------------------------------------- class StyleGAN2Loss(Loss): def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10, style_mixing_prob=0, pl_weight=0, pl_batch_shrink=2, pl_decay=0.01, pl_no_weight_grad=False, blur_init_sigma=0, blur_fade_kimg=0): super().__init__() self.device = device self.G = G self.D = D self.augment_pipe = augment_pipe self.r1_gamma = r1_gamma self.style_mixing_prob = style_mixing_prob self.pl_weight = pl_weight self.pl_batch_shrink = pl_batch_shrink self.pl_decay = pl_decay self.pl_no_weight_grad = pl_no_weight_grad self.pl_mean = torch.zeros([], device=device) self.blur_init_sigma = blur_init_sigma self.blur_fade_kimg = blur_fade_kimg def run_G(self, z, c, update_emas=False): ws = self.G.mapping(z, c, update_emas=update_emas) if self.style_mixing_prob > 0: with torch.autograd.profiler.record_function('style_mixing'): cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1])) ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:] img = self.G.synthesis(ws, update_emas=update_emas) return img, ws def run_D(self, img, c, blur_sigma=0, update_emas=False): blur_size = np.floor(blur_sigma * 3) if blur_size > 0: with torch.autograd.profiler.record_function('blur'): f = torch.arange(-blur_size, blur_size + 1, device=img.device).div(blur_sigma).square().neg().exp2() img = upfirdn2d.filter2d(img, f / f.sum()) if self.augment_pipe is not None: img = self.augment_pipe(img) logits = self.D(img, c, update_emas=update_emas) return logits def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg): assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth'] if self.pl_weight == 0: phase = {'Greg': 'none', 'Gboth': 'Gmain'}.get(phase, phase) if self.r1_gamma == 0: phase = {'Dreg': 'none', 'Dboth': 'Dmain'}.get(phase, phase) blur_sigma = max(1 - cur_nimg / (self.blur_fade_kimg * 1e3), 0) * self.blur_init_sigma if self.blur_fade_kimg > 0 else 0 # Gmain: Maximize logits for generated images. if phase in ['Gmain', 'Gboth']: with torch.autograd.profiler.record_function('Gmain_forward'): gen_img, _gen_ws = self.run_G(gen_z, gen_c) gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma) training_stats.report('Loss/scores/fake', gen_logits) training_stats.report('Loss/signs/fake', gen_logits.sign()) loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits)) training_stats.report('Loss/G/loss', loss_Gmain) with torch.autograd.profiler.record_function('Gmain_backward'): loss_Gmain.mean().mul(gain).backward() # Gpl: Apply path length regularization. if phase in ['Greg', 'Gboth']: with torch.autograd.profiler.record_function('Gpl_forward'): batch_size = gen_z.shape[0] // self.pl_batch_shrink gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size]) pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3]) with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients(self.pl_no_weight_grad): pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0] pl_lengths = pl_grads.square().sum(2).mean(1).sqrt() pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) self.pl_mean.copy_(pl_mean.detach()) pl_penalty = (pl_lengths - pl_mean).square() training_stats.report('Loss/pl_penalty', pl_penalty) loss_Gpl = pl_penalty * self.pl_weight training_stats.report('Loss/G/reg', loss_Gpl) with torch.autograd.profiler.record_function('Gpl_backward'): loss_Gpl.mean().mul(gain).backward() # Dmain: Minimize logits for generated images. loss_Dgen = 0 if phase in ['Dmain', 'Dboth']: with torch.autograd.profiler.record_function('Dgen_forward'): gen_img, _gen_ws = self.run_G(gen_z, gen_c, update_emas=True) gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma, update_emas=True) training_stats.report('Loss/scores/fake', gen_logits) training_stats.report('Loss/signs/fake', gen_logits.sign()) loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits)) with torch.autograd.profiler.record_function('Dgen_backward'): loss_Dgen.mean().mul(gain).backward() # Dmain: Maximize logits for real images. # Dr1: Apply R1 regularization. if phase in ['Dmain', 'Dreg', 'Dboth']: name = 'Dreal' if phase == 'Dmain' else 'Dr1' if phase == 'Dreg' else 'Dreal_Dr1' with torch.autograd.profiler.record_function(name + '_forward'): real_img_tmp = real_img.detach().requires_grad_(phase in ['Dreg', 'Dboth']) real_logits = self.run_D(real_img_tmp, real_c, blur_sigma=blur_sigma) training_stats.report('Loss/scores/real', real_logits) training_stats.report('Loss/signs/real', real_logits.sign()) loss_Dreal = 0 if phase in ['Dmain', 'Dboth']: loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits)) training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal) loss_Dr1 = 0 if phase in ['Dreg', 'Dboth']: with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients(): r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0] r1_penalty = r1_grads.square().sum([1,2,3]) loss_Dr1 = r1_penalty * (self.r1_gamma / 2) training_stats.report('Loss/r1_penalty', r1_penalty) training_stats.report('Loss/D/reg', loss_Dr1) with torch.autograd.profiler.record_function(name + '_backward'): (loss_Dreal + loss_Dr1).mean().mul(gain).backward() #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/training/networks_stylegan2.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Network architectures from the paper "Analyzing and Improving the Image Quality of StyleGAN". Matches the original implementation of configs E-F by Karras et al. at https://github.com/NVlabs/stylegan2/blob/master/training/networks_stylegan2.py""" import numpy as np import torch import torch.nn.functional as F from torch_utils import misc from torch_utils import persistence from torch_utils.ops import conv2d_resample from torch_utils.ops import upfirdn2d from torch_utils.ops import bias_act from torch_utils.ops import fma #---------------------------------------------------------------------------- @misc.profiled_function def normalize_2nd_moment(x, dim=1, eps=1e-8): return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() #---------------------------------------------------------------------------- @misc.profiled_function def modulated_conv2d( x, # Input tensor of shape [batch_size, in_channels, in_height, in_width]. weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width]. styles, # Modulation coefficients of shape [batch_size, in_channels]. noise = None, # Optional noise tensor to add to the output activations. up = 1, # Integer upsampling factor. down = 1, # Integer downsampling factor. padding = 0, # Padding with respect to the upsampled image. resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter(). demodulate = True, # Apply weight demodulation? flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d). fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation? ): batch_size = x.shape[0] out_channels, in_channels, kh, kw = weight.shape misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk] misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] misc.assert_shape(styles, [batch_size, in_channels]) # [NI] # Pre-normalize inputs to avoid FP16 overflow. if x.dtype == torch.float16 and demodulate: weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I # Calculate per-sample weights and demodulation coefficients. w = None dcoefs = None if demodulate or fused_modconv: w = weight.unsqueeze(0) # [NOIkk] w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] if demodulate: dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] if demodulate and fused_modconv: w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] # Execute by scaling the activations before and after the convolution. if not fused_modconv: x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight) if demodulate and noise is not None: x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) elif demodulate: x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1) elif noise is not None: x = x.add_(noise.to(x.dtype)) return x # Execute as one fused op using grouped convolution. with misc.suppress_tracer_warnings(): # this value will be treated as a constant batch_size = int(batch_size) misc.assert_shape(x, [batch_size, in_channels, None, None]) x = x.reshape(1, -1, *x.shape[2:]) w = w.reshape(-1, in_channels, kh, kw) x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight) x = x.reshape(batch_size, -1, *x.shape[2:]) if noise is not None: x = x.add_(noise) return x #---------------------------------------------------------------------------- @persistence.persistent_class class FullyConnectedLayer(torch.nn.Module): def __init__(self, in_features, # Number of input features. out_features, # Number of output features. bias = True, # Apply additive bias before the activation function? activation = 'linear', # Activation function: 'relu', 'lrelu', etc. lr_multiplier = 1, # Learning rate multiplier. bias_init = 0, # Initial value for the additive bias. ): super().__init__() self.in_features = in_features self.out_features = out_features self.activation = activation self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None self.weight_gain = lr_multiplier / np.sqrt(in_features) self.bias_gain = lr_multiplier def forward(self, x): w = self.weight.to(x.dtype) * self.weight_gain b = self.bias if b is not None: b = b.to(x.dtype) if self.bias_gain != 1: b = b * self.bias_gain if self.activation == 'linear' and b is not None: x = torch.addmm(b.unsqueeze(0), x, w.t()) else: x = x.matmul(w.t()) x = bias_act.bias_act(x, b, act=self.activation) return x def extra_repr(self): return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}' #---------------------------------------------------------------------------- @persistence.persistent_class class Conv2dLayer(torch.nn.Module): def __init__(self, in_channels, # Number of input channels. out_channels, # Number of output channels. kernel_size, # Width and height of the convolution kernel. bias = True, # Apply additive bias before the activation function? activation = 'linear', # Activation function: 'relu', 'lrelu', etc. up = 1, # Integer upsampling factor. down = 1, # Integer downsampling factor. resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. conv_clamp = None, # Clamp the output to +-X, None = disable clamping. channels_last = False, # Expect the input to have memory_format=channels_last? trainable = True, # Update the weights of this layer during training? ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.activation = activation self.up = up self.down = down self.conv_clamp = conv_clamp self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) self.padding = kernel_size // 2 self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) self.act_gain = bias_act.activation_funcs[activation].def_gain memory_format = torch.channels_last if channels_last else torch.contiguous_format weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format) bias = torch.zeros([out_channels]) if bias else None if trainable: self.weight = torch.nn.Parameter(weight) self.bias = torch.nn.Parameter(bias) if bias is not None else None else: self.register_buffer('weight', weight) if bias is not None: self.register_buffer('bias', bias) else: self.bias = None def forward(self, x, gain=1): w = self.weight * self.weight_gain b = self.bias.to(x.dtype) if self.bias is not None else None flip_weight = (self.up == 1) # slightly faster x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight) act_gain = self.act_gain * gain act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp) return x def extra_repr(self): return ' '.join([ f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, activation={self.activation:s},', f'up={self.up}, down={self.down}']) #---------------------------------------------------------------------------- @persistence.persistent_class class MappingNetwork(torch.nn.Module): def __init__(self, z_dim, # Input latent (Z) dimensionality, 0 = no latent. c_dim, # Conditioning label (C) dimensionality, 0 = no label. w_dim, # Intermediate latent (W) dimensionality. num_ws, # Number of intermediate latents to output, None = do not broadcast. num_layers = 8, # Number of mapping layers. embed_features = None, # Label embedding dimensionality, None = same as w_dim. layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim. activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers. w_avg_beta = 0.998, # Decay for tracking the moving average of W during training, None = do not track. ): super().__init__() self.z_dim = z_dim self.c_dim = c_dim self.w_dim = w_dim self.num_ws = num_ws self.num_layers = num_layers self.w_avg_beta = w_avg_beta if embed_features is None: embed_features = w_dim if c_dim == 0: embed_features = 0 if layer_features is None: layer_features = w_dim features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] if c_dim > 0: self.embed = FullyConnectedLayer(c_dim, embed_features) for idx in range(num_layers): in_features = features_list[idx] out_features = features_list[idx + 1] layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) setattr(self, f'fc{idx}', layer) if num_ws is not None and w_avg_beta is not None: self.register_buffer('w_avg', torch.zeros([w_dim])) def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False): # Embed, normalize, and concat inputs. x = None with torch.autograd.profiler.record_function('input'): if self.z_dim > 0: misc.assert_shape(z, [None, self.z_dim]) x = normalize_2nd_moment(z.to(torch.float32)) if self.c_dim > 0: misc.assert_shape(c, [None, self.c_dim]) y = normalize_2nd_moment(self.embed(c.to(torch.float32))) x = torch.cat([x, y], dim=1) if x is not None else y # Main layers. for idx in range(self.num_layers): layer = getattr(self, f'fc{idx}') x = layer(x) # Update moving average of W. if update_emas and self.w_avg_beta is not None: with torch.autograd.profiler.record_function('update_w_avg'): self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) # Broadcast. if self.num_ws is not None: with torch.autograd.profiler.record_function('broadcast'): x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) # Apply truncation. if truncation_psi != 1: with torch.autograd.profiler.record_function('truncate'): assert self.w_avg_beta is not None if self.num_ws is None or truncation_cutoff is None: x = self.w_avg.lerp(x, truncation_psi) else: x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) return x def extra_repr(self): return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}' #---------------------------------------------------------------------------- @persistence.persistent_class class SynthesisLayer(torch.nn.Module): def __init__(self, in_channels, # Number of input channels. out_channels, # Number of output channels. w_dim, # Intermediate latent (W) dimensionality. resolution, # Resolution of this layer. kernel_size = 3, # Convolution kernel size. up = 1, # Integer upsampling factor. use_noise = True, # Enable noise input? activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. channels_last = False, # Use channels_last format for the weights? ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.w_dim = w_dim self.resolution = resolution self.up = up self.use_noise = use_noise self.activation = activation self.conv_clamp = conv_clamp self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) self.padding = kernel_size // 2 self.act_gain = bias_act.activation_funcs[activation].def_gain self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) memory_format = torch.channels_last if channels_last else torch.contiguous_format self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)) if use_noise: self.register_buffer('noise_const', torch.randn([resolution, resolution])) self.noise_strength = torch.nn.Parameter(torch.zeros([])) self.bias = torch.nn.Parameter(torch.zeros([out_channels])) def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1): assert noise_mode in ['random', 'const', 'none'] in_resolution = self.resolution // self.up misc.assert_shape(x, [None, self.in_channels, in_resolution, in_resolution]) styles = self.affine(w) noise = None if self.use_noise and noise_mode == 'random': noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength if self.use_noise and noise_mode == 'const': noise = self.noise_const * self.noise_strength flip_weight = (self.up == 1) # slightly faster x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up, padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv) act_gain = self.act_gain * gain act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp) return x def extra_repr(self): return ' '.join([ f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d},', f'resolution={self.resolution:d}, up={self.up}, activation={self.activation:s}']) #---------------------------------------------------------------------------- @persistence.persistent_class class ToRGBLayer(torch.nn.Module): def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.w_dim = w_dim self.conv_clamp = conv_clamp self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) memory_format = torch.channels_last if channels_last else torch.contiguous_format self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)) self.bias = torch.nn.Parameter(torch.zeros([out_channels])) self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) def forward(self, x, w, fused_modconv=True): styles = self.affine(w) * self.weight_gain x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv) x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp) return x def extra_repr(self): return f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d}' #---------------------------------------------------------------------------- @persistence.persistent_class class SynthesisBlock(torch.nn.Module): def __init__(self, in_channels, # Number of input channels, 0 = first block. out_channels, # Number of output channels. w_dim, # Intermediate latent (W) dimensionality. resolution, # Resolution of this block. img_channels, # Number of output color channels. is_last, # Is this the last block? architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'. resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. use_fp16 = False, # Use FP16 for this block? fp16_channels_last = False, # Use channels-last memory format with FP16? fused_modconv_default = True, # Default value of fused_modconv. 'inference_only' = True for inference, False for training. **layer_kwargs, # Arguments for SynthesisLayer. ): assert architecture in ['orig', 'skip', 'resnet'] super().__init__() self.in_channels = in_channels self.w_dim = w_dim self.resolution = resolution self.img_channels = img_channels self.is_last = is_last self.architecture = architecture self.use_fp16 = use_fp16 self.channels_last = (use_fp16 and fp16_channels_last) self.fused_modconv_default = fused_modconv_default self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) self.num_conv = 0 self.num_torgb = 0 if in_channels == 0: self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution])) if in_channels != 0: self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=2, resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs) self.num_conv += 1 self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution, conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs) self.num_conv += 1 if is_last or architecture == 'skip': self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim, conv_clamp=conv_clamp, channels_last=self.channels_last) self.num_torgb += 1 if in_channels != 0 and architecture == 'resnet': self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2, resample_filter=resample_filter, channels_last=self.channels_last) def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, update_emas=False, **layer_kwargs): _ = update_emas # unused misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) w_iter = iter(ws.unbind(dim=1)) if ws.device.type != 'cuda': force_fp32 = True dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format if fused_modconv is None: fused_modconv = self.fused_modconv_default if fused_modconv == 'inference_only': fused_modconv = (not self.training) # Input. if self.in_channels == 0: x = self.const.to(dtype=dtype, memory_format=memory_format) x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) else: misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2]) x = x.to(dtype=dtype, memory_format=memory_format) # Main layers. if self.in_channels == 0: x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) elif self.architecture == 'resnet': y = self.skip(x, gain=np.sqrt(0.5)) x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) x = y.add_(x) else: x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) # ToRGB. if img is not None: misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2]) img = upfirdn2d.upsample2d(img, self.resample_filter) if self.is_last or self.architecture == 'skip': y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv) y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) img = img.add_(y) if img is not None else y assert x.dtype == dtype assert img is None or img.dtype == torch.float32 return x, img def extra_repr(self): return f'resolution={self.resolution:d}, architecture={self.architecture:s}' #---------------------------------------------------------------------------- @persistence.persistent_class class SynthesisNetwork(torch.nn.Module): def __init__(self, w_dim, # Intermediate latent (W) dimensionality. img_resolution, # Output image resolution. img_channels, # Number of color channels. channel_base = 32768, # Overall multiplier for the number of channels. channel_max = 512, # Maximum number of channels in any layer. num_fp16_res = 4, # Use FP16 for the N highest resolutions. **block_kwargs, # Arguments for SynthesisBlock. ): assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 super().__init__() self.w_dim = w_dim self.img_resolution = img_resolution self.img_resolution_log2 = int(np.log2(img_resolution)) self.img_channels = img_channels self.num_fp16_res = num_fp16_res self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)] channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions} fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) self.num_ws = 0 for res in self.block_resolutions: in_channels = channels_dict[res // 2] if res > 4 else 0 out_channels = channels_dict[res] use_fp16 = (res >= fp16_resolution) is_last = (res == self.img_resolution) block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res, img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs) self.num_ws += block.num_conv if is_last: self.num_ws += block.num_torgb setattr(self, f'b{res}', block) def forward(self, ws, return_feature=False, **block_kwargs): block_ws = [] features = [] with torch.autograd.profiler.record_function('split_ws'): misc.assert_shape(ws, [None, self.num_ws, self.w_dim]) ws = ws.to(torch.float32) w_idx = 0 for res in self.block_resolutions: block = getattr(self, f'b{res}') block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb)) w_idx += block.num_conv x = img = None for res, cur_ws in zip(self.block_resolutions, block_ws): block = getattr(self, f'b{res}') x, img = block(x, img, cur_ws, **block_kwargs) features.append(x) if return_feature: return img, features else: return img def extra_repr(self): return ' '.join([ f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},', f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},', f'num_fp16_res={self.num_fp16_res:d}']) #---------------------------------------------------------------------------- @persistence.persistent_class class Generator(torch.nn.Module): def __init__(self, z_dim, # Input latent (Z) dimensionality. c_dim, # Conditioning label (C) dimensionality. w_dim, # Intermediate latent (W) dimensionality. img_resolution, # Output resolution. img_channels, # Number of output color channels. mapping_kwargs = {}, # Arguments for MappingNetwork. synthesis_kwargs = {}, # Arguments for SynthesisNetwork. resize=None, # **synthesis_kwargs, # Arguments for SynthesisNetwork. ): super().__init__() self.z_dim = z_dim self.c_dim = c_dim self.w_dim = w_dim self.img_resolution = img_resolution self.img_channels = img_channels self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs) self.num_ws = self.synthesis.num_ws self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) self.resize = resize def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, input_is_w=False, return_feature=False, **synthesis_kwargs): if input_is_w: ws = z if ws.dim() == 2: ws = ws.unsqueeze(1).repeat([1, self.mapping.num_ws, 1]) else: ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) img = self.synthesis(ws, update_emas=update_emas, return_feature=return_feature, **synthesis_kwargs) if self.resize is not None: img = imresize(img, [self.resize, self.resize]) return img def imresize(image, size): dim = image.dim() if dim == 3: image = image.unsqueeze(1) b, _, h, w = image.shape if size[0] > h: image = F.interpolate(image, size, mode='bilinear') elif size[0] < h: image = F.interpolate(image, size, mode='area') if dim == 3: image = image.squeeze(1) return image #---------------------------------------------------------------------------- @persistence.persistent_class class DiscriminatorBlock(torch.nn.Module): def __init__(self, in_channels, # Number of input channels, 0 = first block. tmp_channels, # Number of intermediate channels. out_channels, # Number of output channels. resolution, # Resolution of this block. img_channels, # Number of input color channels. first_layer_idx, # Index of the first layer. architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. use_fp16 = False, # Use FP16 for this block? fp16_channels_last = False, # Use channels-last memory format with FP16? freeze_layers = 0, # Freeze-D: Number of layers to freeze. ): assert in_channels in [0, tmp_channels] assert architecture in ['orig', 'skip', 'resnet'] super().__init__() self.in_channels = in_channels self.resolution = resolution self.img_channels = img_channels self.first_layer_idx = first_layer_idx self.architecture = architecture self.use_fp16 = use_fp16 self.channels_last = (use_fp16 and fp16_channels_last) self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) self.num_layers = 0 def trainable_gen(): while True: layer_idx = self.first_layer_idx + self.num_layers trainable = (layer_idx >= freeze_layers) self.num_layers += 1 yield trainable trainable_iter = trainable_gen() if in_channels == 0 or architecture == 'skip': self.fromrgb = Conv2dLayer(img_channels, tmp_channels, kernel_size=1, activation=activation, trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last) self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation, trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last) self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2, trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last) if architecture == 'resnet': self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2, trainable=next(trainable_iter), resample_filter=resample_filter, channels_last=self.channels_last) def forward(self, x, img, force_fp32=False): if (x if x is not None else img).device.type != 'cuda': force_fp32 = True dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format # Input. if x is not None: misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) x = x.to(dtype=dtype, memory_format=memory_format) # FromRGB. if self.in_channels == 0 or self.architecture == 'skip': misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution]) img = img.to(dtype=dtype, memory_format=memory_format) y = self.fromrgb(img) x = x + y if x is not None else y img = upfirdn2d.downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None # Main layers. if self.architecture == 'resnet': y = self.skip(x, gain=np.sqrt(0.5)) x = self.conv0(x) x = self.conv1(x, gain=np.sqrt(0.5)) x = y.add_(x) else: x = self.conv0(x) x = self.conv1(x) assert x.dtype == dtype return x, img def extra_repr(self): return f'resolution={self.resolution:d}, architecture={self.architecture:s}' #---------------------------------------------------------------------------- @persistence.persistent_class class MinibatchStdLayer(torch.nn.Module): def __init__(self, group_size, num_channels=1): super().__init__() self.group_size = group_size self.num_channels = num_channels def forward(self, x): N, C, H, W = x.shape with misc.suppress_tracer_warnings(): # as_tensor results are registered as constants G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N F = self.num_channels c = C // F y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c. y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group. y = y.square().mean(dim=0) # [nFcHW] Calc variance over group. y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group. y = y.mean(dim=[2,3,4]) # [nF] Take average over channels and pixels. y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions. y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels. x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels. return x def extra_repr(self): return f'group_size={self.group_size}, num_channels={self.num_channels:d}' #---------------------------------------------------------------------------- @persistence.persistent_class class DiscriminatorEpilogue(torch.nn.Module): def __init__(self, in_channels, # Number of input channels. cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label. resolution, # Resolution of this block. img_channels, # Number of input color channels. architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch. mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable. activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. ): assert architecture in ['orig', 'skip', 'resnet'] super().__init__() self.in_channels = in_channels self.cmap_dim = cmap_dim self.resolution = resolution self.img_channels = img_channels self.architecture = architecture if architecture == 'skip': self.fromrgb = Conv2dLayer(img_channels, in_channels, kernel_size=1, activation=activation) self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, conv_clamp=conv_clamp) self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), in_channels, activation=activation) self.out = FullyConnectedLayer(in_channels, 1 if cmap_dim == 0 else cmap_dim) def forward(self, x, img, cmap, force_fp32=False): misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) # [NCHW] _ = force_fp32 # unused dtype = torch.float32 memory_format = torch.contiguous_format # FromRGB. x = x.to(dtype=dtype, memory_format=memory_format) if self.architecture == 'skip': misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution]) img = img.to(dtype=dtype, memory_format=memory_format) x = x + self.fromrgb(img) # Main layers. if self.mbstd is not None: x = self.mbstd(x) x = self.conv(x) x = self.fc(x.flatten(1)) x = self.out(x) # Conditioning. if self.cmap_dim > 0: misc.assert_shape(cmap, [None, self.cmap_dim]) x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) assert x.dtype == dtype return x def extra_repr(self): return f'resolution={self.resolution:d}, architecture={self.architecture:s}' #---------------------------------------------------------------------------- @persistence.persistent_class class Discriminator(torch.nn.Module): def __init__(self, c_dim, # Conditioning label (C) dimensionality. img_resolution, # Input resolution. img_channels, # Number of input color channels. architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. channel_base = 32768, # Overall multiplier for the number of channels. channel_max = 512, # Maximum number of channels in any layer. num_fp16_res = 4, # Use FP16 for the N highest resolutions. conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. block_kwargs = {}, # Arguments for DiscriminatorBlock. mapping_kwargs = {}, # Arguments for MappingNetwork. epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue. ): super().__init__() self.c_dim = c_dim self.img_resolution = img_resolution self.img_resolution_log2 = int(np.log2(img_resolution)) self.img_channels = img_channels self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) if cmap_dim is None: cmap_dim = channels_dict[4] if c_dim == 0: cmap_dim = 0 common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) cur_layer_idx = 0 for res in self.block_resolutions: in_channels = channels_dict[res] if res < img_resolution else 0 tmp_channels = channels_dict[res] out_channels = channels_dict[res // 2] use_fp16 = (res >= fp16_resolution) block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) setattr(self, f'b{res}', block) cur_layer_idx += block.num_layers if c_dim > 0: self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs) def forward(self, img, c, update_emas=False, **block_kwargs): _ = update_emas # unused x = None for res in self.block_resolutions: block = getattr(self, f'b{res}') x, img = block(x, img, **block_kwargs) cmap = None if self.c_dim > 0: cmap = self.mapping(None, c) x = self.b4(x, img, cmap) return x def extra_repr(self): return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}' #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/training/networks_stylegan3.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Generator architecture from the paper "Alias-Free Generative Adversarial Networks".""" import numpy as np import scipy.signal import scipy.optimize import torch import torch.nn.functional as F from torch_utils import misc from torch_utils import persistence from torch_utils.ops import conv2d_gradfix from torch_utils.ops import filtered_lrelu from torch_utils.ops import bias_act #---------------------------------------------------------------------------- @misc.profiled_function def modulated_conv2d( x, # Input tensor: [batch_size, in_channels, in_height, in_width] w, # Weight tensor: [out_channels, in_channels, kernel_height, kernel_width] s, # Style tensor: [batch_size, in_channels] demodulate = True, # Apply weight demodulation? padding = 0, # Padding: int or [padH, padW] input_gain = None, # Optional scale factors for the input channels: [], [in_channels], or [batch_size, in_channels] ): with misc.suppress_tracer_warnings(): # this value will be treated as a constant batch_size = int(x.shape[0]) out_channels, in_channels, kh, kw = w.shape misc.assert_shape(w, [out_channels, in_channels, kh, kw]) # [OIkk] misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] misc.assert_shape(s, [batch_size, in_channels]) # [NI] # Pre-normalize inputs. if demodulate: w = w * w.square().mean([1,2,3], keepdim=True).rsqrt() s = s * s.square().mean().rsqrt() # Modulate weights. w = w.unsqueeze(0) # [NOIkk] w = w * s.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk] # Demodulate weights. if demodulate: dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] w = w * dcoefs.unsqueeze(2).unsqueeze(3).unsqueeze(4) # [NOIkk] # Apply input scaling. if input_gain is not None: input_gain = input_gain.expand(batch_size, in_channels) # [NI] w = w * input_gain.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk] # Execute as one fused op using grouped convolution. x = x.reshape(1, -1, *x.shape[2:]) w = w.reshape(-1, in_channels, kh, kw) x = conv2d_gradfix.conv2d(input=x, weight=w.to(x.dtype), padding=padding, groups=batch_size) x = x.reshape(batch_size, -1, *x.shape[2:]) return x #---------------------------------------------------------------------------- @persistence.persistent_class class FullyConnectedLayer(torch.nn.Module): def __init__(self, in_features, # Number of input features. out_features, # Number of output features. activation = 'linear', # Activation function: 'relu', 'lrelu', etc. bias = True, # Apply additive bias before the activation function? lr_multiplier = 1, # Learning rate multiplier. weight_init = 1, # Initial standard deviation of the weight tensor. bias_init = 0, # Initial value of the additive bias. ): super().__init__() self.in_features = in_features self.out_features = out_features self.activation = activation self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) * (weight_init / lr_multiplier)) bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_features]) self.bias = torch.nn.Parameter(torch.from_numpy(bias_init / lr_multiplier)) if bias else None self.weight_gain = lr_multiplier / np.sqrt(in_features) self.bias_gain = lr_multiplier def forward(self, x): w = self.weight.to(x.dtype) * self.weight_gain b = self.bias if b is not None: b = b.to(x.dtype) if self.bias_gain != 1: b = b * self.bias_gain if self.activation == 'linear' and b is not None: x = torch.addmm(b.unsqueeze(0), x, w.t()) else: x = x.matmul(w.t()) x = bias_act.bias_act(x, b, act=self.activation) return x def extra_repr(self): return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}' #---------------------------------------------------------------------------- @persistence.persistent_class class MappingNetwork(torch.nn.Module): def __init__(self, z_dim, # Input latent (Z) dimensionality. c_dim, # Conditioning label (C) dimensionality, 0 = no labels. w_dim, # Intermediate latent (W) dimensionality. num_ws, # Number of intermediate latents to output. num_layers = 2, # Number of mapping layers. lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers. w_avg_beta = 0.998, # Decay for tracking the moving average of W during training. ): super().__init__() self.z_dim = z_dim self.c_dim = c_dim self.w_dim = w_dim self.num_ws = num_ws self.num_layers = num_layers self.w_avg_beta = w_avg_beta # Construct layers. self.embed = FullyConnectedLayer(self.c_dim, self.w_dim) if self.c_dim > 0 else None features = [self.z_dim + (self.w_dim if self.c_dim > 0 else 0)] + [self.w_dim] * self.num_layers for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]): layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier) setattr(self, f'fc{idx}', layer) self.register_buffer('w_avg', torch.zeros([w_dim])) def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False): misc.assert_shape(z, [None, self.z_dim]) if truncation_cutoff is None: truncation_cutoff = self.num_ws # Embed, normalize, and concatenate inputs. x = z.to(torch.float32) x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt() if self.c_dim > 0: misc.assert_shape(c, [None, self.c_dim]) y = self.embed(c.to(torch.float32)) y = y * (y.square().mean(1, keepdim=True) + 1e-8).rsqrt() x = torch.cat([x, y], dim=1) if x is not None else y # Execute layers. for idx in range(self.num_layers): x = getattr(self, f'fc{idx}')(x) # Update moving average of W. if update_emas: self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) # Broadcast and apply truncation. x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) if truncation_psi != 1: x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) return x def extra_repr(self): return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}' #---------------------------------------------------------------------------- @persistence.persistent_class class SynthesisInput(torch.nn.Module): def __init__(self, w_dim, # Intermediate latent (W) dimensionality. channels, # Number of output channels. size, # Output spatial size: int or [width, height]. sampling_rate, # Output sampling rate. bandwidth, # Output bandwidth. ): super().__init__() self.w_dim = w_dim self.channels = channels self.size = np.broadcast_to(np.asarray(size), [2]) self.sampling_rate = sampling_rate self.bandwidth = bandwidth # Draw random frequencies from uniform 2D disc. freqs = torch.randn([self.channels, 2]) radii = freqs.square().sum(dim=1, keepdim=True).sqrt() freqs /= radii * radii.square().exp().pow(0.25) freqs *= bandwidth phases = torch.rand([self.channels]) - 0.5 # Setup parameters and buffers. self.weight = torch.nn.Parameter(torch.randn([self.channels, self.channels])) self.affine = FullyConnectedLayer(w_dim, 4, weight_init=0, bias_init=[1,0,0,0]) self.register_buffer('transform', torch.eye(3, 3)) # User-specified inverse transform wrt. resulting image. self.register_buffer('freqs', freqs) self.register_buffer('phases', phases) def forward(self, w): # Introduce batch dimension. transforms = self.transform.unsqueeze(0) # [batch, row, col] freqs = self.freqs.unsqueeze(0) # [batch, channel, xy] phases = self.phases.unsqueeze(0) # [batch, channel] # Apply learned transformation. t = self.affine(w) # t = (r_c, r_s, t_x, t_y) t = t / t[:, :2].norm(dim=1, keepdim=True) # t' = (r'_c, r'_s, t'_x, t'_y) m_r = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse rotation wrt. resulting image. m_r[:, 0, 0] = t[:, 0] # r'_c m_r[:, 0, 1] = -t[:, 1] # r'_s m_r[:, 1, 0] = t[:, 1] # r'_s m_r[:, 1, 1] = t[:, 0] # r'_c m_t = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse translation wrt. resulting image. m_t[:, 0, 2] = -t[:, 2] # t'_x m_t[:, 1, 2] = -t[:, 3] # t'_y transforms = m_r @ m_t @ transforms # First rotate resulting image, then translate, and finally apply user-specified transform. # Transform frequencies. phases = phases + (freqs @ transforms[:, :2, 2:]).squeeze(2) freqs = freqs @ transforms[:, :2, :2] # Dampen out-of-band frequencies that may occur due to the user-specified transform. amplitudes = (1 - (freqs.norm(dim=2) - self.bandwidth) / (self.sampling_rate / 2 - self.bandwidth)).clamp(0, 1) # Construct sampling grid. theta = torch.eye(2, 3, device=w.device) theta[0, 0] = 0.5 * self.size[0] / self.sampling_rate theta[1, 1] = 0.5 * self.size[1] / self.sampling_rate grids = torch.nn.functional.affine_grid(theta.unsqueeze(0), [1, 1, self.size[1], self.size[0]], align_corners=False) # Compute Fourier features. x = (grids.unsqueeze(3) @ freqs.permute(0, 2, 1).unsqueeze(1).unsqueeze(2)).squeeze(3) # [batch, height, width, channel] x = x + phases.unsqueeze(1).unsqueeze(2) x = torch.sin(x * (np.pi * 2)) x = x * amplitudes.unsqueeze(1).unsqueeze(2) # Apply trainable mapping. weight = self.weight / np.sqrt(self.channels) x = x @ weight.t() # Ensure correct shape. x = x.permute(0, 3, 1, 2) # [batch, channel, height, width] misc.assert_shape(x, [w.shape[0], self.channels, int(self.size[1]), int(self.size[0])]) return x def extra_repr(self): return '\n'.join([ f'w_dim={self.w_dim:d}, channels={self.channels:d}, size={list(self.size)},', f'sampling_rate={self.sampling_rate:g}, bandwidth={self.bandwidth:g}']) #---------------------------------------------------------------------------- @persistence.persistent_class class SynthesisLayer(torch.nn.Module): def __init__(self, w_dim, # Intermediate latent (W) dimensionality. is_torgb, # Is this the final ToRGB layer? is_critically_sampled, # Does this layer use critical sampling? use_fp16, # Does this layer use FP16? # Input & output specifications. in_channels, # Number of input channels. out_channels, # Number of output channels. in_size, # Input spatial size: int or [width, height]. out_size, # Output spatial size: int or [width, height]. in_sampling_rate, # Input sampling rate (s). out_sampling_rate, # Output sampling rate (s). in_cutoff, # Input cutoff frequency (f_c). out_cutoff, # Output cutoff frequency (f_c). in_half_width, # Input transition band half-width (f_h). out_half_width, # Output Transition band half-width (f_h). # Hyperparameters. conv_kernel = 3, # Convolution kernel size. Ignored for final the ToRGB layer. filter_size = 6, # Low-pass filter size relative to the lower resolution when up/downsampling. lrelu_upsampling = 2, # Relative sampling rate for leaky ReLU. Ignored for final the ToRGB layer. use_radial_filters = False, # Use radially symmetric downsampling filter? Ignored for critically sampled layers. conv_clamp = 256, # Clamp the output to [-X, +X], None = disable clamping. magnitude_ema_beta = 0.999, # Decay rate for the moving average of input magnitudes. ): super().__init__() self.w_dim = w_dim self.is_torgb = is_torgb self.is_critically_sampled = is_critically_sampled self.use_fp16 = use_fp16 self.in_channels = in_channels self.out_channels = out_channels self.in_size = np.broadcast_to(np.asarray(in_size), [2]) self.out_size = np.broadcast_to(np.asarray(out_size), [2]) self.in_sampling_rate = in_sampling_rate self.out_sampling_rate = out_sampling_rate self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) self.in_cutoff = in_cutoff self.out_cutoff = out_cutoff self.in_half_width = in_half_width self.out_half_width = out_half_width self.conv_kernel = 1 if is_torgb else conv_kernel self.conv_clamp = conv_clamp self.magnitude_ema_beta = magnitude_ema_beta # Setup parameters and buffers. self.affine = FullyConnectedLayer(self.w_dim, self.in_channels, bias_init=1) self.weight = torch.nn.Parameter(torch.randn([self.out_channels, self.in_channels, self.conv_kernel, self.conv_kernel])) self.bias = torch.nn.Parameter(torch.zeros([self.out_channels])) self.register_buffer('magnitude_ema', torch.ones([])) # Design upsampling filter. self.up_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) assert self.in_sampling_rate * self.up_factor == self.tmp_sampling_rate self.up_taps = filter_size * self.up_factor if self.up_factor > 1 and not self.is_torgb else 1 self.register_buffer('up_filter', self.design_lowpass_filter( numtaps=self.up_taps, cutoff=self.in_cutoff, width=self.in_half_width*2, fs=self.tmp_sampling_rate)) # Design downsampling filter. self.down_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) assert self.out_sampling_rate * self.down_factor == self.tmp_sampling_rate self.down_taps = filter_size * self.down_factor if self.down_factor > 1 and not self.is_torgb else 1 self.down_radial = use_radial_filters and not self.is_critically_sampled self.register_buffer('down_filter', self.design_lowpass_filter( numtaps=self.down_taps, cutoff=self.out_cutoff, width=self.out_half_width*2, fs=self.tmp_sampling_rate, radial=self.down_radial)) # Compute padding. pad_total = (self.out_size - 1) * self.down_factor + 1 # Desired output size before downsampling. pad_total -= (self.in_size + self.conv_kernel - 1) * self.up_factor # Input size after upsampling. pad_total += self.up_taps + self.down_taps - 2 # Size reduction caused by the filters. pad_lo = (pad_total + self.up_factor) // 2 # Shift sample locations according to the symmetric interpretation (Appendix C.3). pad_hi = pad_total - pad_lo self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] def forward(self, x, w, noise_mode='random', force_fp32=False, update_emas=False): assert noise_mode in ['random', 'const', 'none'] # unused misc.assert_shape(x, [None, self.in_channels, int(self.in_size[1]), int(self.in_size[0])]) misc.assert_shape(w, [x.shape[0], self.w_dim]) # Track input magnitude. if update_emas: with torch.autograd.profiler.record_function('update_magnitude_ema'): magnitude_cur = x.detach().to(torch.float32).square().mean() self.magnitude_ema.copy_(magnitude_cur.lerp(self.magnitude_ema, self.magnitude_ema_beta)) input_gain = self.magnitude_ema.rsqrt() # Execute affine layer. styles = self.affine(w) if self.is_torgb: weight_gain = 1 / np.sqrt(self.in_channels * (self.conv_kernel ** 2)) styles = styles * weight_gain # Execute modulated conv2d. dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 x = modulated_conv2d(x=x.to(dtype), w=self.weight, s=styles, padding=self.conv_kernel-1, demodulate=(not self.is_torgb), input_gain=input_gain) # Execute bias, filtered leaky ReLU, and clamping. gain = 1 if self.is_torgb else np.sqrt(2) slope = 1 if self.is_torgb else 0.2 x = filtered_lrelu.filtered_lrelu(x=x, fu=self.up_filter, fd=self.down_filter, b=self.bias.to(x.dtype), up=self.up_factor, down=self.down_factor, padding=self.padding, gain=gain, slope=slope, clamp=self.conv_clamp) # Ensure correct shape and dtype. misc.assert_shape(x, [None, self.out_channels, int(self.out_size[1]), int(self.out_size[0])]) assert x.dtype == dtype return x @staticmethod def design_lowpass_filter(numtaps, cutoff, width, fs, radial=False): assert numtaps >= 1 # Identity filter. if numtaps == 1: return None # Separable Kaiser low-pass filter. if not radial: f = scipy.signal.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) return torch.as_tensor(f, dtype=torch.float32) # Radially symmetric jinc-based filter. x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs r = np.hypot(*np.meshgrid(x, x)) f = scipy.special.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) beta = scipy.signal.kaiser_beta(scipy.signal.kaiser_atten(numtaps, width / (fs / 2))) w = np.kaiser(numtaps, beta) f *= np.outer(w, w) f /= np.sum(f) return torch.as_tensor(f, dtype=torch.float32) def extra_repr(self): return '\n'.join([ f'w_dim={self.w_dim:d}, is_torgb={self.is_torgb},', f'is_critically_sampled={self.is_critically_sampled}, use_fp16={self.use_fp16},', f'in_sampling_rate={self.in_sampling_rate:g}, out_sampling_rate={self.out_sampling_rate:g},', f'in_cutoff={self.in_cutoff:g}, out_cutoff={self.out_cutoff:g},', f'in_half_width={self.in_half_width:g}, out_half_width={self.out_half_width:g},', f'in_size={list(self.in_size)}, out_size={list(self.out_size)},', f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}']) #---------------------------------------------------------------------------- @persistence.persistent_class class SynthesisNetwork(torch.nn.Module): def __init__(self, w_dim, # Intermediate latent (W) dimensionality. img_resolution, # Output image resolution. img_channels, # Number of color channels. channel_base = 32768, # Overall multiplier for the number of channels. channel_max = 512, # Maximum number of channels in any layer. num_layers = 14, # Total number of layers, excluding Fourier features and ToRGB. num_critical = 2, # Number of critically sampled layers at the end. first_cutoff = 2, # Cutoff frequency of the first layer (f_{c,0}). first_stopband = 2**2.1, # Minimum stopband of the first layer (f_{t,0}). last_stopband_rel = 2**0.3, # Minimum stopband of the last layer, expressed relative to the cutoff. margin_size = 10, # Number of additional pixels outside the image. output_scale = 0.25, # Scale factor for the output image. num_fp16_res = 4, # Use FP16 for the N highest resolutions. **layer_kwargs, # Arguments for SynthesisLayer. ): super().__init__() self.w_dim = w_dim self.num_ws = num_layers + 2 self.img_resolution = img_resolution self.img_channels = img_channels self.num_layers = num_layers self.num_critical = num_critical self.margin_size = margin_size self.output_scale = output_scale self.num_fp16_res = num_fp16_res # Geometric progression of layer cutoffs and min. stopbands. last_cutoff = self.img_resolution / 2 # f_{c,N} last_stopband = last_cutoff * last_stopband_rel # f_{t,N} exponents = np.minimum(np.arange(self.num_layers + 1) / (self.num_layers - self.num_critical), 1) cutoffs = first_cutoff * (last_cutoff / first_cutoff) ** exponents # f_c[i] stopbands = first_stopband * (last_stopband / first_stopband) ** exponents # f_t[i] # Compute remaining layer parameters. sampling_rates = np.exp2(np.ceil(np.log2(np.minimum(stopbands * 2, self.img_resolution)))) # s[i] half_widths = np.maximum(stopbands, sampling_rates / 2) - cutoffs # f_h[i] sizes = sampling_rates + self.margin_size * 2 sizes[-2:] = self.img_resolution channels = np.rint(np.minimum((channel_base / 2) / cutoffs, channel_max)) channels[-1] = self.img_channels # Construct layers. self.input = SynthesisInput( w_dim=self.w_dim, channels=int(channels[0]), size=int(sizes[0]), sampling_rate=sampling_rates[0], bandwidth=cutoffs[0]) self.layer_names = [] for idx in range(self.num_layers + 1): prev = max(idx - 1, 0) is_torgb = (idx == self.num_layers) is_critically_sampled = (idx >= self.num_layers - self.num_critical) use_fp16 = (sampling_rates[idx] * (2 ** self.num_fp16_res) > self.img_resolution) layer = SynthesisLayer( w_dim=self.w_dim, is_torgb=is_torgb, is_critically_sampled=is_critically_sampled, use_fp16=use_fp16, in_channels=int(channels[prev]), out_channels= int(channels[idx]), in_size=int(sizes[prev]), out_size=int(sizes[idx]), in_sampling_rate=int(sampling_rates[prev]), out_sampling_rate=int(sampling_rates[idx]), in_cutoff=cutoffs[prev], out_cutoff=cutoffs[idx], in_half_width=half_widths[prev], out_half_width=half_widths[idx], **layer_kwargs) name = f'L{idx}_{layer.out_size[0]}_{layer.out_channels}' setattr(self, name, layer) self.layer_names.append(name) def forward(self, ws, **layer_kwargs): misc.assert_shape(ws, [None, self.num_ws, self.w_dim]) ws = ws.to(torch.float32).unbind(dim=1) # Execute layers. x = self.input(ws[0]) for name, w in zip(self.layer_names, ws[1:]): x = getattr(self, name)(x, w, **layer_kwargs) if self.output_scale != 1: x = x * self.output_scale # Ensure correct shape and dtype. misc.assert_shape(x, [None, self.img_channels, self.img_resolution, self.img_resolution]) x = x.to(torch.float32) return x def extra_repr(self): return '\n'.join([ f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},', f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},', f'num_layers={self.num_layers:d}, num_critical={self.num_critical:d},', f'margin_size={self.margin_size:d}, num_fp16_res={self.num_fp16_res:d}']) #---------------------------------------------------------------------------- @persistence.persistent_class class Generator(torch.nn.Module): def __init__(self, z_dim, # Input latent (Z) dimensionality. c_dim, # Conditioning label (C) dimensionality. w_dim, # Intermediate latent (W) dimensionality. img_resolution, # Output resolution. img_channels, # Number of output color channels. mapping_kwargs = {}, # Arguments for MappingNetwork. resize=None, **synthesis_kwargs, # Arguments for SynthesisNetwork. ): super().__init__() self.z_dim = z_dim self.c_dim = c_dim self.w_dim = w_dim self.img_resolution = img_resolution self.img_channels = img_channels self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs) self.num_ws = self.synthesis.num_ws self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) self.resize = resize def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, input_is_w=False, **synthesis_kwargs): if input_is_w: ws = z if ws.dim() == 2: ws = ws.unsqueeze(1).repeat([1, self.mapping.num_ws, 1]) else: ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) if self.resize is not None: img = imresize(img, [self.resize, self.resize]) return img #---------------------------------------------------------------------------- def imresize(image, size): dim = image.dim() if dim == 3: image = image.unsqueeze(1) b, _, h, w = image.shape if size[0] > h: image = F.interpolate(image, size, mode='bilinear') elif size[0] < h: image = F.interpolate(image, size, mode='area') if dim == 3: image = image.squeeze(1) return image ================================================ FILE: stylegan_human/training/training_loop.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Main training loop.""" import os import time import copy import json import pickle import psutil import PIL.Image import numpy as np import torch import dnnlib from torch_utils import misc from torch_utils import training_stats from torch_utils.ops import conv2d_gradfix from torch_utils.ops import grid_sample_gradfix import legacy from metrics import metric_main #---------------------------------------------------------------------------- def setup_snapshot_image_grid(training_set, random_seed=0): rnd = np.random.RandomState(random_seed) gw = np.clip(7680 // training_set.image_shape[2], 7, 32) gh = np.clip(4320 // training_set.image_shape[1], 4, 32) # No labels => show random subset of training samples. if not training_set.has_labels: all_indices = list(range(len(training_set))) rnd.shuffle(all_indices) grid_indices = [all_indices[i % len(all_indices)] for i in range(gw * gh)] else: # Group training samples by label. label_groups = dict() # label => [idx, ...] for idx in range(len(training_set)): label = tuple(training_set.get_details(idx).raw_label.flat[::-1]) if label not in label_groups: label_groups[label] = [] label_groups[label].append(idx) # Reorder. label_order = sorted(label_groups.keys()) for label in label_order: rnd.shuffle(label_groups[label]) # Organize into grid. grid_indices = [] for y in range(gh): label = label_order[y % len(label_order)] indices = label_groups[label] grid_indices += [indices[x % len(indices)] for x in range(gw)] label_groups[label] = [indices[(i + gw) % len(indices)] for i in range(len(indices))] # Load data. images, labels = zip(*[training_set[i] for i in grid_indices]) return (gw, gh), np.stack(images), np.stack(labels) #---------------------------------------------------------------------------- def save_image_grid(img, fname, drange, grid_size): lo, hi = drange img = np.asarray(img, dtype=np.float32) img = (img - lo) * (255 / (hi - lo)) img = np.rint(img).clip(0, 255).astype(np.uint8) gw, gh = grid_size _N, C, H, W = img.shape img = img.reshape([gh, gw, C, H, W]) img = img.transpose(0, 3, 1, 4, 2) img = img.reshape([gh * H, gw * W, C]) assert C in [1, 3] if C == 1: PIL.Image.fromarray(img[:, :, 0], 'L').save(fname) if C == 3: PIL.Image.fromarray(img, 'RGB').save(fname) #---------------------------------------------------------------------------- def training_loop( run_dir = '.', # Output directory. training_set_kwargs = {}, # Options for training set. data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader. G_kwargs = {}, # Options for generator network. D_kwargs = {}, # Options for discriminator network. G_opt_kwargs = {}, # Options for generator optimizer. D_opt_kwargs = {}, # Options for discriminator optimizer. augment_kwargs = None, # Options for augmentation pipeline. None = disable. loss_kwargs = {}, # Options for loss function. metrics = [], # Metrics to evaluate during training. random_seed = 0, # Global random seed. num_gpus = 1, # Number of GPUs participating in the training. rank = 0, # Rank of the current process in [0, num_gpus[. batch_size = 4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus. batch_gpu = 4, # Number of samples processed at a time by one GPU. ema_kimg = 10, # Half-life of the exponential moving average (EMA) of generator weights. ema_rampup = 0.05, # EMA ramp-up coefficient. None = no rampup. G_reg_interval = None, # How often to perform regularization for G? None = disable lazy regularization. D_reg_interval = 16, # How often to perform regularization for D? None = disable lazy regularization. augment_p = 0, # Initial value of augmentation probability. ada_target = None, # ADA target value. None = fixed p. ada_interval = 4, # How often to perform ADA adjustment? ada_kimg = 500, # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit. total_kimg = 25000, # Total length of the training, measured in thousands of real images. kimg_per_tick = 4, # Progress snapshot interval. image_snapshot_ticks = 50, # How often to save image snapshots? None = disable. network_snapshot_ticks = 50, # How often to save network snapshots? None = disable. resume_pkl = None, # Network pickle to resume training from. resume_kimg = 0, # First kimg to report when resuming training. cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark? abort_fn = None, # Callback function for determining whether to abort training. Must return consistent results across ranks. progress_fn = None, # Callback function for updating training progress. Called for all ranks. ): # Initialize. start_time = time.time() device = torch.device('cuda', rank) np.random.seed(random_seed * num_gpus + rank) torch.manual_seed(random_seed * num_gpus + rank) torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed. torch.backends.cuda.matmul.allow_tf32 = False # Improves numerical accuracy. torch.backends.cudnn.allow_tf32 = False # Improves numerical accuracy. conv2d_gradfix.enabled = True # Improves training speed. grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe. # Load training set. if rank == 0: print('Loading training set...') training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs) # subclass of training.dataset.Dataset training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed) training_set_iterator = iter(torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//num_gpus, **data_loader_kwargs)) if rank == 0: print() print('Num images: ', len(training_set)) print('Image shape:', training_set.image_shape) print('Label shape:', training_set.label_shape) print() # Construct networks. if rank == 0: print('Constructing networks...') common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels) G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module G_ema = copy.deepcopy(G).eval() # Resume from existing pickle. if (resume_pkl is not None) and (rank == 0): print(f'Resuming from "{resume_pkl}"') with dnnlib.util.open_url(resume_pkl) as f: resume_data = legacy.load_network_pkl(f) for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]: misc.copy_params_and_buffers(resume_data[name], module, require_all=False) # Print network summary tables. if rank == 0: z = torch.empty([batch_gpu, G.z_dim], device=device) c = torch.empty([batch_gpu, G.c_dim], device=device) img = misc.print_module_summary(G, [z, c]) misc.print_module_summary(D, [img, c]) # Setup augmentation. if rank == 0: print('Setting up augmentation...') augment_pipe = None ada_stats = None if (augment_kwargs is not None) and (augment_p > 0 or ada_target is not None): augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module augment_pipe.p.copy_(torch.as_tensor(augment_p)) if ada_target is not None: ada_stats = training_stats.Collector(regex='Loss/signs/real') # Distribute across GPUs. if rank == 0: print(f'Distributing across {num_gpus} GPUs...') for module in [G, D, G_ema, augment_pipe]: if module is not None and num_gpus > 1: for param in misc.params_and_buffers(module): torch.distributed.broadcast(param, src=0) # Setup training phases. if rank == 0: print('Setting up training phases...') loss = dnnlib.util.construct_class_by_name(device=device, G=G, D=D, augment_pipe=augment_pipe, **loss_kwargs) # subclass of training.loss.Loss phases = [] for name, module, opt_kwargs, reg_interval in [('G', G, G_opt_kwargs, G_reg_interval), ('D', D, D_opt_kwargs, D_reg_interval)]: if reg_interval is None: opt = dnnlib.util.construct_class_by_name(params=module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer phases += [dnnlib.EasyDict(name=name+'both', module=module, opt=opt, interval=1)] else: # Lazy regularization. mb_ratio = reg_interval / (reg_interval + 1) opt_kwargs = dnnlib.EasyDict(opt_kwargs) opt_kwargs.lr = opt_kwargs.lr * mb_ratio opt_kwargs.betas = [beta ** mb_ratio for beta in opt_kwargs.betas] opt = dnnlib.util.construct_class_by_name(module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer phases += [dnnlib.EasyDict(name=name+'main', module=module, opt=opt, interval=1)] phases += [dnnlib.EasyDict(name=name+'reg', module=module, opt=opt, interval=reg_interval)] for phase in phases: phase.start_event = None phase.end_event = None if rank == 0: phase.start_event = torch.cuda.Event(enable_timing=True) phase.end_event = torch.cuda.Event(enable_timing=True) # Export sample images. grid_size = None grid_z = None grid_c = None if rank == 0: print('Exporting sample images...') grid_size, images, labels = setup_snapshot_image_grid(training_set=training_set) save_image_grid(images, os.path.join(run_dir, 'reals.png'), drange=[0,255], grid_size=grid_size) grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu) grid_c = torch.from_numpy(labels).to(device).split(batch_gpu) images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy() save_image_grid(images, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size) # Initialize logs. if rank == 0: print('Initializing logs...') stats_collector = training_stats.Collector(regex='.*') stats_metrics = dict() stats_jsonl = None stats_tfevents = None if rank == 0: stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt') try: import torch.utils.tensorboard as tensorboard stats_tfevents = tensorboard.SummaryWriter(run_dir) except ImportError as err: print('Skipping tfevents export:', err) # Train. if rank == 0: print(f'Training for {total_kimg} kimg...') print() cur_nimg = resume_kimg * 1000 cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() maintenance_time = tick_start_time - start_time batch_idx = 0 if progress_fn is not None: progress_fn(0, total_kimg) while True: # Fetch training data. with torch.autograd.profiler.record_function('data_fetch'): phase_real_img, phase_real_c = next(training_set_iterator) phase_real_img = (phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu) phase_real_c = phase_real_c.to(device).split(batch_gpu) all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim], device=device) all_gen_z = [phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split(batch_size)] all_gen_c = [training_set.get_label(np.random.randint(len(training_set))) for _ in range(len(phases) * batch_size)] all_gen_c = torch.from_numpy(np.stack(all_gen_c)).pin_memory().to(device) all_gen_c = [phase_gen_c.split(batch_gpu) for phase_gen_c in all_gen_c.split(batch_size)] # Execute training phases. for phase, phase_gen_z, phase_gen_c in zip(phases, all_gen_z, all_gen_c): if batch_idx % phase.interval != 0: continue if phase.start_event is not None: phase.start_event.record(torch.cuda.current_stream(device)) # Accumulate gradients. phase.opt.zero_grad(set_to_none=True) phase.module.requires_grad_(True) for real_img, real_c, gen_z, gen_c in zip(phase_real_img, phase_real_c, phase_gen_z, phase_gen_c): loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, gain=phase.interval, cur_nimg=cur_nimg) phase.module.requires_grad_(False) # Update weights. with torch.autograd.profiler.record_function(phase.name + '_opt'): params = [param for param in phase.module.parameters() if param.grad is not None] if len(params) > 0: flat = torch.cat([param.grad.flatten() for param in params]) if num_gpus > 1: torch.distributed.all_reduce(flat) flat /= num_gpus misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) grads = flat.split([param.numel() for param in params]) for param, grad in zip(params, grads): param.grad = grad.reshape(param.shape) phase.opt.step() # Phase done. if phase.end_event is not None: phase.end_event.record(torch.cuda.current_stream(device)) # Update G_ema. with torch.autograd.profiler.record_function('Gema'): ema_nimg = ema_kimg * 1000 if ema_rampup is not None: ema_nimg = min(ema_nimg, cur_nimg * ema_rampup) ema_beta = 0.5 ** (batch_size / max(ema_nimg, 1e-8)) for p_ema, p in zip(G_ema.parameters(), G.parameters()): p_ema.copy_(p.lerp(p_ema, ema_beta)) for b_ema, b in zip(G_ema.buffers(), G.buffers()): b_ema.copy_(b) # Update state. cur_nimg += batch_size batch_idx += 1 # Execute ADA heuristic. if (ada_stats is not None) and (batch_idx % ada_interval == 0): ada_stats.update() adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) * (batch_size * ada_interval) / (ada_kimg * 1000) augment_pipe.p.copy_((augment_pipe.p + adjust).max(misc.constant(0, device=device))) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000): continue # Print status line, accumulating the same information in training_stats. tick_end_time = time.time() fields = [] fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"] fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}"] fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"] fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"] fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"] fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"] fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"] fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"] fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"] torch.cuda.reset_peak_memory_stats() fields += [f"augment {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.3f}"] training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60)) training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60)) if rank == 0: print(' '.join(fields)) # Check for abort. if (not done) and (abort_fn is not None) and abort_fn(): done = True if rank == 0: print() print('Aborting...') # Save image snapshot. if (rank == 0) and (image_snapshot_ticks is not None) and (done or cur_tick % image_snapshot_ticks == 0): images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy() save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'), drange=[-1,1], grid_size=grid_size) # Save network snapshot. snapshot_pkl = None snapshot_data = None if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0): snapshot_data = dict(G=G, D=D, G_ema=G_ema, augment_pipe=augment_pipe, training_set_kwargs=dict(training_set_kwargs)) for key, value in snapshot_data.items(): if isinstance(value, torch.nn.Module): value = copy.deepcopy(value).eval().requires_grad_(False) if num_gpus > 1: misc.check_ddp_consistency(value, ignore_regex=r'.*\.[^.]+_(avg|ema)') for param in misc.params_and_buffers(value): torch.distributed.broadcast(param, src=0) snapshot_data[key] = value.cpu() del value # conserve memory snapshot_pkl = os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl') if rank == 0: with open(snapshot_pkl, 'wb') as f: pickle.dump(snapshot_data, f) # Evaluate metrics. if (snapshot_data is not None) and (len(metrics) > 0): if rank == 0: print('Evaluating metrics...') for metric in metrics: result_dict = metric_main.calc_metric(metric=metric, G=snapshot_data['G_ema'], dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, rank=rank, device=device) if rank == 0: metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl) stats_metrics.update(result_dict.results) del snapshot_data # conserve memory # Collect statistics. for phase in phases: value = [] if (phase.start_event is not None) and (phase.end_event is not None): phase.end_event.synchronize() value = phase.start_event.elapsed_time(phase.end_event) training_stats.report0('Timing/' + phase.name, value) stats_collector.update() stats_dict = stats_collector.as_dict() # Update logs. timestamp = time.time() if stats_jsonl is not None: fields = dict(stats_dict, timestamp=timestamp) stats_jsonl.write(json.dumps(fields) + '\n') stats_jsonl.flush() if stats_tfevents is not None: global_step = int(cur_nimg / 1e3) walltime = timestamp - start_time for name, value in stats_dict.items(): stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime) for name, value in stats_metrics.items(): stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime) stats_tfevents.flush() if progress_fn is not None: progress_fn(cur_nimg // 1000, total_kimg) # Update state. cur_tick += 1 tick_start_nimg = cur_nimg tick_start_time = time.time() maintenance_time = tick_start_time - tick_end_time if done: break # Done. if rank == 0: print() print('Exiting...') #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/training_scripts/sg2/train.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Train a GAN using the techniques described in the paper "Training Generative Adversarial Networks with Limited Data".""" import os import click import re import json import tempfile import torch import dnnlib import ast from training import training_loop from metrics import metric_main from torch_utils import training_stats from torch_utils import custom_ops #---------------------------------------------------------------------------- class UserError(Exception): pass #---------------------------------------------------------------------------- def setup_training_loop_kwargs( # General options (not included in desc). gpus = None, # Number of GPUs: , default = 1 gpu snap = None, # Snapshot interval: , default = 50 ticks metrics = None, # List of metric names: [], ['fid50k_full'] (default), ... seed = None, # Random seed: , default = 0 # Dataset. data = None, # Training dataset (required): cond = None, # Train conditional model based on dataset labels: , default = False subset = None, # Train with only N images: , default = all mirror = None, # Augment dataset with x-flips: , default = False square = None, # Base config. cfg = None, # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar', 'shhq' gamma = None, # Override R1 gamma: kimg = None, # Override training duration: batch = None, # Override batch size: # Discriminator augmentation. aug = None, # Augmentation mode: 'ada' (default), 'noaug', 'fixed' p = None, # Specify p for 'fixed' (required): target = None, # Override ADA target for 'ada': , default = depends on aug augpipe = None, # Augmentation pipeline: 'blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc' (default), ..., 'bgcfnc' # Transfer learning. resume = None, # Load previous network: 'noresume' (default), 'ffhq256', 'ffhq512', 'ffhq1024', 'celebahq256', 'lsundog256', , freezed = None, # Freeze-D: , default = 0 discriminator layers # Performance options (not included in desc). fp32 = None, # Disable mixed-precision training: , default = False nhwc = None, # Use NHWC memory format with FP16: , default = False allow_tf32 = None, # Allow PyTorch to use TF32 for matmul and convolutions: , default = False nobench = None, # Disable cuDNN benchmarking: , default = False workers = None, # Override number of DataLoader workers: , default = 3 ): args = dnnlib.EasyDict() # ------------------------------------------ # General options: gpus, snap, metrics, seed # ------------------------------------------ if gpus is None: gpus = 1 assert isinstance(gpus, int) if not (gpus >= 1 and gpus & (gpus - 1) == 0): raise UserError('--gpus must be a power of two') args.num_gpus = gpus if snap is None: snap = 50 assert isinstance(snap, int) if snap < 1: raise UserError('--snap must be at least 1') args.image_snapshot_ticks = snap args.network_snapshot_ticks = snap if metrics is None: metrics = ['fid50k_full'] assert isinstance(metrics, list) if not all(metric_main.is_valid_metric(metric) for metric in metrics): raise UserError('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) args.metrics = metrics if seed is None: seed = 0 assert isinstance(seed, int) args.random_seed = seed # ------------------------------------------- # Dataset: data, cond, subset, mirror, square # ------------------------------------------- print('square : ', square) assert data is not None assert isinstance(data, str) args.training_set_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data, use_labels=True, max_size=None, xflip=False, square=square) args.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=3, prefetch_factor=2) try: training_set = dnnlib.util.construct_class_by_name(**args.training_set_kwargs) # subclass of training.dataset.Dataset args.training_set_kwargs.resolution = training_set.resolution # be explicit about resolution args.training_set_kwargs.use_labels = training_set.has_labels # be explicit about labels args.training_set_kwargs.max_size = len(training_set) # be explicit about dataset size desc = training_set.name print('desc: ', desc) del training_set # conserve memory except IOError as err: raise UserError(f'--data: {err}') if square: desc += '-square' else: desc += '-rectangle' if cond is None: cond = False assert isinstance(cond, bool) if cond: if not args.training_set_kwargs.use_labels: raise UserError('--cond=True requires labels specified in dataset.json') desc += '-cond' else: args.training_set_kwargs.use_labels = False if subset is not None: assert isinstance(subset, int) if not 1 <= subset <= args.training_set_kwargs.max_size: raise UserError(f'--subset must be between 1 and {args.training_set_kwargs.max_size}') desc += f'-subset{subset}' if subset < args.training_set_kwargs.max_size: args.training_set_kwargs.max_size = subset args.training_set_kwargs.random_seed = args.random_seed if mirror is None: mirror = False assert isinstance(mirror, bool) if mirror: desc += '-mirror' args.training_set_kwargs.xflip = True # ------------------------------------ # Base config: cfg, gamma, kimg, batch # ------------------------------------ if cfg is None: cfg = 'auto' assert isinstance(cfg, str) desc += f'-{cfg}' cfg_specs = { 'auto': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=2), 'shhq': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=8), # Populated dynamically based on resolution and GPU count. 'stylegan2': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # Uses mixed-precision, unlike the original StyleGAN2. 'paper256': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=0.5, lrate=0.0025, gamma=1, ema=20, ramp=None, map=8), 'paper512': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=1, lrate=0.0025, gamma=0.5, ema=20, ramp=None, map=8), 'paper1024': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=2, ema=10, ramp=None, map=8), 'cifar': dict(ref_gpus=2, kimg=100000, mb=64, mbstd=32, fmaps=1, lrate=0.0025, gamma=0.01, ema=500, ramp=0.05, map=2), } assert cfg in cfg_specs spec = dnnlib.EasyDict(cfg_specs[cfg]) if cfg == 'auto' or cfg == 'shhq': desc += f'{gpus:d}' spec.ref_gpus = gpus res = args.training_set_kwargs.resolution spec.mb = max(min(gpus * min(4096 // res, 32), 64), gpus) # keep gpu memory consumption at bay spec.mbstd = min(spec.mb // gpus, 4) # other hyperparams behave more predictably if mbstd group size remains fixed spec.fmaps = 1 if res >= 512 else 0.5 spec.lrate = 0.002 if res >= 1024 else 0.0025 spec.gamma = 0.0002 * (res ** 2) / spec.mb # heuristic formula spec.ema = spec.mb * 10 / 32 args.G_kwargs = dnnlib.EasyDict(class_name='training.networks.Generator', z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict(), synthesis_kwargs=dnnlib.EasyDict(),square=square) args.D_kwargs = dnnlib.EasyDict(class_name='training.networks.Discriminator', block_kwargs=dnnlib.EasyDict(), mapping_kwargs=dnnlib.EasyDict(), epilogue_kwargs=dnnlib.EasyDict(),square=square) args.G_kwargs.synthesis_kwargs.channel_base = args.D_kwargs.channel_base = int(spec.fmaps * 32768) args.G_kwargs.synthesis_kwargs.channel_max = args.D_kwargs.channel_max = 512 args.G_kwargs.mapping_kwargs.num_layers = spec.map args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 4 # enable mixed-precision training args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = 256 # clamp activations to avoid float16 overflow args.D_kwargs.epilogue_kwargs.mbstd_group_size = spec.mbstd args.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0,0.99], eps=1e-8) args.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0,0.99], eps=1e-8) args.loss_kwargs = dnnlib.EasyDict(class_name='training.loss.StyleGAN2Loss', r1_gamma=spec.gamma) args.total_kimg = spec.kimg args.batch_size = spec.mb args.batch_gpu = spec.mb // spec.ref_gpus args.ema_kimg = spec.ema args.ema_rampup = spec.ramp if cfg == 'cifar': args.loss_kwargs.pl_weight = 0 # disable path length regularization args.loss_kwargs.style_mixing_prob = 0 # disable style mixing args.D_kwargs.architecture = 'orig' # disable residual skip connections if gamma is not None: assert isinstance(gamma, float) if not gamma >= 0: raise UserError('--gamma must be non-negative') desc += f'-gamma{gamma:g}' args.loss_kwargs.r1_gamma = gamma if kimg is not None: assert isinstance(kimg, int) if not kimg >= 1: raise UserError('--kimg must be at least 1') desc += f'-kimg{kimg:d}' args.total_kimg = kimg if batch is not None: assert isinstance(batch, int) if not (batch >= 1 and batch % gpus == 0): raise UserError('--batch must be at least 1 and divisible by --gpus') desc += f'-batch{batch}' args.batch_size = batch args.batch_gpu = batch // gpus # --------------------------------------------------- # Discriminator augmentation: aug, p, target, augpipe # --------------------------------------------------- if aug is None: aug = 'ada' else: assert isinstance(aug, str) desc += f'-{aug}' if aug == 'ada': args.ada_target = 0.6 elif aug == 'noaug': pass elif aug == 'fixed': if p is None: raise UserError(f'--aug={aug} requires specifying --p') else: raise UserError(f'--aug={aug} not supported') if p is not None: assert isinstance(p, float) if aug != 'fixed': raise UserError('--p can only be specified with --aug=fixed') if not 0 <= p <= 1: raise UserError('--p must be between 0 and 1') desc += f'-p{p:g}' args.augment_p = p if target is not None: assert isinstance(target, float) if aug != 'ada': raise UserError('--target can only be specified with --aug=ada') if not 0 <= target <= 1: raise UserError('--target must be between 0 and 1') desc += f'-target{target:g}' args.ada_target = target assert augpipe is None or isinstance(augpipe, str) if augpipe is None: augpipe = 'bgc' else: if aug == 'noaug': raise UserError('--augpipe cannot be specified with --aug=noaug') desc += f'-{augpipe}' augpipe_specs = { 'blit': dict(xflip=1, rotate90=1, xint=1), 'geom': dict(scale=1, rotate=1, aniso=1, xfrac=1), 'color': dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1), 'filter': dict(imgfilter=1), 'noise': dict(noise=1), 'cutout': dict(cutout=1), 'bg': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1), 'bgc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1), 'bgcf': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1), 'bgcfn': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1), 'bgcfnc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1, cutout=1), 'body': dict(xflip=1, rotate90=0, xint=1, scale=1, rotate=0, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1) } assert augpipe in augpipe_specs if aug != 'noaug': args.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', **augpipe_specs[augpipe]) # ---------------------------------- # Transfer learning: resume, freezed # ---------------------------------- resume_specs = { 'ffhq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl', 'ffhq512': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl', 'ffhq1024': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl', 'celebahq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl', 'lsundog256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl', } assert resume is None or isinstance(resume, str) if resume is None: resume = 'noresume' elif resume == 'noresume': desc += '-noresume' elif resume in resume_specs: desc += f'-resume{resume}' args.resume_pkl = resume_specs[resume] # predefined url else: desc += '-resumecustom' args.resume_pkl = resume # custom path or url if resume != 'noresume': args.ada_kimg = 100 # make ADA react faster at the beginning args.ema_rampup = None # disable EMA rampup if freezed is not None: assert isinstance(freezed, int) if not freezed >= 0: raise UserError('--freezed must be non-negative') desc += f'-freezed{freezed:d}' args.D_kwargs.block_kwargs.freeze_layers = freezed # ------------------------------------------------- # Performance options: fp32, nhwc, nobench, workers # ------------------------------------------------- if fp32 is None: fp32 = False assert isinstance(fp32, bool) if fp32: args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 0 args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = None if nhwc is None: nhwc = False assert isinstance(nhwc, bool) if nhwc: args.G_kwargs.synthesis_kwargs.fp16_channels_last = args.D_kwargs.block_kwargs.fp16_channels_last = True if nobench is None: nobench = False assert isinstance(nobench, bool) if nobench: args.cudnn_benchmark = False if allow_tf32 is None: allow_tf32 = False assert isinstance(allow_tf32, bool) if allow_tf32: args.allow_tf32 = True if workers is not None: assert isinstance(workers, int) if not workers >= 1: raise UserError('--workers must be at least 1') args.data_loader_kwargs.num_workers = workers return desc, args #---------------------------------------------------------------------------- def subprocess_fn(rank, args, temp_dir): dnnlib.util.Logger(file_name=os.path.join(args.run_dir, 'log.txt'), file_mode='a', should_flush=True) # Init torch.distributed. if args.num_gpus > 1: init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) if os.name == 'nt': init_method = 'file:///' + init_file.replace('\\', '/') torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus) else: init_method = f'file://{init_file}' torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus) # Init torch_utils. sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) if rank != 0: custom_ops.verbosity = 'none' # Execute training loop. training_loop.training_loop(rank=rank, **args) #---------------------------------------------------------------------------- class CommaSeparatedList(click.ParamType): name = 'list' def convert(self, value, param, ctx): _ = param, ctx if value is None or value.lower() == 'none' or value == '': return [] return value.split(',') #---------------------------------------------------------------------------- @click.command() @click.pass_context # General options. @click.option('--outdir', help='Where to save the results', required=True, metavar='DIR') @click.option('--gpus', help='Number of GPUs to use [default: 1]', type=int, metavar='INT') @click.option('--snap', help='Snapshot interval [default: 50 ticks]', type=int, metavar='INT') @click.option('--metrics', help='Comma-separated list or "none" [default: fid50k_full]', type=CommaSeparatedList()) @click.option('--seed', help='Random seed [default: 0]', type=int, metavar='INT') @click.option('-n', '--dry-run', help='Print training options and exit', is_flag=True) # Dataset. @click.option('--data', help='Training data (directory or zip)', metavar='PATH', required=True) @click.option('--cond', help='Train conditional model based on dataset labels [default: false]', type=bool, metavar='BOOL') @click.option('--subset', help='Train with only N images [default: all]', type=int, metavar='INT') @click.option('--mirror', help='Enable dataset x-flips [default: false]', type=bool, metavar='BOOL') @click.option('--square', help='True for square, False for rectangle', type=bool, metavar='BOOL', default=False) # Base config. @click.option('--cfg', help='Base config [default: auto]', type=click.Choice(['auto', 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar','shhq'])) @click.option('--gamma', help='Override R1 gamma', type=float) @click.option('--kimg', help='Override training duration', type=int, metavar='INT') @click.option('--batch', help='Override batch size', type=int, metavar='INT') # Discriminator augmentation. @click.option('--aug', help='Augmentation mode [default: ada]', type=click.Choice(['noaug', 'ada', 'fixed'])) @click.option('--p', help='Augmentation probability for --aug=fixed', type=float) @click.option('--target', help='ADA target value for --aug=ada', type=float) @click.option('--augpipe', help='Augmentation pipeline [default: bgc]', type=click.Choice(['blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc', 'bgcf', 'bgcfn', 'bgcfnc', 'body'])) # Transfer learning. @click.option('--resume', help='Resume training [default: noresume]', metavar='PKL') @click.option('--freezed', help='Freeze-D [default: 0 layers]', type=int, metavar='INT') # Performance options. @click.option('--fp32', help='Disable mixed-precision training', type=bool, metavar='BOOL') @click.option('--nhwc', help='Use NHWC memory format with FP16', type=bool, metavar='BOOL') @click.option('--nobench', help='Disable cuDNN benchmarking', type=bool, metavar='BOOL') @click.option('--allow-tf32', help='Allow PyTorch to use TF32 internally', type=bool, metavar='BOOL') @click.option('--workers', help='Override number of DataLoader workers', type=int, metavar='INT') def main(ctx, outdir, dry_run, **config_kwargs): """Train a GAN using the techniques described in the paper "Training Generative Adversarial Networks with Limited Data". Examples: \b # Train with custom dataset using 1 GPU. python train.py --outdir=~/training-runs --data=~/mydataset.zip --gpus=1 \b # Train class-conditional CIFAR-10 using 2 GPUs. python train.py --outdir=~/training-runs --data=~/datasets/cifar10.zip \\ --gpus=2 --cfg=cifar --cond=1 \b # Transfer learn MetFaces from FFHQ using 4 GPUs. python train.py --outdir=~/training-runs --data=~/datasets/metfaces.zip \\ --gpus=4 --cfg=paper1024 --mirror=1 --resume=ffhq1024 --snap=10 \b # Reproduce original StyleGAN2 config F. python train.py --outdir=~/training-runs --data=~/datasets/ffhq.zip \\ --gpus=8 --cfg=stylegan2 --mirror=1 --aug=noaug \b Base configs (--cfg): auto Automatically select reasonable defaults based on resolution and GPU count. Good starting point for new datasets. stylegan2 Reproduce results for StyleGAN2 config F at 1024x1024. paper256 Reproduce results for FFHQ and LSUN Cat at 256x256. paper512 Reproduce results for BreCaHAD and AFHQ at 512x512. paper1024 Reproduce results for MetFaces at 1024x1024. cifar Reproduce results for CIFAR-10 at 32x32. \b Transfer learning source networks (--resume): ffhq256 FFHQ trained at 256x256 resolution. ffhq512 FFHQ trained at 512x512 resolution. ffhq1024 FFHQ trained at 1024x1024 resolution. celebahq256 CelebA-HQ trained at 256x256 resolution. lsundog256 LSUN Dog trained at 256x256 resolution. Custom network pickle. """ dnnlib.util.Logger(should_flush=True) # Setup training options. try: run_desc, args = setup_training_loop_kwargs(**config_kwargs) except UserError as err: ctx.fail(err) # Pick output directory. prev_run_dirs = [] if os.path.isdir(outdir): prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))] prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs] prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None] cur_run_id = max(prev_run_ids, default=-1) + 1 args.run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{run_desc}') assert not os.path.exists(args.run_dir) # Print options. print() print('Training options:') print(json.dumps(args, indent=2)) print() print(f'Output directory: {args.run_dir}') print(f'Training data: {args.training_set_kwargs.path}') print(f'Training duration: {args.total_kimg} kimg') print(f'Number of GPUs: {args.num_gpus}') print(f'Number of images: {args.training_set_kwargs.max_size}') print(f'Image resolution: {args.training_set_kwargs.resolution}') print(f'Conditional model: {args.training_set_kwargs.use_labels}') print(f'Dataset x-flips: {args.training_set_kwargs.xflip}') print() # Dry run? if dry_run: print('Dry run; exiting.') return # Create output directory. print('Creating output directory...') os.makedirs(args.run_dir, exist_ok=True) with open(os.path.join(args.run_dir, 'training_options.json'), 'wt') as f: json.dump(args, f, indent=2) # Launch processes. print('Launching processes...') torch.multiprocessing.set_start_method('spawn') with tempfile.TemporaryDirectory() as temp_dir: if args.num_gpus == 1: subprocess_fn(rank=0, args=args, temp_dir=temp_dir) else: torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus) #---------------------------------------------------------------------------- if __name__ == "__main__": main() # pylint: disable=no-value-for-parameter #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/training_scripts/sg2/training/dataset.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import os import numpy as np import zipfile import PIL.Image import json import torch import dnnlib import cv2 from collections import Counter try: import pyspng except ImportError: pyspng = None #---------------------------------------------------------------------------- class Dataset(torch.utils.data.Dataset): def __init__(self, name, # Name of the dataset. raw_shape, # Shape of the raw image data (NCHW). max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. use_labels = False, # Enable conditioning labels? False = label dimension is zero. xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. random_seed = 0, # Random seed to use when applying max_size. square = False, ): # print(' Inside Dataset ') self._name = name self._raw_shape = list(raw_shape) self._use_labels = use_labels self._raw_labels = None self._label_shape = None self._square = square # Apply max_size. self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) if (max_size is not None) and (self._raw_idx.size > max_size): np.random.RandomState(random_seed).shuffle(self._raw_idx) self._raw_idx = np.sort(self._raw_idx[:max_size]) # Apply xflip. self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) if xflip: self._raw_idx = np.tile(self._raw_idx, 2) self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) def _get_raw_labels(self): if self._raw_labels is None: self._raw_labels = self._load_raw_labels() if self._use_labels else None if self._raw_labels is None: self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) assert isinstance(self._raw_labels, np.ndarray) assert self._raw_labels.shape[0] == self._raw_shape[0] assert self._raw_labels.dtype in [np.float32, np.int64] if self._raw_labels.dtype == np.int64: assert self._raw_labels.ndim == 1 assert np.all(self._raw_labels >= 0) return self._raw_labels def close(self): # to be overridden by subclass pass def _load_raw_image(self, raw_idx): # to be overridden by subclass raise NotImplementedError def _load_raw_labels(self): # to be overridden by subclass raise NotImplementedError def __getstate__(self): return dict(self.__dict__, _raw_labels=None) def __del__(self): try: self.close() except: pass def __len__(self): return self._raw_idx.size def __getitem__(self, idx): image = self._load_raw_image(self._raw_idx[idx]) assert isinstance(image, np.ndarray) assert list(image.shape) == self.image_shape assert image.dtype == np.uint8 if self._xflip[idx]: assert image.ndim == 3 # CHW image = image[:, :, ::-1] return image.copy(), self.get_label(idx) def get_label(self, idx): label = self._get_raw_labels()[self._raw_idx[idx]] if label.dtype == np.int64: onehot = np.zeros(self.label_shape, dtype=np.float32) onehot[label] = 1 label = onehot return label.copy() def get_details(self, idx): d = dnnlib.EasyDict() d.raw_idx = int(self._raw_idx[idx]) d.xflip = (int(self._xflip[idx]) != 0) d.raw_label = self._get_raw_labels()[d.raw_idx].copy() return d @property def name(self): return self._name @property def image_shape(self): return list(self._raw_shape[1:]) @property def num_channels(self): assert len(self.image_shape) == 3 # CHW return self.image_shape[0] @property def resolution(self): assert len(self.image_shape) == 3 # CHW if self._square: assert self.image_shape[1] == self.image_shape[2] else: assert self.image_shape[1] == self.image_shape[2] * 2 return self.image_shape[1] @property def label_shape(self): if self._label_shape is None: raw_labels = self._get_raw_labels() if raw_labels.dtype == np.int64: self._label_shape = [int(np.max(raw_labels)) + 1] else: self._label_shape = raw_labels.shape[1:] return list(self._label_shape) @property def label_dim(self): assert len(self.label_shape) == 1 return self.label_shape[0] @property def has_labels(self): return any(x != 0 for x in self.label_shape) @property def has_onehot_labels(self): return self._get_raw_labels().dtype == np.int64 #---------------------------------------------------------------------------- class ImageFolderDataset(Dataset): def __init__(self, path, # Path to directory or zip. resolution = None, # Ensure specific resolution, None = highest available. square = False, **super_kwargs, # Additional arguments for the Dataset base class. ): self._path = path self._zipfile = None self._square = square if os.path.isdir(self._path): self._type = 'dir' self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} elif self._file_ext(self._path) == '.zip': self._type = 'zip' self._all_fnames = set(self._get_zipfile().namelist()) else: raise IOError('Path must point to a directory or zip') PIL.Image.init() self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) if len(self._image_fnames) == 0: raise IOError('No image files found in the specified path') name = os.path.splitext(os.path.basename(self._path))[0] raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) # if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): # raise IOError('Image files do not match the specified resolution') if resolution is not None: if self._square: raw_shape[2] = raw_shape[3] = resolution else: raw_shape[2] = resolution raw_shape[3] = resolution // 2 # print(raw_shape) super().__init__(name=name, raw_shape=raw_shape,square=square, **super_kwargs) @staticmethod def _file_ext(fname): return os.path.splitext(fname)[1].lower() def _get_zipfile(self): assert self._type == 'zip' if self._zipfile is None: self._zipfile = zipfile.ZipFile(self._path) return self._zipfile def _open_file(self, fname): if self._type == 'dir': return open(os.path.join(self._path, fname), 'rb') if self._type == 'zip': return self._get_zipfile().open(fname, 'r') return None def close(self): try: if self._zipfile is not None: self._zipfile.close() finally: self._zipfile = None def __getstate__(self): return dict(super().__getstate__(), _zipfile=None) def _load_raw_image(self, raw_idx): #load single image fname = self._image_fnames[raw_idx] with self._open_file(fname) as f: if pyspng is not None and self._file_ext(fname) == '.png': image = pyspng.load(f.read()) else: image = np.array(PIL.Image.open(f)) if image.ndim == 2: image = image[:, :, np.newaxis] # HW => HWC image = image.transpose(2, 0, 1) # HWC => CHW return image def _load_raw_labels(self): fname = 'dataset.json' if fname not in self._all_fnames: return None with self._open_file(fname) as f: labels = json.load(f)['labels'] if labels is None: return None labels = dict(labels) labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] labels = np.array(labels) labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) return labels #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/training_scripts/sg2/training/networks.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import numpy as np import torch from torch_utils import misc from torch_utils import persistence from torch_utils.ops import conv2d_resample from torch_utils.ops import upfirdn2d from torch_utils.ops import bias_act from torch_utils.ops import fma #---------------------------------------------------------------------------- @misc.profiled_function def normalize_2nd_moment(x, dim=1, eps=1e-8): return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() #---------------------------------------------------------------------------- @misc.profiled_function def modulated_conv2d( x, # Input tensor of shape [batch_size, in_channels, in_height, in_width]. weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width]. styles, # Modulation coefficients of shape [batch_size, in_channels]. noise = None, # Optional noise tensor to add to the output activations. up = 1, # Integer upsampling factor. down = 1, # Integer downsampling factor. padding = 0, # Padding with respect to the upsampled image. resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter(). demodulate = True, # Apply weight demodulation? flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d). fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation? ): batch_size = x.shape[0] out_channels, in_channels, kh, kw = weight.shape misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk] misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] misc.assert_shape(styles, [batch_size, in_channels]) # [NI] # Pre-normalize inputs to avoid FP16 overflow. if x.dtype == torch.float16 and demodulate: weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I # Calculate per-sample weights and demodulation coefficients. w = None dcoefs = None if demodulate or fused_modconv: w = weight.unsqueeze(0) # [NOIkk] w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] if demodulate: dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] if demodulate and fused_modconv: w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] # Execute by scaling the activations before and after the convolution. if not fused_modconv: x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight) if demodulate and noise is not None: x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) elif demodulate: x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1) elif noise is not None: x = x.add_(noise.to(x.dtype)) return x # Execute as one fused op using grouped convolution. with misc.suppress_tracer_warnings(): # this value will be treated as a constant batch_size = int(batch_size) misc.assert_shape(x, [batch_size, in_channels, None, None]) x = x.reshape(1, -1, *x.shape[2:]) w = w.reshape(-1, in_channels, kh, kw) x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight) x = x.reshape(batch_size, -1, *x.shape[2:]) if noise is not None: x = x.add_(noise) return x #---------------------------------------------------------------------------- @persistence.persistent_class class FullyConnectedLayer(torch.nn.Module): def __init__(self, in_features, # Number of input features. out_features, # Number of output features. bias = True, # Apply additive bias before the activation function? activation = 'linear', # Activation function: 'relu', 'lrelu', etc. lr_multiplier = 1, # Learning rate multiplier. bias_init = 0, # Initial value for the additive bias. ): super().__init__() self.activation = activation self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None self.weight_gain = lr_multiplier / np.sqrt(in_features) self.bias_gain = lr_multiplier def forward(self, x): w = self.weight.to(x.dtype) * self.weight_gain b = self.bias if b is not None: b = b.to(x.dtype) if self.bias_gain != 1: b = b * self.bias_gain if self.activation == 'linear' and b is not None: x = torch.addmm(b.unsqueeze(0), x, w.t()) else: x = x.matmul(w.t()) x = bias_act.bias_act(x, b, act=self.activation) return x #---------------------------------------------------------------------------- @persistence.persistent_class class Conv2dLayer(torch.nn.Module): def __init__(self, in_channels, # Number of input channels. out_channels, # Number of output channels. kernel_size, # Width and height of the convolution kernel. bias = True, # Apply additive bias before the activation function? activation = 'linear', # Activation function: 'relu', 'lrelu', etc. up = 1, # Integer upsampling factor. down = 1, # Integer downsampling factor. resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. conv_clamp = None, # Clamp the output to +-X, None = disable clamping. channels_last = False, # Expect the input to have memory_format=channels_last? trainable = True, # Update the weights of this layer during training? ): super().__init__() self.activation = activation self.up = up self.down = down self.conv_clamp = conv_clamp self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) self.padding = kernel_size // 2 self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) self.act_gain = bias_act.activation_funcs[activation].def_gain memory_format = torch.channels_last if channels_last else torch.contiguous_format weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format) bias = torch.zeros([out_channels]) if bias else None if trainable: self.weight = torch.nn.Parameter(weight) self.bias = torch.nn.Parameter(bias) if bias is not None else None else: self.register_buffer('weight', weight) if bias is not None: self.register_buffer('bias', bias) else: self.bias = None def forward(self, x, gain=1): w = self.weight * self.weight_gain b = self.bias.to(x.dtype) if self.bias is not None else None flip_weight = (self.up == 1) # slightly faster x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight) act_gain = self.act_gain * gain act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp) return x #---------------------------------------------------------------------------- @persistence.persistent_class class MappingNetwork(torch.nn.Module): def __init__(self, z_dim, # Input latent (Z) dimensionality, 0 = no latent. c_dim, # Conditioning label (C) dimensionality, 0 = no label. w_dim, # Intermediate latent (W) dimensionality. num_ws, # Number of intermediate latents to output, None = do not broadcast. num_layers = 8, # Number of mapping layers. embed_features = None, # Label embedding dimensionality, None = same as w_dim. layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim. activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers. w_avg_beta = 0.995, # Decay for tracking the moving average of W during training, None = do not track. ): super().__init__() self.z_dim = z_dim self.c_dim = c_dim self.w_dim = w_dim self.num_ws = num_ws self.num_layers = num_layers self.w_avg_beta = w_avg_beta if embed_features is None: embed_features = w_dim if c_dim == 0: embed_features = 0 if layer_features is None: layer_features = w_dim features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] if c_dim > 0: self.embed = FullyConnectedLayer(c_dim, embed_features) for idx in range(num_layers): in_features = features_list[idx] out_features = features_list[idx + 1] layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) setattr(self, f'fc{idx}', layer) if num_ws is not None and w_avg_beta is not None: self.register_buffer('w_avg', torch.zeros([w_dim])) def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False): # Embed, normalize, and concat inputs. x = None with torch.autograd.profiler.record_function('input'): if self.z_dim > 0: misc.assert_shape(z, [None, self.z_dim]) x = normalize_2nd_moment(z.to(torch.float32)) if self.c_dim > 0: misc.assert_shape(c, [None, self.c_dim]) y = normalize_2nd_moment(self.embed(c.to(torch.float32))) x = torch.cat([x, y], dim=1) if x is not None else y # Main layers. for idx in range(self.num_layers): layer = getattr(self, f'fc{idx}') x = layer(x) # Update moving average of W. if self.w_avg_beta is not None and self.training and not skip_w_avg_update: with torch.autograd.profiler.record_function('update_w_avg'): self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) # Broadcast. if self.num_ws is not None: with torch.autograd.profiler.record_function('broadcast'): x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) # Apply truncation. if truncation_psi != 1: with torch.autograd.profiler.record_function('truncate'): assert self.w_avg_beta is not None if self.num_ws is None or truncation_cutoff is None: x = self.w_avg.lerp(x, truncation_psi) else: x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) return x #---------------------------------------------------------------------------- @persistence.persistent_class class SynthesisLayer(torch.nn.Module): def __init__(self, in_channels, # Number of input channels. out_channels, # Number of output channels. w_dim, # Intermediate latent (W) dimensionality. resolution, # Resolution of this layer. kernel_size = 3, # Convolution kernel size. up = 1, # Integer upsampling factor. use_noise = True, # Enable noise input? activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. channels_last = False, # Use channels_last format for the weights? square = False, # default if for rectangle images ): super().__init__() self.resolution = resolution self.up = up self.use_noise = use_noise self.activation = activation self.conv_clamp = conv_clamp self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) self.padding = kernel_size // 2 self.act_gain = bias_act.activation_funcs[activation].def_gain self.square=square self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) memory_format = torch.channels_last if channels_last else torch.contiguous_format self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)) if use_noise: if self.square: self.register_buffer('noise_const', torch.randn([resolution, resolution])) else: self.register_buffer('noise_const', torch.randn([resolution, resolution // 2])) self.noise_strength = torch.nn.Parameter(torch.zeros([])) self.bias = torch.nn.Parameter(torch.zeros([out_channels])) def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1): assert noise_mode in ['random', 'const', 'none'] in_resolution = self.resolution // self.up if self.square: misc.assert_shape(x, [None, self.weight.shape[1], in_resolution, in_resolution]) else: misc.assert_shape(x, [None, self.weight.shape[1], in_resolution, in_resolution // 2]) styles = self.affine(w) noise = None if self.use_noise and noise_mode == 'random': if self.square: noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength else: noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution // 2], device=x.device) * self.noise_strength if self.use_noise and noise_mode == 'const': noise = self.noise_const * self.noise_strength flip_weight = (self.up == 1) # slightly faster x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up, padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv) act_gain = self.act_gain * gain act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp) return x #---------------------------------------------------------------------------- @persistence.persistent_class class ToRGBLayer(torch.nn.Module): def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False): super().__init__() self.conv_clamp = conv_clamp self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) memory_format = torch.channels_last if channels_last else torch.contiguous_format self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)) self.bias = torch.nn.Parameter(torch.zeros([out_channels])) self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) def forward(self, x, w, fused_modconv=True): styles = self.affine(w) * self.weight_gain x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv) x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp) return x #---------------------------------------------------------------------------- @persistence.persistent_class class SynthesisBlock(torch.nn.Module): def __init__(self, in_channels, # Number of input channels, 0 = first block. out_channels, # Number of output channels. w_dim, # Intermediate latent (W) dimensionality. resolution, # Resolution of this block. img_channels, # Number of output color channels. is_last, # Is this the last block? architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'. resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. use_fp16 = False, # Use FP16 for this block? fp16_channels_last = False, # Use channels-last memory format with FP16? square = False, # default is for rectangle images **layer_kwargs, # Arguments for SynthesisLayer. ): assert architecture in ['orig', 'skip', 'resnet'] super().__init__() self.in_channels = in_channels self.w_dim = w_dim self.resolution = resolution self.img_channels = img_channels self.is_last = is_last self.architecture = architecture self.use_fp16 = use_fp16 self.channels_last = (use_fp16 and fp16_channels_last) self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) self.num_conv = 0 self.num_torgb = 0 self.square = square if in_channels == 0: if self.square: self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution])) else: # rectangle self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution // 2])) if in_channels != 0: self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=2, resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, square=square,**layer_kwargs) self.num_conv += 1 self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution, conv_clamp=conv_clamp, channels_last=self.channels_last, square=square, **layer_kwargs) self.num_conv += 1 if is_last or architecture == 'skip': self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim, conv_clamp=conv_clamp, channels_last=self.channels_last) self.num_torgb += 1 if in_channels != 0 and architecture == 'resnet': self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2, resample_filter=resample_filter, channels_last=self.channels_last) def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, **layer_kwargs): misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) w_iter = iter(ws.unbind(dim=1)) dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format if fused_modconv is None: with misc.suppress_tracer_warnings(): # this value will be treated as a constant fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1) # Input. if self.in_channels == 0: x = self.const.to(dtype=dtype, memory_format=memory_format) x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) else: if self.square: misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2]) else: # rectangle misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 4]) x = x.to(dtype=dtype, memory_format=memory_format) # Main layers. if self.in_channels == 0: x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) elif self.architecture == 'resnet': y = self.skip(x, gain=np.sqrt(0.5)) x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) x = y.add_(x) else: x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) # ToRGB. if img is not None: if self.square: misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2]) else: misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 4]) img = upfirdn2d.upsample2d(img, self.resample_filter) if self.is_last or self.architecture == 'skip': y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv) y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) img = img.add_(y) if img is not None else y assert x.dtype == dtype assert img is None or img.dtype == torch.float32 return x, img #---------------------------------------------------------------------------- @persistence.persistent_class class SynthesisNetwork(torch.nn.Module): def __init__(self, w_dim, # Intermediate latent (W) dimensionality. img_resolution, # Output image resolution. img_channels, # Number of color channels. square, channel_base = 32768, # Overall multiplier for the number of channels. channel_max = 512, # Maximum number of channels in any layer. num_fp16_res = 0, # Use FP16 for the N highest resolutions. **block_kwargs, # Arguments for SynthesisBlock. ): assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 super().__init__() self.w_dim = w_dim self.img_resolution = img_resolution self.img_resolution_log2 = int(np.log2(img_resolution)) self.img_channels = img_channels self.square=square self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)] channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions} fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) self.num_ws = 0 for res in self.block_resolutions: in_channels = channels_dict[res // 2] if res > 4 else 0 out_channels = channels_dict[res] use_fp16 = (res >= fp16_resolution) is_last = (res == self.img_resolution) block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res, img_channels=img_channels, is_last=is_last, use_fp16=use_fp16,square=square, **block_kwargs) self.num_ws += block.num_conv if is_last: self.num_ws += block.num_torgb setattr(self, f'b{res}', block) def forward(self, ws, return_feature=False, **block_kwargs): block_ws = [] features = [] with torch.autograd.profiler.record_function('split_ws'): misc.assert_shape(ws, [None, self.num_ws, self.w_dim]) ws = ws.to(torch.float32) w_idx = 0 for res in self.block_resolutions: block = getattr(self, f'b{res}') block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb)) w_idx += block.num_conv x = img = None for res, cur_ws in zip(self.block_resolutions, block_ws): block = getattr(self, f'b{res}') x, img = block(x, img, cur_ws, **block_kwargs) features.append(x) if return_feature: return img, features else: return img #---------------------------------------------------------------------------- @persistence.persistent_class class Generator(torch.nn.Module): def __init__(self, z_dim, # Input latent (Z) dimensionality. c_dim, # Conditioning label (C) dimensionality. w_dim, # Intermediate latent (W) dimensionality. img_resolution, # Output resolution. square, img_channels, # Number of output color channels. mapping_kwargs = {}, # Arguments for MappingNetwork. synthesis_kwargs = {}, # Arguments for SynthesisNetwork. padding=False ): super().__init__() self.z_dim = z_dim self.c_dim = c_dim self.w_dim = w_dim self.square = square self.img_resolution = img_resolution self.img_channels = img_channels self.padding = padding self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels,square=square,**synthesis_kwargs) self.num_ws = self.synthesis.num_ws self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, input_is_w=False, return_feature=False, **synthesis_kwargs): if input_is_w: ws = z if ws.dim() == 2: ws = ws.unsqueeze(1).repeat([1, self.mapping.num_ws, 1]) else: ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff) img = self.synthesis(ws, return_feature=return_feature, **synthesis_kwargs) if return_feature: img, feature = img if self.padding: pad = (img.size(2) - img.size(3)) // 2 img = torch.nn.functional.pad(img, (pad, pad), "constant", 1) if return_feature: for i, feat in enumerate(feature): pad = (feat.size(2) - feat.size(3)) // 2 feature[i] = torch.nn.functional.pad(feat, (pad, pad), "constant", 0) if return_feature: return img, feature else: return img #---------------------------------------------------------------------------- @persistence.persistent_class class DiscriminatorBlock(torch.nn.Module): def __init__(self, in_channels, # Number of input channels, 0 = first block. tmp_channels, # Number of intermediate channels. out_channels, # Number of output channels. resolution, # Resolution of this block. img_channels, # Number of input color channels. first_layer_idx, # Index of the first layer. architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. use_fp16 = False, # Use FP16 for this block? fp16_channels_last = False, # Use channels-last memory format with FP16? freeze_layers = 0, # Freeze-D: Number of layers to freeze. square = False, ): assert in_channels in [0, tmp_channels] assert architecture in ['orig', 'skip', 'resnet'] super().__init__() self.in_channels = in_channels self.resolution = resolution self.img_channels = img_channels self.first_layer_idx = first_layer_idx self.architecture = architecture self.use_fp16 = use_fp16 self.channels_last = (use_fp16 and fp16_channels_last) self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) self.square = square self.num_layers = 0 def trainable_gen(): while True: layer_idx = self.first_layer_idx + self.num_layers trainable = (layer_idx >= freeze_layers) self.num_layers += 1 yield trainable trainable_iter = trainable_gen() if in_channels == 0 or architecture == 'skip': self.fromrgb = Conv2dLayer(img_channels, tmp_channels, kernel_size=1, activation=activation, trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last) self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation, trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last) self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2, trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last) if architecture == 'resnet': self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2, trainable=next(trainable_iter), resample_filter=resample_filter, channels_last=self.channels_last) def forward(self, x, img, force_fp32=False): dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format # Input. if x is not None: if self.square: misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) else: misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution // 2]) x = x.to(dtype=dtype, memory_format=memory_format) # FromRGB. if self.in_channels == 0 or self.architecture == 'skip': if self.square: misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution]) else: misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution // 2]) img = img.to(dtype=dtype, memory_format=memory_format) y = self.fromrgb(img) x = x + y if x is not None else y img = upfirdn2d.downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None # Main layers. if self.architecture == 'resnet': y = self.skip(x, gain=np.sqrt(0.5)) x = self.conv0(x) x = self.conv1(x, gain=np.sqrt(0.5)) x = y.add_(x) else: x = self.conv0(x) x = self.conv1(x) assert x.dtype == dtype return x, img #---------------------------------------------------------------------------- @persistence.persistent_class class MinibatchStdLayer(torch.nn.Module): def __init__(self, group_size, num_channels=1): super().__init__() self.group_size = group_size self.num_channels = num_channels def forward(self, x): N, C, H, W = x.shape with misc.suppress_tracer_warnings(): # as_tensor results are registered as constants G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N F = self.num_channels c = C // F y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c. y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group. y = y.square().mean(dim=0) # [nFcHW] Calc variance over group. y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group. y = y.mean(dim=[2,3,4]) # [nF] Take average over channels and pixels. y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions. y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels. x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels. return x #---------------------------------------------------------------------------- @persistence.persistent_class class DiscriminatorEpilogue(torch.nn.Module): def __init__(self, in_channels, # Number of input channels. cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label. resolution, # Resolution of this block. img_channels, # Number of input color channels. architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch. mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable. activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. square = False, ): assert architecture in ['orig', 'skip', 'resnet'] super().__init__() self.in_channels = in_channels self.cmap_dim = cmap_dim self.resolution = resolution self.img_channels = img_channels self.architecture = architecture self.square = square if architecture == 'skip': self.fromrgb = Conv2dLayer(img_channels, in_channels, kernel_size=1, activation=activation) self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, conv_clamp=conv_clamp) if self.square: self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), in_channels, activation=activation) else: self.fc = FullyConnectedLayer(in_channels * (resolution ** 2 // 2), in_channels, activation=activation) self.out = FullyConnectedLayer(in_channels, 1 if cmap_dim == 0 else cmap_dim) def forward(self, x, img, cmap, force_fp32=False): if self.square: misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) else: misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution // 2]) # [NCHW] _ = force_fp32 # unused dtype = torch.float32 memory_format = torch.contiguous_format # FromRGB. x = x.to(dtype=dtype, memory_format=memory_format) if self.architecture == 'skip': if self.square: misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution]) else: misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution // 2]) img = img.to(dtype=dtype, memory_format=memory_format) x = x + self.fromrgb(img) # Main layers. if self.mbstd is not None: x = self.mbstd(x) x = self.conv(x) x = self.fc(x.flatten(1)) x = self.out(x) # Conditioning. if self.cmap_dim > 0: misc.assert_shape(cmap, [None, self.cmap_dim]) x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) assert x.dtype == dtype return x #---------------------------------------------------------------------------- @persistence.persistent_class class Discriminator(torch.nn.Module): def __init__(self, c_dim, # Conditioning label (C) dimensionality. img_resolution, # Input resolution. img_channels, # Number of input color channels. architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. channel_base = 32768, # Overall multiplier for the number of channels. channel_max = 512, # Maximum number of channels in any layer. num_fp16_res = 0, # Use FP16 for the N highest resolutions. conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. square = False, # default for rectangle images block_kwargs = {}, # Arguments for DiscriminatorBlock. mapping_kwargs = {}, # Arguments for MappingNetwork. epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue. ): super().__init__() self.c_dim = c_dim self.img_resolution = img_resolution self.img_resolution_log2 = int(np.log2(img_resolution)) self.img_channels = img_channels self.square = square self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) if cmap_dim is None: cmap_dim = channels_dict[4] if c_dim == 0: cmap_dim = 0 common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) cur_layer_idx = 0 for res in self.block_resolutions: in_channels = channels_dict[res] if res < img_resolution else 0 tmp_channels = channels_dict[res] out_channels = channels_dict[res // 2] use_fp16 = (res >= fp16_resolution) block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, first_layer_idx=cur_layer_idx, use_fp16=use_fp16, square=square, **block_kwargs, **common_kwargs) setattr(self, f'b{res}', block) cur_layer_idx += block.num_layers if c_dim > 0: self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, square=square, **epilogue_kwargs, **common_kwargs) def forward(self, img, c, **block_kwargs): x = None for res in self.block_resolutions: block = getattr(self, f'b{res}') x, img = block(x, img, **block_kwargs) cmap = None if self.c_dim > 0: cmap = self.mapping(None, c) x = self.b4(x, img, cmap) return x #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/training_scripts/sg3/train.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Train a GAN using the techniques described in the paper "Alias-Free Generative Adversarial Networks".""" import os import click import re import json import tempfile import torch import dnnlib from training import training_loop from metrics import metric_main from torch_utils import training_stats from torch_utils import custom_ops import ast #---------------------------------------------------------------------------- def subprocess_fn(rank, c, temp_dir): dnnlib.util.Logger(file_name=os.path.join(c.run_dir, 'log.txt'), file_mode='a', should_flush=True) # Init torch.distributed. if c.num_gpus > 1: init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) if os.name == 'nt': init_method = 'file:///' + init_file.replace('\\', '/') torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=c.num_gpus) else: init_method = f'file://{init_file}' torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=c.num_gpus) # Init torch_utils. sync_device = torch.device('cuda', rank) if c.num_gpus > 1 else None training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) if rank != 0: custom_ops.verbosity = 'none' # Execute training loop. training_loop.training_loop(rank=rank, **c) #---------------------------------------------------------------------------- def launch_training(c, desc, outdir, dry_run): dnnlib.util.Logger(should_flush=True) # Pick output directory. prev_run_dirs = [] if os.path.isdir(outdir): prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))] prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs] prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None] cur_run_id = max(prev_run_ids, default=-1) + 1 c.run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{desc}') assert not os.path.exists(c.run_dir) # Print options. print() print('Training options:') print(json.dumps(c, indent=2)) print() print(f'Output directory: {c.run_dir}') print(f'Number of GPUs: {c.num_gpus}') print(f'Batch size: {c.batch_size} images') print(f'Training duration: {c.total_kimg} kimg') print(f'Dataset path: {c.training_set_kwargs.path}') print(f'Dataset size: {c.training_set_kwargs.max_size} images') print(f'Dataset resolution: {c.training_set_kwargs.resolution}') print(f'Dataset labels: {c.training_set_kwargs.use_labels}') print(f'Dataset x-flips: {c.training_set_kwargs.xflip}') print() # Dry run? if dry_run: print('Dry run; exiting.') return # Create output directory. print('Creating output directory...') os.makedirs(c.run_dir) with open(os.path.join(c.run_dir, 'training_options.json'), 'wt') as f: json.dump(c, f, indent=2) # Launch processes. print('Launching processes...') torch.multiprocessing.set_start_method('spawn') with tempfile.TemporaryDirectory() as temp_dir: if c.num_gpus == 1: subprocess_fn(rank=0, c=c, temp_dir=temp_dir) else: torch.multiprocessing.spawn(fn=subprocess_fn, args=(c, temp_dir), nprocs=c.num_gpus) #---------------------------------------------------------------------------- def init_dataset_kwargs(data, square=False): # dataset try: dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data, use_labels=True, max_size=None, xflip=False, square=square) dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # Subclass of training.dataset.Dataset. dataset_kwargs.resolution = dataset_obj.resolution # Be explicit about resolution. dataset_kwargs.use_labels = dataset_obj.has_labels # Be explicit about labels. dataset_kwargs.max_size = len(dataset_obj) # Be explicit about dataset size. return dataset_kwargs, dataset_obj.name except IOError as err: raise click.ClickException(f'--data: {err}') print("out of dataset") #---------------------------------------------------------------------------- def parse_comma_separated_list(s): if isinstance(s, list): return s if s is None or s.lower() == 'none' or s == '': return [] return s.split(',') #---------------------------------------------------------------------------- @click.command() # Required. @click.option('--outdir', help='Where to save the results', metavar='DIR', required=True) @click.option('--cfg', help='Base configuration', type=click.Choice(['stylegan3-t', 'stylegan3-r', 'stylegan2']), required=True) @click.option('--data', help='Training data', metavar='PATH', required=True) @click.option('--gpus', help='Number of GPUs to use', metavar='INT', type=click.IntRange(min=1), required=True) @click.option('--batch', help='Total batch size', metavar='INT', type=click.IntRange(min=1), required=True) @click.option('--gamma', help='R1 regularization weight', metavar='FLOAT', type=click.FloatRange(min=0), required=True) @click.option('--square', help='True for square, False for rectangle', type=bool, metavar='BOOL', default=False) # Optional features. @click.option('--cond', help='Train conditional model', metavar='BOOL', type=bool, default=False, show_default=True) @click.option('--mirror', help='Enable dataset x-flips', metavar='BOOL', type=bool, default=False, show_default=True) @click.option('--aug', help='Augmentation mode', type=click.Choice(['noaug', 'ada', 'fixed']), default='ada', show_default=True) @click.option('--resume', help='Resume from given network pickle', metavar='[PATH|URL]', type=str) @click.option('--freezed', help='Freeze first layers of D', metavar='INT', type=click.IntRange(min=0), default=0, show_default=True) # Misc hyperparameters. @click.option('--p', help='Probability for --aug=fixed', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.2, show_default=True) @click.option('--target', help='Target value for --aug=ada', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.6, show_default=True) @click.option('--batch-gpu', help='Limit batch size per GPU', metavar='INT', type=click.IntRange(min=1)) @click.option('--cbase', help='Capacity multiplier', metavar='INT', type=click.IntRange(min=1), default=32768, show_default=True) @click.option('--cmax', help='Max. feature maps', metavar='INT', type=click.IntRange(min=1), default=512, show_default=True) @click.option('--glr', help='G learning rate [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0)) @click.option('--dlr', help='D learning rate', metavar='FLOAT', type=click.FloatRange(min=0), default=0.002, show_default=True) @click.option('--map-depth', help='Mapping network depth [default: varies]', metavar='INT', type=click.IntRange(min=1)) @click.option('--mbstd-group', help='Minibatch std group size', metavar='INT', type=click.IntRange(min=1), default=4, show_default=True) # Misc settings. @click.option('--desc', help='String to include in result dir name', metavar='STR', type=str) @click.option('--metrics', help='Quality metrics', metavar='[NAME|A,B,C|none]', type=parse_comma_separated_list, default='fid50k_full', show_default=True) @click.option('--kimg', help='Total training duration', metavar='KIMG', type=click.IntRange(min=1), default=25000, show_default=True) @click.option('--tick', help='How often to print progress', metavar='KIMG', type=click.IntRange(min=1), default=4, show_default=True) @click.option('--snap', help='How often to save snapshots', metavar='TICKS', type=click.IntRange(min=1), default=50, show_default=True) @click.option('--seed', help='Random seed', metavar='INT', type=click.IntRange(min=0), default=0, show_default=True) @click.option('--fp32', help='Disable mixed-precision', metavar='BOOL', type=bool, default=False, show_default=True) @click.option('--nobench', help='Disable cuDNN benchmarking', metavar='BOOL', type=bool, default=False, show_default=True) @click.option('--workers', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=3, show_default=True) @click.option('-n','--dry-run', help='Print training options and exit', is_flag=True) def main(**kwargs): """Train a GAN using the techniques described in the paper "Alias-Free Generative Adversarial Networks". Examples: \b # Train StyleGAN3-T for AFHQv2 using 8 GPUs. python train.py --outdir=~/training-runs --cfg=stylegan3-t --data=~/datasets/afhqv2-512x512.zip \\ --gpus=8 --batch=32 --gamma=8.2 --mirror=1 \b # Fine-tune StyleGAN3-R for MetFaces-U using 1 GPU, starting from the pre-trained FFHQ-U pickle. python train.py --outdir=~/training-runs --cfg=stylegan3-r --data=~/datasets/metfacesu-1024x1024.zip \\ --gpus=8 --batch=32 --gamma=6.6 --mirror=1 --kimg=5000 --snap=5 \\ --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl \b # Train StyleGAN2 for FFHQ at 1024x1024 resolution using 8 GPUs. python train.py --outdir=~/training-runs --cfg=stylegan2 --data=~/datasets/ffhq-1024x1024.zip \\ --gpus=8 --batch=32 --gamma=10 --mirror=1 --aug=noaug """ # Initialize config. opts = dnnlib.EasyDict(kwargs) # Command line arguments. c = dnnlib.EasyDict() # Main config dict. print('---- square: ',opts.square) c.G_kwargs = dnnlib.EasyDict(class_name=None, z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict(),square=opts.square) c.D_kwargs = dnnlib.EasyDict(class_name='training.networks_stylegan2.Discriminator', block_kwargs=dnnlib.EasyDict(), mapping_kwargs=dnnlib.EasyDict(), epilogue_kwargs=dnnlib.EasyDict(),square=opts.square) c.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0.99], eps=1e-8) c.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0.99], eps=1e-8) c.loss_kwargs = dnnlib.EasyDict(class_name='training.loss.StyleGAN2Loss') c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, prefetch_factor=2) # Training set. c.training_set_kwargs, dataset_name = init_dataset_kwargs(data=opts.data, square=opts.square) if opts.cond and not c.training_set_kwargs.use_labels: raise click.ClickException('--cond=True requires labels specified in dataset.json') c.training_set_kwargs.use_labels = opts.cond c.training_set_kwargs.xflip = opts.mirror # Hyperparameters & settings. c.num_gpus = opts.gpus c.batch_size = opts.batch c.batch_gpu = opts.batch_gpu or opts.batch // opts.gpus c.G_kwargs.channel_base = c.D_kwargs.channel_base = opts.cbase c.G_kwargs.channel_max = c.D_kwargs.channel_max = opts.cmax c.G_kwargs.mapping_kwargs.num_layers = (8 if opts.cfg == 'stylegan2' else 2) if opts.map_depth is None else opts.map_depth c.D_kwargs.block_kwargs.freeze_layers = opts.freezed c.D_kwargs.epilogue_kwargs.mbstd_group_size = opts.mbstd_group c.loss_kwargs.r1_gamma = opts.gamma c.G_opt_kwargs.lr = (0.002 if opts.cfg == 'stylegan2' else 0.0025) if opts.glr is None else opts.glr c.D_opt_kwargs.lr = opts.dlr c.metrics = opts.metrics c.total_kimg = opts.kimg c.kimg_per_tick = opts.tick c.image_snapshot_ticks = c.network_snapshot_ticks = opts.snap c.random_seed = c.training_set_kwargs.random_seed = opts.seed c.data_loader_kwargs.num_workers = opts.workers # Sanity checks. if c.batch_size % c.num_gpus != 0: raise click.ClickException('--batch must be a multiple of --gpus') if c.batch_size % (c.num_gpus * c.batch_gpu) != 0: raise click.ClickException('--batch must be a multiple of --gpus times --batch-gpu') if c.batch_gpu < c.D_kwargs.epilogue_kwargs.mbstd_group_size: raise click.ClickException('--batch-gpu cannot be smaller than --mbstd') if any(not metric_main.is_valid_metric(metric) for metric in c.metrics): raise click.ClickException('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) # Base configuration. c.ema_kimg = c.batch_size * 10 / 32 if opts.cfg == 'stylegan2': c.G_kwargs.class_name = 'training.networks_stylegan2.Generator' c.loss_kwargs.style_mixing_prob = 0.9 # Enable style mixing regularization. c.loss_kwargs.pl_weight = 2 # Enable path length regularization. c.G_reg_interval = 4 # Enable lazy regularization for G. c.G_kwargs.fused_modconv_default = 'inference_only' # Speed up training by using regular convolutions instead of grouped convolutions. c.loss_kwargs.pl_no_weight_grad = True # Speed up path length regularization by skipping gradient computation wrt. conv2d weights. else: c.G_kwargs.class_name = 'training.networks_stylegan3.Generator' c.G_kwargs.magnitude_ema_beta = 0.5 ** (c.batch_size / (20 * 1e3)) if opts.cfg == 'stylegan3-r': c.G_kwargs.conv_kernel = 1 # Use 1x1 convolutions. c.G_kwargs.channel_base *= 2 # Double the number of feature maps. c.G_kwargs.channel_max *= 2 c.G_kwargs.use_radial_filters = True # Use radially symmetric downsampling filters. c.loss_kwargs.blur_init_sigma = 10 # Blur the images seen by the discriminator. c.loss_kwargs.blur_fade_kimg = c.batch_size * 200 / 32 # Fade out the blur during the first N kimg. # Augmentation. if opts.aug != 'noaug': c.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1) if opts.aug == 'ada': c.ada_target = opts.target if opts.aug == 'fixed': c.augment_p = opts.p # Resume. if opts.resume is not None: c.resume_pkl = opts.resume c.ada_kimg = 100 # Make ADA react faster at the beginning. c.ema_rampup = None # Disable EMA rampup. c.loss_kwargs.blur_init_sigma = 0 # Disable blur rampup. # Performance-related toggles. if opts.fp32: c.G_kwargs.num_fp16_res = c.D_kwargs.num_fp16_res = 0 c.G_kwargs.conv_clamp = c.D_kwargs.conv_clamp = None if opts.nobench: c.cudnn_benchmark = False # Description string. desc = f'{opts.cfg:s}-{dataset_name:s}-gpus{c.num_gpus:d}-batch{c.batch_size:d}-gamma{c.loss_kwargs.r1_gamma:g}' if opts.desc is not None: desc += f'-{opts.desc}' # Launch. launch_training(c=c, desc=desc, outdir=opts.outdir, dry_run=opts.dry_run) #---------------------------------------------------------------------------- if __name__ == "__main__": main() # pylint: disable=no-value-for-parameter #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/training_scripts/sg3/training/dataset.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Streaming images and labels from datasets created with dataset_tool.py.""" import os import numpy as np import zipfile import PIL.Image import json import torch import dnnlib from petrel_client.client import Client import cv2 try: import pyspng except ImportError: pyspng = None #---------------------------------------------------------------------------- class Dataset(torch.utils.data.Dataset): def __init__(self, name, # Name of the dataset. raw_shape, # Shape of the raw image data (NCHW). max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. use_labels = False, # Enable conditioning labels? False = label dimension is zero. xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. random_seed = 0, # Random seed to use when applying max_size. square = False, ): print('Inside Dataset') self._name = name self._raw_shape = list(raw_shape) self._use_labels = use_labels self._raw_labels = None self._label_shape = None self._square = square print("inside dataset, _square: ", self._square) # Apply max_size. self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) if (max_size is not None) and (self._raw_idx.size > max_size): np.random.RandomState(random_seed).shuffle(self._raw_idx) self._raw_idx = np.sort(self._raw_idx[:max_size]) # Apply xflip. self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) if xflip: self._raw_idx = np.tile(self._raw_idx, 2) self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) def _get_raw_labels(self): if self._raw_labels is None: self._raw_labels = self._load_raw_labels() if self._use_labels else None if self._raw_labels is None: self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) assert isinstance(self._raw_labels, np.ndarray) assert self._raw_labels.shape[0] == self._raw_shape[0] assert self._raw_labels.dtype in [np.float32, np.int64] if self._raw_labels.dtype == np.int64: assert self._raw_labels.ndim == 1 assert np.all(self._raw_labels >= 0) return self._raw_labels def close(self): # to be overridden by subclass pass def _load_raw_image(self, raw_idx): # to be overridden by subclass raise NotImplementedError def _load_raw_labels(self): # to be overridden by subclass raise NotImplementedError def __getstate__(self): return dict(self.__dict__, _raw_labels=None) def __del__(self): try: self.close() except: pass def __len__(self): return self._raw_idx.size def __getitem__(self, idx): image = self._load_raw_image(self._raw_idx[idx]) assert isinstance(image, np.ndarray) assert list(image.shape) == self.image_shape assert image.dtype == np.uint8 if self._xflip[idx]: assert image.ndim == 3 # CHW image = image[:, :, ::-1] return image.copy(), self.get_label(idx) def get_label(self, idx): label = self._get_raw_labels()[self._raw_idx[idx]] if label.dtype == np.int64: onehot = np.zeros(self.label_shape, dtype=np.float32) onehot[label] = 1 label = onehot return label.copy() def get_details(self, idx): d = dnnlib.EasyDict() d.raw_idx = int(self._raw_idx[idx]) d.xflip = (int(self._xflip[idx]) != 0) d.raw_label = self._get_raw_labels()[d.raw_idx].copy() return d @property def name(self): return self._name @property def image_shape(self): return list(self._raw_shape[1:]) @property def num_channels(self): assert len(self.image_shape) == 3 # CHW return self.image_shape[0] @property def resolution(self): assert len(self.image_shape) == 3 # CHW if self._square: assert self.image_shape[1] == self.image_shape[2] else: assert self.image_shape[1] == self.image_shape[2] * 2 return self.image_shape[1] @property def label_shape(self): if self._label_shape is None: raw_labels = self._get_raw_labels() if raw_labels.dtype == np.int64: self._label_shape = [int(np.max(raw_labels)) + 1] else: self._label_shape = raw_labels.shape[1:] return list(self._label_shape) @property def label_dim(self): assert len(self.label_shape) == 1 return self.label_shape[0] @property def has_labels(self): return any(x != 0 for x in self.label_shape) @property def has_onehot_labels(self): return self._get_raw_labels().dtype == np.int64 #---------------------------------------------------------------------------- class ImageFolderDataset(Dataset): def __init__(self, path, # Path to directory or zip. resolution = None, # Ensure specific resolution, None = highest available. ceph = False, square = False, **super_kwargs, # Additional arguments for the Dataset base class. ): self._path = path self._zipfile = None self._square = square if os.path.isdir(self._path): self._type = 'dir' self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} elif self._file_ext(self._path) == '.zip': self._type = 'zip' self._all_fnames = set(self._get_zipfile().namelist()) else: raise IOError('Path must point to a directory or zip') PIL.Image.init() self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) if len(self._image_fnames) == 0: raise IOError('No image files found in the specified path') name = os.path.splitext(os.path.basename(self._path))[0] raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) # if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): # raise IOError('Image files do not match the specified resolution') if resolution is not None: if self._square: raw_shape[2] = raw_shape[3] = resolution else: raw_shape[2] = resolution raw_shape[3] = resolution // 2 # print(raw_shape) super().__init__(name=name, raw_shape=raw_shape,square=square, **super_kwargs) @staticmethod def _file_ext(fname): return os.path.splitext(fname)[1].lower() def _get_zipfile(self): assert self._type == 'zip' if self._zipfile is None: self._zipfile = zipfile.ZipFile(self._path) return self._zipfile def _open_file(self, fname): if self._type == 'dir': return open(os.path.join(self._path, fname), 'rb') if self._type == 'zip': return self._get_zipfile().open(fname, 'r') return None def close(self): try: if self._zipfile is not None: self._zipfile.close() finally: self._zipfile = None def __getstate__(self): return dict(super().__getstate__(), _zipfile=None) def _load_raw_image(self, raw_idx): fname = self._image_fnames[raw_idx] with self._open_file(fname) as f: if pyspng is not None and self._file_ext(fname) == '.png': image = pyspng.load(f.read()) else: image = np.array(PIL.Image.open(f)) if image.ndim == 2: image = image[:, :, np.newaxis] # HW => HWC image = image.transpose(2, 0, 1) # HWC => CHW return image def _load_raw_labels(self): fname = 'dataset.json' if fname not in self._all_fnames: return None with self._open_file(fname) as f: labels = json.load(f)['labels'] if labels is None: return None labels = dict(labels) labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] labels = np.array(labels) labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) return labels #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/training_scripts/sg3/training/networks_stylegan2.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Network architectures from the paper "Analyzing and Improving the Image Quality of StyleGAN". Matches the original implementation of configs E-F by Karras et al. at https://github.com/NVlabs/stylegan2/blob/master/training/networks_stylegan2.py""" import numpy as np import torch from torch_utils import misc from torch_utils import persistence from torch_utils.ops import conv2d_resample from torch_utils.ops import upfirdn2d from torch_utils.ops import bias_act from torch_utils.ops import fma #---------------------------------------------------------------------------- @misc.profiled_function def normalize_2nd_moment(x, dim=1, eps=1e-8): return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() #---------------------------------------------------------------------------- @misc.profiled_function def modulated_conv2d( x, # Input tensor of shape [batch_size, in_channels, in_height, in_width]. weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width]. styles, # Modulation coefficients of shape [batch_size, in_channels]. noise = None, # Optional noise tensor to add to the output activations. up = 1, # Integer upsampling factor. down = 1, # Integer downsampling factor. padding = 0, # Padding with respect to the upsampled image. resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter(). demodulate = True, # Apply weight demodulation? flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d). fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation? ): batch_size = x.shape[0] out_channels, in_channels, kh, kw = weight.shape misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk] misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] misc.assert_shape(styles, [batch_size, in_channels]) # [NI] # Pre-normalize inputs to avoid FP16 overflow. if x.dtype == torch.float16 and demodulate: weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I # Calculate per-sample weights and demodulation coefficients. w = None dcoefs = None if demodulate or fused_modconv: w = weight.unsqueeze(0) # [NOIkk] w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] if demodulate: dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] if demodulate and fused_modconv: w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] # Execute by scaling the activations before and after the convolution. if not fused_modconv: x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight) if demodulate and noise is not None: x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) elif demodulate: x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1) elif noise is not None: x = x.add_(noise.to(x.dtype)) return x # Execute as one fused op using grouped convolution. with misc.suppress_tracer_warnings(): # this value will be treated as a constant batch_size = int(batch_size) misc.assert_shape(x, [batch_size, in_channels, None, None]) x = x.reshape(1, -1, *x.shape[2:]) w = w.reshape(-1, in_channels, kh, kw) x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight) x = x.reshape(batch_size, -1, *x.shape[2:]) if noise is not None: x = x.add_(noise) return x #---------------------------------------------------------------------------- @persistence.persistent_class class FullyConnectedLayer(torch.nn.Module): def __init__(self, in_features, # Number of input features. out_features, # Number of output features. bias = True, # Apply additive bias before the activation function? activation = 'linear', # Activation function: 'relu', 'lrelu', etc. lr_multiplier = 1, # Learning rate multiplier. bias_init = 0, # Initial value for the additive bias. ): super().__init__() self.in_features = in_features self.out_features = out_features self.activation = activation self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None self.weight_gain = lr_multiplier / np.sqrt(in_features) self.bias_gain = lr_multiplier def forward(self, x): w = self.weight.to(x.dtype) * self.weight_gain b = self.bias if b is not None: b = b.to(x.dtype) if self.bias_gain != 1: b = b * self.bias_gain if self.activation == 'linear' and b is not None: x = torch.addmm(b.unsqueeze(0), x, w.t()) else: x = x.matmul(w.t()) x = bias_act.bias_act(x, b, act=self.activation) return x def extra_repr(self): return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}' #---------------------------------------------------------------------------- @persistence.persistent_class class Conv2dLayer(torch.nn.Module): def __init__(self, in_channels, # Number of input channels. out_channels, # Number of output channels. kernel_size, # Width and height of the convolution kernel. bias = True, # Apply additive bias before the activation function? activation = 'linear', # Activation function: 'relu', 'lrelu', etc. up = 1, # Integer upsampling factor. down = 1, # Integer downsampling factor. resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. conv_clamp = None, # Clamp the output to +-X, None = disable clamping. channels_last = False, # Expect the input to have memory_format=channels_last? trainable = True, # Update the weights of this layer during training? ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.activation = activation self.up = up self.down = down self.conv_clamp = conv_clamp self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) self.padding = kernel_size // 2 self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) self.act_gain = bias_act.activation_funcs[activation].def_gain memory_format = torch.channels_last if channels_last else torch.contiguous_format weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format) bias = torch.zeros([out_channels]) if bias else None if trainable: self.weight = torch.nn.Parameter(weight) self.bias = torch.nn.Parameter(bias) if bias is not None else None else: self.register_buffer('weight', weight) if bias is not None: self.register_buffer('bias', bias) else: self.bias = None def forward(self, x, gain=1): w = self.weight * self.weight_gain b = self.bias.to(x.dtype) if self.bias is not None else None flip_weight = (self.up == 1) # slightly faster x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight) act_gain = self.act_gain * gain act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp) return x def extra_repr(self): return ' '.join([ f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, activation={self.activation:s},', f'up={self.up}, down={self.down}']) #---------------------------------------------------------------------------- @persistence.persistent_class class MappingNetwork(torch.nn.Module): def __init__(self, z_dim, # Input latent (Z) dimensionality, 0 = no latent. c_dim, # Conditioning label (C) dimensionality, 0 = no label. w_dim, # Intermediate latent (W) dimensionality. num_ws, # Number of intermediate latents to output, None = do not broadcast. num_layers = 8, # Number of mapping layers. embed_features = None, # Label embedding dimensionality, None = same as w_dim. layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim. activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers. w_avg_beta = 0.998, # Decay for tracking the moving average of W during training, None = do not track. ): super().__init__() self.z_dim = z_dim self.c_dim = c_dim self.w_dim = w_dim self.num_ws = num_ws self.num_layers = num_layers self.w_avg_beta = w_avg_beta if embed_features is None: embed_features = w_dim if c_dim == 0: embed_features = 0 if layer_features is None: layer_features = w_dim features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] if c_dim > 0: self.embed = FullyConnectedLayer(c_dim, embed_features) for idx in range(num_layers): in_features = features_list[idx] out_features = features_list[idx + 1] layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) setattr(self, f'fc{idx}', layer) if num_ws is not None and w_avg_beta is not None: self.register_buffer('w_avg', torch.zeros([w_dim])) def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False): # Embed, normalize, and concat inputs. x = None with torch.autograd.profiler.record_function('input'): if self.z_dim > 0: misc.assert_shape(z, [None, self.z_dim]) x = normalize_2nd_moment(z.to(torch.float32)) if self.c_dim > 0: misc.assert_shape(c, [None, self.c_dim]) y = normalize_2nd_moment(self.embed(c.to(torch.float32))) x = torch.cat([x, y], dim=1) if x is not None else y # Main layers. for idx in range(self.num_layers): layer = getattr(self, f'fc{idx}') x = layer(x) # Update moving average of W. if update_emas and self.w_avg_beta is not None: with torch.autograd.profiler.record_function('update_w_avg'): self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) # Broadcast. if self.num_ws is not None: with torch.autograd.profiler.record_function('broadcast'): x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) # Apply truncation. if truncation_psi != 1: with torch.autograd.profiler.record_function('truncate'): assert self.w_avg_beta is not None if self.num_ws is None or truncation_cutoff is None: x = self.w_avg.lerp(x, truncation_psi) else: x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) return x def extra_repr(self): return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}' #---------------------------------------------------------------------------- @persistence.persistent_class class SynthesisLayer(torch.nn.Module): def __init__(self, in_channels, # Number of input channels. out_channels, # Number of output channels. w_dim, # Intermediate latent (W) dimensionality. resolution, # Resolution of this layer. kernel_size = 3, # Convolution kernel size. up = 1, # Integer upsampling factor. use_noise = True, # Enable noise input? activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. channels_last = False, # Use channels_last format for the weights? square = False, # default if for rectangle images ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.w_dim = w_dim self.resolution = resolution self.up = up self.use_noise = use_noise self.activation = activation self.conv_clamp = conv_clamp self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) self.padding = kernel_size // 2 self.act_gain = bias_act.activation_funcs[activation].def_gain self.square=square self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) memory_format = torch.channels_last if channels_last else torch.contiguous_format self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)) if use_noise: if self.square: self.register_buffer('noise_const', torch.randn([resolution, resolution])) else: self.register_buffer('noise_const', torch.randn([resolution, resolution // 2])) self.noise_strength = torch.nn.Parameter(torch.zeros([])) self.bias = torch.nn.Parameter(torch.zeros([out_channels])) def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1): assert noise_mode in ['random', 'const', 'none'] in_resolution = self.resolution // self.up if self.square: misc.assert_shape(x, [None, self.weight.shape[1], in_resolution, in_resolution]) else: misc.assert_shape(x, [None, self.weight.shape[1], in_resolution, in_resolution // 2]) styles = self.affine(w) noise = None if self.use_noise and noise_mode == 'random': if self.square: noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength else: noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution // 2], device=x.device) * self.noise_strength if self.use_noise and noise_mode == 'const': noise = self.noise_const * self.noise_strength flip_weight = (self.up == 1) # slightly faster x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up, padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv) act_gain = self.act_gain * gain act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp) return x def extra_repr(self): return ' '.join([ f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d},', f'resolution={self.resolution:d}, up={self.up}, activation={self.activation:s}']) #---------------------------------------------------------------------------- @persistence.persistent_class class ToRGBLayer(torch.nn.Module): def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.w_dim = w_dim self.conv_clamp = conv_clamp self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) memory_format = torch.channels_last if channels_last else torch.contiguous_format self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)) self.bias = torch.nn.Parameter(torch.zeros([out_channels])) self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) def forward(self, x, w, fused_modconv=True): styles = self.affine(w) * self.weight_gain x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv) x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp) return x def extra_repr(self): return f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d}' #---------------------------------------------------------------------------- @persistence.persistent_class class SynthesisBlock(torch.nn.Module): def __init__(self, in_channels, # Number of input channels, 0 = first block. out_channels, # Number of output channels. w_dim, # Intermediate latent (W) dimensionality. resolution, # Resolution of this block. img_channels, # Number of output color channels. is_last, # Is this the last block? architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'. resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. use_fp16 = False, # Use FP16 for this block? fp16_channels_last = False, # Use channels-last memory format with FP16? square = False, # default is for rectangle images fused_modconv_default = True, # Default value of fused_modconv. 'inference_only' = True for inference, False for training. **layer_kwargs, # Arguments for SynthesisLayer. ): assert architecture in ['orig', 'skip', 'resnet'] super().__init__() self.in_channels = in_channels self.w_dim = w_dim self.resolution = resolution self.img_channels = img_channels self.is_last = is_last self.architecture = architecture self.use_fp16 = use_fp16 self.channels_last = (use_fp16 and fp16_channels_last) self.fused_modconv_default = fused_modconv_default self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) self.num_conv = 0 self.num_torgb = 0 self.square = square if in_channels == 0: if self.square: self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution])) else: # rectangle self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution // 2])) if in_channels != 0: self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=2, resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, square=square, **layer_kwargs) self.num_conv += 1 self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution, conv_clamp=conv_clamp, channels_last=self.channels_last, square=square, **layer_kwargs) self.num_conv += 1 if is_last or architecture == 'skip': self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim, conv_clamp=conv_clamp, channels_last=self.channels_last) self.num_torgb += 1 if in_channels != 0 and architecture == 'resnet': self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2, resample_filter=resample_filter, channels_last=self.channels_last) def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, update_emas=False, **layer_kwargs): _ = update_emas # unused misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) w_iter = iter(ws.unbind(dim=1)) if ws.device.type != 'cuda': force_fp32 = True dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format if fused_modconv is None: fused_modconv = self.fused_modconv_default if fused_modconv == 'inference_only': fused_modconv = (not self.training) # Input. if self.in_channels == 0: x = self.const.to(dtype=dtype, memory_format=memory_format) x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) else: if self.square: misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2]) else: # rectangle misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 4]) x = x.to(dtype=dtype, memory_format=memory_format) # Main layers. if self.in_channels == 0: x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) elif self.architecture == 'resnet': y = self.skip(x, gain=np.sqrt(0.5)) x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) x = y.add_(x) else: x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) # ToRGB. if img is not None: if self.square: misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2]) else: misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 4]) img = upfirdn2d.upsample2d(img, self.resample_filter) if self.is_last or self.architecture == 'skip': y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv) y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) img = img.add_(y) if img is not None else y assert x.dtype == dtype assert img is None or img.dtype == torch.float32 return x, img def extra_repr(self): return f'resolution={self.resolution:d}, architecture={self.architecture:s}' #---------------------------------------------------------------------------- @persistence.persistent_class class SynthesisNetwork(torch.nn.Module): def __init__(self, w_dim, # Intermediate latent (W) dimensionality. img_resolution, # Output image resolution. img_channels, # Number of color channels. square, channel_base = 32768, # Overall multiplier for the number of channels. channel_max = 512, # Maximum number of channels in any layer. num_fp16_res = 4, # Use FP16 for the N highest resolutions. **block_kwargs, # Arguments for SynthesisBlock. ): assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 super().__init__() self.w_dim = w_dim self.img_resolution = img_resolution self.img_resolution_log2 = int(np.log2(img_resolution)) self.img_channels = img_channels self.square=square self.num_fp16_res = num_fp16_res self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)] channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions} fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) self.num_ws = 0 for res in self.block_resolutions: in_channels = channels_dict[res // 2] if res > 4 else 0 out_channels = channels_dict[res] use_fp16 = (res >= fp16_resolution) is_last = (res == self.img_resolution) block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res, img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, square=square,**block_kwargs) self.num_ws += block.num_conv if is_last: self.num_ws += block.num_torgb setattr(self, f'b{res}', block) def forward(self, ws, **block_kwargs): block_ws = [] with torch.autograd.profiler.record_function('split_ws'): misc.assert_shape(ws, [None, self.num_ws, self.w_dim]) ws = ws.to(torch.float32) w_idx = 0 for res in self.block_resolutions: block = getattr(self, f'b{res}') block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb)) w_idx += block.num_conv x = img = None for res, cur_ws in zip(self.block_resolutions, block_ws): block = getattr(self, f'b{res}') x, img = block(x, img, cur_ws, **block_kwargs) return img def extra_repr(self): return ' '.join([ f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},', f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},', f'num_fp16_res={self.num_fp16_res:d}']) #---------------------------------------------------------------------------- @persistence.persistent_class class Generator(torch.nn.Module): def __init__(self, z_dim, # Input latent (Z) dimensionality. c_dim, # Conditioning label (C) dimensionality. w_dim, # Intermediate latent (W) dimensionality. square, img_resolution, # Output resolution. img_channels, # Number of output color channels. mapping_kwargs = {}, # Arguments for MappingNetwork. **synthesis_kwargs, # Arguments for SynthesisNetwork. ): super().__init__() self.z_dim = z_dim self.c_dim = c_dim self.w_dim = w_dim self.square = square self.img_resolution = img_resolution self.img_channels = img_channels self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, square=square, **synthesis_kwargs) self.num_ws = self.synthesis.num_ws self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs): ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) return img #---------------------------------------------------------------------------- @persistence.persistent_class class DiscriminatorBlock(torch.nn.Module): def __init__(self, in_channels, # Number of input channels, 0 = first block. tmp_channels, # Number of intermediate channels. out_channels, # Number of output channels. resolution, # Resolution of this block. img_channels, # Number of input color channels. first_layer_idx, # Index of the first layer. architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. use_fp16 = False, # Use FP16 for this block? fp16_channels_last = False, # Use channels-last memory format with FP16? freeze_layers = 0, # Freeze-D: Number of layers to freeze. square = False, ): assert in_channels in [0, tmp_channels] assert architecture in ['orig', 'skip', 'resnet'] super().__init__() self.in_channels = in_channels self.resolution = resolution self.img_channels = img_channels self.first_layer_idx = first_layer_idx self.architecture = architecture self.use_fp16 = use_fp16 self.channels_last = (use_fp16 and fp16_channels_last) self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) self.square = square self.num_layers = 0 def trainable_gen(): while True: layer_idx = self.first_layer_idx + self.num_layers trainable = (layer_idx >= freeze_layers) self.num_layers += 1 yield trainable trainable_iter = trainable_gen() if in_channels == 0 or architecture == 'skip': self.fromrgb = Conv2dLayer(img_channels, tmp_channels, kernel_size=1, activation=activation, trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last) self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation, trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last) self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2, trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last) if architecture == 'resnet': self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2, trainable=next(trainable_iter), resample_filter=resample_filter, channels_last=self.channels_last) def forward(self, x, img, force_fp32=False): if (x if x is not None else img).device.type != 'cuda': force_fp32 = True dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format # Input. if x is not None: if self.square: misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) else: misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution // 2]) x = x.to(dtype=dtype, memory_format=memory_format) # FromRGB. if self.in_channels == 0 or self.architecture == 'skip': if self.square: misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution]) else: misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution // 2]) img = img.to(dtype=dtype, memory_format=memory_format) y = self.fromrgb(img) x = x + y if x is not None else y img = upfirdn2d.downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None # Main layers. if self.architecture == 'resnet': y = self.skip(x, gain=np.sqrt(0.5)) x = self.conv0(x) x = self.conv1(x, gain=np.sqrt(0.5)) x = y.add_(x) else: x = self.conv0(x) x = self.conv1(x) assert x.dtype == dtype return x, img def extra_repr(self): return f'resolution={self.resolution:d}, architecture={self.architecture:s}' #---------------------------------------------------------------------------- @persistence.persistent_class class MinibatchStdLayer(torch.nn.Module): def __init__(self, group_size, num_channels=1): super().__init__() self.group_size = group_size self.num_channels = num_channels def forward(self, x): N, C, H, W = x.shape with misc.suppress_tracer_warnings(): # as_tensor results are registered as constants G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N F = self.num_channels c = C // F y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c. y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group. y = y.square().mean(dim=0) # [nFcHW] Calc variance over group. y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group. y = y.mean(dim=[2,3,4]) # [nF] Take average over channels and pixels. y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions. y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels. x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels. return x def extra_repr(self): return f'group_size={self.group_size}, num_channels={self.num_channels:d}' #---------------------------------------------------------------------------- @persistence.persistent_class class DiscriminatorEpilogue(torch.nn.Module): def __init__(self, in_channels, # Number of input channels. cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label. resolution, # Resolution of this block. img_channels, # Number of input color channels. architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch. mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable. activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. square = False, ): assert architecture in ['orig', 'skip', 'resnet'] super().__init__() self.in_channels = in_channels self.cmap_dim = cmap_dim self.resolution = resolution self.img_channels = img_channels self.architecture = architecture self.square = square if architecture == 'skip': self.fromrgb = Conv2dLayer(img_channels, in_channels, kernel_size=1, activation=activation) self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, conv_clamp=conv_clamp) if self.square: self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), in_channels, activation=activation) else: self.fc = FullyConnectedLayer(in_channels * (resolution ** 2 // 2), in_channels, activation=activation) self.out = FullyConnectedLayer(in_channels, 1 if cmap_dim == 0 else cmap_dim) def forward(self, x, img, cmap, force_fp32=False): if self.square: misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) else: misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution // 2]) # [NCHW] _ = force_fp32 # unused dtype = torch.float32 memory_format = torch.contiguous_format # FromRGB. x = x.to(dtype=dtype, memory_format=memory_format) if self.architecture == 'skip': if self.square: misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution]) else: misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution // 2]) img = img.to(dtype=dtype, memory_format=memory_format) x = x + self.fromrgb(img) # Main layers. if self.mbstd is not None: x = self.mbstd(x) x = self.conv(x) x = self.fc(x.flatten(1)) x = self.out(x) # Conditioning. if self.cmap_dim > 0: misc.assert_shape(cmap, [None, self.cmap_dim]) x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) assert x.dtype == dtype return x def extra_repr(self): return f'resolution={self.resolution:d}, architecture={self.architecture:s}' #---------------------------------------------------------------------------- @persistence.persistent_class class Discriminator(torch.nn.Module): def __init__(self, c_dim, # Conditioning label (C) dimensionality. img_resolution, # Input resolution. img_channels, # Number of input color channels. architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. channel_base = 32768, # Overall multiplier for the number of channels. channel_max = 512, # Maximum number of channels in any layer. num_fp16_res = 4, # Use FP16 for the N highest resolutions. conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. square = False, # default for rectangle images block_kwargs = {}, # Arguments for DiscriminatorBlock. mapping_kwargs = {}, # Arguments for MappingNetwork. epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue. ): super().__init__() self.c_dim = c_dim self.img_resolution = img_resolution self.img_resolution_log2 = int(np.log2(img_resolution)) self.img_channels = img_channels self.square = square self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) if cmap_dim is None: cmap_dim = channels_dict[4] if c_dim == 0: cmap_dim = 0 common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) cur_layer_idx = 0 for res in self.block_resolutions: in_channels = channels_dict[res] if res < img_resolution else 0 tmp_channels = channels_dict[res] out_channels = channels_dict[res // 2] use_fp16 = (res >= fp16_resolution) block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, first_layer_idx=cur_layer_idx, use_fp16=use_fp16, square=square, **block_kwargs, **common_kwargs) setattr(self, f'b{res}', block) cur_layer_idx += block.num_layers if c_dim > 0: self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, square=square, **epilogue_kwargs, **common_kwargs) def forward(self, img, c, update_emas=False, **block_kwargs): _ = update_emas # unused x = None for res in self.block_resolutions: block = getattr(self, f'b{res}') x, img = block(x, img, **block_kwargs) cmap = None if self.c_dim > 0: cmap = self.mapping(None, c) x = self.b4(x, img, cmap) return x def extra_repr(self): return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}' #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/training_scripts/sg3/training/networks_stylegan3.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Generator architecture from the paper "Alias-Free Generative Adversarial Networks".""" import numpy as np import scipy.signal import scipy.optimize import torch from torch_utils import misc from torch_utils import persistence from torch_utils.ops import conv2d_gradfix from torch_utils.ops import filtered_lrelu from torch_utils.ops import bias_act #---------------------------------------------------------------------------- @misc.profiled_function def modulated_conv2d( x, # Input tensor: [batch_size, in_channels, in_height, in_width] w, # Weight tensor: [out_channels, in_channels, kernel_height, kernel_width] s, # Style tensor: [batch_size, in_channels] demodulate = True, # Apply weight demodulation? padding = 0, # Padding: int or [padH, padW] input_gain = None, # Optional scale factors for the input channels: [], [in_channels], or [batch_size, in_channels] ): with misc.suppress_tracer_warnings(): # this value will be treated as a constant batch_size = int(x.shape[0]) out_channels, in_channels, kh, kw = w.shape misc.assert_shape(w, [out_channels, in_channels, kh, kw]) # [OIkk] misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] misc.assert_shape(s, [batch_size, in_channels]) # [NI] # Pre-normalize inputs. if demodulate: w = w * w.square().mean([1,2,3], keepdim=True).rsqrt() s = s * s.square().mean().rsqrt() # Modulate weights. w = w.unsqueeze(0) # [NOIkk] w = w * s.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk] # Demodulate weights. if demodulate: dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] w = w * dcoefs.unsqueeze(2).unsqueeze(3).unsqueeze(4) # [NOIkk] # Apply input scaling. if input_gain is not None: input_gain = input_gain.expand(batch_size, in_channels) # [NI] w = w * input_gain.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk] # Execute as one fused op using grouped convolution. x = x.reshape(1, -1, *x.shape[2:]) w = w.reshape(-1, in_channels, kh, kw) x = conv2d_gradfix.conv2d(input=x, weight=w.to(x.dtype), padding=padding, groups=batch_size) x = x.reshape(batch_size, -1, *x.shape[2:]) return x #---------------------------------------------------------------------------- @persistence.persistent_class class FullyConnectedLayer(torch.nn.Module): def __init__(self, in_features, # Number of input features. out_features, # Number of output features. activation = 'linear', # Activation function: 'relu', 'lrelu', etc. bias = True, # Apply additive bias before the activation function? lr_multiplier = 1, # Learning rate multiplier. weight_init = 1, # Initial standard deviation of the weight tensor. bias_init = 0, # Initial value of the additive bias. ): super().__init__() self.in_features = in_features self.out_features = out_features self.activation = activation self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) * (weight_init / lr_multiplier)) bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_features]) self.bias = torch.nn.Parameter(torch.from_numpy(bias_init / lr_multiplier)) if bias else None self.weight_gain = lr_multiplier / np.sqrt(in_features) self.bias_gain = lr_multiplier def forward(self, x): w = self.weight.to(x.dtype) * self.weight_gain b = self.bias if b is not None: b = b.to(x.dtype) if self.bias_gain != 1: b = b * self.bias_gain if self.activation == 'linear' and b is not None: x = torch.addmm(b.unsqueeze(0), x, w.t()) else: x = x.matmul(w.t()) x = bias_act.bias_act(x, b, act=self.activation) return x def extra_repr(self): return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}' #---------------------------------------------------------------------------- @persistence.persistent_class class MappingNetwork(torch.nn.Module): def __init__(self, z_dim, # Input latent (Z) dimensionality. c_dim, # Conditioning label (C) dimensionality, 0 = no labels. w_dim, # Intermediate latent (W) dimensionality. num_ws, # Number of intermediate latents to output. num_layers = 2, # Number of mapping layers. lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers. w_avg_beta = 0.998, # Decay for tracking the moving average of W during training. ): super().__init__() self.z_dim = z_dim self.c_dim = c_dim self.w_dim = w_dim self.num_ws = num_ws self.num_layers = num_layers self.w_avg_beta = w_avg_beta # Construct layers. self.embed = FullyConnectedLayer(self.c_dim, self.w_dim) if self.c_dim > 0 else None features = [self.z_dim + (self.w_dim if self.c_dim > 0 else 0)] + [self.w_dim] * self.num_layers for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]): layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier) setattr(self, f'fc{idx}', layer) self.register_buffer('w_avg', torch.zeros([w_dim])) def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False): misc.assert_shape(z, [None, self.z_dim]) if truncation_cutoff is None: truncation_cutoff = self.num_ws # Embed, normalize, and concatenate inputs. x = z.to(torch.float32) x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt() if self.c_dim > 0: misc.assert_shape(c, [None, self.c_dim]) y = self.embed(c.to(torch.float32)) y = y * (y.square().mean(1, keepdim=True) + 1e-8).rsqrt() x = torch.cat([x, y], dim=1) if x is not None else y # Execute layers. for idx in range(self.num_layers): x = getattr(self, f'fc{idx}')(x) # Update moving average of W. if update_emas: self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) # Broadcast and apply truncation. x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) if truncation_psi != 1: x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) return x def extra_repr(self): return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}' #---------------------------------------------------------------------------- @persistence.persistent_class class SynthesisInput(torch.nn.Module): def __init__(self, w_dim, # Intermediate latent (W) dimensionality. channels, # Number of output channels. size, # Output spatial size: int or [width, height]. sampling_rate, # Output sampling rate. bandwidth, # Output bandwidth. square, ): super().__init__() self.w_dim = w_dim self.channels = channels self.square = square if self.square: self.size = np.broadcast_to(np.asarray(size), [2]) else: self.size = np.array([size // 2, size]) # [width, height] self.sampling_rate = sampling_rate self.bandwidth = bandwidth # Draw random frequencies from uniform 2D disc. freqs = torch.randn([self.channels, 2]) radii = freqs.square().sum(dim=1, keepdim=True).sqrt() freqs /= radii * radii.square().exp().pow(0.25) freqs *= bandwidth phases = torch.rand([self.channels]) - 0.5 # Setup parameters and buffers. self.weight = torch.nn.Parameter(torch.randn([self.channels, self.channels])) self.affine = FullyConnectedLayer(w_dim, 4, weight_init=0, bias_init=[1,0,0,0]) self.register_buffer('transform', torch.eye(3, 3)) # User-specified inverse transform wrt. resulting image. self.register_buffer('freqs', freqs) self.register_buffer('phases', phases) def forward(self, w): # Introduce batch dimension. transforms = self.transform.unsqueeze(0) # [batch, row, col] freqs = self.freqs.unsqueeze(0) # [batch, channel, xy] phases = self.phases.unsqueeze(0) # [batch, channel] # Apply learned transformation. t = self.affine(w) # t = (r_c, r_s, t_x, t_y) t = t / t[:, :2].norm(dim=1, keepdim=True) # t' = (r'_c, r'_s, t'_x, t'_y) m_r = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse rotation wrt. resulting image. m_r[:, 0, 0] = t[:, 0] # r'_c m_r[:, 0, 1] = -t[:, 1] # r'_s m_r[:, 1, 0] = t[:, 1] # r'_s m_r[:, 1, 1] = t[:, 0] # r'_c m_t = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse translation wrt. resulting image. m_t[:, 0, 2] = -t[:, 2] # t'_x m_t[:, 1, 2] = -t[:, 3] # t'_y transforms = m_r @ m_t @ transforms # First rotate resulting image, then translate, and finally apply user-specified transform. # Transform frequencies. phases = phases + (freqs @ transforms[:, :2, 2:]).squeeze(2) freqs = freqs @ transforms[:, :2, :2] # Dampen out-of-band frequencies that may occur due to the user-specified transform. amplitudes = (1 - (freqs.norm(dim=2) - self.bandwidth) / (self.sampling_rate / 2 - self.bandwidth)).clamp(0, 1) # Construct sampling grid. theta = torch.eye(2, 3, device=w.device) theta[0, 0] = 0.5 * self.size[0] / self.sampling_rate theta[1, 1] = 0.5 * self.size[1] / self.sampling_rate grids = torch.nn.functional.affine_grid(theta.unsqueeze(0), [1, 1, self.size[1], self.size[0]], align_corners=False) # Compute Fourier features. x = (grids.unsqueeze(3) @ freqs.permute(0, 2, 1).unsqueeze(1).unsqueeze(2)).squeeze(3) # [batch, height, width, channel] x = x + phases.unsqueeze(1).unsqueeze(2) x = torch.sin(x * (np.pi * 2)) x = x * amplitudes.unsqueeze(1).unsqueeze(2) # Apply trainable mapping. weight = self.weight / np.sqrt(self.channels) x = x @ weight.t() # Ensure correct shape. x = x.permute(0, 3, 1, 2) # [batch, channel, height, width] misc.assert_shape(x, [w.shape[0], self.channels, int(self.size[1]), int(self.size[0])]) return x def extra_repr(self): return '\n'.join([ f'w_dim={self.w_dim:d}, channels={self.channels:d}, size={list(self.size)},', f'sampling_rate={self.sampling_rate:g}, bandwidth={self.bandwidth:g}']) #---------------------------------------------------------------------------- @persistence.persistent_class class SynthesisLayer(torch.nn.Module): def __init__(self, w_dim, # Intermediate latent (W) dimensionality. is_torgb, # Is this the final ToRGB layer? is_critically_sampled, # Does this layer use critical sampling? use_fp16, # Does this layer use FP16? # Input & output specifications. in_channels, # Number of input channels. out_channels, # Number of output channels. in_size, # Input spatial size: int or [width, height]. out_size, # Output spatial size: int or [width, height]. in_sampling_rate, # Input sampling rate (s). out_sampling_rate, # Output sampling rate (s). in_cutoff, # Input cutoff frequency (f_c). out_cutoff, # Output cutoff frequency (f_c). in_half_width, # Input transition band half-width (f_h). out_half_width, # Output Transition band half-width (f_h). # Hyperparameters. conv_kernel = 3, # Convolution kernel size. Ignored for final the ToRGB layer. filter_size = 6, # Low-pass filter size relative to the lower resolution when up/downsampling. lrelu_upsampling = 2, # Relative sampling rate for leaky ReLU. Ignored for final the ToRGB layer. use_radial_filters = False, # Use radially symmetric downsampling filter? Ignored for critically sampled layers. conv_clamp = 256, # Clamp the output to [-X, +X], None = disable clamping. magnitude_ema_beta = 0.999, # Decay rate for the moving average of input magnitudes. square = False, # default if for rectangle images ): super().__init__() self.w_dim = w_dim self.is_torgb = is_torgb self.is_critically_sampled = is_critically_sampled self.use_fp16 = use_fp16 self.in_channels = in_channels self.out_channels = out_channels self.square = square if self.square: self.in_size = np.broadcast_to(np.asarray(in_size), [2]) self.out_size = np.broadcast_to(np.asarray(out_size), [2]) else: # self.in_size = np.array[in_size, in_size//2] self.in_size = np.array([in_size // 2, in_size]) # self.out_size = np.array[out_size, out_size//2] self.out_size = np.array([out_size // 2, out_size]) self.in_sampling_rate = in_sampling_rate self.out_sampling_rate = out_sampling_rate self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) self.in_cutoff = in_cutoff self.out_cutoff = out_cutoff self.in_half_width = in_half_width self.out_half_width = out_half_width self.conv_kernel = 1 if is_torgb else conv_kernel self.conv_clamp = conv_clamp self.magnitude_ema_beta = magnitude_ema_beta # Setup parameters and buffers. self.affine = FullyConnectedLayer(self.w_dim, self.in_channels, bias_init=1) self.weight = torch.nn.Parameter(torch.randn([self.out_channels, self.in_channels, self.conv_kernel, self.conv_kernel])) self.bias = torch.nn.Parameter(torch.zeros([self.out_channels])) self.register_buffer('magnitude_ema', torch.ones([])) # Design upsampling filter. self.up_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) assert self.in_sampling_rate * self.up_factor == self.tmp_sampling_rate self.up_taps = filter_size * self.up_factor if self.up_factor > 1 and not self.is_torgb else 1 self.register_buffer('up_filter', self.design_lowpass_filter( numtaps=self.up_taps, cutoff=self.in_cutoff, width=self.in_half_width*2, fs=self.tmp_sampling_rate)) # Design downsampling filter. self.down_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) assert self.out_sampling_rate * self.down_factor == self.tmp_sampling_rate self.down_taps = filter_size * self.down_factor if self.down_factor > 1 and not self.is_torgb else 1 self.down_radial = use_radial_filters and not self.is_critically_sampled self.register_buffer('down_filter', self.design_lowpass_filter( numtaps=self.down_taps, cutoff=self.out_cutoff, width=self.out_half_width*2, fs=self.tmp_sampling_rate, radial=self.down_radial)) # Compute padding. pad_total = (self.out_size - 1) * self.down_factor + 1 # Desired output size before downsampling. pad_total -= (self.in_size + self.conv_kernel - 1) * self.up_factor # Input size after upsampling. pad_total += self.up_taps + self.down_taps - 2 # Size reduction caused by the filters. pad_lo = (pad_total + self.up_factor) // 2 # Shift sample locations according to the symmetric interpretation (Appendix C.3). pad_hi = pad_total - pad_lo self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] def forward(self, x, w, noise_mode='random', force_fp32=False, update_emas=False): assert noise_mode in ['random', 'const', 'none'] # unused misc.assert_shape(x, [None, self.in_channels, int(self.in_size[1]), int(self.in_size[0])]) misc.assert_shape(w, [x.shape[0], self.w_dim]) # Track input magnitude. if update_emas: with torch.autograd.profiler.record_function('update_magnitude_ema'): magnitude_cur = x.detach().to(torch.float32).square().mean() self.magnitude_ema.copy_(magnitude_cur.lerp(self.magnitude_ema, self.magnitude_ema_beta)) input_gain = self.magnitude_ema.rsqrt() # Execute affine layer. styles = self.affine(w) if self.is_torgb: weight_gain = 1 / np.sqrt(self.in_channels * (self.conv_kernel ** 2)) styles = styles * weight_gain # Execute modulated conv2d. dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 x = modulated_conv2d(x=x.to(dtype), w=self.weight, s=styles, padding=self.conv_kernel-1, demodulate=(not self.is_torgb), input_gain=input_gain) # Execute bias, filtered leaky ReLU, and clamping. gain = 1 if self.is_torgb else np.sqrt(2) slope = 1 if self.is_torgb else 0.2 x = filtered_lrelu.filtered_lrelu(x=x, fu=self.up_filter, fd=self.down_filter, b=self.bias.to(x.dtype), up=self.up_factor, down=self.down_factor, padding=self.padding, gain=gain, slope=slope, clamp=self.conv_clamp) # Ensure correct shape and dtype. misc.assert_shape(x, [None, self.out_channels, int(self.out_size[1]), int(self.out_size[0])]) assert x.dtype == dtype return x @staticmethod def design_lowpass_filter(numtaps, cutoff, width, fs, radial=False): assert numtaps >= 1 # Identity filter. if numtaps == 1: return None # Separable Kaiser low-pass filter. if not radial: f = scipy.signal.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) return torch.as_tensor(f, dtype=torch.float32) # Radially symmetric jinc-based filter. x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs r = np.hypot(*np.meshgrid(x, x)) f = scipy.special.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) beta = scipy.signal.kaiser_beta(scipy.signal.kaiser_atten(numtaps, width / (fs / 2))) w = np.kaiser(numtaps, beta) f *= np.outer(w, w) f /= np.sum(f) return torch.as_tensor(f, dtype=torch.float32) def extra_repr(self): return '\n'.join([ f'w_dim={self.w_dim:d}, is_torgb={self.is_torgb},', f'is_critically_sampled={self.is_critically_sampled}, use_fp16={self.use_fp16},', f'in_sampling_rate={self.in_sampling_rate:g}, out_sampling_rate={self.out_sampling_rate:g},', f'in_cutoff={self.in_cutoff:g}, out_cutoff={self.out_cutoff:g},', f'in_half_width={self.in_half_width:g}, out_half_width={self.out_half_width:g},', f'in_size={list(self.in_size)}, out_size={list(self.out_size)},', f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}']) #---------------------------------------------------------------------------- @persistence.persistent_class class SynthesisNetwork(torch.nn.Module): def __init__(self, w_dim, # Intermediate latent (W) dimensionality. img_resolution, # Output image resolution. img_channels, # Number of color channels. square, channel_base = 32768, # Overall multiplier for the number of channels. channel_max = 512, # Maximum number of channels in any layer. num_layers = 14, # Total number of layers, excluding Fourier features and ToRGB. num_critical = 2, # Number of critically sampled layers at the end. first_cutoff = 2, # Cutoff frequency of the first layer (f_{c,0}). first_stopband = 2**2.1, # Minimum stopband of the first layer (f_{t,0}). last_stopband_rel = 2**0.3, # Minimum stopband of the last layer, expressed relative to the cutoff. margin_size = 10, # Number of additional pixels outside the image. output_scale = 0.25, # Scale factor for the output image. num_fp16_res = 4, # Use FP16 for the N highest resolutions. **layer_kwargs, # Arguments for SynthesisLayer. ): super().__init__() self.w_dim = w_dim self.num_ws = num_layers + 2 self.img_resolution = img_resolution self.img_channels = img_channels self.num_layers = num_layers self.num_critical = num_critical self.margin_size = margin_size self.output_scale = output_scale self.num_fp16_res = num_fp16_res self.square = square # Geometric progression of layer cutoffs and min. stopbands. last_cutoff = self.img_resolution / 2 # f_{c,N} last_stopband = last_cutoff * last_stopband_rel # f_{t,N} exponents = np.minimum(np.arange(self.num_layers + 1) / (self.num_layers - self.num_critical), 1) cutoffs = first_cutoff * (last_cutoff / first_cutoff) ** exponents # f_c[i] stopbands = first_stopband * (last_stopband / first_stopband) ** exponents # f_t[i] # Compute remaining layer parameters. sampling_rates = np.exp2(np.ceil(np.log2(np.minimum(stopbands * 2, self.img_resolution)))) # s[i] half_widths = np.maximum(stopbands, sampling_rates / 2) - cutoffs # f_h[i] sizes = sampling_rates + self.margin_size * 2 sizes[-2:] = self.img_resolution channels = np.rint(np.minimum((channel_base / 2) / cutoffs, channel_max)) channels[-1] = self.img_channels # Construct layers. self.input = SynthesisInput( w_dim=self.w_dim, channels=int(channels[0]), size=int(sizes[0]), sampling_rate=sampling_rates[0], bandwidth=cutoffs[0], square=self.square) self.layer_names = [] for idx in range(self.num_layers + 1): prev = max(idx - 1, 0) is_torgb = (idx == self.num_layers) is_critically_sampled = (idx >= self.num_layers - self.num_critical) use_fp16 = (sampling_rates[idx] * (2 ** self.num_fp16_res) > self.img_resolution) layer = SynthesisLayer( w_dim=self.w_dim, is_torgb=is_torgb, is_critically_sampled=is_critically_sampled, use_fp16=use_fp16, in_channels=int(channels[prev]), out_channels= int(channels[idx]), in_size=int(sizes[prev]), out_size=int(sizes[idx]), in_sampling_rate=int(sampling_rates[prev]), out_sampling_rate=int(sampling_rates[idx]), in_cutoff=cutoffs[prev], out_cutoff=cutoffs[idx], in_half_width=half_widths[prev], out_half_width=half_widths[idx], square=self.square, **layer_kwargs) name = f'L{idx}_{layer.out_size[0]}_{layer.out_channels}' setattr(self, name, layer) self.layer_names.append(name) def forward(self, ws, **layer_kwargs): misc.assert_shape(ws, [None, self.num_ws, self.w_dim]) ws = ws.to(torch.float32).unbind(dim=1) # Execute layers. x = self.input(ws[0]) for name, w in zip(self.layer_names, ws[1:]): x = getattr(self, name)(x, w, **layer_kwargs) if self.output_scale != 1: x = x * self.output_scale # Ensure correct shape and dtype. if self.square: misc.assert_shape(x, [None, self.img_channels, self.img_resolution, self.img_resolution]) else: misc.assert_shape(x, [None, self.img_channels, self.img_resolution, self.img_resolution // 2]) x = x.to(torch.float32) return x def extra_repr(self): return '\n'.join([ f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},', f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},', f'num_layers={self.num_layers:d}, num_critical={self.num_critical:d},', f'margin_size={self.margin_size:d}, num_fp16_res={self.num_fp16_res:d}']) #---------------------------------------------------------------------------- @persistence.persistent_class class Generator(torch.nn.Module): def __init__(self, z_dim, # Input latent (Z) dimensionality. c_dim, # Conditioning label (C) dimensionality. w_dim, # Intermediate latent (W) dimensionality. img_resolution, # Output resolution. square, img_channels, # Number of output color channels. mapping_kwargs = {}, # Arguments for MappingNetwork. **synthesis_kwargs, # Arguments for SynthesisNetwork. ): super().__init__() self.z_dim = z_dim self.c_dim = c_dim self.w_dim = w_dim self.img_resolution = img_resolution self.img_channels = img_channels self.square = square self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, square=self.square, **synthesis_kwargs) self.num_ws = self.synthesis.num_ws self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs): ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) return img #---------------------------------------------------------------------------- ================================================ FILE: stylegan_human/utils/ImagesDataset.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. import os from torch.utils.data import Dataset from PIL import Image from utils.data_utils import make_dataset class ImagesDataset(Dataset): def __init__(self, source_root, source_transform=None): self.source_paths = sorted(make_dataset(source_root)) self.source_transform = source_transform def __len__(self): return len(self.source_paths) def __getitem__(self, index): fname, from_path = self.source_paths[index] from_im = Image.open(from_path).convert('RGB') if self.source_transform: from_im = self.source_transform(from_im) return fname, from_im ================================================ FILE: stylegan_human/utils/__init__.py ================================================ ================================================ FILE: stylegan_human/utils/data_utils.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. import os from PIL import Image IMG_EXTENSIONS = [ '.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' ] def is_image_file(filename): return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) def tensor2im(var): # var shape: (3, H, W) var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy() var = ((var + 1) / 2) var[var < 0] = 0 var[var > 1] = 1 var = var * 255 return Image.fromarray(var.astype('uint8')) def make_dataset(dir): images = [] assert os.path.isdir(dir), '%s is not a valid directory' % dir for root, _, fnames in sorted(os.walk(dir)): for fname in fnames: if is_image_file(fname): path = os.path.join(root, fname) fname = fname.split('.')[0] images.append((fname, path)) return images ================================================ FILE: stylegan_human/utils/face_alignment.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. import numpy as np import PIL import PIL.Image import scipy import scipy.ndimage import dlib import copy from PIL import Image def get_landmark(img, detector, predictor): """get landmark with dlib :return: np.array shape=(68, 2) """ # detector = dlib.get_frontal_face_detector() # dets, _, _ = detector.run(img, 1, -1) dets = detector(img, 1) for k, d in enumerate(dets): shape = predictor(img, d.rect) t = list(shape.parts()) a = [] for tt in t: a.append([tt.x, tt.y]) lm = np.array(a) # face rect face_rect = [dets[0].rect.left(), dets[0].rect.top(), dets[0].rect.right(), dets[0].rect.bottom()] return lm, face_rect def align_face_for_insetgan(img, detector, predictor, output_size=256): """ :param img: numpy array rgb :return: PIL Image """ img_cp = copy.deepcopy(img) lm, face_rect = get_landmark(img, detector, predictor) lm_chin = lm[0: 17] # left-right lm_eyebrow_left = lm[17: 22] # left-right lm_eyebrow_right = lm[22: 27] # left-right lm_nose = lm[27: 31] # top-down lm_nostrils = lm[31: 36] # top-down lm_eye_left = lm[36: 42] # left-clockwise lm_eye_right = lm[42: 48] # left-clockwise lm_mouth_outer = lm[48: 60] # left-clockwise lm_mouth_inner = lm[60: 68] # left-clockwise # Calculate auxiliary vectors. eye_left = np.mean(lm_eye_left, axis=0) eye_right = np.mean(lm_eye_right, axis=0) eye_avg = (eye_left + eye_right) * 0.5 eye_to_eye = eye_right - eye_left mouth_left = lm_mouth_outer[0] mouth_right = lm_mouth_outer[6] mouth_avg = (mouth_left + mouth_right) * 0.5 eye_to_mouth = mouth_avg - eye_avg # Choose oriented crop rectangle. x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] x /= np.hypot(*x) x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) y = np.flipud(x) * [-1, 1] c = eye_avg + eye_to_mouth * 0.1 quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) qsize = np.hypot(*x) * 2 # read image # opencv to PIL img = PIL.Image.fromarray(img_cp) # img = PIL.Image.open(filepath) transform_size = output_size enable_padding = False # Shrink. # shrink = int(np.floor(qsize / output_size * 0.5)) # if shrink > 1: # rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) # img = img.resize(rsize, PIL.Image.ANTIALIAS) # quad /= shrink # qsize /= shrink # Crop. border = max(int(np.rint(qsize * 0.1)), 3) crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1])))) # crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), # min(crop[3] + border, img.size[1])) # img.save("debug/raw.jpg") if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: img = img.crop(crop) quad -= crop[0:2] # img.save("debug/crop.jpg") # Pad. # pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), # int(np.ceil(max(quad[:, 1])))) # pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), # max(pad[3] - img.size[1] + border, 0)) # if enable_padding and max(pad) > border - 4: # pad = np.maximum(pad, int(np.rint(qsize * 0.3))) # img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') # h, w, _ = img.shape # y, x, _ = np.ogrid[:h, :w, :1] # mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), # 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3])) # blur = qsize * 0.02 # img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) # img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) # img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') # quad += pad[:2] # Transform. # crop shape to transform shape # nw = # print(img.size, quad+0.5, np.bound((quad+0.5).flatten())) # assert False # img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR) # img.save("debug/transform.jpg") # if output_size < transform_size: img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) # img.save("debug/resize.jpg") # print((quad+crop[0:2]).flatten()) # assert False # Return aligned image. return img, crop, face_rect def align_face_for_projector(img, detector, predictor, output_size): """ :param filepath: str :return: PIL Image """ img_cp = copy.deepcopy(img) lm, face_rect = get_landmark(img, detector, predictor) lm_chin = lm[0: 17] # left-right lm_eyebrow_left = lm[17: 22] # left-right lm_eyebrow_right = lm[22: 27] # left-right lm_nose = lm[27: 31] # top-down lm_nostrils = lm[31: 36] # top-down lm_eye_left = lm[36: 42] # left-clockwise lm_eye_right = lm[42: 48] # left-clockwise lm_mouth_outer = lm[48: 60] # left-clockwise lm_mouth_inner = lm[60: 68] # left-clockwise # Calculate auxiliary vectors. eye_left = np.mean(lm_eye_left, axis=0) eye_right = np.mean(lm_eye_right, axis=0) eye_avg = (eye_left + eye_right) * 0.5 eye_to_eye = eye_right - eye_left mouth_left = lm_mouth_outer[0] mouth_right = lm_mouth_outer[6] mouth_avg = (mouth_left + mouth_right) * 0.5 eye_to_mouth = mouth_avg - eye_avg # Choose oriented crop rectangle. x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] x /= np.hypot(*x) x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) y = np.flipud(x) * [-1, 1] c = eye_avg + eye_to_mouth * 0.1 quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) qsize = np.hypot(*x) * 2 # read image img = PIL.Image.fromarray(img_cp) transform_size = output_size enable_padding = True # Shrink. shrink = int(np.floor(qsize / output_size * 0.5)) if shrink > 1: rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) img = img.resize(rsize, PIL.Image.ANTIALIAS) quad /= shrink qsize /= shrink # Crop. border = max(int(np.rint(qsize * 0.1)), 3) crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1])))) crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1])) if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: img = img.crop(crop) quad -= crop[0:2] # Pad. pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1])))) pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0)) if enable_padding and max(pad) > border - 4: pad = np.maximum(pad, int(np.rint(qsize * 0.3))) img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') h, w, _ = img.shape y, x, _ = np.ogrid[:h, :w, :1] mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3])) blur = qsize * 0.02 img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') quad += pad[:2] # Transform. img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR) if output_size < transform_size: img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) # Return aligned image. return img def reverse_quad_transform(image, quad_to_map_to, alpha): # forward mapping, for simplicity result = Image.new("RGBA",image.size) result_pixels = result.load() width, height = result.size for y in range(height): for x in range(width): result_pixels[x,y] = (0,0,0,0) p1 = (quad_to_map_to[0],quad_to_map_to[1]) p2 = (quad_to_map_to[2],quad_to_map_to[3]) p3 = (quad_to_map_to[4],quad_to_map_to[5]) p4 = (quad_to_map_to[6],quad_to_map_to[7]) p1_p2_vec = (p2[0] - p1[0],p2[1] - p1[1]) p4_p3_vec = (p3[0] - p4[0],p3[1] - p4[1]) for y in range(height): for x in range(width): pixel = image.getpixel((x,y)) y_percentage = y / float(height) x_percentage = x / float(width) # interpolate vertically pa = (p1[0] + p1_p2_vec[0] * y_percentage, p1[1] + p1_p2_vec[1] * y_percentage) pb = (p4[0] + p4_p3_vec[0] * y_percentage, p4[1] + p4_p3_vec[1] * y_percentage) pa_to_pb_vec = (pb[0] - pa[0],pb[1] - pa[1]) # interpolate horizontally p = (pa[0] + pa_to_pb_vec[0] * x_percentage, pa[1] + pa_to_pb_vec[1] * x_percentage) try: result_pixels[p[0],p[1]] = (pixel[0],pixel[1],pixel[2],min(int(alpha * 255),pixel[3])) except Exception: pass return result ================================================ FILE: stylegan_human/utils/log_utils.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. import numpy as np from PIL import Image import wandb from pti.pti_configs import global_config import torch import matplotlib.pyplot as plt def log_image_from_w(w, G, name): img = get_image_from_w(w, G) pillow_image = Image.fromarray(img) wandb.log( {f"{name}": [ wandb.Image(pillow_image, caption=f"current inversion {name}")]}, step=global_config.training_step) def log_images_from_w(ws, G, names): for name, w in zip(names, ws): w = w.to(global_config.device) log_image_from_w(w, G, name) def plot_image_from_w(w, G): img = get_image_from_w(w, G) pillow_image = Image.fromarray(img) plt.imshow(pillow_image) plt.show() def plot_image(img): img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy() pillow_image = Image.fromarray(img[0]) plt.imshow(pillow_image) plt.show() def save_image(name, method_type, results_dir, image, run_id): image.save(f'{results_dir}/{method_type}_{name}_{run_id}.jpg') def save_w(w, G, name, method_type, results_dir): im = get_image_from_w(w, G) im = Image.fromarray(im, mode='RGB') save_image(name, method_type, results_dir, im) def save_concat_image(base_dir, image_latents, new_inv_image_latent, new_G, old_G, file_name, extra_image=None): images_to_save = [] if extra_image is not None: images_to_save.append(extra_image) for latent in image_latents: images_to_save.append(get_image_from_w(latent, old_G)) images_to_save.append(get_image_from_w(new_inv_image_latent, new_G)) result_image = create_alongside_images(images_to_save) result_image.save(f'{base_dir}/{file_name}.jpg') def save_single_image(base_dir, image_latent, G, file_name): image_to_save = get_image_from_w(image_latent, G) image_to_save = Image.fromarray(image_to_save, mode='RGB') image_to_save.save(f'{base_dir}/{file_name}.jpg') def create_alongside_images(images): res = np.concatenate([np.array(image) for image in images], axis=1) return Image.fromarray(res, mode='RGB') def get_image_from_w(w, G): if len(w.size()) <= 2: w = w.unsqueeze(0) with torch.no_grad(): img = G.synthesis(w, noise_mode='const') img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy() return img[0] ================================================ FILE: stylegan_human/utils/models_utils.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. import pickle import functools import torch from pti.pti_configs import paths_config, global_config def toogle_grad(model, flag=True): for p in model.parameters(): p.requires_grad = flag def load_tuned_G(run_id, type): new_G_path = f'{paths_config.checkpoints_dir}/model_{run_id}_{type}.pt' with open(new_G_path, 'rb') as f: new_G = torch.load(f).to(global_config.device).eval() new_G = new_G.float() toogle_grad(new_G, False) return new_G def load_old_G(): with open(paths_config.stylegan2_ada_shhq, 'rb') as f: old_G = pickle.load(f)['G_ema'].to(global_config.device).eval() old_G = old_G.float() return old_G ================================================ FILE: stylegan_human/utils/util.py ================================================ # Copyright (c) SenseTime Research. All rights reserved. import torch import cv2 from torchvision import transforms import numpy as np import math def visual(output, out_path): output = (output + 1)/2 output = torch.clamp(output, 0, 1) if output.shape[1] == 1: output = torch.cat([output, output, output], 1) output = output[0].detach().cpu().permute(1,2,0).numpy() output = (output*255).astype(np.uint8) output = output[:,:,::-1] cv2.imwrite(out_path, output) def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): lr_ramp = min(1, (1 - t) / rampdown) lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) lr_ramp = lr_ramp * min(1, t / rampup) return initial_lr * lr_ramp def latent_noise(latent, strength): noise = torch.randn_like(latent) * strength return latent + noise def noise_regularize_(noises): loss = 0 for noise in noises: size = noise.shape[2] while True: loss = ( loss + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2) + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2) ) if size <= 8: break noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2]) noise = noise.mean([3, 5]) size //= 2 return loss def noise_normalize_(noises): for noise in noises: mean = noise.mean() std = noise.std() noise.data.add_(-mean).div_(std) def tensor_to_numpy(x): x = x[0].permute(1, 2, 0) x = torch.clamp(x, -1 ,1) x = (x+1) * 127.5 x = x.cpu().detach().numpy().astype(np.uint8) return x def numpy_to_tensor(x): x = (x / 255 - 0.5) * 2 x = torch.from_numpy(x).unsqueeze(0).permute(0, 3, 1, 2) x = x.cuda().float() return x def tensor_to_pil(x): x = torch.clamp(x, -1 ,1) x = (x+1) * 127.5 return transforms.ToPILImage()(x.squeeze_(0)) ================================================ FILE: torch_utils/__init__.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. # empty ================================================ FILE: torch_utils/custom_ops.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import glob import hashlib import importlib import os import re import shutil import uuid import torch import torch.utils.cpp_extension from torch.utils.file_baton import FileBaton #---------------------------------------------------------------------------- # Global options. verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' #---------------------------------------------------------------------------- # Internal helper funcs. def _find_compiler_bindir(): patterns = [ 'C:/Program Files*/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 'C:/Program Files*/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 'C:/Program Files*/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 'C:/Program Files*/Microsoft Visual Studio */vc/bin', ] for pattern in patterns: matches = sorted(glob.glob(pattern)) if len(matches): return matches[-1] return None #---------------------------------------------------------------------------- def _get_mangled_gpu_name(): name = torch.cuda.get_device_name().lower() out = [] for c in name: if re.match('[a-z0-9_-]+', c): out.append(c) else: out.append('-') return ''.join(out) #---------------------------------------------------------------------------- # Main entry point for compiling and loading C++/CUDA plugins. _cached_plugins = dict() def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): assert verbosity in ['none', 'brief', 'full'] if headers is None: headers = [] if source_dir is not None: sources = [os.path.join(source_dir, fname) for fname in sources] headers = [os.path.join(source_dir, fname) for fname in headers] # Already cached? if module_name in _cached_plugins: return _cached_plugins[module_name] # Print status. if verbosity == 'full': print(f'Setting up PyTorch plugin "{module_name}"...') elif verbosity == 'brief': print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) verbose_build = (verbosity == 'full') # Compile and load. try: # pylint: disable=too-many-nested-blocks # Make sure we can find the necessary compiler binaries. if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: compiler_bindir = _find_compiler_bindir() if compiler_bindir is None: raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') os.environ['PATH'] += ';' + compiler_bindir # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either # break the build or unnecessarily restrict what's available to nvcc. # Unset it to let nvcc decide based on what's available on the # machine. os.environ['TORCH_CUDA_ARCH_LIST'] = '' # Incremental build md5sum trickery. Copies all the input source files # into a cached build directory under a combined md5 digest of the input # source files. Copying is done only if the combined digest has changed. # This keeps input file timestamps and filenames the same as in previous # extension builds, allowing for fast incremental rebuilds. # # This optimization is done only in case all the source files reside in # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR # environment variable is set (we take this as a signal that the user # actually cares about this.) # # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work # around the *.cu dependency bug in ninja config. # all_source_files = sorted(sources + headers) all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): # Compute combined hash digest for all source files. hash_md5 = hashlib.md5() for src in all_source_files: with open(src, 'rb') as f: hash_md5.update(f.read()) # Select cached build directory name. source_digest = hash_md5.hexdigest() build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') if not os.path.isdir(cached_build_dir): tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' os.makedirs(tmpdir) for src in all_source_files: shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) try: os.replace(tmpdir, cached_build_dir) # atomic except OSError: # source directory already exists, delete tmpdir and its contents. shutil.rmtree(tmpdir) if not os.path.isdir(cached_build_dir): raise # Compile. cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, verbose=verbose_build, sources=cached_sources, **build_kwargs) else: torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) # Load. module = importlib.import_module(module_name) except: if verbosity == 'brief': print('Failed!') raise # Print status and add to cache dict. if verbosity == 'full': print(f'Done setting up PyTorch plugin "{module_name}".') elif verbosity == 'brief': print('Done.') _cached_plugins[module_name] = module return module #---------------------------------------------------------------------------- ================================================ FILE: torch_utils/misc.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import re import contextlib import numpy as np import torch import warnings import dnnlib #---------------------------------------------------------------------------- # Cached construction of constant tensors. Avoids CPU=>GPU copy when the # same constant is used multiple times. _constant_cache = dict() def constant(value, shape=None, dtype=None, device=None, memory_format=None): value = np.asarray(value) if shape is not None: shape = tuple(shape) if dtype is None: dtype = torch.get_default_dtype() if device is None: device = torch.device('cpu') if memory_format is None: memory_format = torch.contiguous_format key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) tensor = _constant_cache.get(key, None) if tensor is None: tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) if shape is not None: tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) tensor = tensor.contiguous(memory_format=memory_format) _constant_cache[key] = tensor return tensor #---------------------------------------------------------------------------- # Replace NaN/Inf with specified numerical values. try: nan_to_num = torch.nan_to_num # 1.8.0a0 except AttributeError: def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin assert isinstance(input, torch.Tensor) if posinf is None: posinf = torch.finfo(input.dtype).max if neginf is None: neginf = torch.finfo(input.dtype).min assert nan == 0 return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) #---------------------------------------------------------------------------- # Symbolic assert. try: symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access except AttributeError: symbolic_assert = torch.Assert # 1.7.0 #---------------------------------------------------------------------------- # Context manager to temporarily suppress known warnings in torch.jit.trace(). # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 @contextlib.contextmanager def suppress_tracer_warnings(): flt = ('ignore', None, torch.jit.TracerWarning, None, 0) warnings.filters.insert(0, flt) yield warnings.filters.remove(flt) #---------------------------------------------------------------------------- # Assert that the shape of a tensor matches the given list of integers. # None indicates that the size of a dimension is allowed to vary. # Performs symbolic assertion when used in torch.jit.trace(). def assert_shape(tensor, ref_shape): if tensor.ndim != len(ref_shape): raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): if ref_size is None: pass elif isinstance(ref_size, torch.Tensor): with suppress_tracer_warnings(): # as_tensor results are registered as constants symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') elif isinstance(size, torch.Tensor): with suppress_tracer_warnings(): # as_tensor results are registered as constants symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') elif size != ref_size: raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') #---------------------------------------------------------------------------- # Function decorator that calls torch.autograd.profiler.record_function(). def profiled_function(fn): def decorator(*args, **kwargs): with torch.autograd.profiler.record_function(fn.__name__): return fn(*args, **kwargs) decorator.__name__ = fn.__name__ return decorator #---------------------------------------------------------------------------- # Sampler for torch.utils.data.DataLoader that loops over the dataset # indefinitely, shuffling items as it goes. class InfiniteSampler(torch.utils.data.Sampler): def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): assert len(dataset) > 0 assert num_replicas > 0 assert 0 <= rank < num_replicas assert 0 <= window_size <= 1 super().__init__(dataset) self.dataset = dataset self.rank = rank self.num_replicas = num_replicas self.shuffle = shuffle self.seed = seed self.window_size = window_size def __iter__(self): order = np.arange(len(self.dataset)) rnd = None window = 0 if self.shuffle: rnd = np.random.RandomState(self.seed) rnd.shuffle(order) window = int(np.rint(order.size * self.window_size)) idx = 0 while True: i = idx % order.size if idx % self.num_replicas == self.rank: yield order[i] if window >= 2: j = (i - rnd.randint(window)) % order.size order[i], order[j] = order[j], order[i] idx += 1 #---------------------------------------------------------------------------- # Utilities for operating with torch.nn.Module parameters and buffers. def params_and_buffers(module): assert isinstance(module, torch.nn.Module) return list(module.parameters()) + list(module.buffers()) def named_params_and_buffers(module): assert isinstance(module, torch.nn.Module) return list(module.named_parameters()) + list(module.named_buffers()) def copy_params_and_buffers(src_module, dst_module, require_all=False): assert isinstance(src_module, torch.nn.Module) assert isinstance(dst_module, torch.nn.Module) src_tensors = dict(named_params_and_buffers(src_module)) for name, tensor in named_params_and_buffers(dst_module): assert (name in src_tensors) or (not require_all) if name in src_tensors: tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) #---------------------------------------------------------------------------- # Context manager for easily enabling/disabling DistributedDataParallel # synchronization. @contextlib.contextmanager def ddp_sync(module, sync): assert isinstance(module, torch.nn.Module) if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): yield else: with module.no_sync(): yield #---------------------------------------------------------------------------- # Check DistributedDataParallel consistency across processes. def check_ddp_consistency(module, ignore_regex=None): assert isinstance(module, torch.nn.Module) for name, tensor in named_params_and_buffers(module): fullname = type(module).__name__ + '.' + name if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): continue tensor = tensor.detach() if tensor.is_floating_point(): tensor = nan_to_num(tensor) other = tensor.clone() torch.distributed.broadcast(tensor=other, src=0) assert (tensor == other).all(), fullname #---------------------------------------------------------------------------- # Print summary table of module hierarchy. def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): assert isinstance(module, torch.nn.Module) assert not isinstance(module, torch.jit.ScriptModule) assert isinstance(inputs, (tuple, list)) # Register hooks. entries = [] nesting = [0] def pre_hook(_mod, _inputs): nesting[0] += 1 def post_hook(mod, _inputs, outputs): nesting[0] -= 1 if nesting[0] <= max_nesting: outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] outputs = [t for t in outputs if isinstance(t, torch.Tensor)] entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] # Run module. outputs = module(*inputs) for hook in hooks: hook.remove() # Identify unique outputs, parameters, and buffers. tensors_seen = set() for e in entries: e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} # Filter out redundant entries. if skip_redundant: entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] # Construct table. rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] rows += [['---'] * len(rows[0])] param_total = 0 buffer_total = 0 submodule_names = {mod: name for name, mod in module.named_modules()} for e in entries: name = '' if e.mod is module else submodule_names[e.mod] param_size = sum(t.numel() for t in e.unique_params) buffer_size = sum(t.numel() for t in e.unique_buffers) output_shapes = [str(list(t.shape)) for t in e.outputs] output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] rows += [[ name + (':0' if len(e.outputs) >= 2 else ''), str(param_size) if param_size else '-', str(buffer_size) if buffer_size else '-', (output_shapes + ['-'])[0], (output_dtypes + ['-'])[0], ]] for idx in range(1, len(e.outputs)): rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] param_total += param_size buffer_total += buffer_size rows += [['---'] * len(rows[0])] rows += [['Total', str(param_total), str(buffer_total), '-', '-']] # Print table. widths = [max(len(cell) for cell in column) for column in zip(*rows)] print() for row in rows: print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) print() return outputs #---------------------------------------------------------------------------- ================================================ FILE: torch_utils/ops/__init__.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. # empty ================================================ FILE: torch_utils/ops/bias_act.cpp ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include #include #include #include "bias_act.h" //------------------------------------------------------------------------ static bool has_same_layout(torch::Tensor x, torch::Tensor y) { if (x.dim() != y.dim()) return false; for (int64_t i = 0; i < x.dim(); i++) { if (x.size(i) != y.size(i)) return false; if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) return false; } return true; } //------------------------------------------------------------------------ static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) { // Validate arguments. TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); TORCH_CHECK(b.dim() == 1, "b must have rank 1"); TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); TORCH_CHECK(grad >= 0, "grad must be non-negative"); // Validate layout. TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); // Create output tensor. const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); torch::Tensor y = torch::empty_like(x); TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); // Initialize CUDA kernel parameters. bias_act_kernel_params p; p.x = x.data_ptr(); p.b = (b.numel()) ? b.data_ptr() : NULL; p.xref = (xref.numel()) ? xref.data_ptr() : NULL; p.yref = (yref.numel()) ? yref.data_ptr() : NULL; p.dy = (dy.numel()) ? dy.data_ptr() : NULL; p.y = y.data_ptr(); p.grad = grad; p.act = act; p.alpha = alpha; p.gain = gain; p.clamp = clamp; p.sizeX = (int)x.numel(); p.sizeB = (int)b.numel(); p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; // Choose CUDA kernel. void* kernel; AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { kernel = choose_bias_act_kernel(p); }); TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); // Launch CUDA kernel. p.loopX = 4; int blockSize = 4 * 32; int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; void* args[] = {&p}; AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); return y; } //------------------------------------------------------------------------ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("bias_act", &bias_act); } //------------------------------------------------------------------------ ================================================ FILE: torch_utils/ops/bias_act.cu ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include #include "bias_act.h" //------------------------------------------------------------------------ // Helpers. template struct InternalType; template <> struct InternalType { typedef double scalar_t; }; template <> struct InternalType { typedef float scalar_t; }; template <> struct InternalType { typedef float scalar_t; }; //------------------------------------------------------------------------ // CUDA kernel. template __global__ void bias_act_kernel(bias_act_kernel_params p) { typedef typename InternalType::scalar_t scalar_t; int G = p.grad; scalar_t alpha = (scalar_t)p.alpha; scalar_t gain = (scalar_t)p.gain; scalar_t clamp = (scalar_t)p.clamp; scalar_t one = (scalar_t)1; scalar_t two = (scalar_t)2; scalar_t expRange = (scalar_t)80; scalar_t halfExpRange = (scalar_t)40; scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; // Loop over elements. int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) { // Load. scalar_t x = (scalar_t)((const T*)p.x)[xi]; scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; scalar_t yy = (gain != 0) ? yref / gain : 0; scalar_t y = 0; // Apply bias. ((G == 0) ? x : xref) += b; // linear if (A == 1) { if (G == 0) y = x; if (G == 1) y = x; } // relu if (A == 2) { if (G == 0) y = (x > 0) ? x : 0; if (G == 1) y = (yy > 0) ? x : 0; } // lrelu if (A == 3) { if (G == 0) y = (x > 0) ? x : x * alpha; if (G == 1) y = (yy > 0) ? x : x * alpha; } // tanh if (A == 4) { if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } if (G == 1) y = x * (one - yy * yy); if (G == 2) y = x * (one - yy * yy) * (-two * yy); } // sigmoid if (A == 5) { if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); if (G == 1) y = x * yy * (one - yy); if (G == 2) y = x * yy * (one - yy) * (one - two * yy); } // elu if (A == 6) { if (G == 0) y = (x >= 0) ? x : exp(x) - one; if (G == 1) y = (yy >= 0) ? x : x * (yy + one); if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); } // selu if (A == 7) { if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); } // softplus if (A == 8) { if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); if (G == 1) y = x * (one - exp(-yy)); if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } } // swish if (A == 9) { if (G == 0) y = (x < -expRange) ? 0 : x / (exp(-x) + one); else { scalar_t c = exp(xref); scalar_t d = c + one; if (G == 1) y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); else y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; } } // Apply gain. y *= gain * dy; // Clamp. if (clamp >= 0) { if (G == 0) y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; else y = (yref > -clamp & yref < clamp) ? y : 0; } // Store. ((T*)p.y)[xi] = (T)y; } } //------------------------------------------------------------------------ // CUDA kernel selection. template void* choose_bias_act_kernel(const bias_act_kernel_params& p) { if (p.act == 1) return (void*)bias_act_kernel; if (p.act == 2) return (void*)bias_act_kernel; if (p.act == 3) return (void*)bias_act_kernel; if (p.act == 4) return (void*)bias_act_kernel; if (p.act == 5) return (void*)bias_act_kernel; if (p.act == 6) return (void*)bias_act_kernel; if (p.act == 7) return (void*)bias_act_kernel; if (p.act == 8) return (void*)bias_act_kernel; if (p.act == 9) return (void*)bias_act_kernel; return NULL; } //------------------------------------------------------------------------ // Template specializations. template void* choose_bias_act_kernel (const bias_act_kernel_params& p); template void* choose_bias_act_kernel (const bias_act_kernel_params& p); template void* choose_bias_act_kernel (const bias_act_kernel_params& p); //------------------------------------------------------------------------ ================================================ FILE: torch_utils/ops/bias_act.h ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. //------------------------------------------------------------------------ // CUDA kernel parameters. struct bias_act_kernel_params { const void* x; // [sizeX] const void* b; // [sizeB] or NULL const void* xref; // [sizeX] or NULL const void* yref; // [sizeX] or NULL const void* dy; // [sizeX] or NULL void* y; // [sizeX] int grad; int act; float alpha; float gain; float clamp; int sizeX; int sizeB; int stepB; int loopX; }; //------------------------------------------------------------------------ // CUDA kernel selection. template void* choose_bias_act_kernel(const bias_act_kernel_params& p); //------------------------------------------------------------------------ ================================================ FILE: torch_utils/ops/bias_act.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Custom PyTorch ops for efficient bias and activation.""" import os import numpy as np import torch import dnnlib from .. import custom_ops from .. import misc #---------------------------------------------------------------------------- activation_funcs = { 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), } #---------------------------------------------------------------------------- _plugin = None _null_tensor = torch.empty([0]) def _init(): global _plugin if _plugin is None: _plugin = custom_ops.get_plugin( module_name='bias_act_plugin', sources=['bias_act.cpp', 'bias_act.cu'], headers=['bias_act.h'], source_dir=os.path.dirname(__file__), extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'], ) return True #---------------------------------------------------------------------------- def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): r"""Fused bias and activation function. Adds bias `b` to activation tensor `x`, evaluates activation function `act`, and scales the result by `gain`. Each of the steps is optional. In most cases, the fused op is considerably more efficient than performing the same calculation using standard PyTorch ops. It supports first and second order gradients, but not third order gradients. Args: x: Input activation tensor. Can be of any shape. b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type as `x`. The shape must be known, and it must match the dimension of `x` corresponding to `dim`. dim: The dimension in `x` corresponding to the elements of `b`. The value of `dim` is ignored if `b` is not specified. act: Name of the activation function to evaluate, or `"linear"` to disable. Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. See `activation_funcs` for a full list. `None` is not allowed. alpha: Shape parameter for the activation function, or `None` to use the default. gain: Scaling factor for the output tensor, or `None` to use default. See `activation_funcs` for the default scaling of each activation function. If unsure, consider specifying 1. clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable the clamping (default). impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). Returns: Tensor of the same shape and datatype as `x`. """ assert isinstance(x, torch.Tensor) assert impl in ['ref', 'cuda'] if impl == 'cuda' and x.device.type == 'cuda' and _init(): return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) #---------------------------------------------------------------------------- @misc.profiled_function def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): """Slow reference implementation of `bias_act()` using standard TensorFlow ops. """ assert isinstance(x, torch.Tensor) assert clamp is None or clamp >= 0 spec = activation_funcs[act] alpha = float(alpha if alpha is not None else spec.def_alpha) gain = float(gain if gain is not None else spec.def_gain) clamp = float(clamp if clamp is not None else -1) # Add bias. if b is not None: assert isinstance(b, torch.Tensor) and b.ndim == 1 assert 0 <= dim < x.ndim assert b.shape[0] == x.shape[dim] x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) # Evaluate activation function. alpha = float(alpha) x = spec.func(x, alpha=alpha) # Scale by gain. gain = float(gain) if gain != 1: x = x * gain # Clamp. if clamp >= 0: x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type return x #---------------------------------------------------------------------------- _bias_act_cuda_cache = dict() def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): """Fast CUDA implementation of `bias_act()` using custom ops. """ # Parse arguments. assert clamp is None or clamp >= 0 spec = activation_funcs[act] alpha = float(alpha if alpha is not None else spec.def_alpha) gain = float(gain if gain is not None else spec.def_gain) clamp = float(clamp if clamp is not None else -1) # Lookup from cache. key = (dim, act, alpha, gain, clamp) if key in _bias_act_cuda_cache: return _bias_act_cuda_cache[key] # Forward op. class BiasActCuda(torch.autograd.Function): @staticmethod def forward(ctx, x, b): # pylint: disable=arguments-differ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format x = x.contiguous(memory_format=ctx.memory_format) b = b.contiguous() if b is not None else _null_tensor y = x if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) ctx.save_for_backward( x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, y if 'y' in spec.ref else _null_tensor) return y @staticmethod def backward(ctx, dy): # pylint: disable=arguments-differ dy = dy.contiguous(memory_format=ctx.memory_format) x, b, y = ctx.saved_tensors dx = None db = None if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: dx = dy if act != 'linear' or gain != 1 or clamp >= 0: dx = BiasActCudaGrad.apply(dy, x, b, y) if ctx.needs_input_grad[1]: db = dx.sum([i for i in range(dx.ndim) if i != dim]) return dx, db # Backward op. class BiasActCudaGrad(torch.autograd.Function): @staticmethod def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) ctx.save_for_backward( dy if spec.has_2nd_grad else _null_tensor, x, b, y) return dx @staticmethod def backward(ctx, d_dx): # pylint: disable=arguments-differ d_dx = d_dx.contiguous(memory_format=ctx.memory_format) dy, x, b, y = ctx.saved_tensors d_dy = None d_x = None d_b = None d_y = None if ctx.needs_input_grad[0]: d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) if spec.has_2nd_grad and ctx.needs_input_grad[2]: d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) return d_dy, d_x, d_b, d_y # Add to cache. _bias_act_cuda_cache[key] = BiasActCuda return BiasActCuda #---------------------------------------------------------------------------- ================================================ FILE: torch_utils/ops/conv2d_gradfix.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Custom replacement for `torch.nn.functional.conv2d` that supports arbitrarily high order gradients with zero performance penalty.""" import contextlib import torch # pylint: disable=redefined-builtin # pylint: disable=arguments-differ # pylint: disable=protected-access #---------------------------------------------------------------------------- enabled = False # Enable the custom op by setting this to true. weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. @contextlib.contextmanager def no_weight_gradients(disable=True): global weight_gradients_disabled old = weight_gradients_disabled if disable: weight_gradients_disabled = True yield weight_gradients_disabled = old #---------------------------------------------------------------------------- def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): if _should_use_custom_op(input): return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): if _should_use_custom_op(input): return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) #---------------------------------------------------------------------------- def _should_use_custom_op(input): assert isinstance(input, torch.Tensor) if (not enabled) or (not torch.backends.cudnn.enabled): return False if input.device.type != 'cuda': return False return True def _tuple_of_ints(xs, ndim): xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim assert len(xs) == ndim assert all(isinstance(x, int) for x in xs) return xs #---------------------------------------------------------------------------- _conv2d_gradfix_cache = dict() _null_tensor = torch.empty([0]) def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): # Parse arguments. ndim = 2 weight_shape = tuple(weight_shape) stride = _tuple_of_ints(stride, ndim) padding = _tuple_of_ints(padding, ndim) output_padding = _tuple_of_ints(output_padding, ndim) dilation = _tuple_of_ints(dilation, ndim) # Lookup from cache. key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) if key in _conv2d_gradfix_cache: return _conv2d_gradfix_cache[key] # Validate arguments. assert groups >= 1 assert len(weight_shape) == ndim + 2 assert all(stride[i] >= 1 for i in range(ndim)) assert all(padding[i] >= 0 for i in range(ndim)) assert all(dilation[i] >= 0 for i in range(ndim)) if not transpose: assert all(output_padding[i] == 0 for i in range(ndim)) else: # transpose assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) # Helpers. common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) def calc_output_padding(input_shape, output_shape): if transpose: return [0, 0] return [ input_shape[i + 2] - (output_shape[i + 2] - 1) * stride[i] - (1 - 2 * padding[i]) - dilation[i] * (weight_shape[i + 2] - 1) for i in range(ndim) ] # Forward & backward. class Conv2d(torch.autograd.Function): @staticmethod def forward(ctx, input, weight, bias): assert weight.shape == weight_shape ctx.save_for_backward( input if weight.requires_grad else _null_tensor, weight if input.requires_grad else _null_tensor, ) ctx.input_shape = input.shape # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere). if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0): a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1]) b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1) c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2) c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1) c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) # General case => cuDNN. if transpose: return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) @staticmethod def backward(ctx, grad_output): input, weight = ctx.saved_tensors input_shape = ctx.input_shape grad_input = None grad_weight = None grad_bias = None if ctx.needs_input_grad[0]: p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape) op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) grad_input = op.apply(grad_output, weight, None) assert grad_input.shape == input_shape if ctx.needs_input_grad[1] and not weight_gradients_disabled: grad_weight = Conv2dGradWeight.apply(grad_output, input) assert grad_weight.shape == weight_shape if ctx.needs_input_grad[2]: grad_bias = grad_output.sum([0, 2, 3]) return grad_input, grad_weight, grad_bias # Gradient with respect to the weights. class Conv2dGradWeight(torch.autograd.Function): @staticmethod def forward(ctx, grad_output, input): ctx.save_for_backward( grad_output if input.requires_grad else _null_tensor, input if grad_output.requires_grad else _null_tensor, ) ctx.grad_output_shape = grad_output.shape ctx.input_shape = input.shape # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere). if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0): a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape) return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) # General case => cuDNN. name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight' flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) @staticmethod def backward(ctx, grad2_grad_weight): grad_output, input = ctx.saved_tensors grad_output_shape = ctx.grad_output_shape input_shape = ctx.input_shape grad2_grad_output = None grad2_input = None if ctx.needs_input_grad[0]: grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) assert grad2_grad_output.shape == grad_output_shape if ctx.needs_input_grad[1]: p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape) op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) grad2_input = op.apply(grad_output, grad2_grad_weight, None) assert grad2_input.shape == input_shape return grad2_grad_output, grad2_input _conv2d_gradfix_cache[key] = Conv2d return Conv2d #---------------------------------------------------------------------------- ================================================ FILE: torch_utils/ops/conv2d_resample.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """2D convolution with optional up/downsampling.""" import torch from .. import misc from . import conv2d_gradfix from . import upfirdn2d from .upfirdn2d import _parse_padding from .upfirdn2d import _get_filter_size #---------------------------------------------------------------------------- def _get_weight_shape(w): with misc.suppress_tracer_warnings(): # this value will be treated as a constant shape = [int(sz) for sz in w.shape] misc.assert_shape(w, shape) return shape #---------------------------------------------------------------------------- def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. """ _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w) # Flip weight if requested. # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). if not flip_weight and (kw > 1 or kh > 1): w = w.flip([2, 3]) # Execute using conv2d_gradfix. op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d return op(x, w, stride=stride, padding=padding, groups=groups) #---------------------------------------------------------------------------- @misc.profiled_function def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): r"""2D convolution with optional up/downsampling. Padding is performed only once at the beginning, not between the operations. Args: x: Input tensor of shape `[batch_size, in_channels, in_height, in_width]`. w: Weight tensor of shape `[out_channels, in_channels//groups, kernel_height, kernel_width]`. f: Low-pass filter for up/downsampling. Must be prepared beforehand by calling upfirdn2d.setup_filter(). None = identity (default). up: Integer upsampling factor (default: 1). down: Integer downsampling factor (default: 1). padding: Padding with respect to the upsampled image. Can be a single number or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` (default: 0). groups: Split input channels into N groups (default: 1). flip_weight: False = convolution, True = correlation (default: True). flip_filter: False = convolution, True = correlation (default: False). Returns: Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. """ # Validate arguments. assert isinstance(x, torch.Tensor) and (x.ndim == 4) assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) assert isinstance(up, int) and (up >= 1) assert isinstance(down, int) and (down >= 1) assert isinstance(groups, int) and (groups >= 1) out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) fw, fh = _get_filter_size(f) px0, px1, py0, py1 = _parse_padding(padding) # Adjust padding to account for up/downsampling. if up > 1: px0 += (fw + up - 1) // 2 px1 += (fw - up) // 2 py0 += (fh + up - 1) // 2 py1 += (fh - up) // 2 if down > 1: px0 += (fw - down + 1) // 2 px1 += (fw - down) // 2 py0 += (fh - down + 1) // 2 py1 += (fh - down) // 2 # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. if kw == 1 and kh == 1 and (down > 1 and up == 1): x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) return x # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. if kw == 1 and kh == 1 and (up > 1 and down == 1): x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) return x # Fast path: downsampling only => use strided convolution. if down > 1 and up == 1: x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) return x # Fast path: upsampling with optional downsampling => use transpose strided convolution. if up > 1: if groups == 1: w = w.transpose(0, 1) else: w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) w = w.transpose(1, 2) w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) px0 -= kw - 1 px1 -= kw - up py0 -= kh - 1 py1 -= kh - up pxt = max(min(-px0, -px1), 0) pyt = max(min(-py0, -py1), 0) x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) if down > 1: x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) return x # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. if up == 1 and down == 1: if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) # Fallback: Generic reference implementation. x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) if down > 1: x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) return x #---------------------------------------------------------------------------- ================================================ FILE: torch_utils/ops/filtered_lrelu.cpp ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include #include #include #include "filtered_lrelu.h" //------------------------------------------------------------------------ static std::tuple filtered_lrelu( torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si, int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns) { // Set CUDA device. TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); // Validate arguments. TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device"); TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32"); TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype"); TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32"); TORCH_CHECK(x.dim() == 4, "x must be rank 4"); TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large"); TORCH_CHECK(x.numel() > 0, "x is empty"); TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2"); TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large"); TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large"); TORCH_CHECK(fu.numel() > 0, "fu is empty"); TORCH_CHECK(fd.numel() > 0, "fd is empty"); TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x"); TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1"); // Figure out how much shared memory is available on the device. int maxSharedBytes = 0; AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index())); int sharedKB = maxSharedBytes >> 10; // Populate enough launch parameters to check if a CUDA kernel exists. filtered_lrelu_kernel_params p; p.up = up; p.down = down; p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter. p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0); filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel(p, sharedKB); if (!test_spec.exec) { // No kernel found - return empty tensors and indicate missing kernel with return code of -1. return std::make_tuple(torch::Tensor(), torch::Tensor(), -1); } // Input/output element size. int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4; // Input sizes. int64_t xw = (int)x.size(3); int64_t xh = (int)x.size(2); int64_t fut_w = (int)fu.size(-1) - 1; int64_t fut_h = (int)fu.size(0) - 1; int64_t fdt_w = (int)fd.size(-1) - 1; int64_t fdt_h = (int)fd.size(0) - 1; // Logical size of upsampled buffer. int64_t cw = xw * up + (px0 + px1) - fut_w; int64_t ch = xh * up + (py0 + py1) - fut_h; TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter"); TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large"); // Compute output size and allocate. int64_t yw = (cw - fdt_w + (down - 1)) / down; int64_t yh = (ch - fdt_h + (down - 1)) / down; TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1"); TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large"); torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format()); // Allocate sign tensor. torch::Tensor so; torch::Tensor s = si; bool readSigns = !!s.numel(); int64_t sw_active = 0; // Active width of sign tensor. if (writeSigns) { sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements. int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height. int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16. TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large"); s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous); } else if (readSigns) sw_active = s.size(3) << 2; // Validate sign tensor if in use. if (readSigns || writeSigns) { TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x"); TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x"); TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large"); } // Populate rest of CUDA kernel parameters. p.x = x.data_ptr(); p.y = y.data_ptr(); p.b = b.data_ptr(); p.s = (readSigns || writeSigns) ? s.data_ptr() : 0; p.fu = fu.data_ptr(); p.fd = fd.data_ptr(); p.pad0 = make_int2(px0, py0); p.gain = gain; p.slope = slope; p.clamp = clamp; p.flip = (flip_filters) ? 1 : 0; p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous. p.sOfs = make_int2(sx, sy); p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes. // x, y, b strides are in bytes. p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0)); p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0)); p.bStride = sz * b.stride(0); // fu, fd strides are in elements. p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0); p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0); // Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those. bool index64b = false; if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true; if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true; if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true; if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true; if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true; if (s.numel() > INT_MAX) index64b = true; // Choose CUDA kernel. filtered_lrelu_kernel_spec spec = { 0 }; AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&] { if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation. { // Choose kernel based on index type, datatype and sign read/write modes. if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); } }); TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists. // Launch CUDA kernel. void* args[] = {&p}; int bx = spec.numWarps * 32; int gx = (p.yShape.x - 1) / spec.tileOut.x + 1; int gy = (p.yShape.y - 1) / spec.tileOut.y + 1; int gz = p.yShape.z * p.yShape.w; // Repeat multiple horizontal tiles in a CTA? if (spec.xrep) { p.tilesXrep = spec.xrep; p.tilesXdim = gx; gx = (gx + p.tilesXrep - 1) / p.tilesXrep; std::swap(gx, gy); } else { p.tilesXrep = 0; p.tilesXdim = 0; } // Launch filter setup kernel. AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream())); // Copy kernels to constant memory. if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); // Set cache and shared memory configurations for main kernel. AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared)); if (spec.dynamicSharedKB) // Need dynamically allocated shared memory? AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10)); AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte)); // Launch main kernel. const int maxSubGz = 65535; // CUDA maximum for block z dimension. for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big. { p.blockZofs = zofs; int subGz = std::min(maxSubGz, gz - zofs); AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream())); } // Done. return std::make_tuple(y, so, 0); } //------------------------------------------------------------------------ static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns) { // Set CUDA device. TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); // Validate arguments. TORCH_CHECK(x.dim() == 4, "x must be rank 4"); TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large"); TORCH_CHECK(x.numel() > 0, "x is empty"); TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64"); // Output signs if we don't have sign input. torch::Tensor so; torch::Tensor s = si; bool readSigns = !!s.numel(); if (writeSigns) { int64_t sw = x.size(3); sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing. s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous); } // Validate sign tensor if in use. if (readSigns || writeSigns) { TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x"); TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x"); TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large"); } // Initialize CUDA kernel parameters. filtered_lrelu_act_kernel_params p; p.x = x.data_ptr(); p.s = (readSigns || writeSigns) ? s.data_ptr() : 0; p.gain = gain; p.slope = slope; p.clamp = clamp; p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0)); p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous. p.sOfs = make_int2(sx, sy); // Choose CUDA kernel. void* func = 0; AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&] { if (writeSigns) func = choose_filtered_lrelu_act_kernel(); else if (readSigns) func = choose_filtered_lrelu_act_kernel(); else func = choose_filtered_lrelu_act_kernel(); }); TORCH_CHECK(func, "internal error - CUDA kernel not found"); // Launch CUDA kernel. void* args[] = {&p}; int bx = 128; // 4 warps per block. // Logical size of launch = writeSigns ? p.s : p.x uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x; uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y; uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use. gx = (gx - 1) / bx + 1; // Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest. const uint32_t gmax = 65535; gy = std::min(gy, gmax); gz = std::min(gz, gmax); // Launch. AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream())); return so; } //------------------------------------------------------------------------ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("filtered_lrelu", &filtered_lrelu); // The whole thing. m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place. } //------------------------------------------------------------------------ ================================================ FILE: torch_utils/ops/filtered_lrelu.cu ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include #include "filtered_lrelu.h" #include //------------------------------------------------------------------------ // Helpers. enum // Filter modes. { MODE_SUSD = 0, // Separable upsampling, separable downsampling. MODE_FUSD = 1, // Full upsampling, separable downsampling. MODE_SUFD = 2, // Separable upsampling, full downsampling. MODE_FUFD = 3, // Full upsampling, full downsampling. }; template struct InternalType; template <> struct InternalType { typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t; __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); } __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); } __device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); } }; template <> struct InternalType { typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t; __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); } __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); } __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); } }; template <> struct InternalType { typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t; __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); } __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); } __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); } }; #define MIN(A, B) ((A) < (B) ? (A) : (B)) #define MAX(A, B) ((A) > (B) ? (A) : (B)) #define CEIL_DIV(A, B) (((B)==1) ? (A) : \ ((B)==2) ? ((int)((A)+1) >> 1) : \ ((B)==4) ? ((int)((A)+3) >> 2) : \ (((A) + ((A) > 0 ? (B) - 1 : 0)) / (B))) // This works only up to blocks of size 256 x 256 and for all N that are powers of two. template __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i) { if ((N & (N-1)) && N <= 256) y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256. else y = i/N; x = i - y*N; } // Type cast stride before reading it. template __device__ __forceinline__ T get_stride(const int64_t& x) { return *reinterpret_cast(&x); } //------------------------------------------------------------------------ // Filters, setup kernel, copying function. #define MAX_FILTER_SIZE 32 // Combined up/down filter buffers so that transfer can be done with one copy. __device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel. __device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel. // Accessors to combined buffers to index up/down filters individually. #define c_fu (c_fbuf) #define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE) #define g_fu (g_fbuf) #define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE) // Set up filters into global memory buffer. static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p) { for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x) { int x, y; fast_div_mod(x, y, idx); int fu_x = p.flip ? x : (p.fuShape.x - 1 - x); int fu_y = p.flip ? y : (p.fuShape.y - 1 - y); if (p.fuShape.y > 0) g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y]; else g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x]; int fd_x = p.flip ? x : (p.fdShape.x - 1 - x); int fd_y = p.flip ? y : (p.fdShape.y - 1 - y); if (p.fdShape.y > 0) g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y]; else g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x]; } } // Host function to copy filters written by setup kernel into constant buffer for main kernel. template static cudaError_t copy_filters(cudaStream_t stream) { void* src = 0; cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf); if (err) return err; return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream); } //------------------------------------------------------------------------ // Coordinate spaces: // - Relative to input tensor: inX, inY, tileInX, tileInY // - Relative to input tile: relInX, relInY, tileInW, tileInH // - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH // - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH // - Relative to output tensor: outX, outY, tileOutX, tileOutY // // Relationships between coordinate spaces: // - inX = tileInX + relInX // - inY = tileInY + relInY // - relUpX = relInX * up + phaseInX // - relUpY = relInY * up + phaseInY // - relUpX = relOutX * down // - relUpY = relOutY * down // - outX = tileOutX + relOutX // - outY = tileOutY + relOutY extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer. template static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) { // Check that we don't try to support non-existing filter modes. static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported"); static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported"); static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor"); static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor"); static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor"); static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor"); static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE"); static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters"); static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters"); static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4"); static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4"); // Static definitions. typedef typename InternalType::scalar_t scalar_t; typedef typename InternalType::vec2_t vec2_t; typedef typename InternalType::vec4_t vec4_t; const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4. const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height. const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width. const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height. const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up. const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4. // Merge 1x1 downsampling into last upsampling step for upf1 and ups2. const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD)); // Sizes of logical buffers. const int szIn = tileInH_up * tileInW; const int szUpX = tileInH_up * tileUpW; const int szUpXY = downInline ? 0 : (tileUpH * tileUpW); const int szDownX = tileUpH * tileOutW; // Sizes for shared memory arrays. const int s_buf0_size_base = (filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) : (filterMode == MODE_FUSD) ? MAX(szIn, szDownX) : (filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) : (filterMode == MODE_FUFD) ? szIn : -1; const int s_buf1_size_base = (filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) : (filterMode == MODE_FUSD) ? szUpXY : (filterMode == MODE_SUFD) ? szUpX : (filterMode == MODE_FUFD) ? szUpXY : -1; // Ensure U128 alignment. const int s_buf0_size = (s_buf0_size_base + 3) & ~3; const int s_buf1_size = (s_buf1_size_base + 3) & ~3; // Check at compile time that we don't use too much shared memory. static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow"); // Declare shared memory arrays. scalar_t* s_buf0; scalar_t* s_buf1; if (sharedKB <= 48) { // Allocate shared memory arrays here. __shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused. s_buf0 = s_buf0_st; s_buf1 = s_buf0 + s_buf0_size; } else { // Use the dynamically allocated shared memory array. s_buf0 = (scalar_t*)s_buf_raw; s_buf1 = s_buf0 + s_buf0_size; } // Pointers to the buffers. scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY] scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX] scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX] scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX] if (filterMode == MODE_SUSD) { s_tileIn = s_buf0; s_tileUpX = s_buf1; s_tileUpXY = s_buf0; s_tileDownX = s_buf1; } else if (filterMode == MODE_FUSD) { s_tileIn = s_buf0; s_tileUpXY = s_buf1; s_tileDownX = s_buf0; } else if (filterMode == MODE_SUFD) { s_tileIn = s_buf0; s_tileUpX = s_buf1; s_tileUpXY = s_buf0; } else if (filterMode == MODE_FUFD) { s_tileIn = s_buf0; s_tileUpXY = s_buf1; } // Allow large grids in z direction via per-launch offset. int channelIdx = blockIdx.z + p.blockZofs; int batchIdx = channelIdx / p.yShape.z; channelIdx -= batchIdx * p.yShape.z; // Offset to output feature map. In bytes. index_t mapOfsOut = channelIdx * get_stride(p.yStride.z) + batchIdx * get_stride(p.yStride.w); // Sign shift amount. uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6; // Inner tile loop. #pragma unroll 1 for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++) { // Locate output tile. int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x; int tileOutX = tileX * tileOutW; int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH; // Locate input tile. int tmpX = tileOutX * down - p.pad0.x; int tmpY = tileOutY * down - p.pad0.y; int tileInX = CEIL_DIV(tmpX, up); int tileInY = CEIL_DIV(tmpY, up); const int phaseInX = tileInX * up - tmpX; const int phaseInY = tileInY * up - tmpY; // Extra sync if input and output buffers are the same and we are not on first tile. if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline))) __syncthreads(); // Load input tile & apply bias. Unrolled. scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride(p.bStride))); index_t mapOfsIn = channelIdx * get_stride(p.xStride.z) + batchIdx * get_stride(p.xStride.w); int idx = threadIdx.x; const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock); #pragma unroll for (int loop = 0; loop < loopCountIN; loop++) { int relInX, relInY; fast_div_mod(relInX, relInY, idx); int inX = tileInX + relInX; int inY = tileInY + relInY; scalar_t v = 0; if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y) v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride(p.xStride.x) + inY * get_stride(p.xStride.y) + mapOfsIn))) + b; bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH); if (!skip) s_tileIn[idx] = v; idx += threadsPerBlock; } if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter. { // Horizontal upsampling. __syncthreads(); if (up == 4) { for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up) { int relUpX0, relInY; fast_div_mod(relUpX0, relInY, idx); int relInX0 = relUpX0 / up; int src0 = relInX0 + tileInW * relInY; int dst = relInY * tileUpW + relUpX0; vec4_t v = InternalType::zero_vec4(); scalar_t a = s_tileIn[src0]; if (phaseInX == 0) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 0]; a = s_tileIn[src0 + step + 1]; v.y += a * (scalar_t)c_fu[step * up + 3]; v.z += a * (scalar_t)c_fu[step * up + 2]; v.w += a * (scalar_t)c_fu[step * up + 1]; } } else if (phaseInX == 1) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 1]; v.y += a * (scalar_t)c_fu[step * up + 0]; a = s_tileIn[src0 + step + 1]; v.z += a * (scalar_t)c_fu[step * up + 3]; v.w += a * (scalar_t)c_fu[step * up + 2]; } } else if (phaseInX == 2) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 2]; v.y += a * (scalar_t)c_fu[step * up + 1]; v.z += a * (scalar_t)c_fu[step * up + 0]; a = s_tileIn[src0 + step + 1]; v.w += a * (scalar_t)c_fu[step * up + 3]; } } else // (phaseInX == 3) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 3]; v.y += a * (scalar_t)c_fu[step * up + 2]; v.z += a * (scalar_t)c_fu[step * up + 1]; v.w += a * (scalar_t)c_fu[step * up + 0]; a = s_tileIn[src0 + step + 1]; } } s_tileUpX[dst+0] = v.x; s_tileUpX[dst+1] = v.y; s_tileUpX[dst+2] = v.z; s_tileUpX[dst+3] = v.w; } } else if (up == 2) { bool p0 = (phaseInX == 0); for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up) { int relUpX0, relInY; fast_div_mod(relUpX0, relInY, idx); int relInX0 = relUpX0 / up; int src0 = relInX0 + tileInW * relInY; int dst = relInY * tileUpW + relUpX0; vec2_t v = InternalType::zero_vec2(); scalar_t a = s_tileIn[src0]; if (p0) // (phaseInX == 0) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 0]; a = s_tileIn[src0 + step + 1]; v.y += a * (scalar_t)c_fu[step * up + 1]; } } else // (phaseInX == 1) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 1]; v.y += a * (scalar_t)c_fu[step * up + 0]; a = s_tileIn[src0 + step + 1]; } } s_tileUpX[dst+0] = v.x; s_tileUpX[dst+1] = v.y; } } // Vertical upsampling & nonlinearity. __syncthreads(); int groupMask = 15 << ((threadIdx.x & 31) & ~3); int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs. int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes. if (up == 4) { minY -= 3; // Adjust according to block height. for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x) { int relUpX, relInY0; fast_div_mod(relUpX, relInY0, idx); int relUpY0 = relInY0 * up; int src0 = relInY0 * tileUpW + relUpX; int dst = relUpY0 * tileUpW + relUpX; vec4_t v = InternalType::zero_vec4(); scalar_t a = s_tileUpX[src0]; if (phaseInY == 0) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 0]; a = s_tileUpX[src0 + (step + 1) * tileUpW]; v.y += a * (scalar_t)c_fu[step * up + 3]; v.z += a * (scalar_t)c_fu[step * up + 2]; v.w += a * (scalar_t)c_fu[step * up + 1]; } } else if (phaseInY == 1) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 1]; v.y += a * (scalar_t)c_fu[step * up + 0]; a = s_tileUpX[src0 + (step + 1) * tileUpW]; v.z += a * (scalar_t)c_fu[step * up + 3]; v.w += a * (scalar_t)c_fu[step * up + 2]; } } else if (phaseInY == 2) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 2]; v.y += a * (scalar_t)c_fu[step * up + 1]; v.z += a * (scalar_t)c_fu[step * up + 0]; a = s_tileUpX[src0 + (step + 1) * tileUpW]; v.w += a * (scalar_t)c_fu[step * up + 3]; } } else // (phaseInY == 3) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 3]; v.y += a * (scalar_t)c_fu[step * up + 2]; v.z += a * (scalar_t)c_fu[step * up + 1]; v.w += a * (scalar_t)c_fu[step * up + 0]; a = s_tileUpX[src0 + (step + 1) * tileUpW]; } } int x = tileOutX * down + relUpX; int y = tileOutY * down + relUpY0; int signX = x + p.sOfs.x; int signY = y + p.sOfs.y; int signZ = blockIdx.z + p.blockZofs; int signXb = signX >> 2; index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); index_t si1 = si0 + p.sShape.x; index_t si2 = si0 + p.sShape.x * 2; index_t si3 = si0 + p.sShape.x * 3; v.x *= (scalar_t)((float)up * (float)up * p.gain); v.y *= (scalar_t)((float)up * (float)up * p.gain); v.z *= (scalar_t)((float)up * (float)up * p.gain); v.w *= (scalar_t)((float)up * (float)up * p.gain); if (signWrite) { if (!enableWriteSkip) { // Determine and write signs. int sx = __float_as_uint(v.x) >> 31 << 0; int sy = __float_as_uint(v.y) >> 31 << 8; int sz = __float_as_uint(v.z) >> 31 << 16; int sw = __float_as_uint(v.w) >> 31 << 24; if (sx) v.x *= p.slope; if (sy) v.y *= p.slope; if (sz) v.z *= p.slope; if (sw) v.w *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); } if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); } if ((uint32_t)signXb < p.swLimit && signY >= minY) { // Combine signs. uint32_t s = sx + sy + sw + sz; s <<= (signX & 3) << 1; s |= __shfl_xor_sync(groupMask, s, 1); s |= __shfl_xor_sync(groupMask, s, 2); // Write signs. if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); } if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); } } } else { // Determine and write signs. if ((uint32_t)signXb < p.swLimit && signY >= minY) { int sx = __float_as_uint(v.x) >> 31 << 0; int sy = __float_as_uint(v.y) >> 31 << 8; int sz = __float_as_uint(v.z) >> 31 << 16; int sw = __float_as_uint(v.w) >> 31 << 24; if (sx) v.x *= p.slope; if (sy) v.y *= p.slope; if (sz) v.z *= p.slope; if (sw) v.w *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); } if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); } // Combine signs. uint32_t s = sx + sy + sw + sz; s <<= (signX & 3) << 1; s |= __shfl_xor_sync(groupMask, s, 1); s |= __shfl_xor_sync(groupMask, s, 2); // Write signs. if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); } if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); } } else { // Just compute the values. if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); } } } else if (signRead) // Read signs and apply. { if ((uint32_t)signXb < p.swLimit) { int ss = (signX & 3) << 1; if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; } if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; } if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; } if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; } } } else // Forward pass with no sign write. { if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); } s_tileUpXY[dst + 0 * tileUpW] = v.x; if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y; if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z; if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w; } } else if (up == 2) { minY -= 1; // Adjust according to block height. for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x) { int relUpX, relInY0; fast_div_mod(relUpX, relInY0, idx); int relUpY0 = relInY0 * up; int src0 = relInY0 * tileUpW + relUpX; int dst = relUpY0 * tileUpW + relUpX; vec2_t v = InternalType::zero_vec2(); scalar_t a = s_tileUpX[src0]; if (phaseInY == 0) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 0]; a = s_tileUpX[src0 + (step + 1) * tileUpW]; v.y += a * (scalar_t)c_fu[step * up + 1]; } } else // (phaseInY == 1) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 1]; v.y += a * (scalar_t)c_fu[step * up + 0]; a = s_tileUpX[src0 + (step + 1) * tileUpW]; } } int x = tileOutX * down + relUpX; int y = tileOutY * down + relUpY0; int signX = x + p.sOfs.x; int signY = y + p.sOfs.y; int signZ = blockIdx.z + p.blockZofs; int signXb = signX >> 2; index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); index_t si1 = si0 + p.sShape.x; v.x *= (scalar_t)((float)up * (float)up * p.gain); v.y *= (scalar_t)((float)up * (float)up * p.gain); if (signWrite) { if (!enableWriteSkip) { // Determine and write signs. int sx = __float_as_uint(v.x) >> 31 << 0; int sy = __float_as_uint(v.y) >> 31 << 8; if (sx) v.x *= p.slope; if (sy) v.y *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } if ((uint32_t)signXb < p.swLimit && signY >= minY) { // Combine signs. int s = sx + sy; s <<= signXo; s |= __shfl_xor_sync(groupMask, s, 1); s |= __shfl_xor_sync(groupMask, s, 2); // Write signs. if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } } } else { // Determine and write signs. if ((uint32_t)signXb < p.swLimit && signY >= minY) { int sx = __float_as_uint(v.x) >> 31 << 0; int sy = __float_as_uint(v.y) >> 31 << 8; if (sx) v.x *= p.slope; if (sy) v.y *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } // Combine signs. int s = sx + sy; s <<= signXo; s |= __shfl_xor_sync(groupMask, s, 1); s |= __shfl_xor_sync(groupMask, s, 2); // Write signs. if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } } else { // Just compute the values. if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); } } } else if (signRead) // Read signs and apply. { if ((uint32_t)signXb < p.swLimit) { if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> signXo; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; } if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> signXo; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; } } } else // Forward pass with no sign write. { if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); } if (!downInline) { // Write into temporary buffer. s_tileUpXY[dst] = v.x; if (relUpY0 < tileUpH - 1) s_tileUpXY[dst + tileUpW] = v.y; } else { // Write directly into output buffer. if ((uint32_t)x < p.yShape.x) { int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down); index_t ofs = x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut; if ((uint32_t)y + 0 < p.yShape.y) *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]); if ((uint32_t)y + 1 < ymax) *((T*)((char*)p.y + ofs + get_stride(p.yStride.y))) = (T)(v.y * (scalar_t)c_fd[0]); } } } } } else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD) { // Full upsampling filter. if (up == 2) { // 2 x 2-wide. __syncthreads(); int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y : 0; // Skip already written signs. for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; idx += blockDim.x * 4) { int relUpX0, relUpY0; fast_div_mod(relUpX0, relUpY0, idx); int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up); int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up); int src0 = relInX0 + tileInW * relInY0; int tap0y = (relInY0 * up + phaseInY - relUpY0); #define X_LOOP(TAPY, PX) \ for (int sx = 0; sx < fuSize / up; sx++) \ { \ v.x += a * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \ v.z += b * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 0) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \ v.y += a * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \ v.w += b * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 1) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \ } vec4_t v = InternalType::zero_vec4(); if (tap0y == 0 && phaseInX == 0) #pragma unroll for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; #pragma unroll X_LOOP(0, 0) } if (tap0y == 0 && phaseInX == 1) #pragma unroll for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; #pragma unroll X_LOOP(0, 1) } if (tap0y == 1 && phaseInX == 0) #pragma unroll for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; #pragma unroll X_LOOP(1, 0) } if (tap0y == 1 && phaseInX == 1) #pragma unroll for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; #pragma unroll X_LOOP(1, 1) } #undef X_LOOP int x = tileOutX * down + relUpX0; int y = tileOutY * down + relUpY0; int signX = x + p.sOfs.x; int signY = y + p.sOfs.y; int signZ = blockIdx.z + p.blockZofs; int signXb = signX >> 2; index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); v.x *= (scalar_t)((float)up * (float)up * p.gain); v.y *= (scalar_t)((float)up * (float)up * p.gain); v.z *= (scalar_t)((float)up * (float)up * p.gain); v.w *= (scalar_t)((float)up * (float)up * p.gain); if (signWrite) { if (!enableWriteSkip) { // Determine and write signs. int sx = __float_as_uint(v.x) >> 31; int sy = __float_as_uint(v.y) >> 31; int sz = __float_as_uint(v.z) >> 31; int sw = __float_as_uint(v.w) >> 31; if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); } if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); } if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); } if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); } if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) { p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6); } } else { // Determine and write signs. if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) { int sx = __float_as_uint(v.x) >> 31; int sy = __float_as_uint(v.y) >> 31; int sz = __float_as_uint(v.z) >> 31; int sw = __float_as_uint(v.w) >> 31; if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); } if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); } if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); } if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); } p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6); } else { // Just compute the values. if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); } } } else if (signRead) // Read sign and apply. { if ((uint32_t)signY < p.sShape.y) { int s = 0; if ((uint32_t)signXb < p.swLimit) s = p.s[si]; if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8; s >>= (signX & 3) << 1; if (s & 0x01) v.x *= p.slope; if (s & 0x02) v.x = 0.f; if (s & 0x04) v.y *= p.slope; if (s & 0x08) v.y = 0.f; if (s & 0x10) v.z *= p.slope; if (s & 0x20) v.z = 0.f; if (s & 0x40) v.w *= p.slope; if (s & 0x80) v.w = 0.f; } } else // Forward pass with no sign write. { if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); } s_tileUpXY[idx + 0] = v.x; s_tileUpXY[idx + 1] = v.y; s_tileUpXY[idx + 2] = v.z; s_tileUpXY[idx + 3] = v.w; } } else if (up == 1) { __syncthreads(); uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3); int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs. for (int idx = threadIdx.x; idx < tileUpW * tileUpH; idx += blockDim.x) { int relUpX0, relUpY0; fast_div_mod(relUpX0, relUpY0, idx); scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter. int x = tileOutX * down + relUpX0; int y = tileOutY * down + relUpY0; int signX = x + p.sOfs.x; int signY = y + p.sOfs.y; int signZ = blockIdx.z + p.blockZofs; int signXb = signX >> 2; index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); v *= (scalar_t)((float)up * (float)up * p.gain); if (signWrite) { if (!enableWriteSkip) { // Determine and write sign. uint32_t s = 0; uint32_t signXbit = (1u << signXo); if (v < 0.f) { s = signXbit; v *= p.slope; } if (fabsf(v) > p.clamp) { s = signXbit * 2; v = InternalType::clamp(v, p.clamp); } if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) { s += __shfl_xor_sync(groupMask, s, 1); // Coalesce. s += __shfl_xor_sync(groupMask, s, 2); // Coalesce. p.s[si] = s; // Write. } } else { // Determine and write sign. if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) { uint32_t s = 0; uint32_t signXbit = (1u << signXo); if (v < 0.f) { s = signXbit; v *= p.slope; } if (fabsf(v) > p.clamp) { s = signXbit * 2; v = InternalType::clamp(v, p.clamp); } s += __shfl_xor_sync(groupMask, s, 1); // Coalesce. s += __shfl_xor_sync(groupMask, s, 2); // Coalesce. p.s[si] = s; // Write. } else { // Just compute the value. if (v < 0.f) v *= p.slope; v = InternalType::clamp(v, p.clamp); } } } else if (signRead) { // Read sign and apply if within sign tensor bounds. if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y) { int s = p.s[si]; s >>= signXo; if (s & 1) v *= p.slope; if (s & 2) v = 0.f; } } else // Forward pass with no sign write. { if (v < 0.f) v *= p.slope; v = InternalType::clamp(v, p.clamp); } if (!downInline) // Write into temporary buffer. s_tileUpXY[idx] = v; else if ((uint32_t)x < p.yShape.x && (uint32_t)y < p.yShape.y) // Write directly into output buffer *((T*)((char*)p.y + (x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]); } } } // Downsampling. if (filterMode == MODE_SUSD || filterMode == MODE_FUSD) { // Horizontal downsampling. __syncthreads(); if (down == 4 && tileOutW % 4 == 0) { // Calculate 4 pixels at a time. for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; idx += blockDim.x * 4) { int relOutX0, relUpY; fast_div_mod(relOutX0, relUpY, idx); int relUpX0 = relOutX0 * down; int src0 = relUpY * tileUpW + relUpX0; vec4_t v = InternalType::zero_vec4(); #pragma unroll for (int step = 0; step < fdSize; step++) { v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step]; v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step]; v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step]; v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step]; } s_tileDownX[idx+0] = v.x; s_tileDownX[idx+1] = v.y; s_tileDownX[idx+2] = v.z; s_tileDownX[idx+3] = v.w; } } else if ((down == 2 || down == 4) && (tileOutW % 2 == 0)) { // Calculate 2 pixels at a time. for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; idx += blockDim.x * 2) { int relOutX0, relUpY; fast_div_mod(relOutX0, relUpY, idx); int relUpX0 = relOutX0 * down; int src0 = relUpY * tileUpW + relUpX0; vec2_t v = InternalType::zero_vec2(); #pragma unroll for (int step = 0; step < fdSize; step++) { v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step]; v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step]; } s_tileDownX[idx+0] = v.x; s_tileDownX[idx+1] = v.y; } } else { // Calculate 1 pixel at a time. for (int idx = threadIdx.x; idx < tileOutW * tileUpH; idx += blockDim.x) { int relOutX0, relUpY; fast_div_mod(relOutX0, relUpY, idx); int relUpX0 = relOutX0 * down; int src = relUpY * tileUpW + relUpX0; scalar_t v = 0.f; #pragma unroll for (int step = 0; step < fdSize; step++) v += s_tileUpXY[src + step] * (scalar_t)c_fd[step]; s_tileDownX[idx] = v; } } // Vertical downsampling & store output tile. __syncthreads(); for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x) { int relOutX, relOutY0; fast_div_mod(relOutX, relOutY0, idx); int relUpY0 = relOutY0 * down; int src0 = relUpY0 * tileOutW + relOutX; scalar_t v = 0; #pragma unroll for (int step = 0; step < fdSize; step++) v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step]; int outX = tileOutX + relOutX; int outY = tileOutY + relOutY0; if (outX < p.yShape.x & outY < p.yShape.y) *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v; } } else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD) { // Full downsampling filter. if (down == 2) { // 2-wide. __syncthreads(); for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; idx += blockDim.x * 2) { int relOutX0, relOutY0; fast_div_mod(relOutX0, relOutY0, idx); int relUpX0 = relOutX0 * down; int relUpY0 = relOutY0 * down; int src0 = relUpY0 * tileUpW + relUpX0; vec2_t v = InternalType::zero_vec2(); #pragma unroll for (int sy = 0; sy < fdSize; sy++) #pragma unroll for (int sx = 0; sx < fdSize; sx++) { v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE]; v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE]; } int outX = tileOutX + relOutX0; int outY = tileOutY + relOutY0; if ((uint32_t)outY < p.yShape.y) { index_t ofs = outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut; if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x; if (outX + 1 < p.yShape.x) *((T*)((char*)p.y + ofs + get_stride(p.yStride.x))) = (T)v.y; } } } else if (down == 1 && !downInline) { // Thread per pixel. __syncthreads(); for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x) { int relOutX0, relOutY0; fast_div_mod(relOutX0, relOutY0, idx); scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter. int outX = tileOutX + relOutX0; int outY = tileOutY + relOutY0; if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y) *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v; } } } if (!enableXrep) break; } } //------------------------------------------------------------------------ // Compute activation function and signs for upsampled data tensor, modifying data tensor in-place. Used for accelerating the generic variant. // Sign tensor is known to be contiguous, and p.x and p.s have the same z, w dimensions. 64-bit indexing is always used. template static __global__ void filtered_lrelu_act_kernel(filtered_lrelu_act_kernel_params p) { typedef typename InternalType::scalar_t scalar_t; // Indexing. int32_t x = threadIdx.x + blockIdx.x * blockDim.x; int32_t ymax = signWrite ? p.sShape.y : p.xShape.y; int32_t qmax = p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index. // Loop to accommodate oversized tensors. for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z) for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y) { // Extract z and w (channel, minibatch index). int32_t w = q / p.xShape.z; int32_t z = q - w * p.xShape.z; // Choose behavior based on sign read/write mode. if (signWrite) { // Process value if in p.x. uint32_t s = 0; if (x < p.xShape.x && y < p.xShape.y) { int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; T* pv = ((T*)p.x) + ix; scalar_t v = (scalar_t)(*pv); // Gain, LReLU, clamp. v *= p.gain; if (v < 0.f) { v *= p.slope; s = 1; // Sign. } if (fabsf(v) > p.clamp) { v = InternalType::clamp(v, p.clamp); s = 2; // Clamp. } *pv = (T)v; // Write value. } // Coalesce into threads 0 and 16 of warp. uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu; s <<= ((threadIdx.x & 15) << 1); // Shift into place. s |= __shfl_xor_sync(m, s, 1); // Distribute. s |= __shfl_xor_sync(m, s, 2); s |= __shfl_xor_sync(m, s, 4); s |= __shfl_xor_sync(m, s, 8); // Write signs if leader and in p.s. if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in. { uint64_t is = x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous. ((uint32_t*)p.s)[is >> 4] = s; } } else if (signRead) { // Process value if in p.x. if (x < p.xShape.x) // y is always in. { int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; T* pv = ((T*)p.x) + ix; scalar_t v = (scalar_t)(*pv); v *= p.gain; // Apply sign buffer offset. uint32_t sx = x + p.sOfs.x; uint32_t sy = y + p.sOfs.y; // Read and apply signs if we land inside valid region of sign buffer. if (sx < p.sShape.x && sy < p.sShape.y) { uint64_t is = (sx >> 2) + (p.sShape.x >> 2) * (sy + (uint64_t)p.sShape.y * q); // Contiguous. unsigned char s = p.s[is]; s >>= (sx & 3) << 1; // Shift into place. if (s & 1) // Sign? v *= p.slope; if (s & 2) // Clamp? v = 0.f; } *pv = (T)v; // Write value. } } else { // Forward pass with no sign write. Process value if in p.x. if (x < p.xShape.x) // y is always in. { int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; T* pv = ((T*)p.x) + ix; scalar_t v = (scalar_t)(*pv); v *= p.gain; if (v < 0.f) v *= p.slope; if (fabsf(v) > p.clamp) v = InternalType::clamp(v, p.clamp); *pv = (T)v; // Write value. } } } } template void* choose_filtered_lrelu_act_kernel(void) { return (void*)filtered_lrelu_act_kernel; } //------------------------------------------------------------------------ // CUDA kernel selection. template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB) { filtered_lrelu_kernel_spec s = { 0 }; // Return the first matching kernel. #define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \ if (sharedKB >= SH) \ if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \ if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \ if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) \ { \ static_assert((D*TW % 4) == 0, "down * tileWidth must be divisible by 4"); \ static_assert(FU % U == 0, "upscaling filter size must be multiple of upscaling factor"); \ static_assert(FD % D == 0, "downscaling filter size must be multiple of downscaling factor"); \ s.setup = (void*)setup_filters_kernel; \ s.exec = (void*)filtered_lrelu_kernel; \ s.tileOut = make_int2(TW, TH); \ s.numWarps = W; \ s.xrep = XR; \ s.dynamicSharedKB = (SH == 48) ? 0 : SH; \ return s; \ } // Launch parameters for various kernel specializations. // Small filters must be listed before large filters, otherwise the kernel for larger filter will always match first. // Kernels that use more shared memory must be listed before those that use less, for the same reason. CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/1,1, /*mode*/MODE_FUFD, /*tw,th,warps,xrep,wskip*/64, 178, 32, 0, 0) // 1t-upf1-downf1 CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/152, 95, 16, 0, 0) // 4t-ups2-downf1 CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 22, 16, 0, 0) // 4t-upf1-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 29, 16, 11, 0) // 4t-ups2-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/60, 28, 16, 0, 0) // 4t-upf2-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 28, 16, 0, 0) // 4t-ups2-downf2 CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 31, 16, 11, 0) // 4t-ups4-downs2 CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 36, 16, 0, 0) // 4t-ups4-downf2 CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 22, 16, 12, 0) // 4t-ups2-downs4 CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/29, 15, 16, 0, 0) // 4t-upf2-downs4 CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/96, 150, 28, 0, 0) // 6t-ups2-downf1 CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 35, 24, 0, 0) // 6t-upf1-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 16, 10, 0) // 6t-ups2-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/58, 28, 24, 8, 0) // 6t-upf2-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/52, 28, 16, 0, 0) // 6t-ups2-downf2 CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 51, 16, 5, 0) // 6t-ups4-downs2 CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 56, 16, 6, 0) // 6t-ups4-downf2 CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 18, 16, 12, 0) // 6t-ups2-downs4 CASE(/*sharedKB*/96, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 13, 24, 0, 0) // 6t-upf2-downs4 CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/148, 89, 24, 0, 0) // 8t-ups2-downf1 CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 31, 16, 5, 0) // 8t-upf1-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 41, 16, 9, 0) // 8t-ups2-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 26, 24, 0, 0) // 8t-upf2-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 40, 16, 0, 0) // 8t-ups2-downf2 CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 24, 5, 0) // 8t-ups4-downs2 CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 50, 16, 0, 0) // 8t-ups4-downf2 CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 13, 16, 10, 1) // 8t-ups2-downs4 CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 10, 24, 0, 0) // 8t-upf2-downs4 #undef CASE return s; // No kernel found. } //------------------------------------------------------------------------ ================================================ FILE: torch_utils/ops/filtered_lrelu.h ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include //------------------------------------------------------------------------ // CUDA kernel parameters. struct filtered_lrelu_kernel_params { // These parameters decide which kernel to use. int up; // upsampling ratio (1, 2, 4) int down; // downsampling ratio (1, 2, 4) int2 fuShape; // [size, 1] | [size, size] int2 fdShape; // [size, 1] | [size, size] int _dummy; // Alignment. // Rest of the parameters. const void* x; // Input tensor. void* y; // Output tensor. const void* b; // Bias tensor. unsigned char* s; // Sign tensor in/out. NULL if unused. const float* fu; // Upsampling filter. const float* fd; // Downsampling filter. int2 pad0; // Left/top padding. float gain; // Additional gain factor. float slope; // Leaky ReLU slope on negative side. float clamp; // Clamp after nonlinearity. int flip; // Filter kernel flip for gradient computation. int tilesXdim; // Original number of horizontal output tiles. int tilesXrep; // Number of horizontal tiles per CTA. int blockZofs; // Block z offset to support large minibatch, channel dimensions. int4 xShape; // [width, height, channel, batch] int4 yShape; // [width, height, channel, batch] int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. int swLimit; // Active width of sign tensor in bytes. longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. longlong4 yStride; // int64_t bStride; // longlong3 fuStride; // longlong3 fdStride; // }; struct filtered_lrelu_act_kernel_params { void* x; // Input/output, modified in-place. unsigned char* s; // Sign tensor in/out. NULL if unused. float gain; // Additional gain factor. float slope; // Leaky ReLU slope on negative side. float clamp; // Clamp after nonlinearity. int4 xShape; // [width, height, channel, batch] longlong4 xStride; // Input/output tensor strides, same order as in shape. int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. }; //------------------------------------------------------------------------ // CUDA kernel specialization. struct filtered_lrelu_kernel_spec { void* setup; // Function for filter kernel setup. void* exec; // Function for main operation. int2 tileOut; // Width/height of launch tile. int numWarps; // Number of warps per thread block, determines launch block size. int xrep; // For processing multiple horizontal tiles per thread block. int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. }; //------------------------------------------------------------------------ // CUDA kernel selection. template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); template void* choose_filtered_lrelu_act_kernel(void); template cudaError_t copy_filters(cudaStream_t stream); //------------------------------------------------------------------------ ================================================ FILE: torch_utils/ops/filtered_lrelu.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import os import numpy as np import torch import warnings from .. import custom_ops from .. import misc from . import upfirdn2d from . import bias_act #---------------------------------------------------------------------------- _plugin = None def _init(): global _plugin if _plugin is None: _plugin = custom_ops.get_plugin( module_name='filtered_lrelu_plugin', sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'], headers=['filtered_lrelu.h', 'filtered_lrelu.cu'], source_dir=os.path.dirname(__file__), extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'], ) return True def _get_filter_size(f): if f is None: return 1, 1 assert isinstance(f, torch.Tensor) assert 1 <= f.ndim <= 2 return f.shape[-1], f.shape[0] # width, height def _parse_padding(padding): if isinstance(padding, int): padding = [padding, padding] assert isinstance(padding, (list, tuple)) assert all(isinstance(x, (int, np.integer)) for x in padding) padding = [int(x) for x in padding] if len(padding) == 2: px, py = padding padding = [px, px, py, py] px0, px1, py0, py1 = padding return px0, px1, py0, py1 #---------------------------------------------------------------------------- def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'): r"""Filtered leaky ReLU for a batch of 2D images. Performs the following sequence of operations for each channel: 1. Add channel-specific bias if provided (`b`). 2. Upsample the image by inserting N-1 zeros after each pixel (`up`). 3. Pad the image with the specified number of zeros on each side (`padding`). Negative padding corresponds to cropping the image. 4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it so that the footprint of all output pixels lies within the input image. 5. Multiply each value by the provided gain factor (`gain`). 6. Apply leaky ReLU activation function to each value. 7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided. 8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking it so that the footprint of all output pixels lies within the input image. 9. Downsample the image by keeping every Nth pixel (`down`). The fused op is considerably more efficient than performing the same calculation using standard PyTorch ops. It supports gradients of arbitrary order. Args: x: Float32/float16/float64 input tensor of the shape `[batch_size, num_channels, in_height, in_width]`. fu: Float32 upsampling FIR filter of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), or `None` (identity). fd: Float32 downsampling FIR filter of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), or `None` (identity). b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type as `x`. The length of vector must must match the channel dimension of `x`. up: Integer upsampling factor (default: 1). down: Integer downsampling factor. (default: 1). padding: Padding with respect to the upsampled image. Can be a single number or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` (default: 0). gain: Overall scaling factor for signal magnitude (default: sqrt(2)). slope: Slope on the negative side of leaky ReLU (default: 0.2). clamp: Maximum magnitude for leaky ReLU output (default: None). flip_filter: False = convolution, True = correlation (default: False). impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). Returns: Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. """ assert isinstance(x, torch.Tensor) assert impl in ['ref', 'cuda'] if impl == 'cuda' and x.device.type == 'cuda' and _init(): return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0) return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter) #---------------------------------------------------------------------------- @misc.profiled_function def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): """Slow and memory-inefficient reference implementation of `filtered_lrelu()` using existing `upfirdn2n()` and `bias_act()` ops. """ assert isinstance(x, torch.Tensor) and x.ndim == 4 fu_w, fu_h = _get_filter_size(fu) fd_w, fd_h = _get_filter_size(fd) if b is not None: assert isinstance(b, torch.Tensor) and b.dtype == x.dtype misc.assert_shape(b, [x.shape[1]]) assert isinstance(up, int) and up >= 1 assert isinstance(down, int) and down >= 1 px0, px1, py0, py1 = _parse_padding(padding) assert gain == float(gain) and gain > 0 assert slope == float(slope) and slope >= 0 assert clamp is None or (clamp == float(clamp) and clamp >= 0) # Calculate output size. batch_size, channels, in_h, in_w = x.shape in_dtype = x.dtype out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down # Compute using existing ops. x = bias_act.bias_act(x=x, b=b) # Apply bias. x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample. x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp. x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample. # Check output shape & dtype. misc.assert_shape(x, [batch_size, channels, out_h, out_w]) assert x.dtype == in_dtype return x #---------------------------------------------------------------------------- _filtered_lrelu_cuda_cache = dict() def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): """Fast CUDA implementation of `filtered_lrelu()` using custom ops. """ assert isinstance(up, int) and up >= 1 assert isinstance(down, int) and down >= 1 px0, px1, py0, py1 = _parse_padding(padding) assert gain == float(gain) and gain > 0 gain = float(gain) assert slope == float(slope) and slope >= 0 slope = float(slope) assert clamp is None or (clamp == float(clamp) and clamp >= 0) clamp = float(clamp if clamp is not None else 'inf') # Lookup from cache. key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter) if key in _filtered_lrelu_cuda_cache: return _filtered_lrelu_cuda_cache[key] # Forward op. class FilteredLReluCuda(torch.autograd.Function): @staticmethod def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ assert isinstance(x, torch.Tensor) and x.ndim == 4 # Replace empty up/downsample kernels with full 1x1 kernels (faster than separable). if fu is None: fu = torch.ones([1, 1], dtype=torch.float32, device=x.device) if fd is None: fd = torch.ones([1, 1], dtype=torch.float32, device=x.device) assert 1 <= fu.ndim <= 2 assert 1 <= fd.ndim <= 2 # Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1. if up == 1 and fu.ndim == 1 and fu.shape[0] == 1: fu = fu.square()[None] if down == 1 and fd.ndim == 1 and fd.shape[0] == 1: fd = fd.square()[None] # Missing sign input tensor. if si is None: si = torch.empty([0]) # Missing bias tensor. if b is None: b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device) # Construct internal sign tensor only if gradients are needed. write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad) # Warn if input storage strides are not in decreasing order due to e.g. channels-last layout. strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1] if any(a < b for a, b in zip(strides[:-1], strides[1:])): warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning) # Call C++/Cuda plugin if datatype is supported. if x.dtype in [torch.float16, torch.float32]: if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device): warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning) y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs) else: return_code = -1 # No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because # only the bit-packed sign tensor is retained for gradient computation. if return_code < 0: warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning) y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias. y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample. so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place. y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample. # Prepare for gradient computation. ctx.save_for_backward(fu, fd, (si if si.numel() else so)) ctx.x_shape = x.shape ctx.y_shape = y.shape ctx.s_ofs = sx, sy return y @staticmethod def backward(ctx, dy): # pylint: disable=arguments-differ fu, fd, si = ctx.saved_tensors _, _, xh, xw = ctx.x_shape _, _, yh, yw = ctx.y_shape sx, sy = ctx.s_ofs dx = None # 0 dfu = None; assert not ctx.needs_input_grad[1] dfd = None; assert not ctx.needs_input_grad[2] db = None # 3 dsi = None; assert not ctx.needs_input_grad[4] dsx = None; assert not ctx.needs_input_grad[5] dsy = None; assert not ctx.needs_input_grad[6] if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]: pp = [ (fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0, xw * up - yw * down + px0 - (up - 1), (fu.shape[0] - 1) + (fd.shape[0] - 1) - py0, xh * up - yh * down + py0 - (up - 1), ] gg = gain * (up ** 2) / (down ** 2) ff = (not flip_filter) sx = sx - (fu.shape[-1] - 1) + px0 sy = sy - (fu.shape[0] - 1) + py0 dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy) if ctx.needs_input_grad[3]: db = dx.sum([0, 2, 3]) return dx, dfu, dfd, db, dsi, dsx, dsy # Add to cache. _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda return FilteredLReluCuda #---------------------------------------------------------------------------- ================================================ FILE: torch_utils/ops/filtered_lrelu_ns.cu ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include "filtered_lrelu.cu" // Template/kernel specializations for no signs mode (no gradients required). // Full op, 32-bit indexing. template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); // Full op, 64-bit indexing. template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); // Activation/signs only for generic variant. 64-bit indexing. template void* choose_filtered_lrelu_act_kernel(void); template void* choose_filtered_lrelu_act_kernel(void); template void* choose_filtered_lrelu_act_kernel(void); // Copy filters to constant memory. template cudaError_t copy_filters(cudaStream_t stream); ================================================ FILE: torch_utils/ops/filtered_lrelu_rd.cu ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include "filtered_lrelu.cu" // Template/kernel specializations for sign read mode. // Full op, 32-bit indexing. template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); // Full op, 64-bit indexing. template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); // Activation/signs only for generic variant. 64-bit indexing. template void* choose_filtered_lrelu_act_kernel(void); template void* choose_filtered_lrelu_act_kernel(void); template void* choose_filtered_lrelu_act_kernel(void); // Copy filters to constant memory. template cudaError_t copy_filters(cudaStream_t stream); ================================================ FILE: torch_utils/ops/filtered_lrelu_wr.cu ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include "filtered_lrelu.cu" // Template/kernel specializations for sign write mode. // Full op, 32-bit indexing. template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); // Full op, 64-bit indexing. template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); // Activation/signs only for generic variant. 64-bit indexing. template void* choose_filtered_lrelu_act_kernel(void); template void* choose_filtered_lrelu_act_kernel(void); template void* choose_filtered_lrelu_act_kernel(void); // Copy filters to constant memory. template cudaError_t copy_filters(cudaStream_t stream); ================================================ FILE: torch_utils/ops/fma.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" import torch #---------------------------------------------------------------------------- def fma(a, b, c): # => a * b + c return _FusedMultiplyAdd.apply(a, b, c) #---------------------------------------------------------------------------- class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c @staticmethod def forward(ctx, a, b, c): # pylint: disable=arguments-differ out = torch.addcmul(c, a, b) ctx.save_for_backward(a, b) ctx.c_shape = c.shape return out @staticmethod def backward(ctx, dout): # pylint: disable=arguments-differ a, b = ctx.saved_tensors c_shape = ctx.c_shape da = None db = None dc = None if ctx.needs_input_grad[0]: da = _unbroadcast(dout * b, a.shape) if ctx.needs_input_grad[1]: db = _unbroadcast(dout * a, b.shape) if ctx.needs_input_grad[2]: dc = _unbroadcast(dout, c_shape) return da, db, dc #---------------------------------------------------------------------------- def _unbroadcast(x, shape): extra_dims = x.ndim - len(shape) assert extra_dims >= 0 dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] if len(dim): x = x.sum(dim=dim, keepdim=True) if extra_dims: x = x.reshape(-1, *x.shape[extra_dims+1:]) assert x.shape == shape return x #---------------------------------------------------------------------------- ================================================ FILE: torch_utils/ops/grid_sample_gradfix.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Custom replacement for `torch.nn.functional.grid_sample` that supports arbitrarily high order gradients between the input and output. Only works on 2D images and assumes `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" import torch # pylint: disable=redefined-builtin # pylint: disable=arguments-differ # pylint: disable=protected-access #---------------------------------------------------------------------------- enabled = False # Enable the custom op by setting this to true. #---------------------------------------------------------------------------- def grid_sample(input, grid): if _should_use_custom_op(): return _GridSample2dForward.apply(input, grid) return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) #---------------------------------------------------------------------------- def _should_use_custom_op(): return enabled #---------------------------------------------------------------------------- class _GridSample2dForward(torch.autograd.Function): @staticmethod def forward(ctx, input, grid): assert input.ndim == 4 assert grid.ndim == 4 output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) ctx.save_for_backward(input, grid) return output @staticmethod def backward(ctx, grad_output): input, grid = ctx.saved_tensors grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) return grad_input, grad_grid #---------------------------------------------------------------------------- class _GridSample2dBackward(torch.autograd.Function): @staticmethod def forward(ctx, grad_output, input, grid): op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) ctx.save_for_backward(grid) return grad_input, grad_grid @staticmethod def backward(ctx, grad2_grad_input, grad2_grad_grid): _ = grad2_grad_grid # unused grid, = ctx.saved_tensors grad2_grad_output = None grad2_input = None grad2_grid = None if ctx.needs_input_grad[0]: grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) assert not ctx.needs_input_grad[2] return grad2_grad_output, grad2_input, grad2_grid #---------------------------------------------------------------------------- ================================================ FILE: torch_utils/ops/upfirdn2d.cpp ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include #include #include #include "upfirdn2d.h" //------------------------------------------------------------------------ static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) { // Validate arguments. TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); TORCH_CHECK(x.numel() > 0, "x has zero size"); TORCH_CHECK(f.numel() > 0, "f has zero size"); TORCH_CHECK(x.dim() == 4, "x must be rank 4"); TORCH_CHECK(f.dim() == 2, "f must be rank 2"); TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large"); TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); // Create output tensor. const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large"); // Initialize CUDA kernel parameters. upfirdn2d_kernel_params p; p.x = x.data_ptr(); p.f = f.data_ptr(); p.y = y.data_ptr(); p.up = make_int2(upx, upy); p.down = make_int2(downx, downy); p.pad0 = make_int2(padx0, pady0); p.flip = (flip) ? 1 : 0; p.gain = gain; p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; // Choose CUDA kernel. upfirdn2d_kernel_spec spec; AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { spec = choose_upfirdn2d_kernel(p); }); // Set looping options. p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; p.loopMinor = spec.loopMinor; p.loopX = spec.loopX; p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; // Compute grid size. dim3 blockSize, gridSize; if (spec.tileOutW < 0) // large { blockSize = dim3(4, 32, 1); gridSize = dim3( ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, p.launchMajor); } else // small { blockSize = dim3(256, 1, 1); gridSize = dim3( ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, p.launchMajor); } // Launch CUDA kernel. void* args[] = {&p}; AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); return y; } //------------------------------------------------------------------------ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("upfirdn2d", &upfirdn2d); } //------------------------------------------------------------------------ ================================================ FILE: torch_utils/ops/upfirdn2d.cu ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include #include "upfirdn2d.h" //------------------------------------------------------------------------ // Helpers. template struct InternalType; template <> struct InternalType { typedef double scalar_t; }; template <> struct InternalType { typedef float scalar_t; }; template <> struct InternalType { typedef float scalar_t; }; static __device__ __forceinline__ int floor_div(int a, int b) { int t = 1 - a / b; return (a + t * b) / b - t; } //------------------------------------------------------------------------ // Generic CUDA implementation for large filters. template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) { typedef typename InternalType::scalar_t scalar_t; // Calculate thread index. int minorBase = blockIdx.x * blockDim.x + threadIdx.x; int outY = minorBase / p.launchMinor; minorBase -= outY * p.launchMinor; int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; int majorBase = blockIdx.z * p.loopMajor; if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor) return; // Setup Y receptive field. int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y; int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y); int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY; int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y; if (p.flip) filterY = p.filterSize.y - 1 - filterY; // Loop over major, minor, and X. for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor) { int nc = major * p.sizeMinor + minor; int n = nc / p.inSize.z; int c = nc - n * p.inSize.z; for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y) { // Setup X receptive field. int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x; int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x); int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX; int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x; if (p.flip) filterX = p.filterSize.x - 1 - filterX; // Initialize pointers. const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y]; int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x; int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y; // Inner loop. scalar_t v = 0; for (int y = 0; y < h; y++) { for (int x = 0; x < w; x++) { v += (scalar_t)(*xp) * (scalar_t)(*fp); xp += p.inStride.x; fp += filterStepX; } xp += p.inStride.y - w * p.inStride.x; fp += filterStepY - w * filterStepX; } // Store result. v *= p.gain; ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; } } } //------------------------------------------------------------------------ // Specialized CUDA implementation for small filters. template static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) { typedef typename InternalType::scalar_t scalar_t; const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1; const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1; __shared__ volatile scalar_t sf[filterH][filterW]; __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor]; // Calculate tile index. int minorBase = blockIdx.x; int tileOutY = minorBase / p.launchMinor; minorBase -= tileOutY * p.launchMinor; minorBase *= loopMinor; tileOutY *= tileOutH; int tileOutXBase = blockIdx.y * p.loopX * tileOutW; int majorBase = blockIdx.z * p.loopMajor; if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor) return; // Load filter (flipped). for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x) { int fy = tapIdx / filterW; int fx = tapIdx - fy * filterW; scalar_t v = 0; if (fx < p.filterSize.x & fy < p.filterSize.y) { int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx; int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy; v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y]; } sf[fy][fx] = v; } // Loop over major and X. for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) { int baseNC = major * p.sizeMinor + minorBase; int n = baseNC / p.inSize.z; int baseC = baseNC - n * p.inSize.z; for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW) { // Load input pixels. int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x; int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y; int tileInX = floor_div(tileMidX, upx); int tileInY = floor_div(tileMidY, upy); __syncthreads(); for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x) { int relC = inIdx; int relInX = relC / loopMinor; int relInY = relInX / tileInW; relC -= relInX * loopMinor; relInX -= relInY * tileInW; int c = baseC + relC; int inX = tileInX + relInX; int inY = tileInY + relInY; scalar_t v = 0; if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z) v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; sx[relInY][relInX][relC] = v; } // Loop over output pixels. __syncthreads(); for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x) { int relC = outIdx; int relOutX = relC / loopMinor; int relOutY = relOutX / tileOutW; relC -= relOutX * loopMinor; relOutX -= relOutY * tileOutW; int c = baseC + relC; int outX = tileOutX + relOutX; int outY = tileOutY + relOutY; // Setup receptive field. int midX = tileMidX + relOutX * downx; int midY = tileMidY + relOutY * downy; int inX = floor_div(midX, upx); int inY = floor_div(midY, upy); int relInX = inX - tileInX; int relInY = inY - tileInY; int filterX = (inX + 1) * upx - midX - 1; // flipped int filterY = (inY + 1) * upy - midY - 1; // flipped // Inner loop. if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) { scalar_t v = 0; #pragma unroll for (int y = 0; y < filterH / upy; y++) #pragma unroll for (int x = 0; x < filterW / upx; x++) v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx]; v *= p.gain; ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; } } } } } //------------------------------------------------------------------------ // CUDA kernel selection. template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p) { int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y; upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last // No up/downsampling. if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) { // contiguous if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; if (s != 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; if (s != 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; if (s != 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; if (s != 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; // channels_last if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (s == 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; if (s == 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; if (s == 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; if (s == 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; } // 2x upsampling. if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) { // contiguous if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; // channels_last if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; } if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) { // contiguous if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; // channels_last if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; } if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) { // contiguous if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; // channels_last if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; } // 2x downsampling. if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) { // contiguous if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; // channels_last if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 16,16,1, 1}; if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 16,16,1, 1}; if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; } if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) { // contiguous if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; // channels_last if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; } if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) { // contiguous if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; // channels_last if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; } // 4x upsampling. if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) { // contiguous if (s != 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; if (s != 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; // channels_last if (s == 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (s == 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; } if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) { // contiguous if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; // channels_last if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; } if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) { // contiguous if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; // channels_last if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; } // 4x downsampling (inefficient). if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1) { // contiguous if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; // channels_last if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,1,8, 1}; if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,1,8, 1}; } if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4) { // contiguous if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; // channels_last if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 1,32,8, 1}; if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 1,32,8, 1}; } return spec; } //------------------------------------------------------------------------ // Template specializations. template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); //------------------------------------------------------------------------ ================================================ FILE: torch_utils/ops/upfirdn2d.h ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include //------------------------------------------------------------------------ // CUDA kernel parameters. struct upfirdn2d_kernel_params { const void* x; const float* f; void* y; int2 up; int2 down; int2 pad0; int flip; float gain; int4 inSize; // [width, height, channel, batch] int4 inStride; int2 filterSize; // [width, height] int2 filterStride; int4 outSize; // [width, height, channel, batch] int4 outStride; int sizeMinor; int sizeMajor; int loopMinor; int loopMajor; int loopX; int launchMinor; int launchMajor; }; //------------------------------------------------------------------------ // CUDA kernel specialization. struct upfirdn2d_kernel_spec { void* kernel; int tileOutW; int tileOutH; int loopMinor; int loopX; }; //------------------------------------------------------------------------ // CUDA kernel selection. template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); //------------------------------------------------------------------------ ================================================ FILE: torch_utils/ops/upfirdn2d.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Custom PyTorch ops for efficient resampling of 2D images.""" import os import numpy as np import torch from .. import custom_ops from .. import misc from . import conv2d_gradfix #---------------------------------------------------------------------------- _plugin = None def _init(): global _plugin if _plugin is None: _plugin = custom_ops.get_plugin( module_name='upfirdn2d_plugin', sources=['upfirdn2d.cpp', 'upfirdn2d.cu'], headers=['upfirdn2d.h'], source_dir=os.path.dirname(__file__), extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'], ) return True def _parse_scaling(scaling): if isinstance(scaling, int): scaling = [scaling, scaling] assert isinstance(scaling, (list, tuple)) assert all(isinstance(x, int) for x in scaling) sx, sy = scaling assert sx >= 1 and sy >= 1 return sx, sy def _parse_padding(padding): if isinstance(padding, int): padding = [padding, padding] assert isinstance(padding, (list, tuple)) assert all(isinstance(x, int) for x in padding) if len(padding) == 2: padx, pady = padding padding = [padx, padx, pady, pady] padx0, padx1, pady0, pady1 = padding return padx0, padx1, pady0, pady1 def _get_filter_size(f): if f is None: return 1, 1 assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] fw = f.shape[-1] fh = f.shape[0] with misc.suppress_tracer_warnings(): fw = int(fw) fh = int(fh) misc.assert_shape(f, [fh, fw][:f.ndim]) assert fw >= 1 and fh >= 1 return fw, fh #---------------------------------------------------------------------------- def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. Args: f: Torch tensor, numpy array, or python list of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), `[]` (impulse), or `None` (identity). device: Result device (default: cpu). normalize: Normalize the filter so that it retains the magnitude for constant input signal (DC)? (default: True). flip_filter: Flip the filter? (default: False). gain: Overall scaling factor for signal magnitude (default: 1). separable: Return a separable filter? (default: select automatically). Returns: Float32 tensor of the shape `[filter_height, filter_width]` (non-separable) or `[filter_taps]` (separable). """ # Validate. if f is None: f = 1 f = torch.as_tensor(f, dtype=torch.float32) assert f.ndim in [0, 1, 2] assert f.numel() > 0 if f.ndim == 0: f = f[np.newaxis] # Separable? if separable is None: separable = (f.ndim == 1 and f.numel() >= 8) if f.ndim == 1 and not separable: f = f.ger(f) assert f.ndim == (1 if separable else 2) # Apply normalize, flip, gain, and device. if normalize: f /= f.sum() if flip_filter: f = f.flip(list(range(f.ndim))) f = f * (gain ** (f.ndim / 2)) f = f.to(device=device) return f #---------------------------------------------------------------------------- def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): r"""Pad, upsample, filter, and downsample a batch of 2D images. Performs the following sequence of operations for each channel: 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). 2. Pad the image with the specified number of zeros on each side (`padding`). Negative padding corresponds to cropping the image. 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it so that the footprint of all output pixels lies within the input image. 4. Downsample the image by keeping every Nth pixel (`down`). This sequence of operations bears close resemblance to scipy.signal.upfirdn(). The fused op is considerably more efficient than performing the same calculation using standard PyTorch ops. It supports gradients of arbitrary order. Args: x: Float32/float64/float16 input tensor of the shape `[batch_size, num_channels, in_height, in_width]`. f: Float32 FIR filter of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), or `None` (identity). up: Integer upsampling factor. Can be a single int or a list/tuple `[x, y]` (default: 1). down: Integer downsampling factor. Can be a single int or a list/tuple `[x, y]` (default: 1). padding: Padding with respect to the upsampled image. Can be a single number or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` (default: 0). flip_filter: False = convolution, True = correlation (default: False). gain: Overall scaling factor for signal magnitude (default: 1). impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). Returns: Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. """ assert isinstance(x, torch.Tensor) assert impl in ['ref', 'cuda'] if impl == 'cuda' and x.device.type == 'cuda' and _init(): return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) #---------------------------------------------------------------------------- @misc.profiled_function def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. """ # Validate arguments. assert isinstance(x, torch.Tensor) and x.ndim == 4 if f is None: f = torch.ones([1, 1], dtype=torch.float32, device=x.device) assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] assert f.dtype == torch.float32 and not f.requires_grad batch_size, num_channels, in_height, in_width = x.shape upx, upy = _parse_scaling(up) downx, downy = _parse_scaling(down) padx0, padx1, pady0, pady1 = _parse_padding(padding) # Check that upsampled buffer is not smaller than the filter. upW = in_width * upx + padx0 + padx1 upH = in_height * upy + pady0 + pady1 assert upW >= f.shape[-1] and upH >= f.shape[0] # Upsample by inserting zeros. x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) # Pad or crop. x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)] # Setup filter. f = f * (gain ** (f.ndim / 2)) f = f.to(x.dtype) if not flip_filter: f = f.flip(list(range(f.ndim))) # Convolve with the filter. f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) if f.ndim == 4: x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels) else: x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) # Downsample by throwing away pixels. x = x[:, :, ::downy, ::downx] return x #---------------------------------------------------------------------------- _upfirdn2d_cuda_cache = dict() def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): """Fast CUDA implementation of `upfirdn2d()` using custom ops. """ # Parse arguments. upx, upy = _parse_scaling(up) downx, downy = _parse_scaling(down) padx0, padx1, pady0, pady1 = _parse_padding(padding) # Lookup from cache. key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) if key in _upfirdn2d_cuda_cache: return _upfirdn2d_cuda_cache[key] # Forward op. class Upfirdn2dCuda(torch.autograd.Function): @staticmethod def forward(ctx, x, f): # pylint: disable=arguments-differ assert isinstance(x, torch.Tensor) and x.ndim == 4 if f is None: f = torch.ones([1, 1], dtype=torch.float32, device=x.device) if f.ndim == 1 and f.shape[0] == 1: f = f.square().unsqueeze(0) # Convert separable-1 into full-1x1. assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] y = x if f.ndim == 2: y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) else: y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, 1.0) y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, gain) ctx.save_for_backward(f) ctx.x_shape = x.shape return y @staticmethod def backward(ctx, dy): # pylint: disable=arguments-differ f, = ctx.saved_tensors _, _, ih, iw = ctx.x_shape _, _, oh, ow = dy.shape fw, fh = _get_filter_size(f) p = [ fw - padx0 - 1, iw * upx - ow * downx + padx0 - upx + 1, fh - pady0 - 1, ih * upy - oh * downy + pady0 - upy + 1, ] dx = None df = None if ctx.needs_input_grad[0]: dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) assert not ctx.needs_input_grad[1] return dx, df # Add to cache. _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda return Upfirdn2dCuda #---------------------------------------------------------------------------- def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): r"""Filter a batch of 2D images using the given 2D FIR filter. By default, the result is padded so that its shape matches the input. User-specified padding is applied on top of that, with negative values indicating cropping. Pixels outside the image are assumed to be zero. Args: x: Float32/float64/float16 input tensor of the shape `[batch_size, num_channels, in_height, in_width]`. f: Float32 FIR filter of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), or `None` (identity). padding: Padding with respect to the output. Can be a single number or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` (default: 0). flip_filter: False = convolution, True = correlation (default: False). gain: Overall scaling factor for signal magnitude (default: 1). impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). Returns: Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. """ padx0, padx1, pady0, pady1 = _parse_padding(padding) fw, fh = _get_filter_size(f) p = [ padx0 + fw // 2, padx1 + (fw - 1) // 2, pady0 + fh // 2, pady1 + (fh - 1) // 2, ] return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) #---------------------------------------------------------------------------- def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): r"""Upsample a batch of 2D images using the given 2D FIR filter. By default, the result is padded so that its shape is a multiple of the input. User-specified padding is applied on top of that, with negative values indicating cropping. Pixels outside the image are assumed to be zero. Args: x: Float32/float64/float16 input tensor of the shape `[batch_size, num_channels, in_height, in_width]`. f: Float32 FIR filter of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), or `None` (identity). up: Integer upsampling factor. Can be a single int or a list/tuple `[x, y]` (default: 1). padding: Padding with respect to the output. Can be a single number or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` (default: 0). flip_filter: False = convolution, True = correlation (default: False). gain: Overall scaling factor for signal magnitude (default: 1). impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). Returns: Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. """ upx, upy = _parse_scaling(up) padx0, padx1, pady0, pady1 = _parse_padding(padding) fw, fh = _get_filter_size(f) p = [ padx0 + (fw + upx - 1) // 2, padx1 + (fw - upx) // 2, pady0 + (fh + upy - 1) // 2, pady1 + (fh - upy) // 2, ] return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl) #---------------------------------------------------------------------------- def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): r"""Downsample a batch of 2D images using the given 2D FIR filter. By default, the result is padded so that its shape is a fraction of the input. User-specified padding is applied on top of that, with negative values indicating cropping. Pixels outside the image are assumed to be zero. Args: x: Float32/float64/float16 input tensor of the shape `[batch_size, num_channels, in_height, in_width]`. f: Float32 FIR filter of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), or `None` (identity). down: Integer downsampling factor. Can be a single int or a list/tuple `[x, y]` (default: 1). padding: Padding with respect to the input. Can be a single number or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` (default: 0). flip_filter: False = convolution, True = correlation (default: False). gain: Overall scaling factor for signal magnitude (default: 1). impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). Returns: Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. """ downx, downy = _parse_scaling(down) padx0, padx1, pady0, pady1 = _parse_padding(padding) fw, fh = _get_filter_size(f) p = [ padx0 + (fw - downx + 1) // 2, padx1 + (fw - downx) // 2, pady0 + (fh - downy + 1) // 2, pady1 + (fh - downy) // 2, ] return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) #---------------------------------------------------------------------------- ================================================ FILE: torch_utils/persistence.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Facilities for pickling Python code alongside other data. The pickled code is automatically imported into a separate Python module during unpickling. This way, any previously exported pickles will remain usable even if the original code is no longer available, or if the current version of the code is not consistent with what was originally pickled.""" import sys import pickle import io import inspect import copy import uuid import types import dnnlib #---------------------------------------------------------------------------- _version = 6 # internal version number _decorators = set() # {decorator_class, ...} _import_hooks = [] # [hook_function, ...] _module_to_src_dict = dict() # {module: src, ...} _src_to_module_dict = dict() # {src: module, ...} #---------------------------------------------------------------------------- def persistent_class(orig_class): r"""Class decorator that extends a given class to save its source code when pickled. Example: from torch_utils import persistence @persistence.persistent_class class MyNetwork(torch.nn.Module): def __init__(self, num_inputs, num_outputs): super().__init__() self.fc = MyLayer(num_inputs, num_outputs) ... @persistence.persistent_class class MyLayer(torch.nn.Module): ... When pickled, any instance of `MyNetwork` and `MyLayer` will save its source code alongside other internal state (e.g., parameters, buffers, and submodules). This way, any previously exported pickle will remain usable even if the class definitions have been modified or are no longer available. The decorator saves the source code of the entire Python module containing the decorated class. It does *not* save the source code of any imported modules. Thus, the imported modules must be available during unpickling, also including `torch_utils.persistence` itself. It is ok to call functions defined in the same module from the decorated class. However, if the decorated class depends on other classes defined in the same module, they must be decorated as well. This is illustrated in the above example in the case of `MyLayer`. It is also possible to employ the decorator just-in-time before calling the constructor. For example: cls = MyLayer if want_to_make_it_persistent: cls = persistence.persistent_class(cls) layer = cls(num_inputs, num_outputs) As an additional feature, the decorator also keeps track of the arguments that were used to construct each instance of the decorated class. The arguments can be queried via `obj.init_args` and `obj.init_kwargs`, and they are automatically pickled alongside other object state. A typical use case is to first unpickle a previous instance of a persistent class, and then upgrade it to use the latest version of the source code: with open('old_pickle.pkl', 'rb') as f: old_net = pickle.load(f) new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) misc.copy_params_and_buffers(old_net, new_net, require_all=True) """ assert isinstance(orig_class, type) if is_persistent(orig_class): return orig_class assert orig_class.__module__ in sys.modules orig_module = sys.modules[orig_class.__module__] orig_module_src = _module_to_src(orig_module) class Decorator(orig_class): _orig_module_src = orig_module_src _orig_class_name = orig_class.__name__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._init_args = copy.deepcopy(args) self._init_kwargs = copy.deepcopy(kwargs) assert orig_class.__name__ in orig_module.__dict__ _check_pickleable(self.__reduce__()) @property def init_args(self): return copy.deepcopy(self._init_args) @property def init_kwargs(self): return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) def __reduce__(self): fields = list(super().__reduce__()) fields += [None] * max(3 - len(fields), 0) if fields[0] is not _reconstruct_persistent_obj: meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) fields[0] = _reconstruct_persistent_obj # reconstruct func fields[1] = (meta,) # reconstruct args fields[2] = None # state dict return tuple(fields) Decorator.__name__ = orig_class.__name__ _decorators.add(Decorator) return Decorator #---------------------------------------------------------------------------- def is_persistent(obj): r"""Test whether the given object or class is persistent, i.e., whether it will save its source code when pickled. """ try: if obj in _decorators: return True except TypeError: pass return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck #---------------------------------------------------------------------------- def import_hook(hook): r"""Register an import hook that is called whenever a persistent object is being unpickled. A typical use case is to patch the pickled source code to avoid errors and inconsistencies when the API of some imported module has changed. The hook should have the following signature: hook(meta) -> modified meta `meta` is an instance of `dnnlib.EasyDict` with the following fields: type: Type of the persistent object, e.g. `'class'`. version: Internal version number of `torch_utils.persistence`. module_src Original source code of the Python module. class_name: Class name in the original Python module. state: Internal state of the object. Example: @persistence.import_hook def wreck_my_network(meta): if meta.class_name == 'MyNetwork': print('MyNetwork is being imported. I will wreck it!') meta.module_src = meta.module_src.replace("True", "False") return meta """ assert callable(hook) _import_hooks.append(hook) #---------------------------------------------------------------------------- def _reconstruct_persistent_obj(meta): r"""Hook that is called internally by the `pickle` module to unpickle a persistent object. """ meta = dnnlib.EasyDict(meta) meta.state = dnnlib.EasyDict(meta.state) for hook in _import_hooks: meta = hook(meta) assert meta is not None assert meta.version == _version module = _src_to_module(meta.module_src) assert meta.type == 'class' orig_class = module.__dict__[meta.class_name] decorator_class = persistent_class(orig_class) obj = decorator_class.__new__(decorator_class) setstate = getattr(obj, '__setstate__', None) if callable(setstate): setstate(meta.state) # pylint: disable=not-callable else: obj.__dict__.update(meta.state) return obj #---------------------------------------------------------------------------- def _module_to_src(module): r"""Query the source code of a given Python module. """ src = _module_to_src_dict.get(module, None) if src is None: src = inspect.getsource(module) _module_to_src_dict[module] = src _src_to_module_dict[src] = module return src def _src_to_module(src): r"""Get or create a Python module for the given source code. """ module = _src_to_module_dict.get(src, None) if module is None: module_name = "_imported_module_" + uuid.uuid4().hex module = types.ModuleType(module_name) sys.modules[module_name] = module _module_to_src_dict[module] = src _src_to_module_dict[src] = module exec(src, module.__dict__) # pylint: disable=exec-used return module #---------------------------------------------------------------------------- def _check_pickleable(obj): r"""Check that the given object is pickleable, raising an exception if it is not. This function is expected to be considerably more efficient than actually pickling the object. """ def recurse(obj): if isinstance(obj, (list, tuple, set)): return [recurse(x) for x in obj] if isinstance(obj, dict): return [[recurse(x), recurse(y)] for x, y in obj.items()] if isinstance(obj, (str, int, float, bool, bytes, bytearray)): return None # Python primitive types are pickleable. if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: return None # NumPy arrays and PyTorch tensors are pickleable. if is_persistent(obj): return None # Persistent objects are pickleable, by virtue of the constructor check. return obj with io.BytesIO() as f: pickle.dump(recurse(obj), f) #---------------------------------------------------------------------------- ================================================ FILE: torch_utils/training_stats.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Facilities for reporting and collecting training statistics across multiple processes and devices. The interface is designed to minimize synchronization overhead as well as the amount of boilerplate in user code.""" import re import numpy as np import torch import dnnlib from . import misc #---------------------------------------------------------------------------- _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. _counter_dtype = torch.float64 # Data type to use for the internal counters. _rank = 0 # Rank of the current process. _sync_device = None # Device to use for multiprocess communication. None = single-process. _sync_called = False # Has _sync() been called yet? _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor #---------------------------------------------------------------------------- def init_multiprocessing(rank, sync_device): r"""Initializes `torch_utils.training_stats` for collecting statistics across multiple processes. This function must be called after `torch.distributed.init_process_group()` and before `Collector.update()`. The call is not necessary if multi-process collection is not needed. Args: rank: Rank of the current process. sync_device: PyTorch device to use for inter-process communication, or None to disable multi-process collection. Typically `torch.device('cuda', rank)`. """ global _rank, _sync_device assert not _sync_called _rank = rank _sync_device = sync_device #---------------------------------------------------------------------------- @misc.profiled_function def report(name, value): r"""Broadcasts the given set of scalars to all interested instances of `Collector`, across device and process boundaries. This function is expected to be extremely cheap and can be safely called from anywhere in the training loop, loss function, or inside a `torch.nn.Module`. Warning: The current implementation expects the set of unique names to be consistent across processes. Please make sure that `report()` is called at least once for each unique name by each process, and in the same order. If a given process has no scalars to broadcast, it can do `report(name, [])` (empty list). Args: name: Arbitrary string specifying the name of the statistic. Averages are accumulated separately for each unique name. value: Arbitrary set of scalars. Can be a list, tuple, NumPy array, PyTorch tensor, or Python scalar. Returns: The same `value` that was passed in. """ if name not in _counters: _counters[name] = dict() elems = torch.as_tensor(value) if elems.numel() == 0: return value elems = elems.detach().flatten().to(_reduce_dtype) moments = torch.stack([ torch.ones_like(elems).sum(), elems.sum(), elems.square().sum(), ]) assert moments.ndim == 1 and moments.shape[0] == _num_moments moments = moments.to(_counter_dtype) device = moments.device if device not in _counters[name]: _counters[name][device] = torch.zeros_like(moments) _counters[name][device].add_(moments) return value #---------------------------------------------------------------------------- def report0(name, value): r"""Broadcasts the given set of scalars by the first process (`rank = 0`), but ignores any scalars provided by the other processes. See `report()` for further details. """ report(name, value if _rank == 0 else []) return value #---------------------------------------------------------------------------- class Collector: r"""Collects the scalars broadcasted by `report()` and `report0()` and computes their long-term averages (mean and standard deviation) over user-defined periods of time. The averages are first collected into internal counters that are not directly visible to the user. They are then copied to the user-visible state as a result of calling `update()` and can then be queried using `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the internal counters for the next round, so that the user-visible state effectively reflects averages collected between the last two calls to `update()`. Args: regex: Regular expression defining which statistics to collect. The default is to collect everything. keep_previous: Whether to retain the previous averages if no scalars were collected on a given round (default: True). """ def __init__(self, regex='.*', keep_previous=True): self._regex = re.compile(regex) self._keep_previous = keep_previous self._cumulative = dict() self._moments = dict() self.update() self._moments.clear() def names(self): r"""Returns the names of all statistics broadcasted so far that match the regular expression specified at construction time. """ return [name for name in _counters if self._regex.fullmatch(name)] def update(self): r"""Copies current values of the internal counters to the user-visible state and resets them for the next round. If `keep_previous=True` was specified at construction time, the operation is skipped for statistics that have received no scalars since the last update, retaining their previous averages. This method performs a number of GPU-to-CPU transfers and one `torch.distributed.all_reduce()`. It is intended to be called periodically in the main training loop, typically once every N training steps. """ if not self._keep_previous: self._moments.clear() for name, cumulative in _sync(self.names()): if name not in self._cumulative: self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) delta = cumulative - self._cumulative[name] self._cumulative[name].copy_(cumulative) if float(delta[0]) != 0: self._moments[name] = delta def _get_delta(self, name): r"""Returns the raw moments that were accumulated for the given statistic between the last two calls to `update()`, or zero if no scalars were collected. """ assert self._regex.fullmatch(name) if name not in self._moments: self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) return self._moments[name] def num(self, name): r"""Returns the number of scalars that were accumulated for the given statistic between the last two calls to `update()`, or zero if no scalars were collected. """ delta = self._get_delta(name) return int(delta[0]) def mean(self, name): r"""Returns the mean of the scalars that were accumulated for the given statistic between the last two calls to `update()`, or NaN if no scalars were collected. """ delta = self._get_delta(name) if int(delta[0]) == 0: return float('nan') return float(delta[1] / delta[0]) def std(self, name): r"""Returns the standard deviation of the scalars that were accumulated for the given statistic between the last two calls to `update()`, or NaN if no scalars were collected. """ delta = self._get_delta(name) if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): return float('nan') if int(delta[0]) == 1: return float(0) mean = float(delta[1] / delta[0]) raw_var = float(delta[2] / delta[0]) return np.sqrt(max(raw_var - np.square(mean), 0)) def as_dict(self): r"""Returns the averages accumulated between the last two calls to `update()` as an `dnnlib.EasyDict`. The contents are as follows: dnnlib.EasyDict( NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), ... ) """ stats = dnnlib.EasyDict() for name in self.names(): stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) return stats def __getitem__(self, name): r"""Convenience getter. `collector[name]` is a synonym for `collector.mean(name)`. """ return self.mean(name) #---------------------------------------------------------------------------- def _sync(names): r"""Synchronize the global cumulative counters across devices and processes. Called internally by `Collector.update()`. """ if len(names) == 0: return [] global _sync_called _sync_called = True # Collect deltas within current rank. deltas = [] device = _sync_device if _sync_device is not None else torch.device('cpu') for name in names: delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) for counter in _counters[name].values(): delta.add_(counter.to(device)) counter.copy_(torch.zeros_like(counter)) deltas.append(delta) deltas = torch.stack(deltas) # Sum deltas across ranks. if _sync_device is not None: torch.distributed.all_reduce(deltas) # Update cumulative values. deltas = deltas.cpu() for idx, name in enumerate(names): if name not in _cumulative: _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) _cumulative[name].add_(deltas[idx]) # Return name-value pairs. return [(name, _cumulative[name]) for name in names] #---------------------------------------------------------------------------- ================================================ FILE: training/__init__.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. # empty ================================================ FILE: training/augment.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Augmentation pipeline from the paper "Training Generative Adversarial Networks with Limited Data". Matches the original implementation by Karras et al. at https://github.com/NVlabs/stylegan2-ada/blob/main/training/augment.py""" import numpy as np import scipy.signal import torch from torch_utils import persistence from torch_utils import misc from torch_utils.ops import upfirdn2d from torch_utils.ops import grid_sample_gradfix from torch_utils.ops import conv2d_gradfix #---------------------------------------------------------------------------- # Coefficients of various wavelet decomposition low-pass filters. wavelets = { 'haar': [0.7071067811865476, 0.7071067811865476], 'db1': [0.7071067811865476, 0.7071067811865476], 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025], 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569], 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523], 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125], 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017], 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236], 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161], 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025], 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569], 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427], 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728], 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148], 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255], 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609], } #---------------------------------------------------------------------------- # Helpers for constructing transformation matrices. def matrix(*rows, device=None): assert all(len(row) == len(rows[0]) for row in rows) elems = [x for row in rows for x in row] ref = [x for x in elems if isinstance(x, torch.Tensor)] if len(ref) == 0: return misc.constant(np.asarray(rows), device=device) assert device is None or device == ref[0].device elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems] return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1)) def translate2d(tx, ty, **kwargs): return matrix( [1, 0, tx], [0, 1, ty], [0, 0, 1], **kwargs) def translate3d(tx, ty, tz, **kwargs): return matrix( [1, 0, 0, tx], [0, 1, 0, ty], [0, 0, 1, tz], [0, 0, 0, 1], **kwargs) def scale2d(sx, sy, **kwargs): return matrix( [sx, 0, 0], [0, sy, 0], [0, 0, 1], **kwargs) def scale3d(sx, sy, sz, **kwargs): return matrix( [sx, 0, 0, 0], [0, sy, 0, 0], [0, 0, sz, 0], [0, 0, 0, 1], **kwargs) def rotate2d(theta, **kwargs): return matrix( [torch.cos(theta), torch.sin(-theta), 0], [torch.sin(theta), torch.cos(theta), 0], [0, 0, 1], **kwargs) def rotate3d(v, theta, **kwargs): vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2] s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c return matrix( [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0], [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0], [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0], [0, 0, 0, 1], **kwargs) def translate2d_inv(tx, ty, **kwargs): return translate2d(-tx, -ty, **kwargs) def scale2d_inv(sx, sy, **kwargs): return scale2d(1 / sx, 1 / sy, **kwargs) def rotate2d_inv(theta, **kwargs): return rotate2d(-theta, **kwargs) #---------------------------------------------------------------------------- # Versatile image augmentation pipeline from the paper # "Training Generative Adversarial Networks with Limited Data". # # All augmentations are disabled by default; individual augmentations can # be enabled by setting their probability multipliers to 1. @persistence.persistent_class class AugmentPipe(torch.nn.Module): def __init__(self, xflip=0, rotate90=0, xint=0, xint_max=0.125, scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125, brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5, hue_max=1, saturation_std=1, imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1, noise=0, cutout=0, noise_std=0.1, cutout_size=0.5, ): super().__init__() self.register_buffer('p', torch.ones([])) # Overall multiplier for augmentation probability. # Pixel blitting. self.xflip = float(xflip) # Probability multiplier for x-flip. self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations. self.xint = float(xint) # Probability multiplier for integer translation. self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions. # General geometric transformations. self.scale = float(scale) # Probability multiplier for isotropic scaling. self.rotate = float(rotate) # Probability multiplier for arbitrary rotation. self.aniso = float(aniso) # Probability multiplier for anisotropic scaling. self.xfrac = float(xfrac) # Probability multiplier for fractional translation. self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling. self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle. self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling. self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions. # Color transformations. self.brightness = float(brightness) # Probability multiplier for brightness. self.contrast = float(contrast) # Probability multiplier for contrast. self.lumaflip = float(lumaflip) # Probability multiplier for luma flip. self.hue = float(hue) # Probability multiplier for hue rotation. self.saturation = float(saturation) # Probability multiplier for saturation. self.brightness_std = float(brightness_std) # Standard deviation of brightness. self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast. self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle. self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation. # Image-space filtering. self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering. self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands. self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification. # Image-space corruptions. self.noise = float(noise) # Probability multiplier for additive RGB noise. self.cutout = float(cutout) # Probability multiplier for cutout. self.noise_std = float(noise_std) # Standard deviation of additive RGB noise. self.cutout_size = float(cutout_size) # Size of the cutout rectangle, relative to image dimensions. # Setup orthogonal lowpass filter for geometric augmentations. self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6'])) # Construct filter bank for image-space filtering. Hz_lo = np.asarray(wavelets['sym2']) # H(z) Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z) Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2 Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2 Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i) for i in range(1, Hz_fbank.shape[0]): Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1] Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2]) Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2 self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32)) def forward(self, images, debug_percentile=None): assert isinstance(images, torch.Tensor) and images.ndim == 4 batch_size, num_channels, height, width = images.shape device = images.device if debug_percentile is not None: debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device) # ------------------------------------- # Select parameters for pixel blitting. # ------------------------------------- # Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in I_3 = torch.eye(3, device=device) G_inv = I_3 # Apply x-flip with probability (xflip * strength). if self.xflip > 0: i = torch.floor(torch.rand([batch_size], device=device) * 2) i = torch.where(torch.rand([batch_size], device=device) < self.xflip * self.p, i, torch.zeros_like(i)) if debug_percentile is not None: i = torch.full_like(i, torch.floor(debug_percentile * 2)) G_inv = G_inv @ scale2d_inv(1 - 2 * i, 1) # Apply 90 degree rotations with probability (rotate90 * strength). if self.rotate90 > 0: i = torch.floor(torch.rand([batch_size], device=device) * 4) i = torch.where(torch.rand([batch_size], device=device) < self.rotate90 * self.p, i, torch.zeros_like(i)) if debug_percentile is not None: i = torch.full_like(i, torch.floor(debug_percentile * 4)) G_inv = G_inv @ rotate2d_inv(-np.pi / 2 * i) # Apply integer translation with probability (xint * strength). if self.xint > 0: t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t)) if debug_percentile is not None: t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max) G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height)) # -------------------------------------------------------- # Select parameters for general geometric transformations. # -------------------------------------------------------- # Apply isotropic scaling with probability (scale * strength). if self.scale > 0: s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std) s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s)) if debug_percentile is not None: s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std)) G_inv = G_inv @ scale2d_inv(s, s) # Apply pre-rotation with probability p_rot. p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p if self.rotate > 0: theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta)) if debug_percentile is not None: theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max) G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling. # Apply anisotropic scaling with probability (aniso * strength). if self.aniso > 0: s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std) s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s)) if debug_percentile is not None: s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std)) G_inv = G_inv @ scale2d_inv(s, 1 / s) # Apply post-rotation with probability p_rot. if self.rotate > 0: theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta)) if debug_percentile is not None: theta = torch.zeros_like(theta) G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling. # Apply fractional translation with probability (xfrac * strength). if self.xfrac > 0: t = torch.randn([batch_size, 2], device=device) * self.xfrac_std t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t)) if debug_percentile is not None: t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std) G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height) # ---------------------------------- # Execute geometric transformations. # ---------------------------------- # Execute if the transform is not identity. if G_inv is not I_3: # Calculate padding. cx = (width - 1) / 2 cy = (height - 1) / 2 cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz] cp = G_inv @ cp.t() # [batch, xyz, idx] Hz_pad = self.Hz_geom.shape[0] // 4 margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx] margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1] margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device) margin = margin.max(misc.constant([0, 0] * 2, device=device)) margin = margin.min(misc.constant([width-1, height-1] * 2, device=device)) mx0, my0, mx1, my1 = margin.ceil().to(torch.int32) # Pad image and adjust origin. images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect') G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv # Upsample. images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2) G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device) G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device) # Execute transformation. shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2] G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device) grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False) images = grid_sample_gradfix.grid_sample(images, grid) # Downsample and crop. images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True) # -------------------------------------------- # Select parameters for color transformations. # -------------------------------------------- # Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out I_4 = torch.eye(4, device=device) C = I_4 # Apply brightness with probability (brightness * strength). if self.brightness > 0: b = torch.randn([batch_size], device=device) * self.brightness_std b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b)) if debug_percentile is not None: b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std) C = translate3d(b, b, b) @ C # Apply contrast with probability (contrast * strength). if self.contrast > 0: c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std) c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c)) if debug_percentile is not None: c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std)) C = scale3d(c, c, c) @ C # Apply luma flip with probability (lumaflip * strength). v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis. if self.lumaflip > 0: i = torch.floor(torch.rand([batch_size, 1, 1], device=device) * 2) i = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.lumaflip * self.p, i, torch.zeros_like(i)) if debug_percentile is not None: i = torch.full_like(i, torch.floor(debug_percentile * 2)) C = (I_4 - 2 * v.ger(v) * i) @ C # Householder reflection. # Apply hue rotation with probability (hue * strength). if self.hue > 0 and num_channels > 1: theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta)) if debug_percentile is not None: theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max) C = rotate3d(v, theta) @ C # Rotate around v. # Apply saturation with probability (saturation * strength). if self.saturation > 0 and num_channels > 1: s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std) s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s)) if debug_percentile is not None: s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std)) C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C # ------------------------------ # Execute color transformations. # ------------------------------ # Execute if the transform is not identity. if C is not I_4: images = images.reshape([batch_size, num_channels, height * width]) if num_channels == 3: images = C[:, :3, :3] @ images + C[:, :3, 3:] elif num_channels == 1: C = C[:, :3, :].mean(dim=1, keepdims=True) images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:] else: raise ValueError('Image must be RGB (3 channels) or L (1 channel)') images = images.reshape([batch_size, num_channels, height, width]) # ---------------------- # Image-space filtering. # ---------------------- if self.imgfilter > 0: num_bands = self.Hz_fbank.shape[0] assert len(self.imgfilter_bands) == num_bands expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f). # Apply amplification for each band with probability (imgfilter * strength * band_strength). g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity). for i, band_strength in enumerate(self.imgfilter_bands): t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std) t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i)) if debug_percentile is not None: t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i) t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector. t[:, i] = t_i # Replace i'th element. t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power. g = g * t # Accumulate into global gain. # Construct combined amplification filter. Hz_prime = g @ self.Hz_fbank # [batch, tap] Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap] Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap] # Apply filter. p = self.Hz_fbank.shape[1] // 2 images = images.reshape([1, batch_size * num_channels, height, width]) images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect') images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels) images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels) images = images.reshape([batch_size, num_channels, height, width]) # ------------------------ # Image-space corruptions. # ------------------------ # Apply additive RGB noise with probability (noise * strength). if self.noise > 0: sigma = torch.randn([batch_size, 1, 1, 1], device=device).abs() * self.noise_std sigma = torch.where(torch.rand([batch_size, 1, 1, 1], device=device) < self.noise * self.p, sigma, torch.zeros_like(sigma)) if debug_percentile is not None: sigma = torch.full_like(sigma, torch.erfinv(debug_percentile) * self.noise_std) images = images + torch.randn([batch_size, num_channels, height, width], device=device) * sigma # Apply cutout with probability (cutout * strength). if self.cutout > 0: size = torch.full([batch_size, 2, 1, 1, 1], self.cutout_size, device=device) size = torch.where(torch.rand([batch_size, 1, 1, 1, 1], device=device) < self.cutout * self.p, size, torch.zeros_like(size)) center = torch.rand([batch_size, 2, 1, 1, 1], device=device) if debug_percentile is not None: size = torch.full_like(size, self.cutout_size) center = torch.full_like(center, debug_percentile) coord_x = torch.arange(width, device=device).reshape([1, 1, 1, -1]) coord_y = torch.arange(height, device=device).reshape([1, 1, -1, 1]) mask_x = (((coord_x + 0.5) / width - center[:, 0]).abs() >= size[:, 0] / 2) mask_y = (((coord_y + 0.5) / height - center[:, 1]).abs() >= size[:, 1] / 2) mask = torch.logical_or(mask_x, mask_y).to(torch.float32) images = images * mask return images #---------------------------------------------------------------------------- ================================================ FILE: training/dataset.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Streaming images and labels from datasets created with dataset_tool.py.""" import os import numpy as np import zipfile import PIL.Image import json import torch import dnnlib try: import pyspng except ImportError: pyspng = None #---------------------------------------------------------------------------- class Dataset(torch.utils.data.Dataset): def __init__(self, name, # Name of the dataset. raw_shape, # Shape of the raw image data (NCHW). max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. use_labels = False, # Enable conditioning labels? False = label dimension is zero. xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. random_seed = 0, # Random seed to use when applying max_size. ): self._name = name self._raw_shape = list(raw_shape) self._use_labels = use_labels self._raw_labels = None self._label_shape = None # Apply max_size. self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) if (max_size is not None) and (self._raw_idx.size > max_size): np.random.RandomState(random_seed).shuffle(self._raw_idx) self._raw_idx = np.sort(self._raw_idx[:max_size]) # Apply xflip. self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) if xflip: self._raw_idx = np.tile(self._raw_idx, 2) self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) def _get_raw_labels(self): if self._raw_labels is None: self._raw_labels = self._load_raw_labels() if self._use_labels else None if self._raw_labels is None: self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) assert isinstance(self._raw_labels, np.ndarray) assert self._raw_labels.shape[0] == self._raw_shape[0] assert self._raw_labels.dtype in [np.float32, np.int64] if self._raw_labels.dtype == np.int64: assert self._raw_labels.ndim == 1 assert np.all(self._raw_labels >= 0) return self._raw_labels def close(self): # to be overridden by subclass pass def _load_raw_image(self, raw_idx): # to be overridden by subclass raise NotImplementedError def _load_raw_labels(self): # to be overridden by subclass raise NotImplementedError def __getstate__(self): return dict(self.__dict__, _raw_labels=None) def __del__(self): try: self.close() except: pass def __len__(self): return self._raw_idx.size def __getitem__(self, idx): image = self._load_raw_image(self._raw_idx[idx]) assert isinstance(image, np.ndarray) assert list(image.shape) == self.image_shape assert image.dtype == np.uint8 if self._xflip[idx]: assert image.ndim == 3 # CHW image = image[:, :, ::-1] return image.copy(), self.get_label(idx) def get_label(self, idx): label = self._get_raw_labels()[self._raw_idx[idx]] if label.dtype == np.int64: onehot = np.zeros(self.label_shape, dtype=np.float32) onehot[label] = 1 label = onehot return label.copy() def get_details(self, idx): d = dnnlib.EasyDict() d.raw_idx = int(self._raw_idx[idx]) d.xflip = (int(self._xflip[idx]) != 0) d.raw_label = self._get_raw_labels()[d.raw_idx].copy() return d @property def name(self): return self._name @property def image_shape(self): return list(self._raw_shape[1:]) @property def num_channels(self): assert len(self.image_shape) == 3 # CHW return self.image_shape[0] @property def resolution(self): assert len(self.image_shape) == 3 # CHW assert self.image_shape[1] == self.image_shape[2] return self.image_shape[1] @property def label_shape(self): if self._label_shape is None: raw_labels = self._get_raw_labels() if raw_labels.dtype == np.int64: self._label_shape = [int(np.max(raw_labels)) + 1] else: self._label_shape = raw_labels.shape[1:] return list(self._label_shape) @property def label_dim(self): assert len(self.label_shape) == 1 return self.label_shape[0] @property def has_labels(self): return any(x != 0 for x in self.label_shape) @property def has_onehot_labels(self): return self._get_raw_labels().dtype == np.int64 #---------------------------------------------------------------------------- class ImageFolderDataset(Dataset): def __init__(self, path, # Path to directory or zip. resolution = None, # Ensure specific resolution, None = highest available. **super_kwargs, # Additional arguments for the Dataset base class. ): self._path = path self._zipfile = None if os.path.isdir(self._path): self._type = 'dir' self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} elif self._file_ext(self._path) == '.zip': self._type = 'zip' self._all_fnames = set(self._get_zipfile().namelist()) else: raise IOError('Path must point to a directory or zip') PIL.Image.init() self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) if len(self._image_fnames) == 0: raise IOError('No image files found in the specified path') name = os.path.splitext(os.path.basename(self._path))[0] raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): raise IOError('Image files do not match the specified resolution') super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) @staticmethod def _file_ext(fname): return os.path.splitext(fname)[1].lower() def _get_zipfile(self): assert self._type == 'zip' if self._zipfile is None: self._zipfile = zipfile.ZipFile(self._path) return self._zipfile def _open_file(self, fname): if self._type == 'dir': return open(os.path.join(self._path, fname), 'rb') if self._type == 'zip': return self._get_zipfile().open(fname, 'r') return None def close(self): try: if self._zipfile is not None: self._zipfile.close() finally: self._zipfile = None def __getstate__(self): return dict(super().__getstate__(), _zipfile=None) def _load_raw_image(self, raw_idx): fname = self._image_fnames[raw_idx] with self._open_file(fname) as f: if pyspng is not None and self._file_ext(fname) == '.png': image = pyspng.load(f.read()) else: image = np.array(PIL.Image.open(f)) if image.ndim == 2: image = image[:, :, np.newaxis] # HW => HWC image = image.transpose(2, 0, 1) # HWC => CHW return image def _load_raw_labels(self): fname = 'dataset.json' if fname not in self._all_fnames: return None with self._open_file(fname) as f: labels = json.load(f)['labels'] if labels is None: return None labels = dict(labels) labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] labels = np.array(labels) labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) return labels #---------------------------------------------------------------------------- ================================================ FILE: training/loss.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Loss functions.""" import numpy as np import torch from torch_utils import training_stats from torch_utils.ops import conv2d_gradfix from torch_utils.ops import upfirdn2d #---------------------------------------------------------------------------- class Loss: def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg): # to be overridden by subclass raise NotImplementedError() #---------------------------------------------------------------------------- class StyleGAN2Loss(Loss): def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10, style_mixing_prob=0, pl_weight=0, pl_batch_shrink=2, pl_decay=0.01, pl_no_weight_grad=False, blur_init_sigma=0, blur_fade_kimg=0): super().__init__() self.device = device self.G = G self.D = D self.augment_pipe = augment_pipe self.r1_gamma = r1_gamma self.style_mixing_prob = style_mixing_prob self.pl_weight = pl_weight self.pl_batch_shrink = pl_batch_shrink self.pl_decay = pl_decay self.pl_no_weight_grad = pl_no_weight_grad self.pl_mean = torch.zeros([], device=device) self.blur_init_sigma = blur_init_sigma self.blur_fade_kimg = blur_fade_kimg def run_G(self, z, c, update_emas=False): ws = self.G.mapping(z, c, update_emas=update_emas) if self.style_mixing_prob > 0: with torch.autograd.profiler.record_function('style_mixing'): cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1])) ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:] img = self.G.synthesis(ws, update_emas=update_emas) return img, ws def run_D(self, img, c, blur_sigma=0, update_emas=False): blur_size = np.floor(blur_sigma * 3) if blur_size > 0: with torch.autograd.profiler.record_function('blur'): f = torch.arange(-blur_size, blur_size + 1, device=img.device).div(blur_sigma).square().neg().exp2() img = upfirdn2d.filter2d(img, f / f.sum()) if self.augment_pipe is not None: img = self.augment_pipe(img) logits = self.D(img, c, update_emas=update_emas) return logits def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg): assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth'] if self.pl_weight == 0: phase = {'Greg': 'none', 'Gboth': 'Gmain'}.get(phase, phase) if self.r1_gamma == 0: phase = {'Dreg': 'none', 'Dboth': 'Dmain'}.get(phase, phase) blur_sigma = max(1 - cur_nimg / (self.blur_fade_kimg * 1e3), 0) * self.blur_init_sigma if self.blur_fade_kimg > 0 else 0 # Gmain: Maximize logits for generated images. if phase in ['Gmain', 'Gboth']: with torch.autograd.profiler.record_function('Gmain_forward'): gen_img, _gen_ws = self.run_G(gen_z, gen_c) gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma) training_stats.report('Loss/scores/fake', gen_logits) training_stats.report('Loss/signs/fake', gen_logits.sign()) loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits)) training_stats.report('Loss/G/loss', loss_Gmain) with torch.autograd.profiler.record_function('Gmain_backward'): loss_Gmain.mean().mul(gain).backward() # Gpl: Apply path length regularization. if phase in ['Greg', 'Gboth']: with torch.autograd.profiler.record_function('Gpl_forward'): batch_size = gen_z.shape[0] // self.pl_batch_shrink gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size]) pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3]) with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients(self.pl_no_weight_grad): pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0] pl_lengths = pl_grads.square().sum(2).mean(1).sqrt() pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) self.pl_mean.copy_(pl_mean.detach()) pl_penalty = (pl_lengths - pl_mean).square() training_stats.report('Loss/pl_penalty', pl_penalty) loss_Gpl = pl_penalty * self.pl_weight training_stats.report('Loss/G/reg', loss_Gpl) with torch.autograd.profiler.record_function('Gpl_backward'): loss_Gpl.mean().mul(gain).backward() # Dmain: Minimize logits for generated images. loss_Dgen = 0 if phase in ['Dmain', 'Dboth']: with torch.autograd.profiler.record_function('Dgen_forward'): gen_img, _gen_ws = self.run_G(gen_z, gen_c, update_emas=True) gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma, update_emas=True) training_stats.report('Loss/scores/fake', gen_logits) training_stats.report('Loss/signs/fake', gen_logits.sign()) loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits)) with torch.autograd.profiler.record_function('Dgen_backward'): loss_Dgen.mean().mul(gain).backward() # Dmain: Maximize logits for real images. # Dr1: Apply R1 regularization. if phase in ['Dmain', 'Dreg', 'Dboth']: name = 'Dreal' if phase == 'Dmain' else 'Dr1' if phase == 'Dreg' else 'Dreal_Dr1' with torch.autograd.profiler.record_function(name + '_forward'): real_img_tmp = real_img.detach().requires_grad_(phase in ['Dreg', 'Dboth']) real_logits = self.run_D(real_img_tmp, real_c, blur_sigma=blur_sigma) training_stats.report('Loss/scores/real', real_logits) training_stats.report('Loss/signs/real', real_logits.sign()) loss_Dreal = 0 if phase in ['Dmain', 'Dboth']: loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits)) training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal) loss_Dr1 = 0 if phase in ['Dreg', 'Dboth']: with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients(): r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0] r1_penalty = r1_grads.square().sum([1,2,3]) loss_Dr1 = r1_penalty * (self.r1_gamma / 2) training_stats.report('Loss/r1_penalty', r1_penalty) training_stats.report('Loss/D/reg', loss_Dr1) with torch.autograd.profiler.record_function(name + '_backward'): (loss_Dreal + loss_Dr1).mean().mul(gain).backward() #---------------------------------------------------------------------------- ================================================ FILE: training/networks_stylegan2.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Network architectures from the paper "Analyzing and Improving the Image Quality of StyleGAN". Matches the original implementation of configs E-F by Karras et al. at https://github.com/NVlabs/stylegan2/blob/master/training/networks_stylegan2.py""" import numpy as np import torch import torch.nn.functional as F from torch_utils import misc from torch_utils import persistence from torch_utils.ops import conv2d_resample from torch_utils.ops import upfirdn2d from torch_utils.ops import bias_act from torch_utils.ops import fma #---------------------------------------------------------------------------- @misc.profiled_function def normalize_2nd_moment(x, dim=1, eps=1e-8): return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() #---------------------------------------------------------------------------- @misc.profiled_function def modulated_conv2d( x, # Input tensor of shape [batch_size, in_channels, in_height, in_width]. weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width]. styles, # Modulation coefficients of shape [batch_size, in_channels]. noise = None, # Optional noise tensor to add to the output activations. up = 1, # Integer upsampling factor. down = 1, # Integer downsampling factor. padding = 0, # Padding with respect to the upsampled image. resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter(). demodulate = True, # Apply weight demodulation? flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d). fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation? ): batch_size = x.shape[0] out_channels, in_channels, kh, kw = weight.shape misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk] misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] misc.assert_shape(styles, [batch_size, in_channels]) # [NI] # Pre-normalize inputs to avoid FP16 overflow. if x.dtype == torch.float16 and demodulate: weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I # Calculate per-sample weights and demodulation coefficients. w = None dcoefs = None if demodulate or fused_modconv: w = weight.unsqueeze(0) # [NOIkk] w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] if demodulate: dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] if demodulate and fused_modconv: w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] # Execute by scaling the activations before and after the convolution. if not fused_modconv: x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight) if demodulate and noise is not None: x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) elif demodulate: x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1) elif noise is not None: x = x.add_(noise.to(x.dtype)) return x # Execute as one fused op using grouped convolution. with misc.suppress_tracer_warnings(): # this value will be treated as a constant batch_size = int(batch_size) misc.assert_shape(x, [batch_size, in_channels, None, None]) x = x.reshape(1, -1, *x.shape[2:]) w = w.reshape(-1, in_channels, kh, kw) x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight) x = x.reshape(batch_size, -1, *x.shape[2:]) if noise is not None: x = x.add_(noise) return x #---------------------------------------------------------------------------- @persistence.persistent_class class FullyConnectedLayer(torch.nn.Module): def __init__(self, in_features, # Number of input features. out_features, # Number of output features. bias = True, # Apply additive bias before the activation function? activation = 'linear', # Activation function: 'relu', 'lrelu', etc. lr_multiplier = 1, # Learning rate multiplier. bias_init = 0, # Initial value for the additive bias. ): super().__init__() self.in_features = in_features self.out_features = out_features self.activation = activation self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None self.weight_gain = lr_multiplier / np.sqrt(in_features) self.bias_gain = lr_multiplier def forward(self, x): w = self.weight.to(x.dtype) * self.weight_gain b = self.bias if b is not None: b = b.to(x.dtype) if self.bias_gain != 1: b = b * self.bias_gain if self.activation == 'linear' and b is not None: x = torch.addmm(b.unsqueeze(0), x, w.t()) else: x = x.matmul(w.t()) x = bias_act.bias_act(x, b, act=self.activation) return x def extra_repr(self): return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}' #---------------------------------------------------------------------------- @persistence.persistent_class class Conv2dLayer(torch.nn.Module): def __init__(self, in_channels, # Number of input channels. out_channels, # Number of output channels. kernel_size, # Width and height of the convolution kernel. bias = True, # Apply additive bias before the activation function? activation = 'linear', # Activation function: 'relu', 'lrelu', etc. up = 1, # Integer upsampling factor. down = 1, # Integer downsampling factor. resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. conv_clamp = None, # Clamp the output to +-X, None = disable clamping. channels_last = False, # Expect the input to have memory_format=channels_last? trainable = True, # Update the weights of this layer during training? ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.activation = activation self.up = up self.down = down self.conv_clamp = conv_clamp self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) self.padding = kernel_size // 2 self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) self.act_gain = bias_act.activation_funcs[activation].def_gain memory_format = torch.channels_last if channels_last else torch.contiguous_format weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format) bias = torch.zeros([out_channels]) if bias else None if trainable: self.weight = torch.nn.Parameter(weight) self.bias = torch.nn.Parameter(bias) if bias is not None else None else: self.register_buffer('weight', weight) if bias is not None: self.register_buffer('bias', bias) else: self.bias = None def forward(self, x, gain=1): w = self.weight * self.weight_gain b = self.bias.to(x.dtype) if self.bias is not None else None flip_weight = (self.up == 1) # slightly faster x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight) act_gain = self.act_gain * gain act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp) return x def extra_repr(self): return ' '.join([ f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, activation={self.activation:s},', f'up={self.up}, down={self.down}']) #---------------------------------------------------------------------------- @persistence.persistent_class class MappingNetwork(torch.nn.Module): def __init__(self, z_dim, # Input latent (Z) dimensionality, 0 = no latent. c_dim, # Conditioning label (C) dimensionality, 0 = no label. w_dim, # Intermediate latent (W) dimensionality. num_ws, # Number of intermediate latents to output, None = do not broadcast. num_layers = 8, # Number of mapping layers. embed_features = None, # Label embedding dimensionality, None = same as w_dim. layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim. activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers. w_avg_beta = 0.998, # Decay for tracking the moving average of W during training, None = do not track. ): super().__init__() self.z_dim = z_dim self.c_dim = c_dim self.w_dim = w_dim self.num_ws = num_ws self.num_layers = num_layers self.w_avg_beta = w_avg_beta if embed_features is None: embed_features = w_dim if c_dim == 0: embed_features = 0 if layer_features is None: layer_features = w_dim features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] if c_dim > 0: self.embed = FullyConnectedLayer(c_dim, embed_features) for idx in range(num_layers): in_features = features_list[idx] out_features = features_list[idx + 1] layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) setattr(self, f'fc{idx}', layer) if num_ws is not None and w_avg_beta is not None: self.register_buffer('w_avg', torch.zeros([w_dim])) def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False): # Embed, normalize, and concat inputs. x = None with torch.autograd.profiler.record_function('input'): if self.z_dim > 0: misc.assert_shape(z, [None, self.z_dim]) x = normalize_2nd_moment(z.to(torch.float32)) if self.c_dim > 0: misc.assert_shape(c, [None, self.c_dim]) y = normalize_2nd_moment(self.embed(c.to(torch.float32))) x = torch.cat([x, y], dim=1) if x is not None else y # Main layers. for idx in range(self.num_layers): layer = getattr(self, f'fc{idx}') x = layer(x) # Update moving average of W. if update_emas and self.w_avg_beta is not None: with torch.autograd.profiler.record_function('update_w_avg'): self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) # Broadcast. if self.num_ws is not None: with torch.autograd.profiler.record_function('broadcast'): x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) # Apply truncation. if truncation_psi != 1: with torch.autograd.profiler.record_function('truncate'): assert self.w_avg_beta is not None if self.num_ws is None or truncation_cutoff is None: x = self.w_avg.lerp(x, truncation_psi) else: x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) return x def extra_repr(self): return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}' #---------------------------------------------------------------------------- @persistence.persistent_class class SynthesisLayer(torch.nn.Module): def __init__(self, in_channels, # Number of input channels. out_channels, # Number of output channels. w_dim, # Intermediate latent (W) dimensionality. resolution, # Resolution of this layer. kernel_size = 3, # Convolution kernel size. up = 1, # Integer upsampling factor. use_noise = True, # Enable noise input? activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. channels_last = False, # Use channels_last format for the weights? ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.w_dim = w_dim self.resolution = resolution self.up = up self.use_noise = use_noise self.activation = activation self.conv_clamp = conv_clamp self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) self.padding = kernel_size // 2 self.act_gain = bias_act.activation_funcs[activation].def_gain self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) memory_format = torch.channels_last if channels_last else torch.contiguous_format self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)) if use_noise: self.register_buffer('noise_const', torch.randn([resolution, resolution])) self.noise_strength = torch.nn.Parameter(torch.zeros([])) self.bias = torch.nn.Parameter(torch.zeros([out_channels])) def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1): assert noise_mode in ['random', 'const', 'none'] in_resolution = self.resolution // self.up misc.assert_shape(x, [None, self.in_channels, in_resolution, in_resolution]) styles = self.affine(w) noise = None if self.use_noise and noise_mode == 'random': noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength if self.use_noise and noise_mode == 'const': noise = self.noise_const * self.noise_strength flip_weight = (self.up == 1) # slightly faster x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up, padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv) act_gain = self.act_gain * gain act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp) return x def extra_repr(self): return ' '.join([ f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d},', f'resolution={self.resolution:d}, up={self.up}, activation={self.activation:s}']) #---------------------------------------------------------------------------- @persistence.persistent_class class ToRGBLayer(torch.nn.Module): def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.w_dim = w_dim self.conv_clamp = conv_clamp self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) memory_format = torch.channels_last if channels_last else torch.contiguous_format self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)) self.bias = torch.nn.Parameter(torch.zeros([out_channels])) self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) def forward(self, x, w, fused_modconv=True): styles = self.affine(w) * self.weight_gain x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv) x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp) return x def extra_repr(self): return f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d}' #---------------------------------------------------------------------------- @persistence.persistent_class class SynthesisBlock(torch.nn.Module): def __init__(self, in_channels, # Number of input channels, 0 = first block. out_channels, # Number of output channels. w_dim, # Intermediate latent (W) dimensionality. resolution, # Resolution of this block. img_channels, # Number of output color channels. is_last, # Is this the last block? architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'. resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. use_fp16 = False, # Use FP16 for this block? fp16_channels_last = False, # Use channels-last memory format with FP16? fused_modconv_default = True, # Default value of fused_modconv. 'inference_only' = True for inference, False for training. **layer_kwargs, # Arguments for SynthesisLayer. ): assert architecture in ['orig', 'skip', 'resnet'] super().__init__() self.in_channels = in_channels self.w_dim = w_dim self.resolution = resolution self.img_channels = img_channels self.is_last = is_last self.architecture = architecture self.use_fp16 = use_fp16 self.channels_last = (use_fp16 and fp16_channels_last) self.fused_modconv_default = fused_modconv_default self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) self.num_conv = 0 self.num_torgb = 0 if in_channels == 0: self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution])) if in_channels != 0: self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=2, resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs) self.num_conv += 1 self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution, conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs) self.num_conv += 1 if is_last or architecture == 'skip': self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim, conv_clamp=conv_clamp, channels_last=self.channels_last) self.num_torgb += 1 if in_channels != 0 and architecture == 'resnet': self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2, resample_filter=resample_filter, channels_last=self.channels_last) def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, update_emas=False, **layer_kwargs): _ = update_emas # unused misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) w_iter = iter(ws.unbind(dim=1)) if ws.device.type != 'cuda': force_fp32 = True dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format if fused_modconv is None: fused_modconv = self.fused_modconv_default if fused_modconv == 'inference_only': fused_modconv = (not self.training) # Input. if self.in_channels == 0: x = self.const.to(dtype=dtype, memory_format=memory_format) x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) else: misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2]) x = x.to(dtype=dtype, memory_format=memory_format) # Main layers. if self.in_channels == 0: x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) elif self.architecture == 'resnet': y = self.skip(x, gain=np.sqrt(0.5)) x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) x = y.add_(x) else: x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) # ToRGB. if img is not None: misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2]) img = upfirdn2d.upsample2d(img, self.resample_filter) if self.is_last or self.architecture == 'skip': y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv) y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) img = img.add_(y) if img is not None else y assert x.dtype == dtype assert img is None or img.dtype == torch.float32 return x, img def extra_repr(self): return f'resolution={self.resolution:d}, architecture={self.architecture:s}' #---------------------------------------------------------------------------- @persistence.persistent_class class SynthesisNetwork(torch.nn.Module): def __init__(self, w_dim, # Intermediate latent (W) dimensionality. img_resolution, # Output image resolution. img_channels, # Number of color channels. channel_base = 32768, # Overall multiplier for the number of channels. channel_max = 512, # Maximum number of channels in any layer. num_fp16_res = 4, # Use FP16 for the N highest resolutions. **block_kwargs, # Arguments for SynthesisBlock. ): assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 super().__init__() self.w_dim = w_dim self.img_resolution = img_resolution self.img_resolution_log2 = int(np.log2(img_resolution)) self.img_channels = img_channels self.num_fp16_res = num_fp16_res self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)] channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions} fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) self.num_ws = 0 for res in self.block_resolutions: in_channels = channels_dict[res // 2] if res > 4 else 0 out_channels = channels_dict[res] use_fp16 = (res >= fp16_resolution) is_last = (res == self.img_resolution) block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res, img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs) self.num_ws += block.num_conv if is_last: self.num_ws += block.num_torgb setattr(self, f'b{res}', block) def forward(self, ws, return_feature=False, **block_kwargs): block_ws = [] features = [] with torch.autograd.profiler.record_function('split_ws'): misc.assert_shape(ws, [None, self.num_ws, self.w_dim]) ws = ws.to(torch.float32) w_idx = 0 for res in self.block_resolutions: block = getattr(self, f'b{res}') block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb)) w_idx += block.num_conv x = img = None for res, cur_ws in zip(self.block_resolutions, block_ws): block = getattr(self, f'b{res}') x, img = block(x, img, cur_ws, **block_kwargs) features.append(x) if return_feature: return img, features else: return img def extra_repr(self): return ' '.join([ f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},', f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},', f'num_fp16_res={self.num_fp16_res:d}']) #---------------------------------------------------------------------------- @persistence.persistent_class class Generator(torch.nn.Module): def __init__(self, z_dim, # Input latent (Z) dimensionality. c_dim, # Conditioning label (C) dimensionality. w_dim, # Intermediate latent (W) dimensionality. img_resolution, # Output resolution. img_channels, # Number of output color channels. mapping_kwargs = {}, # Arguments for MappingNetwork. synthesis_kwargs = {}, # Arguments for SynthesisNetwork. resize=None, **synthesis_kwargs2, # Arguments for SynthesisNetwork. ): super().__init__() self.z_dim = z_dim self.c_dim = c_dim self.w_dim = w_dim self.img_resolution = img_resolution self.img_channels = img_channels if len(synthesis_kwargs) == 0: synthesis_kwargs = synthesis_kwargs2 self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs) self.num_ws = self.synthesis.num_ws self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) self.resize = resize def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, input_is_w=False, return_feature=False, **synthesis_kwargs): if input_is_w: ws = z if ws.dim() == 2: ws = ws.unsqueeze(1).repeat([1, self.mapping.num_ws, 1]) else: ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) img = self.synthesis(ws, update_emas=update_emas, return_feature=return_feature, **synthesis_kwargs) if return_feature: img, feature = img if self.resize is not None: img = imresize(img, [self.resize, self.resize]) if return_feature: return img, feature else: return img def imresize(image, size): dim = image.dim() if dim == 3: image = image.unsqueeze(1) b, _, h, w = image.shape if size[0] > h: image = F.interpolate(image, size, mode='bilinear') elif size[0] < h: image = F.interpolate(image, size, mode='area') if dim == 3: image = image.squeeze(1) return image #---------------------------------------------------------------------------- @persistence.persistent_class class DiscriminatorBlock(torch.nn.Module): def __init__(self, in_channels, # Number of input channels, 0 = first block. tmp_channels, # Number of intermediate channels. out_channels, # Number of output channels. resolution, # Resolution of this block. img_channels, # Number of input color channels. first_layer_idx, # Index of the first layer. architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. use_fp16 = False, # Use FP16 for this block? fp16_channels_last = False, # Use channels-last memory format with FP16? freeze_layers = 0, # Freeze-D: Number of layers to freeze. ): assert in_channels in [0, tmp_channels] assert architecture in ['orig', 'skip', 'resnet'] super().__init__() self.in_channels = in_channels self.resolution = resolution self.img_channels = img_channels self.first_layer_idx = first_layer_idx self.architecture = architecture self.use_fp16 = use_fp16 self.channels_last = (use_fp16 and fp16_channels_last) self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) self.num_layers = 0 def trainable_gen(): while True: layer_idx = self.first_layer_idx + self.num_layers trainable = (layer_idx >= freeze_layers) self.num_layers += 1 yield trainable trainable_iter = trainable_gen() if in_channels == 0 or architecture == 'skip': self.fromrgb = Conv2dLayer(img_channels, tmp_channels, kernel_size=1, activation=activation, trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last) self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation, trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last) self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2, trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last) if architecture == 'resnet': self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2, trainable=next(trainable_iter), resample_filter=resample_filter, channels_last=self.channels_last) def forward(self, x, img, force_fp32=False): if (x if x is not None else img).device.type != 'cuda': force_fp32 = True dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format # Input. if x is not None: misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) x = x.to(dtype=dtype, memory_format=memory_format) # FromRGB. if self.in_channels == 0 or self.architecture == 'skip': misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution]) img = img.to(dtype=dtype, memory_format=memory_format) y = self.fromrgb(img) x = x + y if x is not None else y img = upfirdn2d.downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None # Main layers. if self.architecture == 'resnet': y = self.skip(x, gain=np.sqrt(0.5)) x = self.conv0(x) x = self.conv1(x, gain=np.sqrt(0.5)) x = y.add_(x) else: x = self.conv0(x) x = self.conv1(x) assert x.dtype == dtype return x, img def extra_repr(self): return f'resolution={self.resolution:d}, architecture={self.architecture:s}' #---------------------------------------------------------------------------- @persistence.persistent_class class MinibatchStdLayer(torch.nn.Module): def __init__(self, group_size, num_channels=1): super().__init__() self.group_size = group_size self.num_channels = num_channels def forward(self, x): N, C, H, W = x.shape with misc.suppress_tracer_warnings(): # as_tensor results are registered as constants G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N F = self.num_channels c = C // F y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c. y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group. y = y.square().mean(dim=0) # [nFcHW] Calc variance over group. y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group. y = y.mean(dim=[2,3,4]) # [nF] Take average over channels and pixels. y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions. y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels. x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels. return x def extra_repr(self): return f'group_size={self.group_size}, num_channels={self.num_channels:d}' #---------------------------------------------------------------------------- @persistence.persistent_class class DiscriminatorEpilogue(torch.nn.Module): def __init__(self, in_channels, # Number of input channels. cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label. resolution, # Resolution of this block. img_channels, # Number of input color channels. architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch. mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable. activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. ): assert architecture in ['orig', 'skip', 'resnet'] super().__init__() self.in_channels = in_channels self.cmap_dim = cmap_dim self.resolution = resolution self.img_channels = img_channels self.architecture = architecture if architecture == 'skip': self.fromrgb = Conv2dLayer(img_channels, in_channels, kernel_size=1, activation=activation) self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, conv_clamp=conv_clamp) self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), in_channels, activation=activation) self.out = FullyConnectedLayer(in_channels, 1 if cmap_dim == 0 else cmap_dim) def forward(self, x, img, cmap, force_fp32=False): misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) # [NCHW] _ = force_fp32 # unused dtype = torch.float32 memory_format = torch.contiguous_format # FromRGB. x = x.to(dtype=dtype, memory_format=memory_format) if self.architecture == 'skip': misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution]) img = img.to(dtype=dtype, memory_format=memory_format) x = x + self.fromrgb(img) # Main layers. if self.mbstd is not None: x = self.mbstd(x) x = self.conv(x) x = self.fc(x.flatten(1)) x = self.out(x) # Conditioning. if self.cmap_dim > 0: misc.assert_shape(cmap, [None, self.cmap_dim]) x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) assert x.dtype == dtype return x def extra_repr(self): return f'resolution={self.resolution:d}, architecture={self.architecture:s}' #---------------------------------------------------------------------------- @persistence.persistent_class class Discriminator(torch.nn.Module): def __init__(self, c_dim, # Conditioning label (C) dimensionality. img_resolution, # Input resolution. img_channels, # Number of input color channels. architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. channel_base = 32768, # Overall multiplier for the number of channels. channel_max = 512, # Maximum number of channels in any layer. num_fp16_res = 4, # Use FP16 for the N highest resolutions. conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. block_kwargs = {}, # Arguments for DiscriminatorBlock. mapping_kwargs = {}, # Arguments for MappingNetwork. epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue. ): super().__init__() self.c_dim = c_dim self.img_resolution = img_resolution self.img_resolution_log2 = int(np.log2(img_resolution)) self.img_channels = img_channels self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) if cmap_dim is None: cmap_dim = channels_dict[4] if c_dim == 0: cmap_dim = 0 common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) cur_layer_idx = 0 for res in self.block_resolutions: in_channels = channels_dict[res] if res < img_resolution else 0 tmp_channels = channels_dict[res] out_channels = channels_dict[res // 2] use_fp16 = (res >= fp16_resolution) block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) setattr(self, f'b{res}', block) cur_layer_idx += block.num_layers if c_dim > 0: self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs) def forward(self, img, c, update_emas=False, **block_kwargs): _ = update_emas # unused x = None for res in self.block_resolutions: block = getattr(self, f'b{res}') x, img = block(x, img, **block_kwargs) cmap = None if self.c_dim > 0: cmap = self.mapping(None, c) x = self.b4(x, img, cmap) return x def extra_repr(self): return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}' #---------------------------------------------------------------------------- ================================================ FILE: training/networks_stylegan3.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Generator architecture from the paper "Alias-Free Generative Adversarial Networks".""" import numpy as np import scipy.signal import scipy.optimize import torch import torch.nn.functional as F from torch_utils import misc from torch_utils import persistence from torch_utils.ops import conv2d_gradfix from torch_utils.ops import filtered_lrelu from torch_utils.ops import bias_act #---------------------------------------------------------------------------- @misc.profiled_function def modulated_conv2d( x, # Input tensor: [batch_size, in_channels, in_height, in_width] w, # Weight tensor: [out_channels, in_channels, kernel_height, kernel_width] s, # Style tensor: [batch_size, in_channels] demodulate = True, # Apply weight demodulation? padding = 0, # Padding: int or [padH, padW] input_gain = None, # Optional scale factors for the input channels: [], [in_channels], or [batch_size, in_channels] ): with misc.suppress_tracer_warnings(): # this value will be treated as a constant batch_size = int(x.shape[0]) out_channels, in_channels, kh, kw = w.shape misc.assert_shape(w, [out_channels, in_channels, kh, kw]) # [OIkk] misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] misc.assert_shape(s, [batch_size, in_channels]) # [NI] # Pre-normalize inputs. if demodulate: w = w * w.square().mean([1,2,3], keepdim=True).rsqrt() s = s * s.square().mean().rsqrt() # Modulate weights. w = w.unsqueeze(0) # [NOIkk] w = w * s.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk] # Demodulate weights. if demodulate: dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] w = w * dcoefs.unsqueeze(2).unsqueeze(3).unsqueeze(4) # [NOIkk] # Apply input scaling. if input_gain is not None: input_gain = input_gain.expand(batch_size, in_channels) # [NI] w = w * input_gain.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk] # Execute as one fused op using grouped convolution. x = x.reshape(1, -1, *x.shape[2:]) w = w.reshape(-1, in_channels, kh, kw) x = conv2d_gradfix.conv2d(input=x, weight=w.to(x.dtype), padding=padding, groups=batch_size) x = x.reshape(batch_size, -1, *x.shape[2:]) return x #---------------------------------------------------------------------------- @persistence.persistent_class class FullyConnectedLayer(torch.nn.Module): def __init__(self, in_features, # Number of input features. out_features, # Number of output features. activation = 'linear', # Activation function: 'relu', 'lrelu', etc. bias = True, # Apply additive bias before the activation function? lr_multiplier = 1, # Learning rate multiplier. weight_init = 1, # Initial standard deviation of the weight tensor. bias_init = 0, # Initial value of the additive bias. ): super().__init__() self.in_features = in_features self.out_features = out_features self.activation = activation self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) * (weight_init / lr_multiplier)) bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_features]) self.bias = torch.nn.Parameter(torch.from_numpy(bias_init / lr_multiplier)) if bias else None self.weight_gain = lr_multiplier / np.sqrt(in_features) self.bias_gain = lr_multiplier def forward(self, x): w = self.weight.to(x.dtype) * self.weight_gain b = self.bias if b is not None: b = b.to(x.dtype) if self.bias_gain != 1: b = b * self.bias_gain if self.activation == 'linear' and b is not None: x = torch.addmm(b.unsqueeze(0), x, w.t()) else: x = x.matmul(w.t()) x = bias_act.bias_act(x, b, act=self.activation) return x def extra_repr(self): return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}' #---------------------------------------------------------------------------- @persistence.persistent_class class MappingNetwork(torch.nn.Module): def __init__(self, z_dim, # Input latent (Z) dimensionality. c_dim, # Conditioning label (C) dimensionality, 0 = no labels. w_dim, # Intermediate latent (W) dimensionality. num_ws, # Number of intermediate latents to output. num_layers = 2, # Number of mapping layers. lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers. w_avg_beta = 0.998, # Decay for tracking the moving average of W during training. ): super().__init__() self.z_dim = z_dim self.c_dim = c_dim self.w_dim = w_dim self.num_ws = num_ws self.num_layers = num_layers self.w_avg_beta = w_avg_beta # Construct layers. self.embed = FullyConnectedLayer(self.c_dim, self.w_dim) if self.c_dim > 0 else None features = [self.z_dim + (self.w_dim if self.c_dim > 0 else 0)] + [self.w_dim] * self.num_layers for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]): layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier) setattr(self, f'fc{idx}', layer) self.register_buffer('w_avg', torch.zeros([w_dim])) def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False): misc.assert_shape(z, [None, self.z_dim]) if truncation_cutoff is None: truncation_cutoff = self.num_ws # Embed, normalize, and concatenate inputs. x = z.to(torch.float32) x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt() if self.c_dim > 0: misc.assert_shape(c, [None, self.c_dim]) y = self.embed(c.to(torch.float32)) y = y * (y.square().mean(1, keepdim=True) + 1e-8).rsqrt() x = torch.cat([x, y], dim=1) if x is not None else y # Execute layers. for idx in range(self.num_layers): x = getattr(self, f'fc{idx}')(x) # Update moving average of W. if update_emas: self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) # Broadcast and apply truncation. x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) if truncation_psi != 1: x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) return x def extra_repr(self): return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}' #---------------------------------------------------------------------------- @persistence.persistent_class class SynthesisInput(torch.nn.Module): def __init__(self, w_dim, # Intermediate latent (W) dimensionality. channels, # Number of output channels. size, # Output spatial size: int or [width, height]. sampling_rate, # Output sampling rate. bandwidth, # Output bandwidth. ): super().__init__() self.w_dim = w_dim self.channels = channels self.size = np.broadcast_to(np.asarray(size), [2]) self.sampling_rate = sampling_rate self.bandwidth = bandwidth # Draw random frequencies from uniform 2D disc. freqs = torch.randn([self.channels, 2]) radii = freqs.square().sum(dim=1, keepdim=True).sqrt() freqs /= radii * radii.square().exp().pow(0.25) freqs *= bandwidth phases = torch.rand([self.channels]) - 0.5 # Setup parameters and buffers. self.weight = torch.nn.Parameter(torch.randn([self.channels, self.channels])) self.affine = FullyConnectedLayer(w_dim, 4, weight_init=0, bias_init=[1,0,0,0]) self.register_buffer('transform', torch.eye(3, 3)) # User-specified inverse transform wrt. resulting image. self.register_buffer('freqs', freqs) self.register_buffer('phases', phases) def forward(self, w): # Introduce batch dimension. transforms = self.transform.unsqueeze(0) # [batch, row, col] freqs = self.freqs.unsqueeze(0) # [batch, channel, xy] phases = self.phases.unsqueeze(0) # [batch, channel] # Apply learned transformation. t = self.affine(w) # t = (r_c, r_s, t_x, t_y) t = t / t[:, :2].norm(dim=1, keepdim=True) # t' = (r'_c, r'_s, t'_x, t'_y) m_r = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse rotation wrt. resulting image. m_r[:, 0, 0] = t[:, 0] # r'_c m_r[:, 0, 1] = -t[:, 1] # r'_s m_r[:, 1, 0] = t[:, 1] # r'_s m_r[:, 1, 1] = t[:, 0] # r'_c m_t = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse translation wrt. resulting image. m_t[:, 0, 2] = -t[:, 2] # t'_x m_t[:, 1, 2] = -t[:, 3] # t'_y transforms = m_r @ m_t @ transforms # First rotate resulting image, then translate, and finally apply user-specified transform. # Transform frequencies. phases = phases + (freqs @ transforms[:, :2, 2:]).squeeze(2) freqs = freqs @ transforms[:, :2, :2] # Dampen out-of-band frequencies that may occur due to the user-specified transform. amplitudes = (1 - (freqs.norm(dim=2) - self.bandwidth) / (self.sampling_rate / 2 - self.bandwidth)).clamp(0, 1) # Construct sampling grid. theta = torch.eye(2, 3, device=w.device) theta[0, 0] = 0.5 * self.size[0] / self.sampling_rate theta[1, 1] = 0.5 * self.size[1] / self.sampling_rate grids = torch.nn.functional.affine_grid(theta.unsqueeze(0), [1, 1, self.size[1], self.size[0]], align_corners=False) # Compute Fourier features. x = (grids.unsqueeze(3) @ freqs.permute(0, 2, 1).unsqueeze(1).unsqueeze(2)).squeeze(3) # [batch, height, width, channel] x = x + phases.unsqueeze(1).unsqueeze(2) x = torch.sin(x * (np.pi * 2)) x = x * amplitudes.unsqueeze(1).unsqueeze(2) # Apply trainable mapping. weight = self.weight / np.sqrt(self.channels) x = x @ weight.t() # Ensure correct shape. x = x.permute(0, 3, 1, 2) # [batch, channel, height, width] misc.assert_shape(x, [w.shape[0], self.channels, int(self.size[1]), int(self.size[0])]) return x def extra_repr(self): return '\n'.join([ f'w_dim={self.w_dim:d}, channels={self.channels:d}, size={list(self.size)},', f'sampling_rate={self.sampling_rate:g}, bandwidth={self.bandwidth:g}']) #---------------------------------------------------------------------------- @persistence.persistent_class class SynthesisLayer(torch.nn.Module): def __init__(self, w_dim, # Intermediate latent (W) dimensionality. is_torgb, # Is this the final ToRGB layer? is_critically_sampled, # Does this layer use critical sampling? use_fp16, # Does this layer use FP16? # Input & output specifications. in_channels, # Number of input channels. out_channels, # Number of output channels. in_size, # Input spatial size: int or [width, height]. out_size, # Output spatial size: int or [width, height]. in_sampling_rate, # Input sampling rate (s). out_sampling_rate, # Output sampling rate (s). in_cutoff, # Input cutoff frequency (f_c). out_cutoff, # Output cutoff frequency (f_c). in_half_width, # Input transition band half-width (f_h). out_half_width, # Output Transition band half-width (f_h). # Hyperparameters. conv_kernel = 3, # Convolution kernel size. Ignored for final the ToRGB layer. filter_size = 6, # Low-pass filter size relative to the lower resolution when up/downsampling. lrelu_upsampling = 2, # Relative sampling rate for leaky ReLU. Ignored for final the ToRGB layer. use_radial_filters = False, # Use radially symmetric downsampling filter? Ignored for critically sampled layers. conv_clamp = 256, # Clamp the output to [-X, +X], None = disable clamping. magnitude_ema_beta = 0.999, # Decay rate for the moving average of input magnitudes. ): super().__init__() self.w_dim = w_dim self.is_torgb = is_torgb self.is_critically_sampled = is_critically_sampled self.use_fp16 = use_fp16 self.in_channels = in_channels self.out_channels = out_channels self.in_size = np.broadcast_to(np.asarray(in_size), [2]) self.out_size = np.broadcast_to(np.asarray(out_size), [2]) self.in_sampling_rate = in_sampling_rate self.out_sampling_rate = out_sampling_rate self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) self.in_cutoff = in_cutoff self.out_cutoff = out_cutoff self.in_half_width = in_half_width self.out_half_width = out_half_width self.conv_kernel = 1 if is_torgb else conv_kernel self.conv_clamp = conv_clamp self.magnitude_ema_beta = magnitude_ema_beta # Setup parameters and buffers. self.affine = FullyConnectedLayer(self.w_dim, self.in_channels, bias_init=1) self.weight = torch.nn.Parameter(torch.randn([self.out_channels, self.in_channels, self.conv_kernel, self.conv_kernel])) self.bias = torch.nn.Parameter(torch.zeros([self.out_channels])) self.register_buffer('magnitude_ema', torch.ones([])) # Design upsampling filter. self.up_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) assert self.in_sampling_rate * self.up_factor == self.tmp_sampling_rate self.up_taps = filter_size * self.up_factor if self.up_factor > 1 and not self.is_torgb else 1 self.register_buffer('up_filter', self.design_lowpass_filter( numtaps=self.up_taps, cutoff=self.in_cutoff, width=self.in_half_width*2, fs=self.tmp_sampling_rate)) # Design downsampling filter. self.down_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) assert self.out_sampling_rate * self.down_factor == self.tmp_sampling_rate self.down_taps = filter_size * self.down_factor if self.down_factor > 1 and not self.is_torgb else 1 self.down_radial = use_radial_filters and not self.is_critically_sampled self.register_buffer('down_filter', self.design_lowpass_filter( numtaps=self.down_taps, cutoff=self.out_cutoff, width=self.out_half_width*2, fs=self.tmp_sampling_rate, radial=self.down_radial)) # Compute padding. pad_total = (self.out_size - 1) * self.down_factor + 1 # Desired output size before downsampling. pad_total -= (self.in_size + self.conv_kernel - 1) * self.up_factor # Input size after upsampling. pad_total += self.up_taps + self.down_taps - 2 # Size reduction caused by the filters. pad_lo = (pad_total + self.up_factor) // 2 # Shift sample locations according to the symmetric interpretation (Appendix C.3). pad_hi = pad_total - pad_lo self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] def forward(self, x, w, noise_mode='random', force_fp32=False, update_emas=False): assert noise_mode in ['random', 'const', 'none'] # unused misc.assert_shape(x, [None, self.in_channels, int(self.in_size[1]), int(self.in_size[0])]) misc.assert_shape(w, [x.shape[0], self.w_dim]) # Track input magnitude. if update_emas: with torch.autograd.profiler.record_function('update_magnitude_ema'): magnitude_cur = x.detach().to(torch.float32).square().mean() self.magnitude_ema.copy_(magnitude_cur.lerp(self.magnitude_ema, self.magnitude_ema_beta)) input_gain = self.magnitude_ema.rsqrt() # Execute affine layer. styles = self.affine(w) if self.is_torgb: weight_gain = 1 / np.sqrt(self.in_channels * (self.conv_kernel ** 2)) styles = styles * weight_gain # Execute modulated conv2d. dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 x = modulated_conv2d(x=x.to(dtype), w=self.weight, s=styles, padding=self.conv_kernel-1, demodulate=(not self.is_torgb), input_gain=input_gain) # Execute bias, filtered leaky ReLU, and clamping. gain = 1 if self.is_torgb else np.sqrt(2) slope = 1 if self.is_torgb else 0.2 x = filtered_lrelu.filtered_lrelu(x=x, fu=self.up_filter, fd=self.down_filter, b=self.bias.to(x.dtype), up=self.up_factor, down=self.down_factor, padding=self.padding, gain=gain, slope=slope, clamp=self.conv_clamp) # Ensure correct shape and dtype. misc.assert_shape(x, [None, self.out_channels, int(self.out_size[1]), int(self.out_size[0])]) assert x.dtype == dtype return x @staticmethod def design_lowpass_filter(numtaps, cutoff, width, fs, radial=False): assert numtaps >= 1 # Identity filter. if numtaps == 1: return None # Separable Kaiser low-pass filter. if not radial: f = scipy.signal.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) return torch.as_tensor(f, dtype=torch.float32) # Radially symmetric jinc-based filter. x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs r = np.hypot(*np.meshgrid(x, x)) f = scipy.special.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) beta = scipy.signal.kaiser_beta(scipy.signal.kaiser_atten(numtaps, width / (fs / 2))) w = np.kaiser(numtaps, beta) f *= np.outer(w, w) f /= np.sum(f) return torch.as_tensor(f, dtype=torch.float32) def extra_repr(self): return '\n'.join([ f'w_dim={self.w_dim:d}, is_torgb={self.is_torgb},', f'is_critically_sampled={self.is_critically_sampled}, use_fp16={self.use_fp16},', f'in_sampling_rate={self.in_sampling_rate:g}, out_sampling_rate={self.out_sampling_rate:g},', f'in_cutoff={self.in_cutoff:g}, out_cutoff={self.out_cutoff:g},', f'in_half_width={self.in_half_width:g}, out_half_width={self.out_half_width:g},', f'in_size={list(self.in_size)}, out_size={list(self.out_size)},', f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}']) #---------------------------------------------------------------------------- @persistence.persistent_class class SynthesisNetwork(torch.nn.Module): def __init__(self, w_dim, # Intermediate latent (W) dimensionality. img_resolution, # Output image resolution. img_channels, # Number of color channels. channel_base = 32768, # Overall multiplier for the number of channels. channel_max = 512, # Maximum number of channels in any layer. num_layers = 14, # Total number of layers, excluding Fourier features and ToRGB. num_critical = 2, # Number of critically sampled layers at the end. first_cutoff = 2, # Cutoff frequency of the first layer (f_{c,0}). first_stopband = 2**2.1, # Minimum stopband of the first layer (f_{t,0}). last_stopband_rel = 2**0.3, # Minimum stopband of the last layer, expressed relative to the cutoff. margin_size = 10, # Number of additional pixels outside the image. output_scale = 0.25, # Scale factor for the output image. num_fp16_res = 4, # Use FP16 for the N highest resolutions. **layer_kwargs, # Arguments for SynthesisLayer. ): super().__init__() self.w_dim = w_dim self.num_ws = num_layers + 2 self.img_resolution = img_resolution self.img_channels = img_channels self.num_layers = num_layers self.num_critical = num_critical self.margin_size = margin_size self.output_scale = output_scale self.num_fp16_res = num_fp16_res # Geometric progression of layer cutoffs and min. stopbands. last_cutoff = self.img_resolution / 2 # f_{c,N} last_stopband = last_cutoff * last_stopband_rel # f_{t,N} exponents = np.minimum(np.arange(self.num_layers + 1) / (self.num_layers - self.num_critical), 1) cutoffs = first_cutoff * (last_cutoff / first_cutoff) ** exponents # f_c[i] stopbands = first_stopband * (last_stopband / first_stopband) ** exponents # f_t[i] # Compute remaining layer parameters. sampling_rates = np.exp2(np.ceil(np.log2(np.minimum(stopbands * 2, self.img_resolution)))) # s[i] half_widths = np.maximum(stopbands, sampling_rates / 2) - cutoffs # f_h[i] sizes = sampling_rates + self.margin_size * 2 sizes[-2:] = self.img_resolution channels = np.rint(np.minimum((channel_base / 2) / cutoffs, channel_max)) channels[-1] = self.img_channels # Construct layers. self.input = SynthesisInput( w_dim=self.w_dim, channels=int(channels[0]), size=int(sizes[0]), sampling_rate=sampling_rates[0], bandwidth=cutoffs[0]) self.layer_names = [] for idx in range(self.num_layers + 1): prev = max(idx - 1, 0) is_torgb = (idx == self.num_layers) is_critically_sampled = (idx >= self.num_layers - self.num_critical) use_fp16 = (sampling_rates[idx] * (2 ** self.num_fp16_res) > self.img_resolution) layer = SynthesisLayer( w_dim=self.w_dim, is_torgb=is_torgb, is_critically_sampled=is_critically_sampled, use_fp16=use_fp16, in_channels=int(channels[prev]), out_channels= int(channels[idx]), in_size=int(sizes[prev]), out_size=int(sizes[idx]), in_sampling_rate=int(sampling_rates[prev]), out_sampling_rate=int(sampling_rates[idx]), in_cutoff=cutoffs[prev], out_cutoff=cutoffs[idx], in_half_width=half_widths[prev], out_half_width=half_widths[idx], **layer_kwargs) name = f'L{idx}_{layer.out_size[0]}_{layer.out_channels}' setattr(self, name, layer) self.layer_names.append(name) def forward(self, ws, return_feature=False, **layer_kwargs): features = [] misc.assert_shape(ws, [None, self.num_ws, self.w_dim]) ws = ws.to(torch.float32).unbind(dim=1) # Execute layers. x = self.input(ws[0]) for name, w in zip(self.layer_names, ws[1:]): x = getattr(self, name)(x, w, **layer_kwargs) features.append(x) if self.output_scale != 1: x = x * self.output_scale # Ensure correct shape and dtype. misc.assert_shape(x, [None, self.img_channels, self.img_resolution, self.img_resolution]) x = x.to(torch.float32) if return_feature: return x, features else: return x def extra_repr(self): return '\n'.join([ f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},', f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},', f'num_layers={self.num_layers:d}, num_critical={self.num_critical:d},', f'margin_size={self.margin_size:d}, num_fp16_res={self.num_fp16_res:d}']) #---------------------------------------------------------------------------- @persistence.persistent_class class Generator(torch.nn.Module): def __init__(self, z_dim, # Input latent (Z) dimensionality. c_dim, # Conditioning label (C) dimensionality. w_dim, # Intermediate latent (W) dimensionality. img_resolution, # Output resolution. img_channels, # Number of output color channels. mapping_kwargs = {}, # Arguments for MappingNetwork. resize=None, **synthesis_kwargs, # Arguments for SynthesisNetwork. ): super().__init__() self.z_dim = z_dim self.c_dim = c_dim self.w_dim = w_dim self.img_resolution = img_resolution self.img_channels = img_channels self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs) self.num_ws = self.synthesis.num_ws self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) self.resize = resize def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, input_is_w=False, return_feature=False, **synthesis_kwargs): if input_is_w: ws = z if ws.dim() == 2: ws = ws.unsqueeze(1).repeat([1, self.mapping.num_ws, 1]) else: ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) img = self.synthesis(ws, update_emas=update_emas, return_feature=return_feature, **synthesis_kwargs) if return_feature: img, feature = img if self.resize is not None: img = imresize(img, [self.resize, self.resize]) if return_feature: return img, feature else: return img #---------------------------------------------------------------------------- def imresize(image, size): dim = image.dim() if dim == 3: image = image.unsqueeze(1) b, _, h, w = image.shape if size[0] > h: image = F.interpolate(image, size, mode='bilinear') elif size[0] < h: image = F.interpolate(image, size, mode='area') if dim == 3: image = image.squeeze(1) return image ================================================ FILE: training/training_loop.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Main training loop.""" import os import time import copy import json import pickle import psutil import PIL.Image import numpy as np import torch import dnnlib from torch_utils import misc from torch_utils import training_stats from torch_utils.ops import conv2d_gradfix from torch_utils.ops import grid_sample_gradfix import legacy from metrics import metric_main #---------------------------------------------------------------------------- def setup_snapshot_image_grid(training_set, random_seed=0): rnd = np.random.RandomState(random_seed) gw = np.clip(7680 // training_set.image_shape[2], 7, 32) gh = np.clip(4320 // training_set.image_shape[1], 4, 32) # No labels => show random subset of training samples. if not training_set.has_labels: all_indices = list(range(len(training_set))) rnd.shuffle(all_indices) grid_indices = [all_indices[i % len(all_indices)] for i in range(gw * gh)] else: # Group training samples by label. label_groups = dict() # label => [idx, ...] for idx in range(len(training_set)): label = tuple(training_set.get_details(idx).raw_label.flat[::-1]) if label not in label_groups: label_groups[label] = [] label_groups[label].append(idx) # Reorder. label_order = sorted(label_groups.keys()) for label in label_order: rnd.shuffle(label_groups[label]) # Organize into grid. grid_indices = [] for y in range(gh): label = label_order[y % len(label_order)] indices = label_groups[label] grid_indices += [indices[x % len(indices)] for x in range(gw)] label_groups[label] = [indices[(i + gw) % len(indices)] for i in range(len(indices))] # Load data. images, labels = zip(*[training_set[i] for i in grid_indices]) return (gw, gh), np.stack(images), np.stack(labels) #---------------------------------------------------------------------------- def save_image_grid(img, fname, drange, grid_size): lo, hi = drange img = np.asarray(img, dtype=np.float32) img = (img - lo) * (255 / (hi - lo)) img = np.rint(img).clip(0, 255).astype(np.uint8) gw, gh = grid_size _N, C, H, W = img.shape img = img.reshape([gh, gw, C, H, W]) img = img.transpose(0, 3, 1, 4, 2) img = img.reshape([gh * H, gw * W, C]) assert C in [1, 3] if C == 1: PIL.Image.fromarray(img[:, :, 0], 'L').save(fname) if C == 3: PIL.Image.fromarray(img, 'RGB').save(fname) #---------------------------------------------------------------------------- def training_loop( run_dir = '.', # Output directory. training_set_kwargs = {}, # Options for training set. data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader. G_kwargs = {}, # Options for generator network. D_kwargs = {}, # Options for discriminator network. G_opt_kwargs = {}, # Options for generator optimizer. D_opt_kwargs = {}, # Options for discriminator optimizer. augment_kwargs = None, # Options for augmentation pipeline. None = disable. loss_kwargs = {}, # Options for loss function. metrics = [], # Metrics to evaluate during training. random_seed = 0, # Global random seed. num_gpus = 1, # Number of GPUs participating in the training. rank = 0, # Rank of the current process in [0, num_gpus[. batch_size = 4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus. batch_gpu = 4, # Number of samples processed at a time by one GPU. ema_kimg = 10, # Half-life of the exponential moving average (EMA) of generator weights. ema_rampup = 0.05, # EMA ramp-up coefficient. None = no rampup. G_reg_interval = None, # How often to perform regularization for G? None = disable lazy regularization. D_reg_interval = 16, # How often to perform regularization for D? None = disable lazy regularization. augment_p = 0, # Initial value of augmentation probability. ada_target = None, # ADA target value. None = fixed p. ada_interval = 4, # How often to perform ADA adjustment? ada_kimg = 500, # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit. total_kimg = 25000, # Total length of the training, measured in thousands of real images. kimg_per_tick = 4, # Progress snapshot interval. image_snapshot_ticks = 50, # How often to save image snapshots? None = disable. network_snapshot_ticks = 50, # How often to save network snapshots? None = disable. resume_pkl = None, # Network pickle to resume training from. resume_kimg = 0, # First kimg to report when resuming training. cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark? abort_fn = None, # Callback function for determining whether to abort training. Must return consistent results across ranks. progress_fn = None, # Callback function for updating training progress. Called for all ranks. ): # Initialize. start_time = time.time() device = torch.device('cuda', rank) np.random.seed(random_seed * num_gpus + rank) torch.manual_seed(random_seed * num_gpus + rank) torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed. torch.backends.cuda.matmul.allow_tf32 = False # Improves numerical accuracy. torch.backends.cudnn.allow_tf32 = False # Improves numerical accuracy. conv2d_gradfix.enabled = True # Improves training speed. grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe. # Load training set. if rank == 0: print('Loading training set...') training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs) # subclass of training.dataset.Dataset training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed) training_set_iterator = iter(torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//num_gpus, **data_loader_kwargs)) if rank == 0: print() print('Num images: ', len(training_set)) print('Image shape:', training_set.image_shape) print('Label shape:', training_set.label_shape) print() # Construct networks. if rank == 0: print('Constructing networks...') common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels) G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module G_ema = copy.deepcopy(G).eval() # Resume from existing pickle. if (resume_pkl is not None) and (rank == 0): print(f'Resuming from "{resume_pkl}"') with dnnlib.util.open_url(resume_pkl) as f: resume_data = legacy.load_network_pkl(f) for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]: misc.copy_params_and_buffers(resume_data[name], module, require_all=False) # Print network summary tables. if rank == 0: z = torch.empty([batch_gpu, G.z_dim], device=device) c = torch.empty([batch_gpu, G.c_dim], device=device) img = misc.print_module_summary(G, [z, c]) misc.print_module_summary(D, [img, c]) # Setup augmentation. if rank == 0: print('Setting up augmentation...') augment_pipe = None ada_stats = None if (augment_kwargs is not None) and (augment_p > 0 or ada_target is not None): augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module augment_pipe.p.copy_(torch.as_tensor(augment_p)) if ada_target is not None: ada_stats = training_stats.Collector(regex='Loss/signs/real') # Distribute across GPUs. if rank == 0: print(f'Distributing across {num_gpus} GPUs...') for module in [G, D, G_ema, augment_pipe]: if module is not None and num_gpus > 1: for param in misc.params_and_buffers(module): torch.distributed.broadcast(param, src=0) # Setup training phases. if rank == 0: print('Setting up training phases...') loss = dnnlib.util.construct_class_by_name(device=device, G=G, D=D, augment_pipe=augment_pipe, **loss_kwargs) # subclass of training.loss.Loss phases = [] for name, module, opt_kwargs, reg_interval in [('G', G, G_opt_kwargs, G_reg_interval), ('D', D, D_opt_kwargs, D_reg_interval)]: if reg_interval is None: opt = dnnlib.util.construct_class_by_name(params=module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer phases += [dnnlib.EasyDict(name=name+'both', module=module, opt=opt, interval=1)] else: # Lazy regularization. mb_ratio = reg_interval / (reg_interval + 1) opt_kwargs = dnnlib.EasyDict(opt_kwargs) opt_kwargs.lr = opt_kwargs.lr * mb_ratio opt_kwargs.betas = [beta ** mb_ratio for beta in opt_kwargs.betas] opt = dnnlib.util.construct_class_by_name(module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer phases += [dnnlib.EasyDict(name=name+'main', module=module, opt=opt, interval=1)] phases += [dnnlib.EasyDict(name=name+'reg', module=module, opt=opt, interval=reg_interval)] for phase in phases: phase.start_event = None phase.end_event = None if rank == 0: phase.start_event = torch.cuda.Event(enable_timing=True) phase.end_event = torch.cuda.Event(enable_timing=True) # Export sample images. grid_size = None grid_z = None grid_c = None if rank == 0: print('Exporting sample images...') grid_size, images, labels = setup_snapshot_image_grid(training_set=training_set) save_image_grid(images, os.path.join(run_dir, 'reals.png'), drange=[0,255], grid_size=grid_size) grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu) grid_c = torch.from_numpy(labels).to(device).split(batch_gpu) images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy() save_image_grid(images, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size) # Initialize logs. if rank == 0: print('Initializing logs...') stats_collector = training_stats.Collector(regex='.*') stats_metrics = dict() stats_jsonl = None stats_tfevents = None if rank == 0: stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt') try: import torch.utils.tensorboard as tensorboard stats_tfevents = tensorboard.SummaryWriter(run_dir) except ImportError as err: print('Skipping tfevents export:', err) # Train. if rank == 0: print(f'Training for {total_kimg} kimg...') print() cur_nimg = resume_kimg * 1000 cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() maintenance_time = tick_start_time - start_time batch_idx = 0 if progress_fn is not None: progress_fn(0, total_kimg) while True: # Fetch training data. with torch.autograd.profiler.record_function('data_fetch'): phase_real_img, phase_real_c = next(training_set_iterator) phase_real_img = (phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu) phase_real_c = phase_real_c.to(device).split(batch_gpu) all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim], device=device) all_gen_z = [phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split(batch_size)] all_gen_c = [training_set.get_label(np.random.randint(len(training_set))) for _ in range(len(phases) * batch_size)] all_gen_c = torch.from_numpy(np.stack(all_gen_c)).pin_memory().to(device) all_gen_c = [phase_gen_c.split(batch_gpu) for phase_gen_c in all_gen_c.split(batch_size)] # Execute training phases. for phase, phase_gen_z, phase_gen_c in zip(phases, all_gen_z, all_gen_c): if batch_idx % phase.interval != 0: continue if phase.start_event is not None: phase.start_event.record(torch.cuda.current_stream(device)) # Accumulate gradients. phase.opt.zero_grad(set_to_none=True) phase.module.requires_grad_(True) for real_img, real_c, gen_z, gen_c in zip(phase_real_img, phase_real_c, phase_gen_z, phase_gen_c): loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, gain=phase.interval, cur_nimg=cur_nimg) phase.module.requires_grad_(False) # Update weights. with torch.autograd.profiler.record_function(phase.name + '_opt'): params = [param for param in phase.module.parameters() if param.grad is not None] if len(params) > 0: flat = torch.cat([param.grad.flatten() for param in params]) if num_gpus > 1: torch.distributed.all_reduce(flat) flat /= num_gpus misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) grads = flat.split([param.numel() for param in params]) for param, grad in zip(params, grads): param.grad = grad.reshape(param.shape) phase.opt.step() # Phase done. if phase.end_event is not None: phase.end_event.record(torch.cuda.current_stream(device)) # Update G_ema. with torch.autograd.profiler.record_function('Gema'): ema_nimg = ema_kimg * 1000 if ema_rampup is not None: ema_nimg = min(ema_nimg, cur_nimg * ema_rampup) ema_beta = 0.5 ** (batch_size / max(ema_nimg, 1e-8)) for p_ema, p in zip(G_ema.parameters(), G.parameters()): p_ema.copy_(p.lerp(p_ema, ema_beta)) for b_ema, b in zip(G_ema.buffers(), G.buffers()): b_ema.copy_(b) # Update state. cur_nimg += batch_size batch_idx += 1 # Execute ADA heuristic. if (ada_stats is not None) and (batch_idx % ada_interval == 0): ada_stats.update() adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) * (batch_size * ada_interval) / (ada_kimg * 1000) augment_pipe.p.copy_((augment_pipe.p + adjust).max(misc.constant(0, device=device))) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000): continue # Print status line, accumulating the same information in training_stats. tick_end_time = time.time() fields = [] fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"] fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}"] fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"] fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"] fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"] fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"] fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"] fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"] fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"] torch.cuda.reset_peak_memory_stats() fields += [f"augment {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.3f}"] training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60)) training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60)) if rank == 0: print(' '.join(fields)) # Check for abort. if (not done) and (abort_fn is not None) and abort_fn(): done = True if rank == 0: print() print('Aborting...') # Save image snapshot. if (rank == 0) and (image_snapshot_ticks is not None) and (done or cur_tick % image_snapshot_ticks == 0): images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy() save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'), drange=[-1,1], grid_size=grid_size) # Save network snapshot. snapshot_pkl = None snapshot_data = None if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0): snapshot_data = dict(G=G, D=D, G_ema=G_ema, augment_pipe=augment_pipe, training_set_kwargs=dict(training_set_kwargs)) for key, value in snapshot_data.items(): if isinstance(value, torch.nn.Module): value = copy.deepcopy(value).eval().requires_grad_(False) if num_gpus > 1: misc.check_ddp_consistency(value, ignore_regex=r'.*\.[^.]+_(avg|ema)') for param in misc.params_and_buffers(value): torch.distributed.broadcast(param, src=0) snapshot_data[key] = value.cpu() del value # conserve memory snapshot_pkl = os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl') if rank == 0: with open(snapshot_pkl, 'wb') as f: pickle.dump(snapshot_data, f) # Evaluate metrics. if (snapshot_data is not None) and (len(metrics) > 0): if rank == 0: print('Evaluating metrics...') for metric in metrics: result_dict = metric_main.calc_metric(metric=metric, G=snapshot_data['G_ema'], dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, rank=rank, device=device) if rank == 0: metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl) stats_metrics.update(result_dict.results) del snapshot_data # conserve memory # Collect statistics. for phase in phases: value = [] if (phase.start_event is not None) and (phase.end_event is not None): phase.end_event.synchronize() value = phase.start_event.elapsed_time(phase.end_event) training_stats.report0('Timing/' + phase.name, value) stats_collector.update() stats_dict = stats_collector.as_dict() # Update logs. timestamp = time.time() if stats_jsonl is not None: fields = dict(stats_dict, timestamp=timestamp) stats_jsonl.write(json.dumps(fields) + '\n') stats_jsonl.flush() if stats_tfevents is not None: global_step = int(cur_nimg / 1e3) walltime = timestamp - start_time for name, value in stats_dict.items(): stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime) for name, value in stats_metrics.items(): stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime) stats_tfevents.flush() if progress_fn is not None: progress_fn(cur_nimg // 1000, total_kimg) # Update state. cur_tick += 1 tick_start_nimg = cur_nimg tick_start_time = time.time() maintenance_time = tick_start_time - tick_end_time if done: break # Done. if rank == 0: print() print('Exiting...') #---------------------------------------------------------------------------- ================================================ FILE: visualizer_drag.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import click import os import multiprocessing import numpy as np import torch import imgui import dnnlib from gui_utils import imgui_window from gui_utils import imgui_utils from gui_utils import gl_utils from gui_utils import text_utils from viz import renderer from viz import pickle_widget from viz import latent_widget from viz import drag_widget from viz import capture_widget #---------------------------------------------------------------------------- class Visualizer(imgui_window.ImguiWindow): def __init__(self, capture_dir=None): super().__init__(title='DragGAN', window_width=3840, window_height=2160) # Internals. self._last_error_print = None self._async_renderer = AsyncRenderer() self._defer_rendering = 0 self._tex_img = None self._tex_obj = None self._mask_obj = None self._image_area = None self._status = dnnlib.EasyDict() # Widget interface. self.args = dnnlib.EasyDict() self.result = dnnlib.EasyDict() self.pane_w = 0 self.label_w = 0 self.button_w = 0 self.image_w = 0 self.image_h = 0 # Widgets. self.pickle_widget = pickle_widget.PickleWidget(self) self.latent_widget = latent_widget.LatentWidget(self) self.drag_widget = drag_widget.DragWidget(self) self.capture_widget = capture_widget.CaptureWidget(self) if capture_dir is not None: self.capture_widget.path = capture_dir # Initialize window. self.set_position(0, 0) self._adjust_font_size() self.skip_frame() # Layout may change after first frame. def close(self): super().close() if self._async_renderer is not None: self._async_renderer.close() self._async_renderer = None def add_recent_pickle(self, pkl, ignore_errors=False): self.pickle_widget.add_recent(pkl, ignore_errors=ignore_errors) def load_pickle(self, pkl, ignore_errors=False): self.pickle_widget.load(pkl, ignore_errors=ignore_errors) def print_error(self, error): error = str(error) if error != self._last_error_print: print('\n' + error + '\n') self._last_error_print = error def defer_rendering(self, num_frames=1): self._defer_rendering = max(self._defer_rendering, num_frames) def clear_result(self): self._async_renderer.clear_result() def set_async(self, is_async): if is_async != self._async_renderer.is_async: self._async_renderer.set_async(is_async) self.clear_result() if 'image' in self.result: self.result.message = 'Switching rendering process...' self.defer_rendering() def _adjust_font_size(self): old = self.font_size self.set_font_size(min(self.content_width / 120, self.content_height / 60)) if self.font_size != old: self.skip_frame() # Layout changed. def check_update_mask(self, **args): update_mask = False if 'pkl' in self._status: if self._status.pkl != args['pkl']: update_mask = True self._status.pkl = args['pkl'] if 'w0_seed' in self._status: if self._status.w0_seed != args['w0_seed']: update_mask = True self._status.w0_seed = args['w0_seed'] return update_mask def capture_image_frame(self): self.capture_next_frame() captured_frame = self.pop_captured_frame() captured_image = None if captured_frame is not None: x1, y1, w, h = self._image_area captured_image = captured_frame[y1:y1+h, x1:x1+w, :] return captured_image def get_drag_info(self): seed = self.latent_widget.seed points = self.drag_widget.points targets = self.drag_widget.targets mask = self.drag_widget.mask w = self._async_renderer._renderer_obj.w return seed, points, targets, mask, w def draw_frame(self): self.begin_frame() self.args = dnnlib.EasyDict() self.pane_w = self.font_size * 18 self.button_w = self.font_size * 5 self.label_w = round(self.font_size * 4.5) # Detect mouse dragging in the result area. if self._image_area is not None: if not hasattr(self.drag_widget, 'width'): self.drag_widget.init_mask(self.image_w, self.image_h) clicked, down, img_x, img_y = imgui_utils.click_hidden_window( '##image_area', self._image_area[0], self._image_area[1], self._image_area[2], self._image_area[3], self.image_w, self.image_h) self.drag_widget.action(clicked, down, img_x, img_y) # Begin control pane. imgui.set_next_window_position(0, 0) imgui.set_next_window_size(self.pane_w, self.content_height) imgui.begin('##control_pane', closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE)) # Widgets. expanded, _visible = imgui_utils.collapsing_header('Network & latent', default=True) self.pickle_widget(expanded) self.latent_widget(expanded) expanded, _visible = imgui_utils.collapsing_header('Drag', default=True) self.drag_widget(expanded) expanded, _visible = imgui_utils.collapsing_header('Capture', default=True) self.capture_widget(expanded) # Render. if self.is_skipping_frames(): pass elif self._defer_rendering > 0: self._defer_rendering -= 1 elif self.args.pkl is not None: self._async_renderer.set_args(**self.args) result = self._async_renderer.get_result() if result is not None: self.result = result if 'stop' in self.result and self.result.stop: self.drag_widget.stop_drag() if 'points' in self.result: self.drag_widget.set_points(self.result.points) if 'init_net' in self.result: if self.result.init_net: self.drag_widget.reset_point() # Display. max_w = self.content_width - self.pane_w max_h = self.content_height pos = np.array([self.pane_w + max_w / 2, max_h / 2]) if 'image' in self.result: # Reset mask after loading a new pickle or changing seed. if self.check_update_mask(**self.args): h, w, _ = self.result.image.shape self.drag_widget.init_mask(w, h) if self._tex_img is not self.result.image: self._tex_img = self.result.image if self._tex_obj is None or not self._tex_obj.is_compatible(image=self._tex_img): self._tex_obj = gl_utils.Texture(image=self._tex_img, bilinear=False, mipmap=False) else: self._tex_obj.update(self._tex_img) self.image_h, self.image_w = self._tex_obj.height, self._tex_obj.width zoom = min(max_w / self._tex_obj.width, max_h / self._tex_obj.height) zoom = np.floor(zoom) if zoom >= 1 else zoom self._tex_obj.draw(pos=pos, zoom=zoom, align=0.5, rint=True) if self.drag_widget.show_mask and hasattr(self.drag_widget, 'mask'): mask = ((1-self.drag_widget.mask.unsqueeze(-1)) * 255).to(torch.uint8) if self._mask_obj is None or not self._mask_obj.is_compatible(image=self._tex_img): self._mask_obj = gl_utils.Texture(image=mask, bilinear=False, mipmap=False) else: self._mask_obj.update(mask) self._mask_obj.draw(pos=pos, zoom=zoom, align=0.5, rint=True, alpha=0.15) if self.drag_widget.mode in ['flexible', 'fixed']: posx, posy = imgui.get_mouse_pos() if posx >= self.pane_w: pos_c = np.array([posx, posy]) gl_utils.draw_circle(center=pos_c, radius=self.drag_widget.r_mask * zoom, alpha=0.5) rescale = self._tex_obj.width / 512 * zoom for point in self.drag_widget.targets: pos_x = self.pane_w + max_w / 2 + (point[1] - self.image_w//2) * zoom pos_y = max_h / 2 + (point[0] - self.image_h//2) * zoom gl_utils.draw_circle(center=np.array([pos_x, pos_y]), color=[0,0,1], radius=9 * rescale) for point in self.drag_widget.points: pos_x = self.pane_w + max_w / 2 + (point[1] - self.image_w//2) * zoom pos_y = max_h / 2 + (point[0] - self.image_h//2) * zoom gl_utils.draw_circle(center=np.array([pos_x, pos_y]), color=[1,0,0], radius=9 * rescale) for point, target in zip(self.drag_widget.points, self.drag_widget.targets): t_x = self.pane_w + max_w / 2 + (target[1] - self.image_w//2) * zoom t_y = max_h / 2 + (target[0] - self.image_h//2) * zoom p_x = self.pane_w + max_w / 2 + (point[1] - self.image_w//2) * zoom p_y = max_h / 2 + (point[0] - self.image_h//2) * zoom gl_utils.draw_arrow(p_x, p_y, t_x, t_y, l=8 * rescale, width = 3 * rescale) imshow_w = int(self._tex_obj.width * zoom) imshow_h = int(self._tex_obj.height * zoom) self._image_area = [int(self.pane_w + max_w / 2 - imshow_w / 2), int(max_h / 2 - imshow_h / 2), imshow_w, imshow_h] if 'error' in self.result: self.print_error(self.result.error) if 'message' not in self.result: self.result.message = str(self.result.error) if 'message' in self.result: tex = text_utils.get_texture(self.result.message, size=self.font_size, max_width=max_w, max_height=max_h, outline=2) tex.draw(pos=pos, align=0.5, rint=True, color=1) # End frame. self._adjust_font_size() imgui.end() self.end_frame() #---------------------------------------------------------------------------- class AsyncRenderer: def __init__(self): self._closed = False self._is_async = False self._cur_args = None self._cur_result = None self._cur_stamp = 0 self._renderer_obj = None self._args_queue = None self._result_queue = None self._process = None def close(self): self._closed = True self._renderer_obj = None if self._process is not None: self._process.terminate() self._process = None self._args_queue = None self._result_queue = None @property def is_async(self): return self._is_async def set_async(self, is_async): self._is_async = is_async def set_args(self, **args): assert not self._closed args2 = args.copy() args_mask = args2.pop('mask') if self._cur_args: _cur_args = self._cur_args.copy() cur_args_mask = _cur_args.pop('mask') else: _cur_args = self._cur_args # if args != self._cur_args: if args2 != _cur_args: if self._is_async: self._set_args_async(**args) else: self._set_args_sync(**args) self._cur_args = args def _set_args_async(self, **args): if self._process is None: self._args_queue = multiprocessing.Queue() self._result_queue = multiprocessing.Queue() try: multiprocessing.set_start_method('spawn') except RuntimeError: pass self._process = multiprocessing.Process(target=self._process_fn, args=(self._args_queue, self._result_queue), daemon=True) self._process.start() self._args_queue.put([args, self._cur_stamp]) def _set_args_sync(self, **args): if self._renderer_obj is None: self._renderer_obj = renderer.Renderer() self._cur_result = self._renderer_obj.render(**args) def get_result(self): assert not self._closed if self._result_queue is not None: while self._result_queue.qsize() > 0: result, stamp = self._result_queue.get() if stamp == self._cur_stamp: self._cur_result = result return self._cur_result def clear_result(self): assert not self._closed self._cur_args = None self._cur_result = None self._cur_stamp += 1 @staticmethod def _process_fn(args_queue, result_queue): renderer_obj = renderer.Renderer() cur_args = None cur_stamp = None while True: args, stamp = args_queue.get() while args_queue.qsize() > 0: args, stamp = args_queue.get() if args != cur_args or stamp != cur_stamp: result = renderer_obj.render(**args) if 'error' in result: result.error = renderer.CapturedException(result.error) result_queue.put([result, stamp]) cur_args = args cur_stamp = stamp #---------------------------------------------------------------------------- @click.command() @click.argument('pkls', metavar='PATH', nargs=-1) @click.option('--capture-dir', help='Where to save screenshot captures', metavar='PATH', default=None) @click.option('--browse-dir', help='Specify model path for the \'Browse...\' button', metavar='PATH') def main( pkls, capture_dir, browse_dir ): """Interactive model visualizer. Optional PATH argument can be used specify which .pkl file to load. """ viz = Visualizer(capture_dir=capture_dir) if browse_dir is not None: viz.pickle_widget.search_dirs = [browse_dir] # List pickles. if len(pkls) > 0: for pkl in pkls: viz.add_recent_pickle(pkl) viz.load_pickle(pkls[0]) else: pretrained = [ 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-afhqcat-512x512.pkl', 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-afhqdog-512x512.pkl', 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-afhqv2-512x512.pkl', 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-afhqwild-512x512.pkl', 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-brecahad-512x512.pkl', 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-celebahq-256x256.pkl', 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-cifar10-32x32.pkl', 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-1024x1024.pkl', 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-256x256.pkl', 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-512x512.pkl', 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhqu-1024x1024.pkl', 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhqu-256x256.pkl', 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-lsundog-256x256.pkl', 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-metfaces-1024x1024.pkl', 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-metfacesu-1024x1024.pkl' ] # Populate recent pickles list with pretrained model URLs. for url in pretrained: viz.add_recent_pickle(url) # Run. while not viz.should_close(): viz.draw_frame() viz.close() #---------------------------------------------------------------------------- if __name__ == "__main__": main() #---------------------------------------------------------------------------- ================================================ FILE: visualizer_drag_gradio.py ================================================ import os import os.path as osp from argparse import ArgumentParser from functools import partial import gradio as gr import numpy as np import torch from PIL import Image import dnnlib from gradio_utils import (ImageMask, draw_mask_on_image, draw_points_on_image, get_latest_points_pair, get_valid_mask, on_change_single_global_state) from viz.renderer import Renderer, add_watermark_np parser = ArgumentParser() parser.add_argument('--share', action='store_true',default='True') parser.add_argument('--cache-dir', type=str, default='./checkpoints') parser.add_argument( "--listen", action="store_true", help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests", ) args = parser.parse_args() cache_dir = args.cache_dir device = 'cuda' def reverse_point_pairs(points): new_points = [] for p in points: new_points.append([p[1], p[0]]) return new_points def clear_state(global_state, target=None): """Clear target history state from global_state If target is not defined, points and mask will be both removed. 1. set global_state['points'] as empty dict 2. set global_state['mask'] as full-one mask. """ if target is None: target = ['point', 'mask'] if not isinstance(target, list): target = [target] if 'point' in target: global_state['points'] = dict() print('Clear Points State!') if 'mask' in target: image_raw = global_state["images"]["image_raw"] global_state['mask'] = np.ones((image_raw.size[1], image_raw.size[0]), dtype=np.uint8) print('Clear mask State!') return global_state def init_images(global_state): """This function is called only ones with Gradio App is started. 0. pre-process global_state, unpack value from global_state of need 1. Re-init renderer 2. run `renderer._render_drag_impl` with `is_drag=False` to generate new image 3. Assign images to global state and re-generate mask """ if isinstance(global_state, gr.State): state = global_state.value else: state = global_state state['renderer'].init_network( state['generator_params'], # res valid_checkpoints_dict[state['pretrained_weight']], # pkl state['params']['seed'], # w0_seed, None, # w_load state['params']['latent_space'] == 'w+', # w_plus 'const', state['params']['trunc_psi'], # trunc_psi, state['params']['trunc_cutoff'], # trunc_cutoff, None, # input_transform state['params']['lr'] # lr, ) state['renderer']._render_drag_impl(state['generator_params'], is_drag=False, to_pil=True) init_image = state['generator_params'].image state['images']['image_orig'] = init_image state['images']['image_raw'] = init_image state['images']['image_show'] = Image.fromarray( add_watermark_np(np.array(init_image))) state['mask'] = np.ones((init_image.size[1], init_image.size[0]), dtype=np.uint8) return global_state def update_image_draw(image, points, mask, show_mask, global_state=None): image_draw = draw_points_on_image(image, points) if show_mask and mask is not None and not (mask == 0).all() and not ( mask == 1).all(): image_draw = draw_mask_on_image(image_draw, mask) image_draw = Image.fromarray(add_watermark_np(np.array(image_draw))) if global_state is not None: global_state['images']['image_show'] = image_draw return image_draw def preprocess_mask_info(global_state, image): """Function to handle mask information. 1. last_mask is None: Do not need to change mask, return mask 2. last_mask is not None: 2.1 global_state is remove_mask: 2.2 global_state is add_mask: """ if isinstance(image, dict): last_mask = get_valid_mask(image['mask']) else: last_mask = None mask = global_state['mask'] # mask in global state is a placeholder with all 1. if (mask == 1).all(): mask = last_mask # last_mask = global_state['last_mask'] editing_mode = global_state['editing_state'] if last_mask is None: return global_state if editing_mode == 'remove_mask': updated_mask = np.clip(mask - last_mask, 0, 1) print(f'Last editing_state is {editing_mode}, do remove.') elif editing_mode == 'add_mask': updated_mask = np.clip(mask + last_mask, 0, 1) print(f'Last editing_state is {editing_mode}, do add.') else: updated_mask = mask print(f'Last editing_state is {editing_mode}, ' 'do nothing to mask.') global_state['mask'] = updated_mask # global_state['last_mask'] = None # clear buffer return global_state valid_checkpoints_dict = { f.split('/')[-1].split('.')[0]: osp.join(cache_dir, f) for f in os.listdir(cache_dir) if (f.endswith('pkl') and osp.exists(osp.join(cache_dir, f))) } print(f'File under cache_dir ({cache_dir}):') print(os.listdir(cache_dir)) print('Valid checkpoint file:') print(valid_checkpoints_dict) init_pkl = 'stylegan2_lions_512_pytorch' with gr.Blocks() as app: # renderer = Renderer() global_state = gr.State({ "images": { # image_orig: the original image, change with seed/model is changed # image_raw: image with mask and points, change durning optimization # image_show: image showed on screen }, "temporal_params": { # stop }, 'mask': None, # mask for visualization, 1 for editing and 0 for unchange 'last_mask': None, # last edited mask 'show_mask': True, # add button "generator_params": dnnlib.EasyDict(), "params": { "seed": 0, "motion_lambda": 20, "r1_in_pixels": 3, "r2_in_pixels": 12, "magnitude_direction_in_pixels": 1.0, "latent_space": "w+", "trunc_psi": 0.7, "trunc_cutoff": None, "lr": 0.001, }, "device": device, "draw_interval": 1, "renderer": Renderer(disable_timing=True), "points": {}, "curr_point": None, "curr_type_point": "start", 'editing_state': 'add_points', 'pretrained_weight': init_pkl }) # init image global_state = init_images(global_state) with gr.Row(): with gr.Row(): # Left --> tools with gr.Column(scale=3): # Pickle with gr.Row(): with gr.Column(scale=1, min_width=10): gr.Markdown(value='Pickle', show_label=False) with gr.Column(scale=4, min_width=10): form_pretrained_dropdown = gr.Dropdown( choices=list(valid_checkpoints_dict.keys()), label="Pretrained Model", value=init_pkl, ) # Latent with gr.Row(): with gr.Column(scale=1, min_width=10): gr.Markdown(value='Latent', show_label=False) with gr.Column(scale=4, min_width=10): form_seed_number = gr.Number( value=global_state.value['params']['seed'], interactive=True, label="Seed", ) form_lr_number = gr.Number( value=global_state.value["params"]["lr"], interactive=True, label="Step Size") with gr.Row(): with gr.Column(scale=2, min_width=10): form_reset_image = gr.Button("Reset Image") with gr.Column(scale=3, min_width=10): form_latent_space = gr.Radio( ['w', 'w+'], value=global_state.value['params'] ['latent_space'], interactive=True, label='Latent space to optimize', show_label=False, ) # Drag with gr.Row(): with gr.Column(scale=1, min_width=10): gr.Markdown(value='Drag', show_label=False) with gr.Column(scale=4, min_width=10): with gr.Row(): with gr.Column(scale=1, min_width=10): enable_add_points = gr.Button('Add Points') with gr.Column(scale=1, min_width=10): undo_points = gr.Button('Reset Points') with gr.Row(): with gr.Column(scale=1, min_width=10): form_start_btn = gr.Button("Start") with gr.Column(scale=1, min_width=10): form_stop_btn = gr.Button("Stop") form_steps_number = gr.Number(value=0, label="Steps", interactive=False) # Mask with gr.Row(): with gr.Column(scale=1, min_width=10): gr.Markdown(value='Mask', show_label=False) with gr.Column(scale=4, min_width=10): enable_add_mask = gr.Button('Edit Flexible Area') with gr.Row(): with gr.Column(scale=1, min_width=10): form_reset_mask_btn = gr.Button("Reset mask") with gr.Column(scale=1, min_width=10): show_mask = gr.Checkbox( label='Show Mask', value=global_state.value['show_mask'], show_label=False) with gr.Row(): form_lambda_number = gr.Number( value=global_state.value["params"] ["motion_lambda"], interactive=True, label="Lambda", ) form_draw_interval_number = gr.Number( value=global_state.value["draw_interval"], label="Draw Interval (steps)", interactive=True, visible=False) # Right --> Image with gr.Column(scale=8): form_image = ImageMask( value=global_state.value['images']['image_show'], brush_radius=20).style( width=768, height=768) # NOTE: hard image size code here. gr.Markdown(""" ## Quick Start 1. Select desired `Pretrained Model` and adjust `Seed` to generate an initial image. 2. Click on image to add control points. 3. Click `Start` and enjoy it! ## Advance Usage 1. Change `Step Size` to adjust learning rate in drag optimization. 2. Select `w` or `w+` to change latent space to optimize: * Optimize on `w` space may cause greater influence to the image. * Optimize on `w+` space may work slower than `w`, but usually achieve better results. * Note that changing the latent space will reset the image, points and mask (this has the same effect as `Reset Image` button). 3. Click `Edit Flexible Area` to create a mask and constrain the unmasked region to remain unchanged. """) gr.HTML("""
Gradio demo supported by OpenMMLab MMagic
""") # Network & latents tab listeners def on_change_pretrained_dropdown(pretrained_value, global_state): """Function to handle model change. 1. Set pretrained value to global_state 2. Re-init images and clear all states """ global_state['pretrained_weight'] = pretrained_value init_images(global_state) clear_state(global_state) return global_state, global_state["images"]['image_show'] form_pretrained_dropdown.change( on_change_pretrained_dropdown, inputs=[form_pretrained_dropdown, global_state], outputs=[global_state, form_image], ) def on_click_reset_image(global_state): """Reset image to the original one and clear all states 1. Re-init images 2. Clear all states """ init_images(global_state) clear_state(global_state) return global_state, global_state['images']['image_show'] form_reset_image.click( on_click_reset_image, inputs=[global_state], outputs=[global_state, form_image], ) # Update parameters def on_change_update_image_seed(seed, global_state): """Function to handle generation seed change. 1. Set seed to global_state 2. Re-init images and clear all states """ global_state["params"]["seed"] = int(seed) init_images(global_state) clear_state(global_state) return global_state, global_state['images']['image_show'] form_seed_number.change( on_change_update_image_seed, inputs=[form_seed_number, global_state], outputs=[global_state, form_image], ) def on_click_latent_space(latent_space, global_state): """Function to reset latent space to optimize. NOTE: this function we reset the image and all controls 1. Set latent-space to global_state 2. Re-init images and clear all state """ global_state['params']['latent_space'] = latent_space init_images(global_state) clear_state(global_state) return global_state, global_state['images']['image_show'] form_latent_space.change(on_click_latent_space, inputs=[form_latent_space, global_state], outputs=[global_state, form_image]) # ==== Params form_lambda_number.change( partial(on_change_single_global_state, ["params", "motion_lambda"]), inputs=[form_lambda_number, global_state], outputs=[global_state], ) def on_change_lr(lr, global_state): if lr == 0: print('lr is 0, do nothing.') return global_state else: global_state["params"]["lr"] = lr renderer = global_state['renderer'] renderer.update_lr(lr) print('New optimizer: ') print(renderer.w_optim) return global_state form_lr_number.change( on_change_lr, inputs=[form_lr_number, global_state], outputs=[global_state], ) def on_click_start(global_state, image): p_in_pixels = [] t_in_pixels = [] valid_points = [] # handle of start drag in mask editing mode global_state = preprocess_mask_info(global_state, image) # Prepare the points for the inference if len(global_state["points"]) == 0: # yield on_click_start_wo_points(global_state, image) image_raw = global_state['images']['image_raw'] update_image_draw( image_raw, global_state['points'], global_state['mask'], global_state['show_mask'], global_state, ) yield ( global_state, 0, global_state['images']['image_show'], # gr.File.update(visible=False), gr.Button.update(interactive=True), gr.Button.update(interactive=True), gr.Button.update(interactive=True), gr.Button.update(interactive=True), gr.Button.update(interactive=True), # latent space gr.Radio.update(interactive=True), gr.Button.update(interactive=True), # NOTE: disable stop button gr.Button.update(interactive=False), # update other comps gr.Dropdown.update(interactive=True), gr.Number.update(interactive=True), gr.Number.update(interactive=True), gr.Button.update(interactive=True), gr.Button.update(interactive=True), gr.Checkbox.update(interactive=True), # gr.Number.update(interactive=True), gr.Number.update(interactive=True), ) else: # Transform the points into torch tensors for key_point, point in global_state["points"].items(): try: p_start = point.get("start_temp", point["start"]) p_end = point["target"] if p_start is None or p_end is None: continue except KeyError: continue p_in_pixels.append(p_start) t_in_pixels.append(p_end) valid_points.append(key_point) mask = torch.tensor(global_state['mask']).float() drag_mask = 1 - mask renderer: Renderer = global_state["renderer"] global_state['temporal_params']['stop'] = False global_state['editing_state'] = 'running' # reverse points order p_to_opt = reverse_point_pairs(p_in_pixels) t_to_opt = reverse_point_pairs(t_in_pixels) print('Running with:') print(f' Source: {p_in_pixels}') print(f' Target: {t_in_pixels}') step_idx = 0 while True: if global_state["temporal_params"]["stop"]: break # do drage here! renderer._render_drag_impl( global_state['generator_params'], p_to_opt, # point t_to_opt, # target drag_mask, # mask, global_state['params']['motion_lambda'], # lambda_mask reg=0, feature_idx=5, # NOTE: do not support change for now r1=global_state['params']['r1_in_pixels'], # r1 r2=global_state['params']['r2_in_pixels'], # r2 # random_seed = 0, # noise_mode = 'const', trunc_psi=global_state['params']['trunc_psi'], # force_fp32 = False, # layer_name = None, # sel_channels = 3, # base_channel = 0, # img_scale_db = 0, # img_normalize = False, # untransform = False, is_drag=True, to_pil=True) if step_idx % global_state['draw_interval'] == 0: print('Current Source:') for key_point, p_i, t_i in zip(valid_points, p_to_opt, t_to_opt): global_state["points"][key_point]["start_temp"] = [ p_i[1], p_i[0], ] global_state["points"][key_point]["target"] = [ t_i[1], t_i[0], ] start_temp = global_state["points"][key_point][ "start_temp"] print(f' {start_temp}') image_result = global_state['generator_params']['image'] image_draw = update_image_draw( image_result, global_state['points'], global_state['mask'], global_state['show_mask'], global_state, ) global_state['images']['image_raw'] = image_result yield ( global_state, step_idx, global_state['images']['image_show'], # gr.File.update(visible=False), gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Button.update(interactive=False), # latent space gr.Radio.update(interactive=False), gr.Button.update(interactive=False), # enable stop button in loop gr.Button.update(interactive=True), # update other comps gr.Dropdown.update(interactive=False), gr.Number.update(interactive=False), gr.Number.update(interactive=False), gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Checkbox.update(interactive=False), # gr.Number.update(interactive=False), gr.Number.update(interactive=False), ) # increate step step_idx += 1 image_result = global_state['generator_params']['image'] global_state['images']['image_raw'] = image_result image_draw = update_image_draw(image_result, global_state['points'], global_state['mask'], global_state['show_mask'], global_state) # fp = NamedTemporaryFile(suffix=".png", delete=False) # image_result.save(fp, "PNG") global_state['editing_state'] = 'add_points' yield ( global_state, 0, # reset step to 0 after stop. global_state['images']['image_show'], # gr.File.update(visible=True, value=fp.name), gr.Button.update(interactive=True), gr.Button.update(interactive=True), gr.Button.update(interactive=True), gr.Button.update(interactive=True), gr.Button.update(interactive=True), # latent space gr.Radio.update(interactive=True), gr.Button.update(interactive=True), # NOTE: disable stop button with loop finish gr.Button.update(interactive=False), # update other comps gr.Dropdown.update(interactive=True), gr.Number.update(interactive=True), gr.Number.update(interactive=True), gr.Checkbox.update(interactive=True), gr.Number.update(interactive=True), ) form_start_btn.click( on_click_start, inputs=[global_state, form_image], outputs=[ global_state, form_steps_number, form_image, # form_download_result_file, # >>> buttons form_reset_image, enable_add_points, enable_add_mask, undo_points, form_reset_mask_btn, form_latent_space, form_start_btn, form_stop_btn, # <<< buttonm # >>> inputs comps form_pretrained_dropdown, form_seed_number, form_lr_number, show_mask, form_lambda_number, ], ) def on_click_stop(global_state): """Function to handle stop button is clicked. 1. send a stop signal by set global_state["temporal_params"]["stop"] as True 2. Disable Stop button """ global_state["temporal_params"]["stop"] = True return global_state, gr.Button.update(interactive=False) form_stop_btn.click(on_click_stop, inputs=[global_state], outputs=[global_state, form_stop_btn]) form_draw_interval_number.change( partial( on_change_single_global_state, "draw_interval", map_transform=lambda x: int(x), ), inputs=[form_draw_interval_number, global_state], outputs=[global_state], ) def on_click_remove_point(global_state): choice = global_state["curr_point"] del global_state["points"][choice] choices = list(global_state["points"].keys()) if len(choices) > 0: global_state["curr_point"] = choices[0] return ( gr.Dropdown.update(choices=choices, value=choices[0]), global_state, ) # Mask def on_click_reset_mask(global_state): global_state['mask'] = np.ones( ( global_state["images"]["image_raw"].size[1], global_state["images"]["image_raw"].size[0], ), dtype=np.uint8, ) image_draw = update_image_draw(global_state['images']['image_raw'], global_state['points'], global_state['mask'], global_state['show_mask'], global_state) return global_state, image_draw form_reset_mask_btn.click( on_click_reset_mask, inputs=[global_state], outputs=[global_state, form_image], ) # Image def on_click_enable_draw(global_state, image): """Function to start add mask mode. 1. Preprocess mask info from last state 2. Change editing state to add_mask 3. Set curr image with points and mask """ global_state = preprocess_mask_info(global_state, image) global_state['editing_state'] = 'add_mask' image_raw = global_state['images']['image_raw'] image_draw = update_image_draw(image_raw, global_state['points'], global_state['mask'], True, global_state) return (global_state, gr.Image.update(value=image_draw, interactive=True)) def on_click_remove_draw(global_state, image): """Function to start remove mask mode. 1. Preprocess mask info from last state 2. Change editing state to remove_mask 3. Set curr image with points and mask """ global_state = preprocess_mask_info(global_state, image) global_state['edinting_state'] = 'remove_mask' image_raw = global_state['images']['image_raw'] image_draw = update_image_draw(image_raw, global_state['points'], global_state['mask'], True, global_state) return (global_state, gr.Image.update(value=image_draw, interactive=True)) enable_add_mask.click(on_click_enable_draw, inputs=[global_state, form_image], outputs=[ global_state, form_image, ]) def on_click_add_point(global_state, image: dict): """Function switch from add mask mode to add points mode. 1. Updaste mask buffer if need 2. Change global_state['editing_state'] to 'add_points' 3. Set current image with mask """ global_state = preprocess_mask_info(global_state, image) global_state['editing_state'] = 'add_points' mask = global_state['mask'] image_raw = global_state['images']['image_raw'] image_draw = update_image_draw(image_raw, global_state['points'], mask, global_state['show_mask'], global_state) return (global_state, gr.Image.update(value=image_draw, interactive=False)) enable_add_points.click(on_click_add_point, inputs=[global_state, form_image], outputs=[global_state, form_image]) def on_click_image(global_state, evt: gr.SelectData): """This function only support click for point selection """ xy = evt.index if global_state['editing_state'] != 'add_points': print(f'In {global_state["editing_state"]} state. ' 'Do not add points.') return global_state, global_state['images']['image_show'] points = global_state["points"] point_idx = get_latest_points_pair(points) if point_idx is None: points[0] = {'start': xy, 'target': None} print(f'Click Image - Start - {xy}') elif points[point_idx].get('target', None) is None: points[point_idx]['target'] = xy print(f'Click Image - Target - {xy}') else: points[point_idx + 1] = {'start': xy, 'target': None} print(f'Click Image - Start - {xy}') image_raw = global_state['images']['image_raw'] image_draw = update_image_draw( image_raw, global_state['points'], global_state['mask'], global_state['show_mask'], global_state, ) return global_state, image_draw form_image.select( on_click_image, inputs=[global_state], outputs=[global_state, form_image], ) def on_click_clear_points(global_state): """Function to handle clear all control points 1. clear global_state['points'] (clear_state) 2. re-init network 2. re-draw image """ clear_state(global_state, target='point') renderer: Renderer = global_state["renderer"] renderer.feat_refs = None image_raw = global_state['images']['image_raw'] image_draw = update_image_draw(image_raw, {}, global_state['mask'], global_state['show_mask'], global_state) return global_state, image_draw undo_points.click(on_click_clear_points, inputs=[global_state], outputs=[global_state, form_image]) def on_click_show_mask(global_state, show_mask): """Function to control whether show mask on image.""" global_state['show_mask'] = show_mask image_raw = global_state['images']['image_raw'] image_draw = update_image_draw( image_raw, global_state['points'], global_state['mask'], global_state['show_mask'], global_state, ) return global_state, image_draw show_mask.change( on_click_show_mask, inputs=[global_state, show_mask], outputs=[global_state, form_image], ) gr.close_all() app.queue(concurrency_count=3, max_size=20) app.launch(share=args.share, server_name="0.0.0.0" if args.listen else "127.0.0.1") ================================================ FILE: viz/__init__.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. # empty ================================================ FILE: viz/capture_widget.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import os import re import numpy as np import imgui import PIL.Image from gui_utils import imgui_utils from . import renderer import torch import torchvision #---------------------------------------------------------------------------- class CaptureWidget: def __init__(self, viz): self.viz = viz self.path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '_screenshots')) self.dump_image = False self.dump_gui = False self.defer_frames = 0 self.disabled_time = 0 def dump_png(self, image): viz = self.viz try: _height, _width, channels = image.shape print(viz.result) assert image.dtype == np.uint8 os.makedirs(self.path, exist_ok=True) file_id = 0 for entry in os.scandir(self.path): if entry.is_file(): match = re.fullmatch(r'(\d+).*', entry.name) if match: file_id = max(file_id, int(match.group(1)) + 1) if channels == 1: pil_image = PIL.Image.fromarray(image[:, :, 0], 'L') else: pil_image = PIL.Image.fromarray(image[:, :, :3], 'RGB') pil_image.save(os.path.join(self.path, f'{file_id:05d}.png')) np.save(os.path.join(self.path, f'{file_id:05d}.npy'), viz.result.w) except: viz.result.error = renderer.CapturedException() @imgui_utils.scoped_by_object_id def __call__(self, show=True): viz = self.viz if show: with imgui_utils.grayed_out(self.disabled_time != 0): imgui.text('Capture') imgui.same_line(viz.label_w) _changed, self.path = imgui_utils.input_text('##path', self.path, 1024, flags=(imgui.INPUT_TEXT_AUTO_SELECT_ALL | imgui.INPUT_TEXT_ENTER_RETURNS_TRUE), width=(-1), help_text='PATH') if imgui.is_item_hovered() and not imgui.is_item_active() and self.path != '': imgui.set_tooltip(self.path) imgui.text(' ') imgui.same_line(viz.label_w) if imgui_utils.button('Save image', width=viz.button_w, enabled=(self.disabled_time == 0 and 'image' in viz.result)): self.dump_image = True self.defer_frames = 2 self.disabled_time = 0.5 imgui.same_line() if imgui_utils.button('Save GUI', width=viz.button_w, enabled=(self.disabled_time == 0)): self.dump_gui = True self.defer_frames = 2 self.disabled_time = 0.5 self.disabled_time = max(self.disabled_time - viz.frame_delta, 0) if self.defer_frames > 0: self.defer_frames -= 1 elif self.dump_image: if 'image' in viz.result: self.dump_png(viz.result.image) self.dump_image = False elif self.dump_gui: viz.capture_next_frame() self.dump_gui = False captured_frame = viz.pop_captured_frame() if captured_frame is not None: self.dump_png(captured_frame) #---------------------------------------------------------------------------- ================================================ FILE: viz/drag_widget.py ================================================ import os import torch import numpy as np import imgui import dnnlib from gui_utils import imgui_utils #---------------------------------------------------------------------------- class DragWidget: def __init__(self, viz): self.viz = viz self.point = [-1, -1] self.points = [] self.targets = [] self.is_point = True self.last_click = False self.is_drag = False self.iteration = 0 self.mode = 'point' self.r_mask = 50 self.show_mask = False self.mask = torch.ones(256, 256) self.lambda_mask = 20 self.feature_idx = 5 self.r1 = 3 self.r2 = 12 self.path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '_screenshots')) self.defer_frames = 0 self.disabled_time = 0 def action(self, click, down, x, y): if self.mode == 'point': self.add_point(click, x, y) elif down: self.draw_mask(x, y) def add_point(self, click, x, y): if click: self.point = [y, x] elif self.last_click: if self.is_drag: self.stop_drag() if self.is_point: self.points.append(self.point) self.is_point = False else: self.targets.append(self.point) self.is_point = True self.last_click = click def init_mask(self, w, h): self.width, self.height = w, h self.mask = torch.ones(h, w) def draw_mask(self, x, y): X = torch.linspace(0, self.width, self.width) Y = torch.linspace(0, self.height, self.height) yy, xx = torch.meshgrid(Y, X) circle = (xx - x)**2 + (yy - y)**2 < self.r_mask**2 if self.mode == 'flexible': self.mask[circle] = 0 elif self.mode == 'fixed': self.mask[circle] = 1 def stop_drag(self): self.is_drag = False self.iteration = 0 def set_points(self, points): self.points = points def reset_point(self): self.points = [] self.targets = [] self.is_point = True def load_points(self, suffix): points = [] point_path = self.path + f'_{suffix}.txt' try: with open(point_path, "r") as f: for line in f.readlines(): y, x = line.split() points.append([int(y), int(x)]) except: print(f'Wrong point file path: {point_path}') return points @imgui_utils.scoped_by_object_id def __call__(self, show=True): viz = self.viz reset = False if show: with imgui_utils.grayed_out(self.disabled_time != 0): imgui.text('Drag') imgui.same_line(viz.label_w) if imgui_utils.button('Add point', width=viz.button_w, enabled='image' in viz.result): self.mode = 'point' imgui.same_line() reset = False if imgui_utils.button('Reset point', width=viz.button_w, enabled='image' in viz.result): self.reset_point() reset = True imgui.text(' ') imgui.same_line(viz.label_w) if imgui_utils.button('Start', width=viz.button_w, enabled='image' in viz.result): self.is_drag = True if len(self.points) > len(self.targets): self.points = self.points[:len(self.targets)] imgui.same_line() if imgui_utils.button('Stop', width=viz.button_w, enabled='image' in viz.result): self.stop_drag() imgui.text(' ') imgui.same_line(viz.label_w) imgui.text(f'Steps: {self.iteration}') imgui.text('Mask') imgui.same_line(viz.label_w) if imgui_utils.button('Flexible area', width=viz.button_w, enabled='image' in viz.result): self.mode = 'flexible' self.show_mask = True imgui.same_line() if imgui_utils.button('Fixed area', width=viz.button_w, enabled='image' in viz.result): self.mode = 'fixed' self.show_mask = True imgui.text(' ') imgui.same_line(viz.label_w) if imgui_utils.button('Reset mask', width=viz.button_w, enabled='image' in viz.result): self.mask = torch.ones(self.height, self.width) imgui.same_line() _clicked, self.show_mask = imgui.checkbox('Show mask', self.show_mask) imgui.text(' ') imgui.same_line(viz.label_w) with imgui_utils.item_width(viz.font_size * 6): changed, self.r_mask = imgui.input_int('Radius', self.r_mask) imgui.text(' ') imgui.same_line(viz.label_w) with imgui_utils.item_width(viz.font_size * 6): changed, self.lambda_mask = imgui.input_int('Lambda', self.lambda_mask) self.disabled_time = max(self.disabled_time - viz.frame_delta, 0) if self.defer_frames > 0: self.defer_frames -= 1 viz.args.is_drag = self.is_drag if self.is_drag: self.iteration += 1 viz.args.iteration = self.iteration viz.args.points = [point for point in self.points] viz.args.targets = [point for point in self.targets] viz.args.mask = self.mask viz.args.lambda_mask = self.lambda_mask viz.args.feature_idx = self.feature_idx viz.args.r1 = self.r1 viz.args.r2 = self.r2 viz.args.reset = reset #---------------------------------------------------------------------------- ================================================ FILE: viz/latent_widget.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import os import numpy as np import imgui import dnnlib import torch from gui_utils import imgui_utils #---------------------------------------------------------------------------- class LatentWidget: def __init__(self, viz): self.viz = viz self.seed = 0 self.w_plus = True self.reg = 0 self.lr = 0.001 self.w_path = '' self.w_load = None self.defer_frames = 0 self.disabled_time = 0 @imgui_utils.scoped_by_object_id def __call__(self, show=True): viz = self.viz if show: with imgui_utils.grayed_out(self.disabled_time != 0): imgui.text('Latent') imgui.same_line(viz.label_w) with imgui_utils.item_width(viz.font_size * 8.75): changed, seed = imgui.input_int('Seed', self.seed) if changed: self.seed = seed # reset latent code self.w_load = None # load latent code imgui.text(' ') imgui.same_line(viz.label_w) _changed, self.w_path = imgui_utils.input_text('##path', self.w_path, 1024, flags=(imgui.INPUT_TEXT_AUTO_SELECT_ALL | imgui.INPUT_TEXT_ENTER_RETURNS_TRUE), width=(-1), help_text='Path to latent code') if imgui.is_item_hovered() and not imgui.is_item_active() and self.w_path != '': imgui.set_tooltip(self.w_path) imgui.text(' ') imgui.same_line(viz.label_w) if imgui_utils.button('Load latent', width=viz.button_w, enabled=(self.disabled_time == 0 and 'image' in viz.result)): assert os.path.isfile(self.w_path), f"{self.w_path} does not exist!" self.w_load = torch.load(self.w_path) self.defer_frames = 2 self.disabled_time = 0.5 imgui.text(' ') imgui.same_line(viz.label_w) with imgui_utils.item_width(viz.button_w): changed, lr = imgui.input_float('Step Size', self.lr) if changed: self.lr = lr # imgui.text(' ') # imgui.same_line(viz.label_w) # with imgui_utils.item_width(viz.button_w): # changed, reg = imgui.input_float('Regularize', self.reg) # if changed: # self.reg = reg imgui.text(' ') imgui.same_line(viz.label_w) reset_w = imgui_utils.button('Reset', width=viz.button_w, enabled='image' in viz.result) imgui.same_line() _clicked, w = imgui.checkbox('w', not self.w_plus) if w: self.w_plus = False imgui.same_line() _clicked, self.w_plus = imgui.checkbox('w+', self.w_plus) self.disabled_time = max(self.disabled_time - viz.frame_delta, 0) if self.defer_frames > 0: self.defer_frames -= 1 viz.args.w0_seed = self.seed viz.args.w_load = self.w_load viz.args.reg = self.reg viz.args.w_plus = self.w_plus viz.args.reset_w = reset_w viz.args.lr = lr #---------------------------------------------------------------------------- ================================================ FILE: viz/pickle_widget.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import glob import os import re import dnnlib import imgui import numpy as np from gui_utils import imgui_utils from . import renderer #---------------------------------------------------------------------------- def _locate_results(pattern): return pattern #---------------------------------------------------------------------------- class PickleWidget: def __init__(self, viz): self.viz = viz self.search_dirs = [] self.cur_pkl = None self.user_pkl = '' self.recent_pkls = [] self.browse_cache = dict() # {tuple(path, ...): [dnnlib.EasyDict(), ...], ...} self.browse_refocus = False self.load('', ignore_errors=True) def add_recent(self, pkl, ignore_errors=False): try: resolved = self.resolve_pkl(pkl) if resolved not in self.recent_pkls: self.recent_pkls.append(resolved) except: if not ignore_errors: raise def load(self, pkl, ignore_errors=False): viz = self.viz viz.clear_result() viz.skip_frame() # The input field will change on next frame. try: resolved = self.resolve_pkl(pkl) name = resolved.replace('\\', '/').split('/')[-1] self.cur_pkl = resolved self.user_pkl = resolved viz.result.message = f'Loading {name}...' viz.defer_rendering() if resolved in self.recent_pkls: self.recent_pkls.remove(resolved) self.recent_pkls.insert(0, resolved) except: self.cur_pkl = None self.user_pkl = pkl if pkl == '': viz.result = dnnlib.EasyDict(message='No network pickle loaded') else: viz.result = dnnlib.EasyDict(error=renderer.CapturedException()) if not ignore_errors: raise @imgui_utils.scoped_by_object_id def __call__(self, show=True): viz = self.viz recent_pkls = [pkl for pkl in self.recent_pkls if pkl != self.user_pkl] if show: imgui.text('Pickle') imgui.same_line(viz.label_w) idx = self.user_pkl.rfind('/') changed, self.user_pkl = imgui_utils.input_text('##pkl', self.user_pkl[idx+1:], 1024, flags=(imgui.INPUT_TEXT_AUTO_SELECT_ALL | imgui.INPUT_TEXT_ENTER_RETURNS_TRUE), width=(-1), help_text=' | | | | /.pkl') if changed: self.load(self.user_pkl, ignore_errors=True) if imgui.is_item_hovered() and not imgui.is_item_active() and self.user_pkl != '': imgui.set_tooltip(self.user_pkl) # imgui.same_line() imgui.text(' ') imgui.same_line(viz.label_w) if imgui_utils.button('Recent...', width=viz.button_w, enabled=(len(recent_pkls) != 0)): imgui.open_popup('recent_pkls_popup') imgui.same_line() if imgui_utils.button('Browse...', enabled=len(self.search_dirs) > 0, width=viz.button_w): imgui.open_popup('browse_pkls_popup') self.browse_cache.clear() self.browse_refocus = True if imgui.begin_popup('recent_pkls_popup'): for pkl in recent_pkls: clicked, _state = imgui.menu_item(pkl) if clicked: self.load(pkl, ignore_errors=True) imgui.end_popup() if imgui.begin_popup('browse_pkls_popup'): def recurse(parents): key = tuple(parents) items = self.browse_cache.get(key, None) if items is None: items = self.list_runs_and_pkls(parents) self.browse_cache[key] = items for item in items: if item.type == 'run' and imgui.begin_menu(item.name): recurse([item.path]) imgui.end_menu() if item.type == 'pkl': clicked, _state = imgui.menu_item(item.name) if clicked: self.load(item.path, ignore_errors=True) if len(items) == 0: with imgui_utils.grayed_out(): imgui.menu_item('No results found') recurse(self.search_dirs) if self.browse_refocus: imgui.set_scroll_here() viz.skip_frame() # Focus will change on next frame. self.browse_refocus = False imgui.end_popup() paths = viz.pop_drag_and_drop_paths() if paths is not None and len(paths) >= 1: self.load(paths[0], ignore_errors=True) viz.args.pkl = self.cur_pkl def list_runs_and_pkls(self, parents): items = [] run_regex = re.compile(r'\d+-.*') pkl_regex = re.compile(r'network-snapshot-\d+\.pkl') for parent in set(parents): if os.path.isdir(parent): for entry in os.scandir(parent): if entry.is_dir() and run_regex.fullmatch(entry.name): items.append(dnnlib.EasyDict(type='run', name=entry.name, path=os.path.join(parent, entry.name))) if entry.is_file() and pkl_regex.fullmatch(entry.name): items.append(dnnlib.EasyDict(type='pkl', name=entry.name, path=os.path.join(parent, entry.name))) items = sorted(items, key=lambda item: (item.name.replace('_', ' '), item.path)) return items def resolve_pkl(self, pattern): assert isinstance(pattern, str) assert pattern != '' # URL => return as is. if dnnlib.util.is_url(pattern): return pattern # Short-hand pattern => locate. path = _locate_results(pattern) # Run dir => pick the last saved snapshot. if os.path.isdir(path): pkl_files = sorted(glob.glob(os.path.join(path, 'network-snapshot-*.pkl'))) if len(pkl_files) == 0: raise IOError(f'No network pickle found in "{path}"') path = pkl_files[-1] # Normalize. path = os.path.abspath(path) return path #---------------------------------------------------------------------------- ================================================ FILE: viz/renderer.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. from socket import has_dualstack_ipv6 import sys import copy import traceback import math import numpy as np from PIL import Image, ImageDraw, ImageFont import torch import torch.fft import torch.nn as nn import torch.nn.functional as F import matplotlib.cm import dnnlib from torch_utils.ops import upfirdn2d import legacy # pylint: disable=import-error #---------------------------------------------------------------------------- class CapturedException(Exception): def __init__(self, msg=None): if msg is None: _type, value, _traceback = sys.exc_info() assert value is not None if isinstance(value, CapturedException): msg = str(value) else: msg = traceback.format_exc() assert isinstance(msg, str) super().__init__(msg) #---------------------------------------------------------------------------- class CaptureSuccess(Exception): def __init__(self, out): super().__init__() self.out = out #---------------------------------------------------------------------------- def add_watermark_np(input_image_array, watermark_text="AI Generated"): image = Image.fromarray(np.uint8(input_image_array)).convert("RGBA") # Initialize text image txt = Image.new('RGBA', image.size, (255, 255, 255, 0)) font = ImageFont.truetype('arial.ttf', round(25/512*image.size[0])) d = ImageDraw.Draw(txt) text_width, text_height = font.getsize(watermark_text) text_position = (image.size[0] - text_width - 10, image.size[1] - text_height - 10) text_color = (255, 255, 255, 128) # white color with the alpha channel set to semi-transparent # Draw the text onto the text canvas d.text(text_position, watermark_text, font=font, fill=text_color) # Combine the image with the watermark watermarked = Image.alpha_composite(image, txt) watermarked_array = np.array(watermarked) return watermarked_array #---------------------------------------------------------------------------- class Renderer: def __init__(self, disable_timing=False): self._device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') self._dtype = torch.float32 if self._device.type == 'mps' else torch.float64 self._pkl_data = dict() # {pkl: dict | CapturedException, ...} self._networks = dict() # {cache_key: torch.nn.Module, ...} self._pinned_bufs = dict() # {(shape, dtype): torch.Tensor, ...} self._cmaps = dict() # {name: torch.Tensor, ...} self._is_timing = False if not disable_timing: self._start_event = torch.cuda.Event(enable_timing=True) self._end_event = torch.cuda.Event(enable_timing=True) self._disable_timing = disable_timing self._net_layers = dict() # {cache_key: [dnnlib.EasyDict, ...], ...} def render(self, **args): if self._disable_timing: self._is_timing = False else: self._start_event.record(torch.cuda.current_stream(self._device)) self._is_timing = True res = dnnlib.EasyDict() try: init_net = False if not hasattr(self, 'G'): init_net = True if hasattr(self, 'pkl'): if self.pkl != args['pkl']: init_net = True if hasattr(self, 'w_load'): if self.w_load is not args['w_load']: init_net = True if hasattr(self, 'w0_seed'): if self.w0_seed != args['w0_seed']: init_net = True if hasattr(self, 'w_plus'): if self.w_plus != args['w_plus']: init_net = True if args['reset_w']: init_net = True res.init_net = init_net if init_net: self.init_network(res, **args) self._render_drag_impl(res, **args) except: res.error = CapturedException() if not self._disable_timing: self._end_event.record(torch.cuda.current_stream(self._device)) if 'image' in res: res.image = self.to_cpu(res.image).detach().numpy() res.image = add_watermark_np(res.image, 'AI Generated') if 'stats' in res: res.stats = self.to_cpu(res.stats).detach().numpy() if 'error' in res: res.error = str(res.error) # if 'stop' in res and res.stop: if self._is_timing and not self._disable_timing: self._end_event.synchronize() res.render_time = self._start_event.elapsed_time(self._end_event) * 1e-3 self._is_timing = False return res def get_network(self, pkl, key, **tweak_kwargs): data = self._pkl_data.get(pkl, None) if data is None: print(f'Loading "{pkl}"... ', end='', flush=True) try: with dnnlib.util.open_url(pkl, verbose=False) as f: data = legacy.load_network_pkl(f) print('Done.') except: data = CapturedException() print('Failed!') self._pkl_data[pkl] = data self._ignore_timing() if isinstance(data, CapturedException): raise data orig_net = data[key] cache_key = (orig_net, self._device, tuple(sorted(tweak_kwargs.items()))) net = self._networks.get(cache_key, None) if net is None: try: if 'stylegan2' in pkl: from training.networks_stylegan2 import Generator elif 'stylegan3' in pkl: from training.networks_stylegan3 import Generator elif 'stylegan_human' in pkl: from stylegan_human.training_scripts.sg2.training.networks import Generator else: raise NameError('Cannot infer model type from pkl name!') print(data[key].init_args) print(data[key].init_kwargs) if 'stylegan_human' in pkl: net = Generator(*data[key].init_args, **data[key].init_kwargs, square=False, padding=True) else: net = Generator(*data[key].init_args, **data[key].init_kwargs) net.load_state_dict(data[key].state_dict()) net.to(self._device) except: net = CapturedException() self._networks[cache_key] = net self._ignore_timing() if isinstance(net, CapturedException): raise net return net def _get_pinned_buf(self, ref): key = (tuple(ref.shape), ref.dtype) buf = self._pinned_bufs.get(key, None) if buf is None: buf = torch.empty(ref.shape, dtype=ref.dtype).pin_memory() self._pinned_bufs[key] = buf return buf def to_device(self, buf): return self._get_pinned_buf(buf).copy_(buf).to(self._device) def to_cpu(self, buf): return self._get_pinned_buf(buf).copy_(buf).clone() def _ignore_timing(self): self._is_timing = False def _apply_cmap(self, x, name='viridis'): cmap = self._cmaps.get(name, None) if cmap is None: cmap = matplotlib.cm.get_cmap(name) cmap = cmap(np.linspace(0, 1, num=1024), bytes=True)[:, :3] cmap = self.to_device(torch.from_numpy(cmap)) self._cmaps[name] = cmap hi = cmap.shape[0] - 1 x = (x * hi + 0.5).clamp(0, hi).to(torch.int64) x = torch.nn.functional.embedding(x, cmap) return x def init_network(self, res, pkl = None, w0_seed = 0, w_load = None, w_plus = True, noise_mode = 'const', trunc_psi = 0.7, trunc_cutoff = None, input_transform = None, lr = 0.001, **kwargs ): # Dig up network details. self.pkl = pkl G = self.get_network(pkl, 'G_ema') self.G = G res.img_resolution = G.img_resolution res.num_ws = G.num_ws res.has_noise = any('noise_const' in name for name, _buf in G.synthesis.named_buffers()) res.has_input_transform = (hasattr(G.synthesis, 'input') and hasattr(G.synthesis.input, 'transform')) # Set input transform. if res.has_input_transform: m = np.eye(3) try: if input_transform is not None: m = np.linalg.inv(np.asarray(input_transform)) except np.linalg.LinAlgError: res.error = CapturedException() G.synthesis.input.transform.copy_(torch.from_numpy(m)) # Generate random latents. self.w0_seed = w0_seed self.w_load = w_load if self.w_load is None: # Generate random latents. z = torch.from_numpy(np.random.RandomState(w0_seed).randn(1, 512)).to(self._device, dtype=self._dtype) # Run mapping network. label = torch.zeros([1, G.c_dim], device=self._device) w = G.mapping(z, label, truncation_psi=trunc_psi, truncation_cutoff=trunc_cutoff) else: w = self.w_load.clone().to(self._device) self.w0 = w.detach().clone() self.w_plus = w_plus if w_plus: self.w = w.detach() else: self.w = w[:, 0, :].detach() self.w.requires_grad = True self.w_optim = torch.optim.Adam([self.w], lr=lr) self.feat_refs = None self.points0_pt = None def update_lr(self, lr): del self.w_optim self.w_optim = torch.optim.Adam([self.w], lr=lr) print(f'Rebuild optimizer with lr: {lr}') print(' Remain feat_refs and points0_pt') def _render_drag_impl(self, res, points = [], targets = [], mask = None, lambda_mask = 10, reg = 0, feature_idx = 5, r1 = 3, r2 = 12, random_seed = 0, noise_mode = 'const', trunc_psi = 0.7, force_fp32 = False, layer_name = None, sel_channels = 3, base_channel = 0, img_scale_db = 0, img_normalize = False, untransform = False, is_drag = False, reset = False, to_pil = False, **kwargs ): G = self.G ws = self.w if ws.dim() == 2: ws = ws.unsqueeze(1).repeat(1,6,1) ws = torch.cat([ws[:,:6,:], self.w0[:,6:,:]], dim=1) if hasattr(self, 'points'): if len(points) != len(self.points): reset = True if reset: self.feat_refs = None self.points0_pt = None self.points = points # Run synthesis network. label = torch.zeros([1, G.c_dim], device=self._device) img, feat = G(ws, label, truncation_psi=trunc_psi, noise_mode=noise_mode, input_is_w=True, return_feature=True) h, w = G.img_resolution, G.img_resolution if is_drag: X = torch.linspace(0, h, h) Y = torch.linspace(0, w, w) xx, yy = torch.meshgrid(X, Y) feat_resize = F.interpolate(feat[feature_idx], [h, w], mode='bilinear') if self.feat_refs is None: self.feat0_resize = F.interpolate(feat[feature_idx].detach(), [h, w], mode='bilinear') self.feat_refs = [] for point in points: py, px = round(point[0]), round(point[1]) self.feat_refs.append(self.feat0_resize[:,:,py,px]) self.points0_pt = torch.Tensor(points).unsqueeze(0).to(self._device) # 1, N, 2 # Point tracking with feature matching with torch.no_grad(): for j, point in enumerate(points): r = round(r2 / 512 * h) up = max(point[0] - r, 0) down = min(point[0] + r + 1, h) left = max(point[1] - r, 0) right = min(point[1] + r + 1, w) feat_patch = feat_resize[:,:,up:down,left:right] L2 = torch.linalg.norm(feat_patch - self.feat_refs[j].reshape(1,-1,1,1), dim=1) _, idx = torch.min(L2.view(1,-1), -1) width = right - left point = [idx.item() // width + up, idx.item() % width + left] points[j] = point res.points = [[point[0], point[1]] for point in points] # Motion supervision loss_motion = 0 res.stop = True for j, point in enumerate(points): direction = torch.Tensor([targets[j][1] - point[1], targets[j][0] - point[0]]) if torch.linalg.norm(direction) > max(2 / 512 * h, 2): res.stop = False if torch.linalg.norm(direction) > 1: distance = ((xx.to(self._device) - point[0])**2 + (yy.to(self._device) - point[1])**2)**0.5 relis, reljs = torch.where(distance < round(r1 / 512 * h)) direction = direction / (torch.linalg.norm(direction) + 1e-7) gridh = (relis+direction[1]) / (h-1) * 2 - 1 gridw = (reljs+direction[0]) / (w-1) * 2 - 1 grid = torch.stack([gridw,gridh], dim=-1).unsqueeze(0).unsqueeze(0) target = F.grid_sample(feat_resize.float(), grid, align_corners=True).squeeze(2) loss_motion += F.l1_loss(feat_resize[:,:,relis,reljs].detach(), target) loss = loss_motion if mask is not None: if mask.min() == 0 and mask.max() == 1: mask_usq = mask.to(self._device).unsqueeze(0).unsqueeze(0) loss_fix = F.l1_loss(feat_resize * mask_usq, self.feat0_resize * mask_usq) loss += lambda_mask * loss_fix loss += reg * F.l1_loss(ws, self.w0) # latent code regularization if not res.stop: self.w_optim.zero_grad() loss.backward() self.w_optim.step() # Scale and convert to uint8. img = img[0] if img_normalize: img = img / img.norm(float('inf'), dim=[1,2], keepdim=True).clip(1e-8, 1e8) img = img * (10 ** (img_scale_db / 20)) img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0) if to_pil: from PIL import Image img = img.cpu().numpy() img = Image.fromarray(img) res.image = img res.w = ws.detach().cpu().numpy() #----------------------------------------------------------------------------