[
  {
    "path": ".gitignore",
    "content": ".idea\n.png\n.jpg\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2020 YotamNitzan\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# Face Identity Disentanglement via Latent Space Mapping\n\n<p align=\"center\">\n<img src=\"docs/imgs/teaser.png\" width=\"400px\"/>\n</p>\n\n\n## Description   \nOfficial Implementation of the paper *Face Identity Disentanglement via Latent Space Mapping*\nfor both training and evaluation.\n\n> **Face Identity Disentanglement via Latent Space Mapping**<br>\n> Yotam Nitzan<sup>1</sup>, Amit Bermano<sup>1</sup>, Yangyan Li<sup>2</sup>, Daniel Cohen-Or<sup>1</sup><br>\n> <sup>1</sup>Tel-Aviv University, <sup>2</sup>Alibaba <br>\n> https://arxiv.org/abs/2005.07728\n>\n> <p align=\"justify\"><b>Abstract:</b> <i>Learning disentangled representations of data is a fundamental problem in artificial intelligence. Specifically, disentangled latent representations allow generative models to control and compose the disentangled factors in the synthesis process. Current methods, however, require extensive supervision and training, or instead, noticeably compromise quality. In this paper, we present a method that learns how to represent data in a disentangled way, with minimal supervision, manifested solely using available pre-trained networks. Our key insight is to decouple the processes of disentanglement and synthesis, by employing a leading pre-trained unconditional image generator, such as StyleGAN. By learning to map into its latent space, we leverage both its state-of-the-art quality, and its rich and expressive latent space, without the burden of training it. We demonstrate our approach on the complex and high dimensional domain of human heads. We evaluate our method qualitatively and quantitatively, and exhibit its success with de-identification operations and with temporal identity coherency in image sequences. Through extensive experimentation, we show that our method successfully disentangles identity from other facial attributes, surpassing existing methods, even though they require more training and supervision.</i></p>\n\n## Setup\n\nTo setup everything you need check out the [setup instructions](docs/setup.md).\n\n## Training\n\n### Preparing the Dataset\n\nThe dataset is comprised of StyleGAN-generated images and W latent codes, both are generated from a single\nStyleGAN model.\n\nWe also use real images from FFHQ to evaluate quality at test time.\n\nThe dataset is assumed to be in the following structure:\n\n| Path | Description\n| :--- | :---\n| base directory | Directory for all datasets\n| &boxvr;&nbsp; real | FFHQ image dataset\n| &boxvr;&nbsp; dataset_N | dataset for resolution NxN\n| &boxv;&nbsp; &boxvr;&nbsp; images | images generated by StyleGAN\n| &boxv;&nbsp; &boxur;&nbsp; ws | W latent codes generated by StyleGAN\n\nTo generate the `dataset_N` directory, run:\n\n```\ncd utils\\\npython generate_fake_data.py \\ \n    --resolution N \\\n    --batch_size BATCH_SIZE \\\n    --output_path OUTPUT_PATH \\\n    --pretrained_models_path PRETRAINED_MODELS_PATH \\\n    --num_images NUM_IMAGES \\\n    --gpu GPU\n```\n\nIt will generate an image dataset in similar format to FFHQ.\n\n### Start training\n\nTo train the model as done in the paper\n\n```\npython main.py\n    NAME\n    --resolution N\n    --pretrained_models_path PRETRAINED_MODELS_PATH\n    --dataset BASE_DATASET_DIR\n    --batch_size BATCH_SIZE\n    --cross_frequency 3\n    --train_data_size 70000\n    --results_dir RESULTS_DIR        \n```\n\nPlease run `python main.py -h` for more details.\n\n## Inference\n\nFor convenience, there are a few inference functions - each serving a different use case.\nThe functions are resolved using the name of the function.\n\n### All possible combinations in dirs\n\n<p align=\"center\">\n<img src=\"docs/imgs/table_results.jpg\"/>\n</p>\n\n**Input data: Two directories, one identity inputs and another for attribute inputs.** <br>\nRuns over all N*M combinations in two directories.\n\n```\npython test.py \n    Name\n    --pretrained_models_path PRETRAINED_MODELS_PATH \\\n    --load_checkpoint PATH_TO_WEIGHTS \\\n    --id_dir DIR_OF_IMAGES_FOR_ID \\\n    --attr_dir DIR_OF_IMAGES_FOR_ATTR \\\n    --output_dir DIR_FOR_OUTPUTS \\\n    --test_func infer_on_dirs\n```\n\n\n### Paired data\n\n**Input data: Two directories, one identity inputs and another for attribute inputs**. <br>\nThe two directories are assumed to be paired. Inference runs on images with the same names.\n\n```\npython test.py \n    Name\n    --pretrained_models_path PRETRAINED_MODELS_PATH \\\n    --load_checkpoint PATH_TO_WEIGHTS \\\n    --id_dir DIR_OF_IMAGES_FOR_ID \\\n    --attr_dir DIR_OF_IMAGES_FOR_ATTR \\\n    --output_dir DIR_FOR_OUTPUTS \\\n    --test_func infer_pairs\n```\n\n### Disentangled interpolation\n\n#### Interpolating attributes\n\n<p align=\"center\">\n<img src=\"docs/imgs/interpolate_attr.jpg\"/>\n</p>\n\n#### Interpolating identity\n\n<p align=\"center\">\n<img src=\"docs/imgs/interpolate_id.jpg\"/>\n</p>\n\n**Input data: A directory with any number of subdirectories. In each subdir, there are three images.**\nAll images should have exactly one of *attr* or *id* in their name.\nIf there are two *attr* images and one *id* image, it will interpolate attribute.\nIf there is one *attr* images and two *id* images, it will interpolate identity.\n\n\n```\npython test.py \n    Name\n    --pretrained_models_path PRETRAINED_MODELS_PATH \\\n    --load_checkpoint PATH_TO_WEIGHTS \\\n    --input_dir PARENT_DIR \\\n    --output_dir DIR_FOR_OUTPUTS \\\n    --test_func interpolate\n```\n\n## Checkpoints\n\nOur pretrained 256x256 [checkpoint](https://drive.google.com/drive/folders/1lVizq4hCq-zTf8Q3fDqqfSnV6jIYEgY_?usp=sharing) is also available.\n\n## Citation\nIf you use this code for your research, please cite our paper using:\n\n```\n@article{Nitzan2020FaceID,\n  title={Face identity disentanglement via latent space mapping},\n  author={Yotam Nitzan and A. Bermano and Yangyan Li and D. Cohen-Or},\n  journal={ACM Transactions on Graphics (TOG)},\n  year={2020},\n  volume={39},\n  pages={1 - 14}\n}\n```"
  },
  {
    "path": "arglib/__init__.py",
    "content": ""
  },
  {
    "path": "arglib/arglib.py",
    "content": "import math\nimport shutil\nimport logging\nimport argparse\nfrom pathlib import Path\nfrom abc import ABC, abstractmethod\n\n\nclass BaseArgs(ABC):\n    def __init__(self):\n        self.args = None\n        self.parser = argparse.ArgumentParser()\n        self.logger = logging.getLogger(self.__class__.__name__)\n\n        self.add_args()\n        self.parse()\n        self.validate()\n        self.process()\n        self.str_args = self.log()\n\n    @abstractmethod\n    def add_args(self):\n        # Hardware\n        self.parser.add_argument('--gpu', type=str, default='0')\n\n        # Model\n        self.parser.add_argument('--face_detection', action='store_true')\n        self.parser.add_argument('--resolution', type=int, default=256, choices=[256, 1024])\n        self.parser.add_argument('--load_checkpoint')\n        self.parser.add_argument('--pretrained_models_path', type=Path, required=True)\n\n        BaseArgs.add_bool_arg(self.parser, 'const_noise')\n\n        # Data\n        self.parser.add_argument('--batch_size', type=int, default=6)\n        self.parser.add_argument('--reals', action='store_true', help='Use real inputs')\n        BaseArgs.add_bool_arg(self.parser, 'test_real_attr')\n\n        # Log & Results\n        self.parser.add_argument('name', type=str, help='Name under which run will be saved')\n        self.parser.add_argument('--results_dir', type=str, default='../results')\n        self.parser.add_argument('--log_debug', action='store_true')\n\n        # Other\n        self.parser.add_argument('--debug', action='store_true')\n\n    def parse(self):\n        self.args = self.parser.parse_args()\n\n    def log(self):\n        out_str = 'The arguments are:\\n'\n        for k, v in self.args.__dict__.items():\n            out_str += f'{k}: {v}\\n'\n\n        return out_str\n\n    @staticmethod\n    def add_bool_arg(parser, name, default=True):\n        group = parser.add_mutually_exclusive_group(required=False)\n        group.add_argument('--' + name, dest=name, action='store_true')\n        group.add_argument('--no_' + name, dest=name, action='store_false')\n        parser.set_defaults(**{name: default})\n\n    @abstractmethod\n    def validate(self):\n        if self.args.load_checkpoint and not Path(self.args.load_checkpoint).exists():\n            raise ValueError(f'Checkpoint directory {self.args.load_checkpoint} does not exist')\n\n    @abstractmethod\n    def process(self):\n        # Log & Results\n        self.args.results_dir = Path(self.args.results_dir).joinpath(self.args.name)\n\n        if self.args.debug:\n            self.args.log_debug = True\n        if self.args.debug or not self.args.train:\n            shutil.rmtree(self.args.results_dir, ignore_errors=True)\n\n        self.args.results_dir.mkdir(parents=True, exist_ok=True)\n        self.args.images_results = self.args.results_dir.joinpath('images')\n        self.args.images_results.mkdir(exist_ok=True)\n\n        # Model\n        if self.args.load_checkpoint:\n            self.args.load_checkpoint = Path(self.args.load_checkpoint)\n\n\nclass TrainArgs(BaseArgs):\n    def __init__(self):\n        super().__init__()\n\n    def add_args(self):\n        super().add_args()\n\n        self.parser.add_argument('--dataset_path', type=str, default='../my_dataset')\n\n        self.parser.add_argument('--num_epochs', type=int, default=math.inf)\n        self.parser.add_argument('--cross_frequency', type=int, default=3,\n                                 help='Once in how many epochs to perform cross-train epoch (0 for never)')\n\n        self.parser.add_argument('--unified', action='store_true')\n\n        # Data\n        BaseArgs.add_bool_arg(self.parser, 'train_real_attr', default=False)\n        self.parser.add_argument('--train_data_size', type=int, default=70000,\n                                 help='How many images to use for training. Others are used as validation')\n\n        # Losses\n        BaseArgs.add_bool_arg(self.parser, 'id_loss')\n        BaseArgs.add_bool_arg(self.parser, 'landmarks_loss')\n        BaseArgs.add_bool_arg(self.parser, 'pixel_loss')\n        BaseArgs.add_bool_arg(self.parser, 'W_D_loss')\n        BaseArgs.add_bool_arg(self.parser, 'gp')\n\n        self.parser.add_argument('--pixel_mask_type', choices=['uniform', 'gaussian'], default='gaussian')\n        self.parser.add_argument('--pixel_loss_type', choices=['L1', 'mix'], default='mix')\n\n        # Test During training\n        self.parser.add_argument('--test_frequency', type=int, default=1000,\n                                 help='Once in how many epochs to perform a test')\n        self.parser.add_argument('--test_size', type=int, default=50,\n                                 help='How many mini-batches should be used for a test')\n        self.parser.add_argument('--not_improved_exit', type=int, default=math.inf,\n                                 help='After how many not-improved test to exit')\n        BaseArgs.add_bool_arg(self.parser, 'test_with_arcface')\n\n    def validate(self):\n        super().validate()\n        if not Path(self.args.dataset_path).exists():\n            raise ValueError(f'Dataset at path: {self.args.dataset_path} does not exist')\n\n    def process(self):\n        self.args.train = True\n\n        super().process()\n\n        # Dataset\n        self.args.dataset_path = Path(self.args.dataset_path)\n\n        self.args.weights_dir = self.args.results_dir.joinpath('weights')\n        self.args.weights_dir.mkdir(exist_ok=True)\n        backup_code_dir = self.args.results_dir.joinpath('code')\n        code_dir = Path().cwd()\n        shutil.copytree(code_dir, backup_code_dir)\n\n\nclass TestArgs(BaseArgs):\n    def __init__(self):\n        super().__init__()\n\n    def add_args(self):\n        super().add_args()\n        self.parser.set_defaults(batch_size=1)\n\n        self.parser.add_argument('--id_dir', type=Path)\n        self.parser.add_argument('--attr_dir', type=Path)\n        self.parser.add_argument('--output_dir', type=Path)\n        self.parser.add_argument('--input_dir', type=Path)\n\n        self.parser.add_argument('--real_id', action='store_true')\n        self.parser.add_argument('--real_attr', action='store_true')\n        BaseArgs.add_bool_arg(self.parser, 'loop_fake')\n\n        self.parser.add_argument('--img_suffixes', type=list, default=['png', 'jpg', 'jpeg'])\n\n        self.parser.add_argument('--test_func', type=str, choices=['infer_on_dirs', 'infer_pairs', 'interpolate'])\n\n        self.parser.add_argument('--input', type=str)\n\n    def validate(self):\n        super().validate()\n\n        # if not self.args.input:\n        #     raise ValueError('Input needed for inference')\n        # if not Path(self.args.input).exists():\n        #     raise ValueError(f'Input {self.args.input} does not exist')\n\n    def process(self):\n        self.args.train = False\n\n        super().process()\n\n        self.args.output_dir.mkdir(exist_ok=True, parents=True)\n\n        # self.args.input = Path(self.args.input)\n        # Split frame sit alongside input, so not every run needs to preprocess\n        # input_name = self.args.input.stem\n        # self.args.extracted_frames_dir = self.args.input.parent.joinpath(f'{input_name}_frames')\n        # self.args.extracted_frames_dir.mkdir(exist_ok=True)\n"
  },
  {
    "path": "data_loader/__init__.py",
    "content": ""
  },
  {
    "path": "data_loader/data_loader.py",
    "content": "import logging\n\nimport numpy as np\nimport tensorflow as tf\n\nfrom utils.general_utils import read_image\n\n\nclass DataLoader(object):\n    def __init__(self, args):\n        super().__init__()\n        self.args = args\n        self.logger = logging.getLogger(self.__class__.__name__)\n\n        self.real_dataset = args.dataset_path.joinpath(f'real')\n\n        dataset = args.dataset_path.joinpath(f'dataset_{args.resolution}')\n\n        self.ws_dataset = dataset.joinpath('ws')\n        self.image_dataset = dataset.joinpath('images')\n\n        max_dir = max([x.name for x in self.image_dataset.iterdir()])\n        self.max_ind = max([int(x.stem) for x in self.image_dataset.joinpath(max_dir).iterdir()])\n        self.train_max_ind = args.train_data_size\n\n        if self.train_max_ind >= self.max_ind:\n            self.logger.warning('There is no validation data... using training data')\n            self.min_val_ind = 0\n            self.train_max_ind = self.max_ind\n        else:\n            self.min_val_ind = self.train_max_ind + 1\n\n    def get_image(self, is_train, black_list=None, is_real=False):\n        # Default should be non-mutable\n        if black_list is None:\n            black_list = []\n\n        max_fails = 10\n        curr_fail = 0\n        if is_train:\n            min_ind, max_ind = 0, self.train_max_ind\n        else:\n            min_ind, max_ind = self.min_val_ind, self.max_ind\n\n        while True:\n            ind = np.random.randint(min_ind, max_ind)\n\n            if ind in black_list:\n                continue\n\n            img_name = f'{ind:05d}.png'\n            dir_name = f'{int(ind - ind % 1e3):05d}'\n            if is_real:\n                img_path = self.real_dataset.joinpath(dir_name, img_name)\n            else:\n                img_path = self.image_dataset.joinpath(dir_name, img_name)\n\n            try:\n                img = read_image(img_path, self.args.resolution)\n                break\n            except Exception as e:\n                self.logger.warning(f'Failed reading image at {ind}. Error: {e}')\n\n                # Try again with a different image...\n                curr_fail += 1\n                if curr_fail > max_fails:\n                    raise IOError('Failed reading multiples images')\n                continue\n\n        return ind, img\n\n    def get_w_by_ind(self, ind):\n        dir_name = f'{int(ind - ind % 1e3):05d}'\n        img_name = f'{ind:05d}.npy'\n        w_path = self.ws_dataset.joinpath(dir_name, img_name)\n\n        w = np.load(w_path)\n\n        # Take one row while keeping dimension\n        w = w[np.newaxis, 0]\n\n        return w\n\n    def get_real_w(self, is_train, black_list=None, is_real=False):\n        ind = np.random.randint(0, self.max_ind)\n        w = self.get_w_by_ind(ind)\n\n        return ind, w\n\n    def batch_samples(self, get_sample_func, is_train, black_list=None, is_real=False):\n        batch = []\n        indices = []\n\n        if not black_list:\n            black_list = []\n        for i in range(self.args.batch_size):\n            ind, sample = get_sample_func(is_train, black_list, is_real)\n\n            batch.append(sample)\n            indices.append(ind)\n\n        batch = tf.concat(batch, 0)\n\n        return indices, batch\n\n    def get_batch(self, is_train=True, is_cross=False, ws=True):\n        black_list = []\n        id_imgs_indices, id_img = self.batch_samples(self.get_image, is_train)\n        matching_ws = None\n\n        self.logger.debug(f'ID images read: {id_imgs_indices}')\n        black_list.extend(id_imgs_indices)\n\n        if is_cross:\n            # Use real attr when args say so or when testing\n            is_real_attr = (is_train and self.args.train_real_attr) or (not is_train and self.args.test_real_attr)\n            black_list = [] if is_real_attr else black_list\n\n            attr_imgs_indices, attr_img = self.batch_samples(self.get_image,\n                                                             is_train,\n                                                             black_list=black_list,\n                                                             is_real=is_real_attr)\n\n            self.logger.debug(f'Attr images read: {attr_imgs_indices}')\n\n        else:\n            if is_train:\n                attr_img = id_img\n                matching_ws = [self.get_w_by_ind(ind) for ind in id_imgs_indices]\n                matching_ws = tf.concat(matching_ws, 0)\n            else:\n                attr_img = id_img\n\n        if not is_train:\n            return attr_img, id_img\n\n        # Only for training\n        real_img = None\n        real_ws = None\n\n        if self.args.train and self.args.reals:\n            real_imgs_indices, real_img = self.batch_samples(self.get_image, is_train, black_list=[], is_real=True)\n            self.logger.debug(f'Real images read: {real_imgs_indices}')\n\n        if ws:\n            _, real_ws = self.batch_samples(self.get_real_w, is_train)\n\n        return attr_img, id_img, real_ws, real_img, matching_ws\n\n"
  },
  {
    "path": "docs/index.html",
    "content": "<html>\n<head>\n    <meta charset=\"utf-8\">\n    <title>ID disentanglement</title>\n\n    <!-- CSS includes -->\n    <link rel=\"stylesheet\" href=\"https://use.fontawesome.com/releases/v5.8.1/css/all.css\"\n          integrity=\"sha384-50oBUHEmvpQ+1lW4y57PTFmhCaXp0ML5d60M1M7uH2+nqUivzIebhndOJK28anvf\" crossorigin=\"anonymous\">\n    <link href=\"https://maxcdn.bootstrapcdn.com/bootstrap/3.3.5/css/bootstrap.min.css\" rel=\"stylesheet\">\n    <link href=\"mainpage.css\" rel=\"stylesheet\">\n</head>\n<body>\n\n<div class=\"container-fluid\">\n    <div class=\"row\">\n        <h1><span style=\"font-size:36px\">Face Identity Disentanglement via Latent Space Mapping</span></h1>\n        <h1><span style=\"font-size:22px\">SIGGRAPH ASIA 2020</span></h1>\n\n        <div class=\"authors\">\n            <span style=\"font-size:18px\"><a href=\"https://yotamnitzan.github.io/\" target=\"new\">Yotam Nitzan<sup>1</sup></a></span>\n            &nbsp;\n            <span style=\"font-size:18px\"><a href=\"https://www.cs.tau.ac.il/~amberman/\"\n                                            target=\"new\">Amit Bermano<sup>1</sup></a></span>\n            &nbsp;\n            <span style=\"font-size:18px\"><a href=\"http://yangyan.li/\" target=\"new\">Yangyan Li<sup>2</sup></a></span>\n            &nbsp;\n            <span style=\"font-size:18px\"><a href=\"https://danielcohenor.com/\"\n                                            target=\"new\">Daniel Cohen-Or<sup>1</sup></a></span>\n            <br>\n            <span style=\"font-size:18px\"><sup>1</sup>Tel-Aviv University &nbsp;&nbsp;&nbsp; <sup>2</sup>Alibaba Cloud Intelligence Business Group<br><br></span>\n        </div>\n    </div>\n\n    <div class=\"row\" style=\"text-align:center;padding:0;margin:0\">\n        <div class=\"container\">\n            <img src=\"imgs/teaser.png\" height=\"650px\">\n        </div>\n    </div>\n\n    <div class=\"container\">\n\n        <div class=\"row\">\n            <div class=\"col-lg-1 col-md-0 col-sm-0\"></div>\n            <div class=\"col-lg-1 col-md-0 col-sm-0\"></div>\n\n            <div class=\"col-lg-3 col-md-4 col-sm-4 text-center\">\n                <div class=\"service-box mt-5 mx-auto\">\n                    <a href=\"https://arxiv.org/abs/2005.07728\" target=\"_blank\">\n                        <i class=\"far fa-4x fa-file text-primary mb-3 \"></i>\n                    </a>\n                    <h3 class=\"mb-3\">Paper</h3>\n                </div>\n            </div>\n\n            <div class=\"col-lg-1 col-md-0 col-sm-0\"></div>\n            <div class=\"col-lg-1 col-md-0 col-sm-0\"></div>\n\n            <div class=\"col-lg-2 col-md-4 col-sm-6 text-center\">\n                <div class=\"service-box mt-5 mx-auto\">\n                    <a href=\"https://github.com/YotamNitzan/ID-disentanglement\" target=\"_blank\">\n                        <i class=\"fab fa-4x fa-github text-primary mb-3 \"></i>\n                    </a>\n                    <h3 class=\"mb-3\">Code</h3>\n                </div>\n            </div>\n\n        </div>\n    </div>\n\n    <div class=\"container\">\n        <h2>Abstract</h2>\n        Learning disentangled representations of data is a fundamental problem in\n        artificial intelligence. Specifically, disentangled latent representations allow\n        generative models to control and compose the disentangled factors in the\n        synthesis process. Current methods, however, require extensive supervision\n        and training, or instead, noticeably compromise quality.\n        In this paper, we present a method that learns how to represent data\n        in a disentangled way, with minimal supervision, manifested solely using\n        available pre-trained networks. Our key insight is to decouple the processes\n        of disentanglement and synthesis, by employing a leading pre-trained unconditional image generator, such as\n        StyleGAN. By learning to map into its\n        latent space, we leverage both its state-of-the-art quality, and its rich and\n        expressive latent space, without the burden of training it.\n        We demonstrate our approach on the complex and high dimensional\n        domain of human heads. We evaluate our method qualitatively and quantitatively, and exhibit its success with\n        de-identification operations and with\n        temporal identity coherency in image sequences. Through extensive experimentation, we show that our method\n        successfully disentangles identity\n        from other facial attributes, surpassing existing methods, even though they\n        require more training and supervision.\n    </div>\n\n    <div class=\"container\">\n        <h2>motivation</h2>\n        learning disentangled representations and image synthesis are different tasks.\n        however, it is a common practice to solve both simultaneously.\n        this way, the image generator learns the semantics of the representations.\n        now it is able to take multiple representations from different sources and mix them to generate novel images.\n        but this comes at a price, one now needs to solve two difficult tasks simultaneously.\n        this often causes the need to devise dedicated architectures and even then, achieve sub-optimal visual quality.\n        <br><br>\n        we propose a different approach. unconditional generators have recently achieved amazing image quality.\n        we take advantage of this fact, and avoid solving this task ourselves. instead, we suggest to use a pretrained\n        generator, such as stylegan. but now, how can the generator, which is pretrained & unconditional, make sense of\n        the disentangled representations?\n        <br>\n        we suggest mapping the disentangled representations directly into the latent space of the generator.\n        the mapping produces in a single feed-forward a new, never before seen, latent code that corresponds to novel\n        images.\n\n        <div class=\"row\" style=\"text-align:center;padding:0;margin:0\">\n            <img src=\"imgs/architecture.jpg\" height=\"512px\">\n        </div>\n\n\n    </div>\n\n    <div class=\"container\">\n        <h2>Composition Results</h2>\n\n        We demonstrate our method on the domain of human faces - specifically disentangling identity from all other\n        attributes.\n        <br>\n        In the following tables the identity is taken from the image on top and the attributes are taken from the\n        left most image.\n        In this figure, the inputs themselves are StyleGAN generated images.\n        <div class=\"row\" style=\"text-align:center;padding:0;margin:0\">\n            <img src=\"imgs/table_results.jpg\" height=\"850px\">\n        </div>\n        <div class=\"space\"></div>\n\n        More results, but this time, the input images are real.\n        <div class=\"row\" style=\"text-align:center;padding:0;margin:0\">\n            <img src=\"imgs/ffhq_table_results.jpg\" height=\"850px\">\n        </div>\n    </div>\n\n    <div class=\"container\">\n        <h2>Disentangled Interpolation</h2>\n\n        Thanks to our disentangled representations, we are able to interpolate only a single feature\n        (identity or attributes) in the generator's latent space.\n        This enables more control and opens the door for new disentangled editing capabilities.\n        <br><br>\n        <div class=\"row\" style=\"text-align:center;padding:0;margin:0\">\n            <img src=\"imgs/interpolate_attr.jpg\" width=\"1024px\">\n        </div>\n        <div class=\"space\"></div>\n        <div class=\"row\" style=\"text-align:center;padding:0;margin:0\">\n            <img src=\"imgs/interpolate_id.jpg\" width=\"1024px\">\n        </div>\n\n    </div>\n\n    <div class=\"container\">\n        <h2>Contact</h2>\n        <div>\n            yotamnitzan at gmail dot com\n        </div>\n    </div>\n\n    <div id=\"footer\">\n    </div>\n\n\n</body>\n</html>\n"
  },
  {
    "path": "docs/mainpage.css",
    "content": "body {\n  font-family: 'Lato', sans-serif;\n  font-weight: 300;\n  color: #333;\n  font-size: 16px;\n}\nh1 {\n  font-size: 40px;\n  color: #555;\n  font-weight: 400;\n  text-align: center;\n  margin: 0;\n  padding: 0;\n  margin-top: 30px;\n  margin-bottom: 10px;\n}\n.authors {\n  color: #222;\n  font-size: 24px;\n  font-weight: 300;\n  text-align: center;\n  margin: 0;\n  padding: 0;\n  margin-bottom: 0px;\n}\n.logoimg {\n  text-align: center;\n  margin-bottom: 30px;\n}\n.container-fluid {\n  margin-top: 5px;\n  margin-bottom: 5px;\n}\n.container {\n  margin-top: 10px;\n}\n#footer {\n  margin-bottom: 100px;\n}\n.thumbs {\n  -webkit-box-shadow: 1px 1px 3px #999;\n  -moz-box-shadow: 1px 1px 3px #999;\n  box-shadow: 1px 1px 3px #999;\n  margin-bottom: 20px;\n}\nh2 {\n  font-size: 24px;\n  font-weight: 900;\n  border-bottom: 1px solid #999;\n  margin-bottom: 20px;\n}\n\n.space {\n   margin-bottom: 1.5cm;\n}\n\n\n.text-primary {\n  color: #5da2d5 !important;\n}\n.text-primary:hover {\n  color: #f3d250  !important;\n  opacity: 1.0;\n}"
  },
  {
    "path": "docs/setup.md",
    "content": "# Setup\n\n## Environment\n\nIt's designed to use Tensorflow 2.X on python (3.7), using cuda 10.1 and cudnn 7.6.5.\nRun `conda create -n environment.yml` to create a conda environment that has the needed dependencies.\n\nTested with Tensorflow 2.0.0, Python 3.7.9, Ubuntu 14.04. \n\n\n## Third-party pretrained networks\n\nOur method relies on several pretrained networks.\nSome are needed only for training and some also for inference.\nDownload according to your intention.\n\nPut all downloaded files/directories under a single directory, which will\nbe the baseline path for all pretrained networks.\n\n| Name | Training | Inference |Description\n| :--- | :----------:| :----------:| :----------\n|[FFHQ StyleGAN 256x256](https://drive.google.com/drive/folders/1OgLvUhd9FX9_mPXrfqAWaLZsceQzE9l4?usp=sharing) | :heavy_check_mark: | :heavy_check_mark:  | StyleGAN model pretrained on FFHQ with 256x256 resolution. Converted using [StyleGAN-Tensorflow2](https://github.com/YotamNitzan/StyleGAN-Tensorflow2)\n|[FFHQ StyleGAN 1024x1024](https://drive.google.com/drive/folders/1jQxJsmapu6SjygvJfvP4-YVxZ9f5Hu_N?usp=sharing) | :heavy_check_mark: | :heavy_check_mark:  | StyleGAN model pretrained on FFHQ with 1024x1024 resolution. Converted using [StyleGAN-Tensorflow2](https://github.com/YotamNitzan/StyleGAN-Tensorflow2)\n|[VGGFace2](https://drive.google.com/file/d/1I_JyR7LH-30hEIpD4OSFVg2TOf9Q8cqU/view?usp=sharing) | :heavy_check_mark: | :heavy_check_mark:  | Pretrained VGGFace2 model taken from [WeidiXie](https://github.com/WeidiXie/Keras-VGGFace2-ResNet50).\n|[dlib landmarks model](http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2) |  | :heavy_check_mark: | dlib landmarks model, used to align images.\n|[ArcFace](https://drive.google.com/drive/folders/1F-Ll9Nw7I1FGP61cpQxOdhs2nxi0E5mg?usp=sharing) | :heavy_check_mark: |   | Pretrained ArcFace model taken from [dmonterom](https://github.com/dmonterom/face_recognition_TF2).\n|[Face & Landmarks Detection](https://drive.google.com/drive/folders/1D__J9UMwzBNR9eVrQGYuL9ueYGi7G4qh?usp=sharing) | :heavy_check_mark: |   | Pretrained face detection and differentiable facial landmarks detection from [610265158](https://github.com/610265158/face_landmark).\n\n\n### Other StyleGANs\n\nTo try out our method with other checkpoints of StyleGAN, first obtain a trained StyleGAN pkl file using the [original StyleGAN repository](https://github.com/NVlabs/stylegan)  \nNext, convert it to Tensorflow-2.0 using this [repository](https://github.com/YotamNitzan/StyleGAN-Tensorflow2).\n\n\n\n"
  },
  {
    "path": "environment.yml",
    "content": "name: id_disen\nchannels:\n  - conda-forge\n  - defaults\ndependencies:\n  - _libgcc_mutex=0.1=main\n  - _tflow_select=2.1.0=gpu\n  - absl-py=0.11.0=py37h06a4308_0\n  - aiohttp=3.6.3=py37h7b6447c_0\n  - astor=0.8.1=py37_0\n  - async-timeout=3.0.1=py37_0\n  - attrs=20.3.0=pyhd3eb1b0_0\n  - blas=1.0=mkl\n  - blinker=1.4=py37_0\n  - blosc=1.20.1=he1b5a44_0\n  - brotli=1.0.9=he1b5a44_3\n  - brotlipy=0.7.0=py37h27cfd23_1003\n  - bzip2=1.0.8=h516909a_3\n  - c-ares=1.17.1=h27cfd23_0\n  - ca-certificates=2020.11.8=ha878542_0\n  - cachetools=4.1.1=py_0\n  - certifi=2020.11.8=py37h89c1867_0\n  - cffi=1.14.3=py37h261ae71_2\n  - chardet=3.0.4=py37h06a4308_1003\n  - charls=2.1.0=he1b5a44_2\n  - click=7.1.2=py_0\n  - cloudpickle=1.6.0=py_0\n  - cryptography=3.2.1=py37h3c74f83_1\n  - cudatoolkit=10.0.130=0\n  - cudnn=7.6.5=cuda10.0_0\n  - cupti=10.0.130=0\n  - cycler=0.10.0=py_2\n  - cytoolz=0.11.0=py37h8f50634_1\n  - dask-core=2.30.0=py_0\n  - decorator=4.4.2=py_0\n  - freetype=2.10.4=h7ca028e_0\n  - gast=0.2.2=py37_0\n  - giflib=5.2.1=h36c2ea0_2\n  - google-auth=1.23.0=pyhd3eb1b0_0\n  - google-auth-oauthlib=0.4.2=pyhd3eb1b0_2\n  - google-pasta=0.2.0=py_0\n  - grpcio=1.31.0=py37hf8bcb03_0\n  - h5py=2.10.0=py37hd6299e0_1\n  - hdf5=1.10.6=hb1b8bf9_0\n  - idna=2.10=py_0\n  - imagecodecs=2020.5.30=py37hda6ee5b_1\n  - imageio=2.9.0=py_0\n  - importlib-metadata=2.0.0=py_1\n  - intel-openmp=2020.2=254\n  - jpeg=9d=h36c2ea0_0\n  - jxrlib=1.1=h516909a_2\n  - keras-applications=1.0.8=py_1\n  - keras-preprocessing=1.1.0=py_1\n  - kiwisolver=1.3.1=py37hc928c03_0\n  - lcms2=2.11=hcbb858e_1\n  - ld_impl_linux-64=2.33.1=h53a641e_7\n  - libaec=1.0.4=he1b5a44_1\n  - libedit=3.1.20191231=h14c3975_1\n  - libffi=3.3=he6710b0_2\n  - libgcc-ng=9.1.0=hdf63c60_0\n  - libgfortran-ng=7.3.0=hdf63c60_0\n  - libpng=1.6.37=h21135ba_2\n  - libprotobuf=3.13.0.1=hd408876_0\n  - libstdcxx-ng=9.1.0=hdf63c60_0\n  - libtiff=4.1.0=h4f3a223_6\n  - libwebp-base=1.1.0=h36c2ea0_3\n  - libzopfli=1.0.3=he1b5a44_0\n  - lz4-c=1.9.2=he1b5a44_3\n  - markdown=3.3.3=py37h06a4308_0\n  - matplotlib-base=3.3.3=py37h4f6019d_0\n  - mkl=2020.2=256\n  - mkl-service=2.3.0=py37he904b0f_0\n  - mkl_fft=1.2.0=py37h23d657b_0\n  - mkl_random=1.1.1=py37h0573a6f_0\n  - multidict=4.7.6=py37h7b6447c_1\n  - ncurses=6.2=he6710b0_1\n  - networkx=2.5=py_0\n  - numpy=1.19.2=py37h54aff64_0\n  - numpy-base=1.19.2=py37hfa32c7d_0\n  - oauthlib=3.1.0=py_0\n  - olefile=0.46=pyh9f0ad1d_1\n  - openjpeg=2.3.1=h981e76c_3\n  - openssl=1.1.1h=h516909a_0\n  - opt_einsum=3.1.0=py_0\n  - pillow=8.0.1=py37h63a5d19_0\n  - pip=20.2.4=py37h06a4308_0\n  - protobuf=3.13.0.1=py37he6710b0_1\n  - pyasn1=0.4.8=py_0\n  - pyasn1-modules=0.2.8=py_0\n  - pycparser=2.20=py_2\n  - pyjwt=1.7.1=py37_0\n  - pyopenssl=19.1.0=pyhd3eb1b0_1\n  - pyparsing=2.4.7=pyh9f0ad1d_0\n  - pysocks=1.7.1=py37_1\n  - python=3.7.9=h7579374_0\n  - python-dateutil=2.8.1=py_0\n  - python_abi=3.7=1_cp37m\n  - pywavelets=1.1.1=py37h161383b_3\n  - readline=8.0=h7b6447c_0\n  - requests=2.24.0=py_0\n  - requests-oauthlib=1.3.0=py_0\n  - rsa=4.6=py_0\n  - scikit-image=0.17.2=py37h10a2094_4\n  - scipy=1.5.2=py37h0b6359f_0\n  - setuptools=50.3.1=py37h06a4308_1\n  - six=1.15.0=py37h06a4308_0\n  - snappy=1.1.8=he1b5a44_3\n  - sqlite=3.33.0=h62c20be_0\n  - tensorboard-plugin-wit=1.6.0=py_0\n  - tensorflow=2.0.0=gpu_py37h768510d_0\n  - tensorflow-base=2.0.0=gpu_py37h0ec5d1f_0\n  - tensorflow-estimator=2.0.0=pyh2649769_0\n  - termcolor=1.1.0=py37_1\n  - tifffile=2020.11.18=pyhd8ed1ab_0\n  - tk=8.6.10=hbc83047_0\n  - toolz=0.11.1=py_0\n  - tornado=6.1=py37h4abf009_0\n  - urllib3=1.25.11=py_0\n  - werkzeug=0.16.1=py_0\n  - wheel=0.35.1=pyhd3eb1b0_0\n  - wrapt=1.12.1=py37h7b6447c_1\n  - xz=5.2.5=h7b6447c_0\n  - yaml=0.2.5=h516909a_0\n  - yarl=1.6.2=py37h7b6447c_0\n  - zipp=3.4.0=pyhd3eb1b0_0\n  - zlib=1.2.11=h7b6447c_3\n  - zstd=1.4.5=h6597ccf_2\n  - pip:\n    - dlib==19.21.0\n    - keras==2.3.1\n    - mtcnn==0.1.0\n    - opencv-python==4.4.0.46\n    - pyyaml==5.3.1\n    - tensorboard==2.0.2\n    - tensorflow-addons==0.6.0\n    - tensorflow-gpu==2.0.0\n    - tqdm==4.53.0\n"
  },
  {
    "path": "inference.py",
    "content": "from pathlib import Path\n\nfrom tqdm import tqdm\nimport tensorflow as tf\n\nfrom writer import Writer\nfrom utils import general_utils as utils\n\n\nclass Inference(object):\n    def __init__(self, args, model):\n        self.args = args\n        self.G = model.G\n\n    def infer_pairs(self):\n        names = [f for f in self.args.id_dir.iterdir() if f.suffix[1:] in self.args.img_suffixes]\n        names.extend([f for f in self.args.attr_dir.iterdir() if f.suffix[1:] in self.args.img_suffixes])\n\n        for img_name in tqdm(names):\n            id_path = utils.find_file_by_str(self.args.id_dir, img_name.stem)\n            attr_path = utils.find_file_by_str(self.args.attr_dir, img_name.stem)\n            if len(id_path) != 1 or len(attr_path) != 1:\n                print(f'Could not find a single pair with name: {img_name.stem}')\n                continue\n\n            id_img = utils.read_image(id_path, self.args.resolution, self.args.reals)\n            attr_img = utils.read_image(attr_path, self.args.resolution, self.args.reals)\n\n            out_img = self.G(id_img, attr_img)[0]\n\n            utils.save_image(out_img, self.args.output_dir.joinpath(f'{img_name.name}'))\n\n    def infer_on_dirs(self):\n        attr_paths = list(self.args.attr_dir.iterdir())\n        attr_paths.sort()\n\n        id_paths = list(self.args.id_dir.iterdir())\n        id_paths.sort()\n\n        for attr_num, attr_img_path in tqdm(enumerate(attr_paths)):\n            if not attr_img_path.is_file() or attr_img_path.suffix[1:] not in self.args.img_suffixes:\n                continue\n\n            attr_img = utils.read_image(attr_img_path, self.args.resolution, self.args.reals)\n\n            attr_dir = self.args.output_dir.joinpath(f'attr_{attr_num}')\n            attr_dir.mkdir(exist_ok=True)\n\n            utils.save_image(attr_img, attr_dir.joinpath(f'attr_image.png'))\n\n            for id_num, id_img_path in enumerate(id_paths):\n                if not id_img_path.is_file() or id_img_path.suffix[1:] not in self.args.img_suffixes:\n                    continue\n\n                id_img = utils.read_image(id_img_path, self.args.resolution, self.args.reals)\n\n                pred = self.G(id_img, attr_img)[0]\n\n                utils.save_image(pred, attr_dir.joinpath(f'prediction_{id_num}.png'))\n                utils.save_image(id_img, attr_dir.joinpath(f'id_{id_num}.png'))\n\n    def interpolate(self, w_space=True):\n        # Change to 0,1 for interpolation\n        extra_start = 0\n        extra_end = 1\n        L = extra_end - extra_start\n        # Extrapolation values include the 0,1 iff\n        #   N-1 is divisible by L if including endpoint\n        #   N is divisble by L o.w\n        #   where L is the length of the extrapolation range ( L = b-a for [a,b] )\n        #   and N is number of jumps\n        num_jumps = 8 * L + 1\n\n        for d in self.args.input_dir.iterdir():\n            out_d = self.args.output_dir.joinpath(d.name)\n            out_d.mkdir(exist_ok=True)\n\n            ids = list(d.glob('*id*'))\n            attrs = list(d.glob('*attr*'))\n\n            if len(ids) == 1 and len(attrs) == 2:\n                const = 'id'\n            elif len(ids) == 2 and len(attrs) == 1:\n                const = 'attr'\n            else:\n                print(f'Wrong data format for {d.name}')\n                continue\n\n            if const == 'id':\n                start_img = utils.read_image(attrs[0], self.args.resolution, self.args.real_attr)\n                end_img = utils.read_image(attrs[1], self.args.resolution, self.args.real_attr)\n                const_img = utils.read_image(ids[0], self.args.resolution, self.args.real_id)\n\n                if self.args.loop_fake:\n                    if not self.args.real_attr:\n                        start_img = self.G(start_img, start_img)\n                        end_img = self.G(end_img, end_img)\n                    if not self.args.real_id:\n                        const_img = self.G(const_img, const_img)\n\n                const_id = self.G.id_encoder(const_img)\n                start_attr = self.G.attr_encoder(start_img)\n                end_attr = self.G.attr_encoder(end_img)\n\n                s_z = tf.concat([const_id, start_attr], -1)\n                e_z = tf.concat([const_id, end_attr], -1)\n\n            elif const == 'attr':\n                start_img = utils.read_image(ids[0], self.args.resolution, self.args.real_id)\n                end_img = utils.read_image(ids[1], self.args.resolution, self.args.real_id)\n                const_img = utils.read_image(attrs[0], self.args.resolution, self.args.real_attr)\n\n                if self.args.loop_fake:\n                    if not self.args.real_attr:\n                        const_img = self.G(const_img, const_img)[0]\n                    if not self.args.real_id:\n                        start_img = self.G(start_img, start_img)[0]\n                        end_img = self.G(end_img, end_img)[0]\n\n                start_id = self.G.id_encoder(start_img)\n                end_id = self.G.id_encoder(end_img)\n\n                const_attr = self.G.attr_encoder(const_img)\n\n                s_z = tf.concat([start_id, const_attr], -1)\n                e_z = tf.concat([end_id, const_attr], -1)\n\n\n            utils.save_image(const_img, out_d.joinpath(f'const_{const}.png'))\n            utils.save_image(start_img, out_d.joinpath(f'start.png'))\n            utils.save_image(end_img, out_d.joinpath(f'end.png'))\n\n            if w_space:\n                s_w = self.G.latent_spaces_mapping(s_z)\n                e_w = self.G.latent_spaces_mapping(e_z)\n                for i in range(num_jumps):\n                    inter_w = (1 - i / num_jumps) * s_w + (i / num_jumps) * e_w\n                    out = self.G.stylegan_s(inter_w)\n                    out = (out + 1) / 2\n                    utils.save_image(out[0],\n                                     out_d.joinpath(f'inter_{i:03}.png'))\n            else:\n                for i in range(num_jumps):\n                    inter_z = (1 - i / num_jumps) * s_z + (i / num_jumps) * e_z\n                    inter_w = self.G.latent_spaces_mapping(inter_z)\n                    out = self.G.stylegan_s(inter_w)\n                    out = (out + 1) / 2\n                    utils.save_image(out[0],\n                                     out_d.joinpath(f'inter_{i:03}.png'))\n\n"
  },
  {
    "path": "main.py",
    "content": "import os\n\nos.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\nos.environ['OMP_NUM_THREADS'] = '1'\nos.environ['USE_SIMPLE_THREADED_LEVEL3'] = '1'\nos.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'\n\nimport sys\nimport logging\nfrom model.stylegan import StyleGAN_G_synthesis\nfrom model.model import Network\nfrom data_loader.data_loader import DataLoader\nfrom writer import Writer\nfrom trainer import Trainer\nfrom arglib import arglib\nfrom utils import general_utils as utils\n\nsys.path.insert(0, 'model/face_utils')\n\n\ndef init_logger(args):\n    root_logger = logging.getLogger()\n\n    level = logging.DEBUG if args.log_debug else logging.INFO\n    root_logger.setLevel(level)\n\n    file_handler = logging.FileHandler(f'{args.results_dir}/log.txt')\n    console_handler = logging.StreamHandler()\n\n    datefmt = '%Y-%m-%d %H:%M:%S'\n    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt)\n\n    file_handler.setLevel(level)\n    console_handler.setLevel(level)\n\n    file_handler.setFormatter(formatter)\n    console_handler.setFormatter(formatter)\n\n    root_logger.addHandler(file_handler)\n    root_logger.addHandler(console_handler)\n\n    pil_logger = logging.getLogger('PIL.PngImagePlugin')\n    pil_logger.setLevel(logging.INFO)\n\n\ndef main():\n    train_args = arglib.TrainArgs()\n    args, str_args = train_args.args, train_args.str_args\n    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu\n\n    init_logger(args)\n\n    logger = logging.getLogger('main')\n\n    cmd_line = ' '.join(sys.argv)\n    logger.info(f'cmd line is: \\n {cmd_line}')\n\n    logger.info(str_args)\n    logger.debug('Copying src to results dir')\n\n    Writer.set_writer(args.results_dir)\n\n    if not args.debug:\n        description = input('Please write a short description of this run\\n')\n        desc_file = args.results_dir.joinpath('description.txt')\n        with desc_file.open('w') as f:\n            f.write(description)\n\n    id_model_path = args.pretrained_models_path.joinpath('vggface2.h5')\n    stylegan_G_synthesis_path = str(\n        args.pretrained_models_path.joinpath(f'stylegan_G_{args.resolution}x{args.resolution}_synthesis'))\n    landmarks_model_path = str(args.pretrained_models_path.joinpath('face_utils/keypoints'))\n    face_detection_model_path = str(args.pretrained_models_path.joinpath('face_utils/detector'))\n\n    arcface_model_path = str(args.pretrained_models_path.joinpath('arcface_weights/weights-b'))\n    utils.landmarks_model_path = str(args.pretrained_models_path.joinpath('shape_predictor_68_face_landmarks.dat'))\n\n    stylegan_G_synthesis = StyleGAN_G_synthesis(resolution=args.resolution, is_const_noise=args.const_noise)\n    stylegan_G_synthesis.load_weights(stylegan_G_synthesis_path)\n\n    network = Network(args, id_model_path, stylegan_G_synthesis, landmarks_model_path,\n                      face_detection_model_path, arcface_model_path)\n    data_loader = DataLoader(args)\n\n    trainer = Trainer(args, network, data_loader)\n    trainer.train()\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "model/__init__.py",
    "content": ""
  },
  {
    "path": "model/arcface/arcface.py",
    "content": "import tensorflow as tf\nimport math\n\nnum_classes = 85742  # 10572\ninitializer = 'glorot_normal'\n# initializer = tf.keras.initializers.TruncatedNormal(\n#     mean=0.0, stddev=0.05, seed=None)\n# initializer = tf.keras.initializers.VarianceScaling(\n#     scale=0.05, mode='fan_avg', distribution='normal', seed=None)\n\n\nclass Arcfacelayer(tf.keras.layers.Layer):\n    def __init__(self, output_dim=num_classes, s=64., m=0.50):\n        self.output_dim = output_dim\n        self.s = s\n        self.m = m\n        super(Arcfacelayer, self).__init__()\n\n    def build(self, input_shape):\n        self.kernel = self.add_weight(name='kernel',\n                                      shape=(input_shape[-1],\n                                             self.output_dim),\n                                      initializer=initializer,\n                                      regularizer=tf.keras.regularizers.l2(\n                                          l=5e-4),\n                                      trainable=True)\n        super(Arcfacelayer, self).build(input_shape)\n\n    def call(self, embedding, labels):\n        cos_m = math.cos(self.m)\n        sin_m = math.sin(self.m)\n        mm = sin_m * self.m  # issue 1\n        threshold = math.cos(math.pi - self.m)\n        # inputs and weights norm\n        embedding_norm = tf.norm(embedding, axis=1, keepdims=True)\n        embedding = embedding / embedding_norm\n        weights_norm = tf.norm(self.kernel, axis=0, keepdims=True)\n        weights = self.kernel / weights_norm\n        # cos(theta+m)\n        cos_t = tf.matmul(embedding, weights, name='cos_t')\n        cos_t2 = tf.square(cos_t, name='cos_2')\n        sin_t2 = tf.subtract(1., cos_t2, name='sin_2')\n        sin_t = tf.sqrt(sin_t2, name='sin_t')\n        cos_mt = self.s * tf.subtract(tf.multiply(cos_t, cos_m),\n                                      tf.multiply(sin_t, sin_m), name='cos_mt')\n\n        # this condition controls the theta+m should in range [0, pi]\n        #      0<=theta+m<=pi\n        #     -m<=theta<=pi-m\n        cond_v = cos_t - threshold\n        cond = tf.cast(tf.nn.relu(cond_v, name='if_else'), dtype=tf.bool)\n\n        keep_val = self.s * (cos_t - mm)\n        cos_mt_temp = tf.where(cond, cos_mt, keep_val)\n\n        mask = tf.one_hot(labels, depth=self.output_dim, name='one_hot_mask')\n        # mask = tf.squeeze(mask, 1)\n        inv_mask = tf.subtract(1., mask, name='inverse_mask')\n\n        s_cos_t = tf.multiply(self.s, cos_t, name='scalar_cos_t')\n\n        output = tf.add(tf.multiply(s_cos_t, inv_mask), tf.multiply(\n            cos_mt_temp, mask), name='arcface_loss_output')\n\n        return output\n\n    def compute_output_shape(self, input_shape):\n        return (input_shape[0], self.output_dim)\n"
  },
  {
    "path": "model/arcface/inference.py",
    "content": "import tensorflow as tf\nimport tensorflow_addons as tfa\nimport numpy as np\nimport cv2\nfrom model.arcface.resnet import ResNet50, train_model\nfrom mtcnn import MTCNN\nfrom skimage import transform as trans\n\n\nclass MyArcFace:\n    def __init__(self, path_to_weights):\n        self.model = train_model()\n        self.model.load_weights(path_to_weights)\n        self.model_resnet = self.model.resnet\n        self.model.resnet.trainable = False\n        self.mtcnn = MTCNN(min_face_size=80)\n\n    def get_best_face(self, faces, resolution):\n        if len(faces) == 0:\n            raise IndexError('No faces found')\n        if len(faces) == 1:\n            return faces[0]\n\n        print('Found more than one face')\n\n        indices = list(range(len(faces)))\n\n        # filter low confidence\n        new_indices = [ind for ind in indices if faces[ind]['confidence'] > 0.99]\n        # print(f'after confidence filtering: {len(new_indices)}')\n        if len(new_indices) == 1:\n            return faces[new_indices[0]]\n        elif len(new_indices) > 1:\n            indices = new_indices\n\n        # filter not centered, distance between x and y must relatively small\n        new_indices = [ind for ind in indices if np.abs(faces[ind]['box'][0] - faces[ind]['box'][1]) < resolution / 2.5]\n        # print(f'after center filtering: {len(new_indices)}')\n        if len(new_indices) == 1:\n            return faces[new_indices[0]]\n        elif len(new_indices) > 1:\n            indices = new_indices\n\n        # Take box with biggest height\n        ind = max(indices, key=lambda ind: faces[ind]['box'][-1])\n        return faces[ind]\n\n    def __detect_face(self, img):\n        # The assumption is that the image is RGB\n        faces = self.mtcnn.detect_faces(img)\n        face_obj = self.get_best_face(faces, img.shape[0])\n\n        face_box_obj = face_obj['box']\n        face_landmarks_obj = face_obj['keypoints']\n        face_landmarks = np.zeros((5, 2))\n        face_landmarks[0] = [face_landmarks_obj['left_eye'][0], face_landmarks_obj['right_eye'][1]]\n        face_landmarks[1] = [face_landmarks_obj['right_eye'][0], face_landmarks_obj['left_eye'][1]]\n        face_landmarks[2] = [face_landmarks_obj['nose'][0], face_landmarks_obj['nose'][1]]\n        face_landmarks[3] = [face_landmarks_obj['mouth_left'][0], face_landmarks_obj['mouth_right'][1]]\n        face_landmarks[4] = [face_landmarks_obj['mouth_right'][0], face_landmarks_obj['mouth_left'][1]]\n        x = face_box_obj[0]\n        y = face_box_obj[1]\n        w = face_box_obj[2]\n        h = face_box_obj[3]\n        face_box = [x, y, x + w, y + h]\n        return face_box, face_landmarks\n\n    def __preprocess(self, img, bbox=None, landmark=None):\n        M = None\n        image_size = [112, 112]\n        assert landmark is not None\n        src = np.array([\n            [30.2946, 51.6963],\n            [65.5318, 51.5014],\n            [48.0252, 71.7366],\n            [33.5493, 92.3655],\n            [62.7299, 92.2041]], dtype=np.float32)\n        if image_size[1] == 112:\n            src[:, 0] += 8.0\n        dst = landmark.astype(np.float32)\n        tform = trans.SimilarityTransform()\n        tform.estimate(src, dst)\n        M = tform.params\n        assert M is not None\n        transforms = np.array(M).flatten()[:-1]\n        tf_transforms = tf.constant([transforms], tf.float32)\n        img_tensor = tf.convert_to_tensor(img.astype(np.float32))\n        batch = tf.stack([img_tensor])\n        output = tfa.image.transform(batch, tf_transforms, interpolation='BILINEAR', output_shape=image_size)\n        return output\n\n    def process_image(self, img):\n\n        if (isinstance(img, tf.Tensor) and img.dtype != tf.dtypes.uint8) or img.dtype != np.uint8:\n            img = np.uint8(img * 255)\n\n        face_box, face_landmarks = self.__detect_face(img)\n        aligned_face = self.__preprocess(img, face_box, face_landmarks)\n        aligned_face -= 127.5\n        aligned_face *= 0.0078125\n        embeddings = self.model_resnet(aligned_face)\n        normelized_embeddings = tf.math.l2_normalize(embeddings)\n        return normelized_embeddings\n\n    def __call__(self, img):\n        if img.ndim == 4:\n            embedding_list = []\n            for x in img:\n                norm_embedding = self.process_image(x)\n                embedding_list.append(norm_embedding)\n            return np.array(embedding_list)\n        else:\n            return self.process_image(img)\n"
  },
  {
    "path": "model/arcface/resnet.py",
    "content": "import tensorflow as tf\nimport os\nfrom model.arcface.arcface import Arcfacelayer\n\nbn_axis = -1\ninitializer = 'glorot_normal'\n\n\ndef residual_unit_v3(input, num_filter, stride, dim_match, name):\n    x = tf.keras.layers.BatchNormalization(axis=bn_axis,\n                                           scale=True,\n                                           momentum=0.9,\n                                           epsilon=2e-5,\n                                           #    beta_regularizer=tf.keras.regularizers.l2(\n                                           #        l=5e-4),\n                                           gamma_regularizer=tf.keras.regularizers.l2(\n                                               l=5e-4),\n                                           name=name + '_bn1')(input)\n    x = tf.keras.layers.ZeroPadding2D(\n        padding=(1, 1), name=name + '_conv1_pad')(x)\n    x = tf.keras.layers.Conv2D(num_filter, (3, 3),\n                               strides=(1, 1),\n                               padding='valid',\n                               kernel_initializer=initializer,\n                               use_bias=False,\n                               kernel_regularizer=tf.keras.regularizers.l2(\n                                   l=5e-4),\n                               name=name + '_conv1')(x)\n    x = tf.keras.layers.BatchNormalization(axis=bn_axis,\n                                           scale=True,\n                                           momentum=0.9,\n                                           epsilon=2e-5,\n                                           #    beta_regularizer=tf.keras.regularizers.l2(\n                                           #        l=5e-4),\n                                           gamma_regularizer=tf.keras.regularizers.l2(\n                                               l=5e-4),\n                                           name=name + '_bn2')(x)\n    x = tf.keras.layers.PReLU(name=name + '_relu1',\n                              alpha_regularizer=tf.keras.regularizers.l2(\n                                  l=5e-4))(x)\n    x = tf.keras.layers.ZeroPadding2D(\n        padding=(1, 1), name=name + '_conv2_pad')(x)\n    x = tf.keras.layers.Conv2D(num_filter, (3, 3),\n                               strides=stride,\n                               padding='valid',\n                               kernel_initializer=initializer,\n                               use_bias=False,\n                               kernel_regularizer=tf.keras.regularizers.l2(\n                                   l=5e-4),\n                               name=name + '_conv2')(x)\n    x = tf.keras.layers.BatchNormalization(axis=bn_axis,\n                                           scale=True,\n                                           momentum=0.9,\n                                           epsilon=2e-5,\n                                           #    beta_regularizer=tf.keras.regularizers.l2(\n                                           #        l=5e-4),\n                                           gamma_regularizer=tf.keras.regularizers.l2(\n                                               l=5e-4),\n                                           name=name + '_bn3')(x)\n    if (dim_match):\n        shortcut = input\n    else:\n        shortcut = tf.keras.layers.Conv2D(num_filter, (1, 1),\n                                          strides=stride,\n                                          padding='valid',\n                                          kernel_initializer=initializer,\n                                          use_bias=False,\n                                          kernel_regularizer=tf.keras.regularizers.l2(\n                                              l=5e-4),\n                                          name=name + '_conv1sc')(input)\n        shortcut = tf.keras.layers.BatchNormalization(axis=bn_axis,\n                                                      scale=True,\n                                                      momentum=0.9,\n                                                      epsilon=2e-5,\n                                                      #   beta_regularizer=tf.keras.regularizers.l2(\n                                                      #       l=5e-4),\n                                                      gamma_regularizer=tf.keras.regularizers.l2(\n                                                          l=5e-4),\n                                                      name=name + '_sc')(shortcut)\n    return x + shortcut\n\n\ndef get_fc1(input):\n    x = tf.keras.layers.BatchNormalization(axis=bn_axis,\n                                           scale=True,\n                                           momentum=0.9,\n                                           epsilon=2e-5,\n                                           #    beta_regularizer=tf.keras.regularizers.l2(\n                                           #        l=5e-4),\n                                           gamma_regularizer=tf.keras.regularizers.l2(\n                                               l=5e-4),\n                                           name='bn1')(input)\n    x = tf.keras.layers.Dropout(0.4)(x)\n    resnet_shape = input.shape\n    x = tf.keras.layers.Reshape(\n        [resnet_shape[1] * resnet_shape[2] * resnet_shape[3]], name='reshapelayer')(x)\n    x = tf.keras.layers.Dense(512,\n                              name='E_DenseLayer', kernel_initializer=initializer,\n                              kernel_regularizer=tf.keras.regularizers.l2(\n                                  l=5e-4),\n                              bias_regularizer=tf.keras.regularizers.l2(\n                                  l=5e-4))(x)\n    x = tf.keras.layers.BatchNormalization(axis=-1,\n                                           scale=False,\n                                           momentum=0.9,\n                                           epsilon=2e-5,\n                                           #    beta_regularizer=tf.keras.regularizers.l2(\n                                           #        l=5e-4),\n                                           name='fc1')(x)\n    return x\n\n\ndef ResNet50():\n\n    input_shape = [112, 112, 3]\n    filter_list = [64, 64, 128, 256, 512]\n    units = [3, 4, 14, 3]\n    num_stages = 4\n\n    img_input = tf.keras.layers.Input(shape=input_shape)\n\n    x = tf.keras.layers.ZeroPadding2D(\n        padding=(1, 1), name='conv0_pad')(img_input)\n    x = tf.keras.layers.Conv2D(64, (3, 3),\n                               strides=(1, 1),\n                               padding='valid',\n                               kernel_initializer=initializer,\n                               use_bias=False,\n                               kernel_regularizer=tf.keras.regularizers.l2(\n                                   l=5e-4),\n                               name='conv0')(x)\n    x = tf.keras.layers.BatchNormalization(axis=bn_axis,\n                                           scale=True,\n                                           momentum=0.9,\n                                           epsilon=2e-5,\n                                           #    beta_regularizer=tf.keras.regularizers.l2(\n                                           #        l=5e-4),\n                                           gamma_regularizer=tf.keras.regularizers.l2(\n                                               l=5e-4),\n                                           name='bn0')(x)\n    # x = tf.keras.layers.Activation('prelu')(x)\n    x = tf.keras.layers.PReLU(\n        name='prelu0',\n        alpha_regularizer=tf.keras.regularizers.l2(\n            l=5e-4))(x)\n\n    for i in range(num_stages):\n        x = residual_unit_v3(x, filter_list[i + 1], (2, 2), False,\n                             name='stage%d_unit%d' % (i + 1, 1))\n        for j in range(units[i] - 1):\n            x = residual_unit_v3(x, filter_list[i + 1], (1, 1),\n                                 True, name='stage%d_unit%d' % (i + 1, j + 2))\n\n    x = get_fc1(x)\n\n    # Create model.\n    model = tf.keras.models.Model(img_input, x, name='resnet50')\n    model.trainable = True\n    for i in range(len(model.layers)):\n        model.layers[i].trainable = True\n        # if ('conv0' in model.layers[i].name):\n        #     model.layers[i].trainable = False\n        # if ('bn0' in model.layers[i].name):\n        #     model.layers[i].trainable = False\n        # if ('prelu0' in model.layers[i].name):\n        #     model.layers[i].trainable = False\n        # if ('stage1' in model.layers[i].name):\n        #     model.layers[i].trainable = False\n        # if ('stage2' in model.layers[i].name):\n        #     model.layers[i].trainable = False\n        # if ('stage3' in model.layers[i].name):\n        #     model.layers[i].trainable = False\n        # if ('stage4' in model.layers[i].name):\n        #     model.layers[i].trainable = False\n\n    return model\n\n\nclass train_model(tf.keras.Model):\n    def __init__(self):\n        super(train_model, self).__init__()\n        self.resnet = ResNet50()\n        self.arcface = Arcfacelayer()\n\n    def call(self, x, y):\n        x = self.resnet(x)\n        return self.arcface(x, y)\n"
  },
  {
    "path": "model/attr_encoder.py",
    "content": "import logging\n\nimport tensorflow as tf\nfrom tensorflow.keras import Model\nfrom tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input\n\n\nclass AttrEncoder(Model):\n    def __init__(self, args):\n        super().__init__()\n        self.args = args\n        self.logger = logging.getLogger(__class__.__name__)\n\n        attr_encoder = InceptionV3(include_top=False, pooling='avg')\n        self.model = attr_encoder\n\n        if self.args.load_checkpoint:\n            self.model.load_weights(str(self.args.load_checkpoint.joinpath(self.__class__.__name__ + '.h5')))\n\n    @tf.function\n    def call(self, input_x):\n        x = tf.image.resize(input_x, (299, 299))\n        x = preprocess_input(255 * x)\n        x = self.model(x)\n        x = tf.expand_dims(x, 1)\n\n        return x\n\n    def my_save(self, reason=''):\n        self.model.save_weights(str(self.args.weights_dir.joinpath(self.__class__.__name__ + reason + '.h5')))\n"
  },
  {
    "path": "model/discriminator.py",
    "content": "import tensorflow as tf\nfrom tensorflow.keras import layers, Model\nfrom utils.general_utils import get_weights\n\n\n# Discriminate between my w's and StyleGAN's w's\nclass W_D(Model):\n    def __init__(self, args):\n        super().__init__()\n        self.args = args\n        slope = 0.2\n\n        # self.linear1 = layers.Dense(512, kernel_initializer=get_weights(slope), input_shape=(512,))\n        self.linear2 = layers.Dense(256, kernel_initializer=get_weights(slope), input_shape=(512,))\n        self.linear3 = layers.Dense(128, kernel_initializer=get_weights(slope))\n        self.linear4 = layers.Dense(64, kernel_initializer=get_weights(slope))\n        self.linear5 = layers.Dense(1, kernel_initializer=get_weights(slope))\n        self.relu = layers.LeakyReLU(slope)\n\n        if self.args.load_checkpoint:\n            self.build(input_shape=(1, 1, 512))\n            self.load_weights(str(self.args.load_checkpoint.joinpath(self.__class__.__name__ + '.h5')))\n\n    @tf.function\n    def call(self, x):\n        # x = self.linear1(x)\n        # x = self.relu(x)\n        x = self.linear2(x)\n        x = self.relu(x)\n        x = self.linear3(x)\n        x = self.relu(x)\n        x = self.linear4(x)\n        x = self.relu(x)\n        x = self.linear5(x)\n\n        return x\n\n    def my_save(self, reason=''):\n        self.save_weights(str(self.args.weights_dir.joinpath(self.__class__.__name__ + reason + '.h5')))\n"
  },
  {
    "path": "model/face_detector.py",
    "content": "import tensorflow as tf\n\n\nclass FaceDetector(object):\n    def __init__(self, args, model_path):\n        super().__init__()\n        self.args = args\n        self.model_path = model_path\n        self.model = None\n\n    def _build(self):\n        if not self.model:\n            self.model = tf.saved_model.load(self.model_path)\n\n    def __call__(self, input_x):\n        \"\"\"\n        Given a batch of images, return the face bounding box in (x1,y1,x2,y2) format\n        \"\"\"\n\n        if not self.model:\n            self._build()\n\n        boxes = []\n        for sample in input_x:\n            boxes.append(self.sample_call(sample))\n\n        boxes = tf.stack(boxes, axis=0)\n        boxes = boxes * self.args.resolution\n\n        return boxes\n\n    def sample_call(self, input_x):\n        boxes = self.model.inference(tf.expand_dims(input_x, axis=0))\n        boxes = tf.squeeze(boxes)\n        indices, scores = \\\n            tf.image.non_max_suppression_with_scores(boxes[..., :4], boxes[..., 4],\n                                                     max_output_size=1, iou_threshold=0.3, score_threshold=0.5)\n        i = indices.numpy()[0]\n        box = boxes[i, :4]\n        return box\n"
  },
  {
    "path": "model/generator.py",
    "content": "import logging\n\nfrom model import id_encoder\nfrom model import attr_encoder\nfrom model import latent_mapping\nfrom model import landmarks\nfrom model.arcface.inference import MyArcFace\n\nimport tensorflow as tf\nfrom tensorflow.keras import layers, Model\n\n\nclass G(Model):\n    def __init__(self, args, id_model_path, image_G,\n                 landmarks_net_path, face_detection_model_path, test_id_model_path):\n\n        super().__init__()\n        self.args = args\n        self.logger = logging.getLogger(__class__.__name__)\n\n        self.id_encoder = id_encoder.IDEncoder(args, id_model_path)\n        self.id_encoder.trainable = False\n\n        self.attr_encoder = attr_encoder.AttrEncoder(args)\n\n        self.latent_spaces_mapping = latent_mapping.LatentMappingNetwork(args)\n\n        self.stylegan_s = image_G\n        self.stylegan_s.trainable = False\n\n        if args.train:\n            self.test_id_encoder = MyArcFace(test_id_model_path)\n            self.test_id_encoder.trainable = False\n\n            self.landmarks = landmarks.LandmarksDetector(args, landmarks_net_path, face_detection_model_path)\n            self.landmarks.trainable = False\n\n    @tf.function\n    def call(self, x1, x2):\n        id_embedding = self.id_encoder(x1)\n\n        if self.args.train:\n            lnds = self.landmarks(x2)\n        else:\n            lnds = None\n        attr_input = x2\n\n        attr_out = self.attr_encoder(attr_input)\n        attr_embedding = attr_out\n        z_tag = tf.concat([id_embedding, attr_embedding], -1)\n        w = self.latent_spaces_mapping(z_tag)\n\n        out = self.stylegan_s(w)\n\n        # Move to roughly [0,1]\n        out = (out + 1) / 2\n\n        return out, id_embedding,  attr_out, w[:, 0, :], lnds\n\n    def my_save(self, reason=''):\n        self.attr_encoder.my_save(reason)\n        self.latent_spaces_mapping.my_save(reason)\n\n\n"
  },
  {
    "path": "model/id_encoder.py",
    "content": "import tensorflow as tf\nimport numpy as np\nfrom tensorflow.keras import Model\n\nclass IDEncoder(Model):\n\n    def __init__(self, args, model_path, intermediate_layers_names=None):\n        super().__init__()\n        self.args = args\n        self.mean = (91.4953, 103.8827, 131.0912)\n        base_model = tf.keras.models.load_model(model_path)\n\n        if intermediate_layers_names:\n            outputs = [base_model.get_layer(name).output for name in intermediate_layers_names]\n        else:\n            outputs = []\n\n        # Add output of the network in any case\n        outputs.append(base_model.layers[-2].output)\n\n        self.model = tf.keras.Model(base_model.inputs, outputs)\n\n\n    def crop_faces(self, img):\n        ps = []\n        for i in range(img.shape[0]):\n            oneimg = img[i]\n            try:\n                box = tf.numpy_function(self.mtcnn.detect_faces, [oneimg], np.uint8)\n                box = [z.numpy() for z in box[:4]]\n\n                x1, y1, w, h = box\n\n                x_expand = w * 0.3\n                y_expand = h * 0.3\n\n                x1 = int(np.maximum(x1 - x_expand // 2, 0))\n                y1 = int(np.maximum(y1 - y_expand // 2, 0))\n\n                x2 = int(np.minimum(x1 + w + x_expand // 2, self.args.resolution))\n                y2 = int(np.minimum(y1 + h + y_expand // 2, self.args.resolution))\n            except Exception as e:\n                x1, y1, x2, y2 = 24, 50, 224, 250\n\n            p = oneimg[y1:y2, x1:x2, :]\n            p = tf.convert_to_tensor(p)\n            p = tf.image.resize(p, (self.args.resolution, self.args.resolution))\n            ps.append(p)\n\n        ps = tf.stack(ps, 0)\n        return ps\n\n    def preprocess(self, img):\n        \"\"\"\n        In VGGFace2 The preprocessing is:\n            1. Face detection\n            2. Expand bbox by factor of 0.3\n            3. Resize so shorter side is 256\n            4. Crop center 224x224\n\n        In StyleGAN faces are not in-the-wild, we get an image of the head.\n        Just cropping a loose center instead of face detection\n        \"\"\"\n\n        # Go from [0, 1] to [0, 255]\n        img = 255 * img\n\n        min_x = int(0.1 * self.args.resolution)\n        max_x = int(0.9 * self.args.resolution)\n        min_y = int(0.1 * self.args.resolution)\n        max_y = int(0.9 * self.args.resolution)\n\n        img = img[:, min_x:max_x, min_y:max_y, :]\n        img = tf.image.resize(img, (256, 256))\n\n        start = (256 - 224) // 2\n        img = img[:, start: 224 + start, start: 224 + start, :]\n        img = img[:, :, :, ::-1] - self.mean\n\n        return img\n\n    @tf.function\n    def call(self, input_x, get_intermediate=False):\n        x = self.preprocess(input_x)\n        x = self.model(x)\n\n        if isinstance(x, list):\n            embedding = x[-1]\n            intermediates = x[:-1]\n        else:\n            embedding = x\n            intermediates = None\n\n        embedding = tf.math.l2_normalize(embedding, axis=-1)\n        embedding = tf.expand_dims(embedding, 1)\n\n        if get_intermediate and intermediates:\n            return embedding, intermediates\n        else:\n            return embedding\n"
  },
  {
    "path": "model/landmarks.py",
    "content": "import cv2\nimport tensorflow as tf\nimport numpy as np\nfrom tensorflow.keras import Model\n\nfrom utils import general_utils as utils\nfrom model.face_detector import FaceDetector\n\n\nclass LandmarksDetector(Model):\n    def __init__(self, args, model_path, face_detection_model_path):\n        super().__init__()\n        self.args = args\n        self.face_detector = FaceDetector(args, face_detection_model_path)\n        self.expand_ratio = 0.2\n\n        # Load without source code\n        self.model = tf.saved_model.load(model_path)\n\n    # Preprocess\n    def preprocess(self, imgs, face_detection=False):\n        imgs *= 255\n        if face_detection:\n            imgs, details = self.hard_preprocess(imgs)\n        else:\n            imgs, details = self.lazy_preprocess(imgs)\n\n        return imgs, details\n\n    def lazy_preprocess(self, imgs):\n        imgs = tf.image.resize(imgs, (160, 160))\n        return imgs, 160\n\n    def hard_preprocess(self, imgs):\n        bboxes = self.face_detector(imgs)\n\n        centers = np.array([bboxes[:, 0] + bboxes[:, 2], bboxes[:, 1] + bboxes[:, 3]]).T // 2\n\n        # Duplicate center point into column order of x,x,y,y\n        centers = np.repeat(centers, repeats=2, axis=1)\n\n        # Permute columns order into x,y,x,y\n        centers[:] = utils.np_permute(centers, [0, 2, 1, 3])\n\n        # Calculate widths of current bboxes\n        widths = np.transpose([bboxes[:, 2] - bboxes[:, 0]])\n\n        # Calculate the maximal expansion\n        max_expand = int(np.ceil(np.max(widths) * self.expand_ratio))\n\n        # Pad the image with the maximal expansion.\n        # Useful in case an expanded bounding box goes outside image\n        paddings = tf.constant([[0, 0], [max_expand, max_expand], [max_expand, max_expand], [0, 0]])\n        pad_imgs = tf.pad(imgs, paddings, mode='CONSTANT', constant_values=127.)\n\n        # The size of the new square bounding box\n        new_scales = np.floor((1 + 2 * self.expand_ratio) * widths)\n\n        # Size of step from the center\n        new_half_scales = new_scales // 2\n\n        # Repeat step in all directions\n        # Decrease in start point, Increase in end point\n        new_half_scales = np.repeat(new_half_scales, repeats=4, axis=1) * [-1, -1, 1, 1]\n\n        # Bounding boxes in respect to padded image\n        new_bboxes = centers + new_half_scales + max_expand\n\n        # tf.image.crop_and_resize requires bounding boxes to be normalized\n        # i.e., between [0,1] and also in order (y,x)\n        normed_bboxes = utils.np_permute(new_bboxes, [1, 0, 3, 2]) / pad_imgs.shape[1]\n\n        cropped_imgs = tf.image.crop_and_resize(pad_imgs, normed_bboxes,\n                                                box_indices=range(self.args.batch_size), crop_size=(160, 160))\n\n        details = (new_scales, new_bboxes[:,:2], max_expand)\n        return cropped_imgs, details\n\n    # Postprocess\n    def postprocess(self, landmarks, details, face_detection=False):\n        landmarks = tf.reshape(landmarks, [-1, 68, 2])\n\n        if face_detection:\n            return self.hard_postprocess(landmarks, details)\n        else:\n            return self.lazy_postprocess(landmarks, details)\n\n    def lazy_postprocess(self, batch_lnds, details):\n        scale = details\n        return scale * batch_lnds\n\n    def hard_postprocess(self, batch_lnds, details):\n        scale, from_origin, pad = details\n\n        scale = tf.broadcast_to(scale, [scale.shape[0], 2])\n        scale = tf.expand_dims(scale, axis=1)\n\n        from_origin = tf.expand_dims(from_origin, axis=1)\n        from_origin = tf.cast(from_origin, tf.dtypes.float32)\n\n        lnds = batch_lnds * scale + from_origin - pad\n        return lnds\n\n    @tf.function\n    def call(self, input_x, face_detection=False):\n\n        # The network input format is a uint8 image (0-255) but in float32 dtype. ^__('')__^\n        x, details = self.preprocess(input_x, face_detection)\n\n        batch_lnds = self.model.inference(x)['landmark']\n\n        batch_lnds = self.postprocess(batch_lnds, details, face_detection)\n\n        return batch_lnds[:, 17:, :]\n"
  },
  {
    "path": "model/latent_mapping.py",
    "content": "from utils.general_utils import get_weights\n\nimport tensorflow as tf\nimport numpy as np\nfrom tensorflow.keras import layers, Model\n\n\nclass LatentMappingNetwork(Model):\n    def __init__(self, args):\n        super().__init__()\n        self.args = args\n\n        input_shape = (2560,)\n\n        self.linear1 = layers.Dense(2048, input_shape=input_shape)\n        self.linear2 = layers.Dense(1024)\n        self.linear3 = layers.Dense(512, kernel_initializer=get_weights())\n        self.linear4 = layers.Dense(512, kernel_initializer=get_weights())\n        self.linears = [self.linear1, self.linear2, self.linear3, self.linear4]\n\n        self.relu = layers.LeakyReLU(0.2)\n\n        self.num_styles = int(np.log2(self.args.resolution)) * 2 - 2\n\n        if self.args.load_checkpoint:\n            self.build(input_shape=(1, 1, 2560))\n            self.load_weights(str(self.args.load_checkpoint.joinpath(self.__class__.__name__ + '.h5')))\n\n    @tf.function\n    def call(self, x):\n        first = True\n        for layer in self.linears:\n            if not first:\n                x = self.relu(x)\n\n            x = layer(x)\n            first = False\n\n        s = list(x.shape)\n\n        # Duplicate the column vector w along columns for each AdaIN entry\n        s[1] = self.num_styles\n        x = tf.broadcast_to(x, s)\n\n        return x\n\n    def my_save(self, reason=''):\n        self.save_weights(str(self.args.weights_dir.joinpath(self.__class__.__name__ + reason + '.h5')))\n"
  },
  {
    "path": "model/model.py",
    "content": "import time\nimport sys\n\nsys.path.append('..')\n\nfrom utils import general_utils as utils\nfrom model import id_encoder, latent_mapping, attr_encoder,\\\n    generator, discriminator, landmarks\n\nfrom model.stylegan import StyleGAN_G, StyleGAN_D\n\nimport tensorflow as tf\nfrom tensorflow.keras import layers, Model\n\n\nclass Network(Model):\n    def __init__(self, args, id_net_path, base_generator,\n                 landmarks_net_path=None, face_detection_model_path=None, test_id_net_path=None):\n        super().__init__()\n        self.args = args\n        self.G = generator.G(args, id_net_path, base_generator,\n                             landmarks_net_path, face_detection_model_path, test_id_net_path)\n\n        if self.args.train:\n            self.W_D = discriminator.W_D(args)\n\n    def call(self):\n        raise NotImplemented()\n\n    def my_save(self, reason):\n        self.G.my_save(reason)\n\n        if self.args.W_D_loss:\n            self.W_D.my_save(reason)\n\n    def my_load(self):\n        raise NotImplemented()\n\n    def train(self):\n        self._set_trainable_behavior(True)\n\n    def test(self):\n        self._set_trainable_behavior(False)\n\n    def _set_trainable_behavior(self, trainable):\n        self.G.attr_encoder.trainable = trainable\n        self.G.latent_spaces_mapping.trainable = trainable\n"
  },
  {
    "path": "model/stylegan.py",
    "content": "import sys\nimport math\nimport numpy as np\nimport tensorflow as tf\n\nimport matplotlib.pyplot as plt\n\nfrom tensorflow.keras import Model, Sequential\nfrom tensorflow.keras.layers import Layer, InputLayer, Multiply, Lambda, Flatten, Dense, Conv2D, Conv2DTranspose\nfrom tensorflow.keras.initializers import VarianceScaling\n\nfrom tensorflow.python.keras import backend\nfrom tensorflow.python.ops import array_ops\n\ndef nf(stage, fmap_base=8192, fmap_decay=1.0, fmap_max=512): \n    return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)\n\ndef LeakyReLU(alpha, name):\n    def lrelu(x, alpha):\n        alpha = tf.constant(alpha, dtype=x.dtype, name='alpha')\n        return tf.maximum(x, x * alpha)\n    return Lambda(lambda x: lrelu(x, alpha), name=name)\n\ndef GetWeights(gain=math.sqrt(2)):\n    return VarianceScaling(gain)\n\ndef runtime_coef(kernel_size, gain, fmaps_in, fmaps_out, lrmul=1.0):\n    # Equalized learning rate and custom learning rate multiplier.\n    shape = [kernel_size[0], kernel_size[1], fmaps_in, fmaps_out]\n    fan_in = np.prod(shape[:-1]) # [kernel, kernel, fmaps_in, fmaps_out] or [in, out]\n    he_std = gain / np.sqrt(fan_in) # He init\n    init_std = 1.0 / lrmul\n    return he_std * lrmul \n\ndef pixel_norm(x, epsilon=1e-8):\n    epsilon = tf.constant(epsilon, dtype=x.dtype, name='epsilon')\n    return x * tf.math.rsqrt(tf.reduce_mean(tf.square(x), axis=1, keepdims=True) + epsilon)\n\nclass PixelNorm(Layer):\n    def __init__(self, name):\n        super(PixelNorm, self).__init__(name=name)\n    \n    def call(self, inputs):\n        return pixel_norm(inputs)\n    \nclass InstanceNorm(Layer):\n    def __init__(self, name):\n        super(InstanceNorm, self).__init__(name=name)\n    \n    def call(self, x):\n        epsilon=1e-8\n        orig_dtype = x.dtype\n        x = tf.cast(x, tf.float32)\n        x -= tf.reduce_mean(x, axis=[2,3], keepdims=True)\n        epsilon = tf.constant(epsilon, dtype=x.dtype, name='epsilon')\n        x *= tf.math.rsqrt(tf.reduce_mean(tf.square(x), axis=[2,3], keepdims=True) + epsilon)\n        x = tf.cast(x, orig_dtype)\n        return x\n    \ndef Identity(name):\n    return Lambda(lambda x: x, name=name)\n\ndef Broadcast(name, dlatent_broadcast=18):\n    def broadcast(x):\n        return tf.tile(x[:, np.newaxis], [1, dlatent_broadcast, 1])\n    return Lambda(lambda x: broadcast(x), name=name)\n\nclass Truncation(Layer):\n    def __init__(self, name, num_layers=18, truncation_psi=0.7, truncation_cutoff=8):\n        super(Truncation, self).__init__(name=name)\n        self.num_layers = num_layers\n        self.truncation_psi = truncation_psi\n        self.truncation_cutoff = truncation_cutoff\n\n    def build(self, input_shape):\n        self.dlatent_avg = self.add_variable('dlatent_avg', shape=[int(input_shape[-1])])\n\n    def call(self, inputs):\n        layer_idx = np.arange(self.num_layers)[np.newaxis, :, np.newaxis]\n        ones = np.ones(layer_idx.shape, dtype=np.float32)\n        coefs = tf.where(layer_idx < self.truncation_cutoff, self.truncation_psi * ones, ones)\n        \n        def lerp(a,b,t): return a + (b - a) * t\n        \n        return lerp(self.dlatent_avg, inputs, coefs)\n\nclass DenseLayer(Dense):\n    def __init__(self, units, name, kernel_initializer=GetWeights(), gain=math.sqrt(2), lrmul=1.0):\n        super(DenseLayer, self).__init__(units=units, kernel_initializer=kernel_initializer, name=name)\n        self.gain = gain\n        self.lrmul = lrmul\n    \n    def call(self, inputs):\n        x, b, w = inputs, self.bias * self.lrmul, self.kernel * runtime_coef([1,1], self.gain, inputs.shape[1], self.units, lrmul=self.lrmul)\n        \n        # Input x kernel\n        if len(x.shape) > 2: x = tf.reshape(x, [-1, np.prod([d for d in x.shape[1:]])])\n        x = tf.matmul(x, w)\n        \n        # Bias\n        if len(x.shape) == 2:\n            return x + b\n        \n        return x + tf.reshape(b, [1, -1, 1, 1])\n\nclass Conv2d(Conv2D):\n    def __init__(self, filters, kernel_size, name, gain=math.sqrt(2), lrmul=1.0, kernel_modifier=None, strides=1, use_bias=True):\n        super(Conv2d, self).__init__(filters=filters, kernel_size=kernel_size, kernel_initializer=GetWeights(gain), \n                                     use_bias=use_bias, padding='same', data_format='channels_first', name=name, strides=strides)\n        self.gain = gain\n        self.lrmul = lrmul\n        self.kernel_modifier = kernel_modifier\n\n    # Perform convolution with modified kernel then add bias\n    def call(self, inputs):\n        if self.kernel_modifier is None:\n            w = self.kernel\n        else:\n            w = self.kernel_modifier(self.kernel)\n            \n        outputs = self._convolution_op(inputs, w * runtime_coef(self.kernel_size, self.gain, inputs.shape[1], self.filters))\n        \n        if self.use_bias:\n            b = self.bias * self.lrmul        \n            if self.data_format == 'channels_first':\n                outputs = tf.nn.bias_add(outputs, b, data_format='NCHW')\n            else:\n                outputs = tf.nn.bias_add(outputs, b, data_format='NHWC')\n\n        return outputs\n\nclass Const(Layer):\n    def __init__(self, name):\n        super(Const, self).__init__(name=name)\n        \n    def build(self, input_shape):\n        self.const = self.add_variable('const', shape=[1,512,4,4])\n        \n    def call(self, inputs):\n        return tf.tile(self.const, [tf.shape(inputs)[0], 1, 1, 1])\n    \nclass RandomNoise(Layer):\n    def __init__(self, name, layer_idx):\n        super(RandomNoise, self).__init__(name=name)\n        \n        res = layer_idx // 2 + 2        \n        self.layer_idx = layer_idx\n        self.noise_shape = [1, 1, 2**res, 2**res]\n    \n    def build(self, input_shape):\n        self.noise = self.add_variable('noise', shape=self.noise_shape, initializer=tf.initializers.zeros(), trainable=False)\n        \n    def call(self, inputs):\n        return self.noise\n    \nclass ApplyNoise(Layer):\n    def __init__(self, name, is_const_noise):\n        super(ApplyNoise, self).__init__(name=name)        \n        self.is_const_noise = is_const_noise\n\n    def build(self, input_shape):\n        input_shape = input_shape[0]\n        self.weight = self.add_variable('weight', shape=[input_shape[1]], initializer=tf.initializers.zeros())\n        \n    def call(self, inputs):\n        x, noise = inputs\n        if not self.is_const_noise:\n            noise = tf.random.normal([tf.shape(x)[0], 1, x.shape[2], x.shape[3]], dtype=x.dtype)\n\n        return x + noise * tf.reshape(self.weight, [1, -1, 1, 1])\n    \nclass ApplyBias(Layer):\n    def __init__(self, name, lrmul=1.0):\n        super(ApplyBias, self).__init__(name=name)\n        self.lrmul = lrmul\n        \n    def build(self, input_shape):\n        self.bias = self.add_variable('bias', shape=[input_shape[1]])\n        \n    def call(self, x):\n        b = self.bias * self.lrmul\n        if len(x.shape) == 2: return x + b\n        return x + tf.reshape(b, [1, -1, 1, 1])\n\nclass StridedSlice(Layer):\n    def __init__(self, layer_idx, name):\n        super(StridedSlice, self).__init__(name=name)\n        self.layer_idx = layer_idx\n    \n    def call(self, inputs):\n        return inputs[:, self.layer_idx]\n    \nclass StyleModApply(Layer):\n    def __init__(self, name):\n        super(StyleModApply, self).__init__(name=name)\n    \n    def call(self, inputs):\n        x, style = inputs\n        \n        style = tf.reshape(style, [-1, 2, x.shape[1]] + [1] * (len(x.shape) - 2))\n        return x * (style[:,0] + 1) + style[:,1]\n\ndef _blur2d(x, f=[1,2,1], normalize=True, flip=False, stride=1):\n    assert x.shape.ndims == 4 and all(dim is not None for dim in x.shape[1:])\n    assert isinstance(stride, int) and stride >= 1\n\n    # Finalize filter kernel.\n    f = np.array(f, dtype=np.float32)\n    if f.ndim == 1:\n        f = f[:, np.newaxis] * f[np.newaxis, :]\n    assert f.ndim == 2\n    if normalize:\n        f /= np.sum(f)\n    if flip:\n        f = f[::-1, ::-1]\n    f = f[:, :, np.newaxis, np.newaxis]\n    f = np.tile(f, [1, 1, int(x.shape[1]), 1])\n\n    # No-op => early exit.\n    if f.shape == (1, 1) and f[0,0] == 1:\n        return x\n\n    # Convolve using depthwise_conv2d.\n    orig_dtype = x.dtype\n    x = tf.cast(x, tf.float32)  # tf.nn.depthwise_conv2d() doesn't support fp16\n    f = tf.constant(f, dtype=x.dtype, name='filter')\n    strides = [1, 1, stride, stride]\n    x = tf.nn.depthwise_conv2d(x, f, strides=strides, padding='SAME', data_format='NCHW')\n    x = tf.cast(x, orig_dtype)\n    return x\n\ndef Blur(name, blur_filter=[1,2,1]):\n    def blur2d(x, f=[1,2,1], normalize=True):\n        return _blur2d(x, f, normalize)\n    return Lambda(lambda x: blur2d(x, blur_filter), name=name)\n\ndef _downscale2d(x, factor=2, gain=1):\n    assert x.shape.ndims == 4 and all(dim is not None for dim in x.shape[1:])\n    assert isinstance(factor, int) and factor >= 1\n\n    # 2x2, float32 => downscale using _blur2d().\n    if factor == 2 and x.dtype == tf.float32:\n        f = [np.sqrt(gain) / factor] * factor\n        return _blur2d(x, f=f, normalize=False, stride=factor)\n\n    # Apply gain.\n    if gain != 1:\n        x *= gain\n\n    # No-op => early exit.\n    if factor == 1:\n        return x\n\n    # Large factor => downscale using tf.nn.avg_pool().\n    # NOTE: Requires tf_config['graph_options.place_pruned_graph']=True to work.\n    ksize = [1, 1, factor, factor]\n    return tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding='VALID', data_format='NCHW')\n\ndef _upscale2d(x, factor=2, gain=1):\n    assert x.shape.ndims == 4 and all(dim is not None for dim in x.shape[1:])\n    assert isinstance(factor, int) and factor >= 1\n\n    # Apply gain.\n    if gain != 1:\n        x *= gain\n\n    # No-op => early exit.\n    if factor == 1:\n        return x\n\n    # Upscale using tf.tile().\n    s = x.shape\n    x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1])\n    x = tf.tile(x, [1, 1, 1, factor, 1, factor])\n    x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor])\n    return x\n\t\ndef Downscaled2d(name, factor=2, gain=1):\n    return Lambda(lambda x: _downscale2d(x, factor, gain), name=name+'/Downscaled2d')\n\t\ndef Upscaled2d(name, factor=2, gain=1):\n    return Lambda(lambda x: _upscale2d(x, factor, gain), name=name+'/Upscaled2d')\n\ndef Conv2d_downscale2d(model, filters, kernel_size, name, gain=math.sqrt(2), fused_scale='auto'):\n    if fused_scale == 'auto':\n        x = model.layers[-1].output\n        fused_scale = min(x.shape[2:]) >= 128\n        \n    if not fused_scale:\n        # Not fused => call the individual ops directly.\n        model.add( Conv2d(filters, kernel_size, name, gain) )\n        model.add( Downscaled2d(name) )        \n    else:\n        # Fused => perform both ops simultaneously using tf.nn.conv2d().\n        def fused_op(w):\n            w = tf.pad(w, [[1,1], [1,1], [0,0], [0,0]], mode='CONSTANT')\n            w = tf.add_n([w[1:, 1:], w[:-1, 1:], w[1:, :-1], w[:-1, :-1]]) * 0.25\n            return w\n        model.add( Conv2d(filters, kernel_size, name, gain, kernel_modifier=fused_op, strides=2) )\n\t\t\ndef Upscale2d_conv2d(x, filters, kernel_size, name, use_bias, gain=math.sqrt(2), fused_scale='auto'):\n    if fused_scale == 'auto':\n        fused_scale = min(x.shape[2:]) * 2 >= 128\n\n    if not fused_scale:\n        x = Upscaled2d(name)(x)\n        x = Conv2d(filters, kernel_size, name=name, gain=gain, use_bias=use_bias)(x)\n        return x\n\n    return Conv2d_transpose(filters, kernel_size, name, gain, strides=2)(x)\n\nclass Conv2d_transpose(Conv2DTranspose):\n    def __init__(self, filters, kernel_size, name, gain=math.sqrt(2), lrmul=1.0, kernel_modifier=None, strides=2, use_bias=False):\n        \n        super(Conv2d_transpose, self).__init__(filters=filters, kernel_size=kernel_size, kernel_initializer=GetWeights(gain), \n                                     use_bias=use_bias, padding='same', data_format='channels_first', name=name, strides=strides)\n        self.gain = gain\n        self.lrmul = lrmul\n        self.kernel_modifier = kernel_modifier\n        \n    def build(self, input_shape):\n        shape = [self.kernel_size[0], self.kernel_size[1], input_shape[1], self.filters]\n        self.kernel = self.add_variable('weight', shape=shape, initializer=tf.initializers.zeros())\n            \n    def call(self, inputs):\n        # Fused => perform both ops simultaneously using tf.nn.conv2d_transpose().\n        def fused_op(w):\n            w = tf.transpose(w, [0, 1, 3, 2]) # [kernel, kernel, fmaps_out, fmaps_in]\n            w = tf.pad(w, [[1,1], [1,1], [0,0], [0,0]], mode='CONSTANT')\n            w = tf.add_n([w[1:, 1:], w[:-1, 1:], w[1:, :-1], w[:-1, :-1]])\n            return w\n\n        x, w = inputs, fused_op(self.kernel * runtime_coef(self.kernel_size, self.gain, inputs.shape[1], self.filters, lrmul=self.lrmul))\n        \n        os = [tf.shape(inputs)[0], self.filters, inputs.shape[2] * 2, inputs.shape[3] * 2]\n        \n        outputs = tf.nn.conv2d_transpose(x, w, os, strides=[1,1,2,2], padding='SAME', data_format='NCHW')\n        \n        return outputs\n\n\nclass MinibatchStddevLayer(tf.keras.layers.Layer):\n    def __init__(self, group_size =4, num_new_features=1):\n        super().__init__()\n        self.group_size = group_size\n        self.num_new_features = num_new_features\n\n    def __call__(self, x, *args, **kwargs):\n        group_size = tf.minimum(self.group_size,\n                                tf.shape(x)[0])  # Minibatch must be divisible by (or smaller than) group_size.\n        s = x.shape  # [NCHW]  Input shape.\n        y = tf.reshape(x, [group_size, -1, self.num_new_features, s[1] // self.num_new_features, s[2], s[\n            3]])  # [GMncHW] Split minibatch into M groups of size G. Split channels into n channel groups c.\n        y = tf.cast(y, tf.float32)  # [GMncHW] Cast to FP32.\n        y -= tf.reduce_mean(y, axis=0, keepdims=True)  # [GMncHW] Subtract mean over group.\n        y = tf.reduce_mean(tf.square(y), axis=0)  # [MncHW]  Calc variance over group.\n        y = tf.sqrt(y + 1e-8)  # [MncHW]  Calc stddev over group.\n        y = tf.reduce_mean(y, axis=[2, 3, 4], keepdims=True)  # [Mn111]  Take average over fmaps and pixels.\n        y = tf.reduce_mean(y, axis=[2])  # [Mn11] Split channels into c channel groups\n        y = tf.cast(y, x.dtype)  # [Mn11]  Cast back to original data type.\n        y = tf.tile(y, [group_size, 1, s[2], s[3]])  # [NnHW]  Replicate over group and pixels.\n        return tf.concat([x, y], axis=1)  # [NCHW]  Append as new fmap.\n\ndef minibatch_stddev_layer(x, group_size=4, num_new_features=1):\n    with tf.compat.v1.variable_scope('MinibatchStddev'):\n        group_size = tf.minimum(group_size, tf.shape(x)[0])     # Minibatch must be divisible by (or smaller than) group_size.\n        s = x.shape                                             # [NCHW]  Input shape.\n        y = tf.reshape(x, [group_size, -1, num_new_features, s[1]//num_new_features, s[2], s[3]])   # [GMncHW] Split minibatch into M groups of size G. Split channels into n channel groups c.\n        y = tf.cast(y, tf.float32)                              # [GMncHW] Cast to FP32.\n        y -= tf.reduce_mean(y, axis=0, keepdims=True)           # [GMncHW] Subtract mean over group.\n        y = tf.reduce_mean(tf.square(y), axis=0)                # [MncHW]  Calc variance over group.\n        y = tf.sqrt(y + 1e-8)                                   # [MncHW]  Calc stddev over group.\n        y = tf.reduce_mean(y, axis=[2,3,4], keepdims=True)      # [Mn111]  Take average over fmaps and pixels.\n        y = tf.reduce_mean(y, axis=[2])                         # [Mn11] Split channels into c channel groups\n        y = tf.cast(y, x.dtype)                                 # [Mn11]  Cast back to original data type.\n        y = tf.tile(y, [group_size, 1, s[2], s[3]])             # [NnHW]  Replicate over group and pixels.\n        return tf.concat([x, y], axis=1)                        # [NCHW]  Append as new fmap.\n\t\t\ndef StyleGAN_G_mapping( latent_size=512, dlatent_size=512, mapping_layers=8, mapping_fmaps=512, mapping_lrmul=0.01,\n                        truncation_psi=1, resolution=1024):\n\n    resolution_log2 = int(np.log2(resolution))\n    num_layers = resolution_log2 * 2 - 2\n\n    model = Sequential(name='G_mapping')\n    model.add( InputLayer(input_shape=[latent_size], name='G_mapping/latents_in') ) \n\n    # Normalize latents.\n    model.add( PixelNorm(name='G_mapping/PixelNorm') )\n    \n    # Mapping layers.\n    for layer_idx in range(mapping_layers):\n        name = 'G_mapping/Dense{}'.format(layer_idx)\n        fmaps = dlatent_size if layer_idx == mapping_layers - 1 else mapping_fmaps\n        \n        model.add( DenseLayer(units=fmaps, kernel_initializer=GetWeights(), name=name, lrmul=mapping_lrmul) )\n        model.add( LeakyReLU(alpha=0.2, name=name+'/LeakyReLU') )\n        \n    # Broadcast.\n    model.add( Broadcast(name='G_mapping/Broadcast', dlatent_broadcast=num_layers) )\n    \n    # Output.\n    model.add( Identity(name='G_mapping/dlatents_out') )\n\n    # Apply truncation trick.\n    model.add( Truncation(name='Truncation', num_layers=num_layers, truncation_psi=truncation_psi ))\n    \n    return model\n\ndef StyleGAN_G_synthesis(dlatent_size=512, resolution=1024, is_const_noise=True):\n    # General parameters\n    num_channels = 3\n    resolution_log2 = int(np.log2(resolution))\n    num_layers = resolution_log2 * 2 - 2\n\n    # Primary inputs.\n    dlatents_in = tf.keras.layers.Input(shape=[num_layers, dlatent_size], name='G_synthesis/dlatents_in')\n    \n    # Noise inputs.\n    noise_inputs = []\n    for layer_idx in range(num_layers):\n        noise_inputs.append( RandomNoise(name='G_synthesis/noise%d'%layer_idx, layer_idx=layer_idx)(dlatents_in) )\n        \n    # Things to do at the end of each layer.\n    def layer_epilogue(x, layer_idx, name):\n        name = 'G_synthesis/{}x{}/{}/'.format(x.shape[2], x.shape[2], name)\n        \n        x = ApplyNoise(name=name+'Noise', is_const_noise=is_const_noise)([x, noise_inputs[layer_idx]])\n        x = ApplyBias(name=name+'bias')(x)\n        x = LeakyReLU(alpha=0.2, name=name+'LeakyReLU')(x)\n        x = InstanceNorm(name=name+'InstanceNorm')(x)       \n        \n        style = DenseLayer(units=x.shape[1]*2, gain=1, name=name+'StyleMod') (StridedSlice(layer_idx, name=name+'StridedSlice')(dlatents_in))\n        x = StyleModApply(name=name+'StyleModApply')([x, style])\n        \n        return x\n    \n    # Building blocks for remaining layers.\n    def block(res, x): # res = 3..resolution_log2\n        name, name0, name1 = '%dx%d' % (2**res, 2**res), 'Conv0_up', 'Conv1'\n        \n        # Conv0_up\n        upscaled = Upscale2d_conv2d(x, name='G_synthesis/{}/{}'.format(name, name0), filters=nf(res-1), kernel_size=3, use_bias=False)   \n        x = layer_epilogue( Blur(name='G_synthesis/{}/{}/Blur'.format(name, name0))(upscaled), res*2-4, name0 )\n        \n        # Conv1\n        x = layer_epilogue( Conv2d(name='G_synthesis/{}/{}'.format(name, name1), filters=nf(res-1), kernel_size=3, use_bias=False)(x), res*2-3, name1 )\n\n        return x \n    \n    def torgb(res, x): # res = 2..resolution_log2\n        lod = resolution_log2 - res        \n        return Conv2d(name='G_synthesis/ToRGB_lod%d' % lod, filters=num_channels, kernel_size=1, gain=1, use_bias=True)(x)\n    \n    # Early layers.\n    x = layer_epilogue(Const(name='G_synthesis/4x4/Const')(dlatents_in), 0, name='Const')\n    x = layer_epilogue(Conv2d(name='G_synthesis/4x4/Conv', filters=nf(1), kernel_size=3, use_bias=False)(x), 1, 'Conv')\n    \n    # Fixed structure: simple and efficient, but does not support progressive growing.\n    for res in range(3, resolution_log2 + 1):\n        x = block(res, x)\n            \n    x = torgb(resolution_log2, x)\n\n    # change output to the default NHWC format, so it will be compatible with other networks\n    x = tf.transpose(x, (0, 2, 3, 1))\n\n    return Model(inputs=dlatents_in, outputs=x, name='G_synthesis')\n\n\nclass StyleGAN_G(Model):\n    def __init__(self, resolution=1024, latent_size=512, dlatent_size=512, mapping_layers=8, mapping_fmaps=512,\n                 mapping_lrmul=0.01, truncation_psi=1):\n        super(StyleGAN_G, self).__init__()\n        self.model_mapping = StyleGAN_G_mapping(latent_size, dlatent_size, mapping_layers, mapping_fmaps,\n                                                mapping_lrmul, truncation_psi, resolution)\n        self.model_synthesis = StyleGAN_G_synthesis(dlatent_size, resolution)\n        print('Model created.')\n        \n    def call(self, inputs):\n        x = self.model_mapping(inputs)\n        x = self.model_synthesis(x)\n        return x\n    \n    def generate_sample(self, seed=5, is_visualize=False):\n        rnd = np.random.RandomState(seed)\n        latents = rnd.randn(1, 512)\n\n        y = self.predict(latents)\n\n        images = y.transpose([0, 2, 3, 1])\n        images = np.clip((images+1)*0.5, 0, 1)\n        \n        # print(images.shape, np.min(images), np.max(images))\n\n        plt.figure(figsize=(10, 10))\n        plt.imshow(images[0])\n        if is_visualize:\n            plt.show()\n\n        return images\n    \nclass StyleGAN_D(Model):\n    def __init__(self, resolution=1024, mbstd_group_size=4, mbstd_num_features=1):\n        super(StyleGAN_D, self).__init__()\n\n        resolution_log2 = int(math.log2(resolution))\n\n        model = Sequential(name='Discriminator')\n        model.add(InputLayer(input_shape=[3, resolution, resolution])) \n\n        def fromrgb(res):\n            name = 'FromRGB_lod%d' % (resolution_log2 - res)\n            model.add( Conv2d(filters=nf(res-1), kernel_size=1, name=name) )\n            model.add( LeakyReLU(alpha=0.2, name=name+'/LeakyReLU') )\n\n        def block(res):\n            name = '%dx%d' % (2**res, 2**res)\n            if res >= 3: # 8x8 and up\n                model.add( Conv2d(filters=nf(res-1), kernel_size=3, name=name+'/Conv0') )\n                model.add( LeakyReLU(alpha=0.2, name=name+'/Conv0/LeakyReLU') )\n\n                model.add( Blur(name=name+'/Blur') )\n                Conv2d_downscale2d(model=model, filters=nf(res-2), kernel_size=3, name=name+'/Conv1_down')\n                model.add( LeakyReLU(alpha=0.2, name=name+'/Conv1_down/LeakyReLU') )\n\n            else: # 4x4\n                if mbstd_group_size > 1: \n                    model.add( Lambda(lambda x: minibatch_stddev_layer(x, mbstd_group_size, mbstd_num_features), name=name+'/MinibatchStddev') )\n\n                model.add( Conv2d(filters=nf(res-1), kernel_size=3, name=name+'/Conv') )\n                model.add( LeakyReLU(alpha=0.2, name=name+'/Conv/LeakyReLU') )\n\n                model.add( Flatten() )\n                model.add( DenseLayer(units=nf(res-2), kernel_initializer=GetWeights(), name=name+'/Dense0') )\n                model.add( LeakyReLU(alpha=0.2, name=name+'/Dense0/LeakyReLU') )\n\n                model.add( DenseLayer(units=1, kernel_initializer=GetWeights(1), gain=1, name=name+'/Dense1') )   \n\n        # Blocks\n        fromrgb(resolution_log2)\n        for res in range(resolution_log2, 2, -1): block(res)\n        block(2)\n\n        self.model = model\n\n    def call(self, inputs):\n        inputs = tf.transpose(inputs, (0, 3, 1, 2))\n        return self.model(inputs)\n\ndef copy_weights_to_keras_model(model, all_weights):\n    c = 0\n    od = all_weights\n    for l in model.layers:\n      try:\n        values = l.get_weights()\n        weights = list(map(lambda x: x.shape, values))\n        if not len(weights): continue\n\n        num_params = values[0].size\n\n        # Special weights\n        if len(weights) == 1:    \n            weights_list = []\n\n            # The learned constant variable\n            if l.name == 'G_synthesis/4x4/Const': weights_list.append( od[l.name+'/const'] )\n\n            # Truncation trick variable\n            if 'Truncation' in l.name: weights_list.append( od['dlatent_avg'] )\n\n            # Input noise\n            if 'G_synthesis/noise' in l.name: weights_list.append( od[l.name] ) \n\n            # Noise variables\n            if 'Noise' in l.name: weights_list.append( od[l.name+'/weight'] )\n\n            # Bias variables\n            if 'bias' in l.name: weights_list.append( od[l.name] )\n\n            # Conv with no bias\n            if l.name.endswith('Conv') or l.name.endswith('Conv1') or l.name.endswith('Conv0_up'):\n                weights_list.append( od[l.name+'/weight'] )\n\n            if len(weights_list) > 0:\n                l.set_weights( weights_list )\n                print('.', end='')\n                c = c + num_params\n            else:\n                print('WARNING: weights not found for ', l.name, '  of size', weights[0])\n        else:  \n        # Standard weights (weight + bias)\n            assert len(weights) == 2\n\n            num_params = num_params + values[1].size\n\n            layer_name = l.name\n            var_names = ['{}/{}'.format(layer_name, 'weight'), '{}/{}'.format(layer_name, 'bias')]\n\n            if var_names[0] in od and var_names[1] in od:\n                weight = od[var_names[0]]\n                bias = od[var_names[1]]\n\n                l.set_weights( [ weight, bias ] )\n\n                print('.', end='')\n                c = c + num_params\n            else:\n                print('WARNING: not found', var_names)\n      except Exception as e:\n          print(e)\n          print('skipping...')\n\n    print('Total number of parameters copied:', c)\n"
  },
  {
    "path": "test.py",
    "content": "import os\n\nos.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\nos.environ['OMP_NUM_THREADS'] = '1'\nos.environ['USE_SIMPLE_THREADED_LEVEL3'] = '1'\nos.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'\n\nimport sys\nfrom model.stylegan import StyleGAN_G_synthesis\nfrom model.model import Network\nfrom writer import Writer\nfrom inference import Inference\nfrom arglib import arglib\nimport utils\n\nsys.path.insert(0, 'model/face_utils')\n\n\n\ndef main():\n    test_args = arglib.TestArgs()\n    args, str_args = test_args.args, test_args.str_args\n    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu\n\n    Writer.set_writer(args.results_dir)\n\n    id_model_path = args.pretrained_models_path.joinpath('vggface2.h5')\n    stylegan_G_synthesis_path = str(\n        args.pretrained_models_path.joinpath(f'stylegan_G_{args.resolution}x{args.resolution}_synthesis'))\n\n    utils.landmarks_model_path = str(args.pretrained_models_path.joinpath('shape_predictor_68_face_landmarks.dat'))\n\n    stylegan_G_synthesis = StyleGAN_G_synthesis(resolution=args.resolution, is_const_noise=args.const_noise)\n    stylegan_G_synthesis.load_weights(stylegan_G_synthesis_path)\n\n    network = Network(args, id_model_path, stylegan_G_synthesis)\n\n    network.test()\n    inference = Inference(args, network)\n    test_func = getattr(inference, args.test_func)\n    test_func()\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "trainer.py",
    "content": "import logging\n\nimport numpy as np\nimport tensorflow as tf\n\nfrom writer import Writer\nfrom utils import general_utils as utils\n\n\ndef id_loss_func(y_gt, y_pred):\n    return tf.reduce_mean(tf.keras.losses.MAE(y_gt, y_pred))\n\n\nclass Trainer(object):\n    def __init__(self, args, model, data_loader):\n        self.args = args\n        self.logger = logging.getLogger(__class__.__name__)\n\n        self.model = model\n        self.data_loader = data_loader\n\n        # lrs & optimizers\n        lr = 5e-5 if self.args.resolution == 256 else 1e-5\n\n        self.g_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)\n        self.g_gan_optimizer = tf.keras.optimizers.Adam(learning_rate=0.1 * lr)\n        self.w_d_optimizer = tf.keras.optimizers.Adam(learning_rate=0.4 * lr)\n\n        self.im_d_optimizer = tf.keras.optimizers.Adam(learning_rate=0.4 * lr)\n\n        # Losses\n        self.gan_loss_func = tf.keras.losses.BinaryCrossentropy(from_logits=True)\n        self.pixel_loss_func = tf.keras.losses.MeanAbsoluteError(tf.keras.losses.Reduction.SUM)\n\n        self.id_loss_func = id_loss_func\n\n        if args.pixel_mask_type == 'gaussian':\n            sigma = int(80 * (self.args.resolution / 256))\n            self.pixel_mask = utils.inverse_gaussian_image(self.args.resolution, sigma)\n        else:\n            self.pixel_mask = tf.ones([self.args.resolution, self.args.resolution])\n            self.pixel_mask = self.pixel_mask / tf.reduce_sum(self.pixel_mask)\n\n        self.pixel_mask = tf.broadcast_to(self.pixel_mask, [self.args.batch_size, *self.pixel_mask.shape])\n\n        self.num_epoch = 0\n        self.is_cross_epoch = False\n\n        # Lambdas\n        if args.unified:\n            self.lambda_gan = 0.5\n        else:\n            self.lambda_gan = 1\n\n        self.lambda_pixel = 0.02\n\n        self.lambda_id = 1\n        self.lambda_attr_id = 1\n        self.lambda_landmarks = 0.001\n        self.r1_gamma = 10\n\n        # Test\n        self.test_not_imporved = 0\n        self.max_id_preserve = 0\n        self.min_lnd_dist = np.inf\n\n    def train(self):\n        while self.num_epoch <= self.args.num_epochs:\n            self.logger.info('---------------------------------------')\n            self.logger.info(f'Start training epoch: {self.num_epoch}')\n\n            if self.args.cross_frequency and (self.num_epoch % self.args.cross_frequency == 0):\n                self.is_cross_epoch = True\n                self.logger.info('This epoch is cross-face')\n            else:\n                self.is_cross_epoch = False\n                self.logger.info('This epoch is same-face')\n\n            try:\n                if self.num_epoch % self.args.test_frequency == 0:\n                    self.test()\n\n                self.train_epoch()\n\n            except Exception as e:\n                self.logger.exception(e)\n                raise\n\n            if self.test_not_imporved > self.args.not_improved_exit:\n                self.logger.info(f'Test has not improved for {self.args.not_improved_exit} epochs. Exiting...')\n                break\n\n            self.num_epoch += 1\n\n    def train_epoch(self):\n        id_loss = 0\n        landmarks_loss = 0\n        g_w_gan_loss = 0\n        pixel_loss = 0\n        w_d_loss = 0\n        w_loss = 0\n\n        self.logger.info(f'train in epoch: {self.num_epoch}')\n        self.model.train()\n\n        use_w_d = self.args.W_D_loss\n\n        # if use_w_d and use_im_d and not self.args.unified:\n        if not self.args.unified:\n            if self.num_epoch % 2 == 0:\n                # This epoch is not using image_D\n                use_im_d = False\n                # self.logger.info(f'Not using Image D in epoch: {self.num_epoch}')\n            if self.num_epoch % 2 != 0:\n                # This epoch is not using W_D\n                use_w_d = False\n                # self.logger.info(f'Not using W_d in epoch: {self.num_epoch}')\n\n        attr_img, id_img, real_w, real_img, matching_ws = self.data_loader.get_batch(is_cross=self.is_cross_epoch)\n\n        # Forward that does not require grads\n        id_embedding = self.model.G.id_encoder(id_img)\n        src_landmarks = self.model.G.landmarks(attr_img)\n        attr_input = attr_img\n\n        with tf.GradientTape(persistent=True) as g_tape:\n\n            attr_out = self.model.G.attr_encoder(attr_input)\n            attr_embedding = attr_out\n\n            self.logger.info(f'attr embedding stats- mean: {tf.reduce_mean(tf.abs(attr_embedding)):.5f},'\n                             f' variance: {tf.math.reduce_variance(attr_embedding):.5f}')\n\n            z_tag = tf.concat([id_embedding, attr_embedding], -1)\n            w = self.model.G.latent_spaces_mapping(z_tag)\n            fake_w = w[:, 0, :]\n\n            self.logger.info(\n                f'w stats- mean: {tf.reduce_mean(tf.abs(fake_w)):.5f}, variance: {tf.math.reduce_variance(fake_w):.5f}')\n\n            pred = self.model.G.stylegan_s(w)\n\n            # Move to roughly [0,1]\n            pred = (pred + 1) / 2\n\n            if use_w_d:\n                with tf.GradientTape() as w_d_tape:\n                    fake_w_logit = self.model.W_D(fake_w)\n                    g_w_gan_loss = self.generator_gan_loss(fake_w_logit)\n\n                    self.logger.info(f'g W loss is {g_w_gan_loss:.3f}')\n                    self.logger.info(f'fake W logit: {tf.squeeze(fake_w_logit)}')\n\n                    with g_tape.stop_recording():\n                        real_w_logit = self.model.W_D(real_w)\n                        w_d_loss = self.discriminator_loss(fake_w_logit, real_w_logit)\n                        w_d_total_loss = w_d_loss\n\n                        if self.args.gp:\n                            w_d_gp = self.R1_gp(self.model.W_D, real_w)\n                            w_d_total_loss += w_d_gp\n                            self.logger.info(f'w_d_gp : {w_d_gp}')\n\n                        self.logger.info(f'W_D loss is {w_d_loss:.3f}')\n                        self.logger.info(f'real W logit: {tf.squeeze(real_w_logit)}')\n\n            if self.args.id_loss:\n                pred_id_embedding = self.model.G.id_encoder(pred)\n                id_loss = self.lambda_id * id_loss_func(pred_id_embedding, tf.stop_gradient(id_embedding))\n                self.logger.info(f'id loss is {id_loss:.3f}')\n\n            if self.args.landmarks_loss:\n                try:\n                    dst_landmarks = self.model.G.landmarks(pred)\n                except Exception as e:\n                    self.logger.warning(f'Failed finding landmarks on prediction. Dont use landmarks loss. Error:{e}')\n                    dst_landmarks = None\n\n                if dst_landmarks is None or src_landmarks is None:\n                    landmarks_loss = 0\n                else:\n                    landmarks_loss = self.lambda_landmarks * \\\n                                     tf.reduce_mean(tf.keras.losses.MSE(src_landmarks, dst_landmarks))\n                    self.logger.info(f'landmarks loss is: {landmarks_loss:.3f}')\n                    # if landmarks_loss > 5:\n                    #     landmarks_loss = 0\n                    #     id_loss = 0\n\n            if not self.is_cross_epoch and self.args.pixel_loss:\n                l1_loss = self.pixel_loss_func(attr_img, pred, sample_weight=self.pixel_mask)\n                self.logger.info(f'L1 pixel loss is {l1_loss:.3f}')\n\n                if self.args.pixel_loss_type == 'mix':\n                    mssim = tf.reduce_mean(1 - tf.image.ssim_multiscale(attr_img, pred, 1.0))\n                    self.logger.info(f'mssim loss is {l1_loss:.3f}')\n                    pixel_loss = self.lambda_pixel * (0.84 * mssim + 0.16 * l1_loss)\n                else:\n                    pixel_loss = self.lambda_pixel * l1_loss\n\n                self.logger.info(f'pixel loss is {pixel_loss:.3f}')\n\n            g_gan_loss = g_w_gan_loss\n\n            total_g_not_gan_loss = id_loss \\\n                                   + landmarks_loss \\\n                                   + pixel_loss \\\n                                   + w_loss\n\n            self.logger.info(f'total G (not gan) loss is {total_g_not_gan_loss:.3f}')\n            self.logger.info(f'G gan loss is {g_gan_loss:.3f}')\n\n        Writer.add_scalar('loss/landmarks_loss', landmarks_loss, step=self.num_epoch)\n        Writer.add_scalar('loss/total_g_not_gan_loss', total_g_not_gan_loss, step=self.num_epoch)\n\n        Writer.add_scalar('loss/id_loss', id_loss, step=self.num_epoch)\n\n        if use_w_d:\n            Writer.add_scalar('loss/g_w_gan_loss', g_w_gan_loss, step=self.num_epoch)\n            Writer.add_scalar('loss/W_D_loss', w_d_loss, step=self.num_epoch)\n            if self.args.gp:\n                Writer.add_scalar('loss/w_d_gp', w_d_gp, step=self.num_epoch)\n\n        if not self.is_cross_epoch:\n            Writer.add_scalar('loss/pixel_loss', pixel_loss, step=self.num_epoch)\n            Writer.add_scalar('loss/w_loss', w_loss, step=self.num_epoch)\n\n        if self.args.debug or \\\n                (self.num_epoch < 1e3 and self.num_epoch % 1e2 == 0) or \\\n                (self.num_epoch < 1e4 and self.num_epoch % 1e3 == 0) or \\\n                (self.num_epoch % 1e4 == 0):\n            utils.save_image(pred[0], self.args.images_results.joinpath(f'{self.num_epoch}_prediction_step.png'))\n            utils.save_image(id_img[0], self.args.images_results.joinpath(f'{self.num_epoch}_id_step.png'))\n            utils.save_image(attr_img[0], self.args.images_results.joinpath(f'{self.num_epoch}_attr_step.png'))\n\n            Writer.add_image('input/id image', tf.expand_dims(id_img[0], 0), step=self.num_epoch)\n            Writer.add_image('Prediction', tf.expand_dims(pred[0], 0), step=self.num_epoch)\n\n        if total_g_not_gan_loss != 0:\n            g_grads = g_tape.gradient(total_g_not_gan_loss, self.model.G.trainable_variables)\n\n            g_grads_global_norm = tf.linalg.global_norm(g_grads)\n            self.logger.info(f'global norm G not gan grad: {g_grads_global_norm}')\n\n            self.g_optimizer.apply_gradients(zip(g_grads, self.model.G.trainable_variables))\n\n        if use_w_d:\n            g_gan_grads = g_tape.gradient(g_gan_loss, self.model.G.trainable_variables)\n\n            g_gan_grad_global_norm = tf.linalg.global_norm(g_gan_grads)\n            self.logger.info(f'global norm G gan grad: {g_gan_grad_global_norm}')\n\n            self.g_gan_optimizer.apply_gradients(zip(g_gan_grads, self.model.G.trainable_variables))\n\n            w_d_grads = w_d_tape.gradient(w_d_total_loss, self.model.W_D.trainable_variables)\n\n            self.logger.info(f'global W_D gan grad: {tf.linalg.global_norm(w_d_grads)}')\n            self.w_d_optimizer.apply_gradients(zip(w_d_grads, self.model.W_D.trainable_variables))\n\n        del g_tape\n\n    # Common\n\n    # Test\n    def test(self):\n        self.logger.info(f'Testing in epoch: {self.num_epoch}')\n        self.model.test()\n\n        similarities = {'id_to_pred': [], 'id_to_attr': [], 'attr_to_pred': []}\n\n        fake_reconstruction = {'MSE': [], 'PSNR': [], 'ID': []}\n        real_reconstruction = {'MSE': [], 'PSNR': [], 'ID': []}\n\n        if self.args.test_with_arcface:\n            test_similarities = {'id_to_pred': [], 'id_to_attr': [], 'attr_to_pred': []}\n\n        lnd_dist = []\n\n        for i in range(self.args.test_size):\n            attr_img, id_img = self.data_loader.get_batch(is_train=False, is_cross=True)\n\n            pred, id_embedding, w, attr_embedding, src_lnds = self.model.G(id_img, attr_img)\n            image = tf.clip_by_value(pred, 0, 1)\n\n            pred_id = self.model.G.id_encoder(image)\n            attr_id = self.model.G.id_encoder(attr_img)\n\n            similarities['id_to_pred'].extend(tf.keras.losses.cosine_similarity(id_embedding, pred_id).numpy())\n            similarities['id_to_attr'].extend(tf.keras.losses.cosine_similarity(id_embedding, attr_id).numpy())\n            similarities['attr_to_pred'].extend(tf.keras.losses.cosine_similarity(attr_id, pred_id).numpy())\n\n            if self.args.test_with_arcface:\n                try:\n                    arc_id_embedding = self.model.G.test_id_encoder(id_img)\n                    arc_pred_id = self.model.G.test_id_encoder(image)\n                    arc_attr_id = self.model.G.test_id_encoder(attr_img)\n\n                    test_similarities['id_to_attr'].extend(\n                        tf.keras.losses.cosine_similarity(arc_id_embedding, arc_attr_id).numpy())\n                    test_similarities['id_to_pred'].extend(\n                        tf.keras.losses.cosine_similarity(arc_id_embedding, arc_pred_id).numpy())\n                    test_similarities['attr_to_pred'].extend(\n                        tf.keras.losses.cosine_similarity(arc_attr_id, arc_pred_id).numpy())\n                except Exception as e:\n                    self.logger.warning(f'Not calculating test similarities for iteration: {i} because: {e}')\n\n            # Landmarks\n            dst_lnds = self.model.G.landmarks(image)\n            lnd_dist.extend(tf.reduce_mean(tf.keras.losses.MSE(src_lnds, dst_lnds), axis=-1).numpy())\n\n            # Fake Reconstruction\n            self.test_reconstruction(id_img, fake_reconstruction, display=(i==0), display_name='id_img')\n\n            if self.args.test_real_attr:\n                # Real Reconstruction\n                self.test_reconstruction(attr_img, real_reconstruction, display=(i==0), display_name='attr_img')\n\n            if i == 0:\n                utils.save_image(image[0], self.args.images_results.joinpath(f'test_prediction_{self.num_epoch}.png'))\n                utils.save_image(id_img[0], self.args.images_results.joinpath(f'test_id_{self.num_epoch}.png'))\n                utils.save_image(attr_img[0],\n                                 self.args.images_results.joinpath(f'test_attr_{self.num_epoch}.png'))\n\n                Writer.add_image('test/prediction', image, step=self.num_epoch)\n                Writer.add_image('test input/id image', id_img, step=self.num_epoch)\n                Writer.add_image('test input/attr image', attr_img, step=self.num_epoch)\n\n                for j in range(np.minimum(3, src_lnds.shape[0])):\n                    src_xy = src_lnds[j]  # GT\n                    dst_xy = dst_lnds[j]  # pred\n\n                    attr_marked = utils.mark_landmarks(attr_img[j], src_xy, color=(0, 0, 0))\n                    pred_marked = utils.mark_landmarks(pred[j], src_xy, color=(0, 0, 0))\n                    pred_marked = utils.mark_landmarks(pred_marked, dst_xy, color=(255, 112, 112))\n\n                    Writer.add_image(f'landmarks/overlay-{j}', pred_marked, step=self.num_epoch)\n                    Writer.add_image(f'landmarks/src-{j}', attr_marked, step=self.num_epoch)\n\n        # Similarity\n        self.logger.info('Similarities:')\n        for k, v in similarities.items():\n            self.logger.info(f'{k}: MEAN: {np.mean(v)}, STD: {np.std(v)}')\n\n        mean_lnd_dist = np.mean(lnd_dist)\n        self.logger.info(f'Mean landmarks L2: {mean_lnd_dist}')\n\n        id_to_pred = np.mean(similarities['id_to_pred'])\n        attr_to_pred = np.mean(similarities['attr_to_pred'])\n        mean_disen = attr_to_pred - id_to_pred\n\n        Writer.add_scalar('similarity/score', mean_disen, step=self.num_epoch)\n        Writer.add_scalar('similarity/id_to_pred', id_to_pred, step=self.num_epoch)\n        Writer.add_scalar('similarity/attr_to_pred', attr_to_pred, step=self.num_epoch)\n\n        if self.args.test_with_arcface:\n            arc_id_to_pred = np.mean(test_similarities['id_to_pred'])\n            arc_attr_to_pred = np.mean(test_similarities['attr_to_pred'])\n            arc_mean_disen = arc_attr_to_pred - arc_id_to_pred\n\n            Writer.add_scalar('arc_similarity/score', arc_mean_disen, step=self.num_epoch)\n            Writer.add_scalar('arc_similarity/id_to_pred', arc_id_to_pred, step=self.num_epoch)\n            Writer.add_scalar('arc_similarity/attr_to_pred', arc_attr_to_pred, step=self.num_epoch)\n\n        self.logger.info(f'Mean disentanglement score is {mean_disen}')\n\n        Writer.add_scalar('landmarks/L2', np.mean(lnd_dist), step=self.num_epoch)\n\n        # Reconstruction\n        if self.args.test_real_attr:\n            Writer.add_scalar('reconstruction/real_MSE', np.mean(real_reconstruction['MSE']), step=self.num_epoch)\n            Writer.add_scalar('reconstruction/real_PSNR', np.mean(real_reconstruction['PSNR']), step=self.num_epoch)\n            Writer.add_scalar('reconstruction/real_ID', np.mean(real_reconstruction['ID']), step=self.num_epoch)\n\n        Writer.add_scalar('reconstruction/fake_MSE', np.mean(fake_reconstruction['MSE']), step=self.num_epoch)\n        Writer.add_scalar('reconstruction/fake_PSNR', np.mean(fake_reconstruction['PSNR']), step=self.num_epoch)\n        Writer.add_scalar('reconstruction/fake_ID', np.mean(fake_reconstruction['ID']), step=self.num_epoch)\n\n        if mean_lnd_dist < self.min_lnd_dist:\n            self.logger.info('Minimum landmarks dist achieved. saving checkpoint')\n            self.test_not_imporved = 0\n            self.min_lnd_dist = mean_lnd_dist\n            self.model.my_save(f'_best_landmarks_epoch_{self.num_epoch}')\n\n        if np.abs(id_to_pred) > self.max_id_preserve:\n            self.logger.info(f'Max ID preservation achieved! saving checkpoint')\n            self.test_not_imporved = 0\n            self.max_id_preserve = np.abs(id_to_pred)\n            self.model.my_save(f'_best_id_epoch_{self.num_epoch}')\n\n        else:\n            self.test_not_imporved += 1\n\n    def test_reconstruction(self, img, errors_dict, display=False, display_name=None):\n        pred, id_embedding, w, attr_embedding, src_lnds = self.model.G(img, img)\n\n        recon_image = tf.clip_by_value(pred, 0, 1)\n        recon_pred_id = self.model.G.id_encoder(recon_image)\n\n        mse = tf.reduce_mean((img - recon_image) ** 2, axis=[1, 2, 3]).numpy()\n        psnr = tf.image.psnr(img, recon_image, 1).numpy()\n\n        errors_dict['MSE'].extend(mse)\n        errors_dict['PSNR'].extend(psnr)\n        errors_dict['ID'].extend(tf.keras.losses.cosine_similarity(id_embedding, recon_pred_id).numpy())\n\n        if display:\n            Writer.add_image(f'reconstruction/{display_name}', pred, step=self.num_epoch)\n\n    # Helpers\n\n    def generator_gan_loss(self, fake_logit):\n        \"\"\"\n        G logistic non saturating loss, to be minimized\n        \"\"\"\n        g_gan_loss = self.gan_loss_func(tf.ones_like(fake_logit), fake_logit)\n        return self.lambda_gan * g_gan_loss\n\n    def discriminator_loss(self, fake_logit, real_logit):\n        \"\"\"\n        D logistic loss, to be minimized\n        verified as identical to StyleGAN's loss.D_logistic\n        \"\"\"\n        fake_gt = tf.zeros_like(fake_logit)\n        real_gt = tf.ones_like(real_logit)\n\n        d_fake_loss = self.gan_loss_func(fake_gt, fake_logit)\n        d_real_loss = self.gan_loss_func(real_gt, real_logit)\n\n        d_loss = d_real_loss + d_fake_loss\n\n        return self.lambda_gan * d_loss\n\n    def R1_gp(self, D, x):\n        with tf.GradientTape() as t:\n            t.watch(x)\n            pred = D(x)\n            pred_sum = tf.reduce_sum(pred)\n\n        grad = t.gradient(pred_sum, x)\n\n        # Reshape as a vector\n        norm = tf.norm(tf.reshape(grad, [tf.shape(grad)[0], -1]), axis=1)\n        gp = tf.reduce_mean(norm ** 2)\n        gp = 0.5 * self.r1_gamma * gp\n\n        return gp\n"
  },
  {
    "path": "utils/__init__.py",
    "content": ""
  },
  {
    "path": "utils/general_utils.py",
    "content": "from pathlib import Path\n\nimport cv2\nimport numpy as np\nfrom PIL import Image\nimport tensorflow as tf\nfrom tensorflow.keras.initializers import VarianceScaling\nimport numbers\nimport scipy\nimport dlib\n\nlandmarks_model_path = None\n\n\ndef read_image(img_path, resolution, align=False):\n    if align:\n        img = read_and_align_image(img_path, resolution)\n    else:\n        img = read_SG_image(img_path, resolution)\n\n    return img\n\n\ndef find_file_by_str(search_dir, s):\n    files = [f for f in search_dir.iterdir() if s in f.name]\n    return files\n\n\ndef read_SG_image(img_path, size=256, resize=True):\n    img = Image.open(str(img_path))\n    img = img.convert('RGB')\n\n    if img.size != (size, size) and resize:\n        img = img.resize((size, size))\n    img = np.asarray(img)\n\n    img = np.expand_dims(img, axis=0)\n\n    # Images in [0, 1]\n    img = np.float32(img) / 255\n\n    return img\n\n\ndef read_and_align_image(img_path, output_size=1024):\n    global landmarks_model_path\n    if not landmarks_model_path:\n        raise ValueError('Please init the landmarks model path')\n\n    transform_size = 4096\n    enable_padding = True\n\n    img = Image.open(img_path)\n    img = img.convert('RGB')\n    npimg = np.asarray(img)\n\n    # states is a 4x1 array with confidence for : [left eye closed, right eye closed, mouth closed, mouth open big]\n    face_detector = dlib.get_frontal_face_detector()\n    landmarks_network = dlib.shape_predictor(landmarks_model_path)\n\n    try:\n        bbox = face_detector(npimg, 0)[0]\n    except:\n        print('face not found!')\n        raise\n\n    # rect = np.array([det.left(), det.top(), det.right(), det.bottom()])\n    shape = landmarks_network(npimg, bbox)\n    lm = np.array([[shape.part(n).x + 0.5, shape.part(n).y + 0.5] for n in range(shape.num_parts)])\n\n    lm = np.round(lm) + 0.5\n\n    lm_chin = lm[0: 17]  # left-right\n    lm_eyebrow_left = lm[17: 22]  # left-right\n    lm_eyebrow_right = lm[22: 27]  # left-right\n    lm_nose = lm[27: 31]  # top-down\n    lm_nostrils = lm[31: 36]  # top-down\n    lm_eye_left = lm[36: 42]  # left-clockwise\n    lm_eye_right = lm[42: 48]  # left-clockwise\n    lm_mouth_outer = lm[48: 60]  # left-clockwise\n    lm_mouth_inner = lm[60: 68]  # left-clockwise\n\n    # Calculate auxiliary vectors.\n    eye_left = np.mean(lm_eye_left, axis=0)\n    eye_right = np.mean(lm_eye_right, axis=0)\n\n    eye_avg = (eye_left + eye_right) * 0.5\n\n    # nose_mock_avg = (lm_nose[0] + lm_nose[1]) * 0.5\n    # eye_avg = nose_mock_avg\n\n    eye_to_eye = eye_right - eye_left\n    mouth_left = lm_mouth_outer[0]\n    mouth_right = lm_mouth_outer[6]\n    mouth_avg = (mouth_left + mouth_right) * 0.5\n    eye_to_mouth = mouth_avg - eye_avg\n\n    # Choose oriented crop rectangle.\n    x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]\n    x /= np.hypot(*x)\n    x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)\n\n    y = np.flipud(x) * [-1, 1]\n\n    c = eye_avg + eye_to_mouth * 0.1\n\n    quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])\n    qsize = np.hypot(*x) * 2\n\n    # Shrink.\n    shrink = int(np.floor(qsize / output_size * 0.5))\n    if shrink > 1:\n        rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))\n        img = img.resize(rsize, Image.ANTIALIAS)\n        quad /= shrink\n        qsize /= shrink\n\n    # Crop.\n    border = max(int(np.rint(qsize * 0.1)), 3)\n    crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),\n            int(np.ceil(max(quad[:, 1]))))\n    crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),\n            min(crop[3] + border, img.size[1]))\n    crop = np.array(crop)\n    if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:\n        img = img.crop(tuple(crop))\n        quad -= crop[0:2]\n\n    # Pad.\n    pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),\n           int(np.ceil(max(quad[:, 1]))))\n    pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),\n           max(pad[3] - img.size[1] + border, 0))\n    if enable_padding and max(pad) > border - 4:\n        pad = np.maximum(pad, int(np.rint(qsize * 0.3)))\n        img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')\n        h, w, _ = img.shape\n        y, x, _ = np.ogrid[:h, :w, :1]\n        mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),\n                          1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))\n        blur = qsize * 0.02\n        img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)\n        img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)\n        img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')\n        quad += pad[:2]\n\n    # Transform.\n    img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(),\n                        Image.BILINEAR)\n\n    if output_size < transform_size:\n        img = img.resize((output_size, output_size), Image.ANTIALIAS)\n\n    img = np.asarray(img)\n    img = np.expand_dims(img, axis=0)\n    img = np.float32(img) / 255\n\n    return img\n\n\ndef gaussian_image(size, sigma, dim=2):\n    if isinstance(size, numbers.Number):\n        size = [size] * dim\n    if isinstance(sigma, numbers.Number):\n        sigma = [sigma] * dim\n\n    weight_kernel = 1\n    meshgrids = np.meshgrid(*[np.arange(size, dtype=np.float32) for size in size])\n    # The gaussian kernel is the product of the\n    # gaussian function of each dimension.\n    for size, std, mgrid in zip(size, sigma, meshgrids):\n        mean = (size - 1) / 2\n        weight_kernel *= 1 / (std * np.sqrt(2 * np.pi)) * np.exp(-((mgrid - mean) / std) ** 2 / 2)\n\n    weight_kernel = weight_kernel / np.sum(weight_kernel)\n    return weight_kernel\n\n\ndef inverse_gaussian_image(size, sigma, dim=2):\n    gauss = gaussian_image(size, sigma, dim)\n\n    # Inversion achieved by max - gauss, but adding min as well to\n    # prevent regions of zeros which don't exist in normal gaussian\n    inv_gauss = np.max(gauss) + np.min(gauss) - gauss\n    inv_gauss = inv_gauss / np.sum(inv_gauss)\n\n    return inv_gauss\n\n\ndef is_float(tensor):\n    \"\"\"\n    Check if input tensor is float32, tensor maybe tf.Tensor or np.array\n    \"\"\"\n\n    return (isinstance(tensor, tf.Tensor) and tensor.dtype != tf.dtypes.uint8) or tensor.dtype != np.uint8\n\n\ndef convert_tensor_to_image(tensor):\n    \"\"\"\n    Converts tensor to image, and saturate output's range\n    :param tensor: tf.Tensor, dtype float32, range [0,1]\n    :return: np.array, dtype uint8, range [0, 255]\n    \"\"\"\n    if is_float(tensor):\n        tensor = tf.clip_by_value(tensor, 0., 1.)\n        tensor = 255 * tensor\n\n    if tensor.ndim == 4 and tensor.shape[0] == 1:\n        tensor = tf.squeeze(tensor)\n\n    tensor = np.uint8(np.round(tensor))\n\n    return tensor\n\n\ndef save_image(img, file_path):\n    \"\"\"\n    :param img: Could be either tf tensor or numpy array\n    :param file_path:\n    \"\"\"\n\n    if isinstance(file_path, Path):\n        file_path = str(file_path)\n\n    img = convert_tensor_to_image(img)\n    img = Image.fromarray(img)\n    img.save(file_path)\n\n\ndef mark_landmarks(img, lnd, color=None):\n    \"\"\"\n    landmarks in (x,y) format\n    \"\"\"\n    img = convert_tensor_to_image(img)\n    radius = int(img.shape[0] / 256)\n\n    lnd = (img.shape[0] / 160) * lnd\n\n    if not color:\n        color = (255, 255, 255)\n\n    for i in range(lnd.shape[0]):\n        x_y = lnd[i]\n        img = cv2.circle(img, center=(int(x_y[0]), int(x_y[1])),\n                         color=color, radius=radius, thickness=-1)\n\n    return img\n\n\ndef get_weights(slope=0.2):\n    \"\"\"\n    The scale is calculated according to:\n        https://pytorch.org/docs/stable/nn.init.html\n    and\n        https://towardsdatascience.com/weight-initialization-in-neural-networks-a-journey-from-the-basics-to-kaiming-954fb9b47c79\n\n    For ReLU and LeakyReLU activations, the preferable initialization is kaiming.\n\n    In Pytorch, the gain for LeakyReLU is calcaulted by: sqrt(2 / ( 1 + leaky_relu_slope ^ 2)) and the weights are\n    sampled from N(0, std^2) where std = gain / sqrt(fan_in)\n\n    To mimic this in TF, I am using VarianceScaling. The weights are sampled from N(0, std^2)\n    where std = sqrt(scale / fan_in). Therefore, scale = gain^2\n    \"\"\"\n    scale = 2 / (1 + slope ** 2)\n    return VarianceScaling(scale)\n\n\ndef np_permute(tensor, permute):\n    idx = np.empty_like(permute)\n    idx[permute] = np.arange(len(permute))\n    return tensor[:, idx]\n"
  },
  {
    "path": "utils/generate_fake_data.py",
    "content": "import sys\nfrom pathlib import Path\nimport os\n\nsys.path.append('..')\n\nimport argparse\n\nimport tensorflow as tf\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom utils.general_utils import save_image\nfrom model.stylegan import StyleGAN_G\n\n\ndef main(args):\n    base_dir = Path(args.output_path).joinpath(f'dataset_{args.resolution}')\n\n    base_w_dir = base_dir.joinpath('ws')\n    base_w_dir.mkdir(parents=True, exist_ok=True)\n\n    base_im_dir = base_dir.joinpath('images')\n    base_im_dir.mkdir(parents=True, exist_ok=True)\n\n    existing_files = list(base_dir.joinpath('images').iterdir())\n    if existing_files:\n        max_exist = max([int(x.name) for x in existing_files])\n        max_exist = int(max_exist - max_exist % 1e3 + 1e3)\n    else:\n        max_exist = 0\n\n    stylegan_G_path = args.pretrained_models_path.joinpath(f'stylegan_G_{args.resolution}x{args.resolution}.h5')\n    stylegan_G = StyleGAN_G(resolution=args.resolution, truncation_psi=args.truncation)\n    stylegan_G.load_weights(str(stylegan_G_path))\n\n    num_samples = args.num_images\n    batch_size = args.batch_size\n    num_batches = int(num_samples / batch_size)\n\n    curr_ind = max_exist\n    for _ in tqdm(range(num_batches)):\n        z = tf.random.normal((batch_size, 512))\n        w = stylegan_G.model_mapping(z)\n        images = stylegan_G.model_synthesis(w)\n        images = (images + 1) / 2\n\n        if curr_ind % 1000 == 0:\n            curr_w_dir = base_w_dir.joinpath(f'{curr_ind:05d}')\n            curr_w_dir.mkdir(exist_ok=True)\n\n            curr_im_dir = base_im_dir.joinpath(f'{curr_ind:05d}')\n            curr_im_dir.mkdir(exist_ok=True)\n\n        for j in range(batch_size):\n            w_path = curr_w_dir.joinpath(f'{curr_ind:05d}.npy')\n            np.save(str(w_path), w[j], allow_pickle=False)\n\n            im_path = curr_im_dir.joinpath(f'{curr_ind:05d}.png')\n            save_image(images[j], im_path)\n\n            curr_ind += 1\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument('--resolution', type=int, choices=[256, 1024], default=256)\n    parser.add_argument('--batch_size', type=int, default=50)\n    parser.add_argument('--truncation', type=float, default=0.7)\n\n    parser.add_argument('--output_path', required=True)\n    parser.add_argument('--pretrained_models_path', type=Path, required=True)\n\n    parser.add_argument('--num_images', type=int, default=10000)\n    parser.add_argument('--gpu', default='0')\n\n    args = parser.parse_args()\n\n    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu\n    assert args.num_images % 1e3 == 0\n\n    main(args)\n"
  },
  {
    "path": "writer.py",
    "content": "from utils.general_utils import convert_tensor_to_image\nfrom pathlib import Path\n\nimport tensorflow as tf\n\n\nclass Writer(object):\n    writer = None\n\n    @staticmethod\n    def set_writer(results_dir):\n        if isinstance(results_dir, str):\n            results_dir = Path(results_dir)\n        results_dir.mkdir(exist_ok=True, parents=True)\n        Writer.writer = tf.summary.create_file_writer(str(results_dir))\n\n    @staticmethod\n    def add_scalar(tag, val, step):\n        with Writer.writer.as_default():\n            tf.summary.scalar(tag, val, step=step)\n\n    @staticmethod\n    def add_image(tag, val, step):\n        val = convert_tensor_to_image(val)\n\n        if tf.rank(val) == 3:\n            val = tf.expand_dims(val, 0)\n\n        with Writer.writer.as_default():\n            tf.summary.image(tag, val, step)\n\n    @staticmethod\n    def flush():\n        with Writer.writer.as_default():\n            Writer.writer.flush()"
  }
]