Full Code of Yueming6568/DeltaEdit for AI

main d80bc4bfdc6d cached
38 files
109.1 KB
29.6k tokens
195 symbols
1 requests
Download .txt
Repository: Yueming6568/DeltaEdit
Branch: main
Commit: d80bc4bfdc6d
Files: 38
Total size: 109.1 KB

Directory structure:
gitextract_2trjho5x/

├── README.md
├── clip/
│   ├── __init__.py
│   ├── clip.py
│   ├── model.py
│   └── simple_tokenizer.py
├── datasets/
│   ├── test_dataset.py
│   └── train_dataset.py
├── delta_mapper.py
├── examples/
│   ├── cspace_img_feat.npy
│   ├── sspace_img_feat.npy
│   └── wplus_img_feat.npy
├── generate_codes.py
├── models/
│   ├── encoders/
│   │   ├── __init__.py
│   │   ├── helpers.py
│   │   ├── model_irse.py
│   │   └── psp_encoders.py
│   └── stylegan2/
│       ├── __init__.py
│       ├── model.py
│       ├── npy_ffhq/
│       │   └── fs3.npy
│       └── op/
│           ├── __init__.py
│           ├── fused_act.py
│           └── upfirdn2d.py
├── options/
│   ├── test_options.py
│   └── train_options.py
├── scripts/
│   ├── inference.py
│   ├── inference_real.py
│   └── train.py
├── tSNE/
│   ├── celeba/
│   │   ├── cspace_celeba_deltai.npy
│   │   ├── cspace_celeba_deltat.npy
│   │   ├── cspace_celeba_i.npy
│   │   └── cspace_celeba_t.npy
│   ├── cocoval/
│   │   ├── cspace_cocoval_deltai.npy
│   │   ├── cspace_cocoval_deltat.npy
│   │   ├── cspace_cocoval_i.npy
│   │   └── cspace_cocoval_t.npy
│   └── compute_tsne.py
└── utils/
    ├── map_tool.py
    └── stylespace_util.py

================================================
FILE CONTENTS
================================================

================================================
FILE: README.md
================================================
# DeltaSpace: A Semantic-aligned Feature Space for Flexible Text-guided Image Editing

## Overview

This repository contains the **offical** PyTorch implementation of paper:

*DeltaEdit: Exploring Text-free Training for Text-driven Image Manipulation*, CVPR 2023

*DeltaSpace: A Semantic-aligned Feature Space for Flexible Text-guided Image Editing*, Arxiv 2023

## News

- [2025-06-22] Upload t-SNE Code for Alignment Validation​ (◍>◡<◍).

- [2023-03-11] Upload the training and inference code for the facial domain (◍•ڡ•◍).

*To be continued...*

<!-- We will release the training and inference code for the LSUN cat, church, horse later : ) -->

## Dependences

- Install CLIP:

  ```shell script
  conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=<CUDA_VERSION>
  pip install ftfy regex tqdm gdown
  pip install git+https://github.com/openai/CLIP.git
  ```

- Download pre-trained models :

  - The code relies on the [Rosinality](https://github.com/rosinality/stylegan2-pytorch/) pytorch implementation of StyleGAN2.
  - Download the pre-trained StyleGAN2 generator model for the faical domain from [here](https://drive.google.com/file/d/1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT/view?usp=sharing), and then place it into the folder `./models/pretrained_models`.
  - Download the pre-trained StyleGAN2 generator model for the LSUN cat, church, horse domains from [here](https://drive.google.com/drive/folders/1YRhXGM-2xk7A4TExM_jXaNg1f2AiCRlw?usp=share_link) and then place them into the folder `./models/pretrained_models/stylegan2-{cat/church/horse}`.

## Training

### Data preparing

- DeltaEdit is trained on latent vectors. 

- For the facial domain,  58,000 real images from [FFHQ](https://github.com/NVlabs/ffhq-dataset) dataset are randomly selected and 200,000 fake images from the z space in StyleGAN are sampled for training. Note that all real images are inverted by [e4e](https://github.com/omertov/encoder4editing) encoder.

- Download the provided FFHQ latent vectors from [here](https://drive.google.com/drive/folders/13NLq4giSgdcMVkYQIiPj4Xhxz4-wlXSD?usp=sharing) and then place all numpy files into the folder `./latent_code/ffhq`.

- Generate the 200,000 sampled latent vectors by running the following commands for each specific domain:

  ```python
  CUDA_VISIBLE_DEVICES=0 python generate_codes.py --classname ffhq --samples 200000
  CUDA_VISIBLE_DEVICES=0 python generate_codes.py --classname cat --samples 200000
  CUDA_VISIBLE_DEVICES=0 python generate_codes.py --classname church --samples 200000
  CUDA_VISIBLE_DEVICES=0 python generate_codes.py --classname horse --samples 200000
  ```

### Usage

- The main training script is placed in `./scripts/train.py`.
- Training arguments can be found at `./options/train_options.py`. 

For training please run the following commands:

```python
CUDA_VISIBLE_DEVICES=0 python scripts/train.py
```

## Inference

- The main inferece script is placed in `./scripts/inference.py`.
- Inference arguments can be found at `./options/test_options.py`. 
- Download the pretrained DeltaMapper model for editing human face from [here](https://drive.google.com/file/d/1Mb2WiELoVDPDIi24tIfoWsjn1l2xTjtZ/view?usp=sharing), and then place it into the folder `./checkpoints` .
- Some inference data are provided in  `./examples`.

To produce editing results please run the following commands :

```python
CUDA_VISIBLE_DEVICES=1 python scripts/inference.py --target "chubby face","face with eyeglasses","face with smile","face with pale skin","face with tanned skin","face with big eyes","face with black clothes","face with blue suit","happy face","face with bangs","face with red hair","face with black hair","face with blond hair","face with curly hair","face with receding hairline","face with bowlcut hairstyle"
```

The produced results are showed in the following. 

You can also specify your desired target attributes to the flag of `--target`.

## Inference for real images

- The main inferece script is placed in `./scripts/inference_real.py`.
- Inference arguments can be found at `./options/test_options.py`. 
- Download the pretrained DeltaMapper model for editing human face from [here](https://drive.google.com/file/d/1Mb2WiELoVDPDIi24tIfoWsjn1l2xTjtZ/view?usp=sharing), and then place it into the folder `./checkpoints` .
- Download the pretrained e4e encoder e4e_ffhq_encode.pt from [e4e](https://github.com/omertov/encoder4editing).
- One test image is provided in  `./test_imgs`.

To produce editing results please run the following commands :

```python
CUDA_VISIBLE_DEVICES=1 python scripts/inference_real.py --target "chubby face","face with eyeglasses","face with smile","face with pale skin","face with tanned skin","face with big eyes","face with black clothes","face with blue suit","happy face","face with bangs","face with red hair","face with black hair","face with blond hair","face with curly hair","face with receding hairline","face with bowlcut hairstyle"
```

## Alignment Validation: CLIP Space vs. DeltaSpace via t-SNE Visualization​

```python
cd tSNE
python compute_tsne.py
```

After executing the implementation code, you can obtain A 2D t-SNE projection of embeddings from both spaces (e.g., CLIP and DeltaSpace). The results are shown below for your convenience.

![tsne](./tsne.jpg)

## Results

![results](./results.jpg)

## Acknowledgements

This code is developed based on the code of [orpatashnik/StyleCLIP](https://github.com/orpatashnik/StyleCLIP) by Or Patashnik et al.

## Citation
If you use this code for your research, please cite our paper:
```
@InProceedings{lyu2023deltaedit,
    author    = {Lyu, Yueming and Lin, Tianwei and Li, Fu and He, Dongliang and Dong, Jing and Tan, Tieniu},
    title     = {DeltaEdit: Exploring Text-free Training for Text-Driven Image Manipulation},
    booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
    year      = {2023}
}

@article{lyu2023deltaspace,
    author    = {Lyu, Yueming and Zhao, Kang and Peng, Bo and Chen, Huafeng and Jiang, Yue and Zhang, Yingya and Dong, Jing, and Shan Caifeng},
    title     = {DeltaSpace: A Semantic-aligned Feature Space for Flexible Text-guided Image Editing},
    journal   = {arXiv preprint arXiv:2310.08785},
    year      = {2023},
}

```


================================================
FILE: clip/__init__.py
================================================
from .clip import *


================================================
FILE: clip/clip.py
================================================
import hashlib
import os
import urllib
import warnings
from typing import Union, List

import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from tqdm import tqdm

from .model import build_model
from .simple_tokenizer import SimpleTokenizer as _Tokenizer

try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC


if torch.__version__.split(".") < ["1", "7", "1"]:
    warnings.warn("PyTorch version 1.7.1 or higher is recommended")


__all__ = ["available_models", "load", "tokenize"]
_tokenizer = _Tokenizer()

_MODELS = {
    "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
    "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
    "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
    "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
    "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
    "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
}


def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
    os.makedirs(root, exist_ok=True)
    filename = os.path.basename(url)

    expected_sha256 = url.split("/")[-2]
    download_target = os.path.join(root, filename)

    if os.path.exists(download_target) and not os.path.isfile(download_target):
        raise RuntimeError(f"{download_target} exists and is not a regular file")

    if os.path.isfile(download_target):
        if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
            return download_target
        else:
            warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")

    with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
        with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
            while True:
                buffer = source.read(8192)
                if not buffer:
                    break

                output.write(buffer)
                loop.update(len(buffer))

    if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
        raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")

    return download_target


def _transform(n_px):
    return Compose([
        Resize(n_px, interpolation=BICUBIC),
        CenterCrop(n_px),
        lambda image: image.convert("RGB"),
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])


def available_models() -> List[str]:
    """Returns the names of available CLIP models"""
    return list(_MODELS.keys())


def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False):
    """Load a CLIP model

    Parameters
    ----------
    name : str
        A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict

    device : Union[str, torch.device]
        The device to put the loaded model

    jit : bool
        Whether to load the optimized JIT model or more hackable non-JIT model (default).

    Returns
    -------
    model : torch.nn.Module
        The CLIP model

    preprocess : Callable[[PIL.Image], torch.Tensor]
        A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
    """
    if name in _MODELS:
        model_path = _download(_MODELS[name])
    elif os.path.isfile(name):
        model_path = name
    else:
        raise RuntimeError(f"Model {name} not found; available models = {available_models()}")

    try:
        # loading JIT archive
        model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
        state_dict = None
    except RuntimeError:
        # loading saved state dict
        if jit:
            warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
            jit = False
        state_dict = torch.load(model_path, map_location="cpu")

    if not jit:
        model = build_model(state_dict or model.state_dict()).to(device)
        if str(device) == "cpu":
            model.float()
        return model, _transform(model.visual.input_resolution)

    # patch the device names
    device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
    device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]

    def patch_device(module):
        try:
            graphs = [module.graph] if hasattr(module, "graph") else []
        except RuntimeError:
            graphs = []

        if hasattr(module, "forward1"):
            graphs.append(module.forward1.graph)

        for graph in graphs:
            for node in graph.findAllNodes("prim::Constant"):
                if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
                    node.copyAttributes(device_node)

    model.apply(patch_device)
    patch_device(model.encode_image)
    patch_device(model.encode_text)

    # patch dtype to float32 on CPU
    if str(device) == "cpu":
        float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
        float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
        float_node = float_input.node()

        def patch_float(module):
            try:
                graphs = [module.graph] if hasattr(module, "graph") else []
            except RuntimeError:
                graphs = []

            if hasattr(module, "forward1"):
                graphs.append(module.forward1.graph)

            for graph in graphs:
                for node in graph.findAllNodes("aten::to"):
                    inputs = list(node.inputs())
                    for i in [1, 2]:  # dtype can be the second or third argument to aten::to()
                        if inputs[i].node()["value"] == 5:
                            inputs[i].node().copyAttributes(float_node)

        model.apply(patch_float)
        patch_float(model.encode_image)
        patch_float(model.encode_text)

        model.float()

    return model, _transform(model.input_resolution.item())


def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
    """
    Returns the tokenized representation of given input string(s)

    Parameters
    ----------
    texts : Union[str, List[str]]
        An input string or a list of input strings to tokenize

    context_length : int
        The context length to use; all CLIP models use 77 as the context length

    truncate: bool
        Whether to truncate the text in case its encoding is longer than the context length

    Returns
    -------
    A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
    """
    if isinstance(texts, str):
        texts = [texts]

    sot_token = _tokenizer.encoder["<|startoftext|>"]
    eot_token = _tokenizer.encoder["<|endoftext|>"]
    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)

    for i, tokens in enumerate(all_tokens):
        if len(tokens) > context_length:
            if truncate:
                tokens = tokens[:context_length]
                tokens[-1] = eot_token
            else:
                raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
        result[i, :len(tokens)] = torch.tensor(tokens)

    return result


================================================
FILE: clip/model.py
================================================
from collections import OrderedDict
from typing import Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1):
        super().__init__()

        # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)

        self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()

        self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)

        self.relu = nn.ReLU(inplace=True)
        self.downsample = None
        self.stride = stride

        if stride > 1 or inplanes != planes * Bottleneck.expansion:
            # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
            self.downsample = nn.Sequential(OrderedDict([
                ("-1", nn.AvgPool2d(stride)),
                ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
                ("1", nn.BatchNorm2d(planes * self.expansion))
            ]))

    def forward(self, x: torch.Tensor):
        identity = x

        out = self.relu(self.bn1(self.conv1(x)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = self.avgpool(out)
        out = self.bn3(self.conv3(out))

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out


class AttentionPool2d(nn.Module):
    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
        super().__init__()
        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
        self.num_heads = num_heads

    def forward(self, x):
        x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1)  # NCHW -> (HW)NC
        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
        x, _ = F.multi_head_attention_forward(
            query=x, key=x, value=x,
            embed_dim_to_check=x.shape[-1],
            num_heads=self.num_heads,
            q_proj_weight=self.q_proj.weight,
            k_proj_weight=self.k_proj.weight,
            v_proj_weight=self.v_proj.weight,
            in_proj_weight=None,
            in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
            bias_k=None,
            bias_v=None,
            add_zero_attn=False,
            dropout_p=0,
            out_proj_weight=self.c_proj.weight,
            out_proj_bias=self.c_proj.bias,
            use_separate_proj_weight=True,
            training=self.training,
            need_weights=False
        )

        return x[0]


class ModifiedResNet(nn.Module):
    """
    A ResNet class that is similar to torchvision's but contains the following changes:
    - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
    - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
    - The final pooling layer is a QKV attention instead of an average pool
    """

    def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
        super().__init__()
        self.output_dim = output_dim
        self.input_resolution = input_resolution

        # the 3-layer stem
        self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(width // 2)
        self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(width // 2)
        self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(width)
        self.avgpool = nn.AvgPool2d(2)
        self.relu = nn.ReLU(inplace=True)

        # residual layers
        self._inplanes = width  # this is a *mutable* variable used during construction
        self.layer1 = self._make_layer(width, layers[0])
        self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
        self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
        self.layer4 = self._make_layer(width * 8, layers[3], stride=2)

        embed_dim = width * 32  # the ResNet feature dimension
        self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)

    def _make_layer(self, planes, blocks, stride=1):
        layers = [Bottleneck(self._inplanes, planes, stride)]

        self._inplanes = planes * Bottleneck.expansion
        for _ in range(1, blocks):
            layers.append(Bottleneck(self._inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        def stem(x):
            for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
                x = self.relu(bn(conv(x)))
            x = self.avgpool(x)
            return x

        x = x.type(self.conv1.weight.dtype)
        x = stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.attnpool(x)

        return x


class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""

    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)


class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)


class ResidualAttentionBlock(nn.Module):
    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
        super().__init__()

        self.attn = nn.MultiheadAttention(d_model, n_head)
        self.ln_1 = LayerNorm(d_model)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, d_model * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(d_model * 4, d_model))
        ]))
        self.ln_2 = LayerNorm(d_model)
        self.attn_mask = attn_mask

    def attention(self, x: torch.Tensor):
        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

    def forward(self, x: torch.Tensor):
        x = x + self.attention(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class Transformer(nn.Module):
    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
        super().__init__()
        self.width = width
        self.layers = layers
        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])

    def forward(self, x: torch.Tensor):
        return self.resblocks(x)


class VisionTransformer(nn.Module):
    def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
        super().__init__()
        self.input_resolution = input_resolution
        self.output_dim = output_dim
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)

        scale = width ** -0.5
        self.class_embedding = nn.Parameter(scale * torch.randn(width))
        self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
        self.ln_pre = LayerNorm(width)

        self.transformer = Transformer(width, layers, heads)

        self.ln_post = LayerNorm(width)
        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))

    def forward(self, x: torch.Tensor):
        x = self.conv1(x)  # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
        x = x + self.positional_embedding.to(x.dtype)
        x = self.ln_pre(x)

        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD

        x = self.ln_post(x[:, 0, :])

        if self.proj is not None:
            x = x @ self.proj

        return x


