main 342a782364fd cached
32 files
129.6 KB
35.0k tokens
193 symbols
1 requests
Download .txt
Repository: YotamNitzan/ID-disentanglement
Branch: main
Commit: 342a782364fd
Files: 32
Total size: 129.6 KB

Directory structure:
gitextract_rgin_bu5/

├── .gitignore
├── LICENSE
├── README.md
├── arglib/
│   ├── __init__.py
│   └── arglib.py
├── data_loader/
│   ├── __init__.py
│   └── data_loader.py
├── docs/
│   ├── index.html
│   ├── mainpage.css
│   └── setup.md
├── environment.yml
├── inference.py
├── main.py
├── model/
│   ├── __init__.py
│   ├── arcface/
│   │   ├── arcface.py
│   │   ├── inference.py
│   │   └── resnet.py
│   ├── attr_encoder.py
│   ├── discriminator.py
│   ├── face_detector.py
│   ├── generator.py
│   ├── id_encoder.py
│   ├── landmarks.py
│   ├── latent_mapping.py
│   ├── model.py
│   └── stylegan.py
├── test.py
├── trainer.py
├── utils/
│   ├── __init__.py
│   ├── general_utils.py
│   └── generate_fake_data.py
└── writer.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
.idea
.png
.jpg


================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2020 YotamNitzan

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# Face Identity Disentanglement via Latent Space Mapping

<p align="center">
<img src="docs/imgs/teaser.png" width="400px"/>
</p>


## Description   
Official Implementation of the paper *Face Identity Disentanglement via Latent Space Mapping*
for both training and evaluation.

> **Face Identity Disentanglement via Latent Space Mapping**<br>
> Yotam Nitzan<sup>1</sup>, Amit Bermano<sup>1</sup>, Yangyan Li<sup>2</sup>, Daniel Cohen-Or<sup>1</sup><br>
> <sup>1</sup>Tel-Aviv University, <sup>2</sup>Alibaba <br>
> https://arxiv.org/abs/2005.07728
>
> <p align="justify"><b>Abstract:</b> <i>Learning disentangled representations of data is a fundamental problem in artificial intelligence. Specifically, disentangled latent representations allow generative models to control and compose the disentangled factors in the synthesis process. Current methods, however, require extensive supervision and training, or instead, noticeably compromise quality. In this paper, we present a method that learns how to represent data in a disentangled way, with minimal supervision, manifested solely using available pre-trained networks. Our key insight is to decouple the processes of disentanglement and synthesis, by employing a leading pre-trained unconditional image generator, such as StyleGAN. By learning to map into its latent space, we leverage both its state-of-the-art quality, and its rich and expressive latent space, without the burden of training it. We demonstrate our approach on the complex and high dimensional domain of human heads. We evaluate our method qualitatively and quantitatively, and exhibit its success with de-identification operations and with temporal identity coherency in image sequences. Through extensive experimentation, we show that our method successfully disentangles identity from other facial attributes, surpassing existing methods, even though they require more training and supervision.</i></p>

## Setup

To setup everything you need check out the [setup instructions](docs/setup.md).

## Training

### Preparing the Dataset

The dataset is comprised of StyleGAN-generated images and W latent codes, both are generated from a single
StyleGAN model.

We also use real images from FFHQ to evaluate quality at test time.

The dataset is assumed to be in the following structure:

| Path | Description
| :--- | :---
| base directory | Directory for all datasets
| &boxvr;&nbsp; real | FFHQ image dataset
| &boxvr;&nbsp; dataset_N | dataset for resolution NxN
| &boxv;&nbsp; &boxvr;&nbsp; images | images generated by StyleGAN
| &boxv;&nbsp; &boxur;&nbsp; ws | W latent codes generated by StyleGAN

To generate the `dataset_N` directory, run:

```
cd utils\
python generate_fake_data.py \ 
    --resolution N \
    --batch_size BATCH_SIZE \
    --output_path OUTPUT_PATH \
    --pretrained_models_path PRETRAINED_MODELS_PATH \
    --num_images NUM_IMAGES \
    --gpu GPU
```

It will generate an image dataset in similar format to FFHQ.

### Start training

To train the model as done in the paper

```
python main.py
    NAME
    --resolution N
    --pretrained_models_path PRETRAINED_MODELS_PATH
    --dataset BASE_DATASET_DIR
    --batch_size BATCH_SIZE
    --cross_frequency 3
    --train_data_size 70000
    --results_dir RESULTS_DIR        
```

Please run `python main.py -h` for more details.

## Inference

For convenience, there are a few inference functions - each serving a different use case.
The functions are resolved using the name of the function.

### All possible combinations in dirs

<p align="center">
<img src="docs/imgs/table_results.jpg"/>
</p>

**Input data: Two directories, one identity inputs and another for attribute inputs.** <br>
Runs over all N*M combinations in two directories.

```
python test.py 
    Name
    --pretrained_models_path PRETRAINED_MODELS_PATH \
    --load_checkpoint PATH_TO_WEIGHTS \
    --id_dir DIR_OF_IMAGES_FOR_ID \
    --attr_dir DIR_OF_IMAGES_FOR_ATTR \
    --output_dir DIR_FOR_OUTPUTS \
    --test_func infer_on_dirs
```


### Paired data

**Input data: Two directories, one identity inputs and another for attribute inputs**. <br>
The two directories are assumed to be paired. Inference runs on images with the same names.

```
python test.py 
    Name
    --pretrained_models_path PRETRAINED_MODELS_PATH \
    --load_checkpoint PATH_TO_WEIGHTS \
    --id_dir DIR_OF_IMAGES_FOR_ID \
    --attr_dir DIR_OF_IMAGES_FOR_ATTR \
    --output_dir DIR_FOR_OUTPUTS \
    --test_func infer_pairs
```

### Disentangled interpolation

#### Interpolating attributes

<p align="center">
<img src="docs/imgs/interpolate_attr.jpg"/>
</p>

#### Interpolating identity

<p align="center">
<img src="docs/imgs/interpolate_id.jpg"/>
</p>

**Input data: A directory with any number of subdirectories. In each subdir, there are three images.**
All images should have exactly one of *attr* or *id* in their name.
If there are two *attr* images and one *id* image, it will interpolate attribute.
If there is one *attr* images and two *id* images, it will interpolate identity.


```
python test.py 
    Name
    --pretrained_models_path PRETRAINED_MODELS_PATH \
    --load_checkpoint PATH_TO_WEIGHTS \
    --input_dir PARENT_DIR \
    --output_dir DIR_FOR_OUTPUTS \
    --test_func interpolate
```

## Checkpoints

Our pretrained 256x256 [checkpoint](https://drive.google.com/drive/folders/1lVizq4hCq-zTf8Q3fDqqfSnV6jIYEgY_?usp=sharing) is also available.

## Citation
If you use this code for your research, please cite our paper using:

```
@article{Nitzan2020FaceID,
  title={Face identity disentanglement via latent space mapping},
  author={Yotam Nitzan and A. Bermano and Yangyan Li and D. Cohen-Or},
  journal={ACM Transactions on Graphics (TOG)},
  year={2020},
  volume={39},
  pages={1 - 14}
}
```

================================================
FILE: arglib/__init__.py
================================================


================================================
FILE: arglib/arglib.py
================================================
import math
import shutil
import logging
import argparse
from pathlib import Path
from abc import ABC, abstractmethod


class BaseArgs(ABC):
    def __init__(self):
        self.args = None
        self.parser = argparse.ArgumentParser()
        self.logger = logging.getLogger(self.__class__.__name__)

        self.add_args()
        self.parse()
        self.validate()
        self.process()
        self.str_args = self.log()

    @abstractmethod
    def add_args(self):
        # Hardware
        self.parser.add_argument('--gpu', type=str, default='0')

        # Model
        self.parser.add_argument('--face_detection', action='store_true')
        self.parser.add_argument('--resolution', type=int, default=256, choices=[256, 1024])
        self.parser.add_argument('--load_checkpoint')
        self.parser.add_argument('--pretrained_models_path', type=Path, required=True)

        BaseArgs.add_bool_arg(self.parser, 'const_noise')

        # Data
        self.parser.add_argument('--batch_size', type=int, default=6)
        self.parser.add_argument('--reals', action='store_true', help='Use real inputs')
        BaseArgs.add_bool_arg(self.parser, 'test_real_attr')

        # Log & Results
        self.parser.add_argument('name', type=str, help='Name under which run will be saved')
        self.parser.add_argument('--results_dir', type=str, default='../results')
        self.parser.add_argument('--log_debug', action='store_true')

        # Other
        self.parser.add_argument('--debug', action='store_true')

    def parse(self):
        self.args = self.parser.parse_args()

    def log(self):
        out_str = 'The arguments are:\n'
        for k, v in self.args.__dict__.items():
            out_str += f'{k}: {v}\n'

        return out_str

    @staticmethod
    def add_bool_arg(parser, name, default=True):
        group = parser.add_mutually_exclusive_group(required=False)
        group.add_argument('--' + name, dest=name, action='store_true')
        group.add_argument('--no_' + name, dest=name, action='store_false')
        parser.set_defaults(**{name: default})

    @abstractmethod
    def validate(self):
        if self.args.load_checkpoint and not Path(self.args.load_checkpoint).exists():
            raise ValueError(f'Checkpoint directory {self.args.load_checkpoint} does not exist')

    @abstractmethod
    def process(self):
        # Log & Results
        self.args.results_dir = Path(self.args.results_dir).joinpath(self.args.name)

        if self.args.debug:
            self.args.log_debug = True
        if self.args.debug or not self.args.train:
            shutil.rmtree(self.args.results_dir, ignore_errors=True)

        self.args.results_dir.mkdir(parents=True, exist_ok=True)
        self.args.images_results = self.args.results_dir.joinpath('images')
        self.args.images_results.mkdir(exist_ok=True)

        # Model
        if self.args.load_checkpoint:
            self.args.load_checkpoint = Path(self.args.load_checkpoint)


class TrainArgs(BaseArgs):
    def __init__(self):
        super().__init__()

    def add_args(self):
        super().add_args()

        self.parser.add_argument('--dataset_path', type=str, default='../my_dataset')

        self.parser.add_argument('--num_epochs', type=int, default=math.inf)
        self.parser.add_argument('--cross_frequency', type=int, default=3,
                                 help='Once in how many epochs to perform cross-train epoch (0 for never)')

        self.parser.add_argument('--unified', action='store_true')

        # Data
        BaseArgs.add_bool_arg(self.parser, 'train_real_attr', default=False)
        self.parser.add_argument('--train_data_size', type=int, default=70000,
                                 help='How many images to use for training. Others are used as validation')

        # Losses
        BaseArgs.add_bool_arg(self.parser, 'id_loss')
        BaseArgs.add_bool_arg(self.parser, 'landmarks_loss')
        BaseArgs.add_bool_arg(self.parser, 'pixel_loss')
        BaseArgs.add_bool_arg(self.parser, 'W_D_loss')
        BaseArgs.add_bool_arg(self.parser, 'gp')

        self.parser.add_argument('--pixel_mask_type', choices=['uniform', 'gaussian'], default='gaussian')
        self.parser.add_argument('--pixel_loss_type', choices=['L1', 'mix'], default='mix')

        # Test During training
        self.parser.add_argument('--test_frequency', type=int, default=1000,
                                 help='Once in how many epochs to perform a test')
        self.parser.add_argument('--test_size', type=int, default=50,
                                 help='How many mini-batches should be used for a test')
        self.parser.add_argument('--not_improved_exit', type=int, default=math.inf,
                                 help='After how many not-improved test to exit')
        BaseArgs.add_bool_arg(self.parser, 'test_with_arcface')

    def validate(self):
        super().validate()
        if not Path(self.args.dataset_path).exists():
            raise ValueError(f'Dataset at path: {self.args.dataset_path} does not exist')

    def process(self):
        self.args.train = True

        super().process()

        # Dataset
        self.args.dataset_path = Path(self.args.dataset_path)

        self.args.weights_dir = self.args.results_dir.joinpath('weights')
        self.args.weights_dir.mkdir(exist_ok=True)
        backup_code_dir = self.args.results_dir.joinpath('code')
        code_dir = Path().cwd()
        shutil.copytree(code_dir, backup_code_dir)


class TestArgs(BaseArgs):
    def __init__(self):
        super().__init__()

    def add_args(self):
        super().add_args()
        self.parser.set_defaults(batch_size=1)

        self.parser.add_argument('--id_dir', type=Path)
        self.parser.add_argument('--attr_dir', type=Path)
        self.parser.add_argument('--output_dir', type=Path)
        self.parser.add_argument('--input_dir', type=Path)

        self.parser.add_argument('--real_id', action='store_true')
        self.parser.add_argument('--real_attr', action='store_true')
        BaseArgs.add_bool_arg(self.parser, 'loop_fake')

        self.parser.add_argument('--img_suffixes', type=list, default=['png', 'jpg', 'jpeg'])

        self.parser.add_argument('--test_func', type=str, choices=['infer_on_dirs', 'infer_pairs', 'interpolate'])

        self.parser.add_argument('--input', type=str)

    def validate(self):
        super().validate()

        # if not self.args.input:
        #     raise ValueError('Input needed for inference')
        # if not Path(self.args.input).exists():
        #     raise ValueError(f'Input {self.args.input} does not exist')

    def process(self):
        self.args.train = False

        super().process()

        self.args.output_dir.mkdir(exist_ok=True, parents=True)

        # self.args.input = Path(self.args.input)
        # Split frame sit alongside input, so not every run needs to preprocess
        # input_name = self.args.input.stem
        # self.args.extracted_frames_dir = self.args.input.parent.joinpath(f'{input_name}_frames')
        # self.args.extracted_frames_dir.mkdir(exist_ok=True)


================================================
FILE: data_loader/__init__.py
================================================


================================================
FILE: data_loader/data_loader.py
================================================
import logging

import numpy as np
import tensorflow as tf

from utils.general_utils import read_image


class DataLoader(object):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.logger = logging.getLogger(self.__class__.__name__)

        self.real_dataset = args.dataset_path.joinpath(f'real')

        dataset = args.dataset_path.joinpath(f'dataset_{args.resolution}')

        self.ws_dataset = dataset.joinpath('ws')
        self.image_dataset = dataset.joinpath('images')

        max_dir = max([x.name for x in self.image_dataset.iterdir()])
        self.max_ind = max([int(x.stem) for x in self.image_dataset.joinpath(max_dir).iterdir()])
        self.train_max_ind = args.train_data_size

        if self.train_max_ind >= self.max_ind:
            self.logger.warning('There is no validation data... using training data')
            self.min_val_ind = 0
            self.train_max_ind = self.max_ind
        else:
            self.min_val_ind = self.train_max_ind + 1

    def get_image(self, is_train, black_list=None, is_real=False):
        # Default should be non-mutable
        if black_list is None:
            black_list = []

        max_fails = 10
        curr_fail = 0
        if is_train:
            min_ind, max_ind = 0, self.train_max_ind
        else:
            min_ind, max_ind = self.min_val_ind, self.max_ind

        while True:
            ind = np.random.randint(min_ind, max_ind)

            if ind in black_list:
                continue

            img_name = f'{ind:05d}.png'
            dir_name = f'{int(ind - ind % 1e3):05d}'
            if is_real:
                img_path = self.real_dataset.joinpath(dir_name, img_name)
            else:
                img_path = self.image_dataset.joinpath(dir_name, img_name)

            try:
                img = read_image(img_path, self.args.resolution)
                break
            except Exception as e:
                self.logger.warning(f'Failed reading image at {ind}. Error: {e}')

                # Try again with a different image...
                curr_fail += 1
                if curr_fail > max_fails:
                    raise IOError('Failed reading multiples images')
                continue

        return ind, img

    def get_w_by_ind(self, ind):
        dir_name = f'{int(ind - ind % 1e3):05d}'
        img_name = f'{ind:05d}.npy'
        w_path = self.ws_dataset.joinpath(dir_name, img_name)

        w = np.load(w_path)

        # Take one row while keeping dimension
        w = w[np.newaxis, 0]

        return w

    def get_real_w(self, is_train, black_list=None, is_real=False):
        ind = np.random.randint(0, self.max_ind)
        w = self.get_w_by_ind(ind)

        return ind, w

    def batch_samples(self, get_sample_func, is_train, black_list=None, is_real=False):
        batch = []
        indices = []

        if not black_list:
            black_list = []
        for i in range(self.args.batch_size):
            ind, sample = get_sample_func(is_train, black_list, is_real)

            batch.append(sample)
            indices.append(ind)

        batch = tf.concat(batch, 0)

        return indices, batch

    def get_batch(self, is_train=True, is_cross=False, ws=True):
        black_list = []
        id_imgs_indices, id_img = self.batch_samples(self.get_image, is_train)
        matching_ws = None

        self.logger.debug(f'ID images read: {id_imgs_indices}')
        black_list.extend(id_imgs_indices)

        if is_cross:
            # Use real attr when args say so or when testing
            is_real_attr = (is_train and self.args.train_real_attr) or (not is_train and self.args.test_real_attr)
            black_list = [] if is_real_attr else black_list

            attr_imgs_indices, attr_img = self.batch_samples(self.get_image,
                                                             is_train,
                                                             black_list=black_list,
                                                             is_real=is_real_attr)

            self.logger.debug(f'Attr images read: {attr_imgs_indices}')

        else:
            if is_train:
                attr_img = id_img
                matching_ws = [self.get_w_by_ind(ind) for ind in id_imgs_indices]
                matching_ws = tf.concat(matching_ws, 0)
            else:
                attr_img = id_img

        if not is_train:
            return attr_img, id_img

        # Only for training
        real_img = None
        real_ws = None

        if self.args.train and self.args.reals:
            real_imgs_indices, real_img = self.batch_samples(self.get_image, is_train, black_list=[], is_real=True)
            self.logger.debug(f'Real images read: {real_imgs_indices}')

        if ws:
            _, real_ws = self.batch_samples(self.get_real_w, is_train)

        return attr_img, id_img, real_ws, real_img, matching_ws



================================================
FILE: docs/index.html
================================================
<html>
<head>
    <meta charset="utf-8">
    <title>ID disentanglement</title>

    <!-- CSS includes -->
    <link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.8.1/css/all.css"
          integrity="sha384-50oBUHEmvpQ+1lW4y57PTFmhCaXp0ML5d60M1M7uH2+nqUivzIebhndOJK28anvf" crossorigin="anonymous">
    <link href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.5/css/bootstrap.min.css" rel="stylesheet">
    <link href="mainpage.css" rel="stylesheet">
</head>
<body>

<div class="container-fluid">
    <div class="row">
        <h1><span style="font-size:36px">Face Identity Disentanglement via Latent Space Mapping</span></h1>
        <h1><span style="font-size:22px">SIGGRAPH ASIA 2020</span></h1>

        <div class="authors">
            <span style="font-size:18px"><a href="https://yotamnitzan.github.io/" target="new">Yotam Nitzan<sup>1</sup></a></span>
            &nbsp;
            <span style="font-size:18px"><a href="https://www.cs.tau.ac.il/~amberman/"
                                            target="new">Amit Bermano<sup>1</sup></a></span>
            &nbsp;
            <span style="font-size:18px"><a href="http://yangyan.li/" target="new">Yangyan Li<sup>2</sup></a></span>
            &nbsp;
            <span style="font-size:18px"><a href="https://danielcohenor.com/"
                                            target="new">Daniel Cohen-Or<sup>1</sup></a></span>
            <br>
            <span style="font-size:18px"><sup>1</sup>Tel-Aviv University &nbsp;&nbsp;&nbsp; <sup>2</sup>Alibaba Cloud Intelligence Business Group<br><br></span>
        </div>
    </div>

    <div class="row" style="text-align:center;padding:0;margin:0">
        <div class="container">
            <img src="imgs/teaser.png" height="650px">
        </div>
    </div>

    <div class="container">

        <div class="row">
            <div class="col-lg-1 col-md-0 col-sm-0"></div>
            <div class="col-lg-1 col-md-0 col-sm-0"></div>

            <div class="col-lg-3 col-md-4 col-sm-4 text-center">
                <div class="service-box mt-5 mx-auto">
                    <a href="https://arxiv.org/abs/2005.07728" target="_blank">
                        <i class="far fa-4x fa-file text-primary mb-3 "></i>
                    </a>
                    <h3 class="mb-3">Paper</h3>
                </div>
            </div>

            <div class="col-lg-1 col-md-0 col-sm-0"></div>
            <div class="col-lg-1 col-md-0 col-sm-0"></div>

            <div class="col-lg-2 col-md-4 col-sm-6 text-center">
                <div class="service-box mt-5 mx-auto">
                    <a href="https://github.com/YotamNitzan/ID-disentanglement" target="_blank">
                        <i class="fab fa-4x fa-github text-primary mb-3 "></i>
                    </a>
                    <h3 class="mb-3">Code</h3>
                </div>
            </div>

        </div>
    </div>

    <div class="container">
        <h2>Abstract</h2>
        Learning disentangled representations of data is a fundamental problem in
        artificial intelligence. Specifically, disentangled latent representations allow
        generative models to control and compose the disentangled factors in the
        synthesis process. Current methods, however, require extensive supervision
        and training, or instead, noticeably compromise quality.
        In this paper, we present a method that learns how to represent data
        in a disentangled way, with minimal supervision, manifested solely using
        available pre-trained networks. Our key insight is to decouple the processes
        of disentanglement and synthesis, by employing a leading pre-trained unconditional image generator, such as
        StyleGAN. By learning to map into its
        latent space, we leverage both its state-of-the-art quality, and its rich and
        expressive latent space, without the burden of training it.
        We demonstrate our approach on the complex and high dimensional
        domain of human heads. We evaluate our method qualitatively and quantitatively, and exhibit its success with
        de-identification operations and with
        temporal identity coherency in image sequences. Through extensive experimentation, we show that our method
        successfully disentangles identity
        from other facial attributes, surpassing existing methods, even though they
        require more training and supervision.
    </div>

    <div class="container">
        <h2>motivation</h2>
        learning disentangled representations and image synthesis are different tasks.
        however, it is a common practice to solve both simultaneously.
        this way, the image generator learns the semantics of the representations.
        now it is able to take multiple representations from different sources and mix them to generate novel images.
        but this comes at a price, one now needs to solve two difficult tasks simultaneously.
        this often causes the need to devise dedicated architectures and even then, achieve sub-optimal visual quality.
        <br><br>
        we propose a different approach. unconditional generators have recently achieved amazing image quality.
        we take advantage of this fact, and avoid solving this task ourselves. instead, we suggest to use a pretrained
        generator, such as stylegan. but now, how can the generator, which is pretrained & unconditional, make sense of
        the disentangled representations?
        <br>
        we suggest mapping the disentangled representations directly into the latent space of the generator.
        the mapping produces in a single feed-forward a new, never before seen, latent code that corresponds to novel
        images.

        <div class="row" style="text-align:center;padding:0;margin:0">
            <img src="imgs/architecture.jpg" height="512px">
        </div>


    </div>

    <div class="container">
        <h2>Composition Results</h2>

        We demonstrate our method on the domain of human faces - specifically disentangling identity from all other
        attributes.
        <br>
        In the following tables the identity is taken from the image on top and the attributes are taken from the
        left most image.
        In this figure, the inputs themselves are StyleGAN generated images.
        <div class="row" style="text-align:center;padding:0;margin:0">
            <img src="imgs/table_results.jpg" height="850px">
        </div>
        <div class="space"></div>

        More results, but this time, the input images are real.
        <div class="row" style="text-align:center;padding:0;margin:0">
            <img src="imgs/ffhq_table_results.jpg" height="850px">
        </div>
    </div>

    <div class="container">
        <h2>Disentangled Interpolation</h2>

        Thanks to our disentangled representations, we are able to interpolate only a single feature
        (identity or attributes) in the generator's latent space.
        This enables more control and opens the door for new disentangled editing capabilities.
        <br><br>
        <div class="row" style="text-align:center;padding:0;margin:0">
            <img src="imgs/interpolate_attr.jpg" width="1024px">
        </div>
        <div class="space"></div>
        <div class="row" style="text-align:center;padding:0;margin:0">
            <img src="imgs/interpolate_id.jpg" width="1024px">
        </div>

    </div>

    <div class="container">
        <h2>Contact</h2>
        <div>
            yotamnitzan at gmail dot com
        </div>
    </div>

    <div id="footer">
    </div>


</body>
</html>


================================================
FILE: docs/mainpage.css
================================================
body {
  font-family: 'Lato', sans-serif;
  font-weight: 300;
  color: #333;
  font-size: 16px;
}
h1 {
  font-size: 40px;
  color: #555;
  font-weight: 400;
  text-align: center;
  margin: 0;
  padding: 0;
  margin-top: 30px;
  margin-bottom: 10px;
}
.authors {
  color: #222;
  font-size: 24px;
  font-weight: 300;
  text-align: center;
  margin: 0;
  padding: 0;
  margin-bottom: 0px;
}
.logoimg {
  text-align: center;
  margin-bottom: 30px;
}
.container-fluid {
  margin-top: 5px;
  margin-bottom: 5px;
}
.container {
  margin-top: 10px;
}
#footer {
  margin-bottom: 100px;
}
.thumbs {
  -webkit-box-shadow: 1px 1px 3px #999;
  -moz-box-shadow: 1px 1px 3px #999;
  box-shadow: 1px 1px 3px #999;
  margin-bottom: 20px;
}
h2 {
  font-size: 24px;
  font-weight: 900;
  border-bottom: 1px solid #999;
  margin-bottom: 20px;
}

