Repository: YotamNitzan/ID-disentanglement Branch: main Commit: 342a782364fd Files: 32 Total size: 129.6 KB Directory structure: gitextract_rgin_bu5/ ├── .gitignore ├── LICENSE ├── README.md ├── arglib/ │ ├── __init__.py │ └── arglib.py ├── data_loader/ │ ├── __init__.py │ └── data_loader.py ├── docs/ │ ├── index.html │ ├── mainpage.css │ └── setup.md ├── environment.yml ├── inference.py ├── main.py ├── model/ │ ├── __init__.py │ ├── arcface/ │ │ ├── arcface.py │ │ ├── inference.py │ │ └── resnet.py │ ├── attr_encoder.py │ ├── discriminator.py │ ├── face_detector.py │ ├── generator.py │ ├── id_encoder.py │ ├── landmarks.py │ ├── latent_mapping.py │ ├── model.py │ └── stylegan.py ├── test.py ├── trainer.py ├── utils/ │ ├── __init__.py │ ├── general_utils.py │ └── generate_fake_data.py └── writer.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ .idea .png .jpg ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2020 YotamNitzan Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # Face Identity Disentanglement via Latent Space Mapping

## Description Official Implementation of the paper *Face Identity Disentanglement via Latent Space Mapping* for both training and evaluation. > **Face Identity Disentanglement via Latent Space Mapping**
> Yotam Nitzan1, Amit Bermano1, Yangyan Li2, Daniel Cohen-Or1
> 1Tel-Aviv University, 2Alibaba
> https://arxiv.org/abs/2005.07728 > >

Abstract: Learning disentangled representations of data is a fundamental problem in artificial intelligence. Specifically, disentangled latent representations allow generative models to control and compose the disentangled factors in the synthesis process. Current methods, however, require extensive supervision and training, or instead, noticeably compromise quality. In this paper, we present a method that learns how to represent data in a disentangled way, with minimal supervision, manifested solely using available pre-trained networks. Our key insight is to decouple the processes of disentanglement and synthesis, by employing a leading pre-trained unconditional image generator, such as StyleGAN. By learning to map into its latent space, we leverage both its state-of-the-art quality, and its rich and expressive latent space, without the burden of training it. We demonstrate our approach on the complex and high dimensional domain of human heads. We evaluate our method qualitatively and quantitatively, and exhibit its success with de-identification operations and with temporal identity coherency in image sequences. Through extensive experimentation, we show that our method successfully disentangles identity from other facial attributes, surpassing existing methods, even though they require more training and supervision.

## Setup To setup everything you need check out the [setup instructions](docs/setup.md). ## Training ### Preparing the Dataset The dataset is comprised of StyleGAN-generated images and W latent codes, both are generated from a single StyleGAN model. We also use real images from FFHQ to evaluate quality at test time. The dataset is assumed to be in the following structure: | Path | Description | :--- | :--- | base directory | Directory for all datasets | ├  real | FFHQ image dataset | ├  dataset_N | dataset for resolution NxN | │  ├  images | images generated by StyleGAN | │  └  ws | W latent codes generated by StyleGAN To generate the `dataset_N` directory, run: ``` cd utils\ python generate_fake_data.py \ --resolution N \ --batch_size BATCH_SIZE \ --output_path OUTPUT_PATH \ --pretrained_models_path PRETRAINED_MODELS_PATH \ --num_images NUM_IMAGES \ --gpu GPU ``` It will generate an image dataset in similar format to FFHQ. ### Start training To train the model as done in the paper ``` python main.py NAME --resolution N --pretrained_models_path PRETRAINED_MODELS_PATH --dataset BASE_DATASET_DIR --batch_size BATCH_SIZE --cross_frequency 3 --train_data_size 70000 --results_dir RESULTS_DIR ``` Please run `python main.py -h` for more details. ## Inference For convenience, there are a few inference functions - each serving a different use case. The functions are resolved using the name of the function. ### All possible combinations in dirs

**Input data: Two directories, one identity inputs and another for attribute inputs.**
Runs over all N*M combinations in two directories. ``` python test.py Name --pretrained_models_path PRETRAINED_MODELS_PATH \ --load_checkpoint PATH_TO_WEIGHTS \ --id_dir DIR_OF_IMAGES_FOR_ID \ --attr_dir DIR_OF_IMAGES_FOR_ATTR \ --output_dir DIR_FOR_OUTPUTS \ --test_func infer_on_dirs ``` ### Paired data **Input data: Two directories, one identity inputs and another for attribute inputs**.
The two directories are assumed to be paired. Inference runs on images with the same names. ``` python test.py Name --pretrained_models_path PRETRAINED_MODELS_PATH \ --load_checkpoint PATH_TO_WEIGHTS \ --id_dir DIR_OF_IMAGES_FOR_ID \ --attr_dir DIR_OF_IMAGES_FOR_ATTR \ --output_dir DIR_FOR_OUTPUTS \ --test_func infer_pairs ``` ### Disentangled interpolation #### Interpolating attributes

#### Interpolating identity

**Input data: A directory with any number of subdirectories. In each subdir, there are three images.** All images should have exactly one of *attr* or *id* in their name. If there are two *attr* images and one *id* image, it will interpolate attribute. If there is one *attr* images and two *id* images, it will interpolate identity. ``` python test.py Name --pretrained_models_path PRETRAINED_MODELS_PATH \ --load_checkpoint PATH_TO_WEIGHTS \ --input_dir PARENT_DIR \ --output_dir DIR_FOR_OUTPUTS \ --test_func interpolate ``` ## Checkpoints Our pretrained 256x256 [checkpoint](https://drive.google.com/drive/folders/1lVizq4hCq-zTf8Q3fDqqfSnV6jIYEgY_?usp=sharing) is also available. ## Citation If you use this code for your research, please cite our paper using: ``` @article{Nitzan2020FaceID, title={Face identity disentanglement via latent space mapping}, author={Yotam Nitzan and A. Bermano and Yangyan Li and D. Cohen-Or}, journal={ACM Transactions on Graphics (TOG)}, year={2020}, volume={39}, pages={1 - 14} } ``` ================================================ FILE: arglib/__init__.py ================================================ ================================================ FILE: arglib/arglib.py ================================================ import math import shutil import logging import argparse from pathlib import Path from abc import ABC, abstractmethod class BaseArgs(ABC): def __init__(self): self.args = None self.parser = argparse.ArgumentParser() self.logger = logging.getLogger(self.__class__.__name__) self.add_args() self.parse() self.validate() self.process() self.str_args = self.log() @abstractmethod def add_args(self): # Hardware self.parser.add_argument('--gpu', type=str, default='0') # Model self.parser.add_argument('--face_detection', action='store_true') self.parser.add_argument('--resolution', type=int, default=256, choices=[256, 1024]) self.parser.add_argument('--load_checkpoint') self.parser.add_argument('--pretrained_models_path', type=Path, required=True) BaseArgs.add_bool_arg(self.parser, 'const_noise') # Data self.parser.add_argument('--batch_size', type=int, default=6) self.parser.add_argument('--reals', action='store_true', help='Use real inputs') BaseArgs.add_bool_arg(self.parser, 'test_real_attr') # Log & Results self.parser.add_argument('name', type=str, help='Name under which run will be saved') self.parser.add_argument('--results_dir', type=str, default='../results') self.parser.add_argument('--log_debug', action='store_true') # Other self.parser.add_argument('--debug', action='store_true') def parse(self): self.args = self.parser.parse_args() def log(self): out_str = 'The arguments are:\n' for k, v in self.args.__dict__.items(): out_str += f'{k}: {v}\n' return out_str @staticmethod def add_bool_arg(parser, name, default=True): group = parser.add_mutually_exclusive_group(required=False) group.add_argument('--' + name, dest=name, action='store_true') group.add_argument('--no_' + name, dest=name, action='store_false') parser.set_defaults(**{name: default}) @abstractmethod def validate(self): if self.args.load_checkpoint and not Path(self.args.load_checkpoint).exists(): raise ValueError(f'Checkpoint directory {self.args.load_checkpoint} does not exist') @abstractmethod def process(self): # Log & Results self.args.results_dir = Path(self.args.results_dir).joinpath(self.args.name) if self.args.debug: self.args.log_debug = True if self.args.debug or not self.args.train: shutil.rmtree(self.args.results_dir, ignore_errors=True) self.args.results_dir.mkdir(parents=True, exist_ok=True) self.args.images_results = self.args.results_dir.joinpath('images') self.args.images_results.mkdir(exist_ok=True) # Model if self.args.load_checkpoint: self.args.load_checkpoint = Path(self.args.load_checkpoint) class TrainArgs(BaseArgs): def __init__(self): super().__init__() def add_args(self): super().add_args() self.parser.add_argument('--dataset_path', type=str, default='../my_dataset') self.parser.add_argument('--num_epochs', type=int, default=math.inf) self.parser.add_argument('--cross_frequency', type=int, default=3, help='Once in how many epochs to perform cross-train epoch (0 for never)') self.parser.add_argument('--unified', action='store_true') # Data BaseArgs.add_bool_arg(self.parser, 'train_real_attr', default=False) self.parser.add_argument('--train_data_size', type=int, default=70000, help='How many images to use for training. Others are used as validation') # Losses BaseArgs.add_bool_arg(self.parser, 'id_loss') BaseArgs.add_bool_arg(self.parser, 'landmarks_loss') BaseArgs.add_bool_arg(self.parser, 'pixel_loss') BaseArgs.add_bool_arg(self.parser, 'W_D_loss') BaseArgs.add_bool_arg(self.parser, 'gp') self.parser.add_argument('--pixel_mask_type', choices=['uniform', 'gaussian'], default='gaussian') self.parser.add_argument('--pixel_loss_type', choices=['L1', 'mix'], default='mix') # Test During training self.parser.add_argument('--test_frequency', type=int, default=1000, help='Once in how many epochs to perform a test') self.parser.add_argument('--test_size', type=int, default=50, help='How many mini-batches should be used for a test') self.parser.add_argument('--not_improved_exit', type=int, default=math.inf, help='After how many not-improved test to exit') BaseArgs.add_bool_arg(self.parser, 'test_with_arcface') def validate(self): super().validate() if not Path(self.args.dataset_path).exists(): raise ValueError(f'Dataset at path: {self.args.dataset_path} does not exist') def process(self): self.args.train = True super().process() # Dataset self.args.dataset_path = Path(self.args.dataset_path) self.args.weights_dir = self.args.results_dir.joinpath('weights') self.args.weights_dir.mkdir(exist_ok=True) backup_code_dir = self.args.results_dir.joinpath('code') code_dir = Path().cwd() shutil.copytree(code_dir, backup_code_dir) class TestArgs(BaseArgs): def __init__(self): super().__init__() def add_args(self): super().add_args() self.parser.set_defaults(batch_size=1) self.parser.add_argument('--id_dir', type=Path) self.parser.add_argument('--attr_dir', type=Path) self.parser.add_argument('--output_dir', type=Path) self.parser.add_argument('--input_dir', type=Path) self.parser.add_argument('--real_id', action='store_true') self.parser.add_argument('--real_attr', action='store_true') BaseArgs.add_bool_arg(self.parser, 'loop_fake') self.parser.add_argument('--img_suffixes', type=list, default=['png', 'jpg', 'jpeg']) self.parser.add_argument('--test_func', type=str, choices=['infer_on_dirs', 'infer_pairs', 'interpolate']) self.parser.add_argument('--input', type=str) def validate(self): super().validate() # if not self.args.input: # raise ValueError('Input needed for inference') # if not Path(self.args.input).exists(): # raise ValueError(f'Input {self.args.input} does not exist') def process(self): self.args.train = False super().process() self.args.output_dir.mkdir(exist_ok=True, parents=True) # self.args.input = Path(self.args.input) # Split frame sit alongside input, so not every run needs to preprocess # input_name = self.args.input.stem # self.args.extracted_frames_dir = self.args.input.parent.joinpath(f'{input_name}_frames') # self.args.extracted_frames_dir.mkdir(exist_ok=True) ================================================ FILE: data_loader/__init__.py ================================================ ================================================ FILE: data_loader/data_loader.py ================================================ import logging import numpy as np import tensorflow as tf from utils.general_utils import read_image class DataLoader(object): def __init__(self, args): super().__init__() self.args = args self.logger = logging.getLogger(self.__class__.__name__) self.real_dataset = args.dataset_path.joinpath(f'real') dataset = args.dataset_path.joinpath(f'dataset_{args.resolution}') self.ws_dataset = dataset.joinpath('ws') self.image_dataset = dataset.joinpath('images') max_dir = max([x.name for x in self.image_dataset.iterdir()]) self.max_ind = max([int(x.stem) for x in self.image_dataset.joinpath(max_dir).iterdir()]) self.train_max_ind = args.train_data_size if self.train_max_ind >= self.max_ind: self.logger.warning('There is no validation data... using training data') self.min_val_ind = 0 self.train_max_ind = self.max_ind else: self.min_val_ind = self.train_max_ind + 1 def get_image(self, is_train, black_list=None, is_real=False): # Default should be non-mutable if black_list is None: black_list = [] max_fails = 10 curr_fail = 0 if is_train: min_ind, max_ind = 0, self.train_max_ind else: min_ind, max_ind = self.min_val_ind, self.max_ind while True: ind = np.random.randint(min_ind, max_ind) if ind in black_list: continue img_name = f'{ind:05d}.png' dir_name = f'{int(ind - ind % 1e3):05d}' if is_real: img_path = self.real_dataset.joinpath(dir_name, img_name) else: img_path = self.image_dataset.joinpath(dir_name, img_name) try: img = read_image(img_path, self.args.resolution) break except Exception as e: self.logger.warning(f'Failed reading image at {ind}. Error: {e}') # Try again with a different image... curr_fail += 1 if curr_fail > max_fails: raise IOError('Failed reading multiples images') continue return ind, img def get_w_by_ind(self, ind): dir_name = f'{int(ind - ind % 1e3):05d}' img_name = f'{ind:05d}.npy' w_path = self.ws_dataset.joinpath(dir_name, img_name) w = np.load(w_path) # Take one row while keeping dimension w = w[np.newaxis, 0] return w def get_real_w(self, is_train, black_list=None, is_real=False): ind = np.random.randint(0, self.max_ind) w = self.get_w_by_ind(ind) return ind, w def batch_samples(self, get_sample_func, is_train, black_list=None, is_real=False): batch = [] indices = [] if not black_list: black_list = [] for i in range(self.args.batch_size): ind, sample = get_sample_func(is_train, black_list, is_real) batch.append(sample) indices.append(ind) batch = tf.concat(batch, 0) return indices, batch def get_batch(self, is_train=True, is_cross=False, ws=True): black_list = [] id_imgs_indices, id_img = self.batch_samples(self.get_image, is_train) matching_ws = None self.logger.debug(f'ID images read: {id_imgs_indices}') black_list.extend(id_imgs_indices) if is_cross: # Use real attr when args say so or when testing is_real_attr = (is_train and self.args.train_real_attr) or (not is_train and self.args.test_real_attr) black_list = [] if is_real_attr else black_list attr_imgs_indices, attr_img = self.batch_samples(self.get_image, is_train, black_list=black_list, is_real=is_real_attr) self.logger.debug(f'Attr images read: {attr_imgs_indices}') else: if is_train: attr_img = id_img matching_ws = [self.get_w_by_ind(ind) for ind in id_imgs_indices] matching_ws = tf.concat(matching_ws, 0) else: attr_img = id_img if not is_train: return attr_img, id_img # Only for training real_img = None real_ws = None if self.args.train and self.args.reals: real_imgs_indices, real_img = self.batch_samples(self.get_image, is_train, black_list=[], is_real=True) self.logger.debug(f'Real images read: {real_imgs_indices}') if ws: _, real_ws = self.batch_samples(self.get_real_w, is_train) return attr_img, id_img, real_ws, real_img, matching_ws ================================================ FILE: docs/index.html ================================================ ID disentanglement