class CLIP(nn.Module):
    def __init__(self,
                 embed_dim: int,
                 # vision
                 image_resolution: int,
                 vision_layers: Union[Tuple[int, int, int, int], int],
                 vision_width: int,
                 vision_patch_size: int,
                 # text
                 context_length: int,
                 vocab_size: int,
                 transformer_width: int,
                 transformer_heads: int,
                 transformer_layers: int
                 ):
        super().__init__()

        self.context_length = context_length

        if isinstance(vision_layers, (tuple, list)):
            vision_heads = vision_width * 32 // 64
            self.visual = ModifiedResNet(
                layers=vision_layers,
                output_dim=embed_dim,
                heads=vision_heads,
                input_resolution=image_resolution,
                width=vision_width
            )
        else:
            vision_heads = vision_width // 64
            self.visual = VisionTransformer(
                input_resolution=image_resolution,
                patch_size=vision_patch_size,
                width=vision_width,
                layers=vision_layers,
                heads=vision_heads,
                output_dim=embed_dim
            )

        self.transformer = Transformer(
            width=transformer_width,
            layers=transformer_layers,
            heads=transformer_heads,
            attn_mask=self.build_attention_mask()
        )

        self.vocab_size = vocab_size
        self.token_embedding = nn.Embedding(vocab_size, transformer_width)
        self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
        self.ln_final = LayerNorm(transformer_width)

        self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        self.initialize_parameters()

    def initialize_parameters(self):
        nn.init.normal_(self.token_embedding.weight, std=0.02)
        nn.init.normal_(self.positional_embedding, std=0.01)

        if isinstance(self.visual, ModifiedResNet):
            if self.visual.attnpool is not None:
                std = self.visual.attnpool.c_proj.in_features ** -0.5
                nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
                nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
                nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
                nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)

            for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
                for name, param in resnet_block.named_parameters():
                    if name.endswith("bn3.weight"):
                        nn.init.zeros_(param)

        proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
        attn_std = self.transformer.width ** -0.5
        fc_std = (2 * self.transformer.width) ** -0.5
        for block in self.transformer.resblocks:
            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)

        if self.text_projection is not None:
            nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)

    def build_attention_mask(self):
        # lazily create causal attention mask, with full attention between the vision tokens
        # pytorch uses additive attention mask; fill with -inf
        mask = torch.empty(self.context_length, self.context_length)
        mask.fill_(float("-inf"))
        mask.triu_(1)  # zero out the lower diagonal
        return mask

    @property
    def dtype(self):
        return self.visual.conv1.weight.dtype

    def encode_image(self, image):
        return self.visual(image.type(self.dtype))

    def encode_text(self, text):
        x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]

        x = x + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

        return x

    def forward(self, image, text):
        image_features = self.encode_image(image)
        text_features = self.encode_text(text)

        # normalized features
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        # cosine similarity as logits
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logit_scale * text_features @ image_features.t()

        # shape = [global_batch_size, global_batch_size]
        return logits_per_image, logits_per_text
    

def convert_weights(model: nn.Module):
    """Convert applicable model parameters to fp16"""

    def _convert_weights_to_fp16(l):
        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
            l.weight.data = l.weight.data.half()
            if l.bias is not None:
                l.bias.data = l.bias.data.half()

        if isinstance(l, nn.MultiheadAttention):
            for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
                tensor = getattr(l, attr)
                if tensor is not None:
                    tensor.data = tensor.data.half()

        for name in ["text_projection", "proj"]:
            if hasattr(l, name):
                attr = getattr(l, name)
                if attr is not None:
                    attr.data = attr.data.half()

    model.apply(_convert_weights_to_fp16)


def build_model(state_dict: dict):
    vit = "visual.proj" in state_dict

    if vit:
        vision_width = state_dict["visual.conv1.weight"].shape[0]
        vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
        vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
        grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
        image_resolution = vision_patch_size * grid_size
    else:
        counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
        vision_layers = tuple(counts)
        vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
        output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
        vision_patch_size = None
        assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
        image_resolution = output_width * 32

    embed_dim = state_dict["text_projection"].shape[1]
    context_length = state_dict["positional_embedding"].shape[0]
    vocab_size = state_dict["token_embedding.weight"].shape[0]
    transformer_width = state_dict["ln_final.weight"].shape[0]
    transformer_heads = transformer_width // 64
    transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))

    model = CLIP(
        embed_dim,
        image_resolution, vision_layers, vision_width, vision_patch_size,
        context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
    )

    for key in ["input_resolution", "context_length", "vocab_size"]:
        if key in state_dict:
            del state_dict[key]

    convert_weights(model)
    model.load_state_dict(state_dict)
    return model.eval()


================================================
FILE: clip/simple_tokenizer.py
================================================
import gzip
import html
import os
from functools import lru_cache

import ftfy
import regex as re


@lru_cache()
def default_bpe():
    return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")


@lru_cache()
def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a corresponding list of unicode strings.
    The reversible bpe codes work on unicode strings.
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
    This is a signficant percentage of your normal, say, 32K bpe vocab.
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
    And avoids mapping to whitespace/control characters the bpe code barfs on.
    """
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))


def get_pairs(word):
    """Return set of symbol pairs in a word.
    Word is represented as tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs


def basic_clean(text):
    text = ftfy.fix_text(text)
    text = html.unescape(html.unescape(text))
    return text.strip()


def whitespace_clean(text):
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    return text


class SimpleTokenizer(object):
    def __init__(self, bpe_path: str = default_bpe()):
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
        merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
        merges = merges[1:49152-256-2+1]
        merges = [tuple(merge.split()) for merge in merges]
        vocab = list(bytes_to_unicode().values())
        vocab = vocab + [v+'</w>' for v in vocab]
        for merge in merges:
            vocab.append(''.join(merge))
        vocab.extend(['<|startoftext|>', '<|endoftext|>'])
        self.encoder = dict(zip(vocab, range(len(vocab))))
        self.decoder = {v: k for k, v in self.encoder.items()}
        self.bpe_ranks = dict(zip(merges, range(len(merges))))
        self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
        self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token[:-1]) + ( token[-1] + '</w>',)
        pairs = get_pairs(word)

        if not pairs:
            return token+'</w>'

        while True:
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = ' '.join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        bpe_tokens = []
        text = whitespace_clean(basic_clean(text)).lower()
        for token in re.findall(self.pat, text):
            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
        return bpe_tokens

    def decode(self, tokens):
        text = ''.join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
        return text


================================================
FILE: datasets/test_dataset.py
================================================
import numpy as np

import torch
from torch.utils.data import Dataset

class TestLatentsDataset(Dataset):
    def __init__(self):

        style_latents_list = []
        clip_latents_list = []
        wplus_latents_list = []
        
        #change the paths here for testing other latent codes
        style_latents_list.append(torch.Tensor(np.load("./examples/sspace_img_feat.npy")))
        clip_latents_list.append(torch.Tensor(np.load("./examples/cspace_img_feat.npy")))
        wplus_latents_list.append(torch.Tensor(np.load("./examples/wplus_img_feat.npy")))
        
        self.style_latents = torch.cat(style_latents_list, dim=0)
        self.clip_latents = torch.cat(clip_latents_list, dim=0)
        self.wplus_latents = torch.cat(wplus_latents_list, dim=0)
        
    def __len__(self):

        return self.style_latents.shape[0]

    def __getitem__(self, index):

        latent_s1 = self.style_latents[index]
        latent_c1 = self.clip_latents[index]
        latent_w1 = self.wplus_latents[index]
        latent_c1 = latent_c1 / latent_c1.norm(dim=-1, keepdim=True).float()
        
        delta_c = torch.cat([latent_c1, latent_c1], dim=0)
        
        return latent_s1, delta_c, latent_w1

================================================
FILE: datasets/train_dataset.py
================================================
import copy
import random
import numpy as np

import torch
from torch.utils.data import Dataset

class TrainLatentsDataset(Dataset):
    def __init__(self, opts, cycle=True):

        style_latents_list = []
        clip_latents_list = []
        wplus_latents_list = []

        style_latents_list.append(torch.Tensor(np.load(f"./latent_code/{opts.classname}/sspace_noise_feat.npy")))
        clip_latents_list.append(torch.Tensor(np.load(f"./latent_code/{opts.classname}/cspace_noise_feat.npy")))
        wplus_latents_list.append(torch.Tensor(np.load(f"./latent_code/{opts.classname}/wspace_noise_feat.npy")))
        
        style_latents_list.append(torch.Tensor(np.load(f"./latent_code/{opts.classname}/sspace_ffhq_feat.npy")))
        clip_latents_list.append(torch.Tensor(np.load(f"./latent_code/{opts.classname}/cspace_ffhq_feat.npy")))
        wplus_latents_list.append(torch.Tensor(np.load(f"./latent_code/{opts.classname}/wspace_ffhq_feat.npy")))
        
        self.style_latents = torch.cat(style_latents_list, dim=0)
        self.clip_latents = torch.cat(clip_latents_list, dim=0)
        self.wplus_latents = torch.cat(wplus_latents_list, dim=0)

        self.style_latents = self.style_latents[:200000+58000]
        self.clip_latents = self.clip_latents[:200000+58000]
        self.wplus_latents = self.wplus_latents[:200000+58000]

        self.dataset_size = self.style_latents.shape[0]
        print("dataset size", self.dataset_size)
        self.cycle = cycle
        
    def __len__(self):
        if self.cycle:
            return self.style_latents.shape[0] * 50
        else:
            return self.style_latents.shape[0]

    def __getitem__(self, index):
        if self.cycle:
            index = index % self.dataset_size

        latent_s1 = self.style_latents[index]
        latent_c1 = self.clip_latents[index]
        latent_w1 = self.wplus_latents[index]
        latent_c1 = latent_c1 / latent_c1.norm(dim=-1, keepdim=True).float()

        random_index = random.randint(0, self.dataset_size - 1)
        latent_s2 = self.style_latents[random_index]
        latent_c2 = self.clip_latents[random_index]
        latent_w2 = self.wplus_latents[random_index]
        latent_c2 = latent_c2 / latent_c2.norm(dim=-1, keepdim=True).float()

        delta_s1 = latent_s2 - latent_s1
        delta_c = latent_c2 - latent_c1
        
        delta_c = delta_c / delta_c.norm(dim=-1, keepdim=True).float().clamp(min=1e-5)
        delta_c = torch.cat([latent_c1, delta_c], dim=0)

        return latent_s1, delta_c, delta_s1

