[
  {
    "path": "README.md",
    "content": "# DeltaSpace: A Semantic-aligned Feature Space for Flexible Text-guided Image Editing\r\n\r\n## Overview\r\n\r\nThis repository contains the **offical** PyTorch implementation of paper:\r\n\r\n*DeltaEdit: Exploring Text-free Training for Text-driven Image Manipulation*, CVPR 2023\r\n\r\n*DeltaSpace: A Semantic-aligned Feature Space for Flexible Text-guided Image Editing*, Arxiv 2023\r\n\r\n## News\r\n\r\n- [2025-06-22] Upload t-SNE Code for Alignment Validation​ (◍＞◡＜◍).\r\n\r\n- [2023-03-11] Upload the training and inference code for the facial domain (◍•ڡ•◍).\r\n\r\n*To be continued...*\r\n\r\n<!-- We will release the training and inference code for the LSUN cat, church, horse later : ) -->\r\n\r\n## Dependences\r\n\r\n- Install CLIP:\r\n\r\n  ```shell script\r\n  conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=<CUDA_VERSION>\r\n  pip install ftfy regex tqdm gdown\r\n  pip install git+https://github.com/openai/CLIP.git\r\n  ```\r\n\r\n- Download pre-trained models :\r\n\r\n  - The code relies on the [Rosinality](https://github.com/rosinality/stylegan2-pytorch/) pytorch implementation of StyleGAN2.\r\n  - Download the pre-trained StyleGAN2 generator model for the faical domain from [here](https://drive.google.com/file/d/1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT/view?usp=sharing), and then place it into the folder `./models/pretrained_models`.\r\n  - Download the pre-trained StyleGAN2 generator model for the LSUN cat, church, horse domains from [here](https://drive.google.com/drive/folders/1YRhXGM-2xk7A4TExM_jXaNg1f2AiCRlw?usp=share_link) and then place them into the folder `./models/pretrained_models/stylegan2-{cat/church/horse}`.\r\n\r\n## Training\r\n\r\n### Data preparing\r\n\r\n- DeltaEdit is trained on latent vectors. \r\n\r\n- For the facial domain,  58,000 real images from [FFHQ](https://github.com/NVlabs/ffhq-dataset) dataset are randomly selected and 200,000 fake images from the z space in StyleGAN are sampled for training. Note that all real images are inverted by [e4e](https://github.com/omertov/encoder4editing) encoder.\r\n\r\n- Download the provided FFHQ latent vectors from [here](https://drive.google.com/drive/folders/13NLq4giSgdcMVkYQIiPj4Xhxz4-wlXSD?usp=sharing) and then place all numpy files into the folder `./latent_code/ffhq`.\r\n\r\n- Generate the 200,000 sampled latent vectors by running the following commands for each specific domain:\r\n\r\n  ```python\r\n  CUDA_VISIBLE_DEVICES=0 python generate_codes.py --classname ffhq --samples 200000\r\n  CUDA_VISIBLE_DEVICES=0 python generate_codes.py --classname cat --samples 200000\r\n  CUDA_VISIBLE_DEVICES=0 python generate_codes.py --classname church --samples 200000\r\n  CUDA_VISIBLE_DEVICES=0 python generate_codes.py --classname horse --samples 200000\r\n  ```\r\n\r\n### Usage\r\n\r\n- The main training script is placed in `./scripts/train.py`.\r\n- Training arguments can be found at `./options/train_options.py`. \r\n\r\nFor training please run the following commands:\r\n\r\n```python\r\nCUDA_VISIBLE_DEVICES=0 python scripts/train.py\r\n```\r\n\r\n## Inference\r\n\r\n- The main inferece script is placed in `./scripts/inference.py`.\r\n- Inference arguments can be found at `./options/test_options.py`. \r\n- Download the pretrained DeltaMapper model for editing human face from [here](https://drive.google.com/file/d/1Mb2WiELoVDPDIi24tIfoWsjn1l2xTjtZ/view?usp=sharing), and then place it into the folder `./checkpoints` .\r\n- Some inference data are provided in  `./examples`.\r\n\r\nTo produce editing results please run the following commands :\r\n\r\n```python\r\nCUDA_VISIBLE_DEVICES=1 python scripts/inference.py --target \"chubby face\",\"face with eyeglasses\",\"face with smile\",\"face with pale skin\",\"face with tanned skin\",\"face with big eyes\",\"face with black clothes\",\"face with blue suit\",\"happy face\",\"face with bangs\",\"face with red hair\",\"face with black hair\",\"face with blond hair\",\"face with curly hair\",\"face with receding hairline\",\"face with bowlcut hairstyle\"\r\n```\r\n\r\nThe produced results are showed in the following. \r\n\r\nYou can also specify your desired target attributes to the flag of `--target`.\r\n\r\n## Inference for real images\r\n\r\n- The main inferece script is placed in `./scripts/inference_real.py`.\r\n- Inference arguments can be found at `./options/test_options.py`. \r\n- Download the pretrained DeltaMapper model for editing human face from [here](https://drive.google.com/file/d/1Mb2WiELoVDPDIi24tIfoWsjn1l2xTjtZ/view?usp=sharing), and then place it into the folder `./checkpoints` .\r\n- Download the pretrained e4e encoder e4e_ffhq_encode.pt from [e4e](https://github.com/omertov/encoder4editing).\r\n- One test image is provided in  `./test_imgs`.\r\n\r\nTo produce editing results please run the following commands :\r\n\r\n```python\r\nCUDA_VISIBLE_DEVICES=1 python scripts/inference_real.py --target \"chubby face\",\"face with eyeglasses\",\"face with smile\",\"face with pale skin\",\"face with tanned skin\",\"face with big eyes\",\"face with black clothes\",\"face with blue suit\",\"happy face\",\"face with bangs\",\"face with red hair\",\"face with black hair\",\"face with blond hair\",\"face with curly hair\",\"face with receding hairline\",\"face with bowlcut hairstyle\"\r\n```\r\n\r\n## Alignment Validation: CLIP Space vs. DeltaSpace via t-SNE Visualization​\r\n\r\n```python\r\ncd tSNE\r\npython compute_tsne.py\r\n```\r\n\r\nAfter executing the implementation code, you can obtain A 2D t-SNE projection of embeddings from both spaces (e.g., CLIP and DeltaSpace). The results are shown below for your convenience.\r\n\r\n![tsne](./tsne.jpg)\r\n\r\n## Results\r\n\r\n![results](./results.jpg)\r\n\r\n## Acknowledgements\r\n\r\nThis code is developed based on the code of [orpatashnik/StyleCLIP](https://github.com/orpatashnik/StyleCLIP) by Or Patashnik et al.\r\n\r\n## Citation\r\nIf you use this code for your research, please cite our paper:\r\n```\r\n@InProceedings{lyu2023deltaedit,\r\n    author    = {Lyu, Yueming and Lin, Tianwei and Li, Fu and He, Dongliang and Dong, Jing and Tan, Tieniu},\r\n    title     = {DeltaEdit: Exploring Text-free Training for Text-Driven Image Manipulation},\r\n    booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},\r\n    year      = {2023}\r\n}\r\n\r\n@article{lyu2023deltaspace,\r\n    author    = {Lyu, Yueming and Zhao, Kang and Peng, Bo and Chen, Huafeng and Jiang, Yue and Zhang, Yingya and Dong, Jing, and Shan Caifeng},\r\n    title     = {DeltaSpace: A Semantic-aligned Feature Space for Flexible Text-guided Image Editing},\r\n    journal   = {arXiv preprint arXiv:2310.08785},\r\n    year      = {2023},\r\n}\r\n\r\n```\r\n"
  },
  {
    "path": "clip/__init__.py",
    "content": "from .clip import *\n"
  },
  {
    "path": "clip/clip.py",
    "content": "import hashlib\nimport os\nimport urllib\nimport warnings\nfrom typing import Union, List\n\nimport torch\nfrom PIL import Image\nfrom torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize\nfrom tqdm import tqdm\n\nfrom .model import build_model\nfrom .simple_tokenizer import SimpleTokenizer as _Tokenizer\n\ntry:\n    from torchvision.transforms import InterpolationMode\n    BICUBIC = InterpolationMode.BICUBIC\nexcept ImportError:\n    BICUBIC = Image.BICUBIC\n\n\nif torch.__version__.split(\".\") < [\"1\", \"7\", \"1\"]:\n    warnings.warn(\"PyTorch version 1.7.1 or higher is recommended\")\n\n\n__all__ = [\"available_models\", \"load\", \"tokenize\"]\n_tokenizer = _Tokenizer()\n\n_MODELS = {\n    \"RN50\": \"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt\",\n    \"RN101\": \"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt\",\n    \"RN50x4\": \"https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt\",\n    \"RN50x16\": \"https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt\",\n    \"ViT-B/32\": \"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt\",\n    \"ViT-B/16\": \"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt\",\n}\n\n\ndef _download(url: str, root: str = os.path.expanduser(\"~/.cache/clip\")):\n    os.makedirs(root, exist_ok=True)\n    filename = os.path.basename(url)\n\n    expected_sha256 = url.split(\"/\")[-2]\n    download_target = os.path.join(root, filename)\n\n    if os.path.exists(download_target) and not os.path.isfile(download_target):\n        raise RuntimeError(f\"{download_target} exists and is not a regular file\")\n\n    if os.path.isfile(download_target):\n        if hashlib.sha256(open(download_target, \"rb\").read()).hexdigest() == expected_sha256:\n            return download_target\n        else:\n            warnings.warn(f\"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file\")\n\n    with urllib.request.urlopen(url) as source, open(download_target, \"wb\") as output:\n        with tqdm(total=int(source.info().get(\"Content-Length\")), ncols=80, unit='iB', unit_scale=True) as loop:\n            while True:\n                buffer = source.read(8192)\n                if not buffer:\n                    break\n\n                output.write(buffer)\n                loop.update(len(buffer))\n\n    if hashlib.sha256(open(download_target, \"rb\").read()).hexdigest() != expected_sha256:\n        raise RuntimeError(f\"Model has been downloaded but the SHA256 checksum does not not match\")\n\n    return download_target\n\n\ndef _transform(n_px):\n    return Compose([\n        Resize(n_px, interpolation=BICUBIC),\n        CenterCrop(n_px),\n        lambda image: image.convert(\"RGB\"),\n        ToTensor(),\n        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),\n    ])\n\n\ndef available_models() -> List[str]:\n    \"\"\"Returns the names of available CLIP models\"\"\"\n    return list(_MODELS.keys())\n\n\ndef load(name: str, device: Union[str, torch.device] = \"cuda\" if torch.cuda.is_available() else \"cpu\", jit=False):\n    \"\"\"Load a CLIP model\n\n    Parameters\n    ----------\n    name : str\n        A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict\n\n    device : Union[str, torch.device]\n        The device to put the loaded model\n\n    jit : bool\n        Whether to load the optimized JIT model or more hackable non-JIT model (default).\n\n    Returns\n    -------\n    model : torch.nn.Module\n        The CLIP model\n\n    preprocess : Callable[[PIL.Image], torch.Tensor]\n        A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input\n    \"\"\"\n    if name in _MODELS:\n        model_path = _download(_MODELS[name])\n    elif os.path.isfile(name):\n        model_path = name\n    else:\n        raise RuntimeError(f\"Model {name} not found; available models = {available_models()}\")\n\n    try:\n        # loading JIT archive\n        model = torch.jit.load(model_path, map_location=device if jit else \"cpu\").eval()\n        state_dict = None\n    except RuntimeError:\n        # loading saved state dict\n        if jit:\n            warnings.warn(f\"File {model_path} is not a JIT archive. Loading as a state dict instead\")\n            jit = False\n        state_dict = torch.load(model_path, map_location=\"cpu\")\n\n    if not jit:\n        model = build_model(state_dict or model.state_dict()).to(device)\n        if str(device) == \"cpu\":\n            model.float()\n        return model, _transform(model.visual.input_resolution)\n\n    # patch the device names\n    device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])\n    device_node = [n for n in device_holder.graph.findAllNodes(\"prim::Constant\") if \"Device\" in repr(n)][-1]\n\n    def patch_device(module):\n        try:\n            graphs = [module.graph] if hasattr(module, \"graph\") else []\n        except RuntimeError:\n            graphs = []\n\n        if hasattr(module, \"forward1\"):\n            graphs.append(module.forward1.graph)\n\n        for graph in graphs:\n            for node in graph.findAllNodes(\"prim::Constant\"):\n                if \"value\" in node.attributeNames() and str(node[\"value\"]).startswith(\"cuda\"):\n                    node.copyAttributes(device_node)\n\n    model.apply(patch_device)\n    patch_device(model.encode_image)\n    patch_device(model.encode_text)\n\n    # patch dtype to float32 on CPU\n    if str(device) == \"cpu\":\n        float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])\n        float_input = list(float_holder.graph.findNode(\"aten::to\").inputs())[1]\n        float_node = float_input.node()\n\n        def patch_float(module):\n            try:\n                graphs = [module.graph] if hasattr(module, \"graph\") else []\n            except RuntimeError:\n                graphs = []\n\n            if hasattr(module, \"forward1\"):\n                graphs.append(module.forward1.graph)\n\n            for graph in graphs:\n                for node in graph.findAllNodes(\"aten::to\"):\n                    inputs = list(node.inputs())\n                    for i in [1, 2]:  # dtype can be the second or third argument to aten::to()\n                        if inputs[i].node()[\"value\"] == 5:\n                            inputs[i].node().copyAttributes(float_node)\n\n        model.apply(patch_float)\n        patch_float(model.encode_image)\n        patch_float(model.encode_text)\n\n        model.float()\n\n    return model, _transform(model.input_resolution.item())\n\n\ndef tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:\n    \"\"\"\n    Returns the tokenized representation of given input string(s)\n\n    Parameters\n    ----------\n    texts : Union[str, List[str]]\n        An input string or a list of input strings to tokenize\n\n    context_length : int\n        The context length to use; all CLIP models use 77 as the context length\n\n    truncate: bool\n        Whether to truncate the text in case its encoding is longer than the context length\n\n    Returns\n    -------\n    A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]\n    \"\"\"\n    if isinstance(texts, str):\n        texts = [texts]\n\n    sot_token = _tokenizer.encoder[\"<|startoftext|>\"]\n    eot_token = _tokenizer.encoder[\"<|endoftext|>\"]\n    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]\n    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)\n\n    for i, tokens in enumerate(all_tokens):\n        if len(tokens) > context_length:\n            if truncate:\n                tokens = tokens[:context_length]\n                tokens[-1] = eot_token\n            else:\n                raise RuntimeError(f\"Input {texts[i]} is too long for context length {context_length}\")\n        result[i, :len(tokens)] = torch.tensor(tokens)\n\n    return result\n"
  },
  {
    "path": "clip/model.py",
    "content": "from collections import OrderedDict\nfrom typing import Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1):\n        super().__init__()\n\n        # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1\n        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)\n        self.bn1 = nn.BatchNorm2d(planes)\n\n        self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(planes)\n\n        self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()\n\n        self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)\n        self.bn3 = nn.BatchNorm2d(planes * self.expansion)\n\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = None\n        self.stride = stride\n\n        if stride > 1 or inplanes != planes * Bottleneck.expansion:\n            # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1\n            self.downsample = nn.Sequential(OrderedDict([\n                (\"-1\", nn.AvgPool2d(stride)),\n                (\"0\", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),\n                (\"1\", nn.BatchNorm2d(planes * self.expansion))\n            ]))\n\n    def forward(self, x: torch.Tensor):\n        identity = x\n\n        out = self.relu(self.bn1(self.conv1(x)))\n        out = self.relu(self.bn2(self.conv2(out)))\n        out = self.avgpool(out)\n        out = self.bn3(self.conv3(out))\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n        return out\n\n\nclass AttentionPool2d(nn.Module):\n    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):\n        super().__init__()\n        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)\n        self.k_proj = nn.Linear(embed_dim, embed_dim)\n        self.q_proj = nn.Linear(embed_dim, embed_dim)\n        self.v_proj = nn.Linear(embed_dim, embed_dim)\n        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)\n        self.num_heads = num_heads\n\n    def forward(self, x):\n        x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1)  # NCHW -> (HW)NC\n        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC\n        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC\n        x, _ = F.multi_head_attention_forward(\n            query=x, key=x, value=x,\n            embed_dim_to_check=x.shape[-1],\n            num_heads=self.num_heads,\n            q_proj_weight=self.q_proj.weight,\n            k_proj_weight=self.k_proj.weight,\n            v_proj_weight=self.v_proj.weight,\n            in_proj_weight=None,\n            in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),\n            bias_k=None,\n            bias_v=None,\n            add_zero_attn=False,\n            dropout_p=0,\n            out_proj_weight=self.c_proj.weight,\n            out_proj_bias=self.c_proj.bias,\n            use_separate_proj_weight=True,\n            training=self.training,\n            need_weights=False\n        )\n\n        return x[0]\n\n\nclass ModifiedResNet(nn.Module):\n    \"\"\"\n    A ResNet class that is similar to torchvision's but contains the following changes:\n    - There are now 3 \"stem\" convolutions as opposed to 1, with an average pool instead of a max pool.\n    - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1\n    - The final pooling layer is a QKV attention instead of an average pool\n    \"\"\"\n\n    def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):\n        super().__init__()\n        self.output_dim = output_dim\n        self.input_resolution = input_resolution\n\n        # the 3-layer stem\n        self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(width // 2)\n        self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(width // 2)\n        self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)\n        self.bn3 = nn.BatchNorm2d(width)\n        self.avgpool = nn.AvgPool2d(2)\n        self.relu = nn.ReLU(inplace=True)\n\n        # residual layers\n        self._inplanes = width  # this is a *mutable* variable used during construction\n        self.layer1 = self._make_layer(width, layers[0])\n        self.layer2 = self._make_layer(width * 2, layers[1], stride=2)\n        self.layer3 = self._make_layer(width * 4, layers[2], stride=2)\n        self.layer4 = self._make_layer(width * 8, layers[3], stride=2)\n\n        embed_dim = width * 32  # the ResNet feature dimension\n        self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)\n\n    def _make_layer(self, planes, blocks, stride=1):\n        layers = [Bottleneck(self._inplanes, planes, stride)]\n\n        self._inplanes = planes * Bottleneck.expansion\n        for _ in range(1, blocks):\n            layers.append(Bottleneck(self._inplanes, planes))\n\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        def stem(x):\n            for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:\n                x = self.relu(bn(conv(x)))\n            x = self.avgpool(x)\n            return x\n\n        x = x.type(self.conv1.weight.dtype)\n        x = stem(x)\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n        x = self.attnpool(x)\n\n        return x\n\n\nclass LayerNorm(nn.LayerNorm):\n    \"\"\"Subclass torch's LayerNorm to handle fp16.\"\"\"\n\n    def forward(self, x: torch.Tensor):\n        orig_type = x.dtype\n        ret = super().forward(x.type(torch.float32))\n        return ret.type(orig_type)\n\n\nclass QuickGELU(nn.Module):\n    def forward(self, x: torch.Tensor):\n        return x * torch.sigmoid(1.702 * x)\n\n\nclass ResidualAttentionBlock(nn.Module):\n    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):\n        super().__init__()\n\n        self.attn = nn.MultiheadAttention(d_model, n_head)\n        self.ln_1 = LayerNorm(d_model)\n        self.mlp = nn.Sequential(OrderedDict([\n            (\"c_fc\", nn.Linear(d_model, d_model * 4)),\n            (\"gelu\", QuickGELU()),\n            (\"c_proj\", nn.Linear(d_model * 4, d_model))\n        ]))\n        self.ln_2 = LayerNorm(d_model)\n        self.attn_mask = attn_mask\n\n    def attention(self, x: torch.Tensor):\n        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None\n        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]\n\n    def forward(self, x: torch.Tensor):\n        x = x + self.attention(self.ln_1(x))\n        x = x + self.mlp(self.ln_2(x))\n        return x\n\n\nclass Transformer(nn.Module):\n    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):\n        super().__init__()\n        self.width = width\n        self.layers = layers\n        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])\n\n    def forward(self, x: torch.Tensor):\n        return self.resblocks(x)\n\n\nclass VisionTransformer(nn.Module):\n    def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.output_dim = output_dim\n        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)\n\n        scale = width ** -0.5\n        self.class_embedding = nn.Parameter(scale * torch.randn(width))\n        self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))\n        self.ln_pre = LayerNorm(width)\n\n        self.transformer = Transformer(width, layers, heads)\n\n        self.ln_post = LayerNorm(width)\n        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))\n\n    def forward(self, x: torch.Tensor):\n        x = self.conv1(x)  # shape = [*, width, grid, grid]\n        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]\n        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]\n        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]\n        x = x + self.positional_embedding.to(x.dtype)\n        x = self.ln_pre(x)\n\n        x = x.permute(1, 0, 2)  # NLD -> LND\n        x = self.transformer(x)\n        x = x.permute(1, 0, 2)  # LND -> NLD\n\n        x = self.ln_post(x[:, 0, :])\n\n        if self.proj is not None:\n            x = x @ self.proj\n\n        return x\n\n\nclass CLIP(nn.Module):\n    def __init__(self,\n                 embed_dim: int,\n                 # vision\n                 image_resolution: int,\n                 vision_layers: Union[Tuple[int, int, int, int], int],\n                 vision_width: int,\n                 vision_patch_size: int,\n                 # text\n                 context_length: int,\n                 vocab_size: int,\n                 transformer_width: int,\n                 transformer_heads: int,\n                 transformer_layers: int\n                 ):\n        super().__init__()\n\n        self.context_length = context_length\n\n        if isinstance(vision_layers, (tuple, list)):\n            vision_heads = vision_width * 32 // 64\n            self.visual = ModifiedResNet(\n                layers=vision_layers,\n                output_dim=embed_dim,\n                heads=vision_heads,\n                input_resolution=image_resolution,\n                width=vision_width\n            )\n        else:\n            vision_heads = vision_width // 64\n            self.visual = VisionTransformer(\n                input_resolution=image_resolution,\n                patch_size=vision_patch_size,\n                width=vision_width,\n                layers=vision_layers,\n                heads=vision_heads,\n                output_dim=embed_dim\n            )\n\n        self.transformer = Transformer(\n            width=transformer_width,\n            layers=transformer_layers,\n            heads=transformer_heads,\n            attn_mask=self.build_attention_mask()\n        )\n\n        self.vocab_size = vocab_size\n        self.token_embedding = nn.Embedding(vocab_size, transformer_width)\n        self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))\n        self.ln_final = LayerNorm(transformer_width)\n\n        self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))\n        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))\n\n        self.initialize_parameters()\n\n    def initialize_parameters(self):\n        nn.init.normal_(self.token_embedding.weight, std=0.02)\n        nn.init.normal_(self.positional_embedding, std=0.01)\n\n        if isinstance(self.visual, ModifiedResNet):\n            if self.visual.attnpool is not None:\n                std = self.visual.attnpool.c_proj.in_features ** -0.5\n                nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)\n                nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)\n                nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)\n                nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)\n\n            for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:\n                for name, param in resnet_block.named_parameters():\n                    if name.endswith(\"bn3.weight\"):\n                        nn.init.zeros_(param)\n\n        proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)\n        attn_std = self.transformer.width ** -0.5\n        fc_std = (2 * self.transformer.width) ** -0.5\n        for block in self.transformer.resblocks:\n            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)\n            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)\n            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)\n            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)\n\n        if self.text_projection is not None:\n            nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)\n\n    def build_attention_mask(self):\n        # lazily create causal attention mask, with full attention between the vision tokens\n        # pytorch uses additive attention mask; fill with -inf\n        mask = torch.empty(self.context_length, self.context_length)\n        mask.fill_(float(\"-inf\"))\n        mask.triu_(1)  # zero out the lower diagonal\n        return mask\n\n    @property\n    def dtype(self):\n        return self.visual.conv1.weight.dtype\n\n    def encode_image(self, image):\n        return self.visual(image.type(self.dtype))\n\n    def encode_text(self, text):\n        x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]\n\n        x = x + self.positional_embedding.type(self.dtype)\n        x = x.permute(1, 0, 2)  # NLD -> LND\n        x = self.transformer(x)\n        x = x.permute(1, 0, 2)  # LND -> NLD\n        x = self.ln_final(x).type(self.dtype)\n\n        # x.shape = [batch_size, n_ctx, transformer.width]\n        # take features from the eot embedding (eot_token is the highest number in each sequence)\n        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection\n\n        return x\n\n    def forward(self, image, text):\n        image_features = self.encode_image(image)\n        text_features = self.encode_text(text)\n\n        # normalized features\n        image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n        text_features = text_features / text_features.norm(dim=-1, keepdim=True)\n\n        # cosine similarity as logits\n        logit_scale = self.logit_scale.exp()\n        logits_per_image = logit_scale * image_features @ text_features.t()\n        logits_per_text = logit_scale * text_features @ image_features.t()\n\n        # shape = [global_batch_size, global_batch_size]\n        return logits_per_image, logits_per_text\n    \n\ndef convert_weights(model: nn.Module):\n    \"\"\"Convert applicable model parameters to fp16\"\"\"\n\n    def _convert_weights_to_fp16(l):\n        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):\n            l.weight.data = l.weight.data.half()\n            if l.bias is not None:\n                l.bias.data = l.bias.data.half()\n\n        if isinstance(l, nn.MultiheadAttention):\n            for attr in [*[f\"{s}_proj_weight\" for s in [\"in\", \"q\", \"k\", \"v\"]], \"in_proj_bias\", \"bias_k\", \"bias_v\"]:\n                tensor = getattr(l, attr)\n                if tensor is not None:\n                    tensor.data = tensor.data.half()\n\n        for name in [\"text_projection\", \"proj\"]:\n            if hasattr(l, name):\n                attr = getattr(l, name)\n                if attr is not None:\n                    attr.data = attr.data.half()\n\n    model.apply(_convert_weights_to_fp16)\n\n\ndef build_model(state_dict: dict):\n    vit = \"visual.proj\" in state_dict\n\n    if vit:\n        vision_width = state_dict[\"visual.conv1.weight\"].shape[0]\n        vision_layers = len([k for k in state_dict.keys() if k.startswith(\"visual.\") and k.endswith(\".attn.in_proj_weight\")])\n        vision_patch_size = state_dict[\"visual.conv1.weight\"].shape[-1]\n        grid_size = round((state_dict[\"visual.positional_embedding\"].shape[0] - 1) ** 0.5)\n        image_resolution = vision_patch_size * grid_size\n    else:\n        counts: list = [len(set(k.split(\".\")[2] for k in state_dict if k.startswith(f\"visual.layer{b}\"))) for b in [1, 2, 3, 4]]\n        vision_layers = tuple(counts)\n        vision_width = state_dict[\"visual.layer1.0.conv1.weight\"].shape[0]\n        output_width = round((state_dict[\"visual.attnpool.positional_embedding\"].shape[0] - 1) ** 0.5)\n        vision_patch_size = None\n        assert output_width ** 2 + 1 == state_dict[\"visual.attnpool.positional_embedding\"].shape[0]\n        image_resolution = output_width * 32\n\n    embed_dim = state_dict[\"text_projection\"].shape[1]\n    context_length = state_dict[\"positional_embedding\"].shape[0]\n    vocab_size = state_dict[\"token_embedding.weight\"].shape[0]\n    transformer_width = state_dict[\"ln_final.weight\"].shape[0]\n    transformer_heads = transformer_width // 64\n    transformer_layers = len(set(k.split(\".\")[2] for k in state_dict if k.startswith(f\"transformer.resblocks\")))\n\n    model = CLIP(\n        embed_dim,\n        image_resolution, vision_layers, vision_width, vision_patch_size,\n        context_length, vocab_size, transformer_width, transformer_heads, transformer_layers\n    )\n\n    for key in [\"input_resolution\", \"context_length\", \"vocab_size\"]:\n        if key in state_dict:\n            del state_dict[key]\n\n    convert_weights(model)\n    model.load_state_dict(state_dict)\n    return model.eval()\n"
  },
  {
    "path": "clip/simple_tokenizer.py",
    "content": "import gzip\nimport html\nimport os\nfrom functools import lru_cache\n\nimport ftfy\nimport regex as re\n\n\n@lru_cache()\ndef default_bpe():\n    return os.path.join(os.path.dirname(os.path.abspath(__file__)), \"bpe_simple_vocab_16e6.txt.gz\")\n\n\n@lru_cache()\ndef bytes_to_unicode():\n    \"\"\"\n    Returns list of utf-8 byte and a corresponding list of unicode strings.\n    The reversible bpe codes work on unicode strings.\n    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.\n    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.\n    This is a signficant percentage of your normal, say, 32K bpe vocab.\n    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.\n    And avoids mapping to whitespace/control characters the bpe code barfs on.\n    \"\"\"\n    bs = list(range(ord(\"!\"), ord(\"~\")+1))+list(range(ord(\"¡\"), ord(\"¬\")+1))+list(range(ord(\"®\"), ord(\"ÿ\")+1))\n    cs = bs[:]\n    n = 0\n    for b in range(2**8):\n        if b not in bs:\n            bs.append(b)\n            cs.append(2**8+n)\n            n += 1\n    cs = [chr(n) for n in cs]\n    return dict(zip(bs, cs))\n\n\ndef get_pairs(word):\n    \"\"\"Return set of symbol pairs in a word.\n    Word is represented as tuple of symbols (symbols being variable-length strings).\n    \"\"\"\n    pairs = set()\n    prev_char = word[0]\n    for char in word[1:]:\n        pairs.add((prev_char, char))\n        prev_char = char\n    return pairs\n\n\ndef basic_clean(text):\n    text = ftfy.fix_text(text)\n    text = html.unescape(html.unescape(text))\n    return text.strip()\n\n\ndef whitespace_clean(text):\n    text = re.sub(r'\\s+', ' ', text)\n    text = text.strip()\n    return text\n\n\nclass SimpleTokenizer(object):\n    def __init__(self, bpe_path: str = default_bpe()):\n        self.byte_encoder = bytes_to_unicode()\n        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}\n        merges = gzip.open(bpe_path).read().decode(\"utf-8\").split('\\n')\n        merges = merges[1:49152-256-2+1]\n        merges = [tuple(merge.split()) for merge in merges]\n        vocab = list(bytes_to_unicode().values())\n        vocab = vocab + [v+'</w>' for v in vocab]\n        for merge in merges:\n            vocab.append(''.join(merge))\n        vocab.extend(['<|startoftext|>', '<|endoftext|>'])\n        self.encoder = dict(zip(vocab, range(len(vocab))))\n        self.decoder = {v: k for k, v in self.encoder.items()}\n        self.bpe_ranks = dict(zip(merges, range(len(merges))))\n        self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}\n        self.pat = re.compile(r\"\"\"<\\|startoftext\\|>|<\\|endoftext\\|>|'s|'t|'re|'ve|'m|'ll|'d|[\\p{L}]+|[\\p{N}]|[^\\s\\p{L}\\p{N}]+\"\"\", re.IGNORECASE)\n\n    def bpe(self, token):\n        if token in self.cache:\n            return self.cache[token]\n        word = tuple(token[:-1]) + ( token[-1] + '</w>',)\n        pairs = get_pairs(word)\n\n        if not pairs:\n            return token+'</w>'\n\n        while True:\n            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))\n            if bigram not in self.bpe_ranks:\n                break\n            first, second = bigram\n            new_word = []\n            i = 0\n            while i < len(word):\n                try:\n                    j = word.index(first, i)\n                    new_word.extend(word[i:j])\n                    i = j\n                except:\n                    new_word.extend(word[i:])\n                    break\n\n                if word[i] == first and i < len(word)-1 and word[i+1] == second:\n                    new_word.append(first+second)\n                    i += 2\n                else:\n                    new_word.append(word[i])\n                    i += 1\n            new_word = tuple(new_word)\n            word = new_word\n            if len(word) == 1:\n                break\n            else:\n                pairs = get_pairs(word)\n        word = ' '.join(word)\n        self.cache[token] = word\n        return word\n\n    def encode(self, text):\n        bpe_tokens = []\n        text = whitespace_clean(basic_clean(text)).lower()\n        for token in re.findall(self.pat, text):\n            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))\n            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))\n        return bpe_tokens\n\n    def decode(self, tokens):\n        text = ''.join([self.decoder[token] for token in tokens])\n        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=\"replace\").replace('</w>', ' ')\n        return text\n"
  },
  {
    "path": "datasets/test_dataset.py",
    "content": "import numpy as np\n\nimport torch\nfrom torch.utils.data import Dataset\n\nclass TestLatentsDataset(Dataset):\n    def __init__(self):\n\n        style_latents_list = []\n        clip_latents_list = []\n        wplus_latents_list = []\n        \n        #change the paths here for testing other latent codes\n        style_latents_list.append(torch.Tensor(np.load(\"./examples/sspace_img_feat.npy\")))\n        clip_latents_list.append(torch.Tensor(np.load(\"./examples/cspace_img_feat.npy\")))\n        wplus_latents_list.append(torch.Tensor(np.load(\"./examples/wplus_img_feat.npy\")))\n        \n        self.style_latents = torch.cat(style_latents_list, dim=0)\n        self.clip_latents = torch.cat(clip_latents_list, dim=0)\n        self.wplus_latents = torch.cat(wplus_latents_list, dim=0)\n        \n    def __len__(self):\n\n        return self.style_latents.shape[0]\n\n    def __getitem__(self, index):\n\n        latent_s1 = self.style_latents[index]\n        latent_c1 = self.clip_latents[index]\n        latent_w1 = self.wplus_latents[index]\n        latent_c1 = latent_c1 / latent_c1.norm(dim=-1, keepdim=True).float()\n        \n        delta_c = torch.cat([latent_c1, latent_c1], dim=0)\n        \n        return latent_s1, delta_c, latent_w1"
  },
  {
    "path": "datasets/train_dataset.py",
    "content": "import copy\nimport random\nimport numpy as np\n\nimport torch\nfrom torch.utils.data import Dataset\n\nclass TrainLatentsDataset(Dataset):\n    def __init__(self, opts, cycle=True):\n\n        style_latents_list = []\n        clip_latents_list = []\n        wplus_latents_list = []\n\n        style_latents_list.append(torch.Tensor(np.load(f\"./latent_code/{opts.classname}/sspace_noise_feat.npy\")))\n        clip_latents_list.append(torch.Tensor(np.load(f\"./latent_code/{opts.classname}/cspace_noise_feat.npy\")))\n        wplus_latents_list.append(torch.Tensor(np.load(f\"./latent_code/{opts.classname}/wspace_noise_feat.npy\")))\n        \n        style_latents_list.append(torch.Tensor(np.load(f\"./latent_code/{opts.classname}/sspace_ffhq_feat.npy\")))\n        clip_latents_list.append(torch.Tensor(np.load(f\"./latent_code/{opts.classname}/cspace_ffhq_feat.npy\")))\n        wplus_latents_list.append(torch.Tensor(np.load(f\"./latent_code/{opts.classname}/wspace_ffhq_feat.npy\")))\n        \n        self.style_latents = torch.cat(style_latents_list, dim=0)\n        self.clip_latents = torch.cat(clip_latents_list, dim=0)\n        self.wplus_latents = torch.cat(wplus_latents_list, dim=0)\n\n        self.style_latents = self.style_latents[:200000+58000]\n        self.clip_latents = self.clip_latents[:200000+58000]\n        self.wplus_latents = self.wplus_latents[:200000+58000]\n\n        self.dataset_size = self.style_latents.shape[0]\n        print(\"dataset size\", self.dataset_size)\n        self.cycle = cycle\n        \n    def __len__(self):\n        if self.cycle:\n            return self.style_latents.shape[0] * 50\n        else:\n            return self.style_latents.shape[0]\n\n    def __getitem__(self, index):\n        if self.cycle:\n            index = index % self.dataset_size\n\n        latent_s1 = self.style_latents[index]\n        latent_c1 = self.clip_latents[index]\n        latent_w1 = self.wplus_latents[index]\n        latent_c1 = latent_c1 / latent_c1.norm(dim=-1, keepdim=True).float()\n\n        random_index = random.randint(0, self.dataset_size - 1)\n        latent_s2 = self.style_latents[random_index]\n        latent_c2 = self.clip_latents[random_index]\n        latent_w2 = self.wplus_latents[random_index]\n        latent_c2 = latent_c2 / latent_c2.norm(dim=-1, keepdim=True).float()\n\n        delta_s1 = latent_s2 - latent_s1\n        delta_c = latent_c2 - latent_c1\n        \n        delta_c = delta_c / delta_c.norm(dim=-1, keepdim=True).float().clamp(min=1e-5)\n        delta_c = torch.cat([latent_c1, delta_c], dim=0)\n\n        return latent_s1, delta_c, delta_s1"
  },
  {
    "path": "delta_mapper.py",
    "content": "import math\n\nimport torch\nfrom torch import nn\nfrom torch.nn import Module\nimport torch.nn.functional as F\n\nfrom models.stylegan2.model import EqualLinear, PixelNorm\n\nclass Mapper(Module):\n\n    def __init__(self, in_channel=512, out_channel=512, norm=True, num_layers=4):\n        super(Mapper, self).__init__()\n\n        layers = [PixelNorm()] if norm else []\n        \n        layers.append(EqualLinear(in_channel, out_channel, lr_mul=0.01, activation='fused_lrelu'))\n        for _ in range(num_layers-1):\n            layers.append(EqualLinear(out_channel, out_channel, lr_mul=0.01, activation='fused_lrelu'))\n        self.mapping = nn.Sequential(*layers)\n\n    def forward(self, x):\n        x = self.mapping(x)\n        return x\n\nclass DeltaMapper(Module):\n\n    def __init__(self):\n        super(DeltaMapper, self).__init__()\n\n        #Style Module(sm)\n        self.sm_coarse = Mapper(512,  512)\n        self.sm_medium = Mapper(512,  512)\n        self.sm_fine   = Mapper(2464, 2464)\n\n        #Condition Module(cm)\n        self.cm_coarse = Mapper(1024, 512)\n        self.cm_medium = Mapper(1024, 512)\n        self.cm_fine   = Mapper(1024, 2464)\n\n        #Fusion Module(fm)\n        self.fm_coarse = Mapper(512*2,  512,  norm=False)\n        self.fm_medium = Mapper(512*2,  512,  norm=False)\n        self.fm_fine   = Mapper(2464*2, 2464, norm=False)\n        \n    def forward(self, sspace_feat, clip_feat):\n\n        s_coarse = sspace_feat[:, :3*512].view(-1,3,512)\n        s_medium = sspace_feat[:, 3*512:7*512].view(-1,4,512)\n        s_fine   = sspace_feat[:, 7*512:] #channels:2464\n\n        s_coarse = self.sm_coarse(s_coarse)\n        s_medium = self.sm_medium(s_medium)\n        s_fine   = self.sm_fine(s_fine)\n\n        c_coarse = self.cm_coarse(clip_feat)\n        c_medium = self.cm_medium(clip_feat)\n        c_fine   = self.cm_fine(clip_feat)\n\n        x_coarse = torch.cat([s_coarse, torch.stack([c_coarse]*3, dim=1)], dim=2) #[b,3,1024]\n        x_medium = torch.cat([s_medium, torch.stack([c_medium]*4, dim=1)], dim=2) #[b,4,1024]\n        x_fine   = torch.cat([s_fine, c_fine], dim=1) #[b,2464*2]\n\n        x_coarse = self.fm_coarse(x_coarse)\n        x_coarse = x_coarse.view(-1,3*512)\n\n        x_medium = self.fm_medium(x_medium)\n        x_medium = x_medium.view(-1,4*512)\n\n        x_fine   = self.fm_fine(x_fine)\n\n        out = torch.cat([x_coarse, x_medium, x_fine], dim=1)\n        return out"
  },
  {
    "path": "generate_codes.py",
    "content": "import os\nimport argparse\nimport clip\n\nimport random\nimport numpy as np\nimport torch\nfrom torchvision import utils\nfrom utils import stylespace_util\nfrom models.stylegan2.model import Generator\n\ndef save_image_pytorch(img, name):\n    \"\"\"Helper function to save torch tensor into an image file.\"\"\"\n    utils.save_image(\n        img,\n        name,\n        nrow=1,\n        padding=0,\n        normalize=True,\n        range=(-1, 1),\n    )\n\n\ndef generate(args, netG, device, mean_latent):\n\n    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n    model, preprocess = clip.load(\"ViT-B/32\", device=device)\n    avg_pool = torch.nn.AvgPool2d(kernel_size=1024 // 32)\n    upsample = torch.nn.Upsample(scale_factor=7)\n\n    ind = 0\n    with torch.no_grad():\n        netG.eval()\n\n        # Generate images from a file of input noises\n        if args.fixed_z is not None:\n            sample_z = torch.load(args.fixed_z, map_location=device)\n            for start in range(0, sample_z.size(0), args.batch_size):\n                end = min(start + args.batch_size, sample_z.size(0))\n                z_batch = sample_z[start:end]\n                sample, _ = netG([z_batch], truncation=args.truncation, truncation_latent=mean_latent)\n                for s in sample:\n                    save_image_pytorch(s, f'{args.save_dir}/{str(ind).zfill(6)}.png')\n                    ind += 1\n            return\n\n        # Generate image by sampling input noises\n        w_latents_list = []\n        s_latents_list = []\n        c_latents_list = []\n        for start in range(0, args.samples, args.batch_size):\n            end = min(start + args.batch_size, args.samples)\n            batch_sz = end - start\n            print(f'current_num:{start}')\n            sample_z = torch.randn(batch_sz, 512, device=device)\n\n            sample, w_latents = netG([sample_z], truncation=args.truncation, truncation_latent=mean_latent,return_latents=True)\n            style_space, noise = stylespace_util.encoder_latent(netG, w_latents)\n            s_latents = torch.cat(style_space, dim=1)\n\n            tmp_imgs = stylespace_util.decoder(netG, style_space, w_latents, noise)\n            # for s in tmp_imgs:\n            #     save_image_pytorch(s, f'{args.save_dir}/{str(ind).zfill(6)}.png')\n            #     ind += 1\n\n            img_gen_for_clip = upsample(tmp_imgs)\n            img_gen_for_clip = avg_pool(img_gen_for_clip)\n            c_latents = model.encode_image(img_gen_for_clip)\n\n            w_latents_list.append(w_latents)\n            s_latents_list.append(s_latents)\n            c_latents_list.append(c_latents)\n        w_all_latents = torch.cat(w_latents_list, dim=0)\n        s_all_latents = torch.cat(s_latents_list, dim=0)\n        c_all_latents = torch.cat(c_latents_list, dim=0)\n\n        print(w_all_latents.size())\n        print(s_all_latents.size())\n        print(c_all_latents.size())\n\n        w_all_latents = w_all_latents.cpu().numpy()\n        s_all_latents = s_all_latents.cpu().numpy()\n        c_all_latents = c_all_latents.cpu().numpy()\n\n        os.makedirs(os.path.join(args.save_dir, args.classname), exist_ok=True)\n        np.save(f\"{args.save_dir}/{args.classname}/wspace_noise_feat.npy\", w_all_latents)\n        np.save(f\"{args.save_dir}/{args.classname}/sspace_noise_feat.npy\", s_all_latents)\n        np.save(f\"{args.save_dir}/{args.classname}/cspace_noise_feat.npy\", c_all_latents)\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument('--classname', type=str, default='ffhq', help=\"place to save the output\")\n    parser.add_argument('--save_dir', type=str, default='./latent_code', help=\"place to save the output\")\n    parser.add_argument('--ckpt', type=str, default='./models/pretrained_models', help=\"checkpoint file for the generator\")\n    parser.add_argument('--size', type=int, default=1024, help=\"output size of the generator\")\n    parser.add_argument('--fixed_z', type=str, default=None, help=\"expect a .pth file. If given, will use this file as the input noise for the output\")\n    parser.add_argument('--w_shift', type=str, default=None, help=\"expect a .pth file. Apply a w-latent shift to the generator\")\n    parser.add_argument('--batch_size', type=int, default=10, help=\"batch size used to generate outputs\")\n    parser.add_argument('--samples', type=int, default=200000, help=\"200000 number of samples to generate, will be overridden if --fixed_z is given\")\n    parser.add_argument('--truncation', type=float, default=1, help=\"strength of truncation:0.5ori\")\n    parser.add_argument('--truncation_mean', type=int, default=4096, help=\"number of samples to calculate the mean latent for truncation\")\n    parser.add_argument('--seed', type=int, default=None, help=\"if specified, use a fixed random seed\")\n    parser.add_argument('--device', type=str, default='cuda')\n\n    args = parser.parse_args()\n\n    device = args.device\n    # use a fixed seed if given\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        torch.cuda.manual_seed_all(args.seed)\n\n    if not os.path.exists(args.save_dir):\n        os.makedirs(args.save_dir)\n\n    netG = Generator(args.size, 512, 8).to(device)\n    if args.classname == 'ffhq':\n        ckpt_path = os.path.join(args.ckpt,f'stylegan2-{args.classname}-config-f.pt')\n    else:\n        ckpt_path = os.path.join(args.ckpt,f'stylegan2-{args.classname}','netG.pth')\n    print(ckpt_path)\n    checkpoint = torch.load(ckpt_path, map_location='cpu')\n\n    if args.classname == 'ffhq':\n        netG.load_state_dict(checkpoint['g_ema'])\n    else:\n        netG.load_state_dict(checkpoint)\n\n    # get mean latent if truncation is applied\n    if args.truncation < 1:\n        with torch.no_grad():\n            mean_latent = netG.mean_latent(args.truncation_mean)\n    else:\n        mean_latent = None\n\n    generate(args, netG, device, mean_latent)\n"
  },
  {
    "path": "models/encoders/__init__.py",
    "content": ""
  },
  {
    "path": "models/encoders/helpers.py",
    "content": "from collections import namedtuple\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module\n\n\"\"\"\nArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)\n\"\"\"\n\n\nclass Flatten(Module):\n    def forward(self, input):\n        return input.view(input.size(0), -1)\n\n\ndef l2_norm(input, axis=1):\n    norm = torch.norm(input, 2, axis, True)\n    output = torch.div(input, norm)\n    return output\n\n\nclass Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):\n    \"\"\" A named tuple describing a ResNet block. \"\"\"\n\n\ndef get_block(in_channel, depth, num_units, stride=2):\n    return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]\n\n\ndef get_blocks(num_layers):\n    if num_layers == 50:\n        blocks = [\n            get_block(in_channel=64, depth=64, num_units=3),\n            get_block(in_channel=64, depth=128, num_units=4),\n            get_block(in_channel=128, depth=256, num_units=14),\n            get_block(in_channel=256, depth=512, num_units=3)\n        ]\n    elif num_layers == 100:\n        blocks = [\n            get_block(in_channel=64, depth=64, num_units=3),\n            get_block(in_channel=64, depth=128, num_units=13),\n            get_block(in_channel=128, depth=256, num_units=30),\n            get_block(in_channel=256, depth=512, num_units=3)\n        ]\n    elif num_layers == 152:\n        blocks = [\n            get_block(in_channel=64, depth=64, num_units=3),\n            get_block(in_channel=64, depth=128, num_units=8),\n            get_block(in_channel=128, depth=256, num_units=36),\n            get_block(in_channel=256, depth=512, num_units=3)\n        ]\n    else:\n        raise ValueError(\"Invalid number of layers: {}. Must be one of [50, 100, 152]\".format(num_layers))\n    return blocks\n\n\nclass SEModule(Module):\n    def __init__(self, channels, reduction):\n        super(SEModule, self).__init__()\n        self.avg_pool = AdaptiveAvgPool2d(1)\n        self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)\n        self.relu = ReLU(inplace=True)\n        self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)\n        self.sigmoid = Sigmoid()\n\n    def forward(self, x):\n        module_input = x\n        x = self.avg_pool(x)\n        x = self.fc1(x)\n        x = self.relu(x)\n        x = self.fc2(x)\n        x = self.sigmoid(x)\n        return module_input * x\n\n\nclass bottleneck_IR(Module):\n    def __init__(self, in_channel, depth, stride):\n        super(bottleneck_IR, self).__init__()\n        if in_channel == depth:\n            self.shortcut_layer = MaxPool2d(1, stride)\n        else:\n            self.shortcut_layer = Sequential(\n                Conv2d(in_channel, depth, (1, 1), stride, bias=False),\n                BatchNorm2d(depth)\n            )\n        self.res_layer = Sequential(\n            BatchNorm2d(in_channel),\n            Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),\n            Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)\n        )\n\n    def forward(self, x):\n        shortcut = self.shortcut_layer(x)\n        res = self.res_layer(x)\n        return res + shortcut\n\n\nclass bottleneck_IR_SE(Module):\n    def __init__(self, in_channel, depth, stride):\n        super(bottleneck_IR_SE, self).__init__()\n        if in_channel == depth:\n            self.shortcut_layer = MaxPool2d(1, stride)\n        else:\n            self.shortcut_layer = Sequential(\n                Conv2d(in_channel, depth, (1, 1), stride, bias=False),\n                BatchNorm2d(depth)\n            )\n        self.res_layer = Sequential(\n            BatchNorm2d(in_channel),\n            Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),\n            PReLU(depth),\n            Conv2d(depth, depth, (3, 3), stride, 1, bias=False),\n            BatchNorm2d(depth),\n            SEModule(depth, 16)\n        )\n\n    def forward(self, x):\n        shortcut = self.shortcut_layer(x)\n        res = self.res_layer(x)\n        return res + shortcut\n\n\ndef _upsample_add(x, y):\n    \"\"\"Upsample and add two feature maps.\n    Args:\n      x: (Variable) top feature map to be upsampled.\n      y: (Variable) lateral feature map.\n    Returns:\n      (Variable) added feature map.\n    Note in PyTorch, when input size is odd, the upsampled feature map\n    with `F.upsample(..., scale_factor=2, mode='nearest')`\n    maybe not equal to the lateral feature map size.\n    e.g.\n    original input size: [N,_,15,15] ->\n    conv2d feature map size: [N,_,8,8] ->\n    upsampled feature map size: [N,_,16,16]\n    So we choose bilinear upsample which supports arbitrary output sizes.\n    \"\"\"\n    _, _, H, W = y.size()\n    return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y\n"
  },
  {
    "path": "models/encoders/model_irse.py",
    "content": "from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module\nfrom models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm\n\n\"\"\"\nModified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)\n\"\"\"\n\n\nclass Backbone(Module):\n    def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):\n        super(Backbone, self).__init__()\n        assert input_size in [112, 224], \"input_size should be 112 or 224\"\n        assert num_layers in [50, 100, 152], \"num_layers should be 50, 100 or 152\"\n        assert mode in ['ir', 'ir_se'], \"mode should be ir or ir_se\"\n        blocks = get_blocks(num_layers)\n        if mode == 'ir':\n            unit_module = bottleneck_IR\n        elif mode == 'ir_se':\n            unit_module = bottleneck_IR_SE\n        self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),\n                                      BatchNorm2d(64),\n                                      PReLU(64))\n        if input_size == 112:\n            self.output_layer = Sequential(BatchNorm2d(512),\n                                           Dropout(drop_ratio),\n                                           Flatten(),\n                                           Linear(512 * 7 * 7, 512),\n                                           BatchNorm1d(512, affine=affine))\n        else:\n            self.output_layer = Sequential(BatchNorm2d(512),\n                                           Dropout(drop_ratio),\n                                           Flatten(),\n                                           Linear(512 * 14 * 14, 512),\n                                           BatchNorm1d(512, affine=affine))\n\n        modules = []\n        for block in blocks:\n            for bottleneck in block:\n                modules.append(unit_module(bottleneck.in_channel,\n                                           bottleneck.depth,\n                                           bottleneck.stride))\n        self.body = Sequential(*modules)\n\n    def forward(self, x):\n        x = self.input_layer(x)\n        x = self.body(x)\n        x = self.output_layer(x)\n        return l2_norm(x)\n\n\ndef IR_50(input_size):\n    \"\"\"Constructs a ir-50 model.\"\"\"\n    model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)\n    return model\n\n\ndef IR_101(input_size):\n    \"\"\"Constructs a ir-101 model.\"\"\"\n    model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)\n    return model\n\n\ndef IR_152(input_size):\n    \"\"\"Constructs a ir-152 model.\"\"\"\n    model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)\n    return model\n\n\ndef IR_SE_50(input_size):\n    \"\"\"Constructs a ir_se-50 model.\"\"\"\n    model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)\n    return model\n\n\ndef IR_SE_101(input_size):\n    \"\"\"Constructs a ir_se-101 model.\"\"\"\n    model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)\n    return model\n\n\ndef IR_SE_152(input_size):\n    \"\"\"Constructs a ir_se-152 model.\"\"\"\n    model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)\n    return model\n"
  },
  {
    "path": "models/encoders/psp_encoders.py",
    "content": "from enum import Enum\nimport math\nimport numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module\n\nfrom models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE, _upsample_add\nfrom models.stylegan2.model import EqualLinear\n\n\nclass ProgressiveStage(Enum):\n    WTraining = 0\n    Delta1Training = 1\n    Delta2Training = 2\n    Delta3Training = 3\n    Delta4Training = 4\n    Delta5Training = 5\n    Delta6Training = 6\n    Delta7Training = 7\n    Delta8Training = 8\n    Delta9Training = 9\n    Delta10Training = 10\n    Delta11Training = 11\n    Delta12Training = 12\n    Delta13Training = 13\n    Delta14Training = 14\n    Delta15Training = 15\n    Delta16Training = 16\n    Delta17Training = 17\n    Inference = 18\n\n\nclass GradualStyleBlock(Module):\n    def __init__(self, in_c, out_c, spatial):\n        super(GradualStyleBlock, self).__init__()\n        self.out_c = out_c\n        self.spatial = spatial\n        num_pools = int(np.log2(spatial))\n        modules = []\n        modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),\n                    nn.LeakyReLU()]\n        for i in range(num_pools - 1):\n            modules += [\n                Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),\n                nn.LeakyReLU()\n            ]\n        self.convs = nn.Sequential(*modules)\n        self.linear = EqualLinear(out_c, out_c, lr_mul=1)\n\n    def forward(self, x):\n        x = self.convs(x)\n        x = x.view(-1, self.out_c)\n        x = self.linear(x)\n        return x\n\n\nclass GradualStyleEncoder(Module):\n    def __init__(self, num_layers, mode='ir', opts=None):\n        super(GradualStyleEncoder, self).__init__()\n        assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'\n        assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'\n        blocks = get_blocks(num_layers)\n        if mode == 'ir':\n            unit_module = bottleneck_IR\n        elif mode == 'ir_se':\n            unit_module = bottleneck_IR_SE\n        self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),\n                                      BatchNorm2d(64),\n                                      PReLU(64))\n        modules = []\n        for block in blocks:\n            for bottleneck in block:\n                modules.append(unit_module(bottleneck.in_channel,\n                                           bottleneck.depth,\n                                           bottleneck.stride))\n        self.body = Sequential(*modules)\n\n        self.styles = nn.ModuleList()\n        log_size = int(math.log(opts.stylegan_size, 2))\n        self.style_count = 2 * log_size - 2\n        self.coarse_ind = 3\n        self.middle_ind = 7\n        for i in range(self.style_count):\n            if i < self.coarse_ind:\n                style = GradualStyleBlock(512, 512, 16)\n            elif i < self.middle_ind:\n                style = GradualStyleBlock(512, 512, 32)\n            else:\n                style = GradualStyleBlock(512, 512, 64)\n            self.styles.append(style)\n        self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)\n        self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)\n\n    def forward(self, x):\n        x = self.input_layer(x)\n\n        latents = []\n        modulelist = list(self.body._modules.values())\n        for i, l in enumerate(modulelist):\n            x = l(x)\n            if i == 6:\n                c1 = x\n            elif i == 20:\n                c2 = x\n            elif i == 23:\n                c3 = x\n\n        for j in range(self.coarse_ind):\n            latents.append(self.styles[j](c3))\n\n        p2 = _upsample_add(c3, self.latlayer1(c2))\n        for j in range(self.coarse_ind, self.middle_ind):\n            latents.append(self.styles[j](p2))\n\n        p1 = _upsample_add(p2, self.latlayer2(c1))\n        for j in range(self.middle_ind, self.style_count):\n            latents.append(self.styles[j](p1))\n\n        out = torch.stack(latents, dim=1)\n        return out\n\n\nclass Encoder4Editing(Module):\n    def __init__(self, num_layers, stylegan_size, mode='ir'):\n        super(Encoder4Editing, self).__init__()\n        assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'\n        assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'\n        blocks = get_blocks(num_layers)\n        if mode == 'ir':\n            unit_module = bottleneck_IR\n        elif mode == 'ir_se':\n            unit_module = bottleneck_IR_SE\n        self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),\n                                      BatchNorm2d(64),\n                                      PReLU(64))\n        modules = []\n        for block in blocks:\n            for bottleneck in block:\n                modules.append(unit_module(bottleneck.in_channel,\n                                           bottleneck.depth,\n                                           bottleneck.stride))\n        self.body = Sequential(*modules)\n\n        self.styles = nn.ModuleList()\n        log_size = int(math.log(stylegan_size, 2))\n        self.style_count = 2 * log_size - 2\n        self.coarse_ind = 3\n        self.middle_ind = 7\n\n        for i in range(self.style_count):\n            if i < self.coarse_ind:\n                style = GradualStyleBlock(512, 512, 16)\n            elif i < self.middle_ind:\n                style = GradualStyleBlock(512, 512, 32)\n            else:\n                style = GradualStyleBlock(512, 512, 64)\n            self.styles.append(style)\n\n        self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)\n        self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)\n\n        self.progressive_stage = ProgressiveStage.Inference\n\n    def get_deltas_starting_dimensions(self):\n        ''' Get a list of the initial dimension of every delta from which it is applied '''\n        return list(range(self.style_count))  # Each dimension has a delta applied to it\n\n    def set_progressive_stage(self, new_stage: ProgressiveStage):\n        self.progressive_stage = new_stage\n        print('Changed progressive stage to: ', new_stage)\n\n    def forward(self, x):\n        x = self.input_layer(x)\n\n        modulelist = list(self.body._modules.values())\n        for i, l in enumerate(modulelist):\n            x = l(x)\n            if i == 6:\n                c1 = x\n            elif i == 20:\n                c2 = x\n            elif i == 23:\n                c3 = x\n\n        # Infer main W and duplicate it\n        w0 = self.styles[0](c3)\n        w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2)\n        stage = self.progressive_stage.value\n        features = c3\n        for i in range(1, min(stage + 1, self.style_count)):  # Infer additional deltas\n            if i == self.coarse_ind:\n                p2 = _upsample_add(c3, self.latlayer1(c2))  # FPN's middle features\n                features = p2\n            elif i == self.middle_ind:\n                p1 = _upsample_add(p2, self.latlayer2(c1))  # FPN's fine features\n                features = p1\n            delta_i = self.styles[i](features)\n            w[:, i] += delta_i\n        return w\n\n\nclass BackboneEncoderUsingLastLayerIntoW(Module):\n    def __init__(self, num_layers, mode='ir', opts=None):\n        super(BackboneEncoderUsingLastLayerIntoW, self).__init__()\n        print('Using BackboneEncoderUsingLastLayerIntoW')\n        assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'\n        assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'\n        blocks = get_blocks(num_layers)\n        if mode == 'ir':\n            unit_module = bottleneck_IR\n        elif mode == 'ir_se':\n            unit_module = bottleneck_IR_SE\n        self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),\n                                      BatchNorm2d(64),\n                                      PReLU(64))\n        self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1))\n        self.linear = EqualLinear(512, 512, lr_mul=1)\n        modules = []\n        for block in blocks:\n            for bottleneck in block:\n                modules.append(unit_module(bottleneck.in_channel,\n                                           bottleneck.depth,\n                                           bottleneck.stride))\n        self.body = Sequential(*modules)\n        log_size = int(math.log(opts.stylegan_size, 2))\n        self.style_count = 2 * log_size - 2\n\n    def forward(self, x):\n        x = self.input_layer(x)\n        x = self.body(x)\n        x = self.output_pool(x)\n        x = x.view(-1, 512)\n        x = self.linear(x)\n        return x.repeat(self.style_count, 1, 1).permute(1, 0, 2)\n"
  },
  {
    "path": "models/stylegan2/__init__.py",
    "content": ""
  },
  {
    "path": "models/stylegan2/model.py",
    "content": "import math\nimport random\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d\n\nclass PixelNorm(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, input):\n        return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)\n\n\ndef make_kernel(k):\n    k = torch.tensor(k, dtype=torch.float32)\n\n    if k.ndim == 1:\n        k = k[None, :] * k[:, None]\n\n    k /= k.sum()\n\n    return k\n\n\nclass Upsample(nn.Module):\n    def __init__(self, kernel, factor=2):\n        super().__init__()\n\n        self.factor = factor\n        kernel = make_kernel(kernel) * (factor ** 2)\n        self.register_buffer('kernel', kernel)\n\n        p = kernel.shape[0] - factor\n\n        pad0 = (p + 1) // 2 + factor - 1\n        pad1 = p // 2\n\n        self.pad = (pad0, pad1)\n\n    def forward(self, input):\n        out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)\n\n        return out\n\n\nclass Downsample(nn.Module):\n    def __init__(self, kernel, factor=2):\n        super().__init__()\n\n        self.factor = factor\n        kernel = make_kernel(kernel)\n        self.register_buffer('kernel', kernel)\n\n        p = kernel.shape[0] - factor\n\n        pad0 = (p + 1) // 2\n        pad1 = p // 2\n\n        self.pad = (pad0, pad1)\n\n    def forward(self, input):\n        out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)\n\n        return out\n\n\nclass Blur(nn.Module):\n    def __init__(self, kernel, pad, upsample_factor=1):\n        super().__init__()\n\n        kernel = make_kernel(kernel)\n\n        if upsample_factor > 1:\n            kernel = kernel * (upsample_factor ** 2)\n\n        self.register_buffer('kernel', kernel)\n\n        self.pad = pad\n\n    def forward(self, input):\n        out = upfirdn2d(input, self.kernel, pad=self.pad)\n\n        return out\n\n\nclass EqualConv2d(nn.Module):\n    def __init__(\n        self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True\n    ):\n        super().__init__()\n\n        self.weight = nn.Parameter(\n            torch.randn(out_channel, in_channel, kernel_size, kernel_size)\n        )\n        self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)\n\n        self.stride = stride\n        self.padding = padding\n\n        if bias:\n            self.bias = nn.Parameter(torch.zeros(out_channel))\n\n        else:\n            self.bias = None\n\n    def forward(self, input):\n        out = F.conv2d(\n            input,\n            self.weight * self.scale,\n            bias=self.bias,\n            stride=self.stride,\n            padding=self.padding,\n        )\n\n        return out\n\n    def __repr__(self):\n        return (\n            f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'\n            f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'\n        )\n\n\nclass EqualLinear(nn.Module):\n    def __init__(\n        self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None\n    ):\n        super().__init__()\n\n        self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))\n\n        if bias:\n            self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))\n\n        else:\n            self.bias = None\n\n        self.activation = activation\n\n        self.scale = (1 / math.sqrt(in_dim)) * lr_mul\n        self.lr_mul = lr_mul\n\n    def forward(self, input):\n        if self.activation:\n            out = F.linear(input, self.weight * self.scale)\n            out = fused_leaky_relu(out, self.bias * self.lr_mul)\n\n        else:\n            out = F.linear(\n                input, self.weight * self.scale, bias=self.bias * self.lr_mul\n            )\n\n        return out\n\n    def __repr__(self):\n        return (\n            f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'\n        )\n\n\nclass ScaledLeakyReLU(nn.Module):\n    def __init__(self, negative_slope=0.2):\n        super().__init__()\n\n        self.negative_slope = negative_slope\n\n    def forward(self, input):\n        out = F.leaky_relu(input, negative_slope=self.negative_slope)\n\n        return out * math.sqrt(2)\n\n\nclass ModulatedConv2d(nn.Module):\n    def __init__(\n        self,\n        in_channel,\n        out_channel,\n        kernel_size,\n        style_dim,\n        demodulate=True,\n        upsample=False,\n        downsample=False,\n        blur_kernel=[1, 3, 3, 1],\n    ):\n        super().__init__()\n\n        self.eps = 1e-8\n        self.kernel_size = kernel_size\n        self.in_channel = in_channel\n        self.out_channel = out_channel\n        self.upsample = upsample\n        self.downsample = downsample\n\n        if upsample:\n            factor = 2\n            p = (len(blur_kernel) - factor) - (kernel_size - 1)\n            pad0 = (p + 1) // 2 + factor - 1\n            pad1 = p // 2 + 1\n\n            self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)\n\n        if downsample:\n            factor = 2\n            p = (len(blur_kernel) - factor) + (kernel_size - 1)\n            pad0 = (p + 1) // 2\n            pad1 = p // 2\n\n            self.blur = Blur(blur_kernel, pad=(pad0, pad1))\n\n        fan_in = in_channel * kernel_size ** 2\n        self.scale = 1 / math.sqrt(fan_in)\n        self.padding = kernel_size // 2\n\n        self.weight = nn.Parameter(\n            torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)\n        )\n\n        self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)\n\n        self.demodulate = demodulate\n\n    def __repr__(self):\n        return (\n            f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '\n            f'upsample={self.upsample}, downsample={self.downsample})'\n        )\n\n    def forward(self, input, style):\n        batch, in_channel, height, width = input.shape\n\n        style = self.modulation(style).view(batch, 1, in_channel, 1, 1)\n        weight = self.scale * self.weight * style\n\n        if self.demodulate:\n            demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)\n            weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)\n\n        weight = weight.view(\n            batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size\n        )\n\n        if self.upsample:\n            input = input.view(1, batch * in_channel, height, width)\n            weight = weight.view(\n                batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size\n            )\n            weight = weight.transpose(1, 2).reshape(\n                batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size\n            )\n            out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)\n            _, _, height, width = out.shape\n            out = out.view(batch, self.out_channel, height, width)\n            out = self.blur(out)\n\n        elif self.downsample:\n            input = self.blur(input)\n            _, _, height, width = input.shape\n            input = input.view(1, batch * in_channel, height, width)\n            out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)\n            _, _, height, width = out.shape\n            out = out.view(batch, self.out_channel, height, width)\n\n        else:\n            input = input.view(1, batch * in_channel, height, width)\n            out = F.conv2d(input, weight, padding=self.padding, groups=batch)\n            _, _, height, width = out.shape\n            out = out.view(batch, self.out_channel, height, width)\n\n        return out\n\n\nclass NoiseInjection(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n        self.weight = nn.Parameter(torch.zeros(1))\n\n    def forward(self, image, noise=None):\n        if noise is None:\n            batch, _, height, width = image.shape\n            noise = image.new_empty(batch, 1, height, width).normal_()\n\n        return image + self.weight * noise\n\n\nclass ConstantInput(nn.Module):\n    def __init__(self, channel, size=4):\n        super().__init__()\n\n        self.input = nn.Parameter(torch.randn(1, channel, size, size))\n\n    def forward(self, input):\n        batch = input.shape[0]\n        out = self.input.repeat(batch, 1, 1, 1)\n\n        return out\n\n\nclass StyledConv(nn.Module):\n    def __init__(\n        self,\n        in_channel,\n        out_channel,\n        kernel_size,\n        style_dim,\n        upsample=False,\n        blur_kernel=[1, 3, 3, 1],\n        demodulate=True,\n    ):\n        super().__init__()\n\n        self.conv = ModulatedConv2d(\n            in_channel,\n            out_channel,\n            kernel_size,\n            style_dim,\n            upsample=upsample,\n            blur_kernel=blur_kernel,\n            demodulate=demodulate,\n        )\n\n        self.noise = NoiseInjection()\n        # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))\n        # self.activate = ScaledLeakyReLU(0.2)\n        self.activate = FusedLeakyReLU(out_channel)\n\n    def forward(self, input, style, noise=None):\n        out = self.conv(input, style)\n        out = self.noise(out, noise=noise)\n        # out = out + self.bias\n        out = self.activate(out)\n\n        return out\n\n\nclass ToRGB(nn.Module):\n    def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):\n        super().__init__()\n\n        if upsample:\n            self.upsample = Upsample(blur_kernel)\n\n        self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)\n        self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))\n\n    def forward(self, input, style, skip=None):\n        out = self.conv(input, style)\n        out = out + self.bias\n\n        if skip is not None:\n            skip = self.upsample(skip)\n\n            out = out + skip\n\n        return out\n\n\nclass Generator(nn.Module):\n    def __init__(\n        self,\n        size,\n        style_dim,\n        n_mlp,\n        channel_multiplier=2,\n        blur_kernel=[1, 3, 3, 1],\n        lr_mlp=0.01,\n    ):\n        super().__init__()\n\n        self.size = size\n\n        self.style_dim = style_dim\n\n        layers = [PixelNorm()]\n\n        for i in range(n_mlp):\n            layers.append(\n                EqualLinear(\n                    style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'\n                )\n            )\n\n        self.style = nn.Sequential(*layers)\n\n        self.channels = {\n            4: 512,\n            8: 512,\n            16: 512,\n            32: 512,\n            64: 256 * channel_multiplier,\n            128: 128 * channel_multiplier,\n            256: 64 * channel_multiplier,\n            512: 32 * channel_multiplier,\n            1024: 16 * channel_multiplier,\n        }\n\n        self.input = ConstantInput(self.channels[4])\n        self.conv1 = StyledConv(\n            self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel\n        )\n        self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)\n\n        self.log_size = int(math.log(size, 2))\n        self.num_layers = (self.log_size - 2) * 2 + 1\n\n        self.convs = nn.ModuleList()\n        self.upsamples = nn.ModuleList()\n        self.to_rgbs = nn.ModuleList()\n        self.noises = nn.Module()\n\n        in_channel = self.channels[4]\n\n        for layer_idx in range(self.num_layers):\n            res = (layer_idx + 5) // 2\n            shape = [1, 1, 2 ** res, 2 ** res]\n            self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))\n\n        for i in range(3, self.log_size + 1):\n            out_channel = self.channels[2 ** i]\n\n            self.convs.append(\n                StyledConv(\n                    in_channel,\n                    out_channel,\n                    3,\n                    style_dim,\n                    upsample=True,\n                    blur_kernel=blur_kernel,\n                )\n            )\n\n            self.convs.append(\n                StyledConv(\n                    out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel\n                )\n            )\n\n            self.to_rgbs.append(ToRGB(out_channel, style_dim))\n\n            in_channel = out_channel\n\n        self.n_latent = self.log_size * 2 - 2\n\n    def make_noise(self):\n        device = self.input.input.device\n\n        noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]\n\n        for i in range(3, self.log_size + 1):\n            for _ in range(2):\n                noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))\n\n        return noises\n\n    def mean_latent(self, n_latent):\n        latent_in = torch.randn(\n            n_latent, self.style_dim, device=self.input.input.device\n        )\n        latent = self.style(latent_in).mean(0, keepdim=True)\n\n        return latent\n\n    def get_latent(self, input):\n        return self.style(input)\n\n    def forward(\n        self,\n        styles,\n        return_latents=False,\n        inject_index=None,\n        truncation=1,\n        truncation_latent=None,\n        input_is_latent=False,\n        noise=None,\n        randomize_noise=True,\n    ):\n        if not input_is_latent:\n            styles = [self.style(s) for s in styles]\n\n        if noise is None:\n            if randomize_noise:\n                noise = [None] * self.num_layers\n            else:\n                noise = [\n                    getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)\n                ]\n\n        if truncation < 1:\n            style_t = []\n\n            for style in styles:\n                style_t.append(\n                    truncation_latent + truncation * (style - truncation_latent)\n                )\n\n            styles = style_t\n\n        if len(styles) < 2:\n            inject_index = self.n_latent\n\n            if styles[0].ndim < 3:\n                latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)\n\n            else:\n                latent = styles[0]\n\n        else:\n            if inject_index is None:\n                inject_index = random.randint(1, self.n_latent - 1)\n\n            latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)\n            latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)\n\n            latent = torch.cat([latent, latent2], 1)\n\n        out = self.input(latent)\n        out = self.conv1(out, latent[:, 0], noise=noise[0])\n\n        skip = self.to_rgb1(out, latent[:, 1])\n\n        i = 1\n        for conv1, conv2, noise1, noise2, to_rgb in zip(\n            self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs\n        ):\n            out = conv1(out, latent[:, i], noise=noise1)\n            out = conv2(out, latent[:, i + 1], noise=noise2)\n            skip = to_rgb(out, latent[:, i + 2], skip)\n\n            i += 2\n\n        image = skip\n\n        if return_latents:\n            return image, latent\n\n        else:\n            return image, None\n\n\nclass ConvLayer(nn.Sequential):\n    def __init__(\n        self,\n        in_channel,\n        out_channel,\n        kernel_size,\n        downsample=False,\n        blur_kernel=[1, 3, 3, 1],\n        bias=True,\n        activate=True,\n    ):\n        layers = []\n\n        if downsample:\n            factor = 2\n            p = (len(blur_kernel) - factor) + (kernel_size - 1)\n            pad0 = (p + 1) // 2\n            pad1 = p // 2\n\n            layers.append(Blur(blur_kernel, pad=(pad0, pad1)))\n\n            stride = 2\n            self.padding = 0\n\n        else:\n            stride = 1\n            self.padding = kernel_size // 2\n\n        layers.append(\n            EqualConv2d(\n                in_channel,\n                out_channel,\n                kernel_size,\n                padding=self.padding,\n                stride=stride,\n                bias=bias and not activate,\n            )\n        )\n\n        if activate:\n            if bias:\n                layers.append(FusedLeakyReLU(out_channel))\n\n            else:\n                layers.append(ScaledLeakyReLU(0.2))\n\n        super().__init__(*layers)\n\n\nclass ResBlock(nn.Module):\n    def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):\n        super().__init__()\n\n        self.conv1 = ConvLayer(in_channel, in_channel, 3)\n        self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)\n\n        self.skip = ConvLayer(\n            in_channel, out_channel, 1, downsample=True, activate=False, bias=False\n        )\n\n    def forward(self, input):\n        out = self.conv1(input)\n        out = self.conv2(out)\n\n        skip = self.skip(input)\n        out = (out + skip) / math.sqrt(2)\n\n        return out\n\n\nclass Discriminator(nn.Module):\n    def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):\n        super().__init__()\n\n        channels = {\n            4: 512,\n            8: 512,\n            16: 512,\n            32: 512,\n            64: 256 * channel_multiplier,\n            128: 128 * channel_multiplier,\n            256: 64 * channel_multiplier,\n            512: 32 * channel_multiplier,\n            1024: 16 * channel_multiplier,\n        }\n\n        convs = [ConvLayer(3, channels[size], 1)]\n\n        log_size = int(math.log(size, 2))\n\n        in_channel = channels[size]\n\n        for i in range(log_size, 2, -1):\n            out_channel = channels[2 ** (i - 1)]\n\n            convs.append(ResBlock(in_channel, out_channel, blur_kernel))\n\n            in_channel = out_channel\n\n        self.convs = nn.Sequential(*convs)\n\n        self.stddev_group = 4\n        self.stddev_feat = 1\n\n        self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)\n        self.final_linear = nn.Sequential(\n            EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),\n            EqualLinear(channels[4], 1),\n        )\n\n    def forward(self, input):\n        out = self.convs(input)\n\n        batch, channel, height, width = out.shape\n        group = min(batch, self.stddev_group)\n        stddev = out.view(\n            group, -1, self.stddev_feat, channel // self.stddev_feat, height, width\n        )\n        stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)\n        stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)\n        stddev = stddev.repeat(group, 1, height, width)\n        out = torch.cat([out, stddev], 1)\n\n        out = self.final_conv(out)\n\n        out = out.view(batch, -1)\n        out = self.final_linear(out)\n\n        return out\n\n"
  },
  {
    "path": "models/stylegan2/op/__init__.py",
    "content": "from .fused_act import FusedLeakyReLU, fused_leaky_relu\nfrom .upfirdn2d import upfirdn2d\n"
  },
  {
    "path": "models/stylegan2/op/fused_act.py",
    "content": "import os\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nmodule_path = os.path.dirname(__file__)\n\nclass FusedLeakyReLU(nn.Module):\n    def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):\n        super().__init__()\n\n        self.bias = nn.Parameter(torch.zeros(channel))\n        self.negative_slope = negative_slope\n        self.scale = scale\n\n    def forward(self, input):\n        return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)\n\n\ndef fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):\n    rest_dim = [1] * (input.ndim - bias.ndim - 1)\n    input = input.cuda()\n    if input.ndim == 3:\n        return (\n            F.leaky_relu(\n                input + bias.view(1, *rest_dim, bias.shape[0]), negative_slope=negative_slope\n            )\n            * scale\n        )\n    else:\n        return (\n            F.leaky_relu(\n                input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope\n            )\n            * scale\n        )\n\n"
  },
  {
    "path": "models/stylegan2/op/upfirdn2d.py",
    "content": "import os\nimport torch\nfrom torch.nn import functional as F\n\nmodule_path = os.path.dirname(__file__)\n\ndef upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):\n    out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])\n\n    return out\n\ndef upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):\n    \n    _, channel, in_h, in_w = input.shape\n    input = input.reshape(-1, in_h, in_w, 1)\n\n    _, in_h, in_w, minor = input.shape\n    kernel_h, kernel_w = kernel.shape\n\n    out = input.view(-1, in_h, 1, in_w, 1, minor)\n    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])\n    out = out.view(-1, in_h * up_y, in_w * up_x, minor)\n\n    out = F.pad(\n        out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]\n    )\n    out = out[\n        :,\n        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),\n        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),\n        :,\n    ]\n\n    out = out.permute(0, 3, 1, 2)\n    out = out.reshape(\n        [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]\n    )\n    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)\n    out = F.conv2d(out, w)\n    out = out.reshape(\n        -1,\n        minor,\n        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,\n        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,\n    )\n    out = out.permute(0, 2, 3, 1)\n    out = out[:, ::down_y, ::down_x, :]\n\n    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1\n    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1\n\n    return out.view(-1, channel, out_h, out_w)"
  },
  {
    "path": "options/test_options.py",
    "content": "from argparse import ArgumentParser\n\nclass TestOptions:\n\n\tdef __init__(self):\n\t\tself.parser = ArgumentParser()\n\t\tself.initialize()\n\n\tdef initialize(self):\n\t\t# arguments for inference script\n\t\t\n\t\tself.parser.add_argument('--batch_size', default=1, type=int, help='Batch size for inference')\n\t\tself.parser.add_argument('--workers', default=4, type=int, help='Number of test dataloader workers')\n\t\t\n\t\tself.parser.add_argument('--stylegan_weights', default='models/pretrained_models/stylegan2-ffhq-config-f.pt', type=str, help='Path to StyleGAN model weights')\n\t\tself.parser.add_argument('--stylegan_size', default=1024, type=int)\n\t\t\n\t\tself.parser.add_argument(\"--threshold\", type=int, default=0.03)\n\t\tself.parser.add_argument(\"--checkpoint_path\", type=str, default='checkpoints/net_face.pth')\n\t\tself.parser.add_argument(\"--save_dir\", type=str, default='output')\n\t\tself.parser.add_argument(\"--num_all\", type=int, default=20)\n\t\t\n\t\tself.parser.add_argument(\"--target\", type=str, required=True, help='Specify the target attributes to be edited')\n\n\tdef parse(self):\n\t\topts = self.parser.parse_args()\n\t\treturn opts"
  },
  {
    "path": "options/train_options.py",
    "content": "from argparse import ArgumentParser\n\nclass TrainOptions:\n\n\tdef __init__(self):\n\t\tself.parser = ArgumentParser()\n\t\tself.initialize()\n\n\tdef initialize(self):\n\n\t\tself.parser.add_argument('--batch_size', default=64, type=int, help='Batch size for training')\n\t\tself.parser.add_argument('--workers', default=4, type=int, help='Number of train dataloader workers')\n\n\t\tself.parser.add_argument('--learning_rate', default=0.5, type=float, help='Optimizer learning rate')\n\n\t\tself.parser.add_argument('--l2_lambda', default=1.0, type=float, help='l2 loss')\n\t\tself.parser.add_argument('--cos_lambda', default=1.0, type=float, help='cos loss')\n\n\t\tself.parser.add_argument('--checkpoint_path', default='checkpoints', type=str, help='Path to StyleCLIPModel model checkpoint')\n\t\tself.parser.add_argument('--classname', type=str, default='ffhq', help=\"which specific domain for training\")\n\t\tself.parser.add_argument('--print_interval', default=1000, type=int, help='Interval for printing loss values during training')\n\t\tself.parser.add_argument('--val_interval', default=5000, type=int, help='Validation interval')\n\t\tself.parser.add_argument('--save_interval', default=10000, type=int, help='Model checkpoint interval')\n\n\tdef parse(self):\n\t\topts = self.parser.parse_args()\n\t\treturn opts"
  },
  {
    "path": "scripts/inference.py",
    "content": "import os\nimport sys\nsys.path.append(\".\")\nsys.path.append(\"..\")\n\nimport copy\nimport clip\nimport numpy as np\n\nimport torch\nimport torchvision\nfrom torch.utils.data import DataLoader\n\nimport torch.nn.functional as F\n\nfrom datasets.test_dataset import TestLatentsDataset\n\nfrom models.stylegan2.model import Generator\nfrom delta_mapper import DeltaMapper\n\nfrom options.test_options import TestOptions\n\nfrom utils import map_tool\nfrom utils import stylespace_util\n\ndef GetBoundary(fs3,dt,threshold):\n    tmp=np.dot(fs3,dt)\n    \n    select=np.abs(tmp)<threshold\n    return select\n\ndef improved_ds(ds, select):\n    ds_imp = copy.copy(ds)\n    ds_imp[select] = 0\n    ds_imp = ds_imp.unsqueeze(0)\n    return ds_imp\n\ndef main(opts):\n\n    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n    #Initialize test dataset\n    test_dataset = TestLatentsDataset()\n    test_dataloader = DataLoader(test_dataset, \n                                 batch_size=opts.batch_size,\n                                 shuffle=False,\n                                 num_workers=int(opts.workers),\n                                 drop_last=True)\n\n    #Initialize generator\n    print('Loading stylegan weights from pretrained!')\n    g_ema = Generator(size=opts.stylegan_size, style_dim=512, n_mlp=8)\n    g_ema_ckpt = torch.load(opts.stylegan_weights)\n    g_ema.load_state_dict(g_ema_ckpt['g_ema'], strict=False)\n    g_ema.eval()\n    g_ema = g_ema.to(device)\n\n    #load relevance matrix Rs\n    fs3=np.load('./models/stylegan2/npy_ffhq/fs3.npy')\n    np.set_printoptions(suppress=True)\n\n    #Initialze DeltaMapper\n    net = DeltaMapper()\n    net_ckpt = torch.load(opts.checkpoint_path)\n    net.load_state_dict(net_ckpt)\n    net = net.to(device)\n    \n    #Load CLIP model\n    clip_model, preprocess = clip.load(\"ViT-B/32\", device=device)\n\n    os.makedirs(opts.save_dir, exist_ok=True)\n\n    neutral='face'\n    target_list = opts.target.split(',')\n    # print(target_list)\n\n    dt_list = []\n    select_list = []\n    for target in target_list:\n        classnames=[target,neutral]\n        dt = map_tool.GetDt(classnames,clip_model)\n        select = GetBoundary(fs3, dt, opts.threshold)\n        dt = torch.Tensor(dt).to(device)\n        dt = dt / dt.norm(dim=-1, keepdim=True).float().clamp(min=1e-5)\n\n        select_list.append(select)\n        dt_list.append(dt)\n\n    for bid, batch in enumerate(test_dataloader):\n        if bid == opts.num_all:\n            break\n        \n        latent_s, delta_c, latent_w = batch\n        latent_s = latent_s.to(device)\n        delta_c = delta_c.to(device)\n        latent_w = latent_w.to(device)\n        delta_s_list = []\n\n        for i, dt in enumerate(dt_list):\n            delta_c[0, 512:] = dt\n            with torch.no_grad():\n                fake_delta_s = net(latent_s, delta_c)\n                improved_fake_delta_s = improved_ds(fake_delta_s[0], select_list[i])\n            delta_s_list.append(improved_fake_delta_s)\n\n        with torch.no_grad():\n            img_ori = stylespace_util.decoder_validate(g_ema, latent_s, latent_w)\n\n            img_list = [img_ori]\n            for delta_s in delta_s_list:\n                img_gen = stylespace_util.decoder_validate(g_ema, latent_s + delta_s, latent_w)\n                img_list.append(img_gen)\n            img_gen_all = torch.cat(img_list, dim=3)\n            torchvision.utils.save_image(img_gen_all, os.path.join(opts.save_dir, \"%04d.jpg\" %(bid+1)), normalize=True, range=(-1, 1))\n    print(f'completed👍! Please check results in {opts.save_dir}')\n\nif __name__ == \"__main__\":\n    opts = TestOptions().parse()\n    main(opts)"
  },
  {
    "path": "scripts/inference_real.py",
    "content": "import os\nimport sys\nsys.path.append(\".\")\nsys.path.append(\"..\")\n\nimport copy\nimport clip\nimport numpy as np\nfrom PIL import Image\n\nimport torch\nimport torchvision\nfrom torch.utils.data import Dataset\nfrom torch.utils.data import DataLoader\nfrom torchvision import transforms\n\nimport torch.nn.functional as F\n\n\nfrom datasets.test_dataset import TestLatentsDataset\n\nfrom models.stylegan2.model import Generator\nfrom models.encoders import psp_encoders\nfrom delta_mapper import DeltaMapper\n\nfrom options.test_options import TestOptions\n\nfrom utils import map_tool\nfrom utils import stylespace_util\n\n\n\ndef get_keys(d, name):\n    if 'state_dict' in d:\n        d = d['state_dict']\n    d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}\n    return d_filt\n\nclass Imagedataset(Dataset):\n    def __init__(self,\n                 path,\n                 image_size=256,\n                 split=None):\n\n        self.path = path\n        self.images = os.listdir(path)\n\n        self.image_size = image_size\n\n        self.length = len(self.images)\n\n        transform = [\n            transforms.Resize(image_size),\n            transforms.ToTensor(),\n            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n        ]\n\n        self.transform = transforms.Compose(transform)\n\n    def __len__(self):\n        return self.length\n\n    def __getitem__(self, index):\n        cur_name = self.images[index]\n        img_path = os.path.join(self.path, cur_name)\n\n        img = Image.open(img_path).convert(\"RGB\") \n\n        if self.transform is not None:\n            img = self.transform(img)\n        return img\n\ndef encoder_latent(G, latent):\n    # an encoder warper for G\n    #styles = [noise]\n    style_space = []\n    \n    #styles = [G.style(s) for s in styles]\n    noise = [getattr(G.noises, 'noise_{}'.format(i)) for i in range(G.num_layers)]\n    # inject_index = G.n_latent\n    #latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)\n    style_space.append(G.conv1.conv.modulation(latent[:, 0]))\n\n    i = 1\n    for conv1, conv2, to_rgb in zip(\n        G.convs[::2], G.convs[1::2], G.to_rgbs\n    ):\n        style_space.append(conv1.conv.modulation(latent[:, i]))\n        style_space.append(conv2.conv.modulation(latent[:, i+1]))\n        i += 2\n        \n    return style_space, noise\n\ndef GetBoundary(fs3,dt,threshold):\n    tmp=np.dot(fs3,dt)\n    \n    select=np.abs(tmp)<threshold\n    return select\n\ndef improved_ds(ds, select):\n    ds_imp = copy.copy(ds)\n    ds_imp[select] = 0\n    ds_imp = ds_imp.unsqueeze(0)\n    return ds_imp\n\ndef main(opts):\n\n    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n    # NOTE load e4e\n    checkpoint_path = \"encoder4editing-main/e4e_ffhq_encode.pt\"\n    ckpt_enc = torch.load(checkpoint_path, map_location='cpu') #dict_keys(['state_dict', 'latent_avg', 'opts'])\n    encoder = psp_encoders.Encoder4Editing(50, 1024, 'ir_se')\n    encoder.load_state_dict(get_keys(ckpt_enc, 'encoder'), strict=True)\n    encoder.eval()\n    encoder.to(device)\n\n    #Initialize test dataset\n    test_dataset = Imagedataset('./test_imgs', image_size=256)\n    test_dataloader = DataLoader(test_dataset, \n                                 batch_size=opts.batch_size,\n                                 shuffle=False,\n                                 num_workers=int(opts.workers),\n                                 drop_last=True)\n\n    #Initialize generator\n    print('Loading stylegan weights from pretrained!')\n    g_ema = Generator(size=opts.stylegan_size, style_dim=512, n_mlp=8)\n    g_ema_ckpt = torch.load(opts.stylegan_weights)\n    g_ema.load_state_dict(g_ema_ckpt['g_ema'], strict=False)\n    g_ema.eval()\n    g_ema = g_ema.to(device)\n\n    #load relevance matrix Rs\n    fs3=np.load('./models/stylegan2/npy_ffhq/fs3.npy')\n    np.set_printoptions(suppress=True)\n\n    #Initialze DeltaMapper\n    net = DeltaMapper()\n    net_ckpt = torch.load(opts.checkpoint_path)\n    net.load_state_dict(net_ckpt)\n    net = net.to(device)\n    \n    #Load CLIP model\n    clip_model, preprocess = clip.load(\"ViT-B/32\", device=device)\n    avg_pool = torch.nn.AvgPool2d(kernel_size=256//32)\n    upsample = torch.nn.Upsample(scale_factor=7)\n\n    os.makedirs(opts.save_dir, exist_ok=True)\n\n    neutral='face'\n    target_list = opts.target.split(',')\n    # print(target_list)\n\n    dt_list = []\n    select_list = []\n    for target in target_list:\n        classnames=[target,neutral]\n        dt = map_tool.GetDt(classnames,clip_model)\n        select = GetBoundary(fs3, dt, opts.threshold)\n        dt = torch.Tensor(dt).to(device)\n        dt = dt / dt.norm(dim=-1, keepdim=True).float().clamp(min=1e-5)\n\n        select_list.append(select)\n        dt_list.append(dt)\n\n    for bid, batch in enumerate(test_dataloader):\n        if bid == opts.num_all:\n            break\n        input_img = batch.to(device)\n        with torch.no_grad():\n            latent_w = encoder(input_img)\n            latent_avg = ckpt_enc['latent_avg'].cuda()\n            latent_w = latent_w + latent_avg.repeat(latent_w.shape[0], 1, 1)\n\n            style_space, noise = encoder_latent(g_ema, latent_w)\n            latent_s = torch.cat(style_space, dim=1)\n\n            img_gen_for_clip = upsample(input_img)\n            img_gen_for_clip = avg_pool(img_gen_for_clip)\n            c_latents = clip_model.encode_image(img_gen_for_clip)\n            c_latents = c_latents / c_latents.norm(dim=-1, keepdim=True).float()\n\n        delta_s_list = []\n\n        for i, dt in enumerate(dt_list):\n            delta_c = torch.cat((c_latents, dt.unsqueeze(0)), dim=1)\n            with torch.no_grad():\n                fake_delta_s = net(latent_s, delta_c)\n                improved_fake_delta_s = improved_ds(fake_delta_s[0], select_list[i])\n            delta_s_list.append(improved_fake_delta_s)\n\n        with torch.no_grad():\n            img_ori = stylespace_util.decoder_validate(g_ema, latent_s, latent_w)\n\n            img_list = [img_ori]\n            for delta_s in delta_s_list:\n                img_gen = stylespace_util.decoder_validate(g_ema, latent_s + delta_s, latent_w)\n                img_list.append(img_gen)\n            img_gen_all = torch.cat(img_list, dim=3)\n            torchvision.utils.save_image(img_gen_all, os.path.join(opts.save_dir, \"%04d.jpg\" %(bid+1)), normalize=True, range=(-1, 1))\n    print(f'completed👍! Please check results in {opts.save_dir}')\n\nif __name__ == \"__main__\":\n    opts = TestOptions().parse()\n    main(opts)\n"
  },
  {
    "path": "scripts/train.py",
    "content": "import os\nimport sys\n\nimport torch\nfrom torch.utils.data import DataLoader\n\nsys.path.append(\".\")\nsys.path.append(\"..\")\n\nfrom datasets.train_dataset import TrainLatentsDataset\nfrom options.train_options import TrainOptions\nfrom delta_mapper import DeltaMapper\n\ndef main(opts):\n\n    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n    train_dataset = TrainLatentsDataset(opts)\n    train_dataloader = DataLoader(train_dataset,\n                                  batch_size=opts.batch_size,\n                                  shuffle=True,\n                                  num_workers=int(opts.workers),\n                                  drop_last=True)\n\n    #Initialze DeltaMapper\n    net = DeltaMapper().to(device)\n\n    #Initialize optimizer\n    optimizer = torch.optim.Adam(list(net.parameters()), lr=opts.learning_rate)\n\n    #Initialize loss\n    l2_loss = torch.nn.MSELoss().to(device)\n    cosine_loss = torch.nn.CosineSimilarity(dim=-1).to(device)\n\n    #save dir\n    os.makedirs(os.path.join(opts.checkpoint_path, opts.classname), exist_ok=True)\n\n    for batch_idx, batch in enumerate(train_dataloader):\n\n        latent_s, delta_c, delta_s = batch\n        latent_s = latent_s.to(device)\n        delta_c = delta_c.to(device)\n        delta_s = delta_s.to(device)\n\n        fake_delta_s = net(latent_s, delta_c)\n\n        optimizer.zero_grad()\n        loss_l2 = l2_loss(fake_delta_s, delta_s)\n        loss_cos = 1 - torch.mean(cosine_loss(fake_delta_s, delta_s))\n\n        loss = opts.l2_lambda * loss_l2 + opts.cos_lambda * loss_cos\n        loss.backward()\n        optimizer.step()\n\n        if batch_idx % opts.print_interval == 0 :\n            print(batch_idx, loss.detach().cpu().numpy(), loss_l2.detach().cpu().numpy(), loss_cos.detach().cpu().numpy())\n\n        if batch_idx % opts.save_interval == 0:\n            torch.save(net.state_dict(), os.path.join(opts.checkpoint_path, opts.classname, \"net_%06d.pth\" % batch_idx))\n\nif __name__ == \"__main__\":\n    opts = TrainOptions().parse()\n    main(opts)"
  },
  {
    "path": "tSNE/compute_tsne.py",
    "content": "import numpy as np\nimport matplotlib.pyplot as plt\nfrom sklearn.manifold import TSNE\n\ndata_name = 'celeba' # 'celeba' or 'cocoval'\n\n#for img/text\ncspace_img = np.load(f'./{data_name}/cspace_{data_name}_i.npy')\ncspace_text = np.load(f'./{data_name}/cspace_{data_name}_t.npy')\n\n#for deltaimg\ncspace_deltaimg = np.load(f'./{data_name}/cspace_{data_name}_deltai.npy')\ncspace_deltatext = np.load(f'./{data_name}/cspace_{data_name}_deltat.npy')\n\nnum=1000\n\ndata_ori = np.concatenate([cspace_img[:num], cspace_text[:num]], axis=0)\ndata_delta = np.concatenate([cspace_deltaimg[:num], cspace_deltatext[:num]], axis=0)\n\ntsne = TSNE(n_components=2, init='pca')\n\nresult_ori = tsne.fit_transform(data_ori)\nresult_delta = tsne.fit_transform(data_delta)\n\nfor i in range(result_ori.shape[0]):\n    x_min, x_max = np.min(result_ori, 0), np.max(result_ori, 0)\n    data = (result_ori - x_min) / (x_max - x_min)\n    if i < result_ori.shape[0]//2:\n        s0 = plt.scatter(data[i, 0], data[i, 1], color=plt.cm.Set1(0/4), s=12, marker='o')\n    elif i < result_ori.shape[0]:\n        s1 = plt.scatter(data[i, 0], data[i, 1], color=plt.cm.Set1(1/4), s=12, marker='o')\n    \nplt.legend((s0, s1), ('CLIP Image Space', 'CLIP Text Space'), fontsize=10)\nplt.xticks()\nplt.yticks()\nplt.title('t-SNE Results')\nplt.tight_layout()\nplt.savefig(f'tSNE-{data_name}-{num}_ori.png')\n\nplt.close()\n\nfor i in range(result_delta.shape[0]):\n    x_min, x_max = np.min(result_delta, 0), np.max(result_delta, 0)\n    data = (result_delta - x_min) / (x_max - x_min)\n    if i < result_delta.shape[0]//2:\n        s0 = plt.scatter(data[i, 0], data[i, 1], color=plt.cm.Set1(2/4), s=12, marker='o')\n    elif i < result_delta.shape[0]:\n        s1 = plt.scatter(data[i, 0], data[i, 1], color=plt.cm.Set1(3/4), s=12, marker='o')\n    \nplt.legend((s0, s1), ('CLIP Delta Image Space', 'CLIP Delta Text Space'), fontsize=10)\nplt.xticks()\nplt.yticks()\nplt.title('t-SNE Results')\nplt.tight_layout()\nplt.savefig(f'tSNE-{data_name}-{num}_delta.png')"
  },
  {
    "path": "utils/map_tool.py",
    "content": "import torch\nimport clip\nimport os\nimport numpy as np\n\nimagenet_templates = [\n    'a bad photo of a {}.',\n#    'a photo of many {}.',\n    'a sculpture of a {}.',\n    'a photo of the hard to see {}.',\n    'a low resolution photo of the {}.',\n    'a rendering of a {}.',\n    'graffiti of a {}.',\n    'a bad photo of the {}.',\n    'a cropped photo of the {}.',\n    'a tattoo of a {}.',\n    'the embroidered {}.',\n    'a photo of a hard to see {}.',\n    'a bright photo of a {}.',\n    'a photo of a clean {}.',\n    'a photo of a dirty {}.',\n    'a dark photo of the {}.',\n    'a drawing of a {}.',\n    'a photo of my {}.',\n    'the plastic {}.',\n    'a photo of the cool {}.',\n    'a close-up photo of a {}.',\n    'a black and white photo of the {}.',\n    'a painting of the {}.',\n    'a painting of a {}.',\n    'a pixelated photo of the {}.',\n    'a sculpture of the {}.',\n    'a bright photo of the {}.',\n    'a cropped photo of a {}.',\n    'a plastic {}.',\n    'a photo of the dirty {}.',\n    'a jpeg corrupted photo of a {}.',\n    'a blurry photo of the {}.',\n    'a photo of the {}.',\n    'a good photo of the {}.',\n    'a rendering of the {}.',\n    'a {} in a video game.',\n    'a photo of one {}.',\n    'a doodle of a {}.',\n    'a close-up photo of the {}.',\n    'a photo of a {}.',\n    'the origami {}.',\n    'the {} in a video game.',\n    'a sketch of a {}.',\n    'a doodle of the {}.',\n    'a origami {}.',\n    'a low resolution photo of a {}.',\n    'the toy {}.',\n    'a rendition of the {}.',\n    'a photo of the clean {}.',\n    'a photo of a large {}.',\n    'a rendition of a {}.',\n    'a photo of a nice {}.',\n    'a photo of a weird {}.',\n    'a blurry photo of a {}.',\n    'a cartoon {}.',\n    'art of a {}.',\n    'a sketch of the {}.',\n    'a embroidered {}.',\n    'a pixelated photo of a {}.',\n    'itap of the {}.',\n    'a jpeg corrupted photo of the {}.',\n    'a good photo of a {}.',\n    'a plushie {}.',\n    'a photo of the nice {}.',\n    'a photo of the small {}.',\n    'a photo of the weird {}.',\n    'the cartoon {}.',\n    'art of the {}.',\n    'a drawing of the {}.',\n    'a photo of the large {}.',\n    'a black and white photo of a {}.',\n    'the plushie {}.',\n    'a dark photo of a {}.',\n    'itap of a {}.',\n    'graffiti of the {}.',\n    'a toy {}.',\n    'itap of my {}.',\n    'a photo of a cool {}.',\n    'a photo of a small {}.',\n    'a tattoo of the {}.',\n]\n\ndef zeroshot_classifier(classnames, templates,model):\n    with torch.no_grad():\n        zeroshot_weights = []\n        for classname in classnames:\n            texts = [template.format(classname) for template in templates] #format with class\n            texts = clip.tokenize(texts).cuda() #tokenize\n            class_embeddings = model.encode_text(texts) #embed with text encoder\n            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)\n            class_embedding = class_embeddings.mean(dim=0)\n            class_embedding /= class_embedding.norm()\n            zeroshot_weights.append(class_embedding)\n        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()\n    return zeroshot_weights\n\ndef GetDt(classnames,model):\n    text_features=zeroshot_classifier(classnames, imagenet_templates,model).t()\n    \n    dt=text_features[0]-text_features[1]\n    dt=dt.cpu().numpy()\n    \n    return dt\n\n\nif __name__ == \"__main__\":\n    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n    model, preprocess = clip.load(\"ViT-B/32\", device=device)\n\n    neutral='face with eyes' #@param {type:\"string\"}\n    target='face with blue eyes' #@param {type:\"string\"}\n    classnames=[target,neutral]\n    dt = GetDt(classnames,model)\n    print(dt.shape)"
  },
  {
    "path": "utils/stylespace_util.py",
    "content": "import torch\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport torchvision\n\nfrom torch.nn import functional as F\n\nindex = [0,1,1,2,2,3,4,4,5,6,6,7,8,8,9,10,10,11,12,12,13,14,14,15,16,16]\n\ndef conv_warper(layer, input, style, noise):\n\n    conv = layer.conv\n    batch, in_channel, height, width = input.shape\n\n    style = style.view(batch, 1, in_channel, 1, 1)\n    weight = conv.scale * conv.weight * style\n\n    if conv.demodulate:\n        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)\n        weight = weight * demod.view(batch, conv.out_channel, 1, 1, 1)\n\n    weight = weight.view(\n        batch * conv.out_channel, in_channel, conv.kernel_size, conv.kernel_size\n    )\n\n    if conv.upsample:\n        input = input.view(1, batch * in_channel, height, width)\n        weight = weight.view(\n            batch, conv.out_channel, in_channel, conv.kernel_size, conv.kernel_size\n        )\n        weight = weight.transpose(1, 2).reshape(\n            batch * in_channel, conv.out_channel, conv.kernel_size, conv.kernel_size\n        )\n        out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)\n        _, _, height, width = out.shape\n        out = out.view(batch, conv.out_channel, height, width)\n        out = conv.blur(out)\n\n    elif conv.downsample:\n        input = conv.blur(input)\n        _, _, height, width = input.shape\n        input = input.view(1, batch * in_channel, height, width)\n        out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)\n        _, _, height, width = out.shape\n        out = out.view(batch, conv.out_channel, height, width)\n\n    else:\n        input = input.view(1, batch * in_channel, height, width)\n        out = F.conv2d(input, weight, padding=conv.padding, groups=batch)\n        _, _, height, width = out.shape\n        out = out.view(batch, conv.out_channel, height, width)\n        \n    out = layer.noise(out, noise=noise)\n    out = layer.activate(out)\n    \n    return out\n\ndef decoder(G, style_space, latent, noise):\n\n    out = G.input(latent)\n    out = conv_warper(G.conv1, out, style_space[0], noise[0])\n    skip = G.to_rgb1(out, latent[:, 1])\n\n    i = 1\n    for conv1, conv2, noise1, noise2, to_rgb in zip(\n        G.convs[::2], G.convs[1::2], noise[1::2], noise[2::2], G.to_rgbs\n    ):\n        out = conv_warper(conv1, out, style_space[i], noise=noise1)\n        out = conv_warper(conv2, out, style_space[i+1], noise=noise2)\n        skip = to_rgb(out, latent[:, i + 2], skip)\n\n        i += 2\n\n    image = skip\n\n    return image\n\ndef decoder_validate(G, style_space, latent):\n\n    style_space = split_stylespace(style_space)\n    noise = [getattr(G.noises, 'noise_{}'.format(i)) for i in range(G.num_layers)]\n\n    out = G.input(latent)\n    out = conv_warper(G.conv1, out, style_space[0], noise[0])\n    skip = G.to_rgb1(out, latent[:, 1])\n\n    i = 1\n    for conv1, conv2, noise1, noise2, to_rgb in zip(\n        G.convs[::2], G.convs[1::2], noise[1::2], noise[2::2], G.to_rgbs\n    ):\n        out = conv_warper(conv1, out, style_space[i], noise=noise1)\n        out = conv_warper(conv2, out, style_space[i+1], noise=noise2)\n        skip = to_rgb(out, latent[:, i + 2], skip)\n\n        i += 2\n\n    image = skip\n\n    return image\n\ndef encoder_noise(G, noise):\n\n    styles = [noise]\n    style_space = []\n    \n    styles = [G.style(s) for s in styles]\n    noise = [getattr(G.noises, 'noise_{}'.format(i)) for i in range(G.num_layers)]\n    inject_index = G.n_latent\n    latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)\n    style_space.append(G.conv1.conv.modulation(latent[:, 0]))\n\n    i = 1\n    for conv1, conv2, noise1, noise2, to_rgb in zip(\n        G.convs[::2], G.convs[1::2], noise[1::2], noise[2::2], G.to_rgbs\n    ):\n        style_space.append(conv1.conv.modulation(latent[:, i]))\n        style_space.append(conv2.conv.modulation(latent[:, i+1]))\n        i += 2\n        \n    return style_space, latent, noise\n\ndef encoder_latent(G, latent):\n    # an encoder warper for G\n\n    style_space = []\n    \n    noise = [getattr(G.noises, 'noise_{}'.format(i)) for i in range(G.num_layers)]\n\n    style_space.append(G.conv1.conv.modulation(latent[:, 0]))\n\n    i = 1\n    for conv1, conv2, to_rgb in zip(\n        G.convs[::2], G.convs[1::2], G.to_rgbs\n    ):\n        style_space.append(conv1.conv.modulation(latent[:, i]))\n        style_space.append(conv2.conv.modulation(latent[:, i+1]))\n        i += 2\n        \n    return style_space, noise\n\ndef split_stylespace(style):\n    style_space = []\n\n    for idx in range(10):\n        style_space.append(style[:, idx*512 : (idx+1) * 512])\n    \n    style_space.append(style[:, 10*512: 10*512 + 256])\n    style_space.append(style[:, 10*512 + 256: 10*512 + 256*2])\n    style_space.append(style[:, 10*512 + 256*2: 10*512 + 256*2 + 128])\n    style_space.append(style[:, 10*512 + 256*2 + 128: 10*512 + 256*2 + 128 * 2])\n    style_space.append(style[:, 10*512 + 256*2 + 128*2: 10*512 + 256*2 + 128*2 + 64])\n    style_space.append(style[:, 10*512 + 256*2 + 128*2 + 64: 10*512 + 256*2 + 128*2 + 64*2])\n    style_space.append(style[:, 10*512 + 256*2 + 128*2 + 64*2: 10*512 + 256*2 + 128*2 + 64*2 + 32])\n\n    return style_space\n\ndef fuse_stylespace(style):\n    new_s = torch.cat(style, dim=1)\n\n    return new_s"
  }
]