Face Identity Disentanglement via Latent Space Mapping

SIGGRAPH ASIA 2020

Yotam Nitzan1   Amit Bermano1   Yangyan Li2   Daniel Cohen-Or1
1Tel-Aviv University     2Alibaba Cloud Intelligence Business Group

Paper

Code

Abstract

Learning disentangled representations of data is a fundamental problem in artificial intelligence. Specifically, disentangled latent representations allow generative models to control and compose the disentangled factors in the synthesis process. Current methods, however, require extensive supervision and training, or instead, noticeably compromise quality. In this paper, we present a method that learns how to represent data in a disentangled way, with minimal supervision, manifested solely using available pre-trained networks. Our key insight is to decouple the processes of disentanglement and synthesis, by employing a leading pre-trained unconditional image generator, such as StyleGAN. By learning to map into its latent space, we leverage both its state-of-the-art quality, and its rich and expressive latent space, without the burden of training it. We demonstrate our approach on the complex and high dimensional domain of human heads. We evaluate our method qualitatively and quantitatively, and exhibit its success with de-identification operations and with temporal identity coherency in image sequences. Through extensive experimentation, we show that our method successfully disentangles identity from other facial attributes, surpassing existing methods, even though they require more training and supervision.

motivation

learning disentangled representations and image synthesis are different tasks. however, it is a common practice to solve both simultaneously. this way, the image generator learns the semantics of the representations. now it is able to take multiple representations from different sources and mix them to generate novel images. but this comes at a price, one now needs to solve two difficult tasks simultaneously. this often causes the need to devise dedicated architectures and even then, achieve sub-optimal visual quality.

we propose a different approach. unconditional generators have recently achieved amazing image quality. we take advantage of this fact, and avoid solving this task ourselves. instead, we suggest to use a pretrained generator, such as stylegan. but now, how can the generator, which is pretrained & unconditional, make sense of the disentangled representations?
we suggest mapping the disentangled representations directly into the latent space of the generator. the mapping produces in a single feed-forward a new, never before seen, latent code that corresponds to novel images.

Composition Results

We demonstrate our method on the domain of human faces - specifically disentangling identity from all other attributes.
In the following tables the identity is taken from the image on top and the attributes are taken from the left most image. In this figure, the inputs themselves are StyleGAN generated images.
More results, but this time, the input images are real.

Disentangled Interpolation

Thanks to our disentangled representations, we are able to interpolate only a single feature (identity or attributes) in the generator's latent space. This enables more control and opens the door for new disentangled editing capabilities.

Contact