================================================
FILE: delta_mapper.py
================================================
import math

import torch
from torch import nn
from torch.nn import Module
import torch.nn.functional as F

from models.stylegan2.model import EqualLinear, PixelNorm

class Mapper(Module):

    def __init__(self, in_channel=512, out_channel=512, norm=True, num_layers=4):
        super(Mapper, self).__init__()

        layers = [PixelNorm()] if norm else []
        
        layers.append(EqualLinear(in_channel, out_channel, lr_mul=0.01, activation='fused_lrelu'))
        for _ in range(num_layers-1):
            layers.append(EqualLinear(out_channel, out_channel, lr_mul=0.01, activation='fused_lrelu'))
        self.mapping = nn.Sequential(*layers)

    def forward(self, x):
        x = self.mapping(x)
        return x

class DeltaMapper(Module):

    def __init__(self):
        super(DeltaMapper, self).__init__()

        #Style Module(sm)
        self.sm_coarse = Mapper(512,  512)
        self.sm_medium = Mapper(512,  512)
        self.sm_fine   = Mapper(2464, 2464)

        #Condition Module(cm)
        self.cm_coarse = Mapper(1024, 512)
        self.cm_medium = Mapper(1024, 512)
        self.cm_fine   = Mapper(1024, 2464)

        #Fusion Module(fm)
        self.fm_coarse = Mapper(512*2,  512,  norm=False)
        self.fm_medium = Mapper(512*2,  512,  norm=False)
        self.fm_fine   = Mapper(2464*2, 2464, norm=False)
        
    def forward(self, sspace_feat, clip_feat):

        s_coarse = sspace_feat[:, :3*512].view(-1,3,512)
        s_medium = sspace_feat[:, 3*512:7*512].view(-1,4,512)
        s_fine   = sspace_feat[:, 7*512:] #channels:2464

        s_coarse = self.sm_coarse(s_coarse)
        s_medium = self.sm_medium(s_medium)
        s_fine   = self.sm_fine(s_fine)

        c_coarse = self.cm_coarse(clip_feat)
        c_medium = self.cm_medium(clip_feat)
        c_fine   = self.cm_fine(clip_feat)

        x_coarse = torch.cat([s_coarse, torch.stack([c_coarse]*3, dim=1)], dim=2) #[b,3,1024]
        x_medium = torch.cat([s_medium, torch.stack([c_medium]*4, dim=1)], dim=2) #[b,4,1024]
        x_fine   = torch.cat([s_fine, c_fine], dim=1) #[b,2464*2]

        x_coarse = self.fm_coarse(x_coarse)
        x_coarse = x_coarse.view(-1,3*512)

        x_medium = self.fm_medium(x_medium)
        x_medium = x_medium.view(-1,4*512)

        x_fine   = self.fm_fine(x_fine)

        out = torch.cat([x_coarse, x_medium, x_fine], dim=1)
        return out

================================================
FILE: generate_codes.py
================================================
import os
import argparse
import clip

import random
import numpy as np
import torch
from torchvision import utils
from utils import stylespace_util
from models.stylegan2.model import Generator

def save_image_pytorch(img, name):
    """Helper function to save torch tensor into an image file."""
    utils.save_image(
        img,
        name,
        nrow=1,
        padding=0,
        normalize=True,
        range=(-1, 1),
    )


def generate(args, netG, device, mean_latent):

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load("ViT-B/32", device=device)
    avg_pool = torch.nn.AvgPool2d(kernel_size=1024 // 32)
    upsample = torch.nn.Upsample(scale_factor=7)

    ind = 0
    with torch.no_grad():
        netG.eval()

        # Generate images from a file of input noises
        if args.fixed_z is not None:
            sample_z = torch.load(args.fixed_z, map_location=device)
            for start in range(0, sample_z.size(0), args.batch_size):
                end = min(start + args.batch_size, sample_z.size(0))
                z_batch = sample_z[start:end]
                sample, _ = netG([z_batch], truncation=args.truncation, truncation_latent=mean_latent)
                for s in sample:
                    save_image_pytorch(s, f'{args.save_dir}/{str(ind).zfill(6)}.png')
                    ind += 1
            return

        # Generate image by sampling input noises
        w_latents_list = []
        s_latents_list = []
        c_latents_list = []
        for start in range(0, args.samples, args.batch_size):
            end = min(start + args.batch_size, args.samples)
            batch_sz = end - start
            print(f'current_num:{start}')
            sample_z = torch.randn(batch_sz, 512, device=device)

            sample, w_latents = netG([sample_z], truncation=args.truncation, truncation_latent=mean_latent,return_latents=True)
            style_space, noise = stylespace_util.encoder_latent(netG, w_latents)
            s_latents = torch.cat(style_space, dim=1)

            tmp_imgs = stylespace_util.decoder(netG, style_space, w_latents, noise)
            # for s in tmp_imgs:
            #     save_image_pytorch(s, f'{args.save_dir}/{str(ind).zfill(6)}.png')
            #     ind += 1

            img_gen_for_clip = upsample(tmp_imgs)
            img_gen_for_clip = avg_pool(img_gen_for_clip)
            c_latents = model.encode_image(img_gen_for_clip)

            w_latents_list.append(w_latents)
            s_latents_list.append(s_latents)
            c_latents_list.append(c_latents)
        w_all_latents = torch.cat(w_latents_list, dim=0)
        s_all_latents = torch.cat(s_latents_list, dim=0)
        c_all_latents = torch.cat(c_latents_list, dim=0)

        print(w_all_latents.size())
        print(s_all_latents.size())
        print(c_all_latents.size())

        w_all_latents = w_all_latents.cpu().numpy()
        s_all_latents = s_all_latents.cpu().numpy()
        c_all_latents = c_all_latents.cpu().numpy()

        os.makedirs(os.path.join(args.save_dir, args.classname), exist_ok=True)
        np.save(f"{args.save_dir}/{args.classname}/wspace_noise_feat.npy", w_all_latents)
        np.save(f"{args.save_dir}/{args.classname}/sspace_noise_feat.npy", s_all_latents)
        np.save(f"{args.save_dir}/{args.classname}/cspace_noise_feat.npy", c_all_latents)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--classname', type=str, default='ffhq', help="place to save the output")
    parser.add_argument('--save_dir', type=str, default='./latent_code', help="place to save the output")
    parser.add_argument('--ckpt', type=str, default='./models/pretrained_models', help="checkpoint file for the generator")
    parser.add_argument('--size', type=int, default=1024, help="output size of the generator")
    parser.add_argument('--fixed_z', type=str, default=None, help="expect a .pth file. If given, will use this file as the input noise for the output")
    parser.add_argument('--w_shift', type=str, default=None, help="expect a .pth file. Apply a w-latent shift to the generator")
    parser.add_argument('--batch_size', type=int, default=10, help="batch size used to generate outputs")
    parser.add_argument('--samples', type=int, default=200000, help="200000 number of samples to generate, will be overridden if --fixed_z is given")
    parser.add_argument('--truncation', type=float, default=1, help="strength of truncation:0.5ori")
    parser.add_argument('--truncation_mean', type=int, default=4096, help="number of samples to calculate the mean latent for truncation")
    parser.add_argument('--seed', type=int, default=None, help="if specified, use a fixed random seed")
    parser.add_argument('--device', type=str, default='cuda')

    args = parser.parse_args()

    device = args.device
    # use a fixed seed if given
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    netG = Generator(args.size, 512, 8).to(device)
    if args.classname == 'ffhq':
        ckpt_path = os.path.join(args.ckpt,f'stylegan2-{args.classname}-config-f.pt')
    else:
        ckpt_path = os.path.join(args.ckpt,f'stylegan2-{args.classname}','netG.pth')
    print(ckpt_path)
    checkpoint = torch.load(ckpt_path, map_location='cpu')

    if args.classname == 'ffhq':
        netG.load_state_dict(checkpoint['g_ema'])
    else:
        netG.load_state_dict(checkpoint)

    # get mean latent if truncation is applied
    if args.truncation < 1:
        with torch.no_grad():
            mean_latent = netG.mean_latent(args.truncation_mean)
    else:
        mean_latent = None

    generate(args, netG, device, mean_latent)


================================================
FILE: models/encoders/__init__.py
================================================


================================================
FILE: models/encoders/helpers.py
================================================
from collections import namedtuple
import torch
import torch.nn.functional as F
from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module

"""
ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
"""


class Flatten(Module):
    def forward(self, input):
        return input.view(input.size(0), -1)


def l2_norm(input, axis=1):
    norm = torch.norm(input, 2, axis, True)
    output = torch.div(input, norm)
    return output


class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
    """ A named tuple describing a ResNet block. """


def get_block(in_channel, depth, num_units, stride=2):
    return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]


def get_blocks(num_layers):
    if num_layers == 50:
        blocks = [
            get_block(in_channel=64, depth=64, num_units=3),
            get_block(in_channel=64, depth=128, num_units=4),
            get_block(in_channel=128, depth=256, num_units=14),
            get_block(in_channel=256, depth=512, num_units=3)
        ]
    elif num_layers == 100:
        blocks = [
            get_block(in_channel=64, depth=64, num_units=3),
            get_block(in_channel=64, depth=128, num_units=13),
            get_block(in_channel=128, depth=256, num_units=30),
            get_block(in_channel=256, depth=512, num_units=3)
        ]
    elif num_layers == 152:
        blocks = [
            get_block(in_channel=64, depth=64, num_units=3),
            get_block(in_channel=64, depth=128, num_units=8),
            get_block(in_channel=128, depth=256, num_units=36),
            get_block(in_channel=256, depth=512, num_units=3)
        ]
    else:
        raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
    return blocks


class SEModule(Module):
    def __init__(self, channels, reduction):
        super(SEModule, self).__init__()
        self.avg_pool = AdaptiveAvgPool2d(1)
        self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
        self.relu = ReLU(inplace=True)
        self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
        self.sigmoid = Sigmoid()

    def forward(self, x):
        module_input = x
        x = self.avg_pool(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return module_input * x


class bottleneck_IR(Module):
    def __init__(self, in_channel, depth, stride):
        super(bottleneck_IR, self).__init__()
        if in_channel == depth:
            self.shortcut_layer = MaxPool2d(1, stride)
        else:
            self.shortcut_layer = Sequential(
                Conv2d(in_channel, depth, (1, 1), stride, bias=False),
                BatchNorm2d(depth)
            )
        self.res_layer = Sequential(
            BatchNorm2d(in_channel),
            Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
            Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
        )

    def forward(self, x):
        shortcut = self.shortcut_layer(x)
        res = self.res_layer(x)
        return res + shortcut


class bottleneck_IR_SE(Module):
    def __init__(self, in_channel, depth, stride):
        super(bottleneck_IR_SE, self).__init__()
        if in_channel == depth:
            self.shortcut_layer = MaxPool2d(1, stride)
        else:
            self.shortcut_layer = Sequential(
                Conv2d(in_channel, depth, (1, 1), stride, bias=False),
                BatchNorm2d(depth)
            )
        self.res_layer = Sequential(
            BatchNorm2d(in_channel),
            Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
            PReLU(depth),
            Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
            BatchNorm2d(depth),
            SEModule(depth, 16)
        )

    def forward(self, x):
        shortcut = self.shortcut_layer(x)
        res = self.res_layer(x)
        return res + shortcut


def _upsample_add(x, y):
    """Upsample and add two feature maps.
    Args:
      x: (Variable) top feature map to be upsampled.
      y: (Variable) lateral feature map.
    Returns:
      (Variable) added feature map.
    Note in PyTorch, when input size is odd, the upsampled feature map
    with `F.upsample(..., scale_factor=2, mode='nearest')`
    maybe not equal to the lateral feature map size.
    e.g.
    original input size: [N,_,15,15] ->
    conv2d feature map size: [N,_,8,8] ->
    upsampled feature map size: [N,_,16,16]
    So we choose bilinear upsample which supports arbitrary output sizes.
    """
    _, _, H, W = y.size()
    return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y


================================================
FILE: models/encoders/model_irse.py
================================================
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm

"""
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
"""


class Backbone(Module):
    def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
        super(Backbone, self).__init__()
        assert input_size in [112, 224], "input_size should be 112 or 224"
        assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
        assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
        blocks = get_blocks(num_layers)
        if mode == 'ir':
            unit_module = bottleneck_IR
        elif mode == 'ir_se':
            unit_module = bottleneck_IR_SE
        self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
                                      BatchNorm2d(64),
                                      PReLU(64))
        if input_size == 112:
            self.output_layer = Sequential(BatchNorm2d(512),
                                           Dropout(drop_ratio),
                                           Flatten(),
                                           Linear(512 * 7 * 7, 512),
                                           BatchNorm1d(512, affine=affine))
        else:
            self.output_layer = Sequential(BatchNorm2d(512),
                                           Dropout(drop_ratio),
                                           Flatten(),
                                           Linear(512 * 14 * 14, 512),
                                           BatchNorm1d(512, affine=affine))

        modules = []
        for block in blocks:
            for bottleneck in block:
                modules.append(unit_module(bottleneck.in_channel,
                                           bottleneck.depth,
                                           bottleneck.stride))
        self.body = Sequential(*modules)

    def forward(self, x):
        x = self.input_layer(x)
        x = self.body(x)
        x = self.output_layer(x)
        return l2_norm(x)


def IR_50(input_size):
    """Constructs a ir-50 model."""
    model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
    return model


def IR_101(input_size):
    """Constructs a ir-101 model."""
    model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
    return model


def IR_152(input_size):
    """Constructs a ir-152 model."""
    model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
    return model


