Repository: YotamNitzan/ID-disentanglement
Branch: main
Commit: 342a782364fd
Files: 32
Total size: 129.6 KB
Directory structure:
gitextract_rgin_bu5/
├── .gitignore
├── LICENSE
├── README.md
├── arglib/
│ ├── __init__.py
│ └── arglib.py
├── data_loader/
│ ├── __init__.py
│ └── data_loader.py
├── docs/
│ ├── index.html
│ ├── mainpage.css
│ └── setup.md
├── environment.yml
├── inference.py
├── main.py
├── model/
│ ├── __init__.py
│ ├── arcface/
│ │ ├── arcface.py
│ │ ├── inference.py
│ │ └── resnet.py
│ ├── attr_encoder.py
│ ├── discriminator.py
│ ├── face_detector.py
│ ├── generator.py
│ ├── id_encoder.py
│ ├── landmarks.py
│ ├── latent_mapping.py
│ ├── model.py
│ └── stylegan.py
├── test.py
├── trainer.py
├── utils/
│ ├── __init__.py
│ ├── general_utils.py
│ └── generate_fake_data.py
└── writer.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
.idea
.png
.jpg
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2020 YotamNitzan
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# Face Identity Disentanglement via Latent Space Mapping
Face Identity Disentanglement via Latent Space Mapping
SIGGRAPH ASIA 2020
Abstract
Learning disentangled representations of data is a fundamental problem in
artificial intelligence. Specifically, disentangled latent representations allow
generative models to control and compose the disentangled factors in the
synthesis process. Current methods, however, require extensive supervision
and training, or instead, noticeably compromise quality.
In this paper, we present a method that learns how to represent data
in a disentangled way, with minimal supervision, manifested solely using
available pre-trained networks. Our key insight is to decouple the processes
of disentanglement and synthesis, by employing a leading pre-trained unconditional image generator, such as
StyleGAN. By learning to map into its
latent space, we leverage both its state-of-the-art quality, and its rich and
expressive latent space, without the burden of training it.
We demonstrate our approach on the complex and high dimensional
domain of human heads. We evaluate our method qualitatively and quantitatively, and exhibit its success with
de-identification operations and with
temporal identity coherency in image sequences. Through extensive experimentation, we show that our method
successfully disentangles identity
from other facial attributes, surpassing existing methods, even though they
require more training and supervision.
motivation
learning disentangled representations and image synthesis are different tasks.
however, it is a common practice to solve both simultaneously.
this way, the image generator learns the semantics of the representations.
now it is able to take multiple representations from different sources and mix them to generate novel images.
but this comes at a price, one now needs to solve two difficult tasks simultaneously.
this often causes the need to devise dedicated architectures and even then, achieve sub-optimal visual quality.
we propose a different approach. unconditional generators have recently achieved amazing image quality.
we take advantage of this fact, and avoid solving this task ourselves. instead, we suggest to use a pretrained
generator, such as stylegan. but now, how can the generator, which is pretrained & unconditional, make sense of
the disentangled representations?
we suggest mapping the disentangled representations directly into the latent space of the generator.
the mapping produces in a single feed-forward a new, never before seen, latent code that corresponds to novel
images.
Composition Results
We demonstrate our method on the domain of human faces - specifically disentangling identity from all other
attributes.
In the following tables the identity is taken from the image on top and the attributes are taken from the
left most image.
In this figure, the inputs themselves are StyleGAN generated images.
More results, but this time, the input images are real.
Disentangled Interpolation
Thanks to our disentangled representations, we are able to interpolate only a single feature
(identity or attributes) in the generator's latent space.
This enables more control and opens the door for new disentangled editing capabilities.
Contact
yotamnitzan at gmail dot com
================================================
FILE: docs/mainpage.css
================================================
body {
font-family: 'Lato', sans-serif;
font-weight: 300;
color: #333;
font-size: 16px;
}
h1 {
font-size: 40px;
color: #555;
font-weight: 400;
text-align: center;
margin: 0;
padding: 0;
margin-top: 30px;
margin-bottom: 10px;
}
.authors {
color: #222;
font-size: 24px;
font-weight: 300;
text-align: center;
margin: 0;
padding: 0;
margin-bottom: 0px;
}
.logoimg {
text-align: center;
margin-bottom: 30px;
}
.container-fluid {
margin-top: 5px;
margin-bottom: 5px;
}
.container {
margin-top: 10px;
}
#footer {
margin-bottom: 100px;
}
.thumbs {
-webkit-box-shadow: 1px 1px 3px #999;
-moz-box-shadow: 1px 1px 3px #999;
box-shadow: 1px 1px 3px #999;
margin-bottom: 20px;
}
h2 {
font-size: 24px;
font-weight: 900;
border-bottom: 1px solid #999;
margin-bottom: 20px;
}
.space {
margin-bottom: 1.5cm;
}
.text-primary {
color: #5da2d5 !important;
}
.text-primary:hover {
color: #f3d250 !important;
opacity: 1.0;
}
================================================
FILE: docs/setup.md
================================================
# Setup
## Environment
It's designed to use Tensorflow 2.X on python (3.7), using cuda 10.1 and cudnn 7.6.5.
Run `conda create -n environment.yml` to create a conda environment that has the needed dependencies.
Tested with Tensorflow 2.0.0, Python 3.7.9, Ubuntu 14.04.
## Third-party pretrained networks
Our method relies on several pretrained networks.
Some are needed only for training and some also for inference.
Download according to your intention.
Put all downloaded files/directories under a single directory, which will
be the baseline path for all pretrained networks.
| Name | Training | Inference |Description
| :--- | :----------:| :----------:| :----------
|[FFHQ StyleGAN 256x256](https://drive.google.com/drive/folders/1OgLvUhd9FX9_mPXrfqAWaLZsceQzE9l4?usp=sharing) | :heavy_check_mark: | :heavy_check_mark: | StyleGAN model pretrained on FFHQ with 256x256 resolution. Converted using [StyleGAN-Tensorflow2](https://github.com/YotamNitzan/StyleGAN-Tensorflow2)
|[FFHQ StyleGAN 1024x1024](https://drive.google.com/drive/folders/1jQxJsmapu6SjygvJfvP4-YVxZ9f5Hu_N?usp=sharing) | :heavy_check_mark: | :heavy_check_mark: | StyleGAN model pretrained on FFHQ with 1024x1024 resolution. Converted using [StyleGAN-Tensorflow2](https://github.com/YotamNitzan/StyleGAN-Tensorflow2)
|[VGGFace2](https://drive.google.com/file/d/1I_JyR7LH-30hEIpD4OSFVg2TOf9Q8cqU/view?usp=sharing) | :heavy_check_mark: | :heavy_check_mark: | Pretrained VGGFace2 model taken from [WeidiXie](https://github.com/WeidiXie/Keras-VGGFace2-ResNet50).
|[dlib landmarks model](http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2) | | :heavy_check_mark: | dlib landmarks model, used to align images.
|[ArcFace](https://drive.google.com/drive/folders/1F-Ll9Nw7I1FGP61cpQxOdhs2nxi0E5mg?usp=sharing) | :heavy_check_mark: | | Pretrained ArcFace model taken from [dmonterom](https://github.com/dmonterom/face_recognition_TF2).
|[Face & Landmarks Detection](https://drive.google.com/drive/folders/1D__J9UMwzBNR9eVrQGYuL9ueYGi7G4qh?usp=sharing) | :heavy_check_mark: | | Pretrained face detection and differentiable facial landmarks detection from [610265158](https://github.com/610265158/face_landmark).
### Other StyleGANs
To try out our method with other checkpoints of StyleGAN, first obtain a trained StyleGAN pkl file using the [original StyleGAN repository](https://github.com/NVlabs/stylegan)
Next, convert it to Tensorflow-2.0 using this [repository](https://github.com/YotamNitzan/StyleGAN-Tensorflow2).
================================================
FILE: environment.yml
================================================
name: id_disen
channels:
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _tflow_select=2.1.0=gpu
- absl-py=0.11.0=py37h06a4308_0
- aiohttp=3.6.3=py37h7b6447c_0
- astor=0.8.1=py37_0
- async-timeout=3.0.1=py37_0
- attrs=20.3.0=pyhd3eb1b0_0
- blas=1.0=mkl
- blinker=1.4=py37_0
- blosc=1.20.1=he1b5a44_0
- brotli=1.0.9=he1b5a44_3
- brotlipy=0.7.0=py37h27cfd23_1003
- bzip2=1.0.8=h516909a_3
- c-ares=1.17.1=h27cfd23_0
- ca-certificates=2020.11.8=ha878542_0
- cachetools=4.1.1=py_0
- certifi=2020.11.8=py37h89c1867_0
- cffi=1.14.3=py37h261ae71_2
- chardet=3.0.4=py37h06a4308_1003
- charls=2.1.0=he1b5a44_2
- click=7.1.2=py_0
- cloudpickle=1.6.0=py_0
- cryptography=3.2.1=py37h3c74f83_1
- cudatoolkit=10.0.130=0
- cudnn=7.6.5=cuda10.0_0
- cupti=10.0.130=0
- cycler=0.10.0=py_2
- cytoolz=0.11.0=py37h8f50634_1
- dask-core=2.30.0=py_0
- decorator=4.4.2=py_0
- freetype=2.10.4=h7ca028e_0
- gast=0.2.2=py37_0
- giflib=5.2.1=h36c2ea0_2
- google-auth=1.23.0=pyhd3eb1b0_0
- google-auth-oauthlib=0.4.2=pyhd3eb1b0_2
- google-pasta=0.2.0=py_0
- grpcio=1.31.0=py37hf8bcb03_0
- h5py=2.10.0=py37hd6299e0_1
- hdf5=1.10.6=hb1b8bf9_0
- idna=2.10=py_0
- imagecodecs=2020.5.30=py37hda6ee5b_1
- imageio=2.9.0=py_0
- importlib-metadata=2.0.0=py_1
- intel-openmp=2020.2=254
- jpeg=9d=h36c2ea0_0
- jxrlib=1.1=h516909a_2
- keras-applications=1.0.8=py_1
- keras-preprocessing=1.1.0=py_1
- kiwisolver=1.3.1=py37hc928c03_0
- lcms2=2.11=hcbb858e_1
- ld_impl_linux-64=2.33.1=h53a641e_7
- libaec=1.0.4=he1b5a44_1
- libedit=3.1.20191231=h14c3975_1
- libffi=3.3=he6710b0_2
- libgcc-ng=9.1.0=hdf63c60_0
- libgfortran-ng=7.3.0=hdf63c60_0
- libpng=1.6.37=h21135ba_2
- libprotobuf=3.13.0.1=hd408876_0
- libstdcxx-ng=9.1.0=hdf63c60_0
- libtiff=4.1.0=h4f3a223_6
- libwebp-base=1.1.0=h36c2ea0_3
- libzopfli=1.0.3=he1b5a44_0
- lz4-c=1.9.2=he1b5a44_3
- markdown=3.3.3=py37h06a4308_0
- matplotlib-base=3.3.3=py37h4f6019d_0
- mkl=2020.2=256
- mkl-service=2.3.0=py37he904b0f_0
- mkl_fft=1.2.0=py37h23d657b_0
- mkl_random=1.1.1=py37h0573a6f_0
- multidict=4.7.6=py37h7b6447c_1
- ncurses=6.2=he6710b0_1
- networkx=2.5=py_0
- numpy=1.19.2=py37h54aff64_0
- numpy-base=1.19.2=py37hfa32c7d_0
- oauthlib=3.1.0=py_0
- olefile=0.46=pyh9f0ad1d_1
- openjpeg=2.3.1=h981e76c_3
- openssl=1.1.1h=h516909a_0
- opt_einsum=3.1.0=py_0
- pillow=8.0.1=py37h63a5d19_0
- pip=20.2.4=py37h06a4308_0
- protobuf=3.13.0.1=py37he6710b0_1
- pyasn1=0.4.8=py_0
- pyasn1-modules=0.2.8=py_0
- pycparser=2.20=py_2
- pyjwt=1.7.1=py37_0
- pyopenssl=19.1.0=pyhd3eb1b0_1
- pyparsing=2.4.7=pyh9f0ad1d_0
- pysocks=1.7.1=py37_1
- python=3.7.9=h7579374_0
- python-dateutil=2.8.1=py_0
- python_abi=3.7=1_cp37m
- pywavelets=1.1.1=py37h161383b_3
- readline=8.0=h7b6447c_0
- requests=2.24.0=py_0
- requests-oauthlib=1.3.0=py_0
- rsa=4.6=py_0
- scikit-image=0.17.2=py37h10a2094_4
- scipy=1.5.2=py37h0b6359f_0
- setuptools=50.3.1=py37h06a4308_1
- six=1.15.0=py37h06a4308_0
- snappy=1.1.8=he1b5a44_3
- sqlite=3.33.0=h62c20be_0
- tensorboard-plugin-wit=1.6.0=py_0
- tensorflow=2.0.0=gpu_py37h768510d_0
- tensorflow-base=2.0.0=gpu_py37h0ec5d1f_0
- tensorflow-estimator=2.0.0=pyh2649769_0
- termcolor=1.1.0=py37_1
- tifffile=2020.11.18=pyhd8ed1ab_0
- tk=8.6.10=hbc83047_0
- toolz=0.11.1=py_0
- tornado=6.1=py37h4abf009_0
- urllib3=1.25.11=py_0
- werkzeug=0.16.1=py_0
- wheel=0.35.1=pyhd3eb1b0_0
- wrapt=1.12.1=py37h7b6447c_1
- xz=5.2.5=h7b6447c_0
- yaml=0.2.5=h516909a_0
- yarl=1.6.2=py37h7b6447c_0
- zipp=3.4.0=pyhd3eb1b0_0
- zlib=1.2.11=h7b6447c_3
- zstd=1.4.5=h6597ccf_2
- pip:
- dlib==19.21.0
- keras==2.3.1
- mtcnn==0.1.0
- opencv-python==4.4.0.46
- pyyaml==5.3.1
- tensorboard==2.0.2
- tensorflow-addons==0.6.0
- tensorflow-gpu==2.0.0
- tqdm==4.53.0
================================================
FILE: inference.py
================================================
from pathlib import Path
from tqdm import tqdm
import tensorflow as tf
from writer import Writer
from utils import general_utils as utils
class Inference(object):
def __init__(self, args, model):
self.args = args
self.G = model.G
def infer_pairs(self):
names = [f for f in self.args.id_dir.iterdir() if f.suffix[1:] in self.args.img_suffixes]
names.extend([f for f in self.args.attr_dir.iterdir() if f.suffix[1:] in self.args.img_suffixes])
for img_name in tqdm(names):
id_path = utils.find_file_by_str(self.args.id_dir, img_name.stem)
attr_path = utils.find_file_by_str(self.args.attr_dir, img_name.stem)
if len(id_path) != 1 or len(attr_path) != 1:
print(f'Could not find a single pair with name: {img_name.stem}')
continue
id_img = utils.read_image(id_path, self.args.resolution, self.args.reals)
attr_img = utils.read_image(attr_path, self.args.resolution, self.args.reals)
out_img = self.G(id_img, attr_img)[0]
utils.save_image(out_img, self.args.output_dir.joinpath(f'{img_name.name}'))
def infer_on_dirs(self):
attr_paths = list(self.args.attr_dir.iterdir())
attr_paths.sort()
id_paths = list(self.args.id_dir.iterdir())
id_paths.sort()
for attr_num, attr_img_path in tqdm(enumerate(attr_paths)):
if not attr_img_path.is_file() or attr_img_path.suffix[1:] not in self.args.img_suffixes:
continue
attr_img = utils.read_image(attr_img_path, self.args.resolution, self.args.reals)
attr_dir = self.args.output_dir.joinpath(f'attr_{attr_num}')
attr_dir.mkdir(exist_ok=True)
utils.save_image(attr_img, attr_dir.joinpath(f'attr_image.png'))
for id_num, id_img_path in enumerate(id_paths):
if not id_img_path.is_file() or id_img_path.suffix[1:] not in self.args.img_suffixes:
continue
id_img = utils.read_image(id_img_path, self.args.resolution, self.args.reals)
pred = self.G(id_img, attr_img)[0]
utils.save_image(pred, attr_dir.joinpath(f'prediction_{id_num}.png'))
utils.save_image(id_img, attr_dir.joinpath(f'id_{id_num}.png'))
def interpolate(self, w_space=True):
# Change to 0,1 for interpolation
extra_start = 0
extra_end = 1
L = extra_end - extra_start
# Extrapolation values include the 0,1 iff
# N-1 is divisible by L if including endpoint
# N is divisble by L o.w
# where L is the length of the extrapolation range ( L = b-a for [a,b] )
# and N is number of jumps
num_jumps = 8 * L + 1
for d in self.args.input_dir.iterdir():
out_d = self.args.output_dir.joinpath(d.name)
out_d.mkdir(exist_ok=True)
ids = list(d.glob('*id*'))
attrs = list(d.glob('*attr*'))
if len(ids) == 1 and len(attrs) == 2:
const = 'id'
elif len(ids) == 2 and len(attrs) == 1:
const = 'attr'
else:
print(f'Wrong data format for {d.name}')
continue
if const == 'id':
start_img = utils.read_image(attrs[0], self.args.resolution, self.args.real_attr)
end_img = utils.read_image(attrs[1], self.args.resolution, self.args.real_attr)
const_img = utils.read_image(ids[0], self.args.resolution, self.args.real_id)
if self.args.loop_fake:
if not self.args.real_attr:
start_img = self.G(start_img, start_img)
end_img = self.G(end_img, end_img)
if not self.args.real_id:
const_img = self.G(const_img, const_img)
const_id = self.G.id_encoder(const_img)
start_attr = self.G.attr_encoder(start_img)
end_attr = self.G.attr_encoder(end_img)
s_z = tf.concat([const_id, start_attr], -1)
e_z = tf.concat([const_id, end_attr], -1)
elif const == 'attr':
start_img = utils.read_image(ids[0], self.args.resolution, self.args.real_id)
end_img = utils.read_image(ids[1], self.args.resolution, self.args.real_id)
const_img = utils.read_image(attrs[0], self.args.resolution, self.args.real_attr)
if self.args.loop_fake:
if not self.args.real_attr:
const_img = self.G(const_img, const_img)[0]
if not self.args.real_id:
start_img = self.G(start_img, start_img)[0]
end_img = self.G(end_img, end_img)[0]
start_id = self.G.id_encoder(start_img)
end_id = self.G.id_encoder(end_img)
const_attr = self.G.attr_encoder(const_img)
s_z = tf.concat([start_id, const_attr], -1)
e_z = tf.concat([end_id, const_attr], -1)
utils.save_image(const_img, out_d.joinpath(f'const_{const}.png'))
utils.save_image(start_img, out_d.joinpath(f'start.png'))
utils.save_image(end_img, out_d.joinpath(f'end.png'))
if w_space:
s_w = self.G.latent_spaces_mapping(s_z)
e_w = self.G.latent_spaces_mapping(e_z)
for i in range(num_jumps):
inter_w = (1 - i / num_jumps) * s_w + (i / num_jumps) * e_w
out = self.G.stylegan_s(inter_w)
out = (out + 1) / 2
utils.save_image(out[0],
out_d.joinpath(f'inter_{i:03}.png'))
else:
for i in range(num_jumps):
inter_z = (1 - i / num_jumps) * s_z + (i / num_jumps) * e_z
inter_w = self.G.latent_spaces_mapping(inter_z)
out = self.G.stylegan_s(inter_w)
out = (out + 1) / 2
utils.save_image(out[0],
out_d.joinpath(f'inter_{i:03}.png'))
================================================
FILE: main.py
================================================
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['USE_SIMPLE_THREADED_LEVEL3'] = '1'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
import sys
import logging
from model.stylegan import StyleGAN_G_synthesis
from model.model import Network
from data_loader.data_loader import DataLoader
from writer import Writer
from trainer import Trainer
from arglib import arglib
from utils import general_utils as utils
sys.path.insert(0, 'model/face_utils')
def init_logger(args):
root_logger = logging.getLogger()
level = logging.DEBUG if args.log_debug else logging.INFO
root_logger.setLevel(level)
file_handler = logging.FileHandler(f'{args.results_dir}/log.txt')
console_handler = logging.StreamHandler()
datefmt = '%Y-%m-%d %H:%M:%S'
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt)
file_handler.setLevel(level)
console_handler.setLevel(level)
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)
root_logger.addHandler(file_handler)
root_logger.addHandler(console_handler)
pil_logger = logging.getLogger('PIL.PngImagePlugin')
pil_logger.setLevel(logging.INFO)
def main():
train_args = arglib.TrainArgs()
args, str_args = train_args.args, train_args.str_args
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
init_logger(args)
logger = logging.getLogger('main')
cmd_line = ' '.join(sys.argv)
logger.info(f'cmd line is: \n {cmd_line}')
logger.info(str_args)
logger.debug('Copying src to results dir')
Writer.set_writer(args.results_dir)
if not args.debug:
description = input('Please write a short description of this run\n')
desc_file = args.results_dir.joinpath('description.txt')
with desc_file.open('w') as f:
f.write(description)
id_model_path = args.pretrained_models_path.joinpath('vggface2.h5')
stylegan_G_synthesis_path = str(
args.pretrained_models_path.joinpath(f'stylegan_G_{args.resolution}x{args.resolution}_synthesis'))
landmarks_model_path = str(args.pretrained_models_path.joinpath('face_utils/keypoints'))
face_detection_model_path = str(args.pretrained_models_path.joinpath('face_utils/detector'))
arcface_model_path = str(args.pretrained_models_path.joinpath('arcface_weights/weights-b'))
utils.landmarks_model_path = str(args.pretrained_models_path.joinpath('shape_predictor_68_face_landmarks.dat'))
stylegan_G_synthesis = StyleGAN_G_synthesis(resolution=args.resolution, is_const_noise=args.const_noise)
stylegan_G_synthesis.load_weights(stylegan_G_synthesis_path)
network = Network(args, id_model_path, stylegan_G_synthesis, landmarks_model_path,
face_detection_model_path, arcface_model_path)
data_loader = DataLoader(args)
trainer = Trainer(args, network, data_loader)
trainer.train()
if __name__ == '__main__':
main()
================================================
FILE: model/__init__.py
================================================
================================================
FILE: model/arcface/arcface.py
================================================
import tensorflow as tf
import math
num_classes = 85742 # 10572
initializer = 'glorot_normal'
# initializer = tf.keras.initializers.TruncatedNormal(
# mean=0.0, stddev=0.05, seed=None)
# initializer = tf.keras.initializers.VarianceScaling(
# scale=0.05, mode='fan_avg', distribution='normal', seed=None)
class Arcfacelayer(tf.keras.layers.Layer):
def __init__(self, output_dim=num_classes, s=64., m=0.50):
self.output_dim = output_dim
self.s = s
self.m = m
super(Arcfacelayer, self).__init__()
def build(self, input_shape):
self.kernel = self.add_weight(name='kernel',
shape=(input_shape[-1],
self.output_dim),
initializer=initializer,
regularizer=tf.keras.regularizers.l2(
l=5e-4),
trainable=True)
super(Arcfacelayer, self).build(input_shape)
def call(self, embedding, labels):
cos_m = math.cos(self.m)
sin_m = math.sin(self.m)
mm = sin_m * self.m # issue 1
threshold = math.cos(math.pi - self.m)
# inputs and weights norm
embedding_norm = tf.norm(embedding, axis=1, keepdims=True)
embedding = embedding / embedding_norm
weights_norm = tf.norm(self.kernel, axis=0, keepdims=True)
weights = self.kernel / weights_norm
# cos(theta+m)
cos_t = tf.matmul(embedding, weights, name='cos_t')
cos_t2 = tf.square(cos_t, name='cos_2')
sin_t2 = tf.subtract(1., cos_t2, name='sin_2')
sin_t = tf.sqrt(sin_t2, name='sin_t')
cos_mt = self.s * tf.subtract(tf.multiply(cos_t, cos_m),
tf.multiply(sin_t, sin_m), name='cos_mt')
# this condition controls the theta+m should in range [0, pi]
# 0<=theta+m<=pi
# -m<=theta<=pi-m
cond_v = cos_t - threshold
cond = tf.cast(tf.nn.relu(cond_v, name='if_else'), dtype=tf.bool)
keep_val = self.s * (cos_t - mm)
cos_mt_temp = tf.where(cond, cos_mt, keep_val)
mask = tf.one_hot(labels, depth=self.output_dim, name='one_hot_mask')
# mask = tf.squeeze(mask, 1)
inv_mask = tf.subtract(1., mask, name='inverse_mask')
s_cos_t = tf.multiply(self.s, cos_t, name='scalar_cos_t')
output = tf.add(tf.multiply(s_cos_t, inv_mask), tf.multiply(
cos_mt_temp, mask), name='arcface_loss_output')
return output
def compute_output_shape(self, input_shape):
return (input_shape[0], self.output_dim)
================================================
FILE: model/arcface/inference.py
================================================
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
import cv2
from model.arcface.resnet import ResNet50, train_model
from mtcnn import MTCNN
from skimage import transform as trans
class MyArcFace:
def __init__(self, path_to_weights):
self.model = train_model()
self.model.load_weights(path_to_weights)
self.model_resnet = self.model.resnet
self.model.resnet.trainable = False
self.mtcnn = MTCNN(min_face_size=80)
def get_best_face(self, faces, resolution):
if len(faces) == 0:
raise IndexError('No faces found')
if len(faces) == 1:
return faces[0]
print('Found more than one face')
indices = list(range(len(faces)))
# filter low confidence
new_indices = [ind for ind in indices if faces[ind]['confidence'] > 0.99]
# print(f'after confidence filtering: {len(new_indices)}')
if len(new_indices) == 1:
return faces[new_indices[0]]
elif len(new_indices) > 1:
indices = new_indices
# filter not centered, distance between x and y must relatively small
new_indices = [ind for ind in indices if np.abs(faces[ind]['box'][0] - faces[ind]['box'][1]) < resolution / 2.5]
# print(f'after center filtering: {len(new_indices)}')
if len(new_indices) == 1:
return faces[new_indices[0]]
elif len(new_indices) > 1:
indices = new_indices
# Take box with biggest height
ind = max(indices, key=lambda ind: faces[ind]['box'][-1])
return faces[ind]
def __detect_face(self, img):
# The assumption is that the image is RGB
faces = self.mtcnn.detect_faces(img)
face_obj = self.get_best_face(faces, img.shape[0])
face_box_obj = face_obj['box']
face_landmarks_obj = face_obj['keypoints']
face_landmarks = np.zeros((5, 2))
face_landmarks[0] = [face_landmarks_obj['left_eye'][0], face_landmarks_obj['right_eye'][1]]
face_landmarks[1] = [face_landmarks_obj['right_eye'][0], face_landmarks_obj['left_eye'][1]]
face_landmarks[2] = [face_landmarks_obj['nose'][0], face_landmarks_obj['nose'][1]]
face_landmarks[3] = [face_landmarks_obj['mouth_left'][0], face_landmarks_obj['mouth_right'][1]]
face_landmarks[4] = [face_landmarks_obj['mouth_right'][0], face_landmarks_obj['mouth_left'][1]]
x = face_box_obj[0]
y = face_box_obj[1]
w = face_box_obj[2]
h = face_box_obj[3]
face_box = [x, y, x + w, y + h]
return face_box, face_landmarks
def __preprocess(self, img, bbox=None, landmark=None):
M = None
image_size = [112, 112]
assert landmark is not None
src = np.array([
[30.2946, 51.6963],
[65.5318, 51.5014],
[48.0252, 71.7366],
[33.5493, 92.3655],
[62.7299, 92.2041]], dtype=np.float32)
if image_size[1] == 112:
src[:, 0] += 8.0
dst = landmark.astype(np.float32)
tform = trans.SimilarityTransform()
tform.estimate(src, dst)
M = tform.params
assert M is not None
transforms = np.array(M).flatten()[:-1]
tf_transforms = tf.constant([transforms], tf.float32)
img_tensor = tf.convert_to_tensor(img.astype(np.float32))
batch = tf.stack([img_tensor])
output = tfa.image.transform(batch, tf_transforms, interpolation='BILINEAR', output_shape=image_size)
return output
def process_image(self, img):
if (isinstance(img, tf.Tensor) and img.dtype != tf.dtypes.uint8) or img.dtype != np.uint8:
img = np.uint8(img * 255)
face_box, face_landmarks = self.__detect_face(img)
aligned_face = self.__preprocess(img, face_box, face_landmarks)
aligned_face -= 127.5
aligned_face *= 0.0078125
embeddings = self.model_resnet(aligned_face)
normelized_embeddings = tf.math.l2_normalize(embeddings)
return normelized_embeddings
def __call__(self, img):
if img.ndim == 4:
embedding_list = []
for x in img:
norm_embedding = self.process_image(x)
embedding_list.append(norm_embedding)
return np.array(embedding_list)
else:
return self.process_image(img)
================================================
FILE: model/arcface/resnet.py
================================================
import tensorflow as tf
import os
from model.arcface.arcface import Arcfacelayer
bn_axis = -1
initializer = 'glorot_normal'
def residual_unit_v3(input, num_filter, stride, dim_match, name):
x = tf.keras.layers.BatchNormalization(axis=bn_axis,
scale=True,
momentum=0.9,
epsilon=2e-5,
# beta_regularizer=tf.keras.regularizers.l2(
# l=5e-4),
gamma_regularizer=tf.keras.regularizers.l2(
l=5e-4),
name=name + '_bn1')(input)
x = tf.keras.layers.ZeroPadding2D(
padding=(1, 1), name=name + '_conv1_pad')(x)
x = tf.keras.layers.Conv2D(num_filter, (3, 3),
strides=(1, 1),
padding='valid',
kernel_initializer=initializer,
use_bias=False,
kernel_regularizer=tf.keras.regularizers.l2(
l=5e-4),
name=name + '_conv1')(x)
x = tf.keras.layers.BatchNormalization(axis=bn_axis,
scale=True,
momentum=0.9,
epsilon=2e-5,
# beta_regularizer=tf.keras.regularizers.l2(
# l=5e-4),
gamma_regularizer=tf.keras.regularizers.l2(
l=5e-4),
name=name + '_bn2')(x)
x = tf.keras.layers.PReLU(name=name + '_relu1',
alpha_regularizer=tf.keras.regularizers.l2(
l=5e-4))(x)
x = tf.keras.layers.ZeroPadding2D(
padding=(1, 1), name=name + '_conv2_pad')(x)
x = tf.keras.layers.Conv2D(num_filter, (3, 3),
strides=stride,
padding='valid',
kernel_initializer=initializer,
use_bias=False,
kernel_regularizer=tf.keras.regularizers.l2(
l=5e-4),
name=name + '_conv2')(x)
x = tf.keras.layers.BatchNormalization(axis=bn_axis,
scale=True,
momentum=0.9,
epsilon=2e-5,
# beta_regularizer=tf.keras.regularizers.l2(
# l=5e-4),
gamma_regularizer=tf.keras.regularizers.l2(
l=5e-4),
name=name + '_bn3')(x)
if (dim_match):
shortcut = input
else:
shortcut = tf.keras.layers.Conv2D(num_filter, (1, 1),
strides=stride,
padding='valid',
kernel_initializer=initializer,
use_bias=False,
kernel_regularizer=tf.keras.regularizers.l2(
l=5e-4),
name=name + '_conv1sc')(input)
shortcut = tf.keras.layers.BatchNormalization(axis=bn_axis,
scale=True,
momentum=0.9,
epsilon=2e-5,
# beta_regularizer=tf.keras.regularizers.l2(
# l=5e-4),
gamma_regularizer=tf.keras.regularizers.l2(
l=5e-4),
name=name + '_sc')(shortcut)
return x + shortcut
def get_fc1(input):
x = tf.keras.layers.BatchNormalization(axis=bn_axis,
scale=True,
momentum=0.9,
epsilon=2e-5,
# beta_regularizer=tf.keras.regularizers.l2(
# l=5e-4),
gamma_regularizer=tf.keras.regularizers.l2(
l=5e-4),
name='bn1')(input)
x = tf.keras.layers.Dropout(0.4)(x)
resnet_shape = input.shape
x = tf.keras.layers.Reshape(
[resnet_shape[1] * resnet_shape[2] * resnet_shape[3]], name='reshapelayer')(x)
x = tf.keras.layers.Dense(512,
name='E_DenseLayer', kernel_initializer=initializer,
kernel_regularizer=tf.keras.regularizers.l2(
l=5e-4),
bias_regularizer=tf.keras.regularizers.l2(
l=5e-4))(x)
x = tf.keras.layers.BatchNormalization(axis=-1,
scale=False,
momentum=0.9,
epsilon=2e-5,
# beta_regularizer=tf.keras.regularizers.l2(
# l=5e-4),
name='fc1')(x)
return x
def ResNet50():
input_shape = [112, 112, 3]
filter_list = [64, 64, 128, 256, 512]
units = [3, 4, 14, 3]
num_stages = 4
img_input = tf.keras.layers.Input(shape=input_shape)
x = tf.keras.layers.ZeroPadding2D(
padding=(1, 1), name='conv0_pad')(img_input)
x = tf.keras.layers.Conv2D(64, (3, 3),
strides=(1, 1),
padding='valid',
kernel_initializer=initializer,
use_bias=False,
kernel_regularizer=tf.keras.regularizers.l2(
l=5e-4),
name='conv0')(x)
x = tf.keras.layers.BatchNormalization(axis=bn_axis,
scale=True,
momentum=0.9,
epsilon=2e-5,
# beta_regularizer=tf.keras.regularizers.l2(
# l=5e-4),
gamma_regularizer=tf.keras.regularizers.l2(
l=5e-4),
name='bn0')(x)
# x = tf.keras.layers.Activation('prelu')(x)
x = tf.keras.layers.PReLU(
name='prelu0',
alpha_regularizer=tf.keras.regularizers.l2(
l=5e-4))(x)
for i in range(num_stages):
x = residual_unit_v3(x, filter_list[i + 1], (2, 2), False,
name='stage%d_unit%d' % (i + 1, 1))
for j in range(units[i] - 1):
x = residual_unit_v3(x, filter_list[i + 1], (1, 1),
True, name='stage%d_unit%d' % (i + 1, j + 2))
x = get_fc1(x)
# Create model.
model = tf.keras.models.Model(img_input, x, name='resnet50')
model.trainable = True
for i in range(len(model.layers)):
model.layers[i].trainable = True
# if ('conv0' in model.layers[i].name):
# model.layers[i].trainable = False
# if ('bn0' in model.layers[i].name):
# model.layers[i].trainable = False
# if ('prelu0' in model.layers[i].name):
# model.layers[i].trainable = False
# if ('stage1' in model.layers[i].name):
# model.layers[i].trainable = False
# if ('stage2' in model.layers[i].name):
# model.layers[i].trainable = False
# if ('stage3' in model.layers[i].name):
# model.layers[i].trainable = False
# if ('stage4' in model.layers[i].name):
# model.layers[i].trainable = False
return model
class train_model(tf.keras.Model):
def __init__(self):
super(train_model, self).__init__()
self.resnet = ResNet50()
self.arcface = Arcfacelayer()
def call(self, x, y):
x = self.resnet(x)
return self.arcface(x, y)
================================================
FILE: model/attr_encoder.py
================================================
import logging
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input
class AttrEncoder(Model):
def __init__(self, args):
super().__init__()
self.args = args
self.logger = logging.getLogger(__class__.__name__)
attr_encoder = InceptionV3(include_top=False, pooling='avg')
self.model = attr_encoder
if self.args.load_checkpoint:
self.model.load_weights(str(self.args.load_checkpoint.joinpath(self.__class__.__name__ + '.h5')))
@tf.function
def call(self, input_x):
x = tf.image.resize(input_x, (299, 299))
x = preprocess_input(255 * x)
x = self.model(x)
x = tf.expand_dims(x, 1)
return x
def my_save(self, reason=''):
self.model.save_weights(str(self.args.weights_dir.joinpath(self.__class__.__name__ + reason + '.h5')))
================================================
FILE: model/discriminator.py
================================================
import tensorflow as tf
from tensorflow.keras import layers, Model
from utils.general_utils import get_weights
# Discriminate between my w's and StyleGAN's w's
class W_D(Model):
def __init__(self, args):
super().__init__()
self.args = args
slope = 0.2
# self.linear1 = layers.Dense(512, kernel_initializer=get_weights(slope), input_shape=(512,))
self.linear2 = layers.Dense(256, kernel_initializer=get_weights(slope), input_shape=(512,))
self.linear3 = layers.Dense(128, kernel_initializer=get_weights(slope))
self.linear4 = layers.Dense(64, kernel_initializer=get_weights(slope))
self.linear5 = layers.Dense(1, kernel_initializer=get_weights(slope))
self.relu = layers.LeakyReLU(slope)
if self.args.load_checkpoint:
self.build(input_shape=(1, 1, 512))
self.load_weights(str(self.args.load_checkpoint.joinpath(self.__class__.__name__ + '.h5')))
@tf.function
def call(self, x):
# x = self.linear1(x)
# x = self.relu(x)
x = self.linear2(x)
x = self.relu(x)
x = self.linear3(x)
x = self.relu(x)
x = self.linear4(x)
x = self.relu(x)
x = self.linear5(x)
return x
def my_save(self, reason=''):
self.save_weights(str(self.args.weights_dir.joinpath(self.__class__.__name__ + reason + '.h5')))
================================================
FILE: model/face_detector.py
================================================
import tensorflow as tf
class FaceDetector(object):
def __init__(self, args, model_path):
super().__init__()
self.args = args
self.model_path = model_path
self.model = None
def _build(self):
if not self.model:
self.model = tf.saved_model.load(self.model_path)
def __call__(self, input_x):
"""
Given a batch of images, return the face bounding box in (x1,y1,x2,y2) format
"""
if not self.model:
self._build()
boxes = []
for sample in input_x:
boxes.append(self.sample_call(sample))
boxes = tf.stack(boxes, axis=0)
boxes = boxes * self.args.resolution
return boxes
def sample_call(self, input_x):
boxes = self.model.inference(tf.expand_dims(input_x, axis=0))
boxes = tf.squeeze(boxes)
indices, scores = \
tf.image.non_max_suppression_with_scores(boxes[..., :4], boxes[..., 4],
max_output_size=1, iou_threshold=0.3, score_threshold=0.5)
i = indices.numpy()[0]
box = boxes[i, :4]
return box
================================================
FILE: model/generator.py
================================================
import logging
from model import id_encoder
from model import attr_encoder
from model import latent_mapping
from model import landmarks
from model.arcface.inference import MyArcFace
import tensorflow as tf
from tensorflow.keras import layers, Model
class G(Model):
def __init__(self, args, id_model_path, image_G,
landmarks_net_path, face_detection_model_path, test_id_model_path):
super().__init__()
self.args = args
self.logger = logging.getLogger(__class__.__name__)
self.id_encoder = id_encoder.IDEncoder(args, id_model_path)
self.id_encoder.trainable = False
self.attr_encoder = attr_encoder.AttrEncoder(args)
self.latent_spaces_mapping = latent_mapping.LatentMappingNetwork(args)
self.stylegan_s = image_G
self.stylegan_s.trainable = False
if args.train:
self.test_id_encoder = MyArcFace(test_id_model_path)
self.test_id_encoder.trainable = False
self.landmarks = landmarks.LandmarksDetector(args, landmarks_net_path, face_detection_model_path)
self.landmarks.trainable = False
@tf.function
def call(self, x1, x2):
id_embedding = self.id_encoder(x1)
if self.args.train:
lnds = self.landmarks(x2)
else:
lnds = None
attr_input = x2
attr_out = self.attr_encoder(attr_input)
attr_embedding = attr_out
z_tag = tf.concat([id_embedding, attr_embedding], -1)
w = self.latent_spaces_mapping(z_tag)
out = self.stylegan_s(w)
# Move to roughly [0,1]
out = (out + 1) / 2
return out, id_embedding, attr_out, w[:, 0, :], lnds
def my_save(self, reason=''):
self.attr_encoder.my_save(reason)
self.latent_spaces_mapping.my_save(reason)
================================================
FILE: model/id_encoder.py
================================================
import tensorflow as tf
import numpy as np
from tensorflow.keras import Model
class IDEncoder(Model):
def __init__(self, args, model_path, intermediate_layers_names=None):
super().__init__()
self.args = args
self.mean = (91.4953, 103.8827, 131.0912)
base_model = tf.keras.models.load_model(model_path)
if intermediate_layers_names:
outputs = [base_model.get_layer(name).output for name in intermediate_layers_names]
else:
outputs = []
# Add output of the network in any case
outputs.append(base_model.layers[-2].output)
self.model = tf.keras.Model(base_model.inputs, outputs)
def crop_faces(self, img):
ps = []
for i in range(img.shape[0]):
oneimg = img[i]
try:
box = tf.numpy_function(self.mtcnn.detect_faces, [oneimg], np.uint8)
box = [z.numpy() for z in box[:4]]
x1, y1, w, h = box
x_expand = w * 0.3
y_expand = h * 0.3
x1 = int(np.maximum(x1 - x_expand // 2, 0))
y1 = int(np.maximum(y1 - y_expand // 2, 0))
x2 = int(np.minimum(x1 + w + x_expand // 2, self.args.resolution))
y2 = int(np.minimum(y1 + h + y_expand // 2, self.args.resolution))
except Exception as e:
x1, y1, x2, y2 = 24, 50, 224, 250
p = oneimg[y1:y2, x1:x2, :]
p = tf.convert_to_tensor(p)
p = tf.image.resize(p, (self.args.resolution, self.args.resolution))
ps.append(p)
ps = tf.stack(ps, 0)
return ps
def preprocess(self, img):
"""
In VGGFace2 The preprocessing is:
1. Face detection
2. Expand bbox by factor of 0.3
3. Resize so shorter side is 256
4. Crop center 224x224
In StyleGAN faces are not in-the-wild, we get an image of the head.
Just cropping a loose center instead of face detection
"""
# Go from [0, 1] to [0, 255]
img = 255 * img
min_x = int(0.1 * self.args.resolution)
max_x = int(0.9 * self.args.resolution)
min_y = int(0.1 * self.args.resolution)
max_y = int(0.9 * self.args.resolution)
img = img[:, min_x:max_x, min_y:max_y, :]
img = tf.image.resize(img, (256, 256))
start = (256 - 224) // 2
img = img[:, start: 224 + start, start: 224 + start, :]
img = img[:, :, :, ::-1] - self.mean
return img
@tf.function
def call(self, input_x, get_intermediate=False):
x = self.preprocess(input_x)
x = self.model(x)
if isinstance(x, list):
embedding = x[-1]
intermediates = x[:-1]
else:
embedding = x
intermediates = None
embedding = tf.math.l2_normalize(embedding, axis=-1)
embedding = tf.expand_dims(embedding, 1)
if get_intermediate and intermediates:
return embedding, intermediates
else:
return embedding
================================================
FILE: model/landmarks.py
================================================
import cv2
import tensorflow as tf
import numpy as np
from tensorflow.keras import Model
from utils import general_utils as utils
from model.face_detector import FaceDetector
class LandmarksDetector(Model):
def __init__(self, args, model_path, face_detection_model_path):
super().__init__()
self.args = args
self.face_detector = FaceDetector(args, face_detection_model_path)
self.expand_ratio = 0.2
# Load without source code
self.model = tf.saved_model.load(model_path)
# Preprocess
def preprocess(self, imgs, face_detection=False):
imgs *= 255
if face_detection:
imgs, details = self.hard_preprocess(imgs)
else:
imgs, details = self.lazy_preprocess(imgs)
return imgs, details
def lazy_preprocess(self, imgs):
imgs = tf.image.resize(imgs, (160, 160))
return imgs, 160
def hard_preprocess(self, imgs):
bboxes = self.face_detector(imgs)
centers = np.array([bboxes[:, 0] + bboxes[:, 2], bboxes[:, 1] + bboxes[:, 3]]).T // 2
# Duplicate center point into column order of x,x,y,y
centers = np.repeat(centers, repeats=2, axis=1)
# Permute columns order into x,y,x,y
centers[:] = utils.np_permute(centers, [0, 2, 1, 3])
# Calculate widths of current bboxes
widths = np.transpose([bboxes[:, 2] - bboxes[:, 0]])
# Calculate the maximal expansion
max_expand = int(np.ceil(np.max(widths) * self.expand_ratio))
# Pad the image with the maximal expansion.
# Useful in case an expanded bounding box goes outside image
paddings = tf.constant([[0, 0], [max_expand, max_expand], [max_expand, max_expand], [0, 0]])
pad_imgs = tf.pad(imgs, paddings, mode='CONSTANT', constant_values=127.)
# The size of the new square bounding box
new_scales = np.floor((1 + 2 * self.expand_ratio) * widths)
# Size of step from the center
new_half_scales = new_scales // 2
# Repeat step in all directions
# Decrease in start point, Increase in end point
new_half_scales = np.repeat(new_half_scales, repeats=4, axis=1) * [-1, -1, 1, 1]
# Bounding boxes in respect to padded image
new_bboxes = centers + new_half_scales + max_expand
# tf.image.crop_and_resize requires bounding boxes to be normalized
# i.e., between [0,1] and also in order (y,x)
normed_bboxes = utils.np_permute(new_bboxes, [1, 0, 3, 2]) / pad_imgs.shape[1]
cropped_imgs = tf.image.crop_and_resize(pad_imgs, normed_bboxes,
box_indices=range(self.args.batch_size), crop_size=(160, 160))
details = (new_scales, new_bboxes[:,:2], max_expand)
return cropped_imgs, details
# Postprocess
def postprocess(self, landmarks, details, face_detection=False):
landmarks = tf.reshape(landmarks, [-1, 68, 2])
if face_detection:
return self.hard_postprocess(landmarks, details)
else:
return self.lazy_postprocess(landmarks, details)
def lazy_postprocess(self, batch_lnds, details):
scale = details
return scale * batch_lnds
def hard_postprocess(self, batch_lnds, details):
scale, from_origin, pad = details
scale = tf.broadcast_to(scale, [scale.shape[0], 2])
scale = tf.expand_dims(scale, axis=1)
from_origin = tf.expand_dims(from_origin, axis=1)
from_origin = tf.cast(from_origin, tf.dtypes.float32)
lnds = batch_lnds * scale + from_origin - pad
return lnds
@tf.function
def call(self, input_x, face_detection=False):
# The network input format is a uint8 image (0-255) but in float32 dtype. ^__('')__^
x, details = self.preprocess(input_x, face_detection)
batch_lnds = self.model.inference(x)['landmark']
batch_lnds = self.postprocess(batch_lnds, details, face_detection)
return batch_lnds[:, 17:, :]
================================================
FILE: model/latent_mapping.py
================================================
from utils.general_utils import get_weights
import tensorflow as tf
import numpy as np
from tensorflow.keras import layers, Model
class LatentMappingNetwork(Model):
def __init__(self, args):
super().__init__()
self.args = args
input_shape = (2560,)
self.linear1 = layers.Dense(2048, input_shape=input_shape)
self.linear2 = layers.Dense(1024)
self.linear3 = layers.Dense(512, kernel_initializer=get_weights())
self.linear4 = layers.Dense(512, kernel_initializer=get_weights())
self.linears = [self.linear1, self.linear2, self.linear3, self.linear4]
self.relu = layers.LeakyReLU(0.2)
self.num_styles = int(np.log2(self.args.resolution)) * 2 - 2
if self.args.load_checkpoint:
self.build(input_shape=(1, 1, 2560))
self.load_weights(str(self.args.load_checkpoint.joinpath(self.__class__.__name__ + '.h5')))
@tf.function
def call(self, x):
first = True
for layer in self.linears:
if not first:
x = self.relu(x)
x = layer(x)
first = False
s = list(x.shape)
# Duplicate the column vector w along columns for each AdaIN entry
s[1] = self.num_styles
x = tf.broadcast_to(x, s)
return x
def my_save(self, reason=''):
self.save_weights(str(self.args.weights_dir.joinpath(self.__class__.__name__ + reason + '.h5')))
================================================
FILE: model/model.py
================================================
import time
import sys
sys.path.append('..')
from utils import general_utils as utils
from model import id_encoder, latent_mapping, attr_encoder,\
generator, discriminator, landmarks
from model.stylegan import StyleGAN_G, StyleGAN_D
import tensorflow as tf
from tensorflow.keras import layers, Model
class Network(Model):
def __init__(self, args, id_net_path, base_generator,
landmarks_net_path=None, face_detection_model_path=None, test_id_net_path=None):
super().__init__()
self.args = args
self.G = generator.G(args, id_net_path, base_generator,
landmarks_net_path, face_detection_model_path, test_id_net_path)
if self.args.train:
self.W_D = discriminator.W_D(args)
def call(self):
raise NotImplemented()
def my_save(self, reason):
self.G.my_save(reason)
if self.args.W_D_loss:
self.W_D.my_save(reason)
def my_load(self):
raise NotImplemented()
def train(self):
self._set_trainable_behavior(True)
def test(self):
self._set_trainable_behavior(False)
def _set_trainable_behavior(self, trainable):
self.G.attr_encoder.trainable = trainable
self.G.latent_spaces_mapping.trainable = trainable
================================================
FILE: model/stylegan.py
================================================
import sys
import math
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import Layer, InputLayer, Multiply, Lambda, Flatten, Dense, Conv2D, Conv2DTranspose
from tensorflow.keras.initializers import VarianceScaling
from tensorflow.python.keras import backend
from tensorflow.python.ops import array_ops
def nf(stage, fmap_base=8192, fmap_decay=1.0, fmap_max=512):
return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)
def LeakyReLU(alpha, name):
def lrelu(x, alpha):
alpha = tf.constant(alpha, dtype=x.dtype, name='alpha')
return tf.maximum(x, x * alpha)
return Lambda(lambda x: lrelu(x, alpha), name=name)
def GetWeights(gain=math.sqrt(2)):
return VarianceScaling(gain)
def runtime_coef(kernel_size, gain, fmaps_in, fmaps_out, lrmul=1.0):
# Equalized learning rate and custom learning rate multiplier.
shape = [kernel_size[0], kernel_size[1], fmaps_in, fmaps_out]
fan_in = np.prod(shape[:-1]) # [kernel, kernel, fmaps_in, fmaps_out] or [in, out]
he_std = gain / np.sqrt(fan_in) # He init
init_std = 1.0 / lrmul
return he_std * lrmul
def pixel_norm(x, epsilon=1e-8):
epsilon = tf.constant(epsilon, dtype=x.dtype, name='epsilon')
return x * tf.math.rsqrt(tf.reduce_mean(tf.square(x), axis=1, keepdims=True) + epsilon)
class PixelNorm(Layer):
def __init__(self, name):
super(PixelNorm, self).__init__(name=name)
def call(self, inputs):
return pixel_norm(inputs)
class InstanceNorm(Layer):
def __init__(self, name):
super(InstanceNorm, self).__init__(name=name)
def call(self, x):
epsilon=1e-8
orig_dtype = x.dtype
x = tf.cast(x, tf.float32)
x -= tf.reduce_mean(x, axis=[2,3], keepdims=True)
epsilon = tf.constant(epsilon, dtype=x.dtype, name='epsilon')
x *= tf.math.rsqrt(tf.reduce_mean(tf.square(x), axis=[2,3], keepdims=True) + epsilon)
x = tf.cast(x, orig_dtype)
return x
def Identity(name):
return Lambda(lambda x: x, name=name)
def Broadcast(name, dlatent_broadcast=18):
def broadcast(x):
return tf.tile(x[:, np.newaxis], [1, dlatent_broadcast, 1])
return Lambda(lambda x: broadcast(x), name=name)
class Truncation(Layer):
def __init__(self, name, num_layers=18, truncation_psi=0.7, truncation_cutoff=8):
super(Truncation, self).__init__(name=name)
self.num_layers = num_layers
self.truncation_psi = truncation_psi
self.truncation_cutoff = truncation_cutoff
def build(self, input_shape):
self.dlatent_avg = self.add_variable('dlatent_avg', shape=[int(input_shape[-1])])
def call(self, inputs):
layer_idx = np.arange(self.num_layers)[np.newaxis, :, np.newaxis]
ones = np.ones(layer_idx.shape, dtype=np.float32)
coefs = tf.where(layer_idx < self.truncation_cutoff, self.truncation_psi * ones, ones)
def lerp(a,b,t): return a + (b - a) * t
return lerp(self.dlatent_avg, inputs, coefs)
class DenseLayer(Dense):
def __init__(self, units, name, kernel_initializer=GetWeights(), gain=math.sqrt(2), lrmul=1.0):
super(DenseLayer, self).__init__(units=units, kernel_initializer=kernel_initializer, name=name)
self.gain = gain
self.lrmul = lrmul
def call(self, inputs):
x, b, w = inputs, self.bias * self.lrmul, self.kernel * runtime_coef([1,1], self.gain, inputs.shape[1], self.units, lrmul=self.lrmul)
# Input x kernel
if len(x.shape) > 2: x = tf.reshape(x, [-1, np.prod([d for d in x.shape[1:]])])
x = tf.matmul(x, w)
# Bias
if len(x.shape) == 2:
return x + b
return x + tf.reshape(b, [1, -1, 1, 1])
class Conv2d(Conv2D):
def __init__(self, filters, kernel_size, name, gain=math.sqrt(2), lrmul=1.0, kernel_modifier=None, strides=1, use_bias=True):
super(Conv2d, self).__init__(filters=filters, kernel_size=kernel_size, kernel_initializer=GetWeights(gain),
use_bias=use_bias, padding='same', data_format='channels_first', name=name, strides=strides)
self.gain = gain
self.lrmul = lrmul
self.kernel_modifier = kernel_modifier
# Perform convolution with modified kernel then add bias
def call(self, inputs):
if self.kernel_modifier is None:
w = self.kernel
else:
w = self.kernel_modifier(self.kernel)
outputs = self._convolution_op(inputs, w * runtime_coef(self.kernel_size, self.gain, inputs.shape[1], self.filters))
if self.use_bias:
b = self.bias * self.lrmul
if self.data_format == 'channels_first':
outputs = tf.nn.bias_add(outputs, b, data_format='NCHW')
else:
outputs = tf.nn.bias_add(outputs, b, data_format='NHWC')
return outputs
class Const(Layer):
def __init__(self, name):
super(Const, self).__init__(name=name)
def build(self, input_shape):
self.const = self.add_variable('const', shape=[1,512,4,4])
def call(self, inputs):
return tf.tile(self.const, [tf.shape(inputs)[0], 1, 1, 1])
class RandomNoise(Layer):
def __init__(self, name, layer_idx):
super(RandomNoise, self).__init__(name=name)
res = layer_idx // 2 + 2
self.layer_idx = layer_idx
self.noise_shape = [1, 1, 2**res, 2**res]
def build(self, input_shape):
self.noise = self.add_variable('noise', shape=self.noise_shape, initializer=tf.initializers.zeros(), trainable=False)
def call(self, inputs):
return self.noise
class ApplyNoise(Layer):
def __init__(self, name, is_const_noise):
super(ApplyNoise, self).__init__(name=name)
self.is_const_noise = is_const_noise
def build(self, input_shape):
input_shape = input_shape[0]
self.weight = self.add_variable('weight', shape=[input_shape[1]], initializer=tf.initializers.zeros())
def call(self, inputs):
x, noise = inputs
if not self.is_const_noise:
noise = tf.random.normal([tf.shape(x)[0], 1, x.shape[2], x.shape[3]], dtype=x.dtype)
return x + noise * tf.reshape(self.weight, [1, -1, 1, 1])
class ApplyBias(Layer):
def __init__(self, name, lrmul=1.0):
super(ApplyBias, self).__init__(name=name)
self.lrmul = lrmul
def build(self, input_shape):
self.bias = self.add_variable('bias', shape=[input_shape[1]])
def call(self, x):
b = self.bias * self.lrmul
if len(x.shape) == 2: return x + b
return x + tf.reshape(b, [1, -1, 1, 1])
class StridedSlice(Layer):
def __init__(self, layer_idx, name):
super(StridedSlice, self).__init__(name=name)
self.layer_idx = layer_idx
def call(self, inputs):
return inputs[:, self.layer_idx]
class StyleModApply(Layer):
def __init__(self, name):
super(StyleModApply, self).__init__(name=name)
def call(self, inputs):
x, style = inputs
style = tf.reshape(style, [-1, 2, x.shape[1]] + [1] * (len(x.shape) - 2))
return x * (style[:,0] + 1) + style[:,1]
def _blur2d(x, f=[1,2,1], normalize=True, flip=False, stride=1):
assert x.shape.ndims == 4 and all(dim is not None for dim in x.shape[1:])
assert isinstance(stride, int) and stride >= 1
# Finalize filter kernel.
f = np.array(f, dtype=np.float32)
if f.ndim == 1:
f = f[:, np.newaxis] * f[np.newaxis, :]
assert f.ndim == 2
if normalize:
f /= np.sum(f)
if flip:
f = f[::-1, ::-1]
f = f[:, :, np.newaxis, np.newaxis]
f = np.tile(f, [1, 1, int(x.shape[1]), 1])
# No-op => early exit.
if f.shape == (1, 1) and f[0,0] == 1:
return x
# Convolve using depthwise_conv2d.
orig_dtype = x.dtype
x = tf.cast(x, tf.float32) # tf.nn.depthwise_conv2d() doesn't support fp16
f = tf.constant(f, dtype=x.dtype, name='filter')
strides = [1, 1, stride, stride]
x = tf.nn.depthwise_conv2d(x, f, strides=strides, padding='SAME', data_format='NCHW')
x = tf.cast(x, orig_dtype)
return x
def Blur(name, blur_filter=[1,2,1]):
def blur2d(x, f=[1,2,1], normalize=True):
return _blur2d(x, f, normalize)
return Lambda(lambda x: blur2d(x, blur_filter), name=name)
def _downscale2d(x, factor=2, gain=1):
assert x.shape.ndims == 4 and all(dim is not None for dim in x.shape[1:])
assert isinstance(factor, int) and factor >= 1
# 2x2, float32 => downscale using _blur2d().
if factor == 2 and x.dtype == tf.float32:
f = [np.sqrt(gain) / factor] * factor
return _blur2d(x, f=f, normalize=False, stride=factor)
# Apply gain.
if gain != 1:
x *= gain
# No-op => early exit.
if factor == 1:
return x
# Large factor => downscale using tf.nn.avg_pool().
# NOTE: Requires tf_config['graph_options.place_pruned_graph']=True to work.
ksize = [1, 1, factor, factor]
return tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding='VALID', data_format='NCHW')
def _upscale2d(x, factor=2, gain=1):
assert x.shape.ndims == 4 and all(dim is not None for dim in x.shape[1:])
assert isinstance(factor, int) and factor >= 1
# Apply gain.
if gain != 1:
x *= gain
# No-op => early exit.
if factor == 1:
return x
# Upscale using tf.tile().
s = x.shape
x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1])
x = tf.tile(x, [1, 1, 1, factor, 1, factor])
x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor])
return x
def Downscaled2d(name, factor=2, gain=1):
return Lambda(lambda x: _downscale2d(x, factor, gain), name=name+'/Downscaled2d')
def Upscaled2d(name, factor=2, gain=1):
return Lambda(lambda x: _upscale2d(x, factor, gain), name=name+'/Upscaled2d')
def Conv2d_downscale2d(model, filters, kernel_size, name, gain=math.sqrt(2), fused_scale='auto'):
if fused_scale == 'auto':
x = model.layers[-1].output
fused_scale = min(x.shape[2:]) >= 128
if not fused_scale:
# Not fused => call the individual ops directly.
model.add( Conv2d(filters, kernel_size, name, gain) )
model.add( Downscaled2d(name) )
else:
# Fused => perform both ops simultaneously using tf.nn.conv2d().
def fused_op(w):
w = tf.pad(w, [[1,1], [1,1], [0,0], [0,0]], mode='CONSTANT')
w = tf.add_n([w[1:, 1:], w[:-1, 1:], w[1:, :-1], w[:-1, :-1]]) * 0.25
return w
model.add( Conv2d(filters, kernel_size, name, gain, kernel_modifier=fused_op, strides=2) )
def Upscale2d_conv2d(x, filters, kernel_size, name, use_bias, gain=math.sqrt(2), fused_scale='auto'):
if fused_scale == 'auto':
fused_scale = min(x.shape[2:]) * 2 >= 128
if not fused_scale:
x = Upscaled2d(name)(x)
x = Conv2d(filters, kernel_size, name=name, gain=gain, use_bias=use_bias)(x)
return x
return Conv2d_transpose(filters, kernel_size, name, gain, strides=2)(x)
class Conv2d_transpose(Conv2DTranspose):
def __init__(self, filters, kernel_size, name, gain=math.sqrt(2), lrmul=1.0, kernel_modifier=None, strides=2, use_bias=False):
super(Conv2d_transpose, self).__init__(filters=filters, kernel_size=kernel_size, kernel_initializer=GetWeights(gain),
use_bias=use_bias, padding='same', data_format='channels_first', name=name, strides=strides)
self.gain = gain
self.lrmul = lrmul
self.kernel_modifier = kernel_modifier
def build(self, input_shape):
shape = [self.kernel_size[0], self.kernel_size[1], input_shape[1], self.filters]
self.kernel = self.add_variable('weight', shape=shape, initializer=tf.initializers.zeros())
def call(self, inputs):
# Fused => perform both ops simultaneously using tf.nn.conv2d_transpose().
def fused_op(w):
w = tf.transpose(w, [0, 1, 3, 2]) # [kernel, kernel, fmaps_out, fmaps_in]
w = tf.pad(w, [[1,1], [1,1], [0,0], [0,0]], mode='CONSTANT')
w = tf.add_n([w[1:, 1:], w[:-1, 1:], w[1:, :-1], w[:-1, :-1]])
return w
x, w = inputs, fused_op(self.kernel * runtime_coef(self.kernel_size, self.gain, inputs.shape[1], self.filters, lrmul=self.lrmul))
os = [tf.shape(inputs)[0], self.filters, inputs.shape[2] * 2, inputs.shape[3] * 2]
outputs = tf.nn.conv2d_transpose(x, w, os, strides=[1,1,2,2], padding='SAME', data_format='NCHW')
return outputs
class MinibatchStddevLayer(tf.keras.layers.Layer):
def __init__(self, group_size =4, num_new_features=1):
super().__init__()
self.group_size = group_size
self.num_new_features = num_new_features
def __call__(self, x, *args, **kwargs):
group_size = tf.minimum(self.group_size,
tf.shape(x)[0]) # Minibatch must be divisible by (or smaller than) group_size.
s = x.shape # [NCHW] Input shape.
y = tf.reshape(x, [group_size, -1, self.num_new_features, s[1] // self.num_new_features, s[2], s[
3]]) # [GMncHW] Split minibatch into M groups of size G. Split channels into n channel groups c.
y = tf.cast(y, tf.float32) # [GMncHW] Cast to FP32.
y -= tf.reduce_mean(y, axis=0, keepdims=True) # [GMncHW] Subtract mean over group.
y = tf.reduce_mean(tf.square(y), axis=0) # [MncHW] Calc variance over group.
y = tf.sqrt(y + 1e-8) # [MncHW] Calc stddev over group.
y = tf.reduce_mean(y, axis=[2, 3, 4], keepdims=True) # [Mn111] Take average over fmaps and pixels.
y = tf.reduce_mean(y, axis=[2]) # [Mn11] Split channels into c channel groups
y = tf.cast(y, x.dtype) # [Mn11] Cast back to original data type.
y = tf.tile(y, [group_size, 1, s[2], s[3]]) # [NnHW] Replicate over group and pixels.
return tf.concat([x, y], axis=1) # [NCHW] Append as new fmap.
def minibatch_stddev_layer(x, group_size=4, num_new_features=1):
with tf.compat.v1.variable_scope('MinibatchStddev'):
group_size = tf.minimum(group_size, tf.shape(x)[0]) # Minibatch must be divisible by (or smaller than) group_size.
s = x.shape # [NCHW] Input shape.
y = tf.reshape(x, [group_size, -1, num_new_features, s[1]//num_new_features, s[2], s[3]]) # [GMncHW] Split minibatch into M groups of size G. Split channels into n channel groups c.
y = tf.cast(y, tf.float32) # [GMncHW] Cast to FP32.
y -= tf.reduce_mean(y, axis=0, keepdims=True) # [GMncHW] Subtract mean over group.
y = tf.reduce_mean(tf.square(y), axis=0) # [MncHW] Calc variance over group.
y = tf.sqrt(y + 1e-8) # [MncHW] Calc stddev over group.
y = tf.reduce_mean(y, axis=[2,3,4], keepdims=True) # [Mn111] Take average over fmaps and pixels.
y = tf.reduce_mean(y, axis=[2]) # [Mn11] Split channels into c channel groups
y = tf.cast(y, x.dtype) # [Mn11] Cast back to original data type.
y = tf.tile(y, [group_size, 1, s[2], s[3]]) # [NnHW] Replicate over group and pixels.
return tf.concat([x, y], axis=1) # [NCHW] Append as new fmap.
def StyleGAN_G_mapping( latent_size=512, dlatent_size=512, mapping_layers=8, mapping_fmaps=512, mapping_lrmul=0.01,
truncation_psi=1, resolution=1024):
resolution_log2 = int(np.log2(resolution))
num_layers = resolution_log2 * 2 - 2
model = Sequential(name='G_mapping')
model.add( InputLayer(input_shape=[latent_size], name='G_mapping/latents_in') )
# Normalize latents.
model.add( PixelNorm(name='G_mapping/PixelNorm') )
# Mapping layers.
for layer_idx in range(mapping_layers):
name = 'G_mapping/Dense{}'.format(layer_idx)
fmaps = dlatent_size if layer_idx == mapping_layers - 1 else mapping_fmaps
model.add( DenseLayer(units=fmaps, kernel_initializer=GetWeights(), name=name, lrmul=mapping_lrmul) )
model.add( LeakyReLU(alpha=0.2, name=name+'/LeakyReLU') )
# Broadcast.
model.add( Broadcast(name='G_mapping/Broadcast', dlatent_broadcast=num_layers) )
# Output.
model.add( Identity(name='G_mapping/dlatents_out') )
# Apply truncation trick.
model.add( Truncation(name='Truncation', num_layers=num_layers, truncation_psi=truncation_psi ))
return model
def StyleGAN_G_synthesis(dlatent_size=512, resolution=1024, is_const_noise=True):
# General parameters
num_channels = 3
resolution_log2 = int(np.log2(resolution))
num_layers = resolution_log2 * 2 - 2
# Primary inputs.
dlatents_in = tf.keras.layers.Input(shape=[num_layers, dlatent_size], name='G_synthesis/dlatents_in')
# Noise inputs.
noise_inputs = []
for layer_idx in range(num_layers):
noise_inputs.append( RandomNoise(name='G_synthesis/noise%d'%layer_idx, layer_idx=layer_idx)(dlatents_in) )
# Things to do at the end of each layer.
def layer_epilogue(x, layer_idx, name):
name = 'G_synthesis/{}x{}/{}/'.format(x.shape[2], x.shape[2], name)
x = ApplyNoise(name=name+'Noise', is_const_noise=is_const_noise)([x, noise_inputs[layer_idx]])
x = ApplyBias(name=name+'bias')(x)
x = LeakyReLU(alpha=0.2, name=name+'LeakyReLU')(x)
x = InstanceNorm(name=name+'InstanceNorm')(x)
style = DenseLayer(units=x.shape[1]*2, gain=1, name=name+'StyleMod') (StridedSlice(layer_idx, name=name+'StridedSlice')(dlatents_in))
x = StyleModApply(name=name+'StyleModApply')([x, style])
return x
# Building blocks for remaining layers.
def block(res, x): # res = 3..resolution_log2
name, name0, name1 = '%dx%d' % (2**res, 2**res), 'Conv0_up', 'Conv1'
# Conv0_up
upscaled = Upscale2d_conv2d(x, name='G_synthesis/{}/{}'.format(name, name0), filters=nf(res-1), kernel_size=3, use_bias=False)
x = layer_epilogue( Blur(name='G_synthesis/{}/{}/Blur'.format(name, name0))(upscaled), res*2-4, name0 )
# Conv1
x = layer_epilogue( Conv2d(name='G_synthesis/{}/{}'.format(name, name1), filters=nf(res-1), kernel_size=3, use_bias=False)(x), res*2-3, name1 )
return x
def torgb(res, x): # res = 2..resolution_log2
lod = resolution_log2 - res
return Conv2d(name='G_synthesis/ToRGB_lod%d' % lod, filters=num_channels, kernel_size=1, gain=1, use_bias=True)(x)
# Early layers.
x = layer_epilogue(Const(name='G_synthesis/4x4/Const')(dlatents_in), 0, name='Const')
x = layer_epilogue(Conv2d(name='G_synthesis/4x4/Conv', filters=nf(1), kernel_size=3, use_bias=False)(x), 1, 'Conv')
# Fixed structure: simple and efficient, but does not support progressive growing.
for res in range(3, resolution_log2 + 1):
x = block(res, x)
x = torgb(resolution_log2, x)
# change output to the default NHWC format, so it will be compatible with other networks
x = tf.transpose(x, (0, 2, 3, 1))
return Model(inputs=dlatents_in, outputs=x, name='G_synthesis')
class StyleGAN_G(Model):
def __init__(self, resolution=1024, latent_size=512, dlatent_size=512, mapping_layers=8, mapping_fmaps=512,
mapping_lrmul=0.01, truncation_psi=1):
super(StyleGAN_G, self).__init__()
self.model_mapping = StyleGAN_G_mapping(latent_size, dlatent_size, mapping_layers, mapping_fmaps,
mapping_lrmul, truncation_psi, resolution)
self.model_synthesis = StyleGAN_G_synthesis(dlatent_size, resolution)
print('Model created.')
def call(self, inputs):
x = self.model_mapping(inputs)
x = self.model_synthesis(x)
return x
def generate_sample(self, seed=5, is_visualize=False):
rnd = np.random.RandomState(seed)
latents = rnd.randn(1, 512)
y = self.predict(latents)
images = y.transpose([0, 2, 3, 1])
images = np.clip((images+1)*0.5, 0, 1)
# print(images.shape, np.min(images), np.max(images))
plt.figure(figsize=(10, 10))
plt.imshow(images[0])
if is_visualize:
plt.show()
return images
class StyleGAN_D(Model):
def __init__(self, resolution=1024, mbstd_group_size=4, mbstd_num_features=1):
super(StyleGAN_D, self).__init__()
resolution_log2 = int(math.log2(resolution))
model = Sequential(name='Discriminator')
model.add(InputLayer(input_shape=[3, resolution, resolution]))
def fromrgb(res):
name = 'FromRGB_lod%d' % (resolution_log2 - res)
model.add( Conv2d(filters=nf(res-1), kernel_size=1, name=name) )
model.add( LeakyReLU(alpha=0.2, name=name+'/LeakyReLU') )
def block(res):
name = '%dx%d' % (2**res, 2**res)
if res >= 3: # 8x8 and up
model.add( Conv2d(filters=nf(res-1), kernel_size=3, name=name+'/Conv0') )
model.add( LeakyReLU(alpha=0.2, name=name+'/Conv0/LeakyReLU') )
model.add( Blur(name=name+'/Blur') )
Conv2d_downscale2d(model=model, filters=nf(res-2), kernel_size=3, name=name+'/Conv1_down')
model.add( LeakyReLU(alpha=0.2, name=name+'/Conv1_down/LeakyReLU') )
else: # 4x4
if mbstd_group_size > 1:
model.add( Lambda(lambda x: minibatch_stddev_layer(x, mbstd_group_size, mbstd_num_features), name=name+'/MinibatchStddev') )
model.add( Conv2d(filters=nf(res-1), kernel_size=3, name=name+'/Conv') )
model.add( LeakyReLU(alpha=0.2, name=name+'/Conv/LeakyReLU') )
model.add( Flatten() )
model.add( DenseLayer(units=nf(res-2), kernel_initializer=GetWeights(), name=name+'/Dense0') )
model.add( LeakyReLU(alpha=0.2, name=name+'/Dense0/LeakyReLU') )
model.add( DenseLayer(units=1, kernel_initializer=GetWeights(1), gain=1, name=name+'/Dense1') )
# Blocks
fromrgb(resolution_log2)
for res in range(resolution_log2, 2, -1): block(res)
block(2)
self.model = model
def call(self, inputs):
inputs = tf.transpose(inputs, (0, 3, 1, 2))
return self.model(inputs)
def copy_weights_to_keras_model(model, all_weights):
c = 0
od = all_weights
for l in model.layers:
try:
values = l.get_weights()
weights = list(map(lambda x: x.shape, values))
if not len(weights): continue
num_params = values[0].size
# Special weights
if len(weights) == 1:
weights_list = []
# The learned constant variable
if l.name == 'G_synthesis/4x4/Const': weights_list.append( od[l.name+'/const'] )
# Truncation trick variable
if 'Truncation' in l.name: weights_list.append( od['dlatent_avg'] )
# Input noise
if 'G_synthesis/noise' in l.name: weights_list.append( od[l.name] )
# Noise variables
if 'Noise' in l.name: weights_list.append( od[l.name+'/weight'] )
# Bias variables
if 'bias' in l.name: weights_list.append( od[l.name] )
# Conv with no bias
if l.name.endswith('Conv') or l.name.endswith('Conv1') or l.name.endswith('Conv0_up'):
weights_list.append( od[l.name+'/weight'] )
if len(weights_list) > 0:
l.set_weights( weights_list )
print('.', end='')
c = c + num_params
else:
print('WARNING: weights not found for ', l.name, ' of size', weights[0])
else:
# Standard weights (weight + bias)
assert len(weights) == 2
num_params = num_params + values[1].size
layer_name = l.name
var_names = ['{}/{}'.format(layer_name, 'weight'), '{}/{}'.format(layer_name, 'bias')]
if var_names[0] in od and var_names[1] in od:
weight = od[var_names[0]]
bias = od[var_names[1]]
l.set_weights( [ weight, bias ] )
print('.', end='')
c = c + num_params
else:
print('WARNING: not found', var_names)
except Exception as e:
print(e)
print('skipping...')
print('Total number of parameters copied:', c)
================================================
FILE: test.py
================================================
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['USE_SIMPLE_THREADED_LEVEL3'] = '1'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
import sys
from model.stylegan import StyleGAN_G_synthesis
from model.model import Network
from writer import Writer
from inference import Inference
from arglib import arglib
import utils
sys.path.insert(0, 'model/face_utils')
def main():
test_args = arglib.TestArgs()
args, str_args = test_args.args, test_args.str_args
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
Writer.set_writer(args.results_dir)
id_model_path = args.pretrained_models_path.joinpath('vggface2.h5')
stylegan_G_synthesis_path = str(
args.pretrained_models_path.joinpath(f'stylegan_G_{args.resolution}x{args.resolution}_synthesis'))
utils.landmarks_model_path = str(args.pretrained_models_path.joinpath('shape_predictor_68_face_landmarks.dat'))
stylegan_G_synthesis = StyleGAN_G_synthesis(resolution=args.resolution, is_const_noise=args.const_noise)
stylegan_G_synthesis.load_weights(stylegan_G_synthesis_path)
network = Network(args, id_model_path, stylegan_G_synthesis)
network.test()
inference = Inference(args, network)
test_func = getattr(inference, args.test_func)
test_func()
if __name__ == '__main__':
main()
================================================
FILE: trainer.py
================================================
import logging
import numpy as np
import tensorflow as tf
from writer import Writer
from utils import general_utils as utils
def id_loss_func(y_gt, y_pred):
return tf.reduce_mean(tf.keras.losses.MAE(y_gt, y_pred))
class Trainer(object):
def __init__(self, args, model, data_loader):
self.args = args
self.logger = logging.getLogger(__class__.__name__)
self.model = model
self.data_loader = data_loader
# lrs & optimizers
lr = 5e-5 if self.args.resolution == 256 else 1e-5
self.g_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
self.g_gan_optimizer = tf.keras.optimizers.Adam(learning_rate=0.1 * lr)
self.w_d_optimizer = tf.keras.optimizers.Adam(learning_rate=0.4 * lr)
self.im_d_optimizer = tf.keras.optimizers.Adam(learning_rate=0.4 * lr)
# Losses
self.gan_loss_func = tf.keras.losses.BinaryCrossentropy(from_logits=True)
self.pixel_loss_func = tf.keras.losses.MeanAbsoluteError(tf.keras.losses.Reduction.SUM)
self.id_loss_func = id_loss_func
if args.pixel_mask_type == 'gaussian':
sigma = int(80 * (self.args.resolution / 256))
self.pixel_mask = utils.inverse_gaussian_image(self.args.resolution, sigma)
else:
self.pixel_mask = tf.ones([self.args.resolution, self.args.resolution])
self.pixel_mask = self.pixel_mask / tf.reduce_sum(self.pixel_mask)
self.pixel_mask = tf.broadcast_to(self.pixel_mask, [self.args.batch_size, *self.pixel_mask.shape])
self.num_epoch = 0
self.is_cross_epoch = False
# Lambdas
if args.unified:
self.lambda_gan = 0.5
else:
self.lambda_gan = 1
self.lambda_pixel = 0.02
self.lambda_id = 1
self.lambda_attr_id = 1
self.lambda_landmarks = 0.001
self.r1_gamma = 10
# Test
self.test_not_imporved = 0
self.max_id_preserve = 0
self.min_lnd_dist = np.inf
def train(self):
while self.num_epoch <= self.args.num_epochs:
self.logger.info('---------------------------------------')
self.logger.info(f'Start training epoch: {self.num_epoch}')
if self.args.cross_frequency and (self.num_epoch % self.args.cross_frequency == 0):
self.is_cross_epoch = True
self.logger.info('This epoch is cross-face')
else:
self.is_cross_epoch = False
self.logger.info('This epoch is same-face')
try:
if self.num_epoch % self.args.test_frequency == 0:
self.test()
self.train_epoch()
except Exception as e:
self.logger.exception(e)
raise
if self.test_not_imporved > self.args.not_improved_exit:
self.logger.info(f'Test has not improved for {self.args.not_improved_exit} epochs. Exiting...')
break
self.num_epoch += 1
def train_epoch(self):
id_loss = 0
landmarks_loss = 0
g_w_gan_loss = 0
pixel_loss = 0
w_d_loss = 0
w_loss = 0
self.logger.info(f'train in epoch: {self.num_epoch}')
self.model.train()
use_w_d = self.args.W_D_loss
# if use_w_d and use_im_d and not self.args.unified:
if not self.args.unified:
if self.num_epoch % 2 == 0:
# This epoch is not using image_D
use_im_d = False
# self.logger.info(f'Not using Image D in epoch: {self.num_epoch}')
if self.num_epoch % 2 != 0:
# This epoch is not using W_D
use_w_d = False
# self.logger.info(f'Not using W_d in epoch: {self.num_epoch}')
attr_img, id_img, real_w, real_img, matching_ws = self.data_loader.get_batch(is_cross=self.is_cross_epoch)
# Forward that does not require grads
id_embedding = self.model.G.id_encoder(id_img)
src_landmarks = self.model.G.landmarks(attr_img)
attr_input = attr_img
with tf.GradientTape(persistent=True) as g_tape:
attr_out = self.model.G.attr_encoder(attr_input)
attr_embedding = attr_out
self.logger.info(f'attr embedding stats- mean: {tf.reduce_mean(tf.abs(attr_embedding)):.5f},'
f' variance: {tf.math.reduce_variance(attr_embedding):.5f}')
z_tag = tf.concat([id_embedding, attr_embedding], -1)
w = self.model.G.latent_spaces_mapping(z_tag)
fake_w = w[:, 0, :]
self.logger.info(
f'w stats- mean: {tf.reduce_mean(tf.abs(fake_w)):.5f}, variance: {tf.math.reduce_variance(fake_w):.5f}')
pred = self.model.G.stylegan_s(w)
# Move to roughly [0,1]
pred = (pred + 1) / 2
if use_w_d:
with tf.GradientTape() as w_d_tape:
fake_w_logit = self.model.W_D(fake_w)
g_w_gan_loss = self.generator_gan_loss(fake_w_logit)
self.logger.info(f'g W loss is {g_w_gan_loss:.3f}')
self.logger.info(f'fake W logit: {tf.squeeze(fake_w_logit)}')
with g_tape.stop_recording():
real_w_logit = self.model.W_D(real_w)
w_d_loss = self.discriminator_loss(fake_w_logit, real_w_logit)
w_d_total_loss = w_d_loss
if self.args.gp:
w_d_gp = self.R1_gp(self.model.W_D, real_w)
w_d_total_loss += w_d_gp
self.logger.info(f'w_d_gp : {w_d_gp}')
self.logger.info(f'W_D loss is {w_d_loss:.3f}')
self.logger.info(f'real W logit: {tf.squeeze(real_w_logit)}')
if self.args.id_loss:
pred_id_embedding = self.model.G.id_encoder(pred)
id_loss = self.lambda_id * id_loss_func(pred_id_embedding, tf.stop_gradient(id_embedding))
self.logger.info(f'id loss is {id_loss:.3f}')
if self.args.landmarks_loss:
try:
dst_landmarks = self.model.G.landmarks(pred)
except Exception as e:
self.logger.warning(f'Failed finding landmarks on prediction. Dont use landmarks loss. Error:{e}')
dst_landmarks = None
if dst_landmarks is None or src_landmarks is None:
landmarks_loss = 0
else:
landmarks_loss = self.lambda_landmarks * \
tf.reduce_mean(tf.keras.losses.MSE(src_landmarks, dst_landmarks))
self.logger.info(f'landmarks loss is: {landmarks_loss:.3f}')
# if landmarks_loss > 5:
# landmarks_loss = 0
# id_loss = 0
if not self.is_cross_epoch and self.args.pixel_loss:
l1_loss = self.pixel_loss_func(attr_img, pred, sample_weight=self.pixel_mask)
self.logger.info(f'L1 pixel loss is {l1_loss:.3f}')
if self.args.pixel_loss_type == 'mix':
mssim = tf.reduce_mean(1 - tf.image.ssim_multiscale(attr_img, pred, 1.0))
self.logger.info(f'mssim loss is {l1_loss:.3f}')
pixel_loss = self.lambda_pixel * (0.84 * mssim + 0.16 * l1_loss)
else:
pixel_loss = self.lambda_pixel * l1_loss
self.logger.info(f'pixel loss is {pixel_loss:.3f}')
g_gan_loss = g_w_gan_loss
total_g_not_gan_loss = id_loss \
+ landmarks_loss \
+ pixel_loss \
+ w_loss
self.logger.info(f'total G (not gan) loss is {total_g_not_gan_loss:.3f}')
self.logger.info(f'G gan loss is {g_gan_loss:.3f}')
Writer.add_scalar('loss/landmarks_loss', landmarks_loss, step=self.num_epoch)
Writer.add_scalar('loss/total_g_not_gan_loss', total_g_not_gan_loss, step=self.num_epoch)
Writer.add_scalar('loss/id_loss', id_loss, step=self.num_epoch)
if use_w_d:
Writer.add_scalar('loss/g_w_gan_loss', g_w_gan_loss, step=self.num_epoch)
Writer.add_scalar('loss/W_D_loss', w_d_loss, step=self.num_epoch)
if self.args.gp:
Writer.add_scalar('loss/w_d_gp', w_d_gp, step=self.num_epoch)
if not self.is_cross_epoch:
Writer.add_scalar('loss/pixel_loss', pixel_loss, step=self.num_epoch)
Writer.add_scalar('loss/w_loss', w_loss, step=self.num_epoch)
if self.args.debug or \
(self.num_epoch < 1e3 and self.num_epoch % 1e2 == 0) or \
(self.num_epoch < 1e4 and self.num_epoch % 1e3 == 0) or \
(self.num_epoch % 1e4 == 0):
utils.save_image(pred[0], self.args.images_results.joinpath(f'{self.num_epoch}_prediction_step.png'))
utils.save_image(id_img[0], self.args.images_results.joinpath(f'{self.num_epoch}_id_step.png'))
utils.save_image(attr_img[0], self.args.images_results.joinpath(f'{self.num_epoch}_attr_step.png'))
Writer.add_image('input/id image', tf.expand_dims(id_img[0], 0), step=self.num_epoch)
Writer.add_image('Prediction', tf.expand_dims(pred[0], 0), step=self.num_epoch)
if total_g_not_gan_loss != 0:
g_grads = g_tape.gradient(total_g_not_gan_loss, self.model.G.trainable_variables)
g_grads_global_norm = tf.linalg.global_norm(g_grads)
self.logger.info(f'global norm G not gan grad: {g_grads_global_norm}')
self.g_optimizer.apply_gradients(zip(g_grads, self.model.G.trainable_variables))
if use_w_d:
g_gan_grads = g_tape.gradient(g_gan_loss, self.model.G.trainable_variables)
g_gan_grad_global_norm = tf.linalg.global_norm(g_gan_grads)
self.logger.info(f'global norm G gan grad: {g_gan_grad_global_norm}')
self.g_gan_optimizer.apply_gradients(zip(g_gan_grads, self.model.G.trainable_variables))
w_d_grads = w_d_tape.gradient(w_d_total_loss, self.model.W_D.trainable_variables)
self.logger.info(f'global W_D gan grad: {tf.linalg.global_norm(w_d_grads)}')
self.w_d_optimizer.apply_gradients(zip(w_d_grads, self.model.W_D.trainable_variables))
del g_tape
# Common
# Test
def test(self):
self.logger.info(f'Testing in epoch: {self.num_epoch}')
self.model.test()
similarities = {'id_to_pred': [], 'id_to_attr': [], 'attr_to_pred': []}
fake_reconstruction = {'MSE': [], 'PSNR': [], 'ID': []}
real_reconstruction = {'MSE': [], 'PSNR': [], 'ID': []}
if self.args.test_with_arcface:
test_similarities = {'id_to_pred': [], 'id_to_attr': [], 'attr_to_pred': []}
lnd_dist = []
for i in range(self.args.test_size):
attr_img, id_img = self.data_loader.get_batch(is_train=False, is_cross=True)
pred, id_embedding, w, attr_embedding, src_lnds = self.model.G(id_img, attr_img)
image = tf.clip_by_value(pred, 0, 1)
pred_id = self.model.G.id_encoder(image)
attr_id = self.model.G.id_encoder(attr_img)
similarities['id_to_pred'].extend(tf.keras.losses.cosine_similarity(id_embedding, pred_id).numpy())
similarities['id_to_attr'].extend(tf.keras.losses.cosine_similarity(id_embedding, attr_id).numpy())
similarities['attr_to_pred'].extend(tf.keras.losses.cosine_similarity(attr_id, pred_id).numpy())
if self.args.test_with_arcface:
try:
arc_id_embedding = self.model.G.test_id_encoder(id_img)
arc_pred_id = self.model.G.test_id_encoder(image)
arc_attr_id = self.model.G.test_id_encoder(attr_img)
test_similarities['id_to_attr'].extend(
tf.keras.losses.cosine_similarity(arc_id_embedding, arc_attr_id).numpy())
test_similarities['id_to_pred'].extend(
tf.keras.losses.cosine_similarity(arc_id_embedding, arc_pred_id).numpy())
test_similarities['attr_to_pred'].extend(
tf.keras.losses.cosine_similarity(arc_attr_id, arc_pred_id).numpy())
except Exception as e:
self.logger.warning(f'Not calculating test similarities for iteration: {i} because: {e}')
# Landmarks
dst_lnds = self.model.G.landmarks(image)
lnd_dist.extend(tf.reduce_mean(tf.keras.losses.MSE(src_lnds, dst_lnds), axis=-1).numpy())
# Fake Reconstruction
self.test_reconstruction(id_img, fake_reconstruction, display=(i==0), display_name='id_img')
if self.args.test_real_attr:
# Real Reconstruction
self.test_reconstruction(attr_img, real_reconstruction, display=(i==0), display_name='attr_img')
if i == 0:
utils.save_image(image[0], self.args.images_results.joinpath(f'test_prediction_{self.num_epoch}.png'))
utils.save_image(id_img[0], self.args.images_results.joinpath(f'test_id_{self.num_epoch}.png'))
utils.save_image(attr_img[0],
self.args.images_results.joinpath(f'test_attr_{self.num_epoch}.png'))
Writer.add_image('test/prediction', image, step=self.num_epoch)
Writer.add_image('test input/id image', id_img, step=self.num_epoch)
Writer.add_image('test input/attr image', attr_img, step=self.num_epoch)
for j in range(np.minimum(3, src_lnds.shape[0])):
src_xy = src_lnds[j] # GT
dst_xy = dst_lnds[j] # pred
attr_marked = utils.mark_landmarks(attr_img[j], src_xy, color=(0, 0, 0))
pred_marked = utils.mark_landmarks(pred[j], src_xy, color=(0, 0, 0))
pred_marked = utils.mark_landmarks(pred_marked, dst_xy, color=(255, 112, 112))
Writer.add_image(f'landmarks/overlay-{j}', pred_marked, step=self.num_epoch)
Writer.add_image(f'landmarks/src-{j}', attr_marked, step=self.num_epoch)
# Similarity
self.logger.info('Similarities:')
for k, v in similarities.items():
self.logger.info(f'{k}: MEAN: {np.mean(v)}, STD: {np.std(v)}')
mean_lnd_dist = np.mean(lnd_dist)
self.logger.info(f'Mean landmarks L2: {mean_lnd_dist}')
id_to_pred = np.mean(similarities['id_to_pred'])
attr_to_pred = np.mean(similarities['attr_to_pred'])
mean_disen = attr_to_pred - id_to_pred
Writer.add_scalar('similarity/score', mean_disen, step=self.num_epoch)
Writer.add_scalar('similarity/id_to_pred', id_to_pred, step=self.num_epoch)
Writer.add_scalar('similarity/attr_to_pred', attr_to_pred, step=self.num_epoch)
if self.args.test_with_arcface:
arc_id_to_pred = np.mean(test_similarities['id_to_pred'])
arc_attr_to_pred = np.mean(test_similarities['attr_to_pred'])
arc_mean_disen = arc_attr_to_pred - arc_id_to_pred
Writer.add_scalar('arc_similarity/score', arc_mean_disen, step=self.num_epoch)
Writer.add_scalar('arc_similarity/id_to_pred', arc_id_to_pred, step=self.num_epoch)
Writer.add_scalar('arc_similarity/attr_to_pred', arc_attr_to_pred, step=self.num_epoch)
self.logger.info(f'Mean disentanglement score is {mean_disen}')
Writer.add_scalar('landmarks/L2', np.mean(lnd_dist), step=self.num_epoch)
# Reconstruction
if self.args.test_real_attr:
Writer.add_scalar('reconstruction/real_MSE', np.mean(real_reconstruction['MSE']), step=self.num_epoch)
Writer.add_scalar('reconstruction/real_PSNR', np.mean(real_reconstruction['PSNR']), step=self.num_epoch)
Writer.add_scalar('reconstruction/real_ID', np.mean(real_reconstruction['ID']), step=self.num_epoch)
Writer.add_scalar('reconstruction/fake_MSE', np.mean(fake_reconstruction['MSE']), step=self.num_epoch)
Writer.add_scalar('reconstruction/fake_PSNR', np.mean(fake_reconstruction['PSNR']), step=self.num_epoch)
Writer.add_scalar('reconstruction/fake_ID', np.mean(fake_reconstruction['ID']), step=self.num_epoch)
if mean_lnd_dist < self.min_lnd_dist:
self.logger.info('Minimum landmarks dist achieved. saving checkpoint')
self.test_not_imporved = 0
self.min_lnd_dist = mean_lnd_dist
self.model.my_save(f'_best_landmarks_epoch_{self.num_epoch}')
if np.abs(id_to_pred) > self.max_id_preserve:
self.logger.info(f'Max ID preservation achieved! saving checkpoint')
self.test_not_imporved = 0
self.max_id_preserve = np.abs(id_to_pred)
self.model.my_save(f'_best_id_epoch_{self.num_epoch}')
else:
self.test_not_imporved += 1
def test_reconstruction(self, img, errors_dict, display=False, display_name=None):
pred, id_embedding, w, attr_embedding, src_lnds = self.model.G(img, img)
recon_image = tf.clip_by_value(pred, 0, 1)
recon_pred_id = self.model.G.id_encoder(recon_image)
mse = tf.reduce_mean((img - recon_image) ** 2, axis=[1, 2, 3]).numpy()
psnr = tf.image.psnr(img, recon_image, 1).numpy()
errors_dict['MSE'].extend(mse)
errors_dict['PSNR'].extend(psnr)
errors_dict['ID'].extend(tf.keras.losses.cosine_similarity(id_embedding, recon_pred_id).numpy())
if display:
Writer.add_image(f'reconstruction/{display_name}', pred, step=self.num_epoch)
# Helpers
def generator_gan_loss(self, fake_logit):
"""
G logistic non saturating loss, to be minimized
"""
g_gan_loss = self.gan_loss_func(tf.ones_like(fake_logit), fake_logit)
return self.lambda_gan * g_gan_loss
def discriminator_loss(self, fake_logit, real_logit):
"""
D logistic loss, to be minimized
verified as identical to StyleGAN's loss.D_logistic
"""
fake_gt = tf.zeros_like(fake_logit)
real_gt = tf.ones_like(real_logit)
d_fake_loss = self.gan_loss_func(fake_gt, fake_logit)
d_real_loss = self.gan_loss_func(real_gt, real_logit)
d_loss = d_real_loss + d_fake_loss
return self.lambda_gan * d_loss
def R1_gp(self, D, x):
with tf.GradientTape() as t:
t.watch(x)
pred = D(x)
pred_sum = tf.reduce_sum(pred)
grad = t.gradient(pred_sum, x)
# Reshape as a vector
norm = tf.norm(tf.reshape(grad, [tf.shape(grad)[0], -1]), axis=1)
gp = tf.reduce_mean(norm ** 2)
gp = 0.5 * self.r1_gamma * gp
return gp
================================================
FILE: utils/__init__.py
================================================
================================================
FILE: utils/general_utils.py
================================================
from pathlib import Path
import cv2
import numpy as np
from PIL import Image
import tensorflow as tf
from tensorflow.keras.initializers import VarianceScaling
import numbers
import scipy
import dlib
landmarks_model_path = None
def read_image(img_path, resolution, align=False):
if align:
img = read_and_align_image(img_path, resolution)
else:
img = read_SG_image(img_path, resolution)
return img
def find_file_by_str(search_dir, s):
files = [f for f in search_dir.iterdir() if s in f.name]
return files
def read_SG_image(img_path, size=256, resize=True):
img = Image.open(str(img_path))
img = img.convert('RGB')
if img.size != (size, size) and resize:
img = img.resize((size, size))
img = np.asarray(img)
img = np.expand_dims(img, axis=0)
# Images in [0, 1]
img = np.float32(img) / 255
return img
def read_and_align_image(img_path, output_size=1024):
global landmarks_model_path
if not landmarks_model_path:
raise ValueError('Please init the landmarks model path')
transform_size = 4096
enable_padding = True
img = Image.open(img_path)
img = img.convert('RGB')
npimg = np.asarray(img)
# states is a 4x1 array with confidence for : [left eye closed, right eye closed, mouth closed, mouth open big]
face_detector = dlib.get_frontal_face_detector()
landmarks_network = dlib.shape_predictor(landmarks_model_path)
try:
bbox = face_detector(npimg, 0)[0]
except:
print('face not found!')
raise
# rect = np.array([det.left(), det.top(), det.right(), det.bottom()])
shape = landmarks_network(npimg, bbox)
lm = np.array([[shape.part(n).x + 0.5, shape.part(n).y + 0.5] for n in range(shape.num_parts)])
lm = np.round(lm) + 0.5
lm_chin = lm[0: 17] # left-right
lm_eyebrow_left = lm[17: 22] # left-right
lm_eyebrow_right = lm[22: 27] # left-right
lm_nose = lm[27: 31] # top-down
lm_nostrils = lm[31: 36] # top-down
lm_eye_left = lm[36: 42] # left-clockwise
lm_eye_right = lm[42: 48] # left-clockwise
lm_mouth_outer = lm[48: 60] # left-clockwise
lm_mouth_inner = lm[60: 68] # left-clockwise
# Calculate auxiliary vectors.
eye_left = np.mean(lm_eye_left, axis=0)
eye_right = np.mean(lm_eye_right, axis=0)
eye_avg = (eye_left + eye_right) * 0.5
# nose_mock_avg = (lm_nose[0] + lm_nose[1]) * 0.5
# eye_avg = nose_mock_avg
eye_to_eye = eye_right - eye_left
mouth_left = lm_mouth_outer[0]
mouth_right = lm_mouth_outer[6]
mouth_avg = (mouth_left + mouth_right) * 0.5
eye_to_mouth = mouth_avg - eye_avg
# Choose oriented crop rectangle.
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
x /= np.hypot(*x)
x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
y = np.flipud(x) * [-1, 1]
c = eye_avg + eye_to_mouth * 0.1
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
qsize = np.hypot(*x) * 2
# Shrink.
shrink = int(np.floor(qsize / output_size * 0.5))
if shrink > 1:
rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
img = img.resize(rsize, Image.ANTIALIAS)
quad /= shrink
qsize /= shrink
# Crop.
border = max(int(np.rint(qsize * 0.1)), 3)
crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
int(np.ceil(max(quad[:, 1]))))
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
min(crop[3] + border, img.size[1]))
crop = np.array(crop)
if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
img = img.crop(tuple(crop))
quad -= crop[0:2]
# Pad.
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
int(np.ceil(max(quad[:, 1]))))
pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
max(pad[3] - img.size[1] + border, 0))
if enable_padding and max(pad) > border - 4:
pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
h, w, _ = img.shape
y, x, _ = np.ogrid[:h, :w, :1]
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
blur = qsize * 0.02
img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
quad += pad[:2]
# Transform.
img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(),
Image.BILINEAR)
if output_size < transform_size:
img = img.resize((output_size, output_size), Image.ANTIALIAS)
img = np.asarray(img)
img = np.expand_dims(img, axis=0)
img = np.float32(img) / 255
return img
def gaussian_image(size, sigma, dim=2):
if isinstance(size, numbers.Number):
size = [size] * dim
if isinstance(sigma, numbers.Number):
sigma = [sigma] * dim
weight_kernel = 1
meshgrids = np.meshgrid(*[np.arange(size, dtype=np.float32) for size in size])
# The gaussian kernel is the product of the
# gaussian function of each dimension.
for size, std, mgrid in zip(size, sigma, meshgrids):
mean = (size - 1) / 2
weight_kernel *= 1 / (std * np.sqrt(2 * np.pi)) * np.exp(-((mgrid - mean) / std) ** 2 / 2)
weight_kernel = weight_kernel / np.sum(weight_kernel)
return weight_kernel
def inverse_gaussian_image(size, sigma, dim=2):
gauss = gaussian_image(size, sigma, dim)
# Inversion achieved by max - gauss, but adding min as well to
# prevent regions of zeros which don't exist in normal gaussian
inv_gauss = np.max(gauss) + np.min(gauss) - gauss
inv_gauss = inv_gauss / np.sum(inv_gauss)
return inv_gauss
def is_float(tensor):
"""
Check if input tensor is float32, tensor maybe tf.Tensor or np.array
"""
return (isinstance(tensor, tf.Tensor) and tensor.dtype != tf.dtypes.uint8) or tensor.dtype != np.uint8
def convert_tensor_to_image(tensor):
"""
Converts tensor to image, and saturate output's range
:param tensor: tf.Tensor, dtype float32, range [0,1]
:return: np.array, dtype uint8, range [0, 255]
"""
if is_float(tensor):
tensor = tf.clip_by_value(tensor, 0., 1.)
tensor = 255 * tensor
if tensor.ndim == 4 and tensor.shape[0] == 1:
tensor = tf.squeeze(tensor)
tensor = np.uint8(np.round(tensor))
return tensor
def save_image(img, file_path):
"""
:param img: Could be either tf tensor or numpy array
:param file_path:
"""
if isinstance(file_path, Path):
file_path = str(file_path)
img = convert_tensor_to_image(img)
img = Image.fromarray(img)
img.save(file_path)
def mark_landmarks(img, lnd, color=None):
"""
landmarks in (x,y) format
"""
img = convert_tensor_to_image(img)
radius = int(img.shape[0] / 256)
lnd = (img.shape[0] / 160) * lnd
if not color:
color = (255, 255, 255)
for i in range(lnd.shape[0]):
x_y = lnd[i]
img = cv2.circle(img, center=(int(x_y[0]), int(x_y[1])),
color=color, radius=radius, thickness=-1)
return img
def get_weights(slope=0.2):
"""
The scale is calculated according to:
https://pytorch.org/docs/stable/nn.init.html
and
https://towardsdatascience.com/weight-initialization-in-neural-networks-a-journey-from-the-basics-to-kaiming-954fb9b47c79
For ReLU and LeakyReLU activations, the preferable initialization is kaiming.
In Pytorch, the gain for LeakyReLU is calcaulted by: sqrt(2 / ( 1 + leaky_relu_slope ^ 2)) and the weights are
sampled from N(0, std^2) where std = gain / sqrt(fan_in)
To mimic this in TF, I am using VarianceScaling. The weights are sampled from N(0, std^2)
where std = sqrt(scale / fan_in). Therefore, scale = gain^2
"""
scale = 2 / (1 + slope ** 2)
return VarianceScaling(scale)
def np_permute(tensor, permute):
idx = np.empty_like(permute)
idx[permute] = np.arange(len(permute))
return tensor[:, idx]
================================================
FILE: utils/generate_fake_data.py
================================================
import sys
from pathlib import Path
import os
sys.path.append('..')
import argparse
import tensorflow as tf
import numpy as np
from tqdm import tqdm
from utils.general_utils import save_image
from model.stylegan import StyleGAN_G
def main(args):
base_dir = Path(args.output_path).joinpath(f'dataset_{args.resolution}')
base_w_dir = base_dir.joinpath('ws')
base_w_dir.mkdir(parents=True, exist_ok=True)
base_im_dir = base_dir.joinpath('images')
base_im_dir.mkdir(parents=True, exist_ok=True)
existing_files = list(base_dir.joinpath('images').iterdir())
if existing_files:
max_exist = max([int(x.name) for x in existing_files])
max_exist = int(max_exist - max_exist % 1e3 + 1e3)
else:
max_exist = 0
stylegan_G_path = args.pretrained_models_path.joinpath(f'stylegan_G_{args.resolution}x{args.resolution}.h5')
stylegan_G = StyleGAN_G(resolution=args.resolution, truncation_psi=args.truncation)
stylegan_G.load_weights(str(stylegan_G_path))
num_samples = args.num_images
batch_size = args.batch_size
num_batches = int(num_samples / batch_size)
curr_ind = max_exist
for _ in tqdm(range(num_batches)):
z = tf.random.normal((batch_size, 512))
w = stylegan_G.model_mapping(z)
images = stylegan_G.model_synthesis(w)
images = (images + 1) / 2
if curr_ind % 1000 == 0:
curr_w_dir = base_w_dir.joinpath(f'{curr_ind:05d}')
curr_w_dir.mkdir(exist_ok=True)
curr_im_dir = base_im_dir.joinpath(f'{curr_ind:05d}')
curr_im_dir.mkdir(exist_ok=True)
for j in range(batch_size):
w_path = curr_w_dir.joinpath(f'{curr_ind:05d}.npy')
np.save(str(w_path), w[j], allow_pickle=False)
im_path = curr_im_dir.joinpath(f'{curr_ind:05d}.png')
save_image(images[j], im_path)
curr_ind += 1
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--resolution', type=int, choices=[256, 1024], default=256)
parser.add_argument('--batch_size', type=int, default=50)
parser.add_argument('--truncation', type=float, default=0.7)
parser.add_argument('--output_path', required=True)
parser.add_argument('--pretrained_models_path', type=Path, required=True)
parser.add_argument('--num_images', type=int, default=10000)
parser.add_argument('--gpu', default='0')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
assert args.num_images % 1e3 == 0
main(args)
================================================
FILE: writer.py
================================================
from utils.general_utils import convert_tensor_to_image
from pathlib import Path
import tensorflow as tf
class Writer(object):
writer = None
@staticmethod
def set_writer(results_dir):
if isinstance(results_dir, str):
results_dir = Path(results_dir)
results_dir.mkdir(exist_ok=True, parents=True)
Writer.writer = tf.summary.create_file_writer(str(results_dir))
@staticmethod
def add_scalar(tag, val, step):
with Writer.writer.as_default():
tf.summary.scalar(tag, val, step=step)
@staticmethod
def add_image(tag, val, step):
val = convert_tensor_to_image(val)
if tf.rank(val) == 3:
val = tf.expand_dims(val, 0)
with Writer.writer.as_default():
tf.summary.image(tag, val, step)
@staticmethod
def flush():
with Writer.writer.as_default():
Writer.writer.flush()