.space {
   margin-bottom: 1.5cm;
}


.text-primary {
  color: #5da2d5 !important;
}
.text-primary:hover {
  color: #f3d250  !important;
  opacity: 1.0;
}

================================================
FILE: docs/setup.md
================================================
# Setup

## Environment

It's designed to use Tensorflow 2.X on python (3.7), using cuda 10.1 and cudnn 7.6.5.
Run `conda create -n environment.yml` to create a conda environment that has the needed dependencies.

Tested with Tensorflow 2.0.0, Python 3.7.9, Ubuntu 14.04. 


## Third-party pretrained networks

Our method relies on several pretrained networks.
Some are needed only for training and some also for inference.
Download according to your intention.

Put all downloaded files/directories under a single directory, which will
be the baseline path for all pretrained networks.

| Name | Training | Inference |Description
| :--- | :----------:| :----------:| :----------
|[FFHQ StyleGAN 256x256](https://drive.google.com/drive/folders/1OgLvUhd9FX9_mPXrfqAWaLZsceQzE9l4?usp=sharing) | :heavy_check_mark: | :heavy_check_mark:  | StyleGAN model pretrained on FFHQ with 256x256 resolution. Converted using [StyleGAN-Tensorflow2](https://github.com/YotamNitzan/StyleGAN-Tensorflow2)
|[FFHQ StyleGAN 1024x1024](https://drive.google.com/drive/folders/1jQxJsmapu6SjygvJfvP4-YVxZ9f5Hu_N?usp=sharing) | :heavy_check_mark: | :heavy_check_mark:  | StyleGAN model pretrained on FFHQ with 1024x1024 resolution. Converted using [StyleGAN-Tensorflow2](https://github.com/YotamNitzan/StyleGAN-Tensorflow2)
|[VGGFace2](https://drive.google.com/file/d/1I_JyR7LH-30hEIpD4OSFVg2TOf9Q8cqU/view?usp=sharing) | :heavy_check_mark: | :heavy_check_mark:  | Pretrained VGGFace2 model taken from [WeidiXie](https://github.com/WeidiXie/Keras-VGGFace2-ResNet50).
|[dlib landmarks model](http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2) |  | :heavy_check_mark: | dlib landmarks model, used to align images.
|[ArcFace](https://drive.google.com/drive/folders/1F-Ll9Nw7I1FGP61cpQxOdhs2nxi0E5mg?usp=sharing) | :heavy_check_mark: |   | Pretrained ArcFace model taken from [dmonterom](https://github.com/dmonterom/face_recognition_TF2).
|[Face & Landmarks Detection](https://drive.google.com/drive/folders/1D__J9UMwzBNR9eVrQGYuL9ueYGi7G4qh?usp=sharing) | :heavy_check_mark: |   | Pretrained face detection and differentiable facial landmarks detection from [610265158](https://github.com/610265158/face_landmark).


### Other StyleGANs

To try out our method with other checkpoints of StyleGAN, first obtain a trained StyleGAN pkl file using the [original StyleGAN repository](https://github.com/NVlabs/stylegan)  
Next, convert it to Tensorflow-2.0 using this [repository](https://github.com/YotamNitzan/StyleGAN-Tensorflow2).





================================================
FILE: environment.yml
================================================
name: id_disen
channels:
  - conda-forge
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - _tflow_select=2.1.0=gpu
  - absl-py=0.11.0=py37h06a4308_0
  - aiohttp=3.6.3=py37h7b6447c_0
  - astor=0.8.1=py37_0
  - async-timeout=3.0.1=py37_0
  - attrs=20.3.0=pyhd3eb1b0_0
  - blas=1.0=mkl
  - blinker=1.4=py37_0
  - blosc=1.20.1=he1b5a44_0
  - brotli=1.0.9=he1b5a44_3
  - brotlipy=0.7.0=py37h27cfd23_1003
  - bzip2=1.0.8=h516909a_3
  - c-ares=1.17.1=h27cfd23_0
  - ca-certificates=2020.11.8=ha878542_0
  - cachetools=4.1.1=py_0
  - certifi=2020.11.8=py37h89c1867_0
  - cffi=1.14.3=py37h261ae71_2
  - chardet=3.0.4=py37h06a4308_1003
  - charls=2.1.0=he1b5a44_2
  - click=7.1.2=py_0
  - cloudpickle=1.6.0=py_0
  - cryptography=3.2.1=py37h3c74f83_1
  - cudatoolkit=10.0.130=0
  - cudnn=7.6.5=cuda10.0_0
  - cupti=10.0.130=0
  - cycler=0.10.0=py_2
  - cytoolz=0.11.0=py37h8f50634_1
  - dask-core=2.30.0=py_0
  - decorator=4.4.2=py_0
  - freetype=2.10.4=h7ca028e_0
  - gast=0.2.2=py37_0
  - giflib=5.2.1=h36c2ea0_2
  - google-auth=1.23.0=pyhd3eb1b0_0
  - google-auth-oauthlib=0.4.2=pyhd3eb1b0_2
  - google-pasta=0.2.0=py_0
  - grpcio=1.31.0=py37hf8bcb03_0
  - h5py=2.10.0=py37hd6299e0_1
  - hdf5=1.10.6=hb1b8bf9_0
  - idna=2.10=py_0
  - imagecodecs=2020.5.30=py37hda6ee5b_1
  - imageio=2.9.0=py_0
  - importlib-metadata=2.0.0=py_1
  - intel-openmp=2020.2=254
  - jpeg=9d=h36c2ea0_0
  - jxrlib=1.1=h516909a_2
  - keras-applications=1.0.8=py_1
  - keras-preprocessing=1.1.0=py_1
  - kiwisolver=1.3.1=py37hc928c03_0
  - lcms2=2.11=hcbb858e_1
  - ld_impl_linux-64=2.33.1=h53a641e_7
  - libaec=1.0.4=he1b5a44_1
  - libedit=3.1.20191231=h14c3975_1
  - libffi=3.3=he6710b0_2
  - libgcc-ng=9.1.0=hdf63c60_0
  - libgfortran-ng=7.3.0=hdf63c60_0
  - libpng=1.6.37=h21135ba_2
  - libprotobuf=3.13.0.1=hd408876_0
  - libstdcxx-ng=9.1.0=hdf63c60_0
  - libtiff=4.1.0=h4f3a223_6
  - libwebp-base=1.1.0=h36c2ea0_3
  - libzopfli=1.0.3=he1b5a44_0
  - lz4-c=1.9.2=he1b5a44_3
  - markdown=3.3.3=py37h06a4308_0
  - matplotlib-base=3.3.3=py37h4f6019d_0
  - mkl=2020.2=256
  - mkl-service=2.3.0=py37he904b0f_0
  - mkl_fft=1.2.0=py37h23d657b_0
  - mkl_random=1.1.1=py37h0573a6f_0
  - multidict=4.7.6=py37h7b6447c_1
  - ncurses=6.2=he6710b0_1
  - networkx=2.5=py_0
  - numpy=1.19.2=py37h54aff64_0
  - numpy-base=1.19.2=py37hfa32c7d_0
  - oauthlib=3.1.0=py_0
  - olefile=0.46=pyh9f0ad1d_1
  - openjpeg=2.3.1=h981e76c_3
  - openssl=1.1.1h=h516909a_0
  - opt_einsum=3.1.0=py_0
  - pillow=8.0.1=py37h63a5d19_0
  - pip=20.2.4=py37h06a4308_0
  - protobuf=3.13.0.1=py37he6710b0_1
  - pyasn1=0.4.8=py_0
  - pyasn1-modules=0.2.8=py_0
  - pycparser=2.20=py_2
  - pyjwt=1.7.1=py37_0
  - pyopenssl=19.1.0=pyhd3eb1b0_1
  - pyparsing=2.4.7=pyh9f0ad1d_0
  - pysocks=1.7.1=py37_1
  - python=3.7.9=h7579374_0
  - python-dateutil=2.8.1=py_0
  - python_abi=3.7=1_cp37m
  - pywavelets=1.1.1=py37h161383b_3
  - readline=8.0=h7b6447c_0
  - requests=2.24.0=py_0
  - requests-oauthlib=1.3.0=py_0
  - rsa=4.6=py_0
  - scikit-image=0.17.2=py37h10a2094_4
  - scipy=1.5.2=py37h0b6359f_0
  - setuptools=50.3.1=py37h06a4308_1
  - six=1.15.0=py37h06a4308_0
  - snappy=1.1.8=he1b5a44_3
  - sqlite=3.33.0=h62c20be_0
  - tensorboard-plugin-wit=1.6.0=py_0
  - tensorflow=2.0.0=gpu_py37h768510d_0
  - tensorflow-base=2.0.0=gpu_py37h0ec5d1f_0
  - tensorflow-estimator=2.0.0=pyh2649769_0
  - termcolor=1.1.0=py37_1
  - tifffile=2020.11.18=pyhd8ed1ab_0
  - tk=8.6.10=hbc83047_0
  - toolz=0.11.1=py_0
  - tornado=6.1=py37h4abf009_0
  - urllib3=1.25.11=py_0
  - werkzeug=0.16.1=py_0
  - wheel=0.35.1=pyhd3eb1b0_0
  - wrapt=1.12.1=py37h7b6447c_1
  - xz=5.2.5=h7b6447c_0
  - yaml=0.2.5=h516909a_0
  - yarl=1.6.2=py37h7b6447c_0
  - zipp=3.4.0=pyhd3eb1b0_0
  - zlib=1.2.11=h7b6447c_3
  - zstd=1.4.5=h6597ccf_2
  - pip:
    - dlib==19.21.0
    - keras==2.3.1
    - mtcnn==0.1.0
    - opencv-python==4.4.0.46
    - pyyaml==5.3.1
    - tensorboard==2.0.2
    - tensorflow-addons==0.6.0
    - tensorflow-gpu==2.0.0
    - tqdm==4.53.0


================================================
FILE: inference.py
================================================
from pathlib import Path

from tqdm import tqdm
import tensorflow as tf

from writer import Writer
from utils import general_utils as utils


class Inference(object):
    def __init__(self, args, model):
        self.args = args
        self.G = model.G

    def infer_pairs(self):
        names = [f for f in self.args.id_dir.iterdir() if f.suffix[1:] in self.args.img_suffixes]
        names.extend([f for f in self.args.attr_dir.iterdir() if f.suffix[1:] in self.args.img_suffixes])

        for img_name in tqdm(names):
            id_path = utils.find_file_by_str(self.args.id_dir, img_name.stem)
            attr_path = utils.find_file_by_str(self.args.attr_dir, img_name.stem)
            if len(id_path) != 1 or len(attr_path) != 1:
                print(f'Could not find a single pair with name: {img_name.stem}')
                continue

            id_img = utils.read_image(id_path, self.args.resolution, self.args.reals)
            attr_img = utils.read_image(attr_path, self.args.resolution, self.args.reals)

            out_img = self.G(id_img, attr_img)[0]

            utils.save_image(out_img, self.args.output_dir.joinpath(f'{img_name.name}'))

    def infer_on_dirs(self):
        attr_paths = list(self.args.attr_dir.iterdir())
        attr_paths.sort()

        id_paths = list(self.args.id_dir.iterdir())
        id_paths.sort()

        for attr_num, attr_img_path in tqdm(enumerate(attr_paths)):
            if not attr_img_path.is_file() or attr_img_path.suffix[1:] not in self.args.img_suffixes:
                continue

            attr_img = utils.read_image(attr_img_path, self.args.resolution, self.args.reals)

            attr_dir = self.args.output_dir.joinpath(f'attr_{attr_num}')
            attr_dir.mkdir(exist_ok=True)

            utils.save_image(attr_img, attr_dir.joinpath(f'attr_image.png'))

            for id_num, id_img_path in enumerate(id_paths):
                if not id_img_path.is_file() or id_img_path.suffix[1:] not in self.args.img_suffixes:
                    continue

                id_img = utils.read_image(id_img_path, self.args.resolution, self.args.reals)

                pred = self.G(id_img, attr_img)[0]

                utils.save_image(pred, attr_dir.joinpath(f'prediction_{id_num}.png'))
                utils.save_image(id_img, attr_dir.joinpath(f'id_{id_num}.png'))

    def interpolate(self, w_space=True):
        # Change to 0,1 for interpolation
        extra_start = 0
        extra_end = 1
        L = extra_end - extra_start
        # Extrapolation values include the 0,1 iff
        #   N-1 is divisible by L if including endpoint
        #   N is divisble by L o.w
        #   where L is the length of the extrapolation range ( L = b-a for [a,b] )
        #   and N is number of jumps
        num_jumps = 8 * L + 1

        for d in self.args.input_dir.iterdir():
            out_d = self.args.output_dir.joinpath(d.name)
            out_d.mkdir(exist_ok=True)

            ids = list(d.glob('*id*'))
            attrs = list(d.glob('*attr*'))

            if len(ids) == 1 and len(attrs) == 2:
                const = 'id'
            elif len(ids) == 2 and len(attrs) == 1:
                const = 'attr'
            else:
                print(f'Wrong data format for {d.name}')
                continue

            if const == 'id':
                start_img = utils.read_image(attrs[0], self.args.resolution, self.args.real_attr)
                end_img = utils.read_image(attrs[1], self.args.resolution, self.args.real_attr)
                const_img = utils.read_image(ids[0], self.args.resolution, self.args.real_id)

                if self.args.loop_fake:
                    if not self.args.real_attr:
                        start_img = self.G(start_img, start_img)
                        end_img = self.G(end_img, end_img)
                    if not self.args.real_id:
                        const_img = self.G(const_img, const_img)

                const_id = self.G.id_encoder(const_img)
                start_attr = self.G.attr_encoder(start_img)
                end_attr = self.G.attr_encoder(end_img)

                s_z = tf.concat([const_id, start_attr], -1)
                e_z = tf.concat([const_id, end_attr], -1)

            elif const == 'attr':
                start_img = utils.read_image(ids[0], self.args.resolution, self.args.real_id)
                end_img = utils.read_image(ids[1], self.args.resolution, self.args.real_id)
                const_img = utils.read_image(attrs[0], self.args.resolution, self.args.real_attr)

                if self.args.loop_fake:
                    if not self.args.real_attr:
                        const_img = self.G(const_img, const_img)[0]
                    if not self.args.real_id:
                        start_img = self.G(start_img, start_img)[0]
                        end_img = self.G(end_img, end_img)[0]

                start_id = self.G.id_encoder(start_img)
                end_id = self.G.id_encoder(end_img)

                const_attr = self.G.attr_encoder(const_img)

                s_z = tf.concat([start_id, const_attr], -1)
                e_z = tf.concat([end_id, const_attr], -1)


            utils.save_image(const_img, out_d.joinpath(f'const_{const}.png'))
            utils.save_image(start_img, out_d.joinpath(f'start.png'))
            utils.save_image(end_img, out_d.joinpath(f'end.png'))

            if w_space:
                s_w = self.G.latent_spaces_mapping(s_z)
                e_w = self.G.latent_spaces_mapping(e_z)
                for i in range(num_jumps):
                    inter_w = (1 - i / num_jumps) * s_w + (i / num_jumps) * e_w
                    out = self.G.stylegan_s(inter_w)
                    out = (out + 1) / 2
                    utils.save_image(out[0],
                                     out_d.joinpath(f'inter_{i:03}.png'))
            else:
                for i in range(num_jumps):
                    inter_z = (1 - i / num_jumps) * s_z + (i / num_jumps) * e_z
                    inter_w = self.G.latent_spaces_mapping(inter_z)
                    out = self.G.stylegan_s(inter_w)
                    out = (out + 1) / 2
                    utils.save_image(out[0],
                                     out_d.joinpath(f'inter_{i:03}.png'))



================================================
FILE: main.py
================================================
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['USE_SIMPLE_THREADED_LEVEL3'] = '1'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

import sys
import logging
from model.stylegan import StyleGAN_G_synthesis
from model.model import Network
from data_loader.data_loader import DataLoader
from writer import Writer
from trainer import Trainer
from arglib import arglib
from utils import general_utils as utils

sys.path.insert(0, 'model/face_utils')


def init_logger(args):
    root_logger = logging.getLogger()

    level = logging.DEBUG if args.log_debug else logging.INFO
    root_logger.setLevel(level)

    file_handler = logging.FileHandler(f'{args.results_dir}/log.txt')
    console_handler = logging.StreamHandler()

    datefmt = '%Y-%m-%d %H:%M:%S'
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt)

    file_handler.setLevel(level)
    console_handler.setLevel(level)

    file_handler.setFormatter(formatter)
    console_handler.setFormatter(formatter)

    root_logger.addHandler(file_handler)
    root_logger.addHandler(console_handler)

    pil_logger = logging.getLogger('PIL.PngImagePlugin')
    pil_logger.setLevel(logging.INFO)


def main():
    train_args = arglib.TrainArgs()
    args, str_args = train_args.args, train_args.str_args
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    init_logger(args)

    logger = logging.getLogger('main')

    cmd_line = ' '.join(sys.argv)
    logger.info(f'cmd line is: \n {cmd_line}')

    logger.info(str_args)
    logger.debug('Copying src to results dir')

    Writer.set_writer(args.results_dir)

    if not args.debug:
        description = input('Please write a short description of this run\n')
        desc_file = args.results_dir.joinpath('description.txt')
        with desc_file.open('w') as f:
            f.write(description)

    id_model_path = args.pretrained_models_path.joinpath('vggface2.h5')
    stylegan_G_synthesis_path = str(
        args.pretrained_models_path.joinpath(f'stylegan_G_{args.resolution}x{args.resolution}_synthesis'))
    landmarks_model_path = str(args.pretrained_models_path.joinpath('face_utils/keypoints'))
    face_detection_model_path = str(args.pretrained_models_path.joinpath('face_utils/detector'))

    arcface_model_path = str(args.pretrained_models_path.joinpath('arcface_weights/weights-b'))
    utils.landmarks_model_path = str(args.pretrained_models_path.joinpath('shape_predictor_68_face_landmarks.dat'))

    stylegan_G_synthesis = StyleGAN_G_synthesis(resolution=args.resolution, is_const_noise=args.const_noise)
    stylegan_G_synthesis.load_weights(stylegan_G_synthesis_path)

    network = Network(args, id_model_path, stylegan_G_synthesis, landmarks_model_path,
                      face_detection_model_path, arcface_model_path)
    data_loader = DataLoader(args)

    trainer = Trainer(args, network, data_loader)
    trainer.train()


if __name__ == '__main__':
    main()


================================================
FILE: model/__init__.py
================================================


================================================
FILE: model/arcface/arcface.py
================================================
import tensorflow as tf
import math

num_classes = 85742  # 10572
initializer = 'glorot_normal'
# initializer = tf.keras.initializers.TruncatedNormal(
#     mean=0.0, stddev=0.05, seed=None)
# initializer = tf.keras.initializers.VarianceScaling(
#     scale=0.05, mode='fan_avg', distribution='normal', seed=None)


class Arcfacelayer(tf.keras.layers.Layer):
    def __init__(self, output_dim=num_classes, s=64., m=0.50):
        self.output_dim = output_dim
        self.s = s
        self.m = m
        super(Arcfacelayer, self).__init__()

    def build(self, input_shape):
        self.kernel = self.add_weight(name='kernel',
                                      shape=(input_shape[-1],
                                             self.output_dim),
                                      initializer=initializer,
                                      regularizer=tf.keras.regularizers.l2(
                                          l=5e-4),
                                      trainable=True)
        super(Arcfacelayer, self).build(input_shape)

    def call(self, embedding, labels):
        cos_m = math.cos(self.m)
        sin_m = math.sin(self.m)
        mm = sin_m * self.m  # issue 1
        threshold = math.cos(math.pi - self.m)
        # inputs and weights norm
        embedding_norm = tf.norm(embedding, axis=1, keepdims=True)
        embedding = embedding / embedding_norm
        weights_norm = tf.norm(self.kernel, axis=0, keepdims=True)
        weights = self.kernel / weights_norm
        # cos(theta+m)
        cos_t = tf.matmul(embedding, weights, name='cos_t')
        cos_t2 = tf.square(cos_t, name='cos_2')
        sin_t2 = tf.subtract(1., cos_t2, name='sin_2')
        sin_t = tf.sqrt(sin_t2, name='sin_t')
        cos_mt = self.s * tf.subtract(tf.multiply(cos_t, cos_m),
                                      tf.multiply(sin_t, sin_m), name='cos_mt')

        # this condition controls the theta+m should in range [0, pi]
        #      0<=theta+m<=pi
        #     -m<=theta<=pi-m
        cond_v = cos_t - threshold
        cond = tf.cast(tf.nn.relu(cond_v, name='if_else'), dtype=tf.bool)

        keep_val = self.s * (cos_t - mm)
        cos_mt_temp = tf.where(cond, cos_mt, keep_val)

        mask = tf.one_hot(labels, depth=self.output_dim, name='one_hot_mask')
        # mask = tf.squeeze(mask, 1)
        inv_mask = tf.subtract(1., mask, name='inverse_mask')

        s_cos_t = tf.multiply(self.s, cos_t, name='scalar_cos_t')

        output = tf.add(tf.multiply(s_cos_t, inv_mask), tf.multiply(
            cos_mt_temp, mask), name='arcface_loss_output')

        return output

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_dim)


================================================
FILE: model/arcface/inference.py
================================================
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
import cv2
from model.arcface.resnet import ResNet50, train_model
from mtcnn import MTCNN
from skimage import transform as trans


class MyArcFace:
    def __init__(self, path_to_weights):
        self.model = train_model()
        self.model.load_weights(path_to_weights)
        self.model_resnet = self.model.resnet
        self.model.resnet.trainable = False
        self.mtcnn = MTCNN(min_face_size=80)

    def get_best_face(self, faces, resolution):
        if len(faces) == 0:
            raise IndexError('No faces found')
        if len(faces) == 1:
            return faces[0]

        print('Found more than one face')

        indices = list(range(len(faces)))

        # filter low confidence
        new_indices = [ind for ind in indices if faces[ind]['confidence'] > 0.99]
        # print(f'after confidence filtering: {len(new_indices)}')
        if len(new_indices) == 1:
            return faces[new_indices[0]]
        elif len(new_indices) > 1:
            indices = new_indices

        # filter not centered, distance between x and y must relatively small
        new_indices = [ind for ind in indices if np.abs(faces[ind]['box'][0] - faces[ind]['box'][1]) < resolution / 2.5]
        # print(f'after center filtering: {len(new_indices)}')
        if len(new_indices) == 1:
            return faces[new_indices[0]]
        elif len(new_indices) > 1:
            indices = new_indices

        # Take box with biggest height
        ind = max(indices, key=lambda ind: faces[ind]['box'][-1])
        return faces[ind]

    def __detect_face(self, img):
        # The assumption is that the image is RGB
        faces = self.mtcnn.detect_faces(img)
        face_obj = self.get_best_face(faces, img.shape[0])

        face_box_obj = face_obj['box']
        face_landmarks_obj = face_obj['keypoints']
        face_landmarks = np.zeros((5, 2))
        face_landmarks[0] = [face_landmarks_obj['left_eye'][0], face_landmarks_obj['right_eye'][1]]
        face_landmarks[1] = [face_landmarks_obj['right_eye'][0], face_landmarks_obj['left_eye'][1]]
        face_landmarks[2] = [face_landmarks_obj['nose'][0], face_landmarks_obj['nose'][1]]
        face_landmarks[3] = [face_landmarks_obj['mouth_left'][0], face_landmarks_obj['mouth_right'][1]]
        face_landmarks[4] = [face_landmarks_obj['mouth_right'][0], face_landmarks_obj['mouth_left'][1]]
        x = face_box_obj[0]
        y = face_box_obj[1]
        w = face_box_obj[2]
        h = face_box_obj[3]
        face_box = [x, y, x + w, y + h]
        return face_box, face_landmarks

    def __preprocess(self, img, bbox=None, landmark=None):
        M = None
        image_size = [112, 112]
        assert landmark is not None
        src = np.array([
            [30.2946, 51.6963],
            [65.5318, 51.5014],
            [48.0252, 71.7366],
            [33.5493, 92.3655],
            [62.7299, 92.2041]], dtype=np.float32)
        if image_size[1] == 112:
            src[:, 0] += 8.0
        dst = landmark.astype(np.float32)
        tform = trans.SimilarityTransform()
        tform.estimate(src, dst)
        M = tform.params
        assert M is not None
        transforms = np.array(M).flatten()[:-1]
        tf_transforms = tf.constant([transforms], tf.float32)
        img_tensor = tf.convert_to_tensor(img.astype(np.float32))
        batch = tf.stack([img_tensor])
        output = tfa.image.transform(batch, tf_transforms, interpolation='BILINEAR', output_shape=image_size)
        return output

    def process_image(self, img):

        if (isinstance(img, tf.Tensor) and img.dtype != tf.dtypes.uint8) or img.dtype != np.uint8:
            img = np.uint8(img * 255)

        face_box, face_landmarks = self.__detect_face(img)
        aligned_face = self.__preprocess(img, face_box, face_landmarks)
        aligned_face -= 127.5
        aligned_face *= 0.0078125
        embeddings = self.model_resnet(aligned_face)
        normelized_embeddings = tf.math.l2_normalize(embeddings)
        return normelized_embeddings

    def __call__(self, img):
        if img.ndim == 4:
            embedding_list = []
            for x in img:
                norm_embedding = self.process_image(x)
                embedding_list.append(norm_embedding)
            return np.array(embedding_list)
        else:
            return self.process_image(img)


================================================
FILE: model/arcface/resnet.py
================================================
import tensorflow as tf
import os
from model.arcface.arcface import Arcfacelayer

bn_axis = -1
initializer = 'glorot_normal'


def residual_unit_v3(input, num_filter, stride, dim_match, name):
    x = tf.keras.layers.BatchNormalization(axis=bn_axis,
                                           scale=True,
                                           momentum=0.9,
                                           epsilon=2e-5,
                                           #    beta_regularizer=tf.keras.regularizers.l2(
                                           #        l=5e-4),
                                           gamma_regularizer=tf.keras.regularizers.l2(
                                               l=5e-4),
                                           name=name + '_bn1')(input)
    x = tf.keras.layers.ZeroPadding2D(
        padding=(1, 1), name=name + '_conv1_pad')(x)
    x = tf.keras.layers.Conv2D(num_filter, (3, 3),
                               strides=(1, 1),
                               padding='valid',
                               kernel_initializer=initializer,
                               use_bias=False,
                               kernel_regularizer=tf.keras.regularizers.l2(
                                   l=5e-4),
                               name=name + '_conv1')(x)
    x = tf.keras.layers.BatchNormalization(axis=bn_axis,
                                           scale=True,
                                           momentum=0.9,
                                           epsilon=2e-5,
                                           #    beta_regularizer=tf.keras.regularizers.l2(
                                           #        l=5e-4),
                                           gamma_regularizer=tf.keras.regularizers.l2(
                                               l=5e-4),
                                           name=name + '_bn2')(x)
    x = tf.keras.layers.PReLU(name=name + '_relu1',
                              alpha_regularizer=tf.keras.regularizers.l2(
                                  l=5e-4))(x)
    x = tf.keras.layers.ZeroPadding2D(
        padding=(1, 1), name=name + '_conv2_pad')(x)
    x = tf.keras.layers.Conv2D(num_filter, (3, 3),
                               strides=stride,
                               padding='valid',
                               kernel_initializer=initializer,
                               use_bias=False,
                               kernel_regularizer=tf.keras.regularizers.l2(
                                   l=5e-4),
                               name=name + '_conv2')(x)
    x = tf.keras.layers.BatchNormalization(axis=bn_axis,
                                           scale=True,
                                           momentum=0.9,
                                           epsilon=2e-5,
                                           #    beta_regularizer=tf.keras.regularizers.l2(
                                           #        l=5e-4),
                                           gamma_regularizer=tf.keras.regularizers.l2(
                                               l=5e-4),
                                           name=name + '_bn3')(x)
    if (dim_match):
        shortcut = input
    else:
        shortcut = tf.keras.layers.Conv2D(num_filter, (1, 1),
                                          strides=stride,
                                          padding='valid',
                                          kernel_initializer=initializer,
                                          use_bias=False,
                                          kernel_regularizer=tf.keras.regularizers.l2(
                                              l=5e-4),
                                          name=name + '_conv1sc')(input)
        shortcut = tf.keras.layers.BatchNormalization(axis=bn_axis,
                                                      scale=True,
                                                      momentum=0.9,
                                                      epsilon=2e-5,
                                                      #   beta_regularizer=tf.keras.regularizers.l2(
                                                      #       l=5e-4),
                                                      gamma_regularizer=tf.keras.regularizers.l2(
                                                          l=5e-4),
                                                      name=name + '_sc')(shortcut)
    return x + shortcut


def get_fc1(input):
    x = tf.keras.layers.BatchNormalization(axis=bn_axis,
                                           scale=True,
                                           momentum=0.9,
                                           epsilon=2e-5,
                                           #    beta_regularizer=tf.keras.regularizers.l2(
                                           #        l=5e-4),
                                           gamma_regularizer=tf.keras.regularizers.l2(
                                               l=5e-4),
                                           name='bn1')(input)
    x = tf.keras.layers.Dropout(0.4)(x)
    resnet_shape = input.shape
    x = tf.keras.layers.Reshape(
        [resnet_shape[1] * resnet_shape[2] * resnet_shape[3]], name='reshapelayer')(x)
    x = tf.keras.layers.Dense(512,
                              name='E_DenseLayer', kernel_initializer=initializer,
                              kernel_regularizer=tf.keras.regularizers.l2(
                                  l=5e-4),
                              bias_regularizer=tf.keras.regularizers.l2(
                                  l=5e-4))(x)
    x = tf.keras.layers.BatchNormalization(axis=-1,
                                           scale=False,
                                           momentum=0.9,
                                           epsilon=2e-5,
                                           #    beta_regularizer=tf.keras.regularizers.l2(
                                           #        l=5e-4),
                                           name='fc1')(x)
    return x


def ResNet50():

    input_shape = [112, 112, 3]
    filter_list = [64, 64, 128, 256, 512]
    units = [3, 4, 14, 3]
    num_stages = 4

    img_input = tf.keras.layers.Input(shape=input_shape)

    x = tf.keras.layers.ZeroPadding2D(
        padding=(1, 1), name='conv0_pad')(img_input)
    x = tf.keras.layers.Conv2D(64, (3, 3),
                               strides=(1, 1),
                               padding='valid',
                               kernel_initializer=initializer,
                               use_bias=False,
                               kernel_regularizer=tf.keras.regularizers.l2(
                                   l=5e-4),
                               name='conv0')(x)
    x = tf.keras.layers.BatchNormalization(axis=bn_axis,
                                           scale=True,
                                           momentum=0.9,
                                           epsilon=2e-5,
                                           #    beta_regularizer=tf.keras.regularizers.l2(
                                           #        l=5e-4),
                                           gamma_regularizer=tf.keras.regularizers.l2(
                                               l=5e-4),
                                           name='bn0')(x)
    # x = tf.keras.layers.Activation('prelu')(x)
    x = tf.keras.layers.PReLU(
        name='prelu0',
        alpha_regularizer=tf.keras.regularizers.l2(
            l=5e-4))(x)

    for i in range(num_stages):
        x = residual_unit_v3(x, filter_list[i + 1], (2, 2), False,
                             name='stage%d_unit%d' % (i + 1, 1))
        for j in range(units[i] - 1):
            x = residual_unit_v3(x, filter_list[i + 1], (1, 1),
                                 True, name='stage%d_unit%d' % (i + 1, j + 2))

    x = get_fc1(x)

    # Create model.
    model = tf.keras.models.Model(img_input, x, name='resnet50')
    model.trainable = True
    for i in range(len(model.layers)):
        model.layers[i].trainable = True
        # if ('conv0' in model.layers[i].name):
        #     model.layers[i].trainable = False
        # if ('bn0' in model.layers[i].name):
        #     model.layers[i].trainable = False
        # if ('prelu0' in model.layers[i].name):
        #     model.layers[i].trainable = False
        # if ('stage1' in model.layers[i].name):
        #     model.layers[i].trainable = False
        # if ('stage2' in model.layers[i].name):
        #     model.layers[i].trainable = False
        # if ('stage3' in model.layers[i].name):
        #     model.layers[i].trainable = False
        # if ('stage4' in model.layers[i].name):
        #     model.layers[i].trainable = False

    return model


class train_model(tf.keras.Model):
    def __init__(self):
        super(train_model, self).__init__()
        self.resnet = ResNet50()
        self.arcface = Arcfacelayer()

    def call(self, x, y):
        x = self.resnet(x)
        return self.arcface(x, y)


================================================
FILE: model/attr_encoder.py
================================================
import logging

import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input


class AttrEncoder(Model):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.logger = logging.getLogger(__class__.__name__)

        attr_encoder = InceptionV3(include_top=False, pooling='avg')
        self.model = attr_encoder

        if self.args.load_checkpoint:
            self.model.load_weights(str(self.args.load_checkpoint.joinpath(self.__class__.__name__ + '.h5')))

    @tf.function
    def call(self, input_x):
        x = tf.image.resize(input_x, (299, 299))
        x = preprocess_input(255 * x)
        x = self.model(x)
        x = tf.expand_dims(x, 1)

        return x

    def my_save(self, reason=''):
        self.model.save_weights(str(self.args.weights_dir.joinpath(self.__class__.__name__ + reason + '.h5')))


================================================
FILE: model/discriminator.py
================================================
import tensorflow as tf
from tensorflow.keras import layers, Model
from utils.general_utils import get_weights


# Discriminate between my w's and StyleGAN's w's
class W_D(Model):
    def __init__(self, args):
        super().__init__()
        self.args = args
        slope = 0.2

        # self.linear1 = layers.Dense(512, kernel_initializer=get_weights(slope), input_shape=(512,))
        self.linear2 = layers.Dense(256, kernel_initializer=get_weights(slope), input_shape=(512,))
        self.linear3 = layers.Dense(128, kernel_initializer=get_weights(slope))
        self.linear4 = layers.Dense(64, kernel_initializer=get_weights(slope))
        self.linear5 = layers.Dense(1, kernel_initializer=get_weights(slope))
        self.relu = layers.LeakyReLU(slope)

        if self.args.load_checkpoint:
            self.build(input_shape=(1, 1, 512))
            self.load_weights(str(self.args.load_checkpoint.joinpath(self.__class__.__name__ + '.h5')))

    @tf.function
    def call(self, x):
        # x = self.linear1(x)
        # x = self.relu(x)
        x = self.linear2(x)
        x = self.relu(x)
        x = self.linear3(x)
        x = self.relu(x)
        x = self.linear4(x)
        x = self.relu(x)
        x = self.linear5(x)

        return x

    def my_save(self, reason=''):
        self.save_weights(str(self.args.weights_dir.joinpath(self.__class__.__name__ + reason + '.h5')))


================================================
FILE: model/face_detector.py
================================================
import tensorflow as tf


class FaceDetector(object):
    def __init__(self, args, model_path):
        super().__init__()
        self.args = args
        self.model_path = model_path
        self.model = None

    def _build(self):
        if not self.model:
            self.model = tf.saved_model.load(self.model_path)

    def __call__(self, input_x):
        """
        Given a batch of images, return the face bounding box in (x1,y1,x2,y2) format
        """

        if not self.model:
            self._build()

        boxes = []
        for sample in input_x:
            boxes.append(self.sample_call(sample))

        boxes = tf.stack(boxes, axis=0)
        boxes = boxes * self.args.resolution

        return boxes

    def sample_call(self, input_x):
        boxes = self.model.inference(tf.expand_dims(input_x, axis=0))
        boxes = tf.squeeze(boxes)
        indices, scores = \
            tf.image.non_max_suppression_with_scores(boxes[..., :4], boxes[..., 4],
                                                     max_output_size=1, iou_threshold=0.3, score_threshold=0.5)
        i = indices.numpy()[0]
        box = boxes[i, :4]
        return box


================================================
FILE: model/generator.py
================================================
import logging

from model import id_encoder
from model import attr_encoder
from model import latent_mapping
from model import landmarks
from model.arcface.inference import MyArcFace

import tensorflow as tf
from tensorflow.keras import layers, Model


class G(Model):
    def __init__(self, args, id_model_path, image_G,
                 landmarks_net_path, face_detection_model_path, test_id_model_path):

        super().__init__()
        self.args = args
        self.logger = logging.getLogger(__class__.__name__)

        self.id_encoder = id_encoder.IDEncoder(args, id_model_path)
        self.id_encoder.trainable = False

        self.attr_encoder = attr_encoder.AttrEncoder(args)

        self.latent_spaces_mapping = latent_mapping.LatentMappingNetwork(args)

        self.stylegan_s = image_G
        self.stylegan_s.trainable = False

        if args.train:
            self.test_id_encoder = MyArcFace(test_id_model_path)
            self.test_id_encoder.trainable = False

            self.landmarks = landmarks.LandmarksDetector(args, landmarks_net_path, face_detection_model_path)
            self.landmarks.trainable = False

    @tf.function
    def call(self, x1, x2):
        id_embedding = self.id_encoder(x1)

        if self.args.train:
            lnds = self.landmarks(x2)
        else:
            lnds = None
        attr_input = x2

        attr_out = self.attr_encoder(attr_input)
        attr_embedding = attr_out
        z_tag = tf.concat([id_embedding, attr_embedding], -1)
        w = self.latent_spaces_mapping(z_tag)

        out = self.stylegan_s(w)

        # Move to roughly [0,1]
        out = (out + 1) / 2

        return out, id_embedding,  attr_out, w[:, 0, :], lnds

    def my_save(self, reason=''):
        self.attr_encoder.my_save(reason)
        self.latent_spaces_mapping.my_save(reason)




================================================
FILE: model/id_encoder.py
================================================
import tensorflow as tf
import numpy as np
from tensorflow.keras import Model

class IDEncoder(Model):

    def __init__(self, args, model_path, intermediate_layers_names=None):
        super().__init__()
        self.args = args
        self.mean = (91.4953, 103.8827, 131.0912)
        base_model = tf.keras.models.load_model(model_path)

        if intermediate_layers_names:
            outputs = [base_model.get_layer(name).output for name in intermediate_layers_names]
        else:
            outputs = []

        # Add output of the network in any case
        outputs.append(base_model.layers[-2].output)

        self.model = tf.keras.Model(base_model.inputs, outputs)


    def crop_faces(self, img):
        ps = []
        for i in range(img.shape[0]):
            oneimg = img[i]
            try:
                box = tf.numpy_function(self.mtcnn.detect_faces, [oneimg], np.uint8)
                box = [z.numpy() for z in box[:4]]

                x1, y1, w, h = box

                x_expand = w * 0.3
                y_expand = h * 0.3

                x1 = int(np.maximum(x1 - x_expand // 2, 0))
                y1 = int(np.maximum(y1 - y_expand // 2, 0))

                x2 = int(np.minimum(x1 + w + x_expand // 2, self.args.resolution))
                y2 = int(np.minimum(y1 + h + y_expand // 2, self.args.resolution))
            except Exception as e:
                x1, y1, x2, y2 = 24, 50, 224, 250

            p = oneimg[y1:y2, x1:x2, :]
            p = tf.convert_to_tensor(p)
            p = tf.image.resize(p, (self.args.resolution, self.args.resolution))
            ps.append(p)

        ps = tf.stack(ps, 0)
        return ps

    def preprocess(self, img):
        """
        In VGGFace2 The preprocessing is:
            1. Face detection
            2. Expand bbox by factor of 0.3
            3. Resize so shorter side is 256
            4. Crop center 224x224

        In StyleGAN faces are not in-the-wild, we get an image of the head.
        Just cropping a loose center instead of face detection
        """

        # Go from [0, 1] to [0, 255]
        img = 255 * img

        min_x = int(0.1 * self.args.resolution)
        max_x = int(0.9 * self.args.resolution)
        min_y = int(0.1 * self.args.resolution)
        max_y = int(0.9 * self.args.resolution)

        img = img[:, min_x:max_x, min_y:max_y, :]
        img = tf.image.resize(img, (256, 256))

        start = (256 - 224) // 2
        img = img[:, start: 224 + start, start: 224 + start, :]
        img = img[:, :, :, ::-1] - self.mean

        return img

    @tf.function
    def call(self, input_x, get_intermediate=False):
        x = self.preprocess(input_x)
        x = self.model(x)

        if isinstance(x, list):
            embedding = x[-1]
            intermediates = x[:-1]
        else:
            embedding = x
            intermediates = None

        embedding = tf.math.l2_normalize(embedding, axis=-1)
        embedding = tf.expand_dims(embedding, 1)

        if get_intermediate and intermediates:
            return embedding, intermediates
        else:
            return embedding


================================================
FILE: model/landmarks.py
================================================
import cv2
import tensorflow as tf
import numpy as np
from tensorflow.keras import Model

from utils import general_utils as utils
from model.face_detector import FaceDetector


class LandmarksDetector(Model):
    def __init__(self, args, model_path, face_detection_model_path):
        super().__init__()
        self.args = args
        self.face_detector = FaceDetector(args, face_detection_model_path)
        self.expand_ratio = 0.2

        # Load without source code
        self.model = tf.saved_model.load(model_path)

    # Preprocess
    def preprocess(self, imgs, face_detection=False):
        imgs *= 255
        if face_detection:
            imgs, details = self.hard_preprocess(imgs)
        else:
            imgs, details = self.lazy_preprocess(imgs)

        return imgs, details

    def lazy_preprocess(self, imgs):
        imgs = tf.image.resize(imgs, (160, 160))
        return imgs, 160

    def hard_preprocess(self, imgs):
        bboxes = self.face_detector(imgs)

        centers = np.array([bboxes[:, 0] + bboxes[:, 2], bboxes[:, 1] + bboxes[:, 3]]).T // 2

        # Duplicate center point into column order of x,x,y,y
        centers = np.repeat(centers, repeats=2, axis=1)

        # Permute columns order into x,y,x,y
        centers[:] = utils.np_permute(centers, [0, 2, 1, 3])

        # Calculate widths of current bboxes
        widths = np.transpose([bboxes[:, 2] - bboxes[:, 0]])

        # Calculate the maximal expansion
        max_expand = int(np.ceil(np.max(widths) * self.expand_ratio))

        # Pad the image with the maximal expansion.
        # Useful in case an expanded bounding box goes outside image
        paddings = tf.constant([[0, 0], [max_expand, max_expand], [max_expand, max_expand], [0, 0]])
        pad_imgs = tf.pad(imgs, paddings, mode='CONSTANT', constant_values=127.)

        # The size of the new square bounding box
        new_scales = np.floor((1 + 2 * self.expand_ratio) * widths)

        # Size of step from the center
        new_half_scales = new_scales // 2

        # Repeat step in all directions
        # Decrease in start point, Increase in end point
        new_half_scales = np.repeat(new_half_scales, repeats=4, axis=1) * [-1, -1, 1, 1]

        # Bounding boxes in respect to padded image
        new_bboxes = centers + new_half_scales + max_expand

        # tf.image.crop_and_resize requires bounding boxes to be normalized
        # i.e., between [0,1] and also in order (y,x)
        normed_bboxes = utils.np_permute(new_bboxes, [1, 0, 3, 2]) / pad_imgs.shape[1]

        cropped_imgs = tf.image.crop_and_resize(pad_imgs, normed_bboxes,
                                                box_indices=range(self.args.batch_size), crop_size=(160, 160))

        details = (new_scales, new_bboxes[:,:2], max_expand)
        return cropped_imgs, details

    # Postprocess
    def postprocess(self, landmarks, details, face_detection=False):
        landmarks = tf.reshape(landmarks, [-1, 68, 2])

        if face_detection:
            return self.hard_postprocess(landmarks, details)
        else:
            return self.lazy_postprocess(landmarks, details)

    def lazy_postprocess(self, batch_lnds, details):
        scale = details
        return scale * batch_lnds

    def hard_postprocess(self, batch_lnds, details):
        scale, from_origin, pad = details

        scale = tf.broadcast_to(scale, [scale.shape[0], 2])
        scale = tf.expand_dims(scale, axis=1)

        from_origin = tf.expand_dims(from_origin, axis=1)
        from_origin = tf.cast(from_origin, tf.dtypes.float32)

        lnds = batch_lnds * scale + from_origin - pad
        return lnds

    @tf.function
    def call(self, input_x, face_detection=False):

        # The network input format is a uint8 image (0-255) but in float32 dtype. ^__('')__^
        x, details = self.preprocess(input_x, face_detection)

        batch_lnds = self.model.inference(x)['landmark']

        batch_lnds = self.postprocess(batch_lnds, details, face_detection)

        return batch_lnds[:, 17:, :]


================================================
FILE: model/latent_mapping.py
================================================
from utils.general_utils import get_weights

import tensorflow as tf
import numpy as np
from tensorflow.keras import layers, Model


class LatentMappingNetwork(Model):
    def __init__(self, args):
        super().__init__()
        self.args = args

        input_shape = (2560,)

        self.linear1 = layers.Dense(2048, input_shape=input_shape)
        self.linear2 = layers.Dense(1024)
        self.linear3 = layers.Dense(512, kernel_initializer=get_weights())
        self.linear4 = layers.Dense(512, kernel_initializer=get_weights())
        self.linears = [self.linear1, self.linear2, self.linear3, self.linear4]

        self.relu = layers.LeakyReLU(0.2)

        self.num_styles = int(np.log2(self.args.resolution)) * 2 - 2

        if self.args.load_checkpoint:
            self.build(input_shape=(1, 1, 2560))
            self.load_weights(str(self.args.load_checkpoint.joinpath(self.__class__.__name__ + '.h5')))

    @tf.function
    def call(self, x):
        first = True
        for layer in self.linears:
            if not first:
                x = self.relu(x)

            x = layer(x)
            first = False

        s = list(x.shape)

        # Duplicate the column vector w along columns for each AdaIN entry
        s[1] = self.num_styles
        x = tf.broadcast_to(x, s)

        return x

    def my_save(self, reason=''):
        self.save_weights(str(self.args.weights_dir.joinpath(self.__class__.__name__ + reason + '.h5')))


================================================
FILE: model/model.py
================================================
import time
import sys

sys.path.append('..')

from utils import general_utils as utils
from model import id_encoder, latent_mapping, attr_encoder,\
    generator, discriminator, landmarks

from model.stylegan import StyleGAN_G, StyleGAN_D

import tensorflow as tf
from tensorflow.keras import layers, Model


class Network(Model):
    def __init__(self, args, id_net_path, base_generator,
                 landmarks_net_path=None, face_detection_model_path=None, test_id_net_path=None):
        super().__init__()
        self.args = args
        self.G = generator.G(args, id_net_path, base_generator,
                             landmarks_net_path, face_detection_model_path, test_id_net_path)

        if self.args.train:
            self.W_D = discriminator.W_D(args)

    def call(self):
        raise NotImplemented()

    def my_save(self, reason):
        self.G.my_save(reason)

        if self.args.W_D_loss:
            self.W_D.my_save(reason)

    def my_load(self):
        raise NotImplemented()

    def train(self):
        self._set_trainable_behavior(True)

    def test(self):
        self._set_trainable_behavior(False)

    def _set_trainable_behavior(self, trainable):
        self.G.attr_encoder.trainable = trainable
        self.G.latent_spaces_mapping.trainable = trainable


================================================
FILE: model/stylegan.py
================================================
import sys
import math
import numpy as np
import tensorflow as tf

import matplotlib.pyplot as plt

from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import Layer, InputLayer, Multiply, Lambda, Flatten, Dense, Conv2D, Conv2DTranspose
from tensorflow.keras.initializers import VarianceScaling

from tensorflow.python.keras import backend
from tensorflow.python.ops import array_ops

def nf(stage, fmap_base=8192, fmap_decay=1.0, fmap_max=512): 
    return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)

def LeakyReLU(alpha, name):
    def lrelu(x, alpha):
        alpha = tf.constant(alpha, dtype=x.dtype, name='alpha')
        return tf.maximum(x, x * alpha)
    return Lambda(lambda x: lrelu(x, alpha), name=name)

def GetWeights(gain=math.sqrt(2)):
    return VarianceScaling(gain)

def runtime_coef(kernel_size, gain, fmaps_in, fmaps_out, lrmul=1.0):
    # Equalized learning rate and custom learning rate multiplier.
    shape = [kernel_size[0], kernel_size[1], fmaps_in, fmaps_out]
    fan_in = np.prod(shape[:-1]) # [kernel, kernel, fmaps_in, fmaps_out] or [in, out]
    he_std = gain / np.sqrt(fan_in) # He init
    init_std = 1.0 / lrmul
    return he_std * lrmul 

def pixel_norm(x, epsilon=1e-8):
    epsilon = tf.constant(epsilon, dtype=x.dtype, name='epsilon')
    return x * tf.math.rsqrt(tf.reduce_mean(tf.square(x), axis=1, keepdims=True) + epsilon)

class PixelNorm(Layer):
    def __init__(self, name):
        super(PixelNorm, self).__init__(name=name)
    
    def call(self, inputs):
        return pixel_norm(inputs)
    
class InstanceNorm(Layer):
    def __init__(self, name):
        super(InstanceNorm, self).__init__(name=name)
    
    def call(self, x):
        epsilon=1e-8
        orig_dtype = x.dtype
        x = tf.cast(x, tf.float32)
        x -= tf.reduce_mean(x, axis=[2,3], keepdims=True)
        epsilon = tf.constant(epsilon, dtype=x.dtype, name='epsilon')
        x *= tf.math.rsqrt(tf.reduce_mean(tf.square(x), axis=[2,3], keepdims=True) + epsilon)
        x = tf.cast(x, orig_dtype)
        return x
    
def Identity(name):
    return Lambda(lambda x: x, name=name)

def Broadcast(name, dlatent_broadcast=18):
    def broadcast(x):
        return tf.tile(x[:, np.newaxis], [1, dlatent_broadcast, 1])
    return Lambda(lambda x: broadcast(x), name=name)

class Truncation(Layer):
    def __init__(self, name, num_layers=18, truncation_psi=0.7, truncation_cutoff=8):
        super(Truncation, self).__init__(name=name)
        self.num_layers = num_layers
        self.truncation_psi = truncation_psi
        self.truncation_cutoff = truncation_cutoff

    def build(self, input_shape):
        self.dlatent_avg = self.add_variable('dlatent_avg', shape=[int(input_shape[-1])])

    def call(self, inputs):
        layer_idx = np.arange(self.num_layers)[np.newaxis, :, np.newaxis]
        ones = np.ones(layer_idx.shape, dtype=np.float32)
        coefs = tf.where(layer_idx < self.truncation_cutoff, self.truncation_psi * ones, ones)
        
        def lerp(a,b,t): return a + (b - a) * t
        
        return lerp(self.dlatent_avg, inputs, coefs)

class DenseLayer(Dense):
    def __init__(self, units, name, kernel_initializer=GetWeights(), gain=math.sqrt(2), lrmul=1.0):
        super(DenseLayer, self).__init__(units=units, kernel_initializer=kernel_initializer, name=name)
        self.gain = gain
        self.lrmul = lrmul
    
    def call(self, inputs):
        x, b, w = inputs, self.bias * self.lrmul, self.kernel * runtime_coef([1,1], self.gain, inputs.shape[1], self.units, lrmul=self.lrmul)
        
        # Input x kernel
        if len(x.shape) > 2: x = tf.reshape(x, [-1, np.prod([d for d in x.shape[1:]])])
        x = tf.matmul(x, w)
        
        # Bias
        if len(x.shape) == 2:
            return x + b
        
        return x + tf.reshape(b, [1, -1, 1, 1])

class Conv2d(Conv2D):
    def __init__(self, filters, kernel_size, name, gain=math.sqrt(2), lrmul=1.0, kernel_modifier=None, strides=1, use_bias=True):
        super(Conv2d, self).__init__(filters=filters, kernel_size=kernel_size, kernel_initializer=GetWeights(gain), 
                                     use_bias=use_bias, padding='same', data_format='channels_first', name=name, strides=strides)
        self.gain = gain
        self.lrmul = lrmul
        self.kernel_modifier = kernel_modifier

    # Perform convolution with modified kernel then add bias
    def call(self, inputs):
        if self.kernel_modifier is None:
            w = self.kernel
        else:
            w = self.kernel_modifier(self.kernel)
            
        outputs = self._convolution_op(inputs, w * runtime_coef(self.kernel_size, self.gain, inputs.shape[1], self.filters))
        
        if self.use_bias:
            b = self.bias * self.lrmul        
            if self.data_format == 'channels_first':
                outputs = tf.nn.bias_add(outputs, b, data_format='NCHW')
            else:
                outputs = tf.nn.bias_add(outputs, b, data_format='NHWC')

        return outputs

class Const(Layer):
    def __init__(self, name):
        super(Const, self).__init__(name=name)
        
    def build(self, input_shape):
        self.const = self.add_variable('const', shape=[1,512,4,4])
        
    def call(self, inputs):
        return tf.tile(self.const, [tf.shape(inputs)[0], 1, 1, 1])
    
class RandomNoise(Layer):
    def __init__(self, name, layer_idx):
        super(RandomNoise, self).__init__(name=name)
        
        res = layer_idx // 2 + 2        
        self.layer_idx = layer_idx
        self.noise_shape = [1, 1, 2**res, 2**res]
    
    def build(self, input_shape):
        self.noise = self.add_variable('noise', shape=self.noise_shape, initializer=tf.initializers.zeros(), trainable=False)
        
    def call(self, inputs):
        return self.noise
    
class ApplyNoise(Layer):
    def __init__(self, name, is_const_noise):
        super(ApplyNoise, self).__init__(name=name)        
        self.is_const_noise = is_const_noise

    def build(self, input_shape):
        input_shape = input_shape[0]
        self.weight = self.add_variable('weight', shape=[input_shape[1]], initializer=tf.initializers.zeros())
        
    def call(self, inputs):
        x, noise = inputs
        if not self.is_const_noise:
            noise = tf.random.normal([tf.shape(x)[0], 1, x.shape[2], x.shape[3]], dtype=x.dtype)

        return x + noise * tf.reshape(self.weight, [1, -1, 1, 1])
    
class ApplyBias(Layer):
    def __init__(self, name, lrmul=1.0):
        super(ApplyBias, self).__init__(name=name)
        self.lrmul = lrmul
        
    def build(self, input_shape):
        self.bias = self.add_variable('bias', shape=[input_shape[1]])
        
    def call(self, x):
        b = self.bias * self.lrmul
        if len(x.shape) == 2: return x + b
        return x + tf.reshape(b, [1, -1, 1, 1])

class StridedSlice(Layer):
    def __init__(self, layer_idx, name):
        super(StridedSlice, self).__init__(name=name)
        self.layer_idx = layer_idx
    
    def call(self, inputs):
        return inputs[:, self.layer_idx]
    
class StyleModApply(Layer):
    def __init__(self, name):
        super(StyleModApply, self).__init__(name=name)
    
    def call(self, inputs):
        x, style = inputs
        
        style = tf.reshape(style, [-1, 2, x.shape[1]] + [1] * (len(x.shape) - 2))
        return x * (style[:,0] + 1) + style[:,1]

def _blur2d(x, f=[1,2,1], normalize=True, flip=False, stride=1):
    assert x.shape.ndims == 4 and all(dim is not None for dim in x.shape[1:])
    assert isinstance(stride, int) and stride >= 1

    # Finalize filter kernel.
    f = np.array(f, dtype=np.float32)
    if f.ndim == 1:
        f = f[:, np.newaxis] * f[np.newaxis, :]
    assert f.ndim == 2
    if normalize:
        f /= np.sum(f)
    if flip:
        f = f[::-1, ::-1]
    f = f[:, :, np.newaxis, np.newaxis]
    f = np.tile(f, [1, 1, int(x.shape[1]), 1])

    # No-op => early exit.
    if f.shape == (1, 1) and f[0,0] == 1:
        return x

    # Convolve using depthwise_conv2d.
    orig_dtype = x.dtype
    x = tf.cast(x, tf.float32)  # tf.nn.depthwise_conv2d() doesn't support fp16
    f = tf.constant(f, dtype=x.dtype, name='filter')
    strides = [1, 1, stride, stride]
    x = tf.nn.depthwise_conv2d(x, f, strides=strides, padding='SAME', data_format='NCHW')
    x = tf.cast(x, orig_dtype)
    return x

def Blur(name, blur_filter=[1,2,1]):
    def blur2d(x, f=[1,2,1], normalize=True):
        return _blur2d(x, f, normalize)
    return Lambda(lambda x: blur2d(x, blur_filter), name=name)

def _downscale2d(x, factor=2, gain=1):
    assert x.shape.ndims == 4 and all(dim is not None for dim in x.shape[1:])
    assert isinstance(factor, int) and factor >= 1

    # 2x2, float32 => downscale using _blur2d().
    if factor == 2 and x.dtype == tf.float32:
        f = [np.sqrt(gain) / factor] * factor
        return _blur2d(x, f=f, normalize=False, stride=factor)

    # Apply gain.
    if gain != 1:
        x *= gain

    # No-op => early exit.
    if factor == 1:
        return x

    # Large factor => downscale using tf.nn.avg_pool().
    # NOTE: Requires tf_config['graph_options.place_pruned_graph']=True to work.
    ksize = [1, 1, factor, factor]
    return tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding='VALID', data_format='NCHW')

def _upscale2d(x, factor=2, gain=1):
    assert x.shape.ndims == 4 and all(dim is not None for dim in x.shape[1:])
    assert isinstance(factor, int) and factor >= 1

    # Apply gain.
    if gain != 1:
        x *= gain

    # No-op => early exit.
    if factor == 1:
        return x

    # Upscale using tf.tile().
    s = x.shape
    x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1])
    x = tf.tile(x, [1, 1, 1, factor, 1, factor])
    x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor])
    return x
	
def Downscaled2d(name, factor=2, gain=1):
    return Lambda(lambda x: _downscale2d(x, factor, gain), name=name+'/Downscaled2d')
	
def Upscaled2d(name, factor=2, gain=1):
    return Lambda(lambda x: _upscale2d(x, factor, gain), name=name+'/Upscaled2d')

def Conv2d_downscale2d(model, filters, kernel_size, name, gain=math.sqrt(2), fused_scale='auto'):
    if fused_scale == 'auto':
        x = model.layers[-1].output
        fused_scale = min(x.shape[2:]) >= 128
        
    if not fused_scale:
        # Not fused => call the individual ops directly.
        model.add( Conv2d(filters, kernel_size, name, gain) )
        model.add( Downscaled2d(name) )        
    else:
        # Fused => perform both ops simultaneously using tf.nn.conv2d().
        def fused_op(w):
            w = tf.pad(w, [[1,1], [1,1], [0,0], [0,0]], mode='CONSTANT')
            w = tf.add_n([w[1:, 1:], w[:-1, 1:], w[1:, :-1], w[:-1, :-1]]) * 0.25
            return w
        model.add( Conv2d(filters, kernel_size, name, gain, kernel_modifier=fused_op, strides=2) )
		
def Upscale2d_conv2d(x, filters, kernel_size, name, use_bias, gain=math.sqrt(2), fused_scale='auto'):
    if fused_scale == 'auto':
        fused_scale = min(x.shape[2:]) * 2 >= 128

    if not fused_scale:
        x = Upscaled2d(name)(x)
        x = Conv2d(filters, kernel_size, name=name, gain=gain, use_bias=use_bias)(x)
        return x

    return Conv2d_transpose(filters, kernel_size, name, gain, strides=2)(x)

class Conv2d_transpose(Conv2DTranspose):
    def __init__(self, filters, kernel_size, name, gain=math.sqrt(2), lrmul=1.0, kernel_modifier=None, strides=2, use_bias=False):
        
        super(Conv2d_transpose, self).__init__(filters=filters, kernel_size=kernel_size, kernel_initializer=GetWeights(gain), 
                                     use_bias=use_bias, padding='same', data_format='channels_first', name=name, strides=strides)
        self.gain = gain
        self.lrmul = lrmul
        self.kernel_modifier = kernel_modifier
        
    def build(self, input_shape):
        shape = [self.kernel_size[0], self.kernel_size[1], input_shape[1], self.filters]
        self.kernel = self.add_variable('weight', shape=shape, initializer=tf.initializers.zeros())
            
    def call(self, inputs):
        # Fused => perform both ops simultaneously using tf.nn.conv2d_transpose().
        def fused_op(w):
            w = tf.transpose(w, [0, 1, 3, 2]) # [kernel, kernel, fmaps_out, fmaps_in]
            w = tf.pad(w, [[1,1], [1,1], [0,0], [0,0]], mode='CONSTANT')
            w = tf.add_n([w[1:, 1:], w[:-1, 1:], w[1:, :-1], w[:-1, :-1]])
            return w

        x, w = inputs, fused_op(self.kernel * runtime_coef(self.kernel_size, self.gain, inputs.shape[1], self.filters, lrmul=self.lrmul))
        
        os = [tf.shape(inputs)[0], self.filters, inputs.shape[2] * 2, inputs.shape[3] * 2]
        
        outputs = tf.nn.conv2d_transpose(x, w, os, strides=[1,1,2,2], padding='SAME', data_format='NCHW')
        
        return outputs


class MinibatchStddevLayer(tf.keras.layers.Layer):
    def __init__(self, group_size =4, num_new_features=1):
        super().__init__()
        self.group_size = group_size
        self.num_new_features = num_new_features

    def __call__(self, x, *args, **kwargs):
        group_size = tf.minimum(self.group_size,
                                tf.shape(x)[0])  # Minibatch must be divisible by (or smaller than) group_size.
        s = x.shape  # [NCHW]  Input shape.
        y = tf.reshape(x, [group_size, -1, self.num_new_features, s[1] // self.num_new_features, s[2], s[
            3]])  # [GMncHW] Split minibatch into M groups of size G. Split channels into n channel groups c.
        y = tf.cast(y, tf.float32)  # [GMncHW] Cast to FP32.
        y -= tf.reduce_mean(y, axis=0, keepdims=True)  # [GMncHW] Subtract mean over group.
        y = tf.reduce_mean(tf.square(y), axis=0)  # [MncHW]  Calc variance over group.
        y = tf.sqrt(y + 1e-8)  # [MncHW]  Calc stddev over group.
        y = tf.reduce_mean(y, axis=[2, 3, 4], keepdims=True)  # [Mn111]  Take average over fmaps and pixels.
        y = tf.reduce_mean(y, axis=[2])  # [Mn11] Split channels into c channel groups
        y = tf.cast(y, x.dtype)  # [Mn11]  Cast back to original data type.
        y = tf.tile(y, [group_size, 1, s[2], s[3]])  # [NnHW]  Replicate over group and pixels.
        return tf.concat([x, y], axis=1)  # [NCHW]  Append as new fmap.

def minibatch_stddev_layer(x, group_size=4, num_new_features=1):
    with tf.compat.v1.variable_scope('MinibatchStddev'):
        group_size = tf.minimum(group_size, tf.shape(x)[0])     # Minibatch must be divisible by (or smaller than) group_size.
        s = x.shape                                             # [NCHW]  Input shape.
        y = tf.reshape(x, [group_size, -1, num_new_features, s[1]//num_new_features, s[2], s[3]])   # [GMncHW] Split minibatch into M groups of size G. Split channels into n channel groups c.
        y = tf.cast(y, tf.float32)                              # [GMncHW] Cast to FP32.
        y -= tf.reduce_mean(y, axis=0, keepdims=True)           # [GMncHW] Subtract mean over group.
        y = tf.reduce_mean(tf.square(y), axis=0)                # [MncHW]  Calc variance over group.
        y = tf.sqrt(y + 1e-8)                                   # [MncHW]  Calc stddev over group.
        y = tf.reduce_mean(y, axis=[2,3,4], keepdims=True)      # [Mn111]  Take average over fmaps and pixels.
        y = tf.reduce_mean(y, axis=[2])                         # [Mn11] Split channels into c channel groups
        y = tf.cast(y, x.dtype)                                 # [Mn11]  Cast back to original data type.
        y = tf.tile(y, [group_size, 1, s[2], s[3]])             # [NnHW]  Replicate over group and pixels.
        return tf.concat([x, y], axis=1)                        # [NCHW]  Append as new fmap.
		
def StyleGAN_G_mapping( latent_size=512, dlatent_size=512, mapping_layers=8, mapping_fmaps=512, mapping_lrmul=0.01,
                        truncation_psi=1, resolution=1024):

    resolution_log2 = int(np.log2(resolution))
    num_layers = resolution_log2 * 2 - 2

    model = Sequential(name='G_mapping')
    model.add( InputLayer(input_shape=[latent_size], name='G_mapping/latents_in') ) 

    # Normalize latents.
    model.add( PixelNorm(name='G_mapping/PixelNorm') )
    
    # Mapping layers.
    for layer_idx in range(mapping_layers):
        name = 'G_mapping/Dense{}'.format(layer_idx)
        fmaps = dlatent_size if layer_idx == mapping_layers - 1 else mapping_fmaps
        
        model.add( DenseLayer(units=fmaps, kernel_initializer=GetWeights(), name=name, lrmul=mapping_lrmul) )
        model.add( LeakyReLU(alpha=0.2, name=name+'/LeakyReLU') )
        
    # Broadcast.
    model.add( Broadcast(name='G_mapping/Broadcast', dlatent_broadcast=num_layers) )
    
    # Output.
    model.add( Identity(name='G_mapping/dlatents_out') )

    # Apply truncation trick.
    model.add( Truncation(name='Truncation', num_layers=num_layers, truncation_psi=truncation_psi ))
    
    return model

def StyleGAN_G_synthesis(dlatent_size=512, resolution=1024, is_const_noise=True):
    # General parameters
    num_channels = 3
    resolution_log2 = int(np.log2(resolution))
    num_layers = resolution_log2 * 2 - 2

    # Primary inputs.
    dlatents_in = tf.keras.layers.Input(shape=[num_layers, dlatent_size], name='G_synthesis/dlatents_in')
    
    # Noise inputs.
    noise_inputs = []
    for layer_idx in range(num_layers):
        noise_inputs.append( RandomNoise(name='G_synthesis/noise%d'%layer_idx, layer_idx=layer_idx)(dlatents_in) )
        
    # Things to do at the end of each layer.
    def layer_epilogue(x, layer_idx, name):
        name = 'G_synthesis/{}x{}/{}/'.format(x.shape[2], x.shape[2], name)
        
        x = ApplyNoise(name=name+'Noise', is_const_noise=is_const_noise)([x, noise_inputs[layer_idx]])
        x = ApplyBias(name=name+'bias')(x)
        x = LeakyReLU(alpha=0.2, name=name+'LeakyReLU')(x)
        x = InstanceNorm(name=name+'InstanceNorm')(x)       
        
        style = DenseLayer(units=x.shape[1]*2, gain=1, name=name+'StyleMod') (StridedSlice(layer_idx, name=name+'StridedSlice')(dlatents_in))
        x = StyleModApply(name=name+'StyleModApply')([x, style])
        
        return x
    
    # Building blocks for remaining layers.
    def block(res, x): # res = 3..resolution_log2
        name, name0, name1 = '%dx%d' % (2**res, 2**res), 'Conv0_up', 'Conv1'
        
        # Conv0_up
        upscaled = Upscale2d_conv2d(x, name='G_synthesis/{}/{}'.format(name, name0), filters=nf(res-1), kernel_size=3, use_bias=False)   
        x = layer_epilogue( Blur(name='G_synthesis/{}/{}/Blur'.format(name, name0))(upscaled), res*2-4, name0 )
        
        # Conv1
        x = layer_epilogue( Conv2d(name='G_synthesis/{}/{}'.format(name, name1), filters=nf(res-1), kernel_size=3, use_bias=False)(x), res*2-3, name1 )

        return x 
    
    def torgb(res, x): # res = 2..resolution_log2
        lod = resolution_log2 - res        
        return Conv2d(name='G_synthesis/ToRGB_lod%d' % lod, filters=num_channels, kernel_size=1, gain=1, use_bias=True)(x)
    
    # Early layers.
    x = layer_epilogue(Const(name='G_synthesis/4x4/Const')(dlatents_in), 0, name='Const')
    x = layer_epilogue(Conv2d(name='G_synthesis/4x4/Conv', filters=nf(1), kernel_size=3, use_bias=False)(x), 1, 'Conv')
    
    # Fixed structure: simple and efficient, but does not support progressive growing.
    for res in range(3, resolution_log2 + 1):
        x = block(res, x)
            
    x = torgb(resolution_log2, x)

    # change output to the default NHWC format, so it will be compatible with other networks
    x = tf.transpose(x, (0, 2, 3, 1))

    return Model(inputs=dlatents_in, outputs=x, name='G_synthesis')


class StyleGAN_G(Model):
    def __init__(self, resolution=1024, latent_size=512, dlatent_size=512, mapping_layers=8, mapping_fmaps=512,
                 mapping_lrmul=0.01, truncation_psi=1):
        super(StyleGAN_G, self).__init__()
        self.model_mapping = StyleGAN_G_mapping(latent_size, dlatent_size, mapping_layers, mapping_fmaps,
                                                mapping_lrmul, truncation_psi, resolution)
        self.model_synthesis = StyleGAN_G_synthesis(dlatent_size, resolution)
        print('Model created.')
        
    def call(self, inputs):
        x = self.model_mapping(inputs)
        x = self.model_synthesis(x)
        return x
    
    def generate_sample(self, seed=5, is_visualize=False):
        rnd = np.random.RandomState(seed)
        latents = rnd.randn(1, 512)

        y = self.predict(latents)

        images = y.transpose([0, 2, 3, 1])
        images = np.clip((images+1)*0.5, 0, 1)
        
        # print(images.shape, np.min(images), np.max(images))

        plt.figure(figsize=(10, 10))
        plt.imshow(images[0])
        if is_visualize:
            plt.show()

        return images
    
class StyleGAN_D(Model):
    def __init__(self, resolution=1024, mbstd_group_size=4, mbstd_num_features=1):
        super(StyleGAN_D, self).__init__()

        resolution_log2 = int(math.log2(resolution))

        model = Sequential(name='Discriminator')
        model.add(InputLayer(input_shape=[3, resolution, resolution])) 

        def fromrgb(res):
            name = 'FromRGB_lod%d' % (resolution_log2 - res)
            model.add( Conv2d(filters=nf(res-1), kernel_size=1, name=name) )
            model.add( LeakyReLU(alpha=0.2, name=name+'/LeakyReLU') )

        def block(res):
            name = '%dx%d' % (2**res, 2**res)
            if res >= 3: # 8x8 and up
                model.add( Conv2d(filters=nf(res-1), kernel_size=3, name=name+'/Conv0') )
                model.add( LeakyReLU(alpha=0.2, name=name+'/Conv0/LeakyReLU') )

                model.add( Blur(name=name+'/Blur') )
                Conv2d_downscale2d(model=model, filters=nf(res-2), kernel_size=3, name=name+'/Conv1_down')
                model.add( LeakyReLU(alpha=0.2, name=name+'/Conv1_down/LeakyReLU') )

            else: # 4x4
                if mbstd_group_size > 1: 
                    model.add( Lambda(lambda x: minibatch_stddev_layer(x, mbstd_group_size, mbstd_num_features), name=name+'/MinibatchStddev') )

                model.add( Conv2d(filters=nf(res-1), kernel_size=3, name=name+'/Conv') )
                model.add( LeakyReLU(alpha=0.2, name=name+'/Conv/LeakyReLU') )

                model.add( Flatten() )
                model.add( DenseLayer(units=nf(res-2), kernel_initializer=GetWeights(), name=name+'/Dense0') )
                model.add( LeakyReLU(alpha=0.2, name=name+'/Dense0/LeakyReLU') )

                model.add( DenseLayer(units=1, kernel_initializer=GetWeights(1), gain=1, name=name+'/Dense1') )   

        # Blocks
        fromrgb(resolution_log2)
        for res in range(resolution_log2, 2, -1): block(res)
        block(2)

        self.model = model

    def call(self, inputs):
        inputs = tf.transpose(inputs, (0, 3, 1, 2))
        return self.model(inputs)

def copy_weights_to_keras_model(model, all_weights):
    c = 0
    od = all_weights
    for l in model.layers:
      try:
        values = l.get_weights()
        weights = list(map(lambda x: x.shape, values))
        if not len(weights): continue

        num_params = values[0].size

        # Special weights
        if len(weights) == 1:    
            weights_list = []

            # The learned constant variable
            if l.name == 'G_synthesis/4x4/Const': weights_list.append( od[l.name+'/const'] )

            # Truncation trick variable
            if 'Truncation' in l.name: weights_list.append( od['dlatent_avg'] )

            # Input noise
            if 'G_synthesis/noise' in l.name: weights_list.append( od[l.name] ) 

            # Noise variables
            if 'Noise' in l.name: weights_list.append( od[l.name+'/weight'] )

            # Bias variables
            if 'bias' in l.name: weights_list.append( od[l.name] )

            # Conv with no bias
            if l.name.endswith('Conv') or l.name.endswith('Conv1') or l.name.endswith('Conv0_up'):
                weights_list.append( od[l.name+'/weight'] )

            if len(weights_list) > 0:
                l.set_weights( weights_list )
                print('.', end='')
                c = c + num_params
            else:
                print('WARNING: weights not found for ', l.name, '  of size', weights[0])
        else:  
        # Standard weights (weight + bias)
            assert len(weights) == 2

            num_params = num_params + values[1].size

            layer_name = l.name
            var_names = ['{}/{}'.format(layer_name, 'weight'), '{}/{}'.format(layer_name, 'bias')]

            if var_names[0] in od and var_names[1] in od:
                weight = od[var_names[0]]
                bias = od[var_names[1]]

                l.set_weights( [ weight, bias ] )

                print('.', end='')
                c = c + num_params
            else:
                print('WARNING: not found', var_names)
      except Exception as e:
          print(e)
          print('skipping...')

    print('Total number of parameters copied:', c)


================================================
FILE: test.py
================================================
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['USE_SIMPLE_THREADED_LEVEL3'] = '1'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

import sys
from model.stylegan import StyleGAN_G_synthesis
from model.model import Network
from writer import Writer
from inference import Inference
from arglib import arglib
import utils

sys.path.insert(0, 'model/face_utils')



def main():
    test_args = arglib.TestArgs()
    args, str_args = test_args.args, test_args.str_args
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    Writer.set_writer(args.results_dir)

    id_model_path = args.pretrained_models_path.joinpath('vggface2.h5')
    stylegan_G_synthesis_path = str(
        args.pretrained_models_path.joinpath(f'stylegan_G_{args.resolution}x{args.resolution}_synthesis'))

    utils.landmarks_model_path = str(args.pretrained_models_path.joinpath('shape_predictor_68_face_landmarks.dat'))

    stylegan_G_synthesis = StyleGAN_G_synthesis(resolution=args.resolution, is_const_noise=args.const_noise)
    stylegan_G_synthesis.load_weights(stylegan_G_synthesis_path)

    network = Network(args, id_model_path, stylegan_G_synthesis)

    network.test()
    inference = Inference(args, network)
    test_func = getattr(inference, args.test_func)
    test_func()


if __name__ == '__main__':
    main()


================================================
FILE: trainer.py
================================================
import logging

import numpy as np
import tensorflow as tf

from writer import Writer
from utils import general_utils as utils


def id_loss_func(y_gt, y_pred):
    return tf.reduce_mean(tf.keras.losses.MAE(y_gt, y_pred))


class Trainer(object):
    def __init__(self, args, model, data_loader):
        self.args = args
        self.logger = logging.getLogger(__class__.__name__)

        self.model = model
        self.data_loader = data_loader

        # lrs & optimizers
        lr = 5e-5 if self.args.resolution == 256 else 1e-5

        self.g_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
        self.g_gan_optimizer = tf.keras.optimizers.Adam(learning_rate=0.1 * lr)
        self.w_d_optimizer = tf.keras.optimizers.Adam(learning_rate=0.4 * lr)

        self.im_d_optimizer = tf.keras.optimizers.Adam(learning_rate=0.4 * lr)

        # Losses
        self.gan_loss_func = tf.keras.losses.BinaryCrossentropy(from_logits=True)
        self.pixel_loss_func = tf.keras.losses.MeanAbsoluteError(tf.keras.losses.Reduction.SUM)

        self.id_loss_func = id_loss_func

        if args.pixel_mask_type == 'gaussian':
            sigma = int(80 * (self.args.resolution / 256))
            self.pixel_mask = utils.inverse_gaussian_image(self.args.resolution, sigma)
        else:
            self.pixel_mask = tf.ones([self.args.resolution, self.args.resolution])
            self.pixel_mask = self.pixel_mask / tf.reduce_sum(self.pixel_mask)

        self.pixel_mask = tf.broadcast_to(self.pixel_mask, [self.args.batch_size, *self.pixel_mask.shape])

        self.num_epoch = 0
        self.is_cross_epoch = False

        # Lambdas
        if args.unified:
            self.lambda_gan = 0.5
        else:
            self.lambda_gan = 1

        self.lambda_pixel = 0.02

        self.lambda_id = 1
        self.lambda_attr_id = 1
        self.lambda_landmarks = 0.001
        self.r1_gamma = 10

        # Test
        self.test_not_imporved = 0
        self.max_id_preserve = 0
        self.min_lnd_dist = np.inf

    def train(self):
        while self.num_epoch <= self.args.num_epochs:
            self.logger.info('---------------------------------------')
            self.logger.info(f'Start training epoch: {self.num_epoch}')

            if self.args.cross_frequency and (self.num_epoch % self.args.cross_frequency == 0):
                self.is_cross_epoch = True
                self.logger.info('This epoch is cross-face')
            else:
                self.is_cross_epoch = False
                self.logger.info('This epoch is same-face')

            try:
                if self.num_epoch % self.args.test_frequency == 0:
                    self.test()

                self.train_epoch()

            except Exception as e:
                self.logger.exception(e)
                raise

            if self.test_not_imporved > self.args.not_improved_exit:
                self.logger.info(f'Test has not improved for {self.args.not_improved_exit} epochs. Exiting...')
                break

            self.num_epoch += 1

    def train_epoch(self):
        id_loss = 0
        landmarks_loss = 0
        g_w_gan_loss = 0
        pixel_loss = 0
        w_d_loss = 0
        w_loss = 0

        self.logger.info(f'train in epoch: {self.num_epoch}')
        self.model.train()

        use_w_d = self.args.W_D_loss

        # if use_w_d and use_im_d and not self.args.unified:
        if not self.args.unified:
            if self.num_epoch % 2 == 0:
                # This epoch is not using image_D
                use_im_d = False
                # self.logger.info(f'Not using Image D in epoch: {self.num_epoch}')
            if self.num_epoch % 2 != 0:
                # This epoch is not using W_D
                use_w_d = False
                # self.logger.info(f'Not using W_d in epoch: {self.num_epoch}')

        attr_img, id_img, real_w, real_img, matching_ws = self.data_loader.get_batch(is_cross=self.is_cross_epoch)

        # Forward that does not require grads
        id_embedding = self.model.G.id_encoder(id_img)
        src_landmarks = self.model.G.landmarks(attr_img)
        attr_input = attr_img

        with tf.GradientTape(persistent=True) as g_tape:

            attr_out = self.model.G.attr_encoder(attr_input)
            attr_embedding = attr_out

            self.logger.info(f'attr embedding stats- mean: {tf.reduce_mean(tf.abs(attr_embedding)):.5f},'
                             f' variance: {tf.math.reduce_variance(attr_embedding):.5f}')

            z_tag = tf.concat([id_embedding, attr_embedding], -1)
            w = self.model.G.latent_spaces_mapping(z_tag)
            fake_w = w[:, 0, :]

            self.logger.info(
                f'w stats- mean: {tf.reduce_mean(tf.abs(fake_w)):.5f}, variance: {tf.math.reduce_variance(fake_w):.5f}')

            pred = self.model.G.stylegan_s(w)

            # Move to roughly [0,1]
            pred = (pred + 1) / 2

            if use_w_d:
                with tf.GradientTape() as w_d_tape:
                    fake_w_logit = self.model.W_D(fake_w)
                    g_w_gan_loss = self.generator_gan_loss(fake_w_logit)

                    self.logger.info(f'g W loss is {g_w_gan_loss:.3f}')
                    self.logger.info(f'fake W logit: {tf.squeeze(fake_w_logit)}')

                    with g_tape.stop_recording():
                        real_w_logit = self.model.W_D(real_w)
                        w_d_loss = self.discriminator_loss(fake_w_logit, real_w_logit)
                        w_d_total_loss = w_d_loss

                        if self.args.gp:
                            w_d_gp = self.R1_gp(self.model.W_D, real_w)
                            w_d_total_loss += w_d_gp
                            self.logger.info(f'w_d_gp : {w_d_gp}')

                        self.logger.info(f'W_D loss is {w_d_loss:.3f}')
                        self.logger.info(f'real W logit: {tf.squeeze(real_w_logit)}')

            if self.args.id_loss:
                pred_id_embedding = self.model.G.id_encoder(pred)
                id_loss = self.lambda_id * id_loss_func(pred_id_embedding, tf.stop_gradient(id_embedding))
                self.logger.info(f'id loss is {id_loss:.3f}')

            if self.args.landmarks_loss:
                try:
                    dst_landmarks = self.model.G.landmarks(pred)
                except Exception as e:
                    self.logger.warning(f'Failed finding landmarks on prediction. Dont use landmarks loss. Error:{e}')
                    dst_landmarks = None

                if dst_landmarks is None or src_landmarks is None:
                    landmarks_loss = 0
                else:
                    landmarks_loss = self.lambda_landmarks * \
                                     tf.reduce_mean(tf.keras.losses.MSE(src_landmarks, dst_landmarks))
                    self.logger.info(f'landmarks loss is: {landmarks_loss:.3f}')
                    # if landmarks_loss > 5:
                    #     landmarks_loss = 0
                    #     id_loss = 0

            if not self.is_cross_epoch and self.args.pixel_loss:
                l1_loss = self.pixel_loss_func(attr_img, pred, sample_weight=self.pixel_mask)
                self.logger.info(f'L1 pixel loss is {l1_loss:.3f}')

                if self.args.pixel_loss_type == 'mix':
                    mssim = tf.reduce_mean(1 - tf.image.ssim_multiscale(attr_img, pred, 1.0))
                    self.logger.info(f'mssim loss is {l1_loss:.3f}')
                    pixel_loss = self.lambda_pixel * (0.84 * mssim + 0.16 * l1_loss)
                else:
                    pixel_loss = self.lambda_pixel * l1_loss

                self.logger.info(f'pixel loss is {pixel_loss:.3f}')

            g_gan_loss = g_w_gan_loss

            total_g_not_gan_loss = id_loss \
                                   + landmarks_loss \
                                   + pixel_loss \
                                   + w_loss

            self.logger.info(f'total G (not gan) loss is {total_g_not_gan_loss:.3f}')
            self.logger.info(f'G gan loss is {g_gan_loss:.3f}')

        Writer.add_scalar('loss/landmarks_loss', landmarks_loss, step=self.num_epoch)
        Writer.add_scalar('loss/total_g_not_gan_loss', total_g_not_gan_loss, step=self.num_epoch)

        Writer.add_scalar('loss/id_loss', id_loss, step=self.num_epoch)

        if use_w_d:
            Writer.add_scalar('loss/g_w_gan_loss', g_w_gan_loss, step=self.num_epoch)
            Writer.add_scalar('loss/W_D_loss', w_d_loss, step=self.num_epoch)
            if self.args.gp:
                Writer.add_scalar('loss/w_d_gp', w_d_gp, step=self.num_epoch)

        if not self.is_cross_epoch:
            Writer.add_scalar('loss/pixel_loss', pixel_loss, step=self.num_epoch)
            Writer.add_scalar('loss/w_loss', w_loss, step=self.num_epoch)

        if self.args.debug or \
                (self.num_epoch < 1e3 and self.num_epoch % 1e2 == 0) or \
                (self.num_epoch < 1e4 and self.num_epoch % 1e3 == 0) or \
                (self.num_epoch % 1e4 == 0):
            utils.save_image(pred[0], self.args.images_results.joinpath(f'{self.num_epoch}_prediction_step.png'))
            utils.save_image(id_img[0], self.args.images_results.joinpath(f'{self.num_epoch}_id_step.png'))
            utils.save_image(attr_img[0], self.args.images_results.joinpath(f'{self.num_epoch}_attr_step.png'))

            Writer.add_image('input/id image', tf.expand_dims(id_img[0], 0), step=self.num_epoch)
            Writer.add_image('Prediction', tf.expand_dims(pred[0], 0), step=self.num_epoch)

        if total_g_not_gan_loss != 0:
            g_grads = g_tape.gradient(total_g_not_gan_loss, self.model.G.trainable_variables)

            g_grads_global_norm = tf.linalg.global_norm(g_grads)
            self.logger.info(f'global norm G not gan grad: {g_grads_global_norm}')

            self.g_optimizer.apply_gradients(zip(g_grads, self.model.G.trainable_variables))

        if use_w_d:
            g_gan_grads = g_tape.gradient(g_gan_loss, self.model.G.trainable_variables)

            g_gan_grad_global_norm = tf.linalg.global_norm(g_gan_grads)
            self.logger.info(f'global norm G gan grad: {g_gan_grad_global_norm}')

            self.g_gan_optimizer.apply_gradients(zip(g_gan_grads, self.model.G.trainable_variables))

            w_d_grads = w_d_tape.gradient(w_d_total_loss, self.model.W_D.trainable_variables)

            self.logger.info(f'global W_D gan grad: {tf.linalg.global_norm(w_d_grads)}')
            self.w_d_optimizer.apply_gradients(zip(w_d_grads, self.model.W_D.trainable_variables))

        del g_tape

    # Common

    # Test
    def test(self):
        self.logger.info(f'Testing in epoch: {self.num_epoch}')
        self.model.test()

        similarities = {'id_to_pred': [], 'id_to_attr': [], 'attr_to_pred': []}

        fake_reconstruction = {'MSE': [], 'PSNR': [], 'ID': []}
        real_reconstruction = {'MSE': [], 'PSNR': [], 'ID': []}

        if self.args.test_with_arcface:
            test_similarities = {'id_to_pred': [], 'id_to_attr': [], 'attr_to_pred': []}

        lnd_dist = []

        for i in range(self.args.test_size):
            attr_img, id_img = self.data_loader.get_batch(is_train=False, is_cross=True)

            pred, id_embedding, w, attr_embedding, src_lnds = self.model.G(id_img, attr_img)
            image = tf.clip_by_value(pred, 0, 1)

            pred_id = self.model.G.id_encoder(image)
            attr_id = self.model.G.id_encoder(attr_img)

            similarities['id_to_pred'].extend(tf.keras.losses.cosine_similarity(id_embedding, pred_id).numpy())
            similarities['id_to_attr'].extend(tf.keras.losses.cosine_similarity(id_embedding, attr_id).numpy())
            similarities['attr_to_pred'].extend(tf.keras.losses.cosine_similarity(attr_id, pred_id).numpy())

            if self.args.test_with_arcface:
                try:
                    arc_id_embedding = self.model.G.test_id_encoder(id_img)
                    arc_pred_id = self.model.G.test_id_encoder(image)
                    arc_attr_id = self.model.G.test_id_encoder(attr_img)

                    test_similarities['id_to_attr'].extend(
                        tf.keras.losses.cosine_similarity(arc_id_embedding, arc_attr_id).numpy())
                    test_similarities['id_to_pred'].extend(
                        tf.keras.losses.cosine_similarity(arc_id_embedding, arc_pred_id).numpy())
                    test_similarities['attr_to_pred'].extend(
                        tf.keras.losses.cosine_similarity(arc_attr_id, arc_pred_id).numpy())
                except Exception as e:
                    self.logger.warning(f'Not calculating test similarities for iteration: {i} because: {e}')

            # Landmarks
            dst_lnds = self.model.G.landmarks(image)
            lnd_dist.extend(tf.reduce_mean(tf.keras.losses.MSE(src_lnds, dst_lnds), axis=-1).numpy())

            # Fake Reconstruction
            self.test_reconstruction(id_img, fake_reconstruction, display=(i==0), display_name='id_img')

            if self.args.test_real_attr:
                # Real Reconstruction
                self.test_reconstruction(attr_img, real_reconstruction, display=(i==0), display_name='attr_img')

            if i == 0:
                utils.save_image(image[0], self.args.images_results.joinpath(f'test_prediction_{self.num_epoch}.png'))
                utils.save_image(id_img[0], self.args.images_results.joinpath(f'test_id_{self.num_epoch}.png'))
                utils.save_image(attr_img[0],
                                 self.args.images_results.joinpath(f'test_attr_{self.num_epoch}.png'))

                Writer.add_image('test/prediction', image, step=self.num_epoch)
                Writer.add_image('test input/id image', id_img, step=self.num_epoch)
                Writer.add_image('test input/attr image', attr_img, step=self.num_epoch)

                for j in range(np.minimum(3, src_lnds.shape[0])):
                    src_xy = src_lnds[j]  # GT
                    dst_xy = dst_lnds[j]  # pred

                    attr_marked = utils.mark_landmarks(attr_img[j], src_xy, color=(0, 0, 0))
                    pred_marked = utils.mark_landmarks(pred[j], src_xy, color=(0, 0, 0))
                    pred_marked = utils.mark_landmarks(pred_marked, dst_xy, color=(255, 112, 112))

                    Writer.add_image(f'landmarks/overlay-{j}', pred_marked, step=self.num_epoch)
                    Writer.add_image(f'landmarks/src-{j}', attr_marked, step=self.num_epoch)

        # Similarity
        self.logger.info('Similarities:')
        for k, v in similarities.items():
            self.logger.info(f'{k}: MEAN: {np.mean(v)}, STD: {np.std(v)}')

        mean_lnd_dist = np.mean(lnd_dist)
        self.logger.info(f'Mean landmarks L2: {mean_lnd_dist}')

        id_to_pred = np.mean(similarities['id_to_pred'])
        attr_to_pred = np.mean(similarities['attr_to_pred'])
        mean_disen = attr_to_pred - id_to_pred

        Writer.add_scalar('similarity/score', mean_disen, step=self.num_epoch)
        Writer.add_scalar('similarity/id_to_pred', id_to_pred, step=self.num_epoch)
        Writer.add_scalar('similarity/attr_to_pred', attr_to_pred, step=self.num_epoch)

        if self.args.test_with_arcface:
            arc_id_to_pred = np.mean(test_similarities['id_to_pred'])
            arc_attr_to_pred = np.mean(test_similarities['attr_to_pred'])
            arc_mean_disen = arc_attr_to_pred - arc_id_to_pred

            Writer.add_scalar('arc_similarity/score', arc_mean_disen, step=self.num_epoch)
            Writer.add_scalar('arc_similarity/id_to_pred', arc_id_to_pred, step=self.num_epoch)
            Writer.add_scalar('arc_similarity/attr_to_pred', arc_attr_to_pred, step=self.num_epoch)

        self.logger.info(f'Mean disentanglement score is {mean_disen}')

        Writer.add_scalar('landmarks/L2', np.mean(lnd_dist), step=self.num_epoch)

        # Reconstruction
        if self.args.test_real_attr:
            Writer.add_scalar('reconstruction/real_MSE', np.mean(real_reconstruction['MSE']), step=self.num_epoch)
            Writer.add_scalar('reconstruction/real_PSNR', np.mean(real_reconstruction['PSNR']), step=self.num_epoch)
            Writer.add_scalar('reconstruction/real_ID', np.mean(real_reconstruction['ID']), step=self.num_epoch)

        Writer.add_scalar('reconstruction/fake_MSE', np.mean(fake_reconstruction['MSE']), step=self.num_epoch)
        Writer.add_scalar('reconstruction/fake_PSNR', np.mean(fake_reconstruction['PSNR']), step=self.num_epoch)
        Writer.add_scalar('reconstruction/fake_ID', np.mean(fake_reconstruction['ID']), step=self.num_epoch)

        if mean_lnd_dist < self.min_lnd_dist:
            self.logger.info('Minimum landmarks dist achieved. saving checkpoint')
            self.test_not_imporved = 0
            self.min_lnd_dist = mean_lnd_dist
            self.model.my_save(f'_best_landmarks_epoch_{self.num_epoch}')

        if np.abs(id_to_pred) > self.max_id_preserve:
            self.logger.info(f'Max ID preservation achieved! saving checkpoint')
            self.test_not_imporved = 0
            self.max_id_preserve = np.abs(id_to_pred)
            self.model.my_save(f'_best_id_epoch_{self.num_epoch}')

        else:
            self.test_not_imporved += 1

    def test_reconstruction(self, img, errors_dict, display=False, display_name=None):
        pred, id_embedding, w, attr_embedding, src_lnds = self.model.G(img, img)

        recon_image = tf.clip_by_value(pred, 0, 1)
        recon_pred_id = self.model.G.id_encoder(recon_image)

        mse = tf.reduce_mean((img - recon_image) ** 2, axis=[1, 2, 3]).numpy()
        psnr = tf.image.psnr(img, recon_image, 1).numpy()

        errors_dict['MSE'].extend(mse)
        errors_dict['PSNR'].extend(psnr)
        errors_dict['ID'].extend(tf.keras.losses.cosine_similarity(id_embedding, recon_pred_id).numpy())

        if display:
            Writer.add_image(f'reconstruction/{display_name}', pred, step=self.num_epoch)

    # Helpers

    def generator_gan_loss(self, fake_logit):
        """
        G logistic non saturating loss, to be minimized
        """
        g_gan_loss = self.gan_loss_func(tf.ones_like(fake_logit), fake_logit)
        return self.lambda_gan * g_gan_loss

    def discriminator_loss(self, fake_logit, real_logit):
        """
        D logistic loss, to be minimized
        verified as identical to StyleGAN's loss.D_logistic
        """
        fake_gt = tf.zeros_like(fake_logit)
        real_gt = tf.ones_like(real_logit)

        d_fake_loss = self.gan_loss_func(fake_gt, fake_logit)
        d_real_loss = self.gan_loss_func(real_gt, real_logit)

        d_loss = d_real_loss + d_fake_loss

        return self.lambda_gan * d_loss

    def R1_gp(self, D, x):
        with tf.GradientTape() as t:
            t.watch(x)
            pred = D(x)
            pred_sum = tf.reduce_sum(pred)

        grad = t.gradient(pred_sum, x)

        # Reshape as a vector
        norm = tf.norm(tf.reshape(grad, [tf.shape(grad)[0], -1]), axis=1)
        gp = tf.reduce_mean(norm ** 2)
        gp = 0.5 * self.r1_gamma * gp

        return gp


================================================
FILE: utils/__init__.py
================================================


================================================
FILE: utils/general_utils.py
================================================
from pathlib import Path

import cv2
import numpy as np
from PIL import Image
import tensorflow as tf
from tensorflow.keras.initializers import VarianceScaling
import numbers
import scipy
import dlib

landmarks_model_path = None


def read_image(img_path, resolution, align=False):
    if align:
        img = read_and_align_image(img_path, resolution)
    else:
        img = read_SG_image(img_path, resolution)

    return img


def find_file_by_str(search_dir, s):
    files = [f for f in search_dir.iterdir() if s in f.name]
    return files


def read_SG_image(img_path, size=256, resize=True):
    img = Image.open(str(img_path))
    img = img.convert('RGB')

    if img.size != (size, size) and resize:
        img = img.resize((size, size))
    img = np.asarray(img)

    img = np.expand_dims(img, axis=0)

    # Images in [0, 1]
    img = np.float32(img) / 255

    return img


def read_and_align_image(img_path, output_size=1024):
    global landmarks_model_path
    if not landmarks_model_path:
        raise ValueError('Please init the landmarks model path')

    transform_size = 4096
    enable_padding = True

    img = Image.open(img_path)
    img = img.convert('RGB')
    npimg = np.asarray(img)

    # states is a 4x1 array with confidence for : [left eye closed, right eye closed, mouth closed, mouth open big]
    face_detector = dlib.get_frontal_face_detector()
    landmarks_network = dlib.shape_predictor(landmarks_model_path)

    try:
        bbox = face_detector(npimg, 0)[0]
    except:
        print('face not found!')
        raise

    # rect = np.array([det.left(), det.top(), det.right(), det.bottom()])
    shape = landmarks_network(npimg, bbox)
    lm = np.array([[shape.part(n).x + 0.5, shape.part(n).y + 0.5] for n in range(shape.num_parts)])

    lm = np.round(lm) + 0.5

    lm_chin = lm[0: 17]  # left-right
    lm_eyebrow_left = lm[17: 22]  # left-right
    lm_eyebrow_right = lm[22: 27]  # left-right
    lm_nose = lm[27: 31]  # top-down
    lm_nostrils = lm[31: 36]  # top-down
    lm_eye_left = lm[36: 42]  # left-clockwise
    lm_eye_right = lm[42: 48]  # left-clockwise
    lm_mouth_outer = lm[48: 60]  # left-clockwise
    lm_mouth_inner = lm[60: 68]  # left-clockwise

    # Calculate auxiliary vectors.
    eye_left = np.mean(lm_eye_left, axis=0)
    eye_right = np.mean(lm_eye_right, axis=0)

    eye_avg = (eye_left + eye_right) * 0.5

    # nose_mock_avg = (lm_nose[0] + lm_nose[1]) * 0.5
    # eye_avg = nose_mock_avg

    eye_to_eye = eye_right - eye_left
    mouth_left = lm_mouth_outer[0]
    mouth_right = lm_mouth_outer[6]
    mouth_avg = (mouth_left + mouth_right) * 0.5
    eye_to_mouth = mouth_avg - eye_avg

    # Choose oriented crop rectangle.
    x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
    x /= np.hypot(*x)
    x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)

    y = np.flipud(x) * [-1, 1]

    c = eye_avg + eye_to_mouth * 0.1

    quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
    qsize = np.hypot(*x) * 2

    # Shrink.
    shrink = int(np.floor(qsize / output_size * 0.5))
    if shrink > 1:
        rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
        img = img.resize(rsize, Image.ANTIALIAS)
        quad /= shrink
        qsize /= shrink

    # Crop.
    border = max(int(np.rint(qsize * 0.1)), 3)
    crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
            int(np.ceil(max(quad[:, 1]))))
    crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
            min(crop[3] + border, img.size[1]))
    crop = np.array(crop)
    if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
        img = img.crop(tuple(crop))
        quad -= crop[0:2]

    # Pad.
    pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
           int(np.ceil(max(quad[:, 1]))))
    pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
           max(pad[3] - img.size[1] + border, 0))
    if enable_padding and max(pad) > border - 4:
        pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
        img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
        h, w, _ = img.shape
        y, x, _ = np.ogrid[:h, :w, :1]
        mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
                          1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
        blur = qsize * 0.02
        img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
        img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
        img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
        quad += pad[:2]

    # Transform.
    img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(),
                        Image.BILINEAR)

    if output_size < transform_size:
        img = img.resize((output_size, output_size), Image.ANTIALIAS)

    img = np.asarray(img)
    img = np.expand_dims(img, axis=0)
    img = np.float32(img) / 255

    return img


def gaussian_image(size, sigma, dim=2):
    if isinstance(size, numbers.Number):
        size = [size] * dim
    if isinstance(sigma, numbers.Number):
        sigma = [sigma] * dim

    weight_kernel = 1
    meshgrids = np.meshgrid(*[np.arange(size, dtype=np.float32) for size in size])
    # The gaussian kernel is the product of the
    # gaussian function of each dimension.
    for size, std, mgrid in zip(size, sigma, meshgrids):
        mean = (size - 1) / 2
        weight_kernel *= 1 / (std * np.sqrt(2 * np.pi)) * np.exp(-((mgrid - mean) / std) ** 2 / 2)

    weight_kernel = weight_kernel / np.sum(weight_kernel)
    return weight_kernel


def inverse_gaussian_image(size, sigma, dim=2):
    gauss = gaussian_image(size, sigma, dim)

    # Inversion achieved by max - gauss, but adding min as well to
    # prevent regions of zeros which don't exist in normal gaussian
    inv_gauss = np.max(gauss) + np.min(gauss) - gauss
    inv_gauss = inv_gauss / np.sum(inv_gauss)

    return inv_gauss


def is_float(tensor):
    """
    Check if input tensor is float32, tensor maybe tf.Tensor or np.array
    """

    return (isinstance(tensor, tf.Tensor) and tensor.dtype != tf.dtypes.uint8) or tensor.dtype != np.uint8


def convert_tensor_to_image(tensor):
    """
    Converts tensor to image, and saturate output's range
    :param tensor: tf.Tensor, dtype float32, range [0,1]
    :return: np.array, dtype uint8, range [0, 255]
    """
    if is_float(tensor):
        tensor = tf.clip_by_value(tensor, 0., 1.)
        tensor = 255 * tensor

    if tensor.ndim == 4 and tensor.shape[0] == 1:
        tensor = tf.squeeze(tensor)

    tensor = np.uint8(np.round(tensor))

    return tensor


def save_image(img, file_path):
    """
    :param img: Could be either tf tensor or numpy array
    :param file_path:
    """

    if isinstance(file_path, Path):
        file_path = str(file_path)

    img = convert_tensor_to_image(img)
    img = Image.fromarray(img)
    img.save(file_path)


def mark_landmarks(img, lnd, color=None):
    """
    landmarks in (x,y) format
    """
    img = convert_tensor_to_image(img)
    radius = int(img.shape[0] / 256)

    lnd = (img.shape[0] / 160) * lnd

    if not color:
        color = (255, 255, 255)

    for i in range(lnd.shape[0]):
        x_y = lnd[i]
        img = cv2.circle(img, center=(int(x_y[0]), int(x_y[1])),
                         color=color, radius=radius, thickness=-1)

    return img


def get_weights(slope=0.2):
    """
    The scale is calculated according to:
        https://pytorch.org/docs/stable/nn.init.html
    and
        https://towardsdatascience.com/weight-initialization-in-neural-networks-a-journey-from-the-basics-to-kaiming-954fb9b47c79

    For ReLU and LeakyReLU activations, the preferable initialization is kaiming.

    In Pytorch, the gain for LeakyReLU is calcaulted by: sqrt(2 / ( 1 + leaky_relu_slope ^ 2)) and the weights are
    sampled from N(0, std^2) where std = gain / sqrt(fan_in)

    To mimic this in TF, I am using VarianceScaling. The weights are sampled from N(0, std^2)
    where std = sqrt(scale / fan_in). Therefore, scale = gain^2
    """
    scale = 2 / (1 + slope ** 2)
    return VarianceScaling(scale)


def np_permute(tensor, permute):
    idx = np.empty_like(permute)
    idx[permute] = np.arange(len(permute))
    return tensor[:, idx]


================================================
FILE: utils/generate_fake_data.py
================================================
import sys
from pathlib import Path
import os

sys.path.append('..')

import argparse

import tensorflow as tf
import numpy as np
from tqdm import tqdm

from utils.general_utils import save_image
from model.stylegan import StyleGAN_G


def main(args):
    base_dir = Path(args.output_path).joinpath(f'dataset_{args.resolution}')

    base_w_dir = base_dir.joinpath('ws')
    base_w_dir.mkdir(parents=True, exist_ok=True)

    base_im_dir = base_dir.joinpath('images')
    base_im_dir.mkdir(parents=True, exist_ok=True)

    existing_files = list(base_dir.joinpath('images').iterdir())
    if existing_files:
        max_exist = max([int(x.name) for x in existing_files])
        max_exist = int(max_exist - max_exist % 1e3 + 1e3)
    else:
        max_exist = 0

    stylegan_G_path = args.pretrained_models_path.joinpath(f'stylegan_G_{args.resolution}x{args.resolution}.h5')
    stylegan_G = StyleGAN_G(resolution=args.resolution, truncation_psi=args.truncation)
    stylegan_G.load_weights(str(stylegan_G_path))

    num_samples = args.num_images
    batch_size = args.batch_size
    num_batches = int(num_samples / batch_size)

    curr_ind = max_exist
    for _ in tqdm(range(num_batches)):
        z = tf.random.normal((batch_size, 512))
        w = stylegan_G.model_mapping(z)
        images = stylegan_G.model_synthesis(w)
        images = (images + 1) / 2

        if curr_ind % 1000 == 0:
            curr_w_dir = base_w_dir.joinpath(f'{curr_ind:05d}')
            curr_w_dir.mkdir(exist_ok=True)

            curr_im_dir = base_im_dir.joinpath(f'{curr_ind:05d}')
            curr_im_dir.mkdir(exist_ok=True)

        for j in range(batch_size):
            w_path = curr_w_dir.joinpath(f'{curr_ind:05d}.npy')
            np.save(str(w_path), w[j], allow_pickle=False)

            im_path = curr_im_dir.joinpath(f'{curr_ind:05d}.png')
            save_image(images[j], im_path)

            curr_ind += 1


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--resolution', type=int, choices=[256, 1024], default=256)
    parser.add_argument('--batch_size', type=int, default=50)
    parser.add_argument('--truncation', type=float, default=0.7)

    parser.add_argument('--output_path', required=True)
    parser.add_argument('--pretrained_models_path', type=Path, required=True)

    parser.add_argument('--num_images', type=int, default=10000)
    parser.add_argument('--gpu', default='0')

    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    assert args.num_images % 1e3 == 0

    main(args)


================================================
FILE: writer.py
================================================
from utils.general_utils import convert_tensor_to_image
from pathlib import Path

import tensorflow as tf


class Writer(object):
    writer = None

    @staticmethod
    def set_writer(results_dir):
        if isinstance(results_dir, str):
            results_dir = Path(results_dir)
        results_dir.mkdir(exist_ok=True, parents=True)
        Writer.writer = tf.summary.create_file_writer(str(results_dir))

    @staticmethod
    def add_scalar(tag, val, step):
        with Writer.writer.as_default():
            tf.summary.scalar(tag, val, step=step)

    @staticmethod
    def add_image(tag, val, step):
        val = convert_tensor_to_image(val)

        if tf.rank(val) == 3:
            val = tf.expand_dims(val, 0)

        with Writer.writer.as_default():
            tf.summary.image(tag, val, step)

    @staticmethod
    def flush():
        with Writer.writer.as_default():
            Writer.writer.flush()
Download .txt
gitextract_rgin_bu5/

├── .gitignore
├── LICENSE
├── README.md
├── arglib/
│   ├── __init__.py
│   └── arglib.py
├── data_loader/
│   ├── __init__.py
│   └── data_loader.py
├── docs/
│   ├── index.html
│   ├── mainpage.css
│   └── setup.md
├── environment.yml
├── inference.py
├── main.py
├── model/
│   ├── __init__.py
│   ├── arcface/
│   │   ├── arcface.py
│   │   ├── inference.py
│   │   └── resnet.py
│   ├── attr_encoder.py
│   ├── discriminator.py
│   ├── face_detector.py
│   ├── generator.py
│   ├── id_encoder.py
│   ├── landmarks.py
│   ├── latent_mapping.py
│   ├── model.py
│   └── stylegan.py
├── test.py
├── trainer.py
├── utils/
│   ├── __init__.py
│   ├── general_utils.py
│   └── generate_fake_data.py
└── writer.py
Download .txt
SYMBOL INDEX (193 symbols across 21 files)

FILE: arglib/arglib.py
  class BaseArgs (line 9) | class BaseArgs(ABC):
    method __init__ (line 10) | def __init__(self):
    method add_args (line 22) | def add_args(self):
    method parse (line 47) | def parse(self):
    method log (line 50) | def log(self):
    method add_bool_arg (line 58) | def add_bool_arg(parser, name, default=True):
    method validate (line 65) | def validate(self):
    method process (line 70) | def process(self):
  class TrainArgs (line 88) | class TrainArgs(BaseArgs):
    method __init__ (line 89) | def __init__(self):
    method add_args (line 92) | def add_args(self):
    method validate (line 127) | def validate(self):
    method process (line 132) | def process(self):
  class TestArgs (line 147) | class TestArgs(BaseArgs):
    method __init__ (line 148) | def __init__(self):
    method add_args (line 151) | def add_args(self):
    method validate (line 170) | def validate(self):
    method process (line 178) | def process(self):

FILE: data_loader/data_loader.py
  class DataLoader (line 9) | class DataLoader(object):
    method __init__ (line 10) | def __init__(self, args):
    method get_image (line 33) | def get_image(self, is_train, black_list=None, is_real=False):
    method get_w_by_ind (line 72) | def get_w_by_ind(self, ind):
    method get_real_w (line 84) | def get_real_w(self, is_train, black_list=None, is_real=False):
    method batch_samples (line 90) | def batch_samples(self, get_sample_func, is_train, black_list=None, is...
    method get_batch (line 106) | def get_batch(self, is_train=True, is_cross=False, ws=True):

FILE: inference.py
  class Inference (line 10) | class Inference(object):
    method __init__ (line 11) | def __init__(self, args, model):
    method infer_pairs (line 15) | def infer_pairs(self):
    method infer_on_dirs (line 33) | def infer_on_dirs(self):
    method interpolate (line 62) | def interpolate(self, w_space=True):

FILE: main.py
  function init_logger (line 21) | def init_logger(args):
  function main (line 46) | def main():

FILE: model/arcface/arcface.py
  class Arcfacelayer (line 12) | class Arcfacelayer(tf.keras.layers.Layer):
    method __init__ (line 13) | def __init__(self, output_dim=num_classes, s=64., m=0.50):
    method build (line 19) | def build(self, input_shape):
    method call (line 29) | def call(self, embedding, labels):
    method compute_output_shape (line 67) | def compute_output_shape(self, input_shape):

FILE: model/arcface/inference.py
  class MyArcFace (line 10) | class MyArcFace:
    method __init__ (line 11) | def __init__(self, path_to_weights):
    method get_best_face (line 18) | def get_best_face(self, faces, resolution):
    method __detect_face (line 48) | def __detect_face(self, img):
    method __preprocess (line 68) | def __preprocess(self, img, bbox=None, landmark=None):
    method process_image (line 92) | def process_image(self, img):
    method __call__ (line 105) | def __call__(self, img):

FILE: model/arcface/resnet.py
  function residual_unit_v3 (line 9) | def residual_unit_v3(input, num_filter, stride, dim_match, name):
  function get_fc1 (line 83) | def get_fc1(input):
  function ResNet50 (line 113) | def ResNet50():
  class train_model (line 179) | class train_model(tf.keras.Model):
    method __init__ (line 180) | def __init__(self):
    method call (line 185) | def call(self, x, y):

FILE: model/attr_encoder.py
  class AttrEncoder (line 8) | class AttrEncoder(Model):
    method __init__ (line 9) | def __init__(self, args):
    method call (line 21) | def call(self, input_x):
    method my_save (line 29) | def my_save(self, reason=''):

FILE: model/discriminator.py
  class W_D (line 7) | class W_D(Model):
    method __init__ (line 8) | def __init__(self, args):
    method call (line 25) | def call(self, x):
    method my_save (line 38) | def my_save(self, reason=''):

FILE: model/face_detector.py
  class FaceDetector (line 4) | class FaceDetector(object):
    method __init__ (line 5) | def __init__(self, args, model_path):
    method _build (line 11) | def _build(self):
    method __call__ (line 15) | def __call__(self, input_x):
    method sample_call (line 32) | def sample_call(self, input_x):

FILE: model/generator.py
  class G (line 13) | class G(Model):
    method __init__ (line 14) | def __init__(self, args, id_model_path, image_G,
    method call (line 39) | def call(self, x1, x2):
    method my_save (line 60) | def my_save(self, reason=''):

FILE: model/id_encoder.py
  class IDEncoder (line 5) | class IDEncoder(Model):
    method __init__ (line 7) | def __init__(self, args, model_path, intermediate_layers_names=None):
    method crop_faces (line 24) | def crop_faces(self, img):
    method preprocess (line 53) | def preprocess(self, img):
    method call (line 83) | def call(self, input_x, get_intermediate=False):

FILE: model/landmarks.py
  class LandmarksDetector (line 10) | class LandmarksDetector(Model):
    method __init__ (line 11) | def __init__(self, args, model_path, face_detection_model_path):
    method preprocess (line 21) | def preprocess(self, imgs, face_detection=False):
    method lazy_preprocess (line 30) | def lazy_preprocess(self, imgs):
    method hard_preprocess (line 34) | def hard_preprocess(self, imgs):
    method postprocess (line 80) | def postprocess(self, landmarks, details, face_detection=False):
    method lazy_postprocess (line 88) | def lazy_postprocess(self, batch_lnds, details):
    method hard_postprocess (line 92) | def hard_postprocess(self, batch_lnds, details):
    method call (line 105) | def call(self, input_x, face_detection=False):

FILE: model/latent_mapping.py
  class LatentMappingNetwork (line 8) | class LatentMappingNetwork(Model):
    method __init__ (line 9) | def __init__(self, args):
    method call (line 30) | def call(self, x):
    method my_save (line 47) | def my_save(self, reason=''):

FILE: model/model.py
  class Network (line 16) | class Network(Model):
    method __init__ (line 17) | def __init__(self, args, id_net_path, base_generator,
    method call (line 27) | def call(self):
    method my_save (line 30) | def my_save(self, reason):
    method my_load (line 36) | def my_load(self):
    method train (line 39) | def train(self):
    method test (line 42) | def test(self):
    method _set_trainable_behavior (line 45) | def _set_trainable_behavior(self, trainable):

FILE: model/stylegan.py
  function nf (line 15) | def nf(stage, fmap_base=8192, fmap_decay=1.0, fmap_max=512):
  function LeakyReLU (line 18) | def LeakyReLU(alpha, name):
  function GetWeights (line 24) | def GetWeights(gain=math.sqrt(2)):
  function runtime_coef (line 27) | def runtime_coef(kernel_size, gain, fmaps_in, fmaps_out, lrmul=1.0):
  function pixel_norm (line 35) | def pixel_norm(x, epsilon=1e-8):
  class PixelNorm (line 39) | class PixelNorm(Layer):
    method __init__ (line 40) | def __init__(self, name):
    method call (line 43) | def call(self, inputs):
  class InstanceNorm (line 46) | class InstanceNorm(Layer):
    method __init__ (line 47) | def __init__(self, name):
    method call (line 50) | def call(self, x):
  function Identity (line 60) | def Identity(name):
  function Broadcast (line 63) | def Broadcast(name, dlatent_broadcast=18):
  class Truncation (line 68) | class Truncation(Layer):
    method __init__ (line 69) | def __init__(self, name, num_layers=18, truncation_psi=0.7, truncation...
    method build (line 75) | def build(self, input_shape):
    method call (line 78) | def call(self, inputs):
  class DenseLayer (line 87) | class DenseLayer(Dense):
    method __init__ (line 88) | def __init__(self, units, name, kernel_initializer=GetWeights(), gain=...
    method call (line 93) | def call(self, inputs):
  class Conv2d (line 106) | class Conv2d(Conv2D):
    method __init__ (line 107) | def __init__(self, filters, kernel_size, name, gain=math.sqrt(2), lrmu...
    method call (line 115) | def call(self, inputs):
  class Const (line 132) | class Const(Layer):
    method __init__ (line 133) | def __init__(self, name):
    method build (line 136) | def build(self, input_shape):
    method call (line 139) | def call(self, inputs):
  class RandomNoise (line 142) | class RandomNoise(Layer):
    method __init__ (line 143) | def __init__(self, name, layer_idx):
    method build (line 150) | def build(self, input_shape):
    method call (line 153) | def call(self, inputs):
  class ApplyNoise (line 156) | class ApplyNoise(Layer):
    method __init__ (line 157) | def __init__(self, name, is_const_noise):
    method build (line 161) | def build(self, input_shape):
    method call (line 165) | def call(self, inputs):
  class ApplyBias (line 172) | class ApplyBias(Layer):
    method __init__ (line 173) | def __init__(self, name, lrmul=1.0):
    method build (line 177) | def build(self, input_shape):
    method call (line 180) | def call(self, x):
  class StridedSlice (line 185) | class StridedSlice(Layer):
    method __init__ (line 186) | def __init__(self, layer_idx, name):
    method call (line 190) | def call(self, inputs):
  class StyleModApply (line 193) | class StyleModApply(Layer):
    method __init__ (line 194) | def __init__(self, name):
    method call (line 197) | def call(self, inputs):
  function _blur2d (line 203) | def _blur2d(x, f=[1,2,1], normalize=True, flip=False, stride=1):
  function Blur (line 232) | def Blur(name, blur_filter=[1,2,1]):
  function _downscale2d (line 237) | def _downscale2d(x, factor=2, gain=1):
  function _upscale2d (line 259) | def _upscale2d(x, factor=2, gain=1):
  function Downscaled2d (line 278) | def Downscaled2d(name, factor=2, gain=1):
  function Upscaled2d (line 281) | def Upscaled2d(name, factor=2, gain=1):
  function Conv2d_downscale2d (line 284) | def Conv2d_downscale2d(model, filters, kernel_size, name, gain=math.sqrt...
  function Upscale2d_conv2d (line 301) | def Upscale2d_conv2d(x, filters, kernel_size, name, use_bias, gain=math....
  class Conv2d_transpose (line 312) | class Conv2d_transpose(Conv2DTranspose):
    method __init__ (line 313) | def __init__(self, filters, kernel_size, name, gain=math.sqrt(2), lrmu...
    method build (line 321) | def build(self, input_shape):
    method call (line 325) | def call(self, inputs):
  class MinibatchStddevLayer (line 342) | class MinibatchStddevLayer(tf.keras.layers.Layer):
    method __init__ (line 343) | def __init__(self, group_size =4, num_new_features=1):
    method __call__ (line 348) | def __call__(self, x, *args, **kwargs):
  function minibatch_stddev_layer (line 364) | def minibatch_stddev_layer(x, group_size=4, num_new_features=1):
  function StyleGAN_G_mapping (line 379) | def StyleGAN_G_mapping( latent_size=512, dlatent_size=512, mapping_layer...
  function StyleGAN_G_synthesis (line 410) | def StyleGAN_G_synthesis(dlatent_size=512, resolution=1024, is_const_noi...
  class StyleGAN_G (line 471) | class StyleGAN_G(Model):
    method __init__ (line 472) | def __init__(self, resolution=1024, latent_size=512, dlatent_size=512,...
    method call (line 480) | def call(self, inputs):
    method generate_sample (line 485) | def generate_sample(self, seed=5, is_visualize=False):
  class StyleGAN_D (line 503) | class StyleGAN_D(Model):
    method __init__ (line 504) | def __init__(self, resolution=1024, mbstd_group_size=4, mbstd_num_feat...
    method call (line 547) | def call(self, inputs):
  function copy_weights_to_keras_model (line 551) | def copy_weights_to_keras_model(model, all_weights):

FILE: test.py
  function main (line 20) | def main():

FILE: trainer.py
  function id_loss_func (line 10) | def id_loss_func(y_gt, y_pred):
  class Trainer (line 14) | class Trainer(object):
    method __init__ (line 15) | def __init__(self, args, model, data_loader):
    method train (line 67) | def train(self):
    method train_epoch (line 95) | def train_epoch(self):
    method test (line 264) | def test(self):
    method test_reconstruction (line 392) | def test_reconstruction(self, img, errors_dict, display=False, display...
    method generator_gan_loss (line 410) | def generator_gan_loss(self, fake_logit):
    method discriminator_loss (line 417) | def discriminator_loss(self, fake_logit, real_logit):
    method R1_gp (line 432) | def R1_gp(self, D, x):

FILE: utils/general_utils.py
  function read_image (line 15) | def read_image(img_path, resolution, align=False):
  function find_file_by_str (line 24) | def find_file_by_str(search_dir, s):
  function read_SG_image (line 29) | def read_SG_image(img_path, size=256, resize=True):
  function read_and_align_image (line 45) | def read_and_align_image(img_path, output_size=1024):
  function gaussian_image (line 161) | def gaussian_image(size, sigma, dim=2):
  function inverse_gaussian_image (line 179) | def inverse_gaussian_image(size, sigma, dim=2):
  function is_float (line 190) | def is_float(tensor):
  function convert_tensor_to_image (line 198) | def convert_tensor_to_image(tensor):
  function save_image (line 216) | def save_image(img, file_path):
  function mark_landmarks (line 230) | def mark_landmarks(img, lnd, color=None):
  function get_weights (line 250) | def get_weights(slope=0.2):
  function np_permute (line 269) | def np_permute(tensor, permute):

FILE: utils/generate_fake_data.py
  function main (line 17) | def main(args):

FILE: writer.py
  class Writer (line 7) | class Writer(object):
    method set_writer (line 11) | def set_writer(results_dir):
    method add_scalar (line 18) | def add_scalar(tag, val, step):
    method add_image (line 23) | def add_image(tag, val, step):
    method flush (line 33) | def flush():
Condensed preview — 32 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (138K chars).
[
  {
    "path": ".gitignore",
    "chars": 16,
    "preview": ".idea\n.png\n.jpg\n"
  },
  {
    "path": "LICENSE",
    "chars": 1068,
    "preview": "MIT License\n\nCopyright (c) 2020 YotamNitzan\n\nPermission is hereby granted, free of charge, to any person obtaining a cop"
  },
  {
    "path": "README.md",
    "chars": 5778,
    "preview": "# Face Identity Disentanglement via Latent Space Mapping\n\n<p align=\"center\">\n<img src=\"docs/imgs/teaser.png\" width=\"400p"
  },
  {
    "path": "arglib/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "arglib/arglib.py",
    "chars": 7156,
    "preview": "import math\nimport shutil\nimport logging\nimport argparse\nfrom pathlib import Path\nfrom abc import ABC, abstractmethod\n\n\n"
  },
  {
    "path": "data_loader/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "data_loader/data_loader.py",
    "chars": 4939,
    "preview": "import logging\n\nimport numpy as np\nimport tensorflow as tf\n\nfrom utils.general_utils import read_image\n\n\nclass DataLoade"
  },
  {
    "path": "docs/index.html",
    "chars": 7651,
    "preview": "<html>\n<head>\n    <meta charset=\"utf-8\">\n    <title>ID disentanglement</title>\n\n    <!-- CSS includes -->\n    <link rel="
  },
  {
    "path": "docs/mainpage.css",
    "chars": 981,
    "preview": "body {\n  font-family: 'Lato', sans-serif;\n  font-weight: 300;\n  color: #333;\n  font-size: 16px;\n}\nh1 {\n  font-size: 40px"
  },
  {
    "path": "docs/setup.md",
    "chars": 2518,
    "preview": "# Setup\n\n## Environment\n\nIt's designed to use Tensorflow 2.X on python (3.7), using cuda 10.1 and cudnn 7.6.5.\nRun `cond"
  },
  {
    "path": "environment.yml",
    "chars": 3956,
    "preview": "name: id_disen\nchannels:\n  - conda-forge\n  - defaults\ndependencies:\n  - _libgcc_mutex=0.1=main\n  - _tflow_select=2.1.0=g"
  },
  {
    "path": "inference.py",
    "chars": 6282,
    "preview": "from pathlib import Path\n\nfrom tqdm import tqdm\nimport tensorflow as tf\n\nfrom writer import Writer\nfrom utils import gen"
  },
  {
    "path": "main.py",
    "chars": 3001,
    "preview": "import os\n\nos.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\nos.environ['OMP_NUM_THREADS'] = '1'\nos.environ['USE_SIMPLE_THREADED_"
  },
  {
    "path": "model/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "model/arcface/arcface.py",
    "chars": 2715,
    "preview": "import tensorflow as tf\nimport math\n\nnum_classes = 85742  # 10572\ninitializer = 'glorot_normal'\n# initializer = tf.keras"
  },
  {
    "path": "model/arcface/inference.py",
    "chars": 4406,
    "preview": "import tensorflow as tf\nimport tensorflow_addons as tfa\nimport numpy as np\nimport cv2\nfrom model.arcface.resnet import R"
  },
  {
    "path": "model/arcface/resnet.py",
    "chars": 9048,
    "preview": "import tensorflow as tf\nimport os\nfrom model.arcface.arcface import Arcfacelayer\n\nbn_axis = -1\ninitializer = 'glorot_nor"
  },
  {
    "path": "model/attr_encoder.py",
    "chars": 940,
    "preview": "import logging\n\nimport tensorflow as tf\nfrom tensorflow.keras import Model\nfrom tensorflow.keras.applications.inception_"
  },
  {
    "path": "model/discriminator.py",
    "chars": 1400,
    "preview": "import tensorflow as tf\nfrom tensorflow.keras import layers, Model\nfrom utils.general_utils import get_weights\n\n\n# Discr"
  },
  {
    "path": "model/face_detector.py",
    "chars": 1173,
    "preview": "import tensorflow as tf\n\n\nclass FaceDetector(object):\n    def __init__(self, args, model_path):\n        super().__init__"
  },
  {
    "path": "model/generator.py",
    "chars": 1842,
    "preview": "import logging\n\nfrom model import id_encoder\nfrom model import attr_encoder\nfrom model import latent_mapping\nfrom model "
  },
  {
    "path": "model/id_encoder.py",
    "chars": 3123,
    "preview": "import tensorflow as tf\nimport numpy as np\nfrom tensorflow.keras import Model\n\nclass IDEncoder(Model):\n\n    def __init__"
  },
  {
    "path": "model/landmarks.py",
    "chars": 4055,
    "preview": "import cv2\nimport tensorflow as tf\nimport numpy as np\nfrom tensorflow.keras import Model\n\nfrom utils import general_util"
  },
  {
    "path": "model/latent_mapping.py",
    "chars": 1460,
    "preview": "from utils.general_utils import get_weights\n\nimport tensorflow as tf\nimport numpy as np\nfrom tensorflow.keras import lay"
  },
  {
    "path": "model/model.py",
    "chars": 1303,
    "preview": "import time\nimport sys\n\nsys.path.append('..')\n\nfrom utils import general_utils as utils\nfrom model import id_encoder, la"
  },
  {
    "path": "model/stylegan.py",
    "chars": 25174,
    "preview": "import sys\nimport math\nimport numpy as np\nimport tensorflow as tf\n\nimport matplotlib.pyplot as plt\n\nfrom tensorflow.kera"
  },
  {
    "path": "test.py",
    "chars": 1352,
    "preview": "import os\n\nos.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\nos.environ['OMP_NUM_THREADS'] = '1'\nos.environ['USE_SIMPLE_THREADED_"
  },
  {
    "path": "trainer.py",
    "chars": 19228,
    "preview": "import logging\n\nimport numpy as np\nimport tensorflow as tf\n\nfrom writer import Writer\nfrom utils import general_utils as"
  },
  {
    "path": "utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "utils/general_utils.py",
    "chars": 8659,
    "preview": "from pathlib import Path\n\nimport cv2\nimport numpy as np\nfrom PIL import Image\nimport tensorflow as tf\nfrom tensorflow.ke"
  },
  {
    "path": "utils/generate_fake_data.py",
    "chars": 2579,
    "preview": "import sys\nfrom pathlib import Path\nimport os\n\nsys.path.append('..')\n\nimport argparse\n\nimport tensorflow as tf\nimport nu"
  },
  {
    "path": "writer.py",
    "chars": 925,
    "preview": "from utils.general_utils import convert_tensor_to_image\nfrom pathlib import Path\n\nimport tensorflow as tf\n\n\nclass Writer"
  }
]

About this extraction

This page contains the full source code of the YotamNitzan/ID-disentanglement GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 32 files (129.6 KB), approximately 35.0k tokens, and a symbol index with 193 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!