def IR_SE_50(input_size):
    """Constructs a ir_se-50 model."""
    model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
    return model


def IR_SE_101(input_size):
    """Constructs a ir_se-101 model."""
    model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
    return model


def IR_SE_152(input_size):
    """Constructs a ir_se-152 model."""
    model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
    return model


================================================
FILE: models/encoders/psp_encoders.py
================================================
from enum import Enum
import math
import numpy as np
import torch
from torch import nn
from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module

from models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE, _upsample_add
from models.stylegan2.model import EqualLinear


class ProgressiveStage(Enum):
    WTraining = 0
    Delta1Training = 1
    Delta2Training = 2
    Delta3Training = 3
    Delta4Training = 4
    Delta5Training = 5
    Delta6Training = 6
    Delta7Training = 7
    Delta8Training = 8
    Delta9Training = 9
    Delta10Training = 10
    Delta11Training = 11
    Delta12Training = 12
    Delta13Training = 13
    Delta14Training = 14
    Delta15Training = 15
    Delta16Training = 16
    Delta17Training = 17
    Inference = 18


class GradualStyleBlock(Module):
    def __init__(self, in_c, out_c, spatial):
        super(GradualStyleBlock, self).__init__()
        self.out_c = out_c
        self.spatial = spatial
        num_pools = int(np.log2(spatial))
        modules = []
        modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
                    nn.LeakyReLU()]
        for i in range(num_pools - 1):
            modules += [
                Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
                nn.LeakyReLU()
            ]
        self.convs = nn.Sequential(*modules)
        self.linear = EqualLinear(out_c, out_c, lr_mul=1)

    def forward(self, x):
        x = self.convs(x)
        x = x.view(-1, self.out_c)
        x = self.linear(x)
        return x


class GradualStyleEncoder(Module):
    def __init__(self, num_layers, mode='ir', opts=None):
        super(GradualStyleEncoder, self).__init__()
        assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
        assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
        blocks = get_blocks(num_layers)
        if mode == 'ir':
            unit_module = bottleneck_IR
        elif mode == 'ir_se':
            unit_module = bottleneck_IR_SE
        self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
                                      BatchNorm2d(64),
                                      PReLU(64))
        modules = []
        for block in blocks:
            for bottleneck in block:
                modules.append(unit_module(bottleneck.in_channel,
                                           bottleneck.depth,
                                           bottleneck.stride))
        self.body = Sequential(*modules)

        self.styles = nn.ModuleList()
        log_size = int(math.log(opts.stylegan_size, 2))
        self.style_count = 2 * log_size - 2
        self.coarse_ind = 3
        self.middle_ind = 7
        for i in range(self.style_count):
            if i < self.coarse_ind:
                style = GradualStyleBlock(512, 512, 16)
            elif i < self.middle_ind:
                style = GradualStyleBlock(512, 512, 32)
            else:
                style = GradualStyleBlock(512, 512, 64)
            self.styles.append(style)
        self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
        self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x = self.input_layer(x)

        latents = []
        modulelist = list(self.body._modules.values())
        for i, l in enumerate(modulelist):
            x = l(x)
            if i == 6:
                c1 = x
            elif i == 20:
                c2 = x
            elif i == 23:
                c3 = x

        for j in range(self.coarse_ind):
            latents.append(self.styles[j](c3))

        p2 = _upsample_add(c3, self.latlayer1(c2))
        for j in range(self.coarse_ind, self.middle_ind):
            latents.append(self.styles[j](p2))

        p1 = _upsample_add(p2, self.latlayer2(c1))
        for j in range(self.middle_ind, self.style_count):
            latents.append(self.styles[j](p1))

        out = torch.stack(latents, dim=1)
        return out


class Encoder4Editing(Module):
    def __init__(self, num_layers, stylegan_size, mode='ir'):
        super(Encoder4Editing, self).__init__()
        assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
        assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
        blocks = get_blocks(num_layers)
        if mode == 'ir':
            unit_module = bottleneck_IR
        elif mode == 'ir_se':
            unit_module = bottleneck_IR_SE
        self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
                                      BatchNorm2d(64),
                                      PReLU(64))
        modules = []
        for block in blocks:
            for bottleneck in block:
                modules.append(unit_module(bottleneck.in_channel,
                                           bottleneck.depth,
                                           bottleneck.stride))
        self.body = Sequential(*modules)

        self.styles = nn.ModuleList()
        log_size = int(math.log(stylegan_size, 2))
        self.style_count = 2 * log_size - 2
        self.coarse_ind = 3
        self.middle_ind = 7

        for i in range(self.style_count):
            if i < self.coarse_ind:
                style = GradualStyleBlock(512, 512, 16)
            elif i < self.middle_ind:
                style = GradualStyleBlock(512, 512, 32)
            else:
                style = GradualStyleBlock(512, 512, 64)
            self.styles.append(style)

        self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
        self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)

        self.progressive_stage = ProgressiveStage.Inference

    def get_deltas_starting_dimensions(self):
        ''' Get a list of the initial dimension of every delta from which it is applied '''
        return list(range(self.style_count))  # Each dimension has a delta applied to it

    def set_progressive_stage(self, new_stage: ProgressiveStage):
        self.progressive_stage = new_stage
        print('Changed progressive stage to: ', new_stage)

    def forward(self, x):
        x = self.input_layer(x)

        modulelist = list(self.body._modules.values())
        for i, l in enumerate(modulelist):
            x = l(x)
            if i == 6:
                c1 = x
            elif i == 20:
                c2 = x
            elif i == 23:
                c3 = x

        # Infer main W and duplicate it
        w0 = self.styles[0](c3)
        w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2)
        stage = self.progressive_stage.value
        features = c3
        for i in range(1, min(stage + 1, self.style_count)):  # Infer additional deltas
            if i == self.coarse_ind:
                p2 = _upsample_add(c3, self.latlayer1(c2))  # FPN's middle features
                features = p2
            elif i == self.middle_ind:
                p1 = _upsample_add(p2, self.latlayer2(c1))  # FPN's fine features
                features = p1
            delta_i = self.styles[i](features)
            w[:, i] += delta_i
        return w


class BackboneEncoderUsingLastLayerIntoW(Module):
    def __init__(self, num_layers, mode='ir', opts=None):
        super(BackboneEncoderUsingLastLayerIntoW, self).__init__()
        print('Using BackboneEncoderUsingLastLayerIntoW')
        assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
        assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
        blocks = get_blocks(num_layers)
        if mode == 'ir':
            unit_module = bottleneck_IR
        elif mode == 'ir_se':
            unit_module = bottleneck_IR_SE
        self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
                                      BatchNorm2d(64),
                                      PReLU(64))
        self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
        self.linear = EqualLinear(512, 512, lr_mul=1)
        modules = []
        for block in blocks:
            for bottleneck in block:
                modules.append(unit_module(bottleneck.in_channel,
                                           bottleneck.depth,
                                           bottleneck.stride))
        self.body = Sequential(*modules)
        log_size = int(math.log(opts.stylegan_size, 2))
        self.style_count = 2 * log_size - 2

    def forward(self, x):
        x = self.input_layer(x)
        x = self.body(x)
        x = self.output_pool(x)
        x = x.view(-1, 512)
        x = self.linear(x)
        return x.repeat(self.style_count, 1, 1).permute(1, 0, 2)


================================================
FILE: models/stylegan2/__init__.py
================================================


================================================
FILE: models/stylegan2/model.py
================================================
import math
import random

import torch
from torch import nn
from torch.nn import functional as F

from models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d

class PixelNorm(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input):
        return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)


def make_kernel(k):
    k = torch.tensor(k, dtype=torch.float32)

    if k.ndim == 1:
        k = k[None, :] * k[:, None]

    k /= k.sum()

    return k


class Upsample(nn.Module):
    def __init__(self, kernel, factor=2):
        super().__init__()

        self.factor = factor
        kernel = make_kernel(kernel) * (factor ** 2)
        self.register_buffer('kernel', kernel)

        p = kernel.shape[0] - factor

        pad0 = (p + 1) // 2 + factor - 1
        pad1 = p // 2

        self.pad = (pad0, pad1)

    def forward(self, input):
        out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)

        return out


class Downsample(nn.Module):
    def __init__(self, kernel, factor=2):
        super().__init__()

        self.factor = factor
        kernel = make_kernel(kernel)
        self.register_buffer('kernel', kernel)

        p = kernel.shape[0] - factor

        pad0 = (p + 1) // 2
        pad1 = p // 2

        self.pad = (pad0, pad1)

    def forward(self, input):
        out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)

        return out


class Blur(nn.Module):
    def __init__(self, kernel, pad, upsample_factor=1):
        super().__init__()

        kernel = make_kernel(kernel)

        if upsample_factor > 1:
            kernel = kernel * (upsample_factor ** 2)

        self.register_buffer('kernel', kernel)

        self.pad = pad

    def forward(self, input):
        out = upfirdn2d(input, self.kernel, pad=self.pad)

        return out


class EqualConv2d(nn.Module):
    def __init__(
        self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
    ):
        super().__init__()

        self.weight = nn.Parameter(
            torch.randn(out_channel, in_channel, kernel_size, kernel_size)
        )
        self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)

        self.stride = stride
        self.padding = padding

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channel))

        else:
            self.bias = None

    def forward(self, input):
        out = F.conv2d(
            input,
            self.weight * self.scale,
            bias=self.bias,
            stride=self.stride,
            padding=self.padding,
        )

        return out

    def __repr__(self):
        return (
            f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
            f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
        )


class EqualLinear(nn.Module):
    def __init__(
        self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
    ):
        super().__init__()

        self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))

        else:
            self.bias = None

        self.activation = activation

        self.scale = (1 / math.sqrt(in_dim)) * lr_mul
        self.lr_mul = lr_mul

    def forward(self, input):
        if self.activation:
            out = F.linear(input, self.weight * self.scale)
            out = fused_leaky_relu(out, self.bias * self.lr_mul)

        else:
            out = F.linear(
                input, self.weight * self.scale, bias=self.bias * self.lr_mul
            )

        return out

    def __repr__(self):
        return (
            f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
        )


class ScaledLeakyReLU(nn.Module):
    def __init__(self, negative_slope=0.2):
        super().__init__()

        self.negative_slope = negative_slope

    def forward(self, input):
        out = F.leaky_relu(input, negative_slope=self.negative_slope)

        return out * math.sqrt(2)


class ModulatedConv2d(nn.Module):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        style_dim,
        demodulate=True,
        upsample=False,
        downsample=False,
        blur_kernel=[1, 3, 3, 1],
    ):
        super().__init__()

        self.eps = 1e-8
        self.kernel_size = kernel_size
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.upsample = upsample
        self.downsample = downsample

        if upsample:
            factor = 2
            p = (len(blur_kernel) - factor) - (kernel_size - 1)
            pad0 = (p + 1) // 2 + factor - 1
            pad1 = p // 2 + 1

            self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)

        if downsample:
            factor = 2
            p = (len(blur_kernel) - factor) + (kernel_size - 1)
            pad0 = (p + 1) // 2
            pad1 = p // 2

            self.blur = Blur(blur_kernel, pad=(pad0, pad1))

        fan_in = in_channel * kernel_size ** 2
        self.scale = 1 / math.sqrt(fan_in)
        self.padding = kernel_size // 2

        self.weight = nn.Parameter(
            torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
        )

        self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)

        self.demodulate = demodulate

    def __repr__(self):
        return (
            f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
            f'upsample={self.upsample}, downsample={self.downsample})'
        )

    def forward(self, input, style):
        batch, in_channel, height, width = input.shape

        style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
        weight = self.scale * self.weight * style

        if self.demodulate:
            demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
            weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)

        weight = weight.view(
            batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
        )

        if self.upsample:
            input = input.view(1, batch * in_channel, height, width)
            weight = weight.view(
                batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
            )
            weight = weight.transpose(1, 2).reshape(
                batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
            )
            out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
            _, _, height, width = out.shape
            out = out.view(batch, self.out_channel, height, width)
            out = self.blur(out)

        elif self.downsample:
            input = self.blur(input)
            _, _, height, width = input.shape
            input = input.view(1, batch * in_channel, height, width)
            out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
            _, _, height, width = out.shape
            out = out.view(batch, self.out_channel, height, width)

        else:
            input = input.view(1, batch * in_channel, height, width)
            out = F.conv2d(input, weight, padding=self.padding, groups=batch)
            _, _, height, width = out.shape
            out = out.view(batch, self.out_channel, height, width)

        return out


class NoiseInjection(nn.Module):
    def __init__(self):
        super().__init__()

        self.weight = nn.Parameter(torch.zeros(1))

    def forward(self, image, noise=None):
        if noise is None:
            batch, _, height, width = image.shape
            noise = image.new_empty(batch, 1, height, width).normal_()

        return image + self.weight * noise


class ConstantInput(nn.Module):
    def __init__(self, channel, size=4):
        super().__init__()

        self.input = nn.Parameter(torch.randn(1, channel, size, size))

    def forward(self, input):
        batch = input.shape[0]
        out = self.input.repeat(batch, 1, 1, 1)

        return out


class StyledConv(nn.Module):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        style_dim,
        upsample=False,
        blur_kernel=[1, 3, 3, 1],
        demodulate=True,
    ):
        super().__init__()

        self.conv = ModulatedConv2d(
            in_channel,
            out_channel,
            kernel_size,
            style_dim,
            upsample=upsample,
            blur_kernel=blur_kernel,
            demodulate=demodulate,
        )

        self.noise = NoiseInjection()
        # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
        # self.activate = ScaledLeakyReLU(0.2)
        self.activate = FusedLeakyReLU(out_channel)

    def forward(self, input, style, noise=None):
        out = self.conv(input, style)
        out = self.noise(out, noise=noise)
        # out = out + self.bias
        out = self.activate(out)

        return out