yotamnitzan at gmail dot com
================================================ FILE: docs/mainpage.css ================================================ body { font-family: 'Lato', sans-serif; font-weight: 300; color: #333; font-size: 16px; } h1 { font-size: 40px; color: #555; font-weight: 400; text-align: center; margin: 0; padding: 0; margin-top: 30px; margin-bottom: 10px; } .authors { color: #222; font-size: 24px; font-weight: 300; text-align: center; margin: 0; padding: 0; margin-bottom: 0px; } .logoimg { text-align: center; margin-bottom: 30px; } .container-fluid { margin-top: 5px; margin-bottom: 5px; } .container { margin-top: 10px; } #footer { margin-bottom: 100px; } .thumbs { -webkit-box-shadow: 1px 1px 3px #999; -moz-box-shadow: 1px 1px 3px #999; box-shadow: 1px 1px 3px #999; margin-bottom: 20px; } h2 { font-size: 24px; font-weight: 900; border-bottom: 1px solid #999; margin-bottom: 20px; } .space { margin-bottom: 1.5cm; } .text-primary { color: #5da2d5 !important; } .text-primary:hover { color: #f3d250 !important; opacity: 1.0; } ================================================ FILE: docs/setup.md ================================================ # Setup ## Environment It's designed to use Tensorflow 2.X on python (3.7), using cuda 10.1 and cudnn 7.6.5. Run `conda create -n environment.yml` to create a conda environment that has the needed dependencies. Tested with Tensorflow 2.0.0, Python 3.7.9, Ubuntu 14.04. ## Third-party pretrained networks Our method relies on several pretrained networks. Some are needed only for training and some also for inference. Download according to your intention. Put all downloaded files/directories under a single directory, which will be the baseline path for all pretrained networks. | Name | Training | Inference |Description | :--- | :----------:| :----------:| :---------- |[FFHQ StyleGAN 256x256](https://drive.google.com/drive/folders/1OgLvUhd9FX9_mPXrfqAWaLZsceQzE9l4?usp=sharing) | :heavy_check_mark: | :heavy_check_mark: | StyleGAN model pretrained on FFHQ with 256x256 resolution. Converted using [StyleGAN-Tensorflow2](https://github.com/YotamNitzan/StyleGAN-Tensorflow2) |[FFHQ StyleGAN 1024x1024](https://drive.google.com/drive/folders/1jQxJsmapu6SjygvJfvP4-YVxZ9f5Hu_N?usp=sharing) | :heavy_check_mark: | :heavy_check_mark: | StyleGAN model pretrained on FFHQ with 1024x1024 resolution. Converted using [StyleGAN-Tensorflow2](https://github.com/YotamNitzan/StyleGAN-Tensorflow2) |[VGGFace2](https://drive.google.com/file/d/1I_JyR7LH-30hEIpD4OSFVg2TOf9Q8cqU/view?usp=sharing) | :heavy_check_mark: | :heavy_check_mark: | Pretrained VGGFace2 model taken from [WeidiXie](https://github.com/WeidiXie/Keras-VGGFace2-ResNet50). |[dlib landmarks model](http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2) | | :heavy_check_mark: | dlib landmarks model, used to align images. |[ArcFace](https://drive.google.com/drive/folders/1F-Ll9Nw7I1FGP61cpQxOdhs2nxi0E5mg?usp=sharing) | :heavy_check_mark: | | Pretrained ArcFace model taken from [dmonterom](https://github.com/dmonterom/face_recognition_TF2). |[Face & Landmarks Detection](https://drive.google.com/drive/folders/1D__J9UMwzBNR9eVrQGYuL9ueYGi7G4qh?usp=sharing) | :heavy_check_mark: | | Pretrained face detection and differentiable facial landmarks detection from [610265158](https://github.com/610265158/face_landmark). ### Other StyleGANs To try out our method with other checkpoints of StyleGAN, first obtain a trained StyleGAN pkl file using the [original StyleGAN repository](https://github.com/NVlabs/stylegan) Next, convert it to Tensorflow-2.0 using this [repository](https://github.com/YotamNitzan/StyleGAN-Tensorflow2). ================================================ FILE: environment.yml ================================================ name: id_disen channels: - conda-forge - defaults dependencies: - _libgcc_mutex=0.1=main - _tflow_select=2.1.0=gpu - absl-py=0.11.0=py37h06a4308_0 - aiohttp=3.6.3=py37h7b6447c_0 - astor=0.8.1=py37_0 - async-timeout=3.0.1=py37_0 - attrs=20.3.0=pyhd3eb1b0_0 - blas=1.0=mkl - blinker=1.4=py37_0 - blosc=1.20.1=he1b5a44_0 - brotli=1.0.9=he1b5a44_3 - brotlipy=0.7.0=py37h27cfd23_1003 - bzip2=1.0.8=h516909a_3 - c-ares=1.17.1=h27cfd23_0 - ca-certificates=2020.11.8=ha878542_0 - cachetools=4.1.1=py_0 - certifi=2020.11.8=py37h89c1867_0 - cffi=1.14.3=py37h261ae71_2 - chardet=3.0.4=py37h06a4308_1003 - charls=2.1.0=he1b5a44_2 - click=7.1.2=py_0 - cloudpickle=1.6.0=py_0 - cryptography=3.2.1=py37h3c74f83_1 - cudatoolkit=10.0.130=0 - cudnn=7.6.5=cuda10.0_0 - cupti=10.0.130=0 - cycler=0.10.0=py_2 - cytoolz=0.11.0=py37h8f50634_1 - dask-core=2.30.0=py_0 - decorator=4.4.2=py_0 - freetype=2.10.4=h7ca028e_0 - gast=0.2.2=py37_0 - giflib=5.2.1=h36c2ea0_2 - google-auth=1.23.0=pyhd3eb1b0_0 - google-auth-oauthlib=0.4.2=pyhd3eb1b0_2 - google-pasta=0.2.0=py_0 - grpcio=1.31.0=py37hf8bcb03_0 - h5py=2.10.0=py37hd6299e0_1 - hdf5=1.10.6=hb1b8bf9_0 - idna=2.10=py_0 - imagecodecs=2020.5.30=py37hda6ee5b_1 - imageio=2.9.0=py_0 - importlib-metadata=2.0.0=py_1 - intel-openmp=2020.2=254 - jpeg=9d=h36c2ea0_0 - jxrlib=1.1=h516909a_2 - keras-applications=1.0.8=py_1 - keras-preprocessing=1.1.0=py_1 - kiwisolver=1.3.1=py37hc928c03_0 - lcms2=2.11=hcbb858e_1 - ld_impl_linux-64=2.33.1=h53a641e_7 - libaec=1.0.4=he1b5a44_1 - libedit=3.1.20191231=h14c3975_1 - libffi=3.3=he6710b0_2 - libgcc-ng=9.1.0=hdf63c60_0 - libgfortran-ng=7.3.0=hdf63c60_0 - libpng=1.6.37=h21135ba_2 - libprotobuf=3.13.0.1=hd408876_0 - libstdcxx-ng=9.1.0=hdf63c60_0 - libtiff=4.1.0=h4f3a223_6 - libwebp-base=1.1.0=h36c2ea0_3 - libzopfli=1.0.3=he1b5a44_0 - lz4-c=1.9.2=he1b5a44_3 - markdown=3.3.3=py37h06a4308_0 - matplotlib-base=3.3.3=py37h4f6019d_0 - mkl=2020.2=256 - mkl-service=2.3.0=py37he904b0f_0 - mkl_fft=1.2.0=py37h23d657b_0 - mkl_random=1.1.1=py37h0573a6f_0 - multidict=4.7.6=py37h7b6447c_1 - ncurses=6.2=he6710b0_1 - networkx=2.5=py_0 - numpy=1.19.2=py37h54aff64_0 - numpy-base=1.19.2=py37hfa32c7d_0 - oauthlib=3.1.0=py_0 - olefile=0.46=pyh9f0ad1d_1 - openjpeg=2.3.1=h981e76c_3 - openssl=1.1.1h=h516909a_0 - opt_einsum=3.1.0=py_0 - pillow=8.0.1=py37h63a5d19_0 - pip=20.2.4=py37h06a4308_0 - protobuf=3.13.0.1=py37he6710b0_1 - pyasn1=0.4.8=py_0 - pyasn1-modules=0.2.8=py_0 - pycparser=2.20=py_2 - pyjwt=1.7.1=py37_0 - pyopenssl=19.1.0=pyhd3eb1b0_1 - pyparsing=2.4.7=pyh9f0ad1d_0 - pysocks=1.7.1=py37_1 - python=3.7.9=h7579374_0 - python-dateutil=2.8.1=py_0 - python_abi=3.7=1_cp37m - pywavelets=1.1.1=py37h161383b_3 - readline=8.0=h7b6447c_0 - requests=2.24.0=py_0 - requests-oauthlib=1.3.0=py_0 - rsa=4.6=py_0 - scikit-image=0.17.2=py37h10a2094_4 - scipy=1.5.2=py37h0b6359f_0 - setuptools=50.3.1=py37h06a4308_1 - six=1.15.0=py37h06a4308_0 - snappy=1.1.8=he1b5a44_3 - sqlite=3.33.0=h62c20be_0 - tensorboard-plugin-wit=1.6.0=py_0 - tensorflow=2.0.0=gpu_py37h768510d_0 - tensorflow-base=2.0.0=gpu_py37h0ec5d1f_0 - tensorflow-estimator=2.0.0=pyh2649769_0 - termcolor=1.1.0=py37_1 - tifffile=2020.11.18=pyhd8ed1ab_0 - tk=8.6.10=hbc83047_0 - toolz=0.11.1=py_0 - tornado=6.1=py37h4abf009_0 - urllib3=1.25.11=py_0 - werkzeug=0.16.1=py_0 - wheel=0.35.1=pyhd3eb1b0_0 - wrapt=1.12.1=py37h7b6447c_1 - xz=5.2.5=h7b6447c_0 - yaml=0.2.5=h516909a_0 - yarl=1.6.2=py37h7b6447c_0 - zipp=3.4.0=pyhd3eb1b0_0 - zlib=1.2.11=h7b6447c_3 - zstd=1.4.5=h6597ccf_2 - pip: - dlib==19.21.0 - keras==2.3.1 - mtcnn==0.1.0 - opencv-python==4.4.0.46 - pyyaml==5.3.1 - tensorboard==2.0.2 - tensorflow-addons==0.6.0 - tensorflow-gpu==2.0.0 - tqdm==4.53.0 ================================================ FILE: inference.py ================================================ from pathlib import Path from tqdm import tqdm import tensorflow as tf from writer import Writer from utils import general_utils as utils class Inference(object): def __init__(self, args, model): self.args = args self.G = model.G def infer_pairs(self): names = [f for f in self.args.id_dir.iterdir() if f.suffix[1:] in self.args.img_suffixes] names.extend([f for f in self.args.attr_dir.iterdir() if f.suffix[1:] in self.args.img_suffixes]) for img_name in tqdm(names): id_path = utils.find_file_by_str(self.args.id_dir, img_name.stem) attr_path = utils.find_file_by_str(self.args.attr_dir, img_name.stem) if len(id_path) != 1 or len(attr_path) != 1: print(f'Could not find a single pair with name: {img_name.stem}') continue id_img = utils.read_image(id_path, self.args.resolution, self.args.reals) attr_img = utils.read_image(attr_path, self.args.resolution, self.args.reals) out_img = self.G(id_img, attr_img)[0] utils.save_image(out_img, self.args.output_dir.joinpath(f'{img_name.name}')) def infer_on_dirs(self): attr_paths = list(self.args.attr_dir.iterdir()) attr_paths.sort() id_paths = list(self.args.id_dir.iterdir()) id_paths.sort() for attr_num, attr_img_path in tqdm(enumerate(attr_paths)): if not attr_img_path.is_file() or attr_img_path.suffix[1:] not in self.args.img_suffixes: continue attr_img = utils.read_image(attr_img_path, self.args.resolution, self.args.reals) attr_dir = self.args.output_dir.joinpath(f'attr_{attr_num}') attr_dir.mkdir(exist_ok=True) utils.save_image(attr_img, attr_dir.joinpath(f'attr_image.png')) for id_num, id_img_path in enumerate(id_paths): if not id_img_path.is_file() or id_img_path.suffix[1:] not in self.args.img_suffixes: continue id_img = utils.read_image(id_img_path, self.args.resolution, self.args.reals) pred = self.G(id_img, attr_img)[0] utils.save_image(pred, attr_dir.joinpath(f'prediction_{id_num}.png')) utils.save_image(id_img, attr_dir.joinpath(f'id_{id_num}.png')) def interpolate(self, w_space=True): # Change to 0,1 for interpolation extra_start = 0 extra_end = 1 L = extra_end - extra_start # Extrapolation values include the 0,1 iff # N-1 is divisible by L if including endpoint # N is divisble by L o.w # where L is the length of the extrapolation range ( L = b-a for [a,b] ) # and N is number of jumps num_jumps = 8 * L + 1 for d in self.args.input_dir.iterdir(): out_d = self.args.output_dir.joinpath(d.name) out_d.mkdir(exist_ok=True) ids = list(d.glob('*id*')) attrs = list(d.glob('*attr*')) if len(ids) == 1 and len(attrs) == 2: const = 'id' elif len(ids) == 2 and len(attrs) == 1: const = 'attr' else: print(f'Wrong data format for {d.name}') continue if const == 'id': start_img = utils.read_image(attrs[0], self.args.resolution, self.args.real_attr) end_img = utils.read_image(attrs[1], self.args.resolution, self.args.real_attr) const_img = utils.read_image(ids[0], self.args.resolution, self.args.real_id) if self.args.loop_fake: if not self.args.real_attr: start_img = self.G(start_img, start_img) end_img = self.G(end_img, end_img) if not self.args.real_id: const_img = self.G(const_img, const_img) const_id = self.G.id_encoder(const_img) start_attr = self.G.attr_encoder(start_img) end_attr = self.G.attr_encoder(end_img) s_z = tf.concat([const_id, start_attr], -1) e_z = tf.concat([const_id, end_attr], -1) elif const == 'attr': start_img = utils.read_image(ids[0], self.args.resolution, self.args.real_id) end_img = utils.read_image(ids[1], self.args.resolution, self.args.real_id) const_img = utils.read_image(attrs[0], self.args.resolution, self.args.real_attr) if self.args.loop_fake: if not self.args.real_attr: const_img = self.G(const_img, const_img)[0] if not self.args.real_id: start_img = self.G(start_img, start_img)[0] end_img = self.G(end_img, end_img)[0] start_id = self.G.id_encoder(start_img) end_id = self.G.id_encoder(end_img) const_attr = self.G.attr_encoder(const_img) s_z = tf.concat([start_id, const_attr], -1) e_z = tf.concat([end_id, const_attr], -1) utils.save_image(const_img, out_d.joinpath(f'const_{const}.png')) utils.save_image(start_img, out_d.joinpath(f'start.png')) utils.save_image(end_img, out_d.joinpath(f'end.png')) if w_space: s_w = self.G.latent_spaces_mapping(s_z) e_w = self.G.latent_spaces_mapping(e_z) for i in range(num_jumps): inter_w = (1 - i / num_jumps) * s_w + (i / num_jumps) * e_w out = self.G.stylegan_s(inter_w) out = (out + 1) / 2 utils.save_image(out[0], out_d.joinpath(f'inter_{i:03}.png')) else: for i in range(num_jumps): inter_z = (1 - i / num_jumps) * s_z + (i / num_jumps) * e_z inter_w = self.G.latent_spaces_mapping(inter_z) out = self.G.stylegan_s(inter_w) out = (out + 1) / 2 utils.save_image(out[0], out_d.joinpath(f'inter_{i:03}.png')) ================================================ FILE: main.py ================================================ import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' os.environ['OMP_NUM_THREADS'] = '1' os.environ['USE_SIMPLE_THREADED_LEVEL3'] = '1' os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' import sys import logging from model.stylegan import StyleGAN_G_synthesis from model.model import Network from data_loader.data_loader import DataLoader from writer import Writer from trainer import Trainer from arglib import arglib from utils import general_utils as utils sys.path.insert(0, 'model/face_utils') def init_logger(args): root_logger = logging.getLogger() level = logging.DEBUG if args.log_debug else logging.INFO root_logger.setLevel(level) file_handler = logging.FileHandler(f'{args.results_dir}/log.txt') console_handler = logging.StreamHandler() datefmt = '%Y-%m-%d %H:%M:%S' formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt) file_handler.setLevel(level) console_handler.setLevel(level) file_handler.setFormatter(formatter) console_handler.setFormatter(formatter) root_logger.addHandler(file_handler) root_logger.addHandler(console_handler) pil_logger = logging.getLogger('PIL.PngImagePlugin') pil_logger.setLevel(logging.INFO) def main(): train_args = arglib.TrainArgs() args, str_args = train_args.args, train_args.str_args os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu init_logger(args) logger = logging.getLogger('main') cmd_line = ' '.join(sys.argv) logger.info(f'cmd line is: \n {cmd_line}') logger.info(str_args) logger.debug('Copying src to results dir') Writer.set_writer(args.results_dir) if not args.debug: description = input('Please write a short description of this run\n') desc_file = args.results_dir.joinpath('description.txt') with desc_file.open('w') as f: f.write(description) id_model_path = args.pretrained_models_path.joinpath('vggface2.h5') stylegan_G_synthesis_path = str( args.pretrained_models_path.joinpath(f'stylegan_G_{args.resolution}x{args.resolution}_synthesis')) landmarks_model_path = str(args.pretrained_models_path.joinpath('face_utils/keypoints')) face_detection_model_path = str(args.pretrained_models_path.joinpath('face_utils/detector')) arcface_model_path = str(args.pretrained_models_path.joinpath('arcface_weights/weights-b')) utils.landmarks_model_path = str(args.pretrained_models_path.joinpath('shape_predictor_68_face_landmarks.dat')) stylegan_G_synthesis = StyleGAN_G_synthesis(resolution=args.resolution, is_const_noise=args.const_noise) stylegan_G_synthesis.load_weights(stylegan_G_synthesis_path) network = Network(args, id_model_path, stylegan_G_synthesis, landmarks_model_path, face_detection_model_path, arcface_model_path) data_loader = DataLoader(args) trainer = Trainer(args, network, data_loader) trainer.train() if __name__ == '__main__': main() ================================================ FILE: model/__init__.py ================================================ ================================================ FILE: model/arcface/arcface.py ================================================ import tensorflow as tf import math num_classes = 85742 # 10572 initializer = 'glorot_normal' # initializer = tf.keras.initializers.TruncatedNormal( # mean=0.0, stddev=0.05, seed=None) # initializer = tf.keras.initializers.VarianceScaling( # scale=0.05, mode='fan_avg', distribution='normal', seed=None) class Arcfacelayer(tf.keras.layers.Layer): def __init__(self, output_dim=num_classes, s=64., m=0.50): self.output_dim = output_dim self.s = s self.m = m super(Arcfacelayer, self).__init__() def build(self, input_shape): self.kernel = self.add_weight(name='kernel', shape=(input_shape[-1], self.output_dim), initializer=initializer, regularizer=tf.keras.regularizers.l2( l=5e-4), trainable=True) super(Arcfacelayer, self).build(input_shape) def call(self, embedding, labels): cos_m = math.cos(self.m) sin_m = math.sin(self.m) mm = sin_m * self.m # issue 1 threshold = math.cos(math.pi - self.m) # inputs and weights norm embedding_norm = tf.norm(embedding, axis=1, keepdims=True) embedding = embedding / embedding_norm weights_norm = tf.norm(self.kernel, axis=0, keepdims=True) weights = self.kernel / weights_norm # cos(theta+m) cos_t = tf.matmul(embedding, weights, name='cos_t') cos_t2 = tf.square(cos_t, name='cos_2') sin_t2 = tf.subtract(1., cos_t2, name='sin_2') sin_t = tf.sqrt(sin_t2, name='sin_t') cos_mt = self.s * tf.subtract(tf.multiply(cos_t, cos_m), tf.multiply(sin_t, sin_m), name='cos_mt') # this condition controls the theta+m should in range [0, pi] # 0<=theta+m<=pi # -m<=theta<=pi-m cond_v = cos_t - threshold cond = tf.cast(tf.nn.relu(cond_v, name='if_else'), dtype=tf.bool) keep_val = self.s * (cos_t - mm) cos_mt_temp = tf.where(cond, cos_mt, keep_val) mask = tf.one_hot(labels, depth=self.output_dim, name='one_hot_mask') # mask = tf.squeeze(mask, 1) inv_mask = tf.subtract(1., mask, name='inverse_mask') s_cos_t = tf.multiply(self.s, cos_t, name='scalar_cos_t') output = tf.add(tf.multiply(s_cos_t, inv_mask), tf.multiply( cos_mt_temp, mask), name='arcface_loss_output') return output def compute_output_shape(self, input_shape): return (input_shape[0], self.output_dim) ================================================ FILE: model/arcface/inference.py ================================================ import tensorflow as tf import tensorflow_addons as tfa import numpy as np import cv2 from model.arcface.resnet import ResNet50, train_model from mtcnn import MTCNN from skimage import transform as trans class MyArcFace: def __init__(self, path_to_weights): self.model = train_model() self.model.load_weights(path_to_weights) self.model_resnet = self.model.resnet self.model.resnet.trainable = False self.mtcnn = MTCNN(min_face_size=80) def get_best_face(self, faces, resolution): if len(faces) == 0: raise IndexError('No faces found') if len(faces) == 1: return faces[0] print('Found more than one face') indices = list(range(len(faces))) # filter low confidence new_indices = [ind for ind in indices if faces[ind]['confidence'] > 0.99] # print(f'after confidence filtering: {len(new_indices)}') if len(new_indices) == 1: return faces[new_indices[0]] elif len(new_indices) > 1: indices = new_indices # filter not centered, distance between x and y must relatively small new_indices = [ind for ind in indices if np.abs(faces[ind]['box'][0] - faces[ind]['box'][1]) < resolution / 2.5] # print(f'after center filtering: {len(new_indices)}') if len(new_indices) == 1: return faces[new_indices[0]] elif len(new_indices) > 1: indices = new_indices # Take box with biggest height ind = max(indices, key=lambda ind: faces[ind]['box'][-1]) return faces[ind] def __detect_face(self, img): # The assumption is that the image is RGB faces = self.mtcnn.detect_faces(img) face_obj = self.get_best_face(faces, img.shape[0]) face_box_obj = face_obj['box'] face_landmarks_obj = face_obj['keypoints'] face_landmarks = np.zeros((5, 2)) face_landmarks[0] = [face_landmarks_obj['left_eye'][0], face_landmarks_obj['right_eye'][1]] face_landmarks[1] = [face_landmarks_obj['right_eye'][0], face_landmarks_obj['left_eye'][1]] face_landmarks[2] = [face_landmarks_obj['nose'][0], face_landmarks_obj['nose'][1]] face_landmarks[3] = [face_landmarks_obj['mouth_left'][0], face_landmarks_obj['mouth_right'][1]] face_landmarks[4] = [face_landmarks_obj['mouth_right'][0], face_landmarks_obj['mouth_left'][1]] x = face_box_obj[0] y = face_box_obj[1] w = face_box_obj[2] h = face_box_obj[3] face_box = [x, y, x + w, y + h] return face_box, face_landmarks def __preprocess(self, img, bbox=None, landmark=None): M = None image_size = [112, 112] assert landmark is not None src = np.array([ [30.2946, 51.6963], [65.5318, 51.5014], [48.0252, 71.7366], [33.5493, 92.3655], [62.7299, 92.2041]], dtype=np.float32) if image_size[1] == 112: src[:, 0] += 8.0 dst = landmark.astype(np.float32) tform = trans.SimilarityTransform() tform.estimate(src, dst) M = tform.params assert M is not None transforms = np.array(M).flatten()[:-1] tf_transforms = tf.constant([transforms], tf.float32) img_tensor = tf.convert_to_tensor(img.astype(np.float32)) batch = tf.stack([img_tensor]) output = tfa.image.transform(batch, tf_transforms, interpolation='BILINEAR', output_shape=image_size) return output def process_image(self, img): if (isinstance(img, tf.Tensor) and img.dtype != tf.dtypes.uint8) or img.dtype != np.uint8: img = np.uint8(img * 255) face_box, face_landmarks = self.__detect_face(img) aligned_face = self.__preprocess(img, face_box, face_landmarks) aligned_face -= 127.5 aligned_face *= 0.0078125 embeddings = self.model_resnet(aligned_face) normelized_embeddings = tf.math.l2_normalize(embeddings) return normelized_embeddings def __call__(self, img): if img.ndim == 4: embedding_list = [] for x in img: norm_embedding = self.process_image(x) embedding_list.append(norm_embedding) return np.array(embedding_list) else: return self.process_image(img) ================================================ FILE: model/arcface/resnet.py ================================================ import tensorflow as tf import os from model.arcface.arcface import Arcfacelayer bn_axis = -1 initializer = 'glorot_normal' def residual_unit_v3(input, num_filter, stride, dim_match, name): x = tf.keras.layers.BatchNormalization(axis=bn_axis, scale=True, momentum=0.9, epsilon=2e-5, # beta_regularizer=tf.keras.regularizers.l2( # l=5e-4), gamma_regularizer=tf.keras.regularizers.l2( l=5e-4), name=name + '_bn1')(input) x = tf.keras.layers.ZeroPadding2D( padding=(1, 1), name=name + '_conv1_pad')(x) x = tf.keras.layers.Conv2D(num_filter, (3, 3), strides=(1, 1), padding='valid', kernel_initializer=initializer, use_bias=False, kernel_regularizer=tf.keras.regularizers.l2( l=5e-4), name=name + '_conv1')(x) x = tf.keras.layers.BatchNormalization(axis=bn_axis, scale=True, momentum=0.9, epsilon=2e-5, # beta_regularizer=tf.keras.regularizers.l2( # l=5e-4), gamma_regularizer=tf.keras.regularizers.l2( l=5e-4), name=name + '_bn2')(x) x = tf.keras.layers.PReLU(name=name + '_relu1', alpha_regularizer=tf.keras.regularizers.l2( l=5e-4))(x) x = tf.keras.layers.ZeroPadding2D( padding=(1, 1), name=name + '_conv2_pad')(x) x = tf.keras.layers.Conv2D(num_filter, (3, 3), strides=stride, padding='valid', kernel_initializer=initializer, use_bias=False, kernel_regularizer=tf.keras.regularizers.l2( l=5e-4), name=name + '_conv2')(x) x = tf.keras.layers.BatchNormalization(axis=bn_axis, scale=True, momentum=0.9, epsilon=2e-5, # beta_regularizer=tf.keras.regularizers.l2( # l=5e-4), gamma_regularizer=tf.keras.regularizers.l2( l=5e-4), name=name + '_bn3')(x) if (dim_match): shortcut = input else: shortcut = tf.keras.layers.Conv2D(num_filter, (1, 1), strides=stride, padding='valid', kernel_initializer=initializer, use_bias=False, kernel_regularizer=tf.keras.regularizers.l2( l=5e-4), name=name + '_conv1sc')(input) shortcut = tf.keras.layers.BatchNormalization(axis=bn_axis, scale=True, momentum=0.9, epsilon=2e-5, # beta_regularizer=tf.keras.regularizers.l2( # l=5e-4), gamma_regularizer=tf.keras.regularizers.l2( l=5e-4), name=name + '_sc')(shortcut) return x + shortcut def get_fc1(input): x = tf.keras.layers.BatchNormalization(axis=bn_axis, scale=True, momentum=0.9, epsilon=2e-5, # beta_regularizer=tf.keras.regularizers.l2( # l=5e-4), gamma_regularizer=tf.keras.regularizers.l2( l=5e-4), name='bn1')(input) x = tf.keras.layers.Dropout(0.4)(x) resnet_shape = input.shape x = tf.keras.layers.Reshape( [resnet_shape[1] * resnet_shape[2] * resnet_shape[3]], name='reshapelayer')(x) x = tf.keras.layers.Dense(512, name='E_DenseLayer', kernel_initializer=initializer, kernel_regularizer=tf.keras.regularizers.l2( l=5e-4), bias_regularizer=tf.keras.regularizers.l2( l=5e-4))(x) x = tf.keras.layers.BatchNormalization(axis=-1, scale=False, momentum=0.9, epsilon=2e-5, # beta_regularizer=tf.keras.regularizers.l2( # l=5e-4), name='fc1')(x) return x def ResNet50(): input_shape = [112, 112, 3] filter_list = [64, 64, 128, 256, 512] units = [3, 4, 14, 3] num_stages = 4 img_input = tf.keras.layers.Input(shape=input_shape) x = tf.keras.layers.ZeroPadding2D( padding=(1, 1), name='conv0_pad')(img_input) x = tf.keras.layers.Conv2D(64, (3, 3), strides=(1, 1), padding='valid', kernel_initializer=initializer, use_bias=False, kernel_regularizer=tf.keras.regularizers.l2( l=5e-4), name='conv0')(x) x = tf.keras.layers.BatchNormalization(axis=bn_axis, scale=True, momentum=0.9, epsilon=2e-5, # beta_regularizer=tf.keras.regularizers.l2( # l=5e-4), gamma_regularizer=tf.keras.regularizers.l2( l=5e-4), name='bn0')(x) # x = tf.keras.layers.Activation('prelu')(x) x = tf.keras.layers.PReLU( name='prelu0', alpha_regularizer=tf.keras.regularizers.l2( l=5e-4))(x) for i in range(num_stages): x = residual_unit_v3(x, filter_list[i + 1], (2, 2), False, name='stage%d_unit%d' % (i + 1, 1)) for j in range(units[i] - 1): x = residual_unit_v3(x, filter_list[i + 1], (1, 1), True, name='stage%d_unit%d' % (i + 1, j + 2)) x = get_fc1(x) # Create model. model = tf.keras.models.Model(img_input, x, name='resnet50') model.trainable = True for i in range(len(model.layers)): model.layers[i].trainable = True # if ('conv0' in model.layers[i].name): # model.layers[i].trainable = False # if ('bn0' in model.layers[i].name): # model.layers[i].trainable = False # if ('prelu0' in model.layers[i].name): # model.layers[i].trainable = False # if ('stage1' in model.layers[i].name): # model.layers[i].trainable = False # if ('stage2' in model.layers[i].name): # model.layers[i].trainable = False # if ('stage3' in model.layers[i].name): # model.layers[i].trainable = False # if ('stage4' in model.layers[i].name): # model.layers[i].trainable = False return model class train_model(tf.keras.Model): def __init__(self): super(train_model, self).__init__() self.resnet = ResNet50() self.arcface = Arcfacelayer() def call(self, x, y): x = self.resnet(x) return self.arcface(x, y) ================================================ FILE: model/attr_encoder.py ================================================ import logging import tensorflow as tf from tensorflow.keras import Model from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input class AttrEncoder(Model): def __init__(self, args): super().__init__() self.args = args self.logger = logging.getLogger(__class__.__name__) attr_encoder = InceptionV3(include_top=False, pooling='avg') self.model = attr_encoder if self.args.load_checkpoint: self.model.load_weights(str(self.args.load_checkpoint.joinpath(self.__class__.__name__ + '.h5'))) @tf.function def call(self, input_x): x = tf.image.resize(input_x, (299, 299)) x = preprocess_input(255 * x) x = self.model(x) x = tf.expand_dims(x, 1) return x def my_save(self, reason=''): self.model.save_weights(str(self.args.weights_dir.joinpath(self.__class__.__name__ + reason + '.h5'))) ================================================ FILE: model/discriminator.py ================================================ import tensorflow as tf from tensorflow.keras import layers, Model from utils.general_utils import get_weights # Discriminate between my w's and StyleGAN's w's class W_D(Model): def __init__(self, args): super().__init__() self.args = args slope = 0.2 # self.linear1 = layers.Dense(512, kernel_initializer=get_weights(slope), input_shape=(512,)) self.linear2 = layers.Dense(256, kernel_initializer=get_weights(slope), input_shape=(512,)) self.linear3 = layers.Dense(128, kernel_initializer=get_weights(slope)) self.linear4 = layers.Dense(64, kernel_initializer=get_weights(slope)) self.linear5 = layers.Dense(1, kernel_initializer=get_weights(slope)) self.relu = layers.LeakyReLU(slope) if self.args.load_checkpoint: self.build(input_shape=(1, 1, 512)) self.load_weights(str(self.args.load_checkpoint.joinpath(self.__class__.__name__ + '.h5'))) @tf.function def call(self, x): # x = self.linear1(x) # x = self.relu(x) x = self.linear2(x) x = self.relu(x) x = self.linear3(x) x = self.relu(x) x = self.linear4(x) x = self.relu(x) x = self.linear5(x) return x def my_save(self, reason=''): self.save_weights(str(self.args.weights_dir.joinpath(self.__class__.__name__ + reason + '.h5'))) ================================================ FILE: model/face_detector.py ================================================ import tensorflow as tf class FaceDetector(object): def __init__(self, args, model_path): super().__init__() self.args = args self.model_path = model_path self.model = None def _build(self): if not self.model: self.model = tf.saved_model.load(self.model_path) def __call__(self, input_x): """ Given a batch of images, return the face bounding box in (x1,y1,x2,y2) format """ if not self.model: self._build() boxes = [] for sample in input_x: boxes.append(self.sample_call(sample)) boxes = tf.stack(boxes, axis=0) boxes = boxes * self.args.resolution return boxes def sample_call(self, input_x): boxes = self.model.inference(tf.expand_dims(input_x, axis=0)) boxes = tf.squeeze(boxes) indices, scores = \ tf.image.non_max_suppression_with_scores(boxes[..., :4], boxes[..., 4], max_output_size=1, iou_threshold=0.3, score_threshold=0.5) i = indices.numpy()[0] box = boxes[i, :4] return box ================================================ FILE: model/generator.py ================================================ import logging from model import id_encoder from model import attr_encoder from model import latent_mapping from model import landmarks from model.arcface.inference import MyArcFace import tensorflow as tf from tensorflow.keras import layers, Model class G(Model): def __init__(self, args, id_model_path, image_G, landmarks_net_path, face_detection_model_path, test_id_model_path): super().__init__() self.args = args self.logger = logging.getLogger(__class__.__name__) self.id_encoder = id_encoder.IDEncoder(args, id_model_path) self.id_encoder.trainable = False self.attr_encoder = attr_encoder.AttrEncoder(args) self.latent_spaces_mapping = latent_mapping.LatentMappingNetwork(args) self.stylegan_s = image_G self.stylegan_s.trainable = False if args.train: self.test_id_encoder = MyArcFace(test_id_model_path) self.test_id_encoder.trainable = False self.landmarks = landmarks.LandmarksDetector(args, landmarks_net_path, face_detection_model_path) self.landmarks.trainable = False @tf.function def call(self, x1, x2): id_embedding = self.id_encoder(x1) if self.args.train: lnds = self.landmarks(x2) else: lnds = None attr_input = x2 attr_out = self.attr_encoder(attr_input) attr_embedding = attr_out z_tag = tf.concat([id_embedding, attr_embedding], -1) w = self.latent_spaces_mapping(z_tag) out = self.stylegan_s(w) # Move to roughly [0,1] out = (out + 1) / 2 return out, id_embedding, attr_out, w[:, 0, :], lnds def my_save(self, reason=''): self.attr_encoder.my_save(reason) self.latent_spaces_mapping.my_save(reason) ================================================ FILE: model/id_encoder.py ================================================ import tensorflow as tf import numpy as np from tensorflow.keras import Model class IDEncoder(Model): def __init__(self, args, model_path, intermediate_layers_names=None): super().__init__() self.args = args self.mean = (91.4953, 103.8827, 131.0912) base_model = tf.keras.models.load_model(model_path) if intermediate_layers_names: outputs = [base_model.get_layer(name).output for name in intermediate_layers_names] else: outputs = [] # Add output of the network in any case outputs.append(base_model.layers[-2].output) self.model = tf.keras.Model(base_model.inputs, outputs) def crop_faces(self, img): ps = [] for i in range(img.shape[0]): oneimg = img[i] try: box = tf.numpy_function(self.mtcnn.detect_faces, [oneimg], np.uint8) box = [z.numpy() for z in box[:4]] x1, y1, w, h = box x_expand = w * 0.3 y_expand = h * 0.3 x1 = int(np.maximum(x1 - x_expand // 2, 0)) y1 = int(np.maximum(y1 - y_expand // 2, 0)) x2 = int(np.minimum(x1 + w + x_expand // 2, self.args.resolution)) y2 = int(np.minimum(y1 + h + y_expand // 2, self.args.resolution)) except Exception as e: x1, y1, x2, y2 = 24, 50, 224, 250 p = oneimg[y1:y2, x1:x2, :] p = tf.convert_to_tensor(p) p = tf.image.resize(p, (self.args.resolution, self.args.resolution)) ps.append(p) ps = tf.stack(ps, 0) return ps def preprocess(self, img): """ In VGGFace2 The preprocessing is: 1. Face detection 2. Expand bbox by factor of 0.3 3. Resize so shorter side is 256 4. Crop center 224x224 In StyleGAN faces are not in-the-wild, we get an image of the head. Just cropping a loose center instead of face detection """ # Go from [0, 1] to [0, 255] img = 255 * img min_x = int(0.1 * self.args.resolution) max_x = int(0.9 * self.args.resolution) min_y = int(0.1 * self.args.resolution) max_y = int(0.9 * self.args.resolution) img = img[:, min_x:max_x, min_y:max_y, :] img = tf.image.resize(img, (256, 256)) start = (256 - 224) // 2 img = img[:, start: 224 + start, start: 224 + start, :] img = img[:, :, :, ::-1] - self.mean return img @tf.function def call(self, input_x, get_intermediate=False): x = self.preprocess(input_x) x = self.model(x) if isinstance(x, list): embedding = x[-1] intermediates = x[:-1] else: embedding = x intermediates = None embedding = tf.math.l2_normalize(embedding, axis=-1) embedding = tf.expand_dims(embedding, 1) if get_intermediate and intermediates: return embedding, intermediates else: return embedding ================================================ FILE: model/landmarks.py ================================================ import cv2 import tensorflow as tf import numpy as np from tensorflow.keras import Model from utils import general_utils as utils from model.face_detector import FaceDetector class LandmarksDetector(Model): def __init__(self, args, model_path, face_detection_model_path): super().__init__() self.args = args self.face_detector = FaceDetector(args, face_detection_model_path) self.expand_ratio = 0.2 # Load without source code self.model = tf.saved_model.load(model_path) # Preprocess def preprocess(self, imgs, face_detection=False): imgs *= 255 if face_detection: imgs, details = self.hard_preprocess(imgs) else: imgs, details = self.lazy_preprocess(imgs) return imgs, details def lazy_preprocess(self, imgs): imgs = tf.image.resize(imgs, (160, 160)) return imgs, 160 def hard_preprocess(self, imgs): bboxes = self.face_detector(imgs) centers = np.array([bboxes[:, 0] + bboxes[:, 2], bboxes[:, 1] + bboxes[:, 3]]).T // 2 # Duplicate center point into column order of x,x,y,y centers = np.repeat(centers, repeats=2, axis=1) # Permute columns order into x,y,x,y centers[:] = utils.np_permute(centers, [0, 2, 1, 3]) # Calculate widths of current bboxes widths = np.transpose([bboxes[:, 2] - bboxes[:, 0]]) # Calculate the maximal expansion max_expand = int(np.ceil(np.max(widths) * self.expand_ratio)) # Pad the image with the maximal expansion. # Useful in case an expanded bounding box goes outside image paddings = tf.constant([[0, 0], [max_expand, max_expand], [max_expand, max_expand], [0, 0]]) pad_imgs = tf.pad(imgs, paddings, mode='CONSTANT', constant_values=127.) # The size of the new square bounding box new_scales = np.floor((1 + 2 * self.expand_ratio) * widths) # Size of step from the center new_half_scales = new_scales // 2 # Repeat step in all directions # Decrease in start point, Increase in end point new_half_scales = np.repeat(new_half_scales, repeats=4, axis=1) * [-1, -1, 1, 1] # Bounding boxes in respect to padded image new_bboxes = centers + new_half_scales + max_expand # tf.image.crop_and_resize requires bounding boxes to be normalized # i.e., between [0,1] and also in order (y,x) normed_bboxes = utils.np_permute(new_bboxes, [1, 0, 3, 2]) / pad_imgs.shape[1] cropped_imgs = tf.image.crop_and_resize(pad_imgs, normed_bboxes, box_indices=range(self.args.batch_size), crop_size=(160, 160)) details = (new_scales, new_bboxes[:,:2], max_expand) return cropped_imgs, details # Postprocess def postprocess(self, landmarks, details, face_detection=False): landmarks = tf.reshape(landmarks, [-1, 68, 2]) if face_detection: return self.hard_postprocess(landmarks, details) else: return self.lazy_postprocess(landmarks, details) def lazy_postprocess(self, batch_lnds, details): scale = details return scale * batch_lnds def hard_postprocess(self, batch_lnds, details): scale, from_origin, pad = details scale = tf.broadcast_to(scale, [scale.shape[0], 2]) scale = tf.expand_dims(scale, axis=1) from_origin = tf.expand_dims(from_origin, axis=1) from_origin = tf.cast(from_origin, tf.dtypes.float32) lnds = batch_lnds * scale + from_origin - pad return lnds @tf.function def call(self, input_x, face_detection=False): # The network input format is a uint8 image (0-255) but in float32 dtype. ^__('')__^ x, details = self.preprocess(input_x, face_detection) batch_lnds = self.model.inference(x)['landmark'] batch_lnds = self.postprocess(batch_lnds, details, face_detection) return batch_lnds[:, 17:, :] ================================================ FILE: model/latent_mapping.py ================================================ from utils.general_utils import get_weights import tensorflow as tf import numpy as np from tensorflow.keras import layers, Model class LatentMappingNetwork(Model): def __init__(self, args): super().__init__() self.args = args input_shape = (2560,) self.linear1 = layers.Dense(2048, input_shape=input_shape) self.linear2 = layers.Dense(1024) self.linear3 = layers.Dense(512, kernel_initializer=get_weights()) self.linear4 = layers.Dense(512, kernel_initializer=get_weights()) self.linears = [self.linear1, self.linear2, self.linear3, self.linear4] self.relu = layers.LeakyReLU(0.2) self.num_styles = int(np.log2(self.args.resolution)) * 2 - 2 if self.args.load_checkpoint: self.build(input_shape=(1, 1, 2560)) self.load_weights(str(self.args.load_checkpoint.joinpath(self.__class__.__name__ + '.h5'))) @tf.function def call(self, x): first = True for layer in self.linears: if not first: x = self.relu(x) x = layer(x) first = False s = list(x.shape) # Duplicate the column vector w along columns for each AdaIN entry s[1] = self.num_styles x = tf.broadcast_to(x, s) return x def my_save(self, reason=''): self.save_weights(str(self.args.weights_dir.joinpath(self.__class__.__name__ + reason + '.h5'))) ================================================ FILE: model/model.py ================================================ import time import sys sys.path.append('..') from utils import general_utils as utils from model import id_encoder, latent_mapping, attr_encoder,\ generator, discriminator, landmarks from model.stylegan import StyleGAN_G, StyleGAN_D import tensorflow as tf from tensorflow.keras import layers, Model class Network(Model): def __init__(self, args, id_net_path, base_generator, landmarks_net_path=None, face_detection_model_path=None, test_id_net_path=None): super().__init__() self.args = args self.G = generator.G(args, id_net_path, base_generator, landmarks_net_path, face_detection_model_path, test_id_net_path) if self.args.train: self.W_D = discriminator.W_D(args) def call(self): raise NotImplemented() def my_save(self, reason): self.G.my_save(reason) if self.args.W_D_loss: self.W_D.my_save(reason) def my_load(self): raise NotImplemented() def train(self): self._set_trainable_behavior(True) def test(self): self._set_trainable_behavior(False) def _set_trainable_behavior(self, trainable): self.G.attr_encoder.trainable = trainable self.G.latent_spaces_mapping.trainable = trainable ================================================ FILE: model/stylegan.py ================================================ import sys import math import numpy as np import tensorflow as tf import matplotlib.pyplot as plt from tensorflow.keras import Model, Sequential from tensorflow.keras.layers import Layer, InputLayer, Multiply, Lambda, Flatten, Dense, Conv2D, Conv2DTranspose from tensorflow.keras.initializers import VarianceScaling from tensorflow.python.keras import backend from tensorflow.python.ops import array_ops def nf(stage, fmap_base=8192, fmap_decay=1.0, fmap_max=512): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) def LeakyReLU(alpha, name): def lrelu(x, alpha): alpha = tf.constant(alpha, dtype=x.dtype, name='alpha') return tf.maximum(x, x * alpha) return Lambda(lambda x: lrelu(x, alpha), name=name) def GetWeights(gain=math.sqrt(2)): return VarianceScaling(gain) def runtime_coef(kernel_size, gain, fmaps_in, fmaps_out, lrmul=1.0): # Equalized learning rate and custom learning rate multiplier. shape = [kernel_size[0], kernel_size[1], fmaps_in, fmaps_out] fan_in = np.prod(shape[:-1]) # [kernel, kernel, fmaps_in, fmaps_out] or [in, out] he_std = gain / np.sqrt(fan_in) # He init init_std = 1.0 / lrmul return he_std * lrmul def pixel_norm(x, epsilon=1e-8): epsilon = tf.constant(epsilon, dtype=x.dtype, name='epsilon') return x * tf.math.rsqrt(tf.reduce_mean(tf.square(x), axis=1, keepdims=True) + epsilon) class PixelNorm(Layer): def __init__(self, name): super(PixelNorm, self).__init__(name=name) def call(self, inputs): return pixel_norm(inputs) class InstanceNorm(Layer): def __init__(self, name): super(InstanceNorm, self).__init__(name=name) def call(self, x): epsilon=1e-8 orig_dtype = x.dtype x = tf.cast(x, tf.float32) x -= tf.reduce_mean(x, axis=[2,3], keepdims=True) epsilon = tf.constant(epsilon, dtype=x.dtype, name='epsilon') x *= tf.math.rsqrt(tf.reduce_mean(tf.square(x), axis=[2,3], keepdims=True) + epsilon) x = tf.cast(x, orig_dtype) return x def Identity(name): return Lambda(lambda x: x, name=name) def Broadcast(name, dlatent_broadcast=18): def broadcast(x): return tf.tile(x[:, np.newaxis], [1, dlatent_broadcast, 1]) return Lambda(lambda x: broadcast(x), name=name) class Truncation(Layer): def __init__(self, name, num_layers=18, truncation_psi=0.7, truncation_cutoff=8): super(Truncation, self).__init__(name=name) self.num_layers = num_layers self.truncation_psi = truncation_psi self.truncation_cutoff = truncation_cutoff def build(self, input_shape): self.dlatent_avg = self.add_variable('dlatent_avg', shape=[int(input_shape[-1])]) def call(self, inputs): layer_idx = np.arange(self.num_layers)[np.newaxis, :, np.newaxis] ones = np.ones(layer_idx.shape, dtype=np.float32) coefs = tf.where(layer_idx < self.truncation_cutoff, self.truncation_psi * ones, ones) def lerp(a,b,t): return a + (b - a) * t return lerp(self.dlatent_avg, inputs, coefs) class DenseLayer(Dense): def __init__(self, units, name, kernel_initializer=GetWeights(), gain=math.sqrt(2), lrmul=1.0): super(DenseLayer, self).__init__(units=units, kernel_initializer=kernel_initializer, name=name) self.gain = gain self.lrmul = lrmul def call(self, inputs): x, b, w = inputs, self.bias * self.lrmul, self.kernel * runtime_coef([1,1], self.gain, inputs.shape[1], self.units, lrmul=self.lrmul) # Input x kernel if len(x.shape) > 2: x = tf.reshape(x, [-1, np.prod([d for d in x.shape[1:]])]) x = tf.matmul(x, w) # Bias if len(x.shape) == 2: return x + b return x + tf.reshape(b, [1, -1, 1, 1]) class Conv2d(Conv2D): def __init__(self, filters, kernel_size, name, gain=math.sqrt(2), lrmul=1.0, kernel_modifier=None, strides=1, use_bias=True): super(Conv2d, self).__init__(filters=filters, kernel_size=kernel_size, kernel_initializer=GetWeights(gain), use_bias=use_bias, padding='same', data_format='channels_first', name=name, strides=strides) self.gain = gain self.lrmul = lrmul self.kernel_modifier = kernel_modifier # Perform convolution with modified kernel then add bias def call(self, inputs): if self.kernel_modifier is None: w = self.kernel else: w = self.kernel_modifier(self.kernel) outputs = self._convolution_op(inputs, w * runtime_coef(self.kernel_size, self.gain, inputs.shape[1], self.filters)) if self.use_bias: b = self.bias * self.lrmul if self.data_format == 'channels_first': outputs = tf.nn.bias_add(outputs, b, data_format='NCHW') else: outputs = tf.nn.bias_add(outputs, b, data_format='NHWC') return outputs class Const(Layer): def __init__(self, name): super(Const, self).__init__(name=name) def build(self, input_shape): self.const = self.add_variable('const', shape=[1,512,4,4]) def call(self, inputs): return tf.tile(self.const, [tf.shape(inputs)[0], 1, 1, 1]) class RandomNoise(Layer): def __init__(self, name, layer_idx): super(RandomNoise, self).__init__(name=name) res = layer_idx // 2 + 2 self.layer_idx = layer_idx self.noise_shape = [1, 1, 2**res, 2**res] def build(self, input_shape): self.noise = self.add_variable('noise', shape=self.noise_shape, initializer=tf.initializers.zeros(), trainable=False) def call(self, inputs): return self.noise class ApplyNoise(Layer): def __init__(self, name, is_const_noise): super(ApplyNoise, self).__init__(name=name) self.is_const_noise = is_const_noise def build(self, input_shape): input_shape = input_shape[0] self.weight = self.add_variable('weight', shape=[input_shape[1]], initializer=tf.initializers.zeros()) def call(self, inputs): x, noise = inputs if not self.is_const_noise: noise = tf.random.normal([tf.shape(x)[0], 1, x.shape[2], x.shape[3]], dtype=x.dtype) return x + noise * tf.reshape(self.weight, [1, -1, 1, 1]) class ApplyBias(Layer): def __init__(self, name, lrmul=1.0): super(ApplyBias, self).__init__(name=name) self.lrmul = lrmul def build(self, input_shape): self.bias = self.add_variable('bias', shape=[input_shape[1]]) def call(self, x): b = self.bias * self.lrmul if len(x.shape) == 2: return x + b return x + tf.reshape(b, [1, -1, 1, 1]) class StridedSlice(Layer): def __init__(self, layer_idx, name): super(StridedSlice, self).__init__(name=name) self.layer_idx = layer_idx def call(self, inputs): return inputs[:, self.layer_idx] class StyleModApply(Layer): def __init__(self, name): super(StyleModApply, self).__init__(name=name) def call(self, inputs): x, style = inputs style = tf.reshape(style, [-1, 2, x.shape[1]] + [1] * (len(x.shape) - 2)) return x * (style[:,0] + 1) + style[:,1] def _blur2d(x, f=[1,2,1], normalize=True, flip=False, stride=1): assert x.shape.ndims == 4 and all(dim is not None for dim in x.shape[1:]) assert isinstance(stride, int) and stride >= 1 # Finalize filter kernel. f = np.array(f, dtype=np.float32) if f.ndim == 1: f = f[:, np.newaxis] * f[np.newaxis, :] assert f.ndim == 2 if normalize: f /= np.sum(f) if flip: f = f[::-1, ::-1] f = f[:, :, np.newaxis, np.newaxis] f = np.tile(f, [1, 1, int(x.shape[1]), 1]) # No-op => early exit. if f.shape == (1, 1) and f[0,0] == 1: return x # Convolve using depthwise_conv2d. orig_dtype = x.dtype x = tf.cast(x, tf.float32) # tf.nn.depthwise_conv2d() doesn't support fp16 f = tf.constant(f, dtype=x.dtype, name='filter') strides = [1, 1, stride, stride] x = tf.nn.depthwise_conv2d(x, f, strides=strides, padding='SAME', data_format='NCHW') x = tf.cast(x, orig_dtype) return x def Blur(name, blur_filter=[1,2,1]): def blur2d(x, f=[1,2,1], normalize=True): return _blur2d(x, f, normalize) return Lambda(lambda x: blur2d(x, blur_filter), name=name) def _downscale2d(x, factor=2, gain=1): assert x.shape.ndims == 4 and all(dim is not None for dim in x.shape[1:]) assert isinstance(factor, int) and factor >= 1 # 2x2, float32 => downscale using _blur2d(). if factor == 2 and x.dtype == tf.float32: f = [np.sqrt(gain) / factor] * factor return _blur2d(x, f=f, normalize=False, stride=factor) # Apply gain. if gain != 1: x *= gain # No-op => early exit. if factor == 1: return x # Large factor => downscale using tf.nn.avg_pool(). # NOTE: Requires tf_config['graph_options.place_pruned_graph']=True to work. ksize = [1, 1, factor, factor] return tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding='VALID', data_format='NCHW') def _upscale2d(x, factor=2, gain=1): assert x.shape.ndims == 4 and all(dim is not None for dim in x.shape[1:]) assert isinstance(factor, int) and factor >= 1 # Apply gain. if gain != 1: x *= gain # No-op => early exit. if factor == 1: return x # Upscale using tf.tile(). s = x.shape x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1]) x = tf.tile(x, [1, 1, 1, factor, 1, factor]) x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor]) return x def Downscaled2d(name, factor=2, gain=1): return Lambda(lambda x: _downscale2d(x, factor, gain), name=name+'/Downscaled2d') def Upscaled2d(name, factor=2, gain=1): return Lambda(lambda x: _upscale2d(x, factor, gain), name=name+'/Upscaled2d') def Conv2d_downscale2d(model, filters, kernel_size, name, gain=math.sqrt(2), fused_scale='auto'): if fused_scale == 'auto': x = model.layers[-1].output fused_scale = min(x.shape[2:]) >= 128 if not fused_scale: # Not fused => call the individual ops directly. model.add( Conv2d(filters, kernel_size, name, gain) ) model.add( Downscaled2d(name) ) else: # Fused => perform both ops simultaneously using tf.nn.conv2d(). def fused_op(w): w = tf.pad(w, [[1,1], [1,1], [0,0], [0,0]], mode='CONSTANT') w = tf.add_n([w[1:, 1:], w[:-1, 1:], w[1:, :-1], w[:-1, :-1]]) * 0.25 return w model.add( Conv2d(filters, kernel_size, name, gain, kernel_modifier=fused_op, strides=2) ) def Upscale2d_conv2d(x, filters, kernel_size, name, use_bias, gain=math.sqrt(2), fused_scale='auto'): if fused_scale == 'auto': fused_scale = min(x.shape[2:]) * 2 >= 128 if not fused_scale: x = Upscaled2d(name)(x) x = Conv2d(filters, kernel_size, name=name, gain=gain, use_bias=use_bias)(x) return x return Conv2d_transpose(filters, kernel_size, name, gain, strides=2)(x) class Conv2d_transpose(Conv2DTranspose): def __init__(self, filters, kernel_size, name, gain=math.sqrt(2), lrmul=1.0, kernel_modifier=None, strides=2, use_bias=False): super(Conv2d_transpose, self).__init__(filters=filters, kernel_size=kernel_size, kernel_initializer=GetWeights(gain), use_bias=use_bias, padding='same', data_format='channels_first', name=name, strides=strides) self.gain = gain self.lrmul = lrmul self.kernel_modifier = kernel_modifier def build(self, input_shape): shape = [self.kernel_size[0], self.kernel_size[1], input_shape[1], self.filters] self.kernel = self.add_variable('weight', shape=shape, initializer=tf.initializers.zeros()) def call(self, inputs): # Fused => perform both ops simultaneously using tf.nn.conv2d_transpose(). def fused_op(w): w = tf.transpose(w, [0, 1, 3, 2]) # [kernel, kernel, fmaps_out, fmaps_in] w = tf.pad(w, [[1,1], [1,1], [0,0], [0,0]], mode='CONSTANT') w = tf.add_n([w[1:, 1:], w[:-1, 1:], w[1:, :-1], w[:-1, :-1]]) return w x, w = inputs, fused_op(self.kernel * runtime_coef(self.kernel_size, self.gain, inputs.shape[1], self.filters, lrmul=self.lrmul)) os = [tf.shape(inputs)[0], self.filters, inputs.shape[2] * 2, inputs.shape[3] * 2] outputs = tf.nn.conv2d_transpose(x, w, os, strides=[1,1,2,2], padding='SAME', data_format='NCHW') return outputs class MinibatchStddevLayer(tf.keras.layers.Layer): def __init__(self, group_size =4, num_new_features=1): super().__init__() self.group_size = group_size self.num_new_features = num_new_features def __call__(self, x, *args, **kwargs): group_size = tf.minimum(self.group_size, tf.shape(x)[0]) # Minibatch must be divisible by (or smaller than) group_size. s = x.shape # [NCHW] Input shape. y = tf.reshape(x, [group_size, -1, self.num_new_features, s[1] // self.num_new_features, s[2], s[ 3]]) # [GMncHW] Split minibatch into M groups of size G. Split channels into n channel groups c. y = tf.cast(y, tf.float32) # [GMncHW] Cast to FP32. y -= tf.reduce_mean(y, axis=0, keepdims=True) # [GMncHW] Subtract mean over group. y = tf.reduce_mean(tf.square(y), axis=0) # [MncHW] Calc variance over group. y = tf.sqrt(y + 1e-8) # [MncHW] Calc stddev over group. y = tf.reduce_mean(y, axis=[2, 3, 4], keepdims=True) # [Mn111] Take average over fmaps and pixels. y = tf.reduce_mean(y, axis=[2]) # [Mn11] Split channels into c channel groups y = tf.cast(y, x.dtype) # [Mn11] Cast back to original data type. y = tf.tile(y, [group_size, 1, s[2], s[3]]) # [NnHW] Replicate over group and pixels. return tf.concat([x, y], axis=1) # [NCHW] Append as new fmap. def minibatch_stddev_layer(x, group_size=4, num_new_features=1): with tf.compat.v1.variable_scope('MinibatchStddev'): group_size = tf.minimum(group_size, tf.shape(x)[0]) # Minibatch must be divisible by (or smaller than) group_size. s = x.shape # [NCHW] Input shape. y = tf.reshape(x, [group_size, -1, num_new_features, s[1]//num_new_features, s[2], s[3]]) # [GMncHW] Split minibatch into M groups of size G. Split channels into n channel groups c. y = tf.cast(y, tf.float32) # [GMncHW] Cast to FP32. y -= tf.reduce_mean(y, axis=0, keepdims=True) # [GMncHW] Subtract mean over group. y = tf.reduce_mean(tf.square(y), axis=0) # [MncHW] Calc variance over group. y = tf.sqrt(y + 1e-8) # [MncHW] Calc stddev over group. y = tf.reduce_mean(y, axis=[2,3,4], keepdims=True) # [Mn111] Take average over fmaps and pixels. y = tf.reduce_mean(y, axis=[2]) # [Mn11] Split channels into c channel groups y = tf.cast(y, x.dtype) # [Mn11] Cast back to original data type. y = tf.tile(y, [group_size, 1, s[2], s[3]]) # [NnHW] Replicate over group and pixels. return tf.concat([x, y], axis=1) # [NCHW] Append as new fmap. def StyleGAN_G_mapping( latent_size=512, dlatent_size=512, mapping_layers=8, mapping_fmaps=512, mapping_lrmul=0.01, truncation_psi=1, resolution=1024): resolution_log2 = int(np.log2(resolution)) num_layers = resolution_log2 * 2 - 2 model = Sequential(name='G_mapping') model.add( InputLayer(input_shape=[latent_size], name='G_mapping/latents_in') ) # Normalize latents. model.add( PixelNorm(name='G_mapping/PixelNorm') ) # Mapping layers. for layer_idx in range(mapping_layers): name = 'G_mapping/Dense{}'.format(layer_idx) fmaps = dlatent_size if layer_idx == mapping_layers - 1 else mapping_fmaps model.add( DenseLayer(units=fmaps, kernel_initializer=GetWeights(), name=name, lrmul=mapping_lrmul) ) model.add( LeakyReLU(alpha=0.2, name=name+'/LeakyReLU') ) # Broadcast. model.add( Broadcast(name='G_mapping/Broadcast', dlatent_broadcast=num_layers) ) # Output. model.add( Identity(name='G_mapping/dlatents_out') ) # Apply truncation trick. model.add( Truncation(name='Truncation', num_layers=num_layers, truncation_psi=truncation_psi )) return model def StyleGAN_G_synthesis(dlatent_size=512, resolution=1024, is_const_noise=True): # General parameters num_channels = 3 resolution_log2 = int(np.log2(resolution)) num_layers = resolution_log2 * 2 - 2 # Primary inputs. dlatents_in = tf.keras.layers.Input(shape=[num_layers, dlatent_size], name='G_synthesis/dlatents_in') # Noise inputs. noise_inputs = [] for layer_idx in range(num_layers): noise_inputs.append( RandomNoise(name='G_synthesis/noise%d'%layer_idx, layer_idx=layer_idx)(dlatents_in) ) # Things to do at the end of each layer. def layer_epilogue(x, layer_idx, name): name = 'G_synthesis/{}x{}/{}/'.format(x.shape[2], x.shape[2], name) x = ApplyNoise(name=name+'Noise', is_const_noise=is_const_noise)([x, noise_inputs[layer_idx]]) x = ApplyBias(name=name+'bias')(x) x = LeakyReLU(alpha=0.2, name=name+'LeakyReLU')(x) x = InstanceNorm(name=name+'InstanceNorm')(x) style = DenseLayer(units=x.shape[1]*2, gain=1, name=name+'StyleMod') (StridedSlice(layer_idx, name=name+'StridedSlice')(dlatents_in)) x = StyleModApply(name=name+'StyleModApply')([x, style]) return x # Building blocks for remaining layers. def block(res, x): # res = 3..resolution_log2 name, name0, name1 = '%dx%d' % (2**res, 2**res), 'Conv0_up', 'Conv1' # Conv0_up upscaled = Upscale2d_conv2d(x, name='G_synthesis/{}/{}'.format(name, name0), filters=nf(res-1), kernel_size=3, use_bias=False) x = layer_epilogue( Blur(name='G_synthesis/{}/{}/Blur'.format(name, name0))(upscaled), res*2-4, name0 ) # Conv1 x = layer_epilogue( Conv2d(name='G_synthesis/{}/{}'.format(name, name1), filters=nf(res-1), kernel_size=3, use_bias=False)(x), res*2-3, name1 ) return x def torgb(res, x): # res = 2..resolution_log2 lod = resolution_log2 - res return Conv2d(name='G_synthesis/ToRGB_lod%d' % lod, filters=num_channels, kernel_size=1, gain=1, use_bias=True)(x) # Early layers. x = layer_epilogue(Const(name='G_synthesis/4x4/Const')(dlatents_in), 0, name='Const') x = layer_epilogue(Conv2d(name='G_synthesis/4x4/Conv', filters=nf(1), kernel_size=3, use_bias=False)(x), 1, 'Conv') # Fixed structure: simple and efficient, but does not support progressive growing. for res in range(3, resolution_log2 + 1): x = block(res, x) x = torgb(resolution_log2, x) # change output to the default NHWC format, so it will be compatible with other networks x = tf.transpose(x, (0, 2, 3, 1)) return Model(inputs=dlatents_in, outputs=x, name='G_synthesis') class StyleGAN_G(Model): def __init__(self, resolution=1024, latent_size=512, dlatent_size=512, mapping_layers=8, mapping_fmaps=512, mapping_lrmul=0.01, truncation_psi=1): super(StyleGAN_G, self).__init__() self.model_mapping = StyleGAN_G_mapping(latent_size, dlatent_size, mapping_layers, mapping_fmaps, mapping_lrmul, truncation_psi, resolution) self.model_synthesis = StyleGAN_G_synthesis(dlatent_size, resolution) print('Model created.') def call(self, inputs): x = self.model_mapping(inputs) x = self.model_synthesis(x) return x def generate_sample(self, seed=5, is_visualize=False): rnd = np.random.RandomState(seed) latents = rnd.randn(1, 512) y = self.predict(latents) images = y.transpose([0, 2, 3, 1]) images = np.clip((images+1)*0.5, 0, 1) # print(images.shape, np.min(images), np.max(images)) plt.figure(figsize=(10, 10)) plt.imshow(images[0]) if is_visualize: plt.show() return images class StyleGAN_D(Model): def __init__(self, resolution=1024, mbstd_group_size=4, mbstd_num_features=1): super(StyleGAN_D, self).__init__() resolution_log2 = int(math.log2(resolution)) model = Sequential(name='Discriminator') model.add(InputLayer(input_shape=[3, resolution, resolution])) def fromrgb(res): name = 'FromRGB_lod%d' % (resolution_log2 - res) model.add( Conv2d(filters=nf(res-1), kernel_size=1, name=name) ) model.add( LeakyReLU(alpha=0.2, name=name+'/LeakyReLU') ) def block(res): name = '%dx%d' % (2**res, 2**res) if res >= 3: # 8x8 and up model.add( Conv2d(filters=nf(res-1), kernel_size=3, name=name+'/Conv0') ) model.add( LeakyReLU(alpha=0.2, name=name+'/Conv0/LeakyReLU') ) model.add( Blur(name=name+'/Blur') ) Conv2d_downscale2d(model=model, filters=nf(res-2), kernel_size=3, name=name+'/Conv1_down') model.add( LeakyReLU(alpha=0.2, name=name+'/Conv1_down/LeakyReLU') ) else: # 4x4 if mbstd_group_size > 1: model.add( Lambda(lambda x: minibatch_stddev_layer(x, mbstd_group_size, mbstd_num_features), name=name+'/MinibatchStddev') ) model.add( Conv2d(filters=nf(res-1), kernel_size=3, name=name+'/Conv') ) model.add( LeakyReLU(alpha=0.2, name=name+'/Conv/LeakyReLU') ) model.add( Flatten() ) model.add( DenseLayer(units=nf(res-2), kernel_initializer=GetWeights(), name=name+'/Dense0') ) model.add( LeakyReLU(alpha=0.2, name=name+'/Dense0/LeakyReLU') ) model.add( DenseLayer(units=1, kernel_initializer=GetWeights(1), gain=1, name=name+'/Dense1') ) # Blocks fromrgb(resolution_log2) for res in range(resolution_log2, 2, -1): block(res) block(2) self.model = model def call(self, inputs): inputs = tf.transpose(inputs, (0, 3, 1, 2)) return self.model(inputs) def copy_weights_to_keras_model(model, all_weights): c = 0 od = all_weights for l in model.layers: try: values = l.get_weights() weights = list(map(lambda x: x.shape, values)) if not len(weights): continue num_params = values[0].size # Special weights if len(weights) == 1: weights_list = [] # The learned constant variable if l.name == 'G_synthesis/4x4/Const': weights_list.append( od[l.name+'/const'] ) # Truncation trick variable if 'Truncation' in l.name: weights_list.append( od['dlatent_avg'] ) # Input noise if 'G_synthesis/noise' in l.name: weights_list.append( od[l.name] ) # Noise variables if 'Noise' in l.name: weights_list.append( od[l.name+'/weight'] ) # Bias variables if 'bias' in l.name: weights_list.append( od[l.name] ) # Conv with no bias if l.name.endswith('Conv') or l.name.endswith('Conv1') or l.name.endswith('Conv0_up'): weights_list.append( od[l.name+'/weight'] ) if len(weights_list) > 0: l.set_weights( weights_list ) print('.', end='') c = c + num_params else: print('WARNING: weights not found for ', l.name, ' of size', weights[0]) else: # Standard weights (weight + bias) assert len(weights) == 2 num_params = num_params + values[1].size layer_name = l.name var_names = ['{}/{}'.format(layer_name, 'weight'), '{}/{}'.format(layer_name, 'bias')] if var_names[0] in od and var_names[1] in od: weight = od[var_names[0]] bias = od[var_names[1]] l.set_weights( [ weight, bias ] ) print('.', end='') c = c + num_params else: print('WARNING: not found', var_names) except Exception as e: print(e) print('skipping...') print('Total number of parameters copied:', c) ================================================ FILE: test.py ================================================ import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' os.environ['OMP_NUM_THREADS'] = '1' os.environ['USE_SIMPLE_THREADED_LEVEL3'] = '1' os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' import sys from model.stylegan import StyleGAN_G_synthesis from model.model import Network from writer import Writer from inference import Inference from arglib import arglib import utils sys.path.insert(0, 'model/face_utils') def main(): test_args = arglib.TestArgs() args, str_args = test_args.args, test_args.str_args os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu Writer.set_writer(args.results_dir) id_model_path = args.pretrained_models_path.joinpath('vggface2.h5') stylegan_G_synthesis_path = str( args.pretrained_models_path.joinpath(f'stylegan_G_{args.resolution}x{args.resolution}_synthesis')) utils.landmarks_model_path = str(args.pretrained_models_path.joinpath('shape_predictor_68_face_landmarks.dat')) stylegan_G_synthesis = StyleGAN_G_synthesis(resolution=args.resolution, is_const_noise=args.const_noise) stylegan_G_synthesis.load_weights(stylegan_G_synthesis_path) network = Network(args, id_model_path, stylegan_G_synthesis) network.test() inference = Inference(args, network) test_func = getattr(inference, args.test_func) test_func() if __name__ == '__main__': main() ================================================ FILE: trainer.py ================================================ import logging import numpy as np import tensorflow as tf from writer import Writer from utils import general_utils as utils def id_loss_func(y_gt, y_pred): return tf.reduce_mean(tf.keras.losses.MAE(y_gt, y_pred)) class Trainer(object): def __init__(self, args, model, data_loader): self.args = args self.logger = logging.getLogger(__class__.__name__) self.model = model self.data_loader = data_loader # lrs & optimizers lr = 5e-5 if self.args.resolution == 256 else 1e-5 self.g_optimizer = tf.keras.optimizers.Adam(learning_rate=lr) self.g_gan_optimizer = tf.keras.optimizers.Adam(learning_rate=0.1 * lr) self.w_d_optimizer = tf.keras.optimizers.Adam(learning_rate=0.4 * lr) self.im_d_optimizer = tf.keras.optimizers.Adam(learning_rate=0.4 * lr) # Losses self.gan_loss_func = tf.keras.losses.BinaryCrossentropy(from_logits=True) self.pixel_loss_func = tf.keras.losses.MeanAbsoluteError(tf.keras.losses.Reduction.SUM) self.id_loss_func = id_loss_func if args.pixel_mask_type == 'gaussian': sigma = int(80 * (self.args.resolution / 256)) self.pixel_mask = utils.inverse_gaussian_image(self.args.resolution, sigma) else: self.pixel_mask = tf.ones([self.args.resolution, self.args.resolution]) self.pixel_mask = self.pixel_mask / tf.reduce_sum(self.pixel_mask) self.pixel_mask = tf.broadcast_to(self.pixel_mask, [self.args.batch_size, *self.pixel_mask.shape]) self.num_epoch = 0 self.is_cross_epoch = False # Lambdas if args.unified: self.lambda_gan = 0.5 else: self.lambda_gan = 1 self.lambda_pixel = 0.02 self.lambda_id = 1 self.lambda_attr_id = 1 self.lambda_landmarks = 0.001 self.r1_gamma = 10 # Test self.test_not_imporved = 0 self.max_id_preserve = 0 self.min_lnd_dist = np.inf def train(self): while self.num_epoch <= self.args.num_epochs: self.logger.info('---------------------------------------') self.logger.info(f'Start training epoch: {self.num_epoch}') if self.args.cross_frequency and (self.num_epoch % self.args.cross_frequency == 0): self.is_cross_epoch = True self.logger.info('This epoch is cross-face') else: self.is_cross_epoch = False self.logger.info('This epoch is same-face') try: if self.num_epoch % self.args.test_frequency == 0: self.test() self.train_epoch() except Exception as e: self.logger.exception(e) raise if self.test_not_imporved > self.args.not_improved_exit: self.logger.info(f'Test has not improved for {self.args.not_improved_exit} epochs. Exiting...') break self.num_epoch += 1 def train_epoch(self): id_loss = 0 landmarks_loss = 0 g_w_gan_loss = 0 pixel_loss = 0 w_d_loss = 0 w_loss = 0 self.logger.info(f'train in epoch: {self.num_epoch}') self.model.train() use_w_d = self.args.W_D_loss # if use_w_d and use_im_d and not self.args.unified: if not self.args.unified: if self.num_epoch % 2 == 0: # This epoch is not using image_D use_im_d = False # self.logger.info(f'Not using Image D in epoch: {self.num_epoch}') if self.num_epoch % 2 != 0: # This epoch is not using W_D use_w_d = False # self.logger.info(f'Not using W_d in epoch: {self.num_epoch}') attr_img, id_img, real_w, real_img, matching_ws = self.data_loader.get_batch(is_cross=self.is_cross_epoch) # Forward that does not require grads id_embedding = self.model.G.id_encoder(id_img) src_landmarks = self.model.G.landmarks(attr_img) attr_input = attr_img with tf.GradientTape(persistent=True) as g_tape: attr_out = self.model.G.attr_encoder(attr_input) attr_embedding = attr_out self.logger.info(f'attr embedding stats- mean: {tf.reduce_mean(tf.abs(attr_embedding)):.5f},' f' variance: {tf.math.reduce_variance(attr_embedding):.5f}') z_tag = tf.concat([id_embedding, attr_embedding], -1) w = self.model.G.latent_spaces_mapping(z_tag) fake_w = w[:, 0, :] self.logger.info( f'w stats- mean: {tf.reduce_mean(tf.abs(fake_w)):.5f}, variance: {tf.math.reduce_variance(fake_w):.5f}') pred = self.model.G.stylegan_s(w) # Move to roughly [0,1] pred = (pred + 1) / 2 if use_w_d: with tf.GradientTape() as w_d_tape: fake_w_logit = self.model.W_D(fake_w) g_w_gan_loss = self.generator_gan_loss(fake_w_logit) self.logger.info(f'g W loss is {g_w_gan_loss:.3f}') self.logger.info(f'fake W logit: {tf.squeeze(fake_w_logit)}') with g_tape.stop_recording(): real_w_logit = self.model.W_D(real_w) w_d_loss = self.discriminator_loss(fake_w_logit, real_w_logit) w_d_total_loss = w_d_loss if self.args.gp: w_d_gp = self.R1_gp(self.model.W_D, real_w) w_d_total_loss += w_d_gp self.logger.info(f'w_d_gp : {w_d_gp}') self.logger.info(f'W_D loss is {w_d_loss:.3f}') self.logger.info(f'real W logit: {tf.squeeze(real_w_logit)}') if self.args.id_loss: pred_id_embedding = self.model.G.id_encoder(pred) id_loss = self.lambda_id * id_loss_func(pred_id_embedding, tf.stop_gradient(id_embedding)) self.logger.info(f'id loss is {id_loss:.3f}') if self.args.landmarks_loss: try: dst_landmarks = self.model.G.landmarks(pred) except Exception as e: self.logger.warning(f'Failed finding landmarks on prediction. Dont use landmarks loss. Error:{e}') dst_landmarks = None if dst_landmarks is None or src_landmarks is None: landmarks_loss = 0 else: landmarks_loss = self.lambda_landmarks * \ tf.reduce_mean(tf.keras.losses.MSE(src_landmarks, dst_landmarks)) self.logger.info(f'landmarks loss is: {landmarks_loss:.3f}') # if landmarks_loss > 5: # landmarks_loss = 0 # id_loss = 0 if not self.is_cross_epoch and self.args.pixel_loss: l1_loss = self.pixel_loss_func(attr_img, pred, sample_weight=self.pixel_mask) self.logger.info(f'L1 pixel loss is {l1_loss:.3f}') if self.args.pixel_loss_type == 'mix': mssim = tf.reduce_mean(1 - tf.image.ssim_multiscale(attr_img, pred, 1.0)) self.logger.info(f'mssim loss is {l1_loss:.3f}') pixel_loss = self.lambda_pixel * (0.84 * mssim + 0.16 * l1_loss) else: pixel_loss = self.lambda_pixel * l1_loss self.logger.info(f'pixel loss is {pixel_loss:.3f}') g_gan_loss = g_w_gan_loss total_g_not_gan_loss = id_loss \ + landmarks_loss \ + pixel_loss \ + w_loss self.logger.info(f'total G (not gan) loss is {total_g_not_gan_loss:.3f}') self.logger.info(f'G gan loss is {g_gan_loss:.3f}') Writer.add_scalar('loss/landmarks_loss', landmarks_loss, step=self.num_epoch) Writer.add_scalar('loss/total_g_not_gan_loss', total_g_not_gan_loss, step=self.num_epoch) Writer.add_scalar('loss/id_loss', id_loss, step=self.num_epoch) if use_w_d: Writer.add_scalar('loss/g_w_gan_loss', g_w_gan_loss, step=self.num_epoch) Writer.add_scalar('loss/W_D_loss', w_d_loss, step=self.num_epoch) if self.args.gp: Writer.add_scalar('loss/w_d_gp', w_d_gp, step=self.num_epoch) if not self.is_cross_epoch: Writer.add_scalar('loss/pixel_loss', pixel_loss, step=self.num_epoch) Writer.add_scalar('loss/w_loss', w_loss, step=self.num_epoch) if self.args.debug or \ (self.num_epoch < 1e3 and self.num_epoch % 1e2 == 0) or \ (self.num_epoch < 1e4 and self.num_epoch % 1e3 == 0) or \ (self.num_epoch % 1e4 == 0): utils.save_image(pred[0], self.args.images_results.joinpath(f'{self.num_epoch}_prediction_step.png')) utils.save_image(id_img[0], self.args.images_results.joinpath(f'{self.num_epoch}_id_step.png')) utils.save_image(attr_img[0], self.args.images_results.joinpath(f'{self.num_epoch}_attr_step.png')) Writer.add_image('input/id image', tf.expand_dims(id_img[0], 0), step=self.num_epoch) Writer.add_image('Prediction', tf.expand_dims(pred[0], 0), step=self.num_epoch) if total_g_not_gan_loss != 0: g_grads = g_tape.gradient(total_g_not_gan_loss, self.model.G.trainable_variables) g_grads_global_norm = tf.linalg.global_norm(g_grads) self.logger.info(f'global norm G not gan grad: {g_grads_global_norm}') self.g_optimizer.apply_gradients(zip(g_grads, self.model.G.trainable_variables)) if use_w_d: g_gan_grads = g_tape.gradient(g_gan_loss, self.model.G.trainable_variables) g_gan_grad_global_norm = tf.linalg.global_norm(g_gan_grads) self.logger.info(f'global norm G gan grad: {g_gan_grad_global_norm}') self.g_gan_optimizer.apply_gradients(zip(g_gan_grads, self.model.G.trainable_variables)) w_d_grads = w_d_tape.gradient(w_d_total_loss, self.model.W_D.trainable_variables) self.logger.info(f'global W_D gan grad: {tf.linalg.global_norm(w_d_grads)}') self.w_d_optimizer.apply_gradients(zip(w_d_grads, self.model.W_D.trainable_variables)) del g_tape # Common # Test def test(self): self.logger.info(f'Testing in epoch: {self.num_epoch}') self.model.test() similarities = {'id_to_pred': [], 'id_to_attr': [], 'attr_to_pred': []} fake_reconstruction = {'MSE': [], 'PSNR': [], 'ID': []} real_reconstruction = {'MSE': [], 'PSNR': [], 'ID': []} if self.args.test_with_arcface: test_similarities = {'id_to_pred': [], 'id_to_attr': [], 'attr_to_pred': []} lnd_dist = [] for i in range(self.args.test_size): attr_img, id_img = self.data_loader.get_batch(is_train=False, is_cross=True) pred, id_embedding, w, attr_embedding, src_lnds = self.model.G(id_img, attr_img) image = tf.clip_by_value(pred, 0, 1) pred_id = self.model.G.id_encoder(image) attr_id = self.model.G.id_encoder(attr_img) similarities['id_to_pred'].extend(tf.keras.losses.cosine_similarity(id_embedding, pred_id).numpy()) similarities['id_to_attr'].extend(tf.keras.losses.cosine_similarity(id_embedding, attr_id).numpy()) similarities['attr_to_pred'].extend(tf.keras.losses.cosine_similarity(attr_id, pred_id).numpy()) if self.args.test_with_arcface: try: arc_id_embedding = self.model.G.test_id_encoder(id_img) arc_pred_id = self.model.G.test_id_encoder(image) arc_attr_id = self.model.G.test_id_encoder(attr_img) test_similarities['id_to_attr'].extend( tf.keras.losses.cosine_similarity(arc_id_embedding, arc_attr_id).numpy()) test_similarities['id_to_pred'].extend( tf.keras.losses.cosine_similarity(arc_id_embedding, arc_pred_id).numpy()) test_similarities['attr_to_pred'].extend( tf.keras.losses.cosine_similarity(arc_attr_id, arc_pred_id).numpy()) except Exception as e: self.logger.warning(f'Not calculating test similarities for iteration: {i} because: {e}') # Landmarks dst_lnds = self.model.G.landmarks(image) lnd_dist.extend(tf.reduce_mean(tf.keras.losses.MSE(src_lnds, dst_lnds), axis=-1).numpy()) # Fake Reconstruction self.test_reconstruction(id_img, fake_reconstruction, display=(i==0), display_name='id_img') if self.args.test_real_attr: # Real Reconstruction self.test_reconstruction(attr_img, real_reconstruction, display=(i==0), display_name='attr_img') if i == 0: utils.save_image(image[0], self.args.images_results.joinpath(f'test_prediction_{self.num_epoch}.png')) utils.save_image(id_img[0], self.args.images_results.joinpath(f'test_id_{self.num_epoch}.png')) utils.save_image(attr_img[0], self.args.images_results.joinpath(f'test_attr_{self.num_epoch}.png')) Writer.add_image('test/prediction', image, step=self.num_epoch) Writer.add_image('test input/id image', id_img, step=self.num_epoch) Writer.add_image('test input/attr image', attr_img, step=self.num_epoch) for j in range(np.minimum(3, src_lnds.shape[0])): src_xy = src_lnds[j] # GT dst_xy = dst_lnds[j] # pred attr_marked = utils.mark_landmarks(attr_img[j], src_xy, color=(0, 0, 0)) pred_marked = utils.mark_landmarks(pred[j], src_xy, color=(0, 0, 0)) pred_marked = utils.mark_landmarks(pred_marked, dst_xy, color=(255, 112, 112)) Writer.add_image(f'landmarks/overlay-{j}', pred_marked, step=self.num_epoch) Writer.add_image(f'landmarks/src-{j}', attr_marked, step=self.num_epoch) # Similarity self.logger.info('Similarities:') for k, v in similarities.items(): self.logger.info(f'{k}: MEAN: {np.mean(v)}, STD: {np.std(v)}') mean_lnd_dist = np.mean(lnd_dist) self.logger.info(f'Mean landmarks L2: {mean_lnd_dist}') id_to_pred = np.mean(similarities['id_to_pred']) attr_to_pred = np.mean(similarities['attr_to_pred']) mean_disen = attr_to_pred - id_to_pred Writer.add_scalar('similarity/score', mean_disen, step=self.num_epoch) Writer.add_scalar('similarity/id_to_pred', id_to_pred, step=self.num_epoch) Writer.add_scalar('similarity/attr_to_pred', attr_to_pred, step=self.num_epoch) if self.args.test_with_arcface: arc_id_to_pred = np.mean(test_similarities['id_to_pred']) arc_attr_to_pred = np.mean(test_similarities['attr_to_pred']) arc_mean_disen = arc_attr_to_pred - arc_id_to_pred Writer.add_scalar('arc_similarity/score', arc_mean_disen, step=self.num_epoch) Writer.add_scalar('arc_similarity/id_to_pred', arc_id_to_pred, step=self.num_epoch) Writer.add_scalar('arc_similarity/attr_to_pred', arc_attr_to_pred, step=self.num_epoch) self.logger.info(f'Mean disentanglement score is {mean_disen}') Writer.add_scalar('landmarks/L2', np.mean(lnd_dist), step=self.num_epoch) # Reconstruction if self.args.test_real_attr: Writer.add_scalar('reconstruction/real_MSE', np.mean(real_reconstruction['MSE']), step=self.num_epoch) Writer.add_scalar('reconstruction/real_PSNR', np.mean(real_reconstruction['PSNR']), step=self.num_epoch) Writer.add_scalar('reconstruction/real_ID', np.mean(real_reconstruction['ID']), step=self.num_epoch) Writer.add_scalar('reconstruction/fake_MSE', np.mean(fake_reconstruction['MSE']), step=self.num_epoch) Writer.add_scalar('reconstruction/fake_PSNR', np.mean(fake_reconstruction['PSNR']), step=self.num_epoch) Writer.add_scalar('reconstruction/fake_ID', np.mean(fake_reconstruction['ID']), step=self.num_epoch) if mean_lnd_dist < self.min_lnd_dist: self.logger.info('Minimum landmarks dist achieved. saving checkpoint') self.test_not_imporved = 0 self.min_lnd_dist = mean_lnd_dist self.model.my_save(f'_best_landmarks_epoch_{self.num_epoch}') if np.abs(id_to_pred) > self.max_id_preserve: self.logger.info(f'Max ID preservation achieved! saving checkpoint') self.test_not_imporved = 0 self.max_id_preserve = np.abs(id_to_pred) self.model.my_save(f'_best_id_epoch_{self.num_epoch}') else: self.test_not_imporved += 1 def test_reconstruction(self, img, errors_dict, display=False, display_name=None): pred, id_embedding, w, attr_embedding, src_lnds = self.model.G(img, img) recon_image = tf.clip_by_value(pred, 0, 1) recon_pred_id = self.model.G.id_encoder(recon_image) mse = tf.reduce_mean((img - recon_image) ** 2, axis=[1, 2, 3]).numpy() psnr = tf.image.psnr(img, recon_image, 1).numpy() errors_dict['MSE'].extend(mse) errors_dict['PSNR'].extend(psnr) errors_dict['ID'].extend(tf.keras.losses.cosine_similarity(id_embedding, recon_pred_id).numpy()) if display: Writer.add_image(f'reconstruction/{display_name}', pred, step=self.num_epoch) # Helpers def generator_gan_loss(self, fake_logit): """ G logistic non saturating loss, to be minimized """ g_gan_loss = self.gan_loss_func(tf.ones_like(fake_logit), fake_logit) return self.lambda_gan * g_gan_loss def discriminator_loss(self, fake_logit, real_logit): """ D logistic loss, to be minimized verified as identical to StyleGAN's loss.D_logistic """ fake_gt = tf.zeros_like(fake_logit) real_gt = tf.ones_like(real_logit) d_fake_loss = self.gan_loss_func(fake_gt, fake_logit) d_real_loss = self.gan_loss_func(real_gt, real_logit) d_loss = d_real_loss + d_fake_loss return self.lambda_gan * d_loss def R1_gp(self, D, x): with tf.GradientTape() as t: t.watch(x) pred = D(x) pred_sum = tf.reduce_sum(pred) grad = t.gradient(pred_sum, x) # Reshape as a vector norm = tf.norm(tf.reshape(grad, [tf.shape(grad)[0], -1]), axis=1) gp = tf.reduce_mean(norm ** 2) gp = 0.5 * self.r1_gamma * gp return gp ================================================ FILE: utils/__init__.py ================================================ ================================================ FILE: utils/general_utils.py ================================================ from pathlib import Path import cv2 import numpy as np from PIL import Image import tensorflow as tf from tensorflow.keras.initializers import VarianceScaling import numbers import scipy import dlib landmarks_model_path = None def read_image(img_path, resolution, align=False): if align: img = read_and_align_image(img_path, resolution) else: img = read_SG_image(img_path, resolution) return img def find_file_by_str(search_dir, s): files = [f for f in search_dir.iterdir() if s in f.name] return files def read_SG_image(img_path, size=256, resize=True): img = Image.open(str(img_path)) img = img.convert('RGB') if img.size != (size, size) and resize: img = img.resize((size, size)) img = np.asarray(img) img = np.expand_dims(img, axis=0) # Images in [0, 1] img = np.float32(img) / 255 return img def read_and_align_image(img_path, output_size=1024): global landmarks_model_path if not landmarks_model_path: raise ValueError('Please init the landmarks model path') transform_size = 4096 enable_padding = True img = Image.open(img_path) img = img.convert('RGB') npimg = np.asarray(img) # states is a 4x1 array with confidence for : [left eye closed, right eye closed, mouth closed, mouth open big] face_detector = dlib.get_frontal_face_detector() landmarks_network = dlib.shape_predictor(landmarks_model_path) try: bbox = face_detector(npimg, 0)[0] except: print('face not found!') raise # rect = np.array([det.left(), det.top(), det.right(), det.bottom()]) shape = landmarks_network(npimg, bbox) lm = np.array([[shape.part(n).x + 0.5, shape.part(n).y + 0.5] for n in range(shape.num_parts)]) lm = np.round(lm) + 0.5 lm_chin = lm[0: 17] # left-right lm_eyebrow_left = lm[17: 22] # left-right lm_eyebrow_right = lm[22: 27] # left-right lm_nose = lm[27: 31] # top-down lm_nostrils = lm[31: 36] # top-down lm_eye_left = lm[36: 42] # left-clockwise lm_eye_right = lm[42: 48] # left-clockwise lm_mouth_outer = lm[48: 60] # left-clockwise lm_mouth_inner = lm[60: 68] # left-clockwise # Calculate auxiliary vectors. eye_left = np.mean(lm_eye_left, axis=0) eye_right = np.mean(lm_eye_right, axis=0) eye_avg = (eye_left + eye_right) * 0.5 # nose_mock_avg = (lm_nose[0] + lm_nose[1]) * 0.5 # eye_avg = nose_mock_avg eye_to_eye = eye_right - eye_left mouth_left = lm_mouth_outer[0] mouth_right = lm_mouth_outer[6] mouth_avg = (mouth_left + mouth_right) * 0.5 eye_to_mouth = mouth_avg - eye_avg # Choose oriented crop rectangle. x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] x /= np.hypot(*x) x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) y = np.flipud(x) * [-1, 1] c = eye_avg + eye_to_mouth * 0.1 quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) qsize = np.hypot(*x) * 2 # Shrink. shrink = int(np.floor(qsize / output_size * 0.5)) if shrink > 1: rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) img = img.resize(rsize, Image.ANTIALIAS) quad /= shrink qsize /= shrink # Crop. border = max(int(np.rint(qsize * 0.1)), 3) crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1])))) crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1])) crop = np.array(crop) if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: img = img.crop(tuple(crop)) quad -= crop[0:2] # Pad. pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1])))) pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0)) if enable_padding and max(pad) > border - 4: pad = np.maximum(pad, int(np.rint(qsize * 0.3))) img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') h, w, _ = img.shape y, x, _ = np.ogrid[:h, :w, :1] mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3])) blur = qsize * 0.02 img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') quad += pad[:2] # Transform. img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(), Image.BILINEAR) if output_size < transform_size: img = img.resize((output_size, output_size), Image.ANTIALIAS) img = np.asarray(img) img = np.expand_dims(img, axis=0) img = np.float32(img) / 255 return img def gaussian_image(size, sigma, dim=2): if isinstance(size, numbers.Number): size = [size] * dim if isinstance(sigma, numbers.Number): sigma = [sigma] * dim weight_kernel = 1 meshgrids = np.meshgrid(*[np.arange(size, dtype=np.float32) for size in size]) # The gaussian kernel is the product of the # gaussian function of each dimension. for size, std, mgrid in zip(size, sigma, meshgrids): mean = (size - 1) / 2 weight_kernel *= 1 / (std * np.sqrt(2 * np.pi)) * np.exp(-((mgrid - mean) / std) ** 2 / 2) weight_kernel = weight_kernel / np.sum(weight_kernel) return weight_kernel def inverse_gaussian_image(size, sigma, dim=2): gauss = gaussian_image(size, sigma, dim) # Inversion achieved by max - gauss, but adding min as well to # prevent regions of zeros which don't exist in normal gaussian inv_gauss = np.max(gauss) + np.min(gauss) - gauss inv_gauss = inv_gauss / np.sum(inv_gauss) return inv_gauss def is_float(tensor): """ Check if input tensor is float32, tensor maybe tf.Tensor or np.array """ return (isinstance(tensor, tf.Tensor) and tensor.dtype != tf.dtypes.uint8) or tensor.dtype != np.uint8 def convert_tensor_to_image(tensor): """ Converts tensor to image, and saturate output's range :param tensor: tf.Tensor, dtype float32, range [0,1] :return: np.array, dtype uint8, range [0, 255] """ if is_float(tensor): tensor = tf.clip_by_value(tensor, 0., 1.) tensor = 255 * tensor if tensor.ndim == 4 and tensor.shape[0] == 1: tensor = tf.squeeze(tensor) tensor = np.uint8(np.round(tensor)) return tensor def save_image(img, file_path): """ :param img: Could be either tf tensor or numpy array :param file_path: """ if isinstance(file_path, Path): file_path = str(file_path) img = convert_tensor_to_image(img) img = Image.fromarray(img) img.save(file_path) def mark_landmarks(img, lnd, color=None): """ landmarks in (x,y) format """ img = convert_tensor_to_image(img) radius = int(img.shape[0] / 256) lnd = (img.shape[0] / 160) * lnd if not color: color = (255, 255, 255) for i in range(lnd.shape[0]): x_y = lnd[i] img = cv2.circle(img, center=(int(x_y[0]), int(x_y[1])), color=color, radius=radius, thickness=-1) return img def get_weights(slope=0.2): """ The scale is calculated according to: https://pytorch.org/docs/stable/nn.init.html and https://towardsdatascience.com/weight-initialization-in-neural-networks-a-journey-from-the-basics-to-kaiming-954fb9b47c79 For ReLU and LeakyReLU activations, the preferable initialization is kaiming. In Pytorch, the gain for LeakyReLU is calcaulted by: sqrt(2 / ( 1 + leaky_relu_slope ^ 2)) and the weights are sampled from N(0, std^2) where std = gain / sqrt(fan_in) To mimic this in TF, I am using VarianceScaling. The weights are sampled from N(0, std^2) where std = sqrt(scale / fan_in). Therefore, scale = gain^2 """ scale = 2 / (1 + slope ** 2) return VarianceScaling(scale) def np_permute(tensor, permute): idx = np.empty_like(permute) idx[permute] = np.arange(len(permute)) return tensor[:, idx] ================================================ FILE: utils/generate_fake_data.py ================================================ import sys from pathlib import Path import os sys.path.append('..') import argparse import tensorflow as tf import numpy as np from tqdm import tqdm from utils.general_utils import save_image from model.stylegan import StyleGAN_G def main(args): base_dir = Path(args.output_path).joinpath(f'dataset_{args.resolution}') base_w_dir = base_dir.joinpath('ws') base_w_dir.mkdir(parents=True, exist_ok=True) base_im_dir = base_dir.joinpath('images') base_im_dir.mkdir(parents=True, exist_ok=True) existing_files = list(base_dir.joinpath('images').iterdir()) if existing_files: max_exist = max([int(x.name) for x in existing_files]) max_exist = int(max_exist - max_exist % 1e3 + 1e3) else: max_exist = 0 stylegan_G_path = args.pretrained_models_path.joinpath(f'stylegan_G_{args.resolution}x{args.resolution}.h5') stylegan_G = StyleGAN_G(resolution=args.resolution, truncation_psi=args.truncation) stylegan_G.load_weights(str(stylegan_G_path)) num_samples = args.num_images batch_size = args.batch_size num_batches = int(num_samples / batch_size) curr_ind = max_exist for _ in tqdm(range(num_batches)): z = tf.random.normal((batch_size, 512)) w = stylegan_G.model_mapping(z) images = stylegan_G.model_synthesis(w) images = (images + 1) / 2 if curr_ind % 1000 == 0: curr_w_dir = base_w_dir.joinpath(f'{curr_ind:05d}') curr_w_dir.mkdir(exist_ok=True) curr_im_dir = base_im_dir.joinpath(f'{curr_ind:05d}') curr_im_dir.mkdir(exist_ok=True) for j in range(batch_size): w_path = curr_w_dir.joinpath(f'{curr_ind:05d}.npy') np.save(str(w_path), w[j], allow_pickle=False) im_path = curr_im_dir.joinpath(f'{curr_ind:05d}.png') save_image(images[j], im_path) curr_ind += 1 if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--resolution', type=int, choices=[256, 1024], default=256) parser.add_argument('--batch_size', type=int, default=50) parser.add_argument('--truncation', type=float, default=0.7) parser.add_argument('--output_path', required=True) parser.add_argument('--pretrained_models_path', type=Path, required=True) parser.add_argument('--num_images', type=int, default=10000) parser.add_argument('--gpu', default='0') args = parser.parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu assert args.num_images % 1e3 == 0 main(args) ================================================ FILE: writer.py ================================================ from utils.general_utils import convert_tensor_to_image from pathlib import Path import tensorflow as tf class Writer(object): writer = None @staticmethod def set_writer(results_dir): if isinstance(results_dir, str): results_dir = Path(results_dir) results_dir.mkdir(exist_ok=True, parents=True) Writer.writer = tf.summary.create_file_writer(str(results_dir)) @staticmethod def add_scalar(tag, val, step): with Writer.writer.as_default(): tf.summary.scalar(tag, val, step=step) @staticmethod def add_image(tag, val, step): val = convert_tensor_to_image(val) if tf.rank(val) == 3: val = tf.expand_dims(val, 0) with Writer.writer.as_default(): tf.summary.image(tag, val, step) @staticmethod def flush(): with Writer.writer.as_default(): Writer.writer.flush()