Repository: Yueming6568/DeltaEdit Branch: main Commit: d80bc4bfdc6d Files: 38 Total size: 109.1 KB Directory structure: gitextract_2trjho5x/ ├── README.md ├── clip/ │ ├── __init__.py │ ├── clip.py │ ├── model.py │ └── simple_tokenizer.py ├── datasets/ │ ├── test_dataset.py │ └── train_dataset.py ├── delta_mapper.py ├── examples/ │ ├── cspace_img_feat.npy │ ├── sspace_img_feat.npy │ └── wplus_img_feat.npy ├── generate_codes.py ├── models/ │ ├── encoders/ │ │ ├── __init__.py │ │ ├── helpers.py │ │ ├── model_irse.py │ │ └── psp_encoders.py │ └── stylegan2/ │ ├── __init__.py │ ├── model.py │ ├── npy_ffhq/ │ │ └── fs3.npy │ └── op/ │ ├── __init__.py │ ├── fused_act.py │ └── upfirdn2d.py ├── options/ │ ├── test_options.py │ └── train_options.py ├── scripts/ │ ├── inference.py │ ├── inference_real.py │ └── train.py ├── tSNE/ │ ├── celeba/ │ │ ├── cspace_celeba_deltai.npy │ │ ├── cspace_celeba_deltat.npy │ │ ├── cspace_celeba_i.npy │ │ └── cspace_celeba_t.npy │ ├── cocoval/ │ │ ├── cspace_cocoval_deltai.npy │ │ ├── cspace_cocoval_deltat.npy │ │ ├── cspace_cocoval_i.npy │ │ └── cspace_cocoval_t.npy │ └── compute_tsne.py └── utils/ ├── map_tool.py └── stylespace_util.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: README.md ================================================ # DeltaSpace: A Semantic-aligned Feature Space for Flexible Text-guided Image Editing ## Overview This repository contains the **offical** PyTorch implementation of paper: *DeltaEdit: Exploring Text-free Training for Text-driven Image Manipulation*, CVPR 2023 *DeltaSpace: A Semantic-aligned Feature Space for Flexible Text-guided Image Editing*, Arxiv 2023 ## News - [2025-06-22] Upload t-SNE Code for Alignment Validation​ (◍>◡<◍). - [2023-03-11] Upload the training and inference code for the facial domain (◍•ڡ•◍). *To be continued...* ## Dependences - Install CLIP: ```shell script conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit= pip install ftfy regex tqdm gdown pip install git+https://github.com/openai/CLIP.git ``` - Download pre-trained models : - The code relies on the [Rosinality](https://github.com/rosinality/stylegan2-pytorch/) pytorch implementation of StyleGAN2. - Download the pre-trained StyleGAN2 generator model for the faical domain from [here](https://drive.google.com/file/d/1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT/view?usp=sharing), and then place it into the folder `./models/pretrained_models`. - Download the pre-trained StyleGAN2 generator model for the LSUN cat, church, horse domains from [here](https://drive.google.com/drive/folders/1YRhXGM-2xk7A4TExM_jXaNg1f2AiCRlw?usp=share_link) and then place them into the folder `./models/pretrained_models/stylegan2-{cat/church/horse}`. ## Training ### Data preparing - DeltaEdit is trained on latent vectors. - For the facial domain, 58,000 real images from [FFHQ](https://github.com/NVlabs/ffhq-dataset) dataset are randomly selected and 200,000 fake images from the z space in StyleGAN are sampled for training. Note that all real images are inverted by [e4e](https://github.com/omertov/encoder4editing) encoder. - Download the provided FFHQ latent vectors from [here](https://drive.google.com/drive/folders/13NLq4giSgdcMVkYQIiPj4Xhxz4-wlXSD?usp=sharing) and then place all numpy files into the folder `./latent_code/ffhq`. - Generate the 200,000 sampled latent vectors by running the following commands for each specific domain: ```python CUDA_VISIBLE_DEVICES=0 python generate_codes.py --classname ffhq --samples 200000 CUDA_VISIBLE_DEVICES=0 python generate_codes.py --classname cat --samples 200000 CUDA_VISIBLE_DEVICES=0 python generate_codes.py --classname church --samples 200000 CUDA_VISIBLE_DEVICES=0 python generate_codes.py --classname horse --samples 200000 ``` ### Usage - The main training script is placed in `./scripts/train.py`. - Training arguments can be found at `./options/train_options.py`. For training please run the following commands: ```python CUDA_VISIBLE_DEVICES=0 python scripts/train.py ``` ## Inference - The main inferece script is placed in `./scripts/inference.py`. - Inference arguments can be found at `./options/test_options.py`. - Download the pretrained DeltaMapper model for editing human face from [here](https://drive.google.com/file/d/1Mb2WiELoVDPDIi24tIfoWsjn1l2xTjtZ/view?usp=sharing), and then place it into the folder `./checkpoints` . - Some inference data are provided in `./examples`. To produce editing results please run the following commands : ```python CUDA_VISIBLE_DEVICES=1 python scripts/inference.py --target "chubby face","face with eyeglasses","face with smile","face with pale skin","face with tanned skin","face with big eyes","face with black clothes","face with blue suit","happy face","face with bangs","face with red hair","face with black hair","face with blond hair","face with curly hair","face with receding hairline","face with bowlcut hairstyle" ``` The produced results are showed in the following. You can also specify your desired target attributes to the flag of `--target`. ## Inference for real images - The main inferece script is placed in `./scripts/inference_real.py`. - Inference arguments can be found at `./options/test_options.py`. - Download the pretrained DeltaMapper model for editing human face from [here](https://drive.google.com/file/d/1Mb2WiELoVDPDIi24tIfoWsjn1l2xTjtZ/view?usp=sharing), and then place it into the folder `./checkpoints` . - Download the pretrained e4e encoder e4e_ffhq_encode.pt from [e4e](https://github.com/omertov/encoder4editing). - One test image is provided in `./test_imgs`. To produce editing results please run the following commands : ```python CUDA_VISIBLE_DEVICES=1 python scripts/inference_real.py --target "chubby face","face with eyeglasses","face with smile","face with pale skin","face with tanned skin","face with big eyes","face with black clothes","face with blue suit","happy face","face with bangs","face with red hair","face with black hair","face with blond hair","face with curly hair","face with receding hairline","face with bowlcut hairstyle" ``` ## Alignment Validation: CLIP Space vs. DeltaSpace via t-SNE Visualization​ ```python cd tSNE python compute_tsne.py ``` After executing the implementation code, you can obtain A 2D t-SNE projection of embeddings from both spaces (e.g., CLIP and DeltaSpace). The results are shown below for your convenience. ![tsne](./tsne.jpg) ## Results ![results](./results.jpg) ## Acknowledgements This code is developed based on the code of [orpatashnik/StyleCLIP](https://github.com/orpatashnik/StyleCLIP) by Or Patashnik et al. ## Citation If you use this code for your research, please cite our paper: ``` @InProceedings{lyu2023deltaedit, author = {Lyu, Yueming and Lin, Tianwei and Li, Fu and He, Dongliang and Dong, Jing and Tan, Tieniu}, title = {DeltaEdit: Exploring Text-free Training for Text-Driven Image Manipulation}, booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, year = {2023} } @article{lyu2023deltaspace, author = {Lyu, Yueming and Zhao, Kang and Peng, Bo and Chen, Huafeng and Jiang, Yue and Zhang, Yingya and Dong, Jing, and Shan Caifeng}, title = {DeltaSpace: A Semantic-aligned Feature Space for Flexible Text-guided Image Editing}, journal = {arXiv preprint arXiv:2310.08785}, year = {2023}, } ``` ================================================ FILE: clip/__init__.py ================================================ from .clip import * ================================================ FILE: clip/clip.py ================================================ import hashlib import os import urllib import warnings from typing import Union, List import torch from PIL import Image from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize from tqdm import tqdm from .model import build_model from .simple_tokenizer import SimpleTokenizer as _Tokenizer try: from torchvision.transforms import InterpolationMode BICUBIC = InterpolationMode.BICUBIC except ImportError: BICUBIC = Image.BICUBIC if torch.__version__.split(".") < ["1", "7", "1"]: warnings.warn("PyTorch version 1.7.1 or higher is recommended") __all__ = ["available_models", "load", "tokenize"] _tokenizer = _Tokenizer() _MODELS = { "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", } def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): os.makedirs(root, exist_ok=True) filename = os.path.basename(url) expected_sha256 = url.split("/")[-2] download_target = os.path.join(root, filename) if os.path.exists(download_target) and not os.path.isfile(download_target): raise RuntimeError(f"{download_target} exists and is not a regular file") if os.path.isfile(download_target): if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: return download_target else: warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: while True: buffer = source.read(8192) if not buffer: break output.write(buffer) loop.update(len(buffer)) if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") return download_target def _transform(n_px): return Compose([ Resize(n_px, interpolation=BICUBIC), CenterCrop(n_px), lambda image: image.convert("RGB"), ToTensor(), Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) def available_models() -> List[str]: """Returns the names of available CLIP models""" return list(_MODELS.keys()) def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False): """Load a CLIP model Parameters ---------- name : str A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict device : Union[str, torch.device] The device to put the loaded model jit : bool Whether to load the optimized JIT model or more hackable non-JIT model (default). Returns ------- model : torch.nn.Module The CLIP model preprocess : Callable[[PIL.Image], torch.Tensor] A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input """ if name in _MODELS: model_path = _download(_MODELS[name]) elif os.path.isfile(name): model_path = name else: raise RuntimeError(f"Model {name} not found; available models = {available_models()}") try: # loading JIT archive model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() state_dict = None except RuntimeError: # loading saved state dict if jit: warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") jit = False state_dict = torch.load(model_path, map_location="cpu") if not jit: model = build_model(state_dict or model.state_dict()).to(device) if str(device) == "cpu": model.float() return model, _transform(model.visual.input_resolution) # patch the device names device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] def patch_device(module): try: graphs = [module.graph] if hasattr(module, "graph") else [] except RuntimeError: graphs = [] if hasattr(module, "forward1"): graphs.append(module.forward1.graph) for graph in graphs: for node in graph.findAllNodes("prim::Constant"): if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): node.copyAttributes(device_node) model.apply(patch_device) patch_device(model.encode_image) patch_device(model.encode_text) # patch dtype to float32 on CPU if str(device) == "cpu": float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] float_node = float_input.node() def patch_float(module): try: graphs = [module.graph] if hasattr(module, "graph") else [] except RuntimeError: graphs = [] if hasattr(module, "forward1"): graphs.append(module.forward1.graph) for graph in graphs: for node in graph.findAllNodes("aten::to"): inputs = list(node.inputs()) for i in [1, 2]: # dtype can be the second or third argument to aten::to() if inputs[i].node()["value"] == 5: inputs[i].node().copyAttributes(float_node) model.apply(patch_float) patch_float(model.encode_image) patch_float(model.encode_text) model.float() return model, _transform(model.input_resolution.item()) def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: """ Returns the tokenized representation of given input string(s) Parameters ---------- texts : Union[str, List[str]] An input string or a list of input strings to tokenize context_length : int The context length to use; all CLIP models use 77 as the context length truncate: bool Whether to truncate the text in case its encoding is longer than the context length Returns ------- A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] """ if isinstance(texts, str): texts = [texts] sot_token = _tokenizer.encoder["<|startoftext|>"] eot_token = _tokenizer.encoder["<|endoftext|>"] all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) for i, tokens in enumerate(all_tokens): if len(tokens) > context_length: if truncate: tokens = tokens[:context_length] tokens[-1] = eot_token else: raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") result[i, :len(tokens)] = torch.tensor(tokens) return result ================================================ FILE: clip/model.py ================================================ from collections import OrderedDict from typing import Tuple, Union import numpy as np import torch import torch.nn.functional as F from torch import nn class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1): super().__init__() # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) self.bn3 = nn.BatchNorm2d(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = None self.stride = stride if stride > 1 or inplanes != planes * Bottleneck.expansion: # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 self.downsample = nn.Sequential(OrderedDict([ ("-1", nn.AvgPool2d(stride)), ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), ("1", nn.BatchNorm2d(planes * self.expansion)) ])) def forward(self, x: torch.Tensor): identity = x out = self.relu(self.bn1(self.conv1(x))) out = self.relu(self.bn2(self.conv2(out))) out = self.avgpool(out) out = self.bn3(self.conv3(out)) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class AttentionPool2d(nn.Module): def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): super().__init__() self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) self.k_proj = nn.Linear(embed_dim, embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) self.num_heads = num_heads def forward(self, x): x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC x, _ = F.multi_head_attention_forward( query=x, key=x, value=x, embed_dim_to_check=x.shape[-1], num_heads=self.num_heads, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, in_proj_weight=None, in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), bias_k=None, bias_v=None, add_zero_attn=False, dropout_p=0, out_proj_weight=self.c_proj.weight, out_proj_bias=self.c_proj.bias, use_separate_proj_weight=True, training=self.training, need_weights=False ) return x[0] class ModifiedResNet(nn.Module): """ A ResNet class that is similar to torchvision's but contains the following changes: - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 - The final pooling layer is a QKV attention instead of an average pool """ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): super().__init__() self.output_dim = output_dim self.input_resolution = input_resolution # the 3-layer stem self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(width // 2) self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(width // 2) self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) self.bn3 = nn.BatchNorm2d(width) self.avgpool = nn.AvgPool2d(2) self.relu = nn.ReLU(inplace=True) # residual layers self._inplanes = width # this is a *mutable* variable used during construction self.layer1 = self._make_layer(width, layers[0]) self.layer2 = self._make_layer(width * 2, layers[1], stride=2) self.layer3 = self._make_layer(width * 4, layers[2], stride=2) self.layer4 = self._make_layer(width * 8, layers[3], stride=2) embed_dim = width * 32 # the ResNet feature dimension self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) def _make_layer(self, planes, blocks, stride=1): layers = [Bottleneck(self._inplanes, planes, stride)] self._inplanes = planes * Bottleneck.expansion for _ in range(1, blocks): layers.append(Bottleneck(self._inplanes, planes)) return nn.Sequential(*layers) def forward(self, x): def stem(x): for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: x = self.relu(bn(conv(x))) x = self.avgpool(x) return x x = x.type(self.conv1.weight.dtype) x = stem(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.attnpool(x) return x class LayerNorm(nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16.""" def forward(self, x: torch.Tensor): orig_type = x.dtype ret = super().forward(x.type(torch.float32)) return ret.type(orig_type) class QuickGELU(nn.Module): def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x) class ResidualAttentionBlock(nn.Module): def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): super().__init__() self.attn = nn.MultiheadAttention(d_model, n_head) self.ln_1 = LayerNorm(d_model) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()), ("c_proj", nn.Linear(d_model * 4, d_model)) ])) self.ln_2 = LayerNorm(d_model) self.attn_mask = attn_mask def attention(self, x: torch.Tensor): self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] def forward(self, x: torch.Tensor): x = x + self.attention(self.ln_1(x)) x = x + self.mlp(self.ln_2(x)) return x class Transformer(nn.Module): def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): super().__init__() self.width = width self.layers = layers self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) def forward(self, x: torch.Tensor): return self.resblocks(x) class VisionTransformer(nn.Module): def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): super().__init__() self.input_resolution = input_resolution self.output_dim = output_dim self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) scale = width ** -0.5 self.class_embedding = nn.Parameter(scale * torch.randn(width)) self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) self.ln_pre = LayerNorm(width) self.transformer = Transformer(width, layers, heads) self.ln_post = LayerNorm(width) self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) def forward(self, x: torch.Tensor): x = self.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] x = x + self.positional_embedding.to(x.dtype) x = self.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_post(x[:, 0, :]) if self.proj is not None: x = x @ self.proj return x class CLIP(nn.Module): def __init__(self, embed_dim: int, # vision image_resolution: int, vision_layers: Union[Tuple[int, int, int, int], int], vision_width: int, vision_patch_size: int, # text context_length: int, vocab_size: int, transformer_width: int, transformer_heads: int, transformer_layers: int ): super().__init__() self.context_length = context_length if isinstance(vision_layers, (tuple, list)): vision_heads = vision_width * 32 // 64 self.visual = ModifiedResNet( layers=vision_layers, output_dim=embed_dim, heads=vision_heads, input_resolution=image_resolution, width=vision_width ) else: vision_heads = vision_width // 64 self.visual = VisionTransformer( input_resolution=image_resolution, patch_size=vision_patch_size, width=vision_width, layers=vision_layers, heads=vision_heads, output_dim=embed_dim ) self.transformer = Transformer( width=transformer_width, layers=transformer_layers, heads=transformer_heads, attn_mask=self.build_attention_mask() ) self.vocab_size = vocab_size self.token_embedding = nn.Embedding(vocab_size, transformer_width) self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) self.ln_final = LayerNorm(transformer_width) self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.initialize_parameters() def initialize_parameters(self): nn.init.normal_(self.token_embedding.weight, std=0.02) nn.init.normal_(self.positional_embedding, std=0.01) if isinstance(self.visual, ModifiedResNet): if self.visual.attnpool is not None: std = self.visual.attnpool.c_proj.in_features ** -0.5 nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: for name, param in resnet_block.named_parameters(): if name.endswith("bn3.weight"): nn.init.zeros_(param) proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) attn_std = self.transformer.width ** -0.5 fc_std = (2 * self.transformer.width) ** -0.5 for block in self.transformer.resblocks: nn.init.normal_(block.attn.in_proj_weight, std=attn_std) nn.init.normal_(block.attn.out_proj.weight, std=proj_std) nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) if self.text_projection is not None: nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) def build_attention_mask(self): # lazily create causal attention mask, with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask @property def dtype(self): return self.visual.conv1.weight.dtype def encode_image(self, image): return self.visual(image.type(self.dtype)) def encode_text(self, text): x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding.type(self.dtype) x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x).type(self.dtype) # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection return x def forward(self, image, text): image_features = self.encode_image(image) text_features = self.encode_text(text) # normalized features image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) # cosine similarity as logits logit_scale = self.logit_scale.exp() logits_per_image = logit_scale * image_features @ text_features.t() logits_per_text = logit_scale * text_features @ image_features.t() # shape = [global_batch_size, global_batch_size] return logits_per_image, logits_per_text def convert_weights(model: nn.Module): """Convert applicable model parameters to fp16""" def _convert_weights_to_fp16(l): if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): l.weight.data = l.weight.data.half() if l.bias is not None: l.bias.data = l.bias.data.half() if isinstance(l, nn.MultiheadAttention): for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: tensor = getattr(l, attr) if tensor is not None: tensor.data = tensor.data.half() for name in ["text_projection", "proj"]: if hasattr(l, name): attr = getattr(l, name) if attr is not None: attr.data = attr.data.half() model.apply(_convert_weights_to_fp16) def build_model(state_dict: dict): vit = "visual.proj" in state_dict if vit: vision_width = state_dict["visual.conv1.weight"].shape[0] vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) image_resolution = vision_patch_size * grid_size else: counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] vision_layers = tuple(counts) vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) vision_patch_size = None assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] image_resolution = output_width * 32 embed_dim = state_dict["text_projection"].shape[1] context_length = state_dict["positional_embedding"].shape[0] vocab_size = state_dict["token_embedding.weight"].shape[0] transformer_width = state_dict["ln_final.weight"].shape[0] transformer_heads = transformer_width // 64 transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) model = CLIP( embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size, context_length, vocab_size, transformer_width, transformer_heads, transformer_layers ) for key in ["input_resolution", "context_length", "vocab_size"]: if key in state_dict: del state_dict[key] convert_weights(model) model.load_state_dict(state_dict) return model.eval() ================================================ FILE: clip/simple_tokenizer.py ================================================ import gzip import html import os from functools import lru_cache import ftfy import regex as re @lru_cache() def default_bpe(): return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") @lru_cache() def bytes_to_unicode(): """ Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a signficant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on. """ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) cs = bs[:] n = 0 for b in range(2**8): if b not in bs: bs.append(b) cs.append(2**8+n) n += 1 cs = [chr(n) for n in cs] return dict(zip(bs, cs)) def get_pairs(word): """Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length strings). """ pairs = set() prev_char = word[0] for char in word[1:]: pairs.add((prev_char, char)) prev_char = char return pairs def basic_clean(text): text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) return text.strip() def whitespace_clean(text): text = re.sub(r'\s+', ' ', text) text = text.strip() return text class SimpleTokenizer(object): def __init__(self, bpe_path: str = default_bpe()): self.byte_encoder = bytes_to_unicode() self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') merges = merges[1:49152-256-2+1] merges = [tuple(merge.split()) for merge in merges] vocab = list(bytes_to_unicode().values()) vocab = vocab + [v+'' for v in vocab] for merge in merges: vocab.append(''.join(merge)) vocab.extend(['<|startoftext|>', '<|endoftext|>']) self.encoder = dict(zip(vocab, range(len(vocab)))) self.decoder = {v: k for k, v in self.encoder.items()} self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) def bpe(self, token): if token in self.cache: return self.cache[token] word = tuple(token[:-1]) + ( token[-1] + '',) pairs = get_pairs(word) if not pairs: return token+'' while True: bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) if bigram not in self.bpe_ranks: break first, second = bigram new_word = [] i = 0 while i < len(word): try: j = word.index(first, i) new_word.extend(word[i:j]) i = j except: new_word.extend(word[i:]) break if word[i] == first and i < len(word)-1 and word[i+1] == second: new_word.append(first+second) i += 2 else: new_word.append(word[i]) i += 1 new_word = tuple(new_word) word = new_word if len(word) == 1: break else: pairs = get_pairs(word) word = ' '.join(word) self.cache[token] = word return word def encode(self, text): bpe_tokens = [] text = whitespace_clean(basic_clean(text)).lower() for token in re.findall(self.pat, text): token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) return bpe_tokens def decode(self, tokens): text = ''.join([self.decoder[token] for token in tokens]) text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') return text ================================================ FILE: datasets/test_dataset.py ================================================ import numpy as np import torch from torch.utils.data import Dataset class TestLatentsDataset(Dataset): def __init__(self): style_latents_list = [] clip_latents_list = [] wplus_latents_list = [] #change the paths here for testing other latent codes style_latents_list.append(torch.Tensor(np.load("./examples/sspace_img_feat.npy"))) clip_latents_list.append(torch.Tensor(np.load("./examples/cspace_img_feat.npy"))) wplus_latents_list.append(torch.Tensor(np.load("./examples/wplus_img_feat.npy"))) self.style_latents = torch.cat(style_latents_list, dim=0) self.clip_latents = torch.cat(clip_latents_list, dim=0) self.wplus_latents = torch.cat(wplus_latents_list, dim=0) def __len__(self): return self.style_latents.shape[0] def __getitem__(self, index): latent_s1 = self.style_latents[index] latent_c1 = self.clip_latents[index] latent_w1 = self.wplus_latents[index] latent_c1 = latent_c1 / latent_c1.norm(dim=-1, keepdim=True).float() delta_c = torch.cat([latent_c1, latent_c1], dim=0) return latent_s1, delta_c, latent_w1 ================================================ FILE: datasets/train_dataset.py ================================================ import copy import random import numpy as np import torch from torch.utils.data import Dataset class TrainLatentsDataset(Dataset): def __init__(self, opts, cycle=True): style_latents_list = [] clip_latents_list = [] wplus_latents_list = [] style_latents_list.append(torch.Tensor(np.load(f"./latent_code/{opts.classname}/sspace_noise_feat.npy"))) clip_latents_list.append(torch.Tensor(np.load(f"./latent_code/{opts.classname}/cspace_noise_feat.npy"))) wplus_latents_list.append(torch.Tensor(np.load(f"./latent_code/{opts.classname}/wspace_noise_feat.npy"))) style_latents_list.append(torch.Tensor(np.load(f"./latent_code/{opts.classname}/sspace_ffhq_feat.npy"))) clip_latents_list.append(torch.Tensor(np.load(f"./latent_code/{opts.classname}/cspace_ffhq_feat.npy"))) wplus_latents_list.append(torch.Tensor(np.load(f"./latent_code/{opts.classname}/wspace_ffhq_feat.npy"))) self.style_latents = torch.cat(style_latents_list, dim=0) self.clip_latents = torch.cat(clip_latents_list, dim=0) self.wplus_latents = torch.cat(wplus_latents_list, dim=0) self.style_latents = self.style_latents[:200000+58000] self.clip_latents = self.clip_latents[:200000+58000] self.wplus_latents = self.wplus_latents[:200000+58000] self.dataset_size = self.style_latents.shape[0] print("dataset size", self.dataset_size) self.cycle = cycle def __len__(self): if self.cycle: return self.style_latents.shape[0] * 50 else: return self.style_latents.shape[0] def __getitem__(self, index): if self.cycle: index = index % self.dataset_size latent_s1 = self.style_latents[index] latent_c1 = self.clip_latents[index] latent_w1 = self.wplus_latents[index] latent_c1 = latent_c1 / latent_c1.norm(dim=-1, keepdim=True).float() random_index = random.randint(0, self.dataset_size - 1) latent_s2 = self.style_latents[random_index] latent_c2 = self.clip_latents[random_index] latent_w2 = self.wplus_latents[random_index] latent_c2 = latent_c2 / latent_c2.norm(dim=-1, keepdim=True).float() delta_s1 = latent_s2 - latent_s1 delta_c = latent_c2 - latent_c1 delta_c = delta_c / delta_c.norm(dim=-1, keepdim=True).float().clamp(min=1e-5) delta_c = torch.cat([latent_c1, delta_c], dim=0) return latent_s1, delta_c, delta_s1 ================================================ FILE: delta_mapper.py ================================================ import math import torch from torch import nn from torch.nn import Module import torch.nn.functional as F from models.stylegan2.model import EqualLinear, PixelNorm class Mapper(Module): def __init__(self, in_channel=512, out_channel=512, norm=True, num_layers=4): super(Mapper, self).__init__() layers = [PixelNorm()] if norm else [] layers.append(EqualLinear(in_channel, out_channel, lr_mul=0.01, activation='fused_lrelu')) for _ in range(num_layers-1): layers.append(EqualLinear(out_channel, out_channel, lr_mul=0.01, activation='fused_lrelu')) self.mapping = nn.Sequential(*layers) def forward(self, x): x = self.mapping(x) return x class DeltaMapper(Module): def __init__(self): super(DeltaMapper, self).__init__() #Style Module(sm) self.sm_coarse = Mapper(512, 512) self.sm_medium = Mapper(512, 512) self.sm_fine = Mapper(2464, 2464) #Condition Module(cm) self.cm_coarse = Mapper(1024, 512) self.cm_medium = Mapper(1024, 512) self.cm_fine = Mapper(1024, 2464) #Fusion Module(fm) self.fm_coarse = Mapper(512*2, 512, norm=False) self.fm_medium = Mapper(512*2, 512, norm=False) self.fm_fine = Mapper(2464*2, 2464, norm=False) def forward(self, sspace_feat, clip_feat): s_coarse = sspace_feat[:, :3*512].view(-1,3,512) s_medium = sspace_feat[:, 3*512:7*512].view(-1,4,512) s_fine = sspace_feat[:, 7*512:] #channels:2464 s_coarse = self.sm_coarse(s_coarse) s_medium = self.sm_medium(s_medium) s_fine = self.sm_fine(s_fine) c_coarse = self.cm_coarse(clip_feat) c_medium = self.cm_medium(clip_feat) c_fine = self.cm_fine(clip_feat) x_coarse = torch.cat([s_coarse, torch.stack([c_coarse]*3, dim=1)], dim=2) #[b,3,1024] x_medium = torch.cat([s_medium, torch.stack([c_medium]*4, dim=1)], dim=2) #[b,4,1024] x_fine = torch.cat([s_fine, c_fine], dim=1) #[b,2464*2] x_coarse = self.fm_coarse(x_coarse) x_coarse = x_coarse.view(-1,3*512) x_medium = self.fm_medium(x_medium) x_medium = x_medium.view(-1,4*512) x_fine = self.fm_fine(x_fine) out = torch.cat([x_coarse, x_medium, x_fine], dim=1) return out ================================================ FILE: generate_codes.py ================================================ import os import argparse import clip import random import numpy as np import torch from torchvision import utils from utils import stylespace_util from models.stylegan2.model import Generator def save_image_pytorch(img, name): """Helper function to save torch tensor into an image file.""" utils.save_image( img, name, nrow=1, padding=0, normalize=True, range=(-1, 1), ) def generate(args, netG, device, mean_latent): device = "cuda" if torch.cuda.is_available() else "cpu" model, preprocess = clip.load("ViT-B/32", device=device) avg_pool = torch.nn.AvgPool2d(kernel_size=1024 // 32) upsample = torch.nn.Upsample(scale_factor=7) ind = 0 with torch.no_grad(): netG.eval() # Generate images from a file of input noises if args.fixed_z is not None: sample_z = torch.load(args.fixed_z, map_location=device) for start in range(0, sample_z.size(0), args.batch_size): end = min(start + args.batch_size, sample_z.size(0)) z_batch = sample_z[start:end] sample, _ = netG([z_batch], truncation=args.truncation, truncation_latent=mean_latent) for s in sample: save_image_pytorch(s, f'{args.save_dir}/{str(ind).zfill(6)}.png') ind += 1 return # Generate image by sampling input noises w_latents_list = [] s_latents_list = [] c_latents_list = [] for start in range(0, args.samples, args.batch_size): end = min(start + args.batch_size, args.samples) batch_sz = end - start print(f'current_num:{start}') sample_z = torch.randn(batch_sz, 512, device=device) sample, w_latents = netG([sample_z], truncation=args.truncation, truncation_latent=mean_latent,return_latents=True) style_space, noise = stylespace_util.encoder_latent(netG, w_latents) s_latents = torch.cat(style_space, dim=1) tmp_imgs = stylespace_util.decoder(netG, style_space, w_latents, noise) # for s in tmp_imgs: # save_image_pytorch(s, f'{args.save_dir}/{str(ind).zfill(6)}.png') # ind += 1 img_gen_for_clip = upsample(tmp_imgs) img_gen_for_clip = avg_pool(img_gen_for_clip) c_latents = model.encode_image(img_gen_for_clip) w_latents_list.append(w_latents) s_latents_list.append(s_latents) c_latents_list.append(c_latents) w_all_latents = torch.cat(w_latents_list, dim=0) s_all_latents = torch.cat(s_latents_list, dim=0) c_all_latents = torch.cat(c_latents_list, dim=0) print(w_all_latents.size()) print(s_all_latents.size()) print(c_all_latents.size()) w_all_latents = w_all_latents.cpu().numpy() s_all_latents = s_all_latents.cpu().numpy() c_all_latents = c_all_latents.cpu().numpy() os.makedirs(os.path.join(args.save_dir, args.classname), exist_ok=True) np.save(f"{args.save_dir}/{args.classname}/wspace_noise_feat.npy", w_all_latents) np.save(f"{args.save_dir}/{args.classname}/sspace_noise_feat.npy", s_all_latents) np.save(f"{args.save_dir}/{args.classname}/cspace_noise_feat.npy", c_all_latents) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--classname', type=str, default='ffhq', help="place to save the output") parser.add_argument('--save_dir', type=str, default='./latent_code', help="place to save the output") parser.add_argument('--ckpt', type=str, default='./models/pretrained_models', help="checkpoint file for the generator") parser.add_argument('--size', type=int, default=1024, help="output size of the generator") parser.add_argument('--fixed_z', type=str, default=None, help="expect a .pth file. If given, will use this file as the input noise for the output") parser.add_argument('--w_shift', type=str, default=None, help="expect a .pth file. Apply a w-latent shift to the generator") parser.add_argument('--batch_size', type=int, default=10, help="batch size used to generate outputs") parser.add_argument('--samples', type=int, default=200000, help="200000 number of samples to generate, will be overridden if --fixed_z is given") parser.add_argument('--truncation', type=float, default=1, help="strength of truncation:0.5ori") parser.add_argument('--truncation_mean', type=int, default=4096, help="number of samples to calculate the mean latent for truncation") parser.add_argument('--seed', type=int, default=None, help="if specified, use a fixed random seed") parser.add_argument('--device', type=str, default='cuda') args = parser.parse_args() device = args.device # use a fixed seed if given if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) netG = Generator(args.size, 512, 8).to(device) if args.classname == 'ffhq': ckpt_path = os.path.join(args.ckpt,f'stylegan2-{args.classname}-config-f.pt') else: ckpt_path = os.path.join(args.ckpt,f'stylegan2-{args.classname}','netG.pth') print(ckpt_path) checkpoint = torch.load(ckpt_path, map_location='cpu') if args.classname == 'ffhq': netG.load_state_dict(checkpoint['g_ema']) else: netG.load_state_dict(checkpoint) # get mean latent if truncation is applied if args.truncation < 1: with torch.no_grad(): mean_latent = netG.mean_latent(args.truncation_mean) else: mean_latent = None generate(args, netG, device, mean_latent) ================================================ FILE: models/encoders/__init__.py ================================================ ================================================ FILE: models/encoders/helpers.py ================================================ from collections import namedtuple import torch import torch.nn.functional as F from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module """ ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) """ class Flatten(Module): def forward(self, input): return input.view(input.size(0), -1) def l2_norm(input, axis=1): norm = torch.norm(input, 2, axis, True) output = torch.div(input, norm) return output class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): """ A named tuple describing a ResNet block. """ def get_block(in_channel, depth, num_units, stride=2): return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] def get_blocks(num_layers): if num_layers == 50: blocks = [ get_block(in_channel=64, depth=64, num_units=3), get_block(in_channel=64, depth=128, num_units=4), get_block(in_channel=128, depth=256, num_units=14), get_block(in_channel=256, depth=512, num_units=3) ] elif num_layers == 100: blocks = [ get_block(in_channel=64, depth=64, num_units=3), get_block(in_channel=64, depth=128, num_units=13), get_block(in_channel=128, depth=256, num_units=30), get_block(in_channel=256, depth=512, num_units=3) ] elif num_layers == 152: blocks = [ get_block(in_channel=64, depth=64, num_units=3), get_block(in_channel=64, depth=128, num_units=8), get_block(in_channel=128, depth=256, num_units=36), get_block(in_channel=256, depth=512, num_units=3) ] else: raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) return blocks class SEModule(Module): def __init__(self, channels, reduction): super(SEModule, self).__init__() self.avg_pool = AdaptiveAvgPool2d(1) self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) self.relu = ReLU(inplace=True) self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) self.sigmoid = Sigmoid() def forward(self, x): module_input = x x = self.avg_pool(x) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) x = self.sigmoid(x) return module_input * x class bottleneck_IR(Module): def __init__(self, in_channel, depth, stride): super(bottleneck_IR, self).__init__() if in_channel == depth: self.shortcut_layer = MaxPool2d(1, stride) else: self.shortcut_layer = Sequential( Conv2d(in_channel, depth, (1, 1), stride, bias=False), BatchNorm2d(depth) ) self.res_layer = Sequential( BatchNorm2d(in_channel), Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) ) def forward(self, x): shortcut = self.shortcut_layer(x) res = self.res_layer(x) return res + shortcut class bottleneck_IR_SE(Module): def __init__(self, in_channel, depth, stride): super(bottleneck_IR_SE, self).__init__() if in_channel == depth: self.shortcut_layer = MaxPool2d(1, stride) else: self.shortcut_layer = Sequential( Conv2d(in_channel, depth, (1, 1), stride, bias=False), BatchNorm2d(depth) ) self.res_layer = Sequential( BatchNorm2d(in_channel), Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth), SEModule(depth, 16) ) def forward(self, x): shortcut = self.shortcut_layer(x) res = self.res_layer(x) return res + shortcut def _upsample_add(x, y): """Upsample and add two feature maps. Args: x: (Variable) top feature map to be upsampled. y: (Variable) lateral feature map. Returns: (Variable) added feature map. Note in PyTorch, when input size is odd, the upsampled feature map with `F.upsample(..., scale_factor=2, mode='nearest')` maybe not equal to the lateral feature map size. e.g. original input size: [N,_,15,15] -> conv2d feature map size: [N,_,8,8] -> upsampled feature map size: [N,_,16,16] So we choose bilinear upsample which supports arbitrary output sizes. """ _, _, H, W = y.size() return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y ================================================ FILE: models/encoders/model_irse.py ================================================ from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm """ Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) """ class Backbone(Module): def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): super(Backbone, self).__init__() assert input_size in [112, 224], "input_size should be 112 or 224" assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" blocks = get_blocks(num_layers) if mode == 'ir': unit_module = bottleneck_IR elif mode == 'ir_se': unit_module = bottleneck_IR_SE self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), PReLU(64)) if input_size == 112: self.output_layer = Sequential(BatchNorm2d(512), Dropout(drop_ratio), Flatten(), Linear(512 * 7 * 7, 512), BatchNorm1d(512, affine=affine)) else: self.output_layer = Sequential(BatchNorm2d(512), Dropout(drop_ratio), Flatten(), Linear(512 * 14 * 14, 512), BatchNorm1d(512, affine=affine)) modules = [] for block in blocks: for bottleneck in block: modules.append(unit_module(bottleneck.in_channel, bottleneck.depth, bottleneck.stride)) self.body = Sequential(*modules) def forward(self, x): x = self.input_layer(x) x = self.body(x) x = self.output_layer(x) return l2_norm(x) def IR_50(input_size): """Constructs a ir-50 model.""" model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) return model def IR_101(input_size): """Constructs a ir-101 model.""" model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) return model def IR_152(input_size): """Constructs a ir-152 model.""" model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) return model def IR_SE_50(input_size): """Constructs a ir_se-50 model.""" model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) return model def IR_SE_101(input_size): """Constructs a ir_se-101 model.""" model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) return model def IR_SE_152(input_size): """Constructs a ir_se-152 model.""" model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) return model ================================================ FILE: models/encoders/psp_encoders.py ================================================ from enum import Enum import math import numpy as np import torch from torch import nn from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module from models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE, _upsample_add from models.stylegan2.model import EqualLinear class ProgressiveStage(Enum): WTraining = 0 Delta1Training = 1 Delta2Training = 2 Delta3Training = 3 Delta4Training = 4 Delta5Training = 5 Delta6Training = 6 Delta7Training = 7 Delta8Training = 8 Delta9Training = 9 Delta10Training = 10 Delta11Training = 11 Delta12Training = 12 Delta13Training = 13 Delta14Training = 14 Delta15Training = 15 Delta16Training = 16 Delta17Training = 17 Inference = 18 class GradualStyleBlock(Module): def __init__(self, in_c, out_c, spatial): super(GradualStyleBlock, self).__init__() self.out_c = out_c self.spatial = spatial num_pools = int(np.log2(spatial)) modules = [] modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] for i in range(num_pools - 1): modules += [ Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU() ] self.convs = nn.Sequential(*modules) self.linear = EqualLinear(out_c, out_c, lr_mul=1) def forward(self, x): x = self.convs(x) x = x.view(-1, self.out_c) x = self.linear(x) return x class GradualStyleEncoder(Module): def __init__(self, num_layers, mode='ir', opts=None): super(GradualStyleEncoder, self).__init__() assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' blocks = get_blocks(num_layers) if mode == 'ir': unit_module = bottleneck_IR elif mode == 'ir_se': unit_module = bottleneck_IR_SE self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), PReLU(64)) modules = [] for block in blocks: for bottleneck in block: modules.append(unit_module(bottleneck.in_channel, bottleneck.depth, bottleneck.stride)) self.body = Sequential(*modules) self.styles = nn.ModuleList() log_size = int(math.log(opts.stylegan_size, 2)) self.style_count = 2 * log_size - 2 self.coarse_ind = 3 self.middle_ind = 7 for i in range(self.style_count): if i < self.coarse_ind: style = GradualStyleBlock(512, 512, 16) elif i < self.middle_ind: style = GradualStyleBlock(512, 512, 32) else: style = GradualStyleBlock(512, 512, 64) self.styles.append(style) self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) def forward(self, x): x = self.input_layer(x) latents = [] modulelist = list(self.body._modules.values()) for i, l in enumerate(modulelist): x = l(x) if i == 6: c1 = x elif i == 20: c2 = x elif i == 23: c3 = x for j in range(self.coarse_ind): latents.append(self.styles[j](c3)) p2 = _upsample_add(c3, self.latlayer1(c2)) for j in range(self.coarse_ind, self.middle_ind): latents.append(self.styles[j](p2)) p1 = _upsample_add(p2, self.latlayer2(c1)) for j in range(self.middle_ind, self.style_count): latents.append(self.styles[j](p1)) out = torch.stack(latents, dim=1) return out class Encoder4Editing(Module): def __init__(self, num_layers, stylegan_size, mode='ir'): super(Encoder4Editing, self).__init__() assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' blocks = get_blocks(num_layers) if mode == 'ir': unit_module = bottleneck_IR elif mode == 'ir_se': unit_module = bottleneck_IR_SE self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), PReLU(64)) modules = [] for block in blocks: for bottleneck in block: modules.append(unit_module(bottleneck.in_channel, bottleneck.depth, bottleneck.stride)) self.body = Sequential(*modules) self.styles = nn.ModuleList() log_size = int(math.log(stylegan_size, 2)) self.style_count = 2 * log_size - 2 self.coarse_ind = 3 self.middle_ind = 7 for i in range(self.style_count): if i < self.coarse_ind: style = GradualStyleBlock(512, 512, 16) elif i < self.middle_ind: style = GradualStyleBlock(512, 512, 32) else: style = GradualStyleBlock(512, 512, 64) self.styles.append(style) self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) self.progressive_stage = ProgressiveStage.Inference def get_deltas_starting_dimensions(self): ''' Get a list of the initial dimension of every delta from which it is applied ''' return list(range(self.style_count)) # Each dimension has a delta applied to it def set_progressive_stage(self, new_stage: ProgressiveStage): self.progressive_stage = new_stage print('Changed progressive stage to: ', new_stage) def forward(self, x): x = self.input_layer(x) modulelist = list(self.body._modules.values()) for i, l in enumerate(modulelist): x = l(x) if i == 6: c1 = x elif i == 20: c2 = x elif i == 23: c3 = x # Infer main W and duplicate it w0 = self.styles[0](c3) w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2) stage = self.progressive_stage.value features = c3 for i in range(1, min(stage + 1, self.style_count)): # Infer additional deltas if i == self.coarse_ind: p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features features = p2 elif i == self.middle_ind: p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features features = p1 delta_i = self.styles[i](features) w[:, i] += delta_i return w class BackboneEncoderUsingLastLayerIntoW(Module): def __init__(self, num_layers, mode='ir', opts=None): super(BackboneEncoderUsingLastLayerIntoW, self).__init__() print('Using BackboneEncoderUsingLastLayerIntoW') assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' blocks = get_blocks(num_layers) if mode == 'ir': unit_module = bottleneck_IR elif mode == 'ir_se': unit_module = bottleneck_IR_SE self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), PReLU(64)) self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1)) self.linear = EqualLinear(512, 512, lr_mul=1) modules = [] for block in blocks: for bottleneck in block: modules.append(unit_module(bottleneck.in_channel, bottleneck.depth, bottleneck.stride)) self.body = Sequential(*modules) log_size = int(math.log(opts.stylegan_size, 2)) self.style_count = 2 * log_size - 2 def forward(self, x): x = self.input_layer(x) x = self.body(x) x = self.output_pool(x) x = x.view(-1, 512) x = self.linear(x) return x.repeat(self.style_count, 1, 1).permute(1, 0, 2) ================================================ FILE: models/stylegan2/__init__.py ================================================ ================================================ FILE: models/stylegan2/model.py ================================================ import math import random import torch from torch import nn from torch.nn import functional as F from models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d class PixelNorm(nn.Module): def __init__(self): super().__init__() def forward(self, input): return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) def make_kernel(k): k = torch.tensor(k, dtype=torch.float32) if k.ndim == 1: k = k[None, :] * k[:, None] k /= k.sum() return k class Upsample(nn.Module): def __init__(self, kernel, factor=2): super().__init__() self.factor = factor kernel = make_kernel(kernel) * (factor ** 2) self.register_buffer('kernel', kernel) p = kernel.shape[0] - factor pad0 = (p + 1) // 2 + factor - 1 pad1 = p // 2 self.pad = (pad0, pad1) def forward(self, input): out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) return out class Downsample(nn.Module): def __init__(self, kernel, factor=2): super().__init__() self.factor = factor kernel = make_kernel(kernel) self.register_buffer('kernel', kernel) p = kernel.shape[0] - factor pad0 = (p + 1) // 2 pad1 = p // 2 self.pad = (pad0, pad1) def forward(self, input): out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) return out class Blur(nn.Module): def __init__(self, kernel, pad, upsample_factor=1): super().__init__() kernel = make_kernel(kernel) if upsample_factor > 1: kernel = kernel * (upsample_factor ** 2) self.register_buffer('kernel', kernel) self.pad = pad def forward(self, input): out = upfirdn2d(input, self.kernel, pad=self.pad) return out class EqualConv2d(nn.Module): def __init__( self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True ): super().__init__() self.weight = nn.Parameter( torch.randn(out_channel, in_channel, kernel_size, kernel_size) ) self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) self.stride = stride self.padding = padding if bias: self.bias = nn.Parameter(torch.zeros(out_channel)) else: self.bias = None def forward(self, input): out = F.conv2d( input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding, ) return out def __repr__(self): return ( f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' ) class EqualLinear(nn.Module): def __init__( self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None ): super().__init__() self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) if bias: self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) else: self.bias = None self.activation = activation self.scale = (1 / math.sqrt(in_dim)) * lr_mul self.lr_mul = lr_mul def forward(self, input): if self.activation: out = F.linear(input, self.weight * self.scale) out = fused_leaky_relu(out, self.bias * self.lr_mul) else: out = F.linear( input, self.weight * self.scale, bias=self.bias * self.lr_mul ) return out def __repr__(self): return ( f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' ) class ScaledLeakyReLU(nn.Module): def __init__(self, negative_slope=0.2): super().__init__() self.negative_slope = negative_slope def forward(self, input): out = F.leaky_relu(input, negative_slope=self.negative_slope) return out * math.sqrt(2) class ModulatedConv2d(nn.Module): def __init__( self, in_channel, out_channel, kernel_size, style_dim, demodulate=True, upsample=False, downsample=False, blur_kernel=[1, 3, 3, 1], ): super().__init__() self.eps = 1e-8 self.kernel_size = kernel_size self.in_channel = in_channel self.out_channel = out_channel self.upsample = upsample self.downsample = downsample if upsample: factor = 2 p = (len(blur_kernel) - factor) - (kernel_size - 1) pad0 = (p + 1) // 2 + factor - 1 pad1 = p // 2 + 1 self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) if downsample: factor = 2 p = (len(blur_kernel) - factor) + (kernel_size - 1) pad0 = (p + 1) // 2 pad1 = p // 2 self.blur = Blur(blur_kernel, pad=(pad0, pad1)) fan_in = in_channel * kernel_size ** 2 self.scale = 1 / math.sqrt(fan_in) self.padding = kernel_size // 2 self.weight = nn.Parameter( torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) ) self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) self.demodulate = demodulate def __repr__(self): return ( f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' f'upsample={self.upsample}, downsample={self.downsample})' ) def forward(self, input, style): batch, in_channel, height, width = input.shape style = self.modulation(style).view(batch, 1, in_channel, 1, 1) weight = self.scale * self.weight * style if self.demodulate: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) weight = weight.view( batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size ) if self.upsample: input = input.view(1, batch * in_channel, height, width) weight = weight.view( batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size ) weight = weight.transpose(1, 2).reshape( batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size ) out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) out = self.blur(out) elif self.downsample: input = self.blur(input) _, _, height, width = input.shape input = input.view(1, batch * in_channel, height, width) out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) else: input = input.view(1, batch * in_channel, height, width) out = F.conv2d(input, weight, padding=self.padding, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) return out class NoiseInjection(nn.Module): def __init__(self): super().__init__() self.weight = nn.Parameter(torch.zeros(1)) def forward(self, image, noise=None): if noise is None: batch, _, height, width = image.shape noise = image.new_empty(batch, 1, height, width).normal_() return image + self.weight * noise class ConstantInput(nn.Module): def __init__(self, channel, size=4): super().__init__() self.input = nn.Parameter(torch.randn(1, channel, size, size)) def forward(self, input): batch = input.shape[0] out = self.input.repeat(batch, 1, 1, 1) return out class StyledConv(nn.Module): def __init__( self, in_channel, out_channel, kernel_size, style_dim, upsample=False, blur_kernel=[1, 3, 3, 1], demodulate=True, ): super().__init__() self.conv = ModulatedConv2d( in_channel, out_channel, kernel_size, style_dim, upsample=upsample, blur_kernel=blur_kernel, demodulate=demodulate, ) self.noise = NoiseInjection() # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) # self.activate = ScaledLeakyReLU(0.2) self.activate = FusedLeakyReLU(out_channel) def forward(self, input, style, noise=None): out = self.conv(input, style) out = self.noise(out, noise=noise) # out = out + self.bias out = self.activate(out) return out class ToRGB(nn.Module): def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): super().__init__() if upsample: self.upsample = Upsample(blur_kernel) self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) def forward(self, input, style, skip=None): out = self.conv(input, style) out = out + self.bias if skip is not None: skip = self.upsample(skip) out = out + skip return out class Generator(nn.Module): def __init__( self, size, style_dim, n_mlp, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], lr_mlp=0.01, ): super().__init__() self.size = size self.style_dim = style_dim layers = [PixelNorm()] for i in range(n_mlp): layers.append( EqualLinear( style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu' ) ) self.style = nn.Sequential(*layers) self.channels = { 4: 512, 8: 512, 16: 512, 32: 512, 64: 256 * channel_multiplier, 128: 128 * channel_multiplier, 256: 64 * channel_multiplier, 512: 32 * channel_multiplier, 1024: 16 * channel_multiplier, } self.input = ConstantInput(self.channels[4]) self.conv1 = StyledConv( self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel ) self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) self.log_size = int(math.log(size, 2)) self.num_layers = (self.log_size - 2) * 2 + 1 self.convs = nn.ModuleList() self.upsamples = nn.ModuleList() self.to_rgbs = nn.ModuleList() self.noises = nn.Module() in_channel = self.channels[4] for layer_idx in range(self.num_layers): res = (layer_idx + 5) // 2 shape = [1, 1, 2 ** res, 2 ** res] self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape)) for i in range(3, self.log_size + 1): out_channel = self.channels[2 ** i] self.convs.append( StyledConv( in_channel, out_channel, 3, style_dim, upsample=True, blur_kernel=blur_kernel, ) ) self.convs.append( StyledConv( out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel ) ) self.to_rgbs.append(ToRGB(out_channel, style_dim)) in_channel = out_channel self.n_latent = self.log_size * 2 - 2 def make_noise(self): device = self.input.input.device noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] for i in range(3, self.log_size + 1): for _ in range(2): noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) return noises def mean_latent(self, n_latent): latent_in = torch.randn( n_latent, self.style_dim, device=self.input.input.device ) latent = self.style(latent_in).mean(0, keepdim=True) return latent def get_latent(self, input): return self.style(input) def forward( self, styles, return_latents=False, inject_index=None, truncation=1, truncation_latent=None, input_is_latent=False, noise=None, randomize_noise=True, ): if not input_is_latent: styles = [self.style(s) for s in styles] if noise is None: if randomize_noise: noise = [None] * self.num_layers else: noise = [ getattr(self.noises, f'noise_{i}') for i in range(self.num_layers) ] if truncation < 1: style_t = [] for style in styles: style_t.append( truncation_latent + truncation * (style - truncation_latent) ) styles = style_t if len(styles) < 2: inject_index = self.n_latent if styles[0].ndim < 3: latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) else: latent = styles[0] else: if inject_index is None: inject_index = random.randint(1, self.n_latent - 1) latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) latent = torch.cat([latent, latent2], 1) out = self.input(latent) out = self.conv1(out, latent[:, 0], noise=noise[0]) skip = self.to_rgb1(out, latent[:, 1]) i = 1 for conv1, conv2, noise1, noise2, to_rgb in zip( self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs ): out = conv1(out, latent[:, i], noise=noise1) out = conv2(out, latent[:, i + 1], noise=noise2) skip = to_rgb(out, latent[:, i + 2], skip) i += 2 image = skip if return_latents: return image, latent else: return image, None class ConvLayer(nn.Sequential): def __init__( self, in_channel, out_channel, kernel_size, downsample=False, blur_kernel=[1, 3, 3, 1], bias=True, activate=True, ): layers = [] if downsample: factor = 2 p = (len(blur_kernel) - factor) + (kernel_size - 1) pad0 = (p + 1) // 2 pad1 = p // 2 layers.append(Blur(blur_kernel, pad=(pad0, pad1))) stride = 2 self.padding = 0 else: stride = 1 self.padding = kernel_size // 2 layers.append( EqualConv2d( in_channel, out_channel, kernel_size, padding=self.padding, stride=stride, bias=bias and not activate, ) ) if activate: if bias: layers.append(FusedLeakyReLU(out_channel)) else: layers.append(ScaledLeakyReLU(0.2)) super().__init__(*layers) class ResBlock(nn.Module): def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): super().__init__() self.conv1 = ConvLayer(in_channel, in_channel, 3) self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) self.skip = ConvLayer( in_channel, out_channel, 1, downsample=True, activate=False, bias=False ) def forward(self, input): out = self.conv1(input) out = self.conv2(out) skip = self.skip(input) out = (out + skip) / math.sqrt(2) return out class Discriminator(nn.Module): def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): super().__init__() channels = { 4: 512, 8: 512, 16: 512, 32: 512, 64: 256 * channel_multiplier, 128: 128 * channel_multiplier, 256: 64 * channel_multiplier, 512: 32 * channel_multiplier, 1024: 16 * channel_multiplier, } convs = [ConvLayer(3, channels[size], 1)] log_size = int(math.log(size, 2)) in_channel = channels[size] for i in range(log_size, 2, -1): out_channel = channels[2 ** (i - 1)] convs.append(ResBlock(in_channel, out_channel, blur_kernel)) in_channel = out_channel self.convs = nn.Sequential(*convs) self.stddev_group = 4 self.stddev_feat = 1 self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) self.final_linear = nn.Sequential( EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), EqualLinear(channels[4], 1), ) def forward(self, input): out = self.convs(input) batch, channel, height, width = out.shape group = min(batch, self.stddev_group) stddev = out.view( group, -1, self.stddev_feat, channel // self.stddev_feat, height, width ) stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) stddev = stddev.repeat(group, 1, height, width) out = torch.cat([out, stddev], 1) out = self.final_conv(out) out = out.view(batch, -1) out = self.final_linear(out) return out ================================================ FILE: models/stylegan2/op/__init__.py ================================================ from .fused_act import FusedLeakyReLU, fused_leaky_relu from .upfirdn2d import upfirdn2d ================================================ FILE: models/stylegan2/op/fused_act.py ================================================ import os import torch from torch import nn from torch.nn import functional as F module_path = os.path.dirname(__file__) class FusedLeakyReLU(nn.Module): def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): super().__init__() self.bias = nn.Parameter(torch.zeros(channel)) self.negative_slope = negative_slope self.scale = scale def forward(self, input): return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): rest_dim = [1] * (input.ndim - bias.ndim - 1) input = input.cuda() if input.ndim == 3: return ( F.leaky_relu( input + bias.view(1, *rest_dim, bias.shape[0]), negative_slope=negative_slope ) * scale ) else: return ( F.leaky_relu( input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope ) * scale ) ================================================ FILE: models/stylegan2/op/upfirdn2d.py ================================================ import os import torch from torch.nn import functional as F module_path = os.path.dirname(__file__) def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) return out def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): _, channel, in_h, in_w = input.shape input = input.reshape(-1, in_h, in_w, 1) _, in_h, in_w, minor = input.shape kernel_h, kernel_w = kernel.shape out = input.view(-1, in_h, 1, in_w, 1, minor) out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) out = out.view(-1, in_h * up_y, in_w * up_x, minor) out = F.pad( out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] ) out = out[ :, max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), :, ] out = out.permute(0, 3, 1, 2) out = out.reshape( [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] ) w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) out = F.conv2d(out, w) out = out.reshape( -1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, ) out = out.permute(0, 2, 3, 1) out = out[:, ::down_y, ::down_x, :] out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 return out.view(-1, channel, out_h, out_w) ================================================ FILE: options/test_options.py ================================================ from argparse import ArgumentParser class TestOptions: def __init__(self): self.parser = ArgumentParser() self.initialize() def initialize(self): # arguments for inference script self.parser.add_argument('--batch_size', default=1, type=int, help='Batch size for inference') self.parser.add_argument('--workers', default=4, type=int, help='Number of test dataloader workers') self.parser.add_argument('--stylegan_weights', default='models/pretrained_models/stylegan2-ffhq-config-f.pt', type=str, help='Path to StyleGAN model weights') self.parser.add_argument('--stylegan_size', default=1024, type=int) self.parser.add_argument("--threshold", type=int, default=0.03) self.parser.add_argument("--checkpoint_path", type=str, default='checkpoints/net_face.pth') self.parser.add_argument("--save_dir", type=str, default='output') self.parser.add_argument("--num_all", type=int, default=20) self.parser.add_argument("--target", type=str, required=True, help='Specify the target attributes to be edited') def parse(self): opts = self.parser.parse_args() return opts ================================================ FILE: options/train_options.py ================================================ from argparse import ArgumentParser class TrainOptions: def __init__(self): self.parser = ArgumentParser() self.initialize() def initialize(self): self.parser.add_argument('--batch_size', default=64, type=int, help='Batch size for training') self.parser.add_argument('--workers', default=4, type=int, help='Number of train dataloader workers') self.parser.add_argument('--learning_rate', default=0.5, type=float, help='Optimizer learning rate') self.parser.add_argument('--l2_lambda', default=1.0, type=float, help='l2 loss') self.parser.add_argument('--cos_lambda', default=1.0, type=float, help='cos loss') self.parser.add_argument('--checkpoint_path', default='checkpoints', type=str, help='Path to StyleCLIPModel model checkpoint') self.parser.add_argument('--classname', type=str, default='ffhq', help="which specific domain for training") self.parser.add_argument('--print_interval', default=1000, type=int, help='Interval for printing loss values during training') self.parser.add_argument('--val_interval', default=5000, type=int, help='Validation interval') self.parser.add_argument('--save_interval', default=10000, type=int, help='Model checkpoint interval') def parse(self): opts = self.parser.parse_args() return opts ================================================ FILE: scripts/inference.py ================================================ import os import sys sys.path.append(".") sys.path.append("..") import copy import clip import numpy as np import torch import torchvision from torch.utils.data import DataLoader import torch.nn.functional as F from datasets.test_dataset import TestLatentsDataset from models.stylegan2.model import Generator from delta_mapper import DeltaMapper from options.test_options import TestOptions from utils import map_tool from utils import stylespace_util def GetBoundary(fs3,dt,threshold): tmp=np.dot(fs3,dt) select=np.abs(tmp)