class ToRGB(nn.Module):
    def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
        super().__init__()

        if upsample:
            self.upsample = Upsample(blur_kernel)

        self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
        self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))

    def forward(self, input, style, skip=None):
        out = self.conv(input, style)
        out = out + self.bias

        if skip is not None:
            skip = self.upsample(skip)

            out = out + skip

        return out


class Generator(nn.Module):
    def __init__(
        self,
        size,
        style_dim,
        n_mlp,
        channel_multiplier=2,
        blur_kernel=[1, 3, 3, 1],
        lr_mlp=0.01,
    ):
        super().__init__()

        self.size = size

        self.style_dim = style_dim

        layers = [PixelNorm()]

        for i in range(n_mlp):
            layers.append(
                EqualLinear(
                    style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
                )
            )

        self.style = nn.Sequential(*layers)

        self.channels = {
            4: 512,
            8: 512,
            16: 512,
            32: 512,
            64: 256 * channel_multiplier,
            128: 128 * channel_multiplier,
            256: 64 * channel_multiplier,
            512: 32 * channel_multiplier,
            1024: 16 * channel_multiplier,
        }

        self.input = ConstantInput(self.channels[4])
        self.conv1 = StyledConv(
            self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
        )
        self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)

        self.log_size = int(math.log(size, 2))
        self.num_layers = (self.log_size - 2) * 2 + 1

        self.convs = nn.ModuleList()
        self.upsamples = nn.ModuleList()
        self.to_rgbs = nn.ModuleList()
        self.noises = nn.Module()

        in_channel = self.channels[4]

        for layer_idx in range(self.num_layers):
            res = (layer_idx + 5) // 2
            shape = [1, 1, 2 ** res, 2 ** res]
            self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))

        for i in range(3, self.log_size + 1):
            out_channel = self.channels[2 ** i]

            self.convs.append(
                StyledConv(
                    in_channel,
                    out_channel,
                    3,
                    style_dim,
                    upsample=True,
                    blur_kernel=blur_kernel,
                )
            )

            self.convs.append(
                StyledConv(
                    out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
                )
            )

            self.to_rgbs.append(ToRGB(out_channel, style_dim))

            in_channel = out_channel

        self.n_latent = self.log_size * 2 - 2

    def make_noise(self):
        device = self.input.input.device

        noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]

        for i in range(3, self.log_size + 1):
            for _ in range(2):
                noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))

        return noises

    def mean_latent(self, n_latent):
        latent_in = torch.randn(
            n_latent, self.style_dim, device=self.input.input.device
        )
        latent = self.style(latent_in).mean(0, keepdim=True)

        return latent

    def get_latent(self, input):
        return self.style(input)

    def forward(
        self,
        styles,
        return_latents=False,
        inject_index=None,
        truncation=1,
        truncation_latent=None,
        input_is_latent=False,
        noise=None,
        randomize_noise=True,
    ):
        if not input_is_latent:
            styles = [self.style(s) for s in styles]

        if noise is None:
            if randomize_noise:
                noise = [None] * self.num_layers
            else:
                noise = [
                    getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
                ]

        if truncation < 1:
            style_t = []

            for style in styles:
                style_t.append(
                    truncation_latent + truncation * (style - truncation_latent)
                )

            styles = style_t

        if len(styles) < 2:
            inject_index = self.n_latent

            if styles[0].ndim < 3:
                latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)

            else:
                latent = styles[0]

        else:
            if inject_index is None:
                inject_index = random.randint(1, self.n_latent - 1)

            latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
            latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)

            latent = torch.cat([latent, latent2], 1)

        out = self.input(latent)
        out = self.conv1(out, latent[:, 0], noise=noise[0])

        skip = self.to_rgb1(out, latent[:, 1])

        i = 1
        for conv1, conv2, noise1, noise2, to_rgb in zip(
            self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
        ):
            out = conv1(out, latent[:, i], noise=noise1)
            out = conv2(out, latent[:, i + 1], noise=noise2)
            skip = to_rgb(out, latent[:, i + 2], skip)

            i += 2

        image = skip

        if return_latents:
            return image, latent

        else:
            return image, None


class ConvLayer(nn.Sequential):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        downsample=False,
        blur_kernel=[1, 3, 3, 1],
        bias=True,
        activate=True,
    ):
        layers = []

        if downsample:
            factor = 2
            p = (len(blur_kernel) - factor) + (kernel_size - 1)
            pad0 = (p + 1) // 2
            pad1 = p // 2

            layers.append(Blur(blur_kernel, pad=(pad0, pad1)))

            stride = 2
            self.padding = 0

        else:
            stride = 1
            self.padding = kernel_size // 2

        layers.append(
            EqualConv2d(
                in_channel,
                out_channel,
                kernel_size,
                padding=self.padding,
                stride=stride,
                bias=bias and not activate,
            )
        )

        if activate:
            if bias:
                layers.append(FusedLeakyReLU(out_channel))

            else:
                layers.append(ScaledLeakyReLU(0.2))

        super().__init__(*layers)


class ResBlock(nn.Module):
    def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
        super().__init__()

        self.conv1 = ConvLayer(in_channel, in_channel, 3)
        self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)

        self.skip = ConvLayer(
            in_channel, out_channel, 1, downsample=True, activate=False, bias=False
        )

    def forward(self, input):
        out = self.conv1(input)
        out = self.conv2(out)

        skip = self.skip(input)
        out = (out + skip) / math.sqrt(2)

        return out


class Discriminator(nn.Module):
    def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
        super().__init__()

        channels = {
            4: 512,
            8: 512,
            16: 512,
            32: 512,
            64: 256 * channel_multiplier,
            128: 128 * channel_multiplier,
            256: 64 * channel_multiplier,
            512: 32 * channel_multiplier,
            1024: 16 * channel_multiplier,
        }

        convs = [ConvLayer(3, channels[size], 1)]

        log_size = int(math.log(size, 2))

        in_channel = channels[size]

        for i in range(log_size, 2, -1):
            out_channel = channels[2 ** (i - 1)]

            convs.append(ResBlock(in_channel, out_channel, blur_kernel))

            in_channel = out_channel

        self.convs = nn.Sequential(*convs)

        self.stddev_group = 4
        self.stddev_feat = 1

        self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
        self.final_linear = nn.Sequential(
            EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
            EqualLinear(channels[4], 1),
        )

    def forward(self, input):
        out = self.convs(input)

        batch, channel, height, width = out.shape
        group = min(batch, self.stddev_group)
        stddev = out.view(
            group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
        )
        stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
        stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
        stddev = stddev.repeat(group, 1, height, width)
        out = torch.cat([out, stddev], 1)

        out = self.final_conv(out)

        out = out.view(batch, -1)
        out = self.final_linear(out)

        return out



================================================
FILE: models/stylegan2/op/__init__.py
================================================
from .fused_act import FusedLeakyReLU, fused_leaky_relu
from .upfirdn2d import upfirdn2d


================================================
FILE: models/stylegan2/op/fused_act.py
================================================
import os

import torch
from torch import nn
from torch.nn import functional as F

module_path = os.path.dirname(__file__)

class FusedLeakyReLU(nn.Module):
    def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
        super().__init__()

        self.bias = nn.Parameter(torch.zeros(channel))
        self.negative_slope = negative_slope
        self.scale = scale

    def forward(self, input):
        return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)


def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
    rest_dim = [1] * (input.ndim - bias.ndim - 1)
    input = input.cuda()
    if input.ndim == 3:
        return (
            F.leaky_relu(
                input + bias.view(1, *rest_dim, bias.shape[0]), negative_slope=negative_slope
            )
            * scale
        )
    else:
        return (
            F.leaky_relu(
                input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope
            )
            * scale
        )



================================================
FILE: models/stylegan2/op/upfirdn2d.py
================================================
import os
import torch
from torch.nn import functional as F

module_path = os.path.dirname(__file__)

def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
    out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])

    return out

def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
    
    _, channel, in_h, in_w = input.shape
    input = input.reshape(-1, in_h, in_w, 1)

    _, in_h, in_w, minor = input.shape
    kernel_h, kernel_w = kernel.shape

    out = input.view(-1, in_h, 1, in_w, 1, minor)
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(
        out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
    )
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape(
        [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
    )
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1

    return out.view(-1, channel, out_h, out_w)

================================================
FILE: options/test_options.py
================================================
from argparse import ArgumentParser

class TestOptions:

	def __init__(self):
		self.parser = ArgumentParser()
		self.initialize()

	def initialize(self):
		# arguments for inference script
		
		self.parser.add_argument('--batch_size', default=1, type=int, help='Batch size for inference')
		self.parser.add_argument('--workers', default=4, type=int, help='Number of test dataloader workers')
		
		self.parser.add_argument('--stylegan_weights', default='models/pretrained_models/stylegan2-ffhq-config-f.pt', type=str, help='Path to StyleGAN model weights')
		self.parser.add_argument('--stylegan_size', default=1024, type=int)
		
		self.parser.add_argument("--threshold", type=int, default=0.03)
		self.parser.add_argument("--checkpoint_path", type=str, default='checkpoints/net_face.pth')
		self.parser.add_argument("--save_dir", type=str, default='output')
		self.parser.add_argument("--num_all", type=int, default=20)
		
		self.parser.add_argument("--target", type=str, required=True, help='Specify the target attributes to be edited')

	def parse(self):
		opts = self.parser.parse_args()
		return opts

================================================
FILE: options/train_options.py
================================================
from argparse import ArgumentParser

class TrainOptions:

	def __init__(self):
		self.parser = ArgumentParser()
		self.initialize()

	def initialize(self):

		self.parser.add_argument('--batch_size', default=64, type=int, help='Batch size for training')
		self.parser.add_argument('--workers', default=4, type=int, help='Number of train dataloader workers')

		self.parser.add_argument('--learning_rate', default=0.5, type=float, help='Optimizer learning rate')

		self.parser.add_argument('--l2_lambda', default=1.0, type=float, help='l2 loss')
		self.parser.add_argument('--cos_lambda', default=1.0, type=float, help='cos loss')

		self.parser.add_argument('--checkpoint_path', default='checkpoints', type=str, help='Path to StyleCLIPModel model checkpoint')
		self.parser.add_argument('--classname', type=str, default='ffhq', help="which specific domain for training")
		self.parser.add_argument('--print_interval', default=1000, type=int, help='Interval for printing loss values during training')
		self.parser.add_argument('--val_interval', default=5000, type=int, help='Validation interval')
		self.parser.add_argument('--save_interval', default=10000, type=int, help='Model checkpoint interval')

	def parse(self):
		opts = self.parser.parse_args()
		return opts

================================================
FILE: scripts/inference.py
================================================
import os
import sys
sys.path.append(".")
sys.path.append("..")

import copy
import clip
import numpy as np

import torch
import torchvision
from torch.utils.data import DataLoader

import torch.nn.functional as F

from datasets.test_dataset import TestLatentsDataset

from models.stylegan2.model import Generator
from delta_mapper import DeltaMapper

from options.test_options import TestOptions

from utils import map_tool
from utils import stylespace_util

def GetBoundary(fs3,dt,threshold):
    tmp=np.dot(fs3,dt)
    
    select=np.abs(tmp)<threshold
    return select

def improved_ds(ds, select):
    ds_imp = copy.copy(ds)
    ds_imp[select] = 0
    ds_imp = ds_imp.unsqueeze(0)
    return ds_imp

def main(opts):

    device = "cuda" if torch.cuda.is_available() else "cpu"

    #Initialize test dataset
    test_dataset = TestLatentsDataset()
    test_dataloader = DataLoader(test_dataset, 
                                 batch_size=opts.batch_size,
                                 shuffle=False,
                                 num_workers=int(opts.workers),
                                 drop_last=True)

    #Initialize generator
    print('Loading stylegan weights from pretrained!')
    g_ema = Generator(size=opts.stylegan_size, style_dim=512, n_mlp=8)
    g_ema_ckpt = torch.load(opts.stylegan_weights)
    g_ema.load_state_dict(g_ema_ckpt['g_ema'], strict=False)
    g_ema.eval()
    g_ema = g_ema.to(device)

    #load relevance matrix Rs
    fs3=np.load('./models/stylegan2/npy_ffhq/fs3.npy')
    np.set_printoptions(suppress=True)

    #Initialze DeltaMapper
    net = DeltaMapper()
    net_ckpt = torch.load(opts.checkpoint_path)
    net.load_state_dict(net_ckpt)
    net = net.to(device)
    
    #Load CLIP model
    clip_model, preprocess = clip.load("ViT-B/32", device=device)

    os.makedirs(opts.save_dir, exist_ok=True)

    neutral='face'
    target_list = opts.target.split(',')
    # print(target_list)

    dt_list = []
    select_list = []
    for target in target_list:
        classnames=[target,neutral]
        dt = map_tool.GetDt(classnames,clip_model)
        select = GetBoundary(fs3, dt, opts.threshold)
        dt = torch.Tensor(dt).to(device)
        dt = dt / dt.norm(dim=-1, keepdim=True).float().clamp(min=1e-5)

        select_list.append(select)
        dt_list.append(dt)

    for bid, batch in enumerate(test_dataloader):
        if bid == opts.num_all:
            break
        
        latent_s, delta_c, latent_w = batch
        latent_s = latent_s.to(device)
        delta_c = delta_c.to(device)
        latent_w = latent_w.to(device)
        delta_s_list = []

        for i, dt in enumerate(dt_list):
            delta_c[0, 512:] = dt
            with torch.no_grad():
                fake_delta_s = net(latent_s, delta_c)
                improved_fake_delta_s = improved_ds(fake_delta_s[0], select_list[i])
            delta_s_list.append(improved_fake_delta_s)

        with torch.no_grad():
            img_ori = stylespace_util.decoder_validate(g_ema, latent_s, latent_w)

            img_list = [img_ori]
            for delta_s in delta_s_list:
                img_gen = stylespace_util.decoder_validate(g_ema, latent_s + delta_s, latent_w)
                img_list.append(img_gen)
            img_gen_all = torch.cat(img_list, dim=3)
            torchvision.utils.save_image(img_gen_all, os.path.join(opts.save_dir, "%04d.jpg" %(bid+1)), normalize=True, range=(-1, 1))
    print(f'completed👍! Please check results in {opts.save_dir}')

if __name__ == "__main__":
    opts = TestOptions().parse()
    main(opts)

================================================
FILE: scripts/inference_real.py
================================================
import os
import sys
sys.path.append(".")
sys.path.append("..")

import copy
import clip
import numpy as np
from PIL import Image

import torch
import torchvision
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms

import torch.nn.functional as F


from datasets.test_dataset import TestLatentsDataset

from models.stylegan2.model import Generator
from models.encoders import psp_encoders
from delta_mapper import DeltaMapper

from options.test_options import TestOptions

from utils import map_tool
from utils import stylespace_util



def get_keys(d, name):
    if 'state_dict' in d:
        d = d['state_dict']
    d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
    return d_filt

class Imagedataset(Dataset):
    def __init__(self,
                 path,
                 image_size=256,
                 split=None):

        self.path = path
        self.images = os.listdir(path)

        self.image_size = image_size

        self.length = len(self.images)

        transform = [
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]

        self.transform = transforms.Compose(transform)

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        cur_name = self.images[index]
        img_path = os.path.join(self.path, cur_name)

        img = Image.open(img_path).convert("RGB") 

        if self.transform is not None:
            img = self.transform(img)
        return img

def encoder_latent(G, latent):
    # an encoder warper for G
    #styles = [noise]
    style_space = []
    
    #styles = [G.style(s) for s in styles]
    noise = [getattr(G.noises, 'noise_{}'.format(i)) for i in range(G.num_layers)]
    # inject_index = G.n_latent
    #latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
    style_space.append(G.conv1.conv.modulation(latent[:, 0]))

    i = 1
    for conv1, conv2, to_rgb in zip(
        G.convs[::2], G.convs[1::2], G.to_rgbs
    ):
        style_space.append(conv1.conv.modulation(latent[:, i]))
        style_space.append(conv2.conv.modulation(latent[:, i+1]))
        i += 2
        
    return style_space, noise

def GetBoundary(fs3,dt,threshold):
    tmp=np.dot(fs3,dt)
    
    select=np.abs(tmp)<threshold
    return select

def improved_ds(ds, select):
    ds_imp = copy.copy(ds)
    ds_imp[select] = 0
    ds_imp = ds_imp.unsqueeze(0)
    return ds_imp

def main(opts):

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # NOTE load e4e
    checkpoint_path = "encoder4editing-main/e4e_ffhq_encode.pt"
    ckpt_enc = torch.load(checkpoint_path, map_location='cpu') #dict_keys(['state_dict', 'latent_avg', 'opts'])
    encoder = psp_encoders.Encoder4Editing(50, 1024, 'ir_se')
    encoder.load_state_dict(get_keys(ckpt_enc, 'encoder'), strict=True)
    encoder.eval()
    encoder.to(device)

    #Initialize test dataset
    test_dataset = Imagedataset('./test_imgs', image_size=256)
    test_dataloader = DataLoader(test_dataset, 
                                 batch_size=opts.batch_size,
                                 shuffle=False,
                                 num_workers=int(opts.workers),
                                 drop_last=True)

    #Initialize generator
    print('Loading stylegan weights from pretrained!')
    g_ema = Generator(size=opts.stylegan_size, style_dim=512, n_mlp=8)
    g_ema_ckpt = torch.load(opts.stylegan_weights)
    g_ema.load_state_dict(g_ema_ckpt['g_ema'], strict=False)
    g_ema.eval()
    g_ema = g_ema.to(device)

    #load relevance matrix Rs
    fs3=np.load('./models/stylegan2/npy_ffhq/fs3.npy')
    np.set_printoptions(suppress=True)

    #Initialze DeltaMapper
    net = DeltaMapper()
    net_ckpt = torch.load(opts.checkpoint_path)
    net.load_state_dict(net_ckpt)
    net = net.to(device)
    
    #Load CLIP model
    clip_model, preprocess = clip.load("ViT-B/32", device=device)
    avg_pool = torch.nn.AvgPool2d(kernel_size=256//32)
    upsample = torch.nn.Upsample(scale_factor=7)

    os.makedirs(opts.save_dir, exist_ok=True)

    neutral='face'
    target_list = opts.target.split(',')
    # print(target_list)

    dt_list = []
    select_list = []
    for target in target_list:
        classnames=[target,neutral]
        dt = map_tool.GetDt(classnames,clip_model)
        select = GetBoundary(fs3, dt, opts.threshold)
        dt = torch.Tensor(dt).to(device)
        dt = dt / dt.norm(dim=-1, keepdim=True).float().clamp(min=1e-5)

        select_list.append(select)
        dt_list.append(dt)

    for bid, batch in enumerate(test_dataloader):
        if bid == opts.num_all:
            break
        input_img = batch.to(device)
        with torch.no_grad():
            latent_w = encoder(input_img)
            latent_avg = ckpt_enc['latent_avg'].cuda()
            latent_w = latent_w + latent_avg.repeat(latent_w.shape[0], 1, 1)

            style_space, noise = encoder_latent(g_ema, latent_w)
            latent_s = torch.cat(style_space, dim=1)

            img_gen_for_clip = upsample(input_img)
            img_gen_for_clip = avg_pool(img_gen_for_clip)
            c_latents = clip_model.encode_image(img_gen_for_clip)
            c_latents = c_latents / c_latents.norm(dim=-1, keepdim=True).float()

        delta_s_list = []

        for i, dt in enumerate(dt_list):
            delta_c = torch.cat((c_latents, dt.unsqueeze(0)), dim=1)
            with torch.no_grad():
                fake_delta_s = net(latent_s, delta_c)
                improved_fake_delta_s = improved_ds(fake_delta_s[0], select_list[i])
            delta_s_list.append(improved_fake_delta_s)

        with torch.no_grad():
            img_ori = stylespace_util.decoder_validate(g_ema, latent_s, latent_w)

            img_list = [img_ori]
            for delta_s in delta_s_list:
                img_gen = stylespace_util.decoder_validate(g_ema, latent_s + delta_s, latent_w)
                img_list.append(img_gen)
            img_gen_all = torch.cat(img_list, dim=3)
            torchvision.utils.save_image(img_gen_all, os.path.join(opts.save_dir, "%04d.jpg" %(bid+1)), normalize=True, range=(-1, 1))
    print(f'completed👍! Please check results in {opts.save_dir}')

if __name__ == "__main__":
    opts = TestOptions().parse()
    main(opts)


================================================
FILE: scripts/train.py
================================================
import os
import sys

import torch
from torch.utils.data import DataLoader

sys.path.append(".")
sys.path.append("..")

from datasets.train_dataset import TrainLatentsDataset
from options.train_options import TrainOptions
from delta_mapper import DeltaMapper

def main(opts):

    device = "cuda" if torch.cuda.is_available() else "cpu"

    train_dataset = TrainLatentsDataset(opts)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=opts.batch_size,
                                  shuffle=True,
                                  num_workers=int(opts.workers),
                                  drop_last=True)

    #Initialze DeltaMapper
    net = DeltaMapper().to(device)

    #Initialize optimizer
    optimizer = torch.optim.Adam(list(net.parameters()), lr=opts.learning_rate)

    #Initialize loss
    l2_loss = torch.nn.MSELoss().to(device)
    cosine_loss = torch.nn.CosineSimilarity(dim=-1).to(device)

    #save dir
    os.makedirs(os.path.join(opts.checkpoint_path, opts.classname), exist_ok=True)

    for batch_idx, batch in enumerate(train_dataloader):

        latent_s, delta_c, delta_s = batch
        latent_s = latent_s.to(device)
        delta_c = delta_c.to(device)
        delta_s = delta_s.to(device)

        fake_delta_s = net(latent_s, delta_c)

        optimizer.zero_grad()
        loss_l2 = l2_loss(fake_delta_s, delta_s)
        loss_cos = 1 - torch.mean(cosine_loss(fake_delta_s, delta_s))

        loss = opts.l2_lambda * loss_l2 + opts.cos_lambda * loss_cos
        loss.backward()
        optimizer.step()

        if batch_idx % opts.print_interval == 0 :
            print(batch_idx, loss.detach().cpu().numpy(), loss_l2.detach().cpu().numpy(), loss_cos.detach().cpu().numpy())

        if batch_idx % opts.save_interval == 0:
            torch.save(net.state_dict(), os.path.join(opts.checkpoint_path, opts.classname, "net_%06d.pth" % batch_idx))

if __name__ == "__main__":
    opts = TrainOptions().parse()
    main(opts)

================================================
FILE: tSNE/compute_tsne.py
================================================
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

data_name = 'celeba' # 'celeba' or 'cocoval'

#for img/text
cspace_img = np.load(f'./{data_name}/cspace_{data_name}_i.npy')
cspace_text = np.load(f'./{data_name}/cspace_{data_name}_t.npy')

#for deltaimg
cspace_deltaimg = np.load(f'./{data_name}/cspace_{data_name}_deltai.npy')
cspace_deltatext = np.load(f'./{data_name}/cspace_{data_name}_deltat.npy')

num=1000

data_ori = np.concatenate([cspace_img[:num], cspace_text[:num]], axis=0)
data_delta = np.concatenate([cspace_deltaimg[:num], cspace_deltatext[:num]], axis=0)

tsne = TSNE(n_components=2, init='pca')

result_ori = tsne.fit_transform(data_ori)
result_delta = tsne.fit_transform(data_delta)

for i in range(result_ori.shape[0]):
    x_min, x_max = np.min(result_ori, 0), np.max(result_ori, 0)
    data = (result_ori - x_min) / (x_max - x_min)
    if i < result_ori.shape[0]//2:
        s0 = plt.scatter(data[i, 0], data[i, 1], color=plt.cm.Set1(0/4), s=12, marker='o')
    elif i < result_ori.shape[0]:
        s1 = plt.scatter(data[i, 0], data[i, 1], color=plt.cm.Set1(1/4), s=12, marker='o')
    
plt.legend((s0, s1), ('CLIP Image Space', 'CLIP Text Space'), fontsize=10)
plt.xticks()
plt.yticks()
plt.title('t-SNE Results')
plt.tight_layout()
plt.savefig(f'tSNE-{data_name}-{num}_ori.png')

plt.close()

for i in range(result_delta.shape[0]):
    x_min, x_max = np.min(result_delta, 0), np.max(result_delta, 0)
    data = (result_delta - x_min) / (x_max - x_min)
    if i < result_delta.shape[0]//2:
        s0 = plt.scatter(data[i, 0], data[i, 1], color=plt.cm.Set1(2/4), s=12, marker='o')
    elif i < result_delta.shape[0]:
        s1 = plt.scatter(data[i, 0], data[i, 1], color=plt.cm.Set1(3/4), s=12, marker='o')
    
plt.legend((s0, s1), ('CLIP Delta Image Space', 'CLIP Delta Text Space'), fontsize=10)
plt.xticks()
plt.yticks()
plt.title('t-SNE Results')
plt.tight_layout()
plt.savefig(f'tSNE-{data_name}-{num}_delta.png')

================================================
FILE: utils/map_tool.py
================================================
import torch
import clip
import os
import numpy as np

imagenet_templates = [
    'a bad photo of a {}.',
#    'a photo of many {}.',
    'a sculpture of a {}.',
    'a photo of the hard to see {}.',
    'a low resolution photo of the {}.',
    'a rendering of a {}.',
    'graffiti of a {}.',
    'a bad photo of the {}.',
    'a cropped photo of the {}.',
    'a tattoo of a {}.',
    'the embroidered {}.',
    'a photo of a hard to see {}.',
    'a bright photo of a {}.',
    'a photo of a clean {}.',
    'a photo of a dirty {}.',
    'a dark photo of the {}.',
    'a drawing of a {}.',
    'a photo of my {}.',
    'the plastic {}.',
    'a photo of the cool {}.',
    'a close-up photo of a {}.',
    'a black and white photo of the {}.',
    'a painting of the {}.',
    'a painting of a {}.',
    'a pixelated photo of the {}.',
    'a sculpture of the {}.',
    'a bright photo of the {}.',
    'a cropped photo of a {}.',
    'a plastic {}.',
    'a photo of the dirty {}.',
    'a jpeg corrupted photo of a {}.',
    'a blurry photo of the {}.',
    'a photo of the {}.',
    'a good photo of the {}.',
    'a rendering of the {}.',
    'a {} in a video game.',
    'a photo of one {}.',
    'a doodle of a {}.',
    'a close-up photo of the {}.',
    'a photo of a {}.',
    'the origami {}.',
    'the {} in a video game.',
    'a sketch of a {}.',
    'a doodle of the {}.',
    'a origami {}.',
    'a low resolution photo of a {}.',
    'the toy {}.',
    'a rendition of the {}.',
    'a photo of the clean {}.',
    'a photo of a large {}.',
    'a rendition of a {}.',
    'a photo of a nice {}.',
    'a photo of a weird {}.',
    'a blurry photo of a {}.',
    'a cartoon {}.',
    'art of a {}.',
    'a sketch of the {}.',
    'a embroidered {}.',
    'a pixelated photo of a {}.',
    'itap of the {}.',
    'a jpeg corrupted photo of the {}.',
    'a good photo of a {}.',
    'a plushie {}.',
    'a photo of the nice {}.',
    'a photo of the small {}.',
    'a photo of the weird {}.',
    'the cartoon {}.',
    'art of the {}.',
    'a drawing of the {}.',
    'a photo of the large {}.',
    'a black and white photo of a {}.',
    'the plushie {}.',
    'a dark photo of a {}.',
    'itap of a {}.',
    'graffiti of the {}.',
    'a toy {}.',
    'itap of my {}.',
    'a photo of a cool {}.',
    'a photo of a small {}.',
    'a tattoo of the {}.',
]

def zeroshot_classifier(classnames, templates,model):
    with torch.no_grad():
        zeroshot_weights = []
        for classname in classnames:
            texts = [template.format(classname) for template in templates] #format with class
            texts = clip.tokenize(texts).cuda() #tokenize
            class_embeddings = model.encode_text(texts) #embed with text encoder
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
    return zeroshot_weights

def GetDt(classnames,model):
    text_features=zeroshot_classifier(classnames, imagenet_templates,model).t()
    
    dt=text_features[0]-text_features[1]
    dt=dt.cpu().numpy()
    
    return dt


if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load("ViT-B/32", device=device)

    neutral='face with eyes' #@param {type:"string"}
    target='face with blue eyes' #@param {type:"string"}
    classnames=[target,neutral]
    dt = GetDt(classnames,model)
    print(dt.shape)

================================================
FILE: utils/stylespace_util.py
================================================
import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision

from torch.nn import functional as F

index = [0,1,1,2,2,3,4,4,5,6,6,7,8,8,9,10,10,11,12,12,13,14,14,15,16,16]

def conv_warper(layer, input, style, noise):

    conv = layer.conv
    batch, in_channel, height, width = input.shape

    style = style.view(batch, 1, in_channel, 1, 1)
    weight = conv.scale * conv.weight * style

    if conv.demodulate:
        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
        weight = weight * demod.view(batch, conv.out_channel, 1, 1, 1)

    weight = weight.view(
        batch * conv.out_channel, in_channel, conv.kernel_size, conv.kernel_size
    )

    if conv.upsample:
        input = input.view(1, batch * in_channel, height, width)
        weight = weight.view(
            batch, conv.out_channel, in_channel, conv.kernel_size, conv.kernel_size
        )
        weight = weight.transpose(1, 2).reshape(
            batch * in_channel, conv.out_channel, conv.kernel_size, conv.kernel_size
        )
        out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
        _, _, height, width = out.shape
        out = out.view(batch, conv.out_channel, height, width)
        out = conv.blur(out)

    elif conv.downsample:
        input = conv.blur(input)
        _, _, height, width = input.shape
        input = input.view(1, batch * in_channel, height, width)
        out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
        _, _, height, width = out.shape
        out = out.view(batch, conv.out_channel, height, width)

    else:
        input = input.view(1, batch * in_channel, height, width)
        out = F.conv2d(input, weight, padding=conv.padding, groups=batch)
        _, _, height, width = out.shape
        out = out.view(batch, conv.out_channel, height, width)
        
    out = layer.noise(out, noise=noise)
    out = layer.activate(out)
    
    return out

def decoder(G, style_space, latent, noise):

    out = G.input(latent)
    out = conv_warper(G.conv1, out, style_space[0], noise[0])
    skip = G.to_rgb1(out, latent[:, 1])

    i = 1
    for conv1, conv2, noise1, noise2, to_rgb in zip(
        G.convs[::2], G.convs[1::2], noise[1::2], noise[2::2], G.to_rgbs
    ):
        out = conv_warper(conv1, out, style_space[i], noise=noise1)
        out = conv_warper(conv2, out, style_space[i+1], noise=noise2)
        skip = to_rgb(out, latent[:, i + 2], skip)

        i += 2

    image = skip

    return image

def decoder_validate(G, style_space, latent):

    style_space = split_stylespace(style_space)
    noise = [getattr(G.noises, 'noise_{}'.format(i)) for i in range(G.num_layers)]

    out = G.input(latent)
    out = conv_warper(G.conv1, out, style_space[0], noise[0])
    skip = G.to_rgb1(out, latent[:, 1])

    i = 1
    for conv1, conv2, noise1, noise2, to_rgb in zip(
        G.convs[::2], G.convs[1::2], noise[1::2], noise[2::2], G.to_rgbs
    ):
        out = conv_warper(conv1, out, style_space[i], noise=noise1)
        out = conv_warper(conv2, out, style_space[i+1], noise=noise2)
        skip = to_rgb(out, latent[:, i + 2], skip)

        i += 2

    image = skip

    return image

def encoder_noise(G, noise):

    styles = [noise]
    style_space = []
    
    styles = [G.style(s) for s in styles]
    noise = [getattr(G.noises, 'noise_{}'.format(i)) for i in range(G.num_layers)]
    inject_index = G.n_latent
    latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
    style_space.append(G.conv1.conv.modulation(latent[:, 0]))

    i = 1
    for conv1, conv2, noise1, noise2, to_rgb in zip(
        G.convs[::2], G.convs[1::2], noise[1::2], noise[2::2], G.to_rgbs
    ):
        style_space.append(conv1.conv.modulation(latent[:, i]))
        style_space.append(conv2.conv.modulation(latent[:, i+1]))
        i += 2
        
    return style_space, latent, noise

def encoder_latent(G, latent):
    # an encoder warper for G

    style_space = []
    
    noise = [getattr(G.noises, 'noise_{}'.format(i)) for i in range(G.num_layers)]

    style_space.append(G.conv1.conv.modulation(latent[:, 0]))

    i = 1
    for conv1, conv2, to_rgb in zip(
        G.convs[::2], G.convs[1::2], G.to_rgbs
    ):
        style_space.append(conv1.conv.modulation(latent[:, i]))
        style_space.append(conv2.conv.modulation(latent[:, i+1]))
        i += 2
        
    return style_space, noise

def split_stylespace(style):
    style_space = []

    for idx in range(10):
        style_space.append(style[:, idx*512 : (idx+1) * 512])
    
    style_space.append(style[:, 10*512: 10*512 + 256])
    style_space.append(style[:, 10*512 + 256: 10*512 + 256*2])
    style_space.append(style[:, 10*512 + 256*2: 10*512 + 256*2 + 128])
    style_space.append(style[:, 10*512 + 256*2 + 128: 10*512 + 256*2 + 128 * 2])
    style_space.append(style[:, 10*512 + 256*2 + 128*2: 10*512 + 256*2 + 128*2 + 64])
    style_space.append(style[:, 10*512 + 256*2 + 128*2 + 64: 10*512 + 256*2 + 128*2 + 64*2])
    style_space.append(style[:, 10*512 + 256*2 + 128*2 + 64*2: 10*512 + 256*2 + 128*2 + 64*2 + 32])

    return style_space

def fuse_stylespace(style):
    new_s = torch.cat(style, dim=1)

    return new_s
Download .txt
gitextract_2trjho5x/

├── README.md
├── clip/
│   ├── __init__.py
│   ├── clip.py
│   ├── model.py
│   └── simple_tokenizer.py
├── datasets/
│   ├── test_dataset.py
│   └── train_dataset.py
├── delta_mapper.py
├── examples/
│   ├── cspace_img_feat.npy
│   ├── sspace_img_feat.npy
│   └── wplus_img_feat.npy
├── generate_codes.py
├── models/
│   ├── encoders/
│   │   ├── __init__.py
│   │   ├── helpers.py
│   │   ├── model_irse.py
│   │   └── psp_encoders.py
│   └── stylegan2/
│       ├── __init__.py
│       ├── model.py
│       ├── npy_ffhq/
│       │   └── fs3.npy
│       └── op/
│           ├── __init__.py
│           ├── fused_act.py
│           └── upfirdn2d.py
├── options/
│   ├── test_options.py
│   └── train_options.py
├── scripts/
│   ├── inference.py
│   ├── inference_real.py
│   └── train.py
├── tSNE/
│   ├── celeba/
│   │   ├── cspace_celeba_deltai.npy
│   │   ├── cspace_celeba_deltat.npy
│   │   ├── cspace_celeba_i.npy
│   │   └── cspace_celeba_t.npy
│   ├── cocoval/
│   │   ├── cspace_cocoval_deltai.npy
│   │   ├── cspace_cocoval_deltat.npy
│   │   ├── cspace_cocoval_i.npy
│   │   └── cspace_cocoval_t.npy
│   └── compute_tsne.py
└── utils/
    ├── map_tool.py
    └── stylespace_util.py
Download .txt
SYMBOL INDEX (195 symbols across 20 files)

FILE: clip/clip.py
  function _download (line 39) | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
  function _transform (line 71) | def _transform(n_px):
  function available_models (line 81) | def available_models() -> List[str]:
  function load (line 86) | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.c...
  function tokenize (line 185) | def tokenize(texts: Union[str, List[str]], context_length: int = 77, tru...

FILE: clip/model.py
  class Bottleneck (line 10) | class Bottleneck(nn.Module):
    method __init__ (line 13) | def __init__(self, inplanes, planes, stride=1):
    method forward (line 40) | def forward(self, x: torch.Tensor):
  class AttentionPool2d (line 56) | class AttentionPool2d(nn.Module):
    method __init__ (line 57) | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, o...
    method forward (line 66) | def forward(self, x):
  class ModifiedResNet (line 93) | class ModifiedResNet(nn.Module):
    method __init__ (line 101) | def __init__(self, layers, output_dim, heads, input_resolution=224, wi...
    method _make_layer (line 126) | def _make_layer(self, planes, blocks, stride=1):
    method forward (line 135) | def forward(self, x):
  class LayerNorm (line 153) | class LayerNorm(nn.LayerNorm):
    method forward (line 156) | def forward(self, x: torch.Tensor):
  class QuickGELU (line 162) | class QuickGELU(nn.Module):
    method forward (line 163) | def forward(self, x: torch.Tensor):
  class ResidualAttentionBlock (line 167) | class ResidualAttentionBlock(nn.Module):
    method __init__ (line 168) | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor ...
    method attention (line 181) | def attention(self, x: torch.Tensor):
    method forward (line 185) | def forward(self, x: torch.Tensor):
  class Transformer (line 191) | class Transformer(nn.Module):
    method __init__ (line 192) | def __init__(self, width: int, layers: int, heads: int, attn_mask: tor...
    method forward (line 198) | def forward(self, x: torch.Tensor):
  class VisionTransformer (line 202) | class VisionTransformer(nn.Module):
    method __init__ (line 203) | def __init__(self, input_resolution: int, patch_size: int, width: int,...
    method forward (line 219) | def forward(self, x: torch.Tensor):
  class CLIP (line 239) | class CLIP(nn.Module):
    method __init__ (line 240) | def __init__(self,
    method initialize_parameters (line 295) | def initialize_parameters(self):
    method build_attention_mask (line 324) | def build_attention_mask(self):
    method dtype (line 333) | def dtype(self):
    method encode_image (line 336) | def encode_image(self, image):
    method encode_text (line 339) | def encode_text(self, text):
    method forward (line 354) | def forward(self, image, text):
  function convert_weights (line 371) | def convert_weights(model: nn.Module):
  function build_model (line 395) | def build_model(state_dict: dict):

FILE: clip/simple_tokenizer.py
  function default_bpe (line 11) | def default_bpe():
  function bytes_to_unicode (line 16) | def bytes_to_unicode():
  function get_pairs (line 38) | def get_pairs(word):
  function basic_clean (line 50) | def basic_clean(text):
  function whitespace_clean (line 56) | def whitespace_clean(text):
  class SimpleTokenizer (line 62) | class SimpleTokenizer(object):
    method __init__ (line 63) | def __init__(self, bpe_path: str = default_bpe()):
    method bpe (line 80) | def bpe(self, token):
    method encode (line 121) | def encode(self, text):
    method decode (line 129) | def decode(self, tokens):

FILE: datasets/test_dataset.py
  class TestLatentsDataset (line 6) | class TestLatentsDataset(Dataset):
    method __init__ (line 7) | def __init__(self):
    method __len__ (line 22) | def __len__(self):
    method __getitem__ (line 26) | def __getitem__(self, index):

FILE: datasets/train_dataset.py
  class TrainLatentsDataset (line 8) | class TrainLatentsDataset(Dataset):
    method __init__ (line 9) | def __init__(self, opts, cycle=True):
    method __len__ (line 35) | def __len__(self):
    method __getitem__ (line 41) | def __getitem__(self, index):

FILE: delta_mapper.py
  class Mapper (line 10) | class Mapper(Module):
    method __init__ (line 12) | def __init__(self, in_channel=512, out_channel=512, norm=True, num_lay...
    method forward (line 22) | def forward(self, x):
  class DeltaMapper (line 26) | class DeltaMapper(Module):
    method __init__ (line 28) | def __init__(self):
    method forward (line 46) | def forward(self, sspace_feat, clip_feat):

FILE: generate_codes.py
  function save_image_pytorch (line 12) | def save_image_pytorch(img, name):
  function generate (line 24) | def generate(args, netG, device, mean_latent):

FILE: models/encoders/helpers.py
  class Flatten (line 11) | class Flatten(Module):
    method forward (line 12) | def forward(self, input):
  function l2_norm (line 16) | def l2_norm(input, axis=1):
  class Bottleneck (line 22) | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
  function get_block (line 26) | def get_block(in_channel, depth, num_units, stride=2):
  function get_blocks (line 30) | def get_blocks(num_layers):
  class SEModule (line 57) | class SEModule(Module):
    method __init__ (line 58) | def __init__(self, channels, reduction):
    method forward (line 66) | def forward(self, x):
  class bottleneck_IR (line 76) | class bottleneck_IR(Module):
    method __init__ (line 77) | def __init__(self, in_channel, depth, stride):
    method forward (line 92) | def forward(self, x):
  class bottleneck_IR_SE (line 98) | class bottleneck_IR_SE(Module):
    method __init__ (line 99) | def __init__(self, in_channel, depth, stride):
    method forward (line 117) | def forward(self, x):
  function _upsample_add (line 123) | def _upsample_add(x, y):

FILE: models/encoders/model_irse.py
  class Backbone (line 9) | class Backbone(Module):
    method __init__ (line 10) | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, ...
    method forward (line 44) | def forward(self, x):
  function IR_50 (line 51) | def IR_50(input_size):
  function IR_101 (line 57) | def IR_101(input_size):
  function IR_152 (line 63) | def IR_152(input_size):
  function IR_SE_50 (line 69) | def IR_SE_50(input_size):
  function IR_SE_101 (line 75) | def IR_SE_101(input_size):
  function IR_SE_152 (line 81) | def IR_SE_152(input_size):

FILE: models/encoders/psp_encoders.py
  class ProgressiveStage (line 12) | class ProgressiveStage(Enum):
  class GradualStyleBlock (line 34) | class GradualStyleBlock(Module):
    method __init__ (line 35) | def __init__(self, in_c, out_c, spatial):
    method forward (line 51) | def forward(self, x):
  class GradualStyleEncoder (line 58) | class GradualStyleEncoder(Module):
    method __init__ (line 59) | def __init__(self, num_layers, mode='ir', opts=None):
    method forward (line 95) | def forward(self, x):
  class Encoder4Editing (line 124) | class Encoder4Editing(Module):
    method __init__ (line 125) | def __init__(self, num_layers, stylegan_size, mode='ir'):
    method get_deltas_starting_dimensions (line 165) | def get_deltas_starting_dimensions(self):
    method set_progressive_stage (line 169) | def set_progressive_stage(self, new_stage: ProgressiveStage):
    method forward (line 173) | def forward(self, x):
  class BackboneEncoderUsingLastLayerIntoW (line 203) | class BackboneEncoderUsingLastLayerIntoW(Module):
    method __init__ (line 204) | def __init__(self, num_layers, mode='ir', opts=None):
    method forward (line 229) | def forward(self, x):

FILE: models/stylegan2/model.py
  class PixelNorm (line 10) | class PixelNorm(nn.Module):
    method __init__ (line 11) | def __init__(self):
    method forward (line 14) | def forward(self, input):
  function make_kernel (line 18) | def make_kernel(k):
  class Upsample (line 29) | class Upsample(nn.Module):
    method __init__ (line 30) | def __init__(self, kernel, factor=2):
    method forward (line 44) | def forward(self, input):
  class Downsample (line 50) | class Downsample(nn.Module):
    method __init__ (line 51) | def __init__(self, kernel, factor=2):
    method forward (line 65) | def forward(self, input):
  class Blur (line 71) | class Blur(nn.Module):
    method __init__ (line 72) | def __init__(self, kernel, pad, upsample_factor=1):
    method forward (line 84) | def forward(self, input):
  class EqualConv2d (line 90) | class EqualConv2d(nn.Module):
    method __init__ (line 91) | def __init__(
    method forward (line 110) | def forward(self, input):
    method __repr__ (line 121) | def __repr__(self):
  class EqualLinear (line 128) | class EqualLinear(nn.Module):
    method __init__ (line 129) | def __init__(
    method forward (line 147) | def forward(self, input):
    method __repr__ (line 159) | def __repr__(self):
  class ScaledLeakyReLU (line 165) | class ScaledLeakyReLU(nn.Module):
    method __init__ (line 166) | def __init__(self, negative_slope=0.2):
    method forward (line 171) | def forward(self, input):
  class ModulatedConv2d (line 177) | class ModulatedConv2d(nn.Module):
    method __init__ (line 178) | def __init__(
    method __repr__ (line 226) | def __repr__(self):
    method forward (line 232) | def forward(self, input, style):
  class NoiseInjection (line 276) | class NoiseInjection(nn.Module):
    method __init__ (line 277) | def __init__(self):
    method forward (line 282) | def forward(self, image, noise=None):
  class ConstantInput (line 290) | class ConstantInput(nn.Module):
    method __init__ (line 291) | def __init__(self, channel, size=4):
    method forward (line 296) | def forward(self, input):
  class StyledConv (line 303) | class StyledConv(nn.Module):
    method __init__ (line 304) | def __init__(
    method forward (line 331) | def forward(self, input, style, noise=None):
  class ToRGB (line 340) | class ToRGB(nn.Module):
    method __init__ (line 341) | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[...
    method forward (line 350) | def forward(self, input, style, skip=None):
  class Generator (line 362) | class Generator(nn.Module):
    method __init__ (line 363) | def __init__(
    method make_noise (line 448) | def make_noise(self):
    method mean_latent (line 459) | def mean_latent(self, n_latent):
    method get_latent (line 467) | def get_latent(self, input):
    method forward (line 470) | def forward(
  class ConvLayer (line 544) | class ConvLayer(nn.Sequential):
    method __init__ (line 545) | def __init__(
  class ResBlock (line 593) | class ResBlock(nn.Module):
    method __init__ (line 594) | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
    method forward (line 604) | def forward(self, input):
  class Discriminator (line 614) | class Discriminator(nn.Module):
    method __init__ (line 615) | def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
    method forward (line 654) | def forward(self, input):

FILE: models/stylegan2/op/fused_act.py
  class FusedLeakyReLU (line 9) | class FusedLeakyReLU(nn.Module):
    method __init__ (line 10) | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
    method forward (line 17) | def forward(self, input):
  function fused_leaky_relu (line 21) | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):

FILE: models/stylegan2/op/upfirdn2d.py
  function upfirdn2d (line 7) | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
  function upfirdn2d_native (line 12) | def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, ...

FILE: options/test_options.py
  class TestOptions (line 3) | class TestOptions:
    method __init__ (line 5) | def __init__(self):
    method initialize (line 9) | def initialize(self):
    method parse (line 25) | def parse(self):

FILE: options/train_options.py
  class TrainOptions (line 3) | class TrainOptions:
    method __init__ (line 5) | def __init__(self):
    method initialize (line 9) | def initialize(self):
    method parse (line 25) | def parse(self):

FILE: scripts/inference.py
  function GetBoundary (line 26) | def GetBoundary(fs3,dt,threshold):
  function improved_ds (line 32) | def improved_ds(ds, select):
  function main (line 38) | def main(opts):

FILE: scripts/inference_real.py
  function get_keys (line 33) | def get_keys(d, name):
  class Imagedataset (line 39) | class Imagedataset(Dataset):
    method __init__ (line 40) | def __init__(self,
    method __len__ (line 60) | def __len__(self):
    method __getitem__ (line 63) | def __getitem__(self, index):
  function encoder_latent (line 73) | def encoder_latent(G, latent):
  function GetBoundary (line 94) | def GetBoundary(fs3,dt,threshold):
  function improved_ds (line 100) | def improved_ds(ds, select):
  function main (line 106) | def main(opts):

FILE: scripts/train.py
  function main (line 14) | def main(opts):

FILE: utils/map_tool.py
  function zeroshot_classifier (line 89) | def zeroshot_classifier(classnames, templates,model):
  function GetDt (line 103) | def GetDt(classnames,model):

FILE: utils/stylespace_util.py
  function conv_warper (line 10) | def conv_warper(layer, input, style, noise):
  function decoder (line 58) | def decoder(G, style_space, latent, noise):
  function decoder_validate (line 78) | def decoder_validate(G, style_space, latent):
  function encoder_noise (line 101) | def encoder_noise(G, noise):
  function encoder_latent (line 122) | def encoder_latent(G, latent):
  function split_stylespace (line 141) | def split_stylespace(style):
  function fuse_stylespace (line 157) | def fuse_stylespace(style):
Condensed preview — 38 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (117K chars).
[
  {
    "path": "README.md",
    "chars": 6431,
    "preview": "# DeltaSpace: A Semantic-aligned Feature Space for Flexible Text-guided Image Editing\r\n\r\n## Overview\r\n\r\nThis repository "
  },
  {
    "path": "clip/__init__.py",
    "chars": 20,
    "preview": "from .clip import *\n"
  },
  {
    "path": "clip/clip.py",
    "chars": 8256,
    "preview": "import hashlib\nimport os\nimport urllib\nimport warnings\nfrom typing import Union, List\n\nimport torch\nfrom PIL import Imag"
  },
  {
    "path": "clip/model.py",
    "chars": 17246,
    "preview": "from collections import OrderedDict\nfrom typing import Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn.fun"
  },
  {
    "path": "clip/simple_tokenizer.py",
    "chars": 4628,
    "preview": "import gzip\nimport html\nimport os\nfrom functools import lru_cache\n\nimport ftfy\nimport regex as re\n\n\n@lru_cache()\ndef def"
  },
  {
    "path": "datasets/test_dataset.py",
    "chars": 1220,
    "preview": "import numpy as np\n\nimport torch\nfrom torch.utils.data import Dataset\n\nclass TestLatentsDataset(Dataset):\n    def __init"
  },
  {
    "path": "datasets/train_dataset.py",
    "chars": 2552,
    "preview": "import copy\nimport random\nimport numpy as np\n\nimport torch\nfrom torch.utils.data import Dataset\n\nclass TrainLatentsDatas"
  },
  {
    "path": "delta_mapper.py",
    "chars": 2392,
    "preview": "import math\n\nimport torch\nfrom torch import nn\nfrom torch.nn import Module\nimport torch.nn.functional as F\n\nfrom models."
  },
  {
    "path": "generate_codes.py",
    "chars": 5856,
    "preview": "import os\nimport argparse\nimport clip\n\nimport random\nimport numpy as np\nimport torch\nfrom torchvision import utils\nfrom "
  },
  {
    "path": "models/encoders/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "models/encoders/helpers.py",
    "chars": 4882,
    "preview": "from collections import namedtuple\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import Conv2d, BatchNorm2d"
  },
  {
    "path": "models/encoders/model_irse.py",
    "chars": 3241,
    "preview": "from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module\nfrom models.encoders.h"
  },
  {
    "path": "models/encoders/psp_encoders.py",
    "chars": 8697,
    "preview": "from enum import Enum\nimport math\nimport numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import Conv2d, Batc"
  },
  {
    "path": "models/stylegan2/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "models/stylegan2/model.py",
    "chars": 18291,
    "preview": "import math\nimport random\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom models.stylegan2"
  },
  {
    "path": "models/stylegan2/op/__init__.py",
    "chars": 89,
    "preview": "from .fused_act import FusedLeakyReLU, fused_leaky_relu\nfrom .upfirdn2d import upfirdn2d\n"
  },
  {
    "path": "models/stylegan2/op/fused_act.py",
    "chars": 1040,
    "preview": "import os\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nmodule_path = os.path.dirname(__file_"
  },
  {
    "path": "models/stylegan2/op/upfirdn2d.py",
    "chars": 1629,
    "preview": "import os\nimport torch\nfrom torch.nn import functional as F\n\nmodule_path = os.path.dirname(__file__)\n\ndef upfirdn2d(inpu"
  },
  {
    "path": "options/test_options.py",
    "chars": 1105,
    "preview": "from argparse import ArgumentParser\n\nclass TestOptions:\n\n\tdef __init__(self):\n\t\tself.parser = ArgumentParser()\n\t\tself.in"
  },
  {
    "path": "options/train_options.py",
    "chars": 1269,
    "preview": "from argparse import ArgumentParser\n\nclass TrainOptions:\n\n\tdef __init__(self):\n\t\tself.parser = ArgumentParser()\n\t\tself.i"
  },
  {
    "path": "scripts/inference.py",
    "chars": 3583,
    "preview": "import os\nimport sys\nsys.path.append(\".\")\nsys.path.append(\"..\")\n\nimport copy\nimport clip\nimport numpy as np\n\nimport torc"
  },
  {
    "path": "scripts/inference_real.py",
    "chars": 6408,
    "preview": "import os\nimport sys\nsys.path.append(\".\")\nsys.path.append(\"..\")\n\nimport copy\nimport clip\nimport numpy as np\nfrom PIL imp"
  },
  {
    "path": "scripts/train.py",
    "chars": 2006,
    "preview": "import os\nimport sys\n\nimport torch\nfrom torch.utils.data import DataLoader\n\nsys.path.append(\".\")\nsys.path.append(\"..\")\n\n"
  },
  {
    "path": "tSNE/compute_tsne.py",
    "chars": 1980,
    "preview": "import numpy as np\nimport matplotlib.pyplot as plt\nfrom sklearn.manifold import TSNE\n\ndata_name = 'celeba' # 'celeba' or"
  },
  {
    "path": "utils/map_tool.py",
    "chars": 3655,
    "preview": "import torch\nimport clip\nimport os\nimport numpy as np\n\nimagenet_templates = [\n    'a bad photo of a {}.',\n#    'a photo "
  },
  {
    "path": "utils/stylespace_util.py",
    "chars": 5221,
    "preview": "import torch\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport torchvision\n\nfrom torch.nn import functional as F"
  }
]

// ... and 12 more files (download for full content)

About this extraction

This page contains the full source code of the Yueming6568/DeltaEdit GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 38 files (109.1 KB), approximately 29.6k tokens, and a symbol index with 